Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
|
@@ -47,7 +47,8 @@ app.add_middleware(
|
|
| 47 |
|
| 48 |
# Single shared environment instance (stateful, per-session)
|
| 49 |
_env = APIGatewayDefender()
|
| 50 |
-
|
|
|
|
| 51 |
|
| 52 |
# βββ Routes ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 53 |
|
|
@@ -81,18 +82,13 @@ def root() -> Dict[str, Any]:
|
|
| 81 |
|
| 82 |
|
| 83 |
@app.post("/reset")
|
| 84 |
-
def reset(
|
| 85 |
"""
|
| 86 |
Start a new episode.
|
| 87 |
-
|
| 88 |
-
Request body (JSON):
|
| 89 |
-
{"task_id": "easy" | "medium" | "hard"}
|
| 90 |
-
|
| 91 |
-
Returns the initial Observation.
|
| 92 |
"""
|
| 93 |
-
task_id = (body or {}).get("task_id", "easy")
|
| 94 |
try:
|
| 95 |
-
obs: Observation = _env.reset(task_id=task_id)
|
| 96 |
except ValueError as exc:
|
| 97 |
raise HTTPException(status_code=422, detail=str(exc))
|
| 98 |
return obs.model_dump()
|
|
|
|
| 47 |
|
| 48 |
# Single shared environment instance (stateful, per-session)
|
| 49 |
_env = APIGatewayDefender()
|
| 50 |
+
class ResetRequest(BaseModel):
|
| 51 |
+
task_id: str = "easy"
|
| 52 |
|
| 53 |
# βββ Routes ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 54 |
|
|
|
|
| 82 |
|
| 83 |
|
| 84 |
@app.post("/reset")
|
| 85 |
+
def reset(req: ResetRequest = ResetRequest()) -> Dict[str, Any]:
|
| 86 |
"""
|
| 87 |
Start a new episode.
|
| 88 |
+
Request body (JSON): {"task_id": "easy" | "medium" | "hard"}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
"""
|
|
|
|
| 90 |
try:
|
| 91 |
+
obs: Observation = _env.reset(task_id=req.task_id)
|
| 92 |
except ValueError as exc:
|
| 93 |
raise HTTPException(status_code=422, detail=str(exc))
|
| 94 |
return obs.model_dump()
|