CystronCode commited on
Commit
9d0ed52
Β·
verified Β·
1 Parent(s): 94205e1

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +5 -9
main.py CHANGED
@@ -47,7 +47,8 @@ app.add_middleware(
47
 
48
  # Single shared environment instance (stateful, per-session)
49
  _env = APIGatewayDefender()
50
-
 
51
 
52
  # ─── Routes ──────────────────────────────────────────────────────────────────────
53
 
@@ -81,18 +82,13 @@ def root() -> Dict[str, Any]:
81
 
82
 
83
  @app.post("/reset")
84
- def reset(body: Dict[str, str] = None) -> Dict[str, Any]:
85
  """
86
  Start a new episode.
87
-
88
- Request body (JSON):
89
- {"task_id": "easy" | "medium" | "hard"}
90
-
91
- Returns the initial Observation.
92
  """
93
- task_id = (body or {}).get("task_id", "easy")
94
  try:
95
- obs: Observation = _env.reset(task_id=task_id)
96
  except ValueError as exc:
97
  raise HTTPException(status_code=422, detail=str(exc))
98
  return obs.model_dump()
 
47
 
48
  # Single shared environment instance (stateful, per-session)
49
  _env = APIGatewayDefender()
50
+ class ResetRequest(BaseModel):
51
+ task_id: str = "easy"
52
 
53
  # ─── Routes ──────────────────────────────────────────────────────────────────────
54
 
 
82
 
83
 
84
  @app.post("/reset")
85
+ def reset(req: ResetRequest = ResetRequest()) -> Dict[str, Any]:
86
  """
87
  Start a new episode.
88
+ Request body (JSON): {"task_id": "easy" | "medium" | "hard"}
 
 
 
 
89
  """
 
90
  try:
91
+ obs: Observation = _env.reset(task_id=req.task_id)
92
  except ValueError as exc:
93
  raise HTTPException(status_code=422, detail=str(exc))
94
  return obs.model_dump()