Spaces:
Sleeping
Sleeping
File size: 7,071 Bytes
7f83247 eaa79f0 7f83247 eaa79f0 7f83247 eaa79f0 7f83247 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 | """
Integration tests for the FastAPI HTTP server (HF Space entrypoint).
Uses FastAPI's TestClient (backed by httpx) to exercise all five endpoints:
GET /health
GET /tasks
POST /reset
POST /step
GET /state
These tests are self-contained: each test creates its own TestClient so there
is no shared mutable state between tests.
"""
import pytest
from fastapi.testclient import TestClient
from grid_env.Server.app import app
TASK_IDS = [
"easy_single_pick",
"medium_multi_item",
"hard_restock_priority",
"obstacle_course",
"heavy_lifting",
"stamina_run",
"budget_run",
"gauntlet",
]
ALL_ACTIONS = [
"turn_left",
"turn_right",
"move_forward",
"scan_bin",
"pick_item",
"pack_item",
"recharge",
"rest",
"wait",
]
@pytest.fixture()
def client():
"""Fresh TestClient (and thus fresh WarehouseEnvService) per test."""
# Re-import to get a fresh app instance each test.
from grid_env.Server import app as app_module
import importlib
importlib.reload(app_module)
with TestClient(app_module.app) as c:
yield c
# ---------------------------------------------------------------------------
# GET /health
# ---------------------------------------------------------------------------
def test_health_returns_200(client):
resp = client.get("/health")
assert resp.status_code == 200
def test_health_has_required_keys(client):
body = client.get("/health").json()
assert "status" in body
assert "task_id" in body
assert "episode_id" in body
def test_health_status_is_ok(client):
body = client.get("/health").json()
assert body["status"] == "ok"
# ---------------------------------------------------------------------------
# GET /tasks
# ---------------------------------------------------------------------------
def test_tasks_returns_200(client):
resp = client.get("/tasks")
assert resp.status_code == 200
def test_tasks_has_tasks_key(client):
body = client.get("/tasks").json()
assert "tasks" in body
assert isinstance(body["tasks"], list)
def test_tasks_returns_all(client):
body = client.get("/tasks").json()
ids = {t["task_id"] for t in body["tasks"]}
assert ids == set(TASK_IDS)
# ---------------------------------------------------------------------------
# POST /reset
# ---------------------------------------------------------------------------
def test_reset_default_returns_200(client):
resp = client.post("/reset", json={})
assert resp.status_code == 200
def test_reset_response_has_observation_and_state(client):
body = client.post("/reset", json={}).json()
assert "observation" in body
assert "state" in body
def test_reset_observation_has_task_id(client):
body = client.post("/reset", json={"task_id": "easy_single_pick"}).json()
assert body["observation"]["task_id"] == "easy_single_pick"
@pytest.mark.parametrize("task_id", TASK_IDS)
def test_reset_each_task(client, task_id):
body = client.post("/reset", json={"task_id": task_id}).json()
assert body["observation"]["task_id"] == task_id
def test_reset_with_seed_is_deterministic(client):
"""Same seed must produce the same initial battery level."""
body1 = client.post("/reset", json={"task_id": "easy_single_pick", "seed": 42}).json()
body2 = client.post("/reset", json={"task_id": "easy_single_pick", "seed": 42}).json()
assert body1["observation"]["battery_level"] == body2["observation"]["battery_level"]
def test_reset_unknown_task_returns_404(client):
resp = client.post("/reset", json={"task_id": "nonexistent_task"})
assert resp.status_code == 404
# ---------------------------------------------------------------------------
# POST /step
# ---------------------------------------------------------------------------
def test_step_after_reset_returns_200(client):
client.post("/reset", json={"task_id": "easy_single_pick"})
resp = client.post("/step", json={"command": "wait"})
assert resp.status_code == 200
def test_step_response_has_required_keys(client):
client.post("/reset", json={"task_id": "easy_single_pick"})
body = client.post("/step", json={"command": "wait"}).json()
assert "observation" in body
assert "reward" in body
assert "done" in body
assert "info" in body
assert "state" in body
def test_step_done_is_bool(client):
client.post("/reset", json={"task_id": "easy_single_pick"})
body = client.post("/step", json={"command": "wait"}).json()
assert isinstance(body["done"], bool)
@pytest.mark.parametrize("command", ALL_ACTIONS)
def test_step_all_commands_accepted(client, command):
"""Every valid command string should return 200."""
client.post("/reset", json={"task_id": "easy_single_pick"})
resp = client.post("/step", json={"command": command})
assert resp.status_code == 200
def test_step_invalid_command_returns_400(client):
client.post("/reset", json={"task_id": "easy_single_pick"})
resp = client.post("/step", json={"command": "fly_to_mars"})
assert resp.status_code == 400
def test_step_increments_step_count(client):
client.post("/reset", json={"task_id": "easy_single_pick"})
client.post("/step", json={"command": "wait"})
body = client.post("/step", json={"command": "wait"}).json()
assert body["state"]["step_count"] == 2
def test_step_episode_terminates(client):
"""Enough wait steps must eventually set done=True."""
client.post("/reset", json={"task_id": "easy_single_pick"})
done = False
for _ in range(60): # easy task max_steps is 40
resp = client.post("/step", json={"command": "wait"})
if resp.json()["done"]:
done = True
break
assert done
def test_step_score_in_range_after_done(client):
"""score in info must be in [0, 1] after the episode ends."""
client.post("/reset", json={"task_id": "easy_single_pick"})
score = None
for _ in range(60):
body = client.post("/step", json={"command": "wait"}).json()
if body["done"]:
score = body["info"].get("score")
break
assert score is not None
assert 0.0 <= score <= 1.0
# ---------------------------------------------------------------------------
# GET /state
# ---------------------------------------------------------------------------
def test_state_returns_200(client):
resp = client.get("/state")
assert resp.status_code == 200
def test_state_has_task_id(client):
body = client.get("/state").json()
assert "task_id" in body
def test_state_reflects_reset(client):
client.post("/reset", json={"task_id": "medium_multi_item"})
body = client.get("/state").json()
assert body["task_id"] == "medium_multi_item"
def test_state_step_count_matches_steps_taken(client):
client.post("/reset", json={"task_id": "easy_single_pick"})
for _ in range(3):
client.post("/step", json={"command": "wait"})
body = client.get("/state").json()
assert body["step_count"] == 3
|