File size: 3,372 Bytes
e4853aa
 
 
 
 
 
3583511
 
474eafa
e4853aa
3583511
 
474eafa
 
 
3583511
 
 
474eafa
3583511
 
 
 
 
 
 
 
 
 
474eafa
4e5e784
3583511
4e5e784
e4853aa
474eafa
 
 
 
 
 
 
 
 
 
e4853aa
474eafa
 
 
 
 
 
 
 
 
 
3583511
 
 
474eafa
 
 
e4853aa
3583511
474eafa
e4853aa
 
 
 
 
 
 
 
 
 
 
 
 
3583511
 
474eafa
e4853aa
3583511
e4853aa
 
 
 
 
 
 
3583511
 
474eafa
e4853aa
3583511
e4853aa
 
 
 
 
 
 
 
3583511
e4853aa
 
 
 
 
 
 
3583511
 
 
 
 
 
 
6d8d3c3
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""Main FastAPI application for Code Security Review.

Exposes RESTful endpoints conforming to standard OpenEnv compliance specifications 
dictating interactions for agent evaluation.
"""

import os
import uvicorn
from typing import List, Optional
from fastapi import FastAPI, HTTPException, Query, status
from fastapi.middleware.cors import CORSMiddleware

from server.models import CodeReviewAction, StepResult, ResetResponse, StateResponse, TaskInfo
from server.tasks import TASKS
from server.environment import CodeSecurityEnv

app = FastAPI(
    title="Code Security Review — OpenEnv",
    description="An RL environment for training AI agents to perform code security review.",
    version="1.0.0",
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

env = CodeSecurityEnv()


@app.get("/")
def health() -> dict:
    """Health check endpoint."""
    return {
        "status": "ok",
        "project": "Code Security Review - OpenEnv",
        "version": "1.0.0",
        "organization": "Inmodel Labs",
    }


@app.get("/tasks", response_model=List[TaskInfo])
def list_tasks() -> List[TaskInfo]:
    """List all available tasks."""
    return [
        TaskInfo(
            id=t["id"],
            language=t["language"],
            bug_class=t["bug_class"],
            difficulty=t["difficulty"],
        )
        for t in TASKS.values()
    ]


@app.post("/reset", response_model=ResetResponse)
def reset(
    task_id: str = Query(default="python-off-by-one", description="Task ID to reset to"),
    seed: Optional[int] = Query(default=None, description="Optional seed for reproducibility")
) -> ResetResponse:
    """Reset the environment and return the first observation."""
    if task_id not in TASKS:
        raise HTTPException(
            status_code=status.HTTP_404_NOT_FOUND, 
            detail=f"Task '{task_id}' not found."
        )
    
    try:
        obs = env.reset(task_id=task_id, seed=seed)
        return ResetResponse(observation=obs)
    except Exception as e:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"System breakdown during environment reset: {e}"
        )


@app.post("/step", response_model=StepResult)
def step(action: CodeReviewAction) -> StepResult:
    """Submit a code review action and receive a reward signal."""
    try:
        return env.step(action)
    except Exception as e:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"Error executing agent action logic: {e}"
        )


@app.get("/state", response_model=StateResponse)
def state() -> StateResponse:
    """Return the current environment state."""
    try:
        return env.state()
    except Exception as e:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"Error analyzing global runtime state tracking: {e}"
        )


def main() -> None:
    """Run the environment ASGI server natively."""
    port_default = os.environ.get("PORT", "8000")
    try:
         port = int(port_default)
    except ValueError:
         port = 8000

    uvicorn.run(
        "server.app:app",
        host="0.0.0.0",
        port=port,
        reload=False,
    )


if __name__ == "__main__":
    main()