ArshVerma commited on
Commit
b0c19e3
Β·
1 Parent(s): 1ea1bca

Allow /reset params via body or query; add tests

Browse files

Make the /reset endpoint more robust by accepting task_id and seed from either a JSON body or query parameters. ResetRequest.task_id is now optional (defaults to BUG_DETECTION). The endpoint resolves final values with precedence: default < body < query, then calls env.reset(final_task_id, final_seed). Added fastapi Body import, an info log on reset, and a test (test_api_reset_robustness) that verifies no-body, empty-body, query-only, and query-overrides-body behavior.

Files changed (2) hide show
  1. app.py +27 -4
  2. tests/test_api.py +26 -0
app.py CHANGED
@@ -6,7 +6,7 @@ from typing import Dict, List, Optional, AsyncGenerator
6
  from datetime import datetime, timezone
7
  from contextlib import asynccontextmanager
8
 
9
- from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Depends, Security, Query, BackgroundTasks, Request
10
  from fastapi.responses import JSONResponse, FileResponse
11
  from fastapi.exceptions import RequestValidationError
12
  from fastapi.security.api_key import APIKeyHeader
@@ -148,7 +148,7 @@ async def cleanup_expired_episodes():
148
 
149
  # ── Models ────────────────────────────────────────────────────────────────────
150
  class ResetRequest(BaseModel):
151
- task_id: TaskId
152
  seed: int = 42
153
 
154
  class ResetResponse(BaseModel):
@@ -216,12 +216,35 @@ def health_check():
216
 
217
  @app.post("/reset", response_model=ResetResponse)
218
  @limiter.limit(f"{settings.rate_limit_per_minute}/minute")
219
- def reset_env(request: Request, req: ResetRequest, _: None = Depends(verify_api_key)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  episode_id = str(uuid.uuid4())
221
  env = CodeLensEnv()
222
- result = env.reset(req.task_id, req.seed)
223
  episodes[episode_id] = env
224
  episode_timestamps[episode_id] = datetime.now(timezone.utc)
 
 
225
  return ResetResponse(episode_id=episode_id, result=result)
226
 
227
  @app.post("/step/{episode_id}", response_model=StepResult)
 
6
  from datetime import datetime, timezone
7
  from contextlib import asynccontextmanager
8
 
9
+ from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Depends, Security, Query, BackgroundTasks, Request, Body
10
  from fastapi.responses import JSONResponse, FileResponse
11
  from fastapi.exceptions import RequestValidationError
12
  from fastapi.security.api_key import APIKeyHeader
 
148
 
149
  # ── Models ────────────────────────────────────────────────────────────────────
150
  class ResetRequest(BaseModel):
151
+ task_id: Optional[TaskId] = TaskId.BUG_DETECTION
152
  seed: int = 42
153
 
154
  class ResetResponse(BaseModel):
 
216
 
217
  @app.post("/reset", response_model=ResetResponse)
218
  @limiter.limit(f"{settings.rate_limit_per_minute}/minute")
219
+ def reset_env(
220
+ request: Request,
221
+ req: Optional[ResetRequest] = Body(None),
222
+ task_id: Optional[TaskId] = Query(None),
223
+ seed: Optional[int] = Query(None),
224
+ _: None = Depends(verify_api_key)
225
+ ):
226
+ # Determine task_id: Body (if present) > Query params > Default
227
+ final_task_id = TaskId.BUG_DETECTION
228
+ final_seed = 42
229
+
230
+ if req:
231
+ if req.task_id:
232
+ final_task_id = req.task_id
233
+ final_seed = req.seed
234
+
235
+ # Query parameters override body if provided explicitly
236
+ if task_id:
237
+ final_task_id = task_id
238
+ if seed is not None:
239
+ final_seed = seed
240
+
241
  episode_id = str(uuid.uuid4())
242
  env = CodeLensEnv()
243
+ result = env.reset(final_task_id, final_seed)
244
  episodes[episode_id] = env
245
  episode_timestamps[episode_id] = datetime.now(timezone.utc)
246
+
247
+ logger.info(f"Reset environment: task={final_task_id.value}, seed={final_seed}, id={episode_id}")
248
  return ResetResponse(episode_id=episode_id, result=result)
249
 
250
  @app.post("/step/{episode_id}", response_model=StepResult)
tests/test_api.py CHANGED
@@ -139,3 +139,29 @@ def test_api_leaderboard_pagination(client):
139
 
140
  # Test ordering (best first)
141
  assert data["entries"][0]["score"] >= data["entries"][1]["score"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  # Test ordering (best first)
141
  assert data["entries"][0]["score"] >= data["entries"][1]["score"]
142
+
143
+ def test_api_reset_robustness(client):
144
+ # 1. No body at all
145
+ resp = client.post("/reset")
146
+ assert resp.status_code == 200
147
+ assert resp.json()["result"]["task_id"] == "bug_detection"
148
+
149
+ # 2. Empty JSON body
150
+ resp = client.post("/reset", json={})
151
+ assert resp.status_code == 200
152
+ assert resp.json()["result"]["task_id"] == "bug_detection"
153
+
154
+ # 3. Query params only
155
+ resp = client.post("/reset?task_id=security_audit&seed=100")
156
+ assert resp.status_code == 200
157
+ assert resp.json()["result"]["task_id"] == "security_audit"
158
+ assert resp.json()["result"]["seed"] == 100
159
+
160
+ # 4. Query params overriding body
161
+ resp = client.post(
162
+ "/reset?task_id=architectural_review",
163
+ json={"task_id": "bug_detection", "seed": 50}
164
+ )
165
+ assert resp.status_code == 200
166
+ assert resp.json()["result"]["task_id"] == "architectural_review"
167
+ assert resp.json()["result"]["seed"] == 50