""" Integration tests for the FastAPI HTTP server (HF Space entrypoint). Uses FastAPI's TestClient (backed by httpx) to exercise all five endpoints: GET /health GET /tasks POST /reset POST /step GET /state These tests are self-contained: each test creates its own TestClient so there is no shared mutable state between tests. """ import pytest from fastapi.testclient import TestClient from grid_env.Server.app import app TASK_IDS = [ "easy_single_pick", "medium_multi_item", "hard_restock_priority", "obstacle_course", "heavy_lifting", "stamina_run", "budget_run", "gauntlet", ] ALL_ACTIONS = [ "turn_left", "turn_right", "move_forward", "scan_bin", "pick_item", "pack_item", "recharge", "rest", "wait", ] @pytest.fixture() def client(): """Fresh TestClient (and thus fresh WarehouseEnvService) per test.""" # Re-import to get a fresh app instance each test. from grid_env.Server import app as app_module import importlib importlib.reload(app_module) with TestClient(app_module.app) as c: yield c # --------------------------------------------------------------------------- # GET /health # --------------------------------------------------------------------------- def test_health_returns_200(client): resp = client.get("/health") assert resp.status_code == 200 def test_health_has_required_keys(client): body = client.get("/health").json() assert "status" in body assert "task_id" in body assert "episode_id" in body def test_health_status_is_ok(client): body = client.get("/health").json() assert body["status"] == "ok" # --------------------------------------------------------------------------- # GET /tasks # --------------------------------------------------------------------------- def test_tasks_returns_200(client): resp = client.get("/tasks") assert resp.status_code == 200 def test_tasks_has_tasks_key(client): body = client.get("/tasks").json() assert "tasks" in body assert isinstance(body["tasks"], list) def test_tasks_returns_all(client): body = client.get("/tasks").json() ids = {t["task_id"] for t in body["tasks"]} assert ids == set(TASK_IDS) # --------------------------------------------------------------------------- # POST /reset # --------------------------------------------------------------------------- def test_reset_default_returns_200(client): resp = client.post("/reset", json={}) assert resp.status_code == 200 def test_reset_response_has_observation_and_state(client): body = client.post("/reset", json={}).json() assert "observation" in body assert "state" in body def test_reset_observation_has_task_id(client): body = client.post("/reset", json={"task_id": "easy_single_pick"}).json() assert body["observation"]["task_id"] == "easy_single_pick" @pytest.mark.parametrize("task_id", TASK_IDS) def test_reset_each_task(client, task_id): body = client.post("/reset", json={"task_id": task_id}).json() assert body["observation"]["task_id"] == task_id def test_reset_with_seed_is_deterministic(client): """Same seed must produce the same initial battery level.""" body1 = client.post("/reset", json={"task_id": "easy_single_pick", "seed": 42}).json() body2 = client.post("/reset", json={"task_id": "easy_single_pick", "seed": 42}).json() assert body1["observation"]["battery_level"] == body2["observation"]["battery_level"] def test_reset_unknown_task_returns_404(client): resp = client.post("/reset", json={"task_id": "nonexistent_task"}) assert resp.status_code == 404 # --------------------------------------------------------------------------- # POST /step # --------------------------------------------------------------------------- def test_step_after_reset_returns_200(client): client.post("/reset", json={"task_id": "easy_single_pick"}) resp = client.post("/step", json={"command": "wait"}) assert resp.status_code == 200 def test_step_response_has_required_keys(client): client.post("/reset", json={"task_id": "easy_single_pick"}) body = client.post("/step", json={"command": "wait"}).json() assert "observation" in body assert "reward" in body assert "done" in body assert "info" in body assert "state" in body def test_step_done_is_bool(client): client.post("/reset", json={"task_id": "easy_single_pick"}) body = client.post("/step", json={"command": "wait"}).json() assert isinstance(body["done"], bool) @pytest.mark.parametrize("command", ALL_ACTIONS) def test_step_all_commands_accepted(client, command): """Every valid command string should return 200.""" client.post("/reset", json={"task_id": "easy_single_pick"}) resp = client.post("/step", json={"command": command}) assert resp.status_code == 200 def test_step_invalid_command_returns_400(client): client.post("/reset", json={"task_id": "easy_single_pick"}) resp = client.post("/step", json={"command": "fly_to_mars"}) assert resp.status_code == 400 def test_step_increments_step_count(client): client.post("/reset", json={"task_id": "easy_single_pick"}) client.post("/step", json={"command": "wait"}) body = client.post("/step", json={"command": "wait"}).json() assert body["state"]["step_count"] == 2 def test_step_episode_terminates(client): """Enough wait steps must eventually set done=True.""" client.post("/reset", json={"task_id": "easy_single_pick"}) done = False for _ in range(60): # easy task max_steps is 40 resp = client.post("/step", json={"command": "wait"}) if resp.json()["done"]: done = True break assert done def test_step_score_in_range_after_done(client): """score in info must be in [0, 1] after the episode ends.""" client.post("/reset", json={"task_id": "easy_single_pick"}) score = None for _ in range(60): body = client.post("/step", json={"command": "wait"}).json() if body["done"]: score = body["info"].get("score") break assert score is not None assert 0.0 <= score <= 1.0 # --------------------------------------------------------------------------- # GET /state # --------------------------------------------------------------------------- def test_state_returns_200(client): resp = client.get("/state") assert resp.status_code == 200 def test_state_has_task_id(client): body = client.get("/state").json() assert "task_id" in body def test_state_reflects_reset(client): client.post("/reset", json={"task_id": "medium_multi_item"}) body = client.get("/state").json() assert body["task_id"] == "medium_multi_item" def test_state_step_count_matches_steps_taken(client): client.post("/reset", json={"task_id": "easy_single_pick"}) for _ in range(3): client.post("/step", json={"command": "wait"}) body = client.get("/state").json() assert body["step_count"] == 3