CystronCode commited on
Commit
f4bfb15
Β·
verified Β·
1 Parent(s): 9d0ed52

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +19 -6
main.py CHANGED
@@ -14,9 +14,9 @@ Endpoints
14
  GET /health β€” Liveness probe (required for HF Spaces ping)
15
  """
16
 
17
- from fastapi import FastAPI, HTTPException
18
  from fastapi.middleware.cors import CORSMiddleware
19
- from typing import Any, Dict
20
  from pydantic import BaseModel
21
 
22
  from env import (
@@ -47,9 +47,12 @@ app.add_middleware(
47
 
48
  # Single shared environment instance (stateful, per-session)
49
  _env = APIGatewayDefender()
 
 
50
  class ResetRequest(BaseModel):
51
  task_id: str = "easy"
52
 
 
53
  # ─── Routes ──────────────────────────────────────────────────────────────────────
54
 
55
  @app.get("/health")
@@ -82,13 +85,23 @@ def root() -> Dict[str, Any]:
82
 
83
 
84
  @app.post("/reset")
85
- def reset(req: ResetRequest = ResetRequest()) -> Dict[str, Any]:
 
 
 
86
  """
87
  Start a new episode.
88
- Request body (JSON): {"task_id": "easy" | "medium" | "hard"}
 
 
 
 
 
89
  """
 
 
90
  try:
91
- obs: Observation = _env.reset(task_id=req.task_id)
92
  except ValueError as exc:
93
  raise HTTPException(status_code=422, detail=str(exc))
94
  return obs.model_dump()
@@ -193,4 +206,4 @@ def baseline() -> Dict[str, Any]:
193
  "Heuristic baseline: reads visible logs, identifies the attack pattern, "
194
  "applies the optimal rule. No LLM required."
195
  ),
196
- }
 
14
  GET /health β€” Liveness probe (required for HF Spaces ping)
15
  """
16
 
17
+ from fastapi import FastAPI, HTTPException, Query
18
  from fastapi.middleware.cors import CORSMiddleware
19
+ from typing import Any, Dict, Optional
20
  from pydantic import BaseModel
21
 
22
  from env import (
 
47
 
48
  # Single shared environment instance (stateful, per-session)
49
  _env = APIGatewayDefender()
50
+
51
+
52
  class ResetRequest(BaseModel):
53
  task_id: str = "easy"
54
 
55
+
56
  # ─── Routes ──────────────────────────────────────────────────────────────────────
57
 
58
  @app.get("/health")
 
85
 
86
 
87
  @app.post("/reset")
88
+ async def reset(
89
+ req: Optional[ResetRequest] = None,
90
+ task_id: Optional[str] = Query(default=None),
91
+ ) -> Dict[str, Any]:
92
  """
93
  Start a new episode.
94
+
95
+ Accepts ALL of these formats (validator may use any):
96
+ - JSON body: {"task_id": "easy"}
97
+ - Query param: POST /reset?task_id=easy
98
+ - Empty body: POST /reset (defaults to "easy")
99
+ - No body at all: POST /reset (defaults to "easy")
100
  """
101
+ # Priority: JSON body > query param > default "easy"
102
+ tid = (req.task_id if req else None) or task_id or "easy"
103
  try:
104
+ obs: Observation = _env.reset(task_id=tid)
105
  except ValueError as exc:
106
  raise HTTPException(status_code=422, detail=str(exc))
107
  return obs.model_dump()
 
206
  "Heuristic baseline: reads visible logs, identifies the attack pattern, "
207
  "applies the optimal rule. No LLM required."
208
  ),
209
+ }