Spaces:
Sleeping
Sleeping
| """ | |
| Integration tests for the FastAPI HTTP server (HF Space entrypoint). | |
| Uses FastAPI's TestClient (backed by httpx) to exercise all five endpoints: | |
| GET /health | |
| GET /tasks | |
| POST /reset | |
| POST /step | |
| GET /state | |
| These tests are self-contained: each test creates its own TestClient so there | |
| is no shared mutable state between tests. | |
| """ | |
| import pytest | |
| from fastapi.testclient import TestClient | |
| from grid_env.Server.app import app | |
| TASK_IDS = [ | |
| "easy_single_pick", | |
| "medium_multi_item", | |
| "hard_restock_priority", | |
| "obstacle_course", | |
| "heavy_lifting", | |
| "stamina_run", | |
| "budget_run", | |
| "gauntlet", | |
| ] | |
| ALL_ACTIONS = [ | |
| "turn_left", | |
| "turn_right", | |
| "move_forward", | |
| "scan_bin", | |
| "pick_item", | |
| "pack_item", | |
| "recharge", | |
| "rest", | |
| "wait", | |
| ] | |
| def client(): | |
| """Fresh TestClient (and thus fresh WarehouseEnvService) per test.""" | |
| # Re-import to get a fresh app instance each test. | |
| from grid_env.Server import app as app_module | |
| import importlib | |
| importlib.reload(app_module) | |
| with TestClient(app_module.app) as c: | |
| yield c | |
| # --------------------------------------------------------------------------- | |
| # GET /health | |
| # --------------------------------------------------------------------------- | |
| def test_health_returns_200(client): | |
| resp = client.get("/health") | |
| assert resp.status_code == 200 | |
| def test_health_has_required_keys(client): | |
| body = client.get("/health").json() | |
| assert "status" in body | |
| assert "task_id" in body | |
| assert "episode_id" in body | |
| def test_health_status_is_ok(client): | |
| body = client.get("/health").json() | |
| assert body["status"] == "ok" | |
| # --------------------------------------------------------------------------- | |
| # GET /tasks | |
| # --------------------------------------------------------------------------- | |
| def test_tasks_returns_200(client): | |
| resp = client.get("/tasks") | |
| assert resp.status_code == 200 | |
| def test_tasks_has_tasks_key(client): | |
| body = client.get("/tasks").json() | |
| assert "tasks" in body | |
| assert isinstance(body["tasks"], list) | |
| def test_tasks_returns_all(client): | |
| body = client.get("/tasks").json() | |
| ids = {t["task_id"] for t in body["tasks"]} | |
| assert ids == set(TASK_IDS) | |
| # --------------------------------------------------------------------------- | |
| # POST /reset | |
| # --------------------------------------------------------------------------- | |
| def test_reset_default_returns_200(client): | |
| resp = client.post("/reset", json={}) | |
| assert resp.status_code == 200 | |
| def test_reset_response_has_observation_and_state(client): | |
| body = client.post("/reset", json={}).json() | |
| assert "observation" in body | |
| assert "state" in body | |
| def test_reset_observation_has_task_id(client): | |
| body = client.post("/reset", json={"task_id": "easy_single_pick"}).json() | |
| assert body["observation"]["task_id"] == "easy_single_pick" | |
| def test_reset_each_task(client, task_id): | |
| body = client.post("/reset", json={"task_id": task_id}).json() | |
| assert body["observation"]["task_id"] == task_id | |
| def test_reset_with_seed_is_deterministic(client): | |
| """Same seed must produce the same initial battery level.""" | |
| body1 = client.post("/reset", json={"task_id": "easy_single_pick", "seed": 42}).json() | |
| body2 = client.post("/reset", json={"task_id": "easy_single_pick", "seed": 42}).json() | |
| assert body1["observation"]["battery_level"] == body2["observation"]["battery_level"] | |
| def test_reset_unknown_task_returns_404(client): | |
| resp = client.post("/reset", json={"task_id": "nonexistent_task"}) | |
| assert resp.status_code == 404 | |
| # --------------------------------------------------------------------------- | |
| # POST /step | |
| # --------------------------------------------------------------------------- | |
| def test_step_after_reset_returns_200(client): | |
| client.post("/reset", json={"task_id": "easy_single_pick"}) | |
| resp = client.post("/step", json={"command": "wait"}) | |
| assert resp.status_code == 200 | |
| def test_step_response_has_required_keys(client): | |
| client.post("/reset", json={"task_id": "easy_single_pick"}) | |
| body = client.post("/step", json={"command": "wait"}).json() | |
| assert "observation" in body | |
| assert "reward" in body | |
| assert "done" in body | |
| assert "info" in body | |
| assert "state" in body | |
| def test_step_done_is_bool(client): | |
| client.post("/reset", json={"task_id": "easy_single_pick"}) | |
| body = client.post("/step", json={"command": "wait"}).json() | |
| assert isinstance(body["done"], bool) | |
| def test_step_all_commands_accepted(client, command): | |
| """Every valid command string should return 200.""" | |
| client.post("/reset", json={"task_id": "easy_single_pick"}) | |
| resp = client.post("/step", json={"command": command}) | |
| assert resp.status_code == 200 | |
| def test_step_invalid_command_returns_400(client): | |
| client.post("/reset", json={"task_id": "easy_single_pick"}) | |
| resp = client.post("/step", json={"command": "fly_to_mars"}) | |
| assert resp.status_code == 400 | |
| def test_step_increments_step_count(client): | |
| client.post("/reset", json={"task_id": "easy_single_pick"}) | |
| client.post("/step", json={"command": "wait"}) | |
| body = client.post("/step", json={"command": "wait"}).json() | |
| assert body["state"]["step_count"] == 2 | |
| def test_step_episode_terminates(client): | |
| """Enough wait steps must eventually set done=True.""" | |
| client.post("/reset", json={"task_id": "easy_single_pick"}) | |
| done = False | |
| for _ in range(60): # easy task max_steps is 40 | |
| resp = client.post("/step", json={"command": "wait"}) | |
| if resp.json()["done"]: | |
| done = True | |
| break | |
| assert done | |
| def test_step_score_in_range_after_done(client): | |
| """score in info must be in [0, 1] after the episode ends.""" | |
| client.post("/reset", json={"task_id": "easy_single_pick"}) | |
| score = None | |
| for _ in range(60): | |
| body = client.post("/step", json={"command": "wait"}).json() | |
| if body["done"]: | |
| score = body["info"].get("score") | |
| break | |
| assert score is not None | |
| assert 0.0 <= score <= 1.0 | |
| # --------------------------------------------------------------------------- | |
| # GET /state | |
| # --------------------------------------------------------------------------- | |
| def test_state_returns_200(client): | |
| resp = client.get("/state") | |
| assert resp.status_code == 200 | |
| def test_state_has_task_id(client): | |
| body = client.get("/state").json() | |
| assert "task_id" in body | |
| def test_state_reflects_reset(client): | |
| client.post("/reset", json={"task_id": "medium_multi_item"}) | |
| body = client.get("/state").json() | |
| assert body["task_id"] == "medium_multi_item" | |
| def test_state_step_count_matches_steps_taken(client): | |
| client.post("/reset", json={"task_id": "easy_single_pick"}) | |
| for _ in range(3): | |
| client.post("/step", json={"command": "wait"}) | |
| body = client.get("/state").json() | |
| assert body["step_count"] == 3 | |