File size: 1,801 Bytes
b641d3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64


from contextlib import asynccontextmanager

from fastapi import FastAPI, HTTPException

from .constants import TaskName
from .env import DistributedDebugEnv
from .models import Action, Observation, StepResult


@asynccontextmanager
async def lifespan(app: FastAPI):
    env = DistributedDebugEnv()
    env.start()
    app.state.env = env
    try:
        yield
    finally:
        env.close()


app = FastAPI(title="Distributed Systems Debug Environment", version="1.0.0", lifespan=lifespan)


@app.post("/reset", response_model=Observation)
async def reset(task_name: str | None = None) -> Observation:
    # Validator and sample inference call /reset without task input.
    # Use a deterministic default task for reproducible episode bootstrapping.
    selected_task_name = task_name or TaskName.CASCADING_TIMEOUT.value
    try:
        task = TaskName.parse(selected_task_name)
    except ValueError as exc:
        raise HTTPException(status_code=400, detail=str(exc)) from exc

    try:
        env: DistributedDebugEnv = app.state.env
        return env.reset(task_name=task)
    except Exception as exc:
        raise HTTPException(status_code=500, detail=str(exc)) from exc


@app.post("/step", response_model=StepResult)
async def step(action: Action) -> StepResult:
    try:
        env: DistributedDebugEnv = app.state.env
        return env.step(action)
    except Exception as exc:
        raise HTTPException(status_code=500, detail=str(exc)) from exc


@app.get("/state")
async def state() -> dict:
    try:
        env: DistributedDebugEnv = app.state.env
        return env.state()
    except Exception as exc:
        raise HTTPException(status_code=500, detail=str(exc)) from exc


@app.get("/health")
async def health() -> dict:
    return {"status": "ok", "version": "1.0.0"}