ArshVerma commited on
Commit
8c875ce
·
1 Parent(s): b0c19e3

Implement Maximum Robustness for /reset to bypass body validation errors

Browse files
Files changed (2) hide show
  1. app.py +18 -9
  2. tests/test_api.py +16 -13
app.py CHANGED
@@ -216,23 +216,32 @@ def health_check():
216
 
217
  @app.post("/reset", response_model=ResetResponse)
218
  @limiter.limit(f"{settings.rate_limit_per_minute}/minute")
219
- def reset_env(
220
  request: Request,
221
- req: Optional[ResetRequest] = Body(None),
222
  task_id: Optional[TaskId] = Query(None),
223
  seed: Optional[int] = Query(None),
224
  _: None = Depends(verify_api_key)
225
  ):
226
- # Determine task_id: Body (if present) > Query params > Default
227
  final_task_id = TaskId.BUG_DETECTION
228
  final_seed = 42
229
 
230
- if req:
231
- if req.task_id:
232
- final_task_id = req.task_id
233
- final_seed = req.seed
 
 
 
 
 
 
 
 
 
 
234
 
235
- # Query parameters override body if provided explicitly
236
  if task_id:
237
  final_task_id = task_id
238
  if seed is not None:
@@ -244,7 +253,7 @@ def reset_env(
244
  episodes[episode_id] = env
245
  episode_timestamps[episode_id] = datetime.now(timezone.utc)
246
 
247
- logger.info(f"Reset environment: task={final_task_id.value}, seed={final_seed}, id={episode_id}")
248
  return ResetResponse(episode_id=episode_id, result=result)
249
 
250
  @app.post("/step/{episode_id}", response_model=StepResult)
 
216
 
217
  @app.post("/reset", response_model=ResetResponse)
218
  @limiter.limit(f"{settings.rate_limit_per_minute}/minute")
219
+ async def reset_env(
220
  request: Request,
 
221
  task_id: Optional[TaskId] = Query(None),
222
  seed: Optional[int] = Query(None),
223
  _: None = Depends(verify_api_key)
224
  ):
225
+ # Determine task_id and seed with manual fallback strategy
226
  final_task_id = TaskId.BUG_DETECTION
227
  final_seed = 42
228
 
229
+ # 1. Try to extract from body manually (handles empty/malformed bodies)
230
+ try:
231
+ body = await request.json()
232
+ if body and isinstance(body, dict):
233
+ if body.get("task_id"):
234
+ try:
235
+ final_task_id = TaskId(body["task_id"])
236
+ except ValueError:
237
+ pass
238
+ if body.get("seed") is not None:
239
+ final_seed = int(body["seed"])
240
+ except Exception:
241
+ # Ignore body parsing errors (empty/malformed) and fall back
242
+ pass
243
 
244
+ # 2. Query parameters override body if provided explicitly
245
  if task_id:
246
  final_task_id = task_id
247
  if seed is not None:
 
253
  episodes[episode_id] = env
254
  episode_timestamps[episode_id] = datetime.now(timezone.utc)
255
 
256
+ logger.info(f"Reset environment (Robust): task={final_task_id.value}, seed={final_seed}, id={episode_id}")
257
  return ResetResponse(episode_id=episode_id, result=result)
258
 
259
  @app.post("/step/{episode_id}", response_model=StepResult)
tests/test_api.py CHANGED
@@ -68,7 +68,8 @@ def test_api_health_fields(client):
68
 
69
  def test_api_reset_invalid_task(client):
70
  resp = client.post("/reset", json={"task_id": "invalid_task", "seed": 0})
71
- assert resp.status_code == 422
 
72
 
73
  def test_api_step_invalid_action_type(client):
74
  reset_resp = client.post("/reset", json={"task_id": "bug_detection", "seed": 0})
@@ -151,17 +152,19 @@ def test_api_reset_robustness(client):
151
  assert resp.status_code == 200
152
  assert resp.json()["result"]["task_id"] == "bug_detection"
153
 
154
- # 3. Query params only
155
- resp = client.post("/reset?task_id=security_audit&seed=100")
 
 
 
 
 
156
  assert resp.status_code == 200
157
- assert resp.json()["result"]["task_id"] == "security_audit"
158
- assert resp.json()["result"]["seed"] == 100
159
-
160
- # 4. Query params overriding body
161
- resp = client.post(
162
- "/reset?task_id=architectural_review",
163
- json={"task_id": "bug_detection", "seed": 50}
164
- )
165
  assert resp.status_code == 200
166
- assert resp.json()["result"]["task_id"] == "architectural_review"
167
- assert resp.json()["result"]["seed"] == 50
 
 
68
 
69
  def test_api_reset_invalid_task(client):
70
  resp = client.post("/reset", json={"task_id": "invalid_task", "seed": 0})
71
+ assert resp.status_code == 200
72
+ assert resp.json()["result"]["task_id"] == "bug_detection" # Fallback
73
 
74
  def test_api_step_invalid_action_type(client):
75
  reset_resp = client.post("/reset", json={"task_id": "bug_detection", "seed": 0})
 
152
  assert resp.status_code == 200
153
  assert resp.json()["result"]["task_id"] == "bug_detection"
154
 
155
+ # 3. Invalid JSON (should not trigger 422 now)
156
+ resp = client.post("/reset", content="invalid json {", headers={"Content-Type": "application/json"})
157
+ assert resp.status_code == 200
158
+ assert resp.json()["result"]["task_id"] == "bug_detection"
159
+
160
+ # 4. Plain text body (unexpected header, should still pass)
161
+ resp = client.post("/reset", content="just some text", headers={"Content-Type": "text/plain"})
162
  assert resp.status_code == 200
163
+ assert resp.json()["result"]["task_id"] == "bug_detection"
164
+
165
+ # 5. Query params override
166
+ resp = client.post("/reset?task_id=security_audit&seed=100")
 
 
 
 
167
  assert resp.status_code == 200
168
+ data = resp.json()
169
+ assert data["result"]["task_id"] == "security_audit"
170
+ assert data["result"]["seed"] == 100