File size: 1,775 Bytes
a5c89a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from typing import Optional
import uvicorn

from server.schemas import Action, StepResult, ResetResult, StateResult
from server.env import DBMigrationEnv

app = FastAPI(
    title="DB Schema Migration OpenEnv",
    description="RL environment for database schema migration tasks",
    version="1.0.0",
)

# Single global env instance (stateful per session)
env = DBMigrationEnv()


class ResetRequest(BaseModel):
    task: Optional[str] = "easy"


@app.get("/")
def root():
    return {"status": "ok", "env": "db-schema-migration", "tasks": ["easy", "medium", "hard"]}


@app.get("/health")
def health():
    return {"status": "ok"}


@app.post("/reset", response_model=ResetResult)
def reset(request: ResetRequest = ResetRequest()):
    task = request.task if request.task else "easy"
    if task not in ["easy", "medium", "hard"]:
        raise HTTPException(status_code=400, detail=f"Unknown task '{task}'. Choose: easy, medium, hard")
    result = env.reset(task_name=task)
    return result


@app.post("/step", response_model=StepResult)
def step(action: Action):
    result = env.step(action)
    return result


@app.get("/state", response_model=StateResult)
def state():
    return env.state()


@app.get("/tasks")
def list_tasks():
    from server.env import TASKS
    return {
        name: {
            "description": task["description"],
            "requirements": task["target_requirements"],
            "max_steps": task["max_steps"],
            "hints": task.get("hints", []),
        }
        for name, task in TASKS.items()
    }


if __name__ == "__main__":
    uvicorn.run("server.main:app", host="0.0.0.0", port=7860, reload=False)