fixes
Browse files- server/app.py +3 -2
- submission_tests/test_server_api.py +7 -0
server/app.py
CHANGED
|
@@ -85,9 +85,10 @@ def schema() -> SchemaResponse:
|
|
| 85 |
|
| 86 |
|
| 87 |
@app.post("/reset", response_model=Observation)
|
| 88 |
-
def reset(request: ResetRequest) -> Observation:
|
| 89 |
try:
|
| 90 |
-
|
|
|
|
| 91 |
except ValueError as exc:
|
| 92 |
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
| 93 |
except Exception as exc:
|
|
|
|
| 85 |
|
| 86 |
|
| 87 |
@app.post("/reset", response_model=Observation)
|
| 88 |
+
def reset(request: Optional[ResetRequest] = Body(default=None)) -> Observation:
|
| 89 |
try:
|
| 90 |
+
payload = request or ResetRequest()
|
| 91 |
+
return env.reset(difficulty=payload.difficulty, seed=payload.seed)
|
| 92 |
except ValueError as exc:
|
| 93 |
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
| 94 |
except Exception as exc:
|
submission_tests/test_server_api.py
CHANGED
|
@@ -64,6 +64,13 @@ class ServerApiTests(unittest.TestCase):
|
|
| 64 |
self.assertIn("policy", payload["info"])
|
| 65 |
self.assertIn("grader_inputs", payload["info"])
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
def test_invalid_reset_payload_returns_422(self) -> None:
|
| 68 |
response = self.client.post("/reset", json={"difficulty": "legendary"})
|
| 69 |
self.assertEqual(response.status_code, 422)
|
|
|
|
| 64 |
self.assertIn("policy", payload["info"])
|
| 65 |
self.assertIn("grader_inputs", payload["info"])
|
| 66 |
|
| 67 |
+
def test_reset_without_body_uses_defaults(self) -> None:
|
| 68 |
+
response = self.client.post("/reset")
|
| 69 |
+
self.assertEqual(response.status_code, 200)
|
| 70 |
+
observation = response.json()
|
| 71 |
+
self.assertEqual(observation["task_difficulty"], "easy")
|
| 72 |
+
self.assertEqual(observation["step_count"], 0)
|
| 73 |
+
|
| 74 |
def test_invalid_reset_payload_returns_422(self) -> None:
|
| 75 |
response = self.client.post("/reset", json={"difficulty": "legendary"})
|
| 76 |
self.assertEqual(response.status_code, 422)
|