mini-rl-env / tests /test_server.py
sohambose98's picture
updated the tests and graders
eaa79f0
"""
Integration tests for the FastAPI HTTP server (HF Space entrypoint).
Uses FastAPI's TestClient (backed by httpx) to exercise all five endpoints:
GET /health
GET /tasks
POST /reset
POST /step
GET /state
These tests are self-contained: each test creates its own TestClient so there
is no shared mutable state between tests.
"""
import pytest
from fastapi.testclient import TestClient
from grid_env.Server.app import app
TASK_IDS = [
"easy_single_pick",
"medium_multi_item",
"hard_restock_priority",
"obstacle_course",
"heavy_lifting",
"stamina_run",
"budget_run",
"gauntlet",
]
ALL_ACTIONS = [
"turn_left",
"turn_right",
"move_forward",
"scan_bin",
"pick_item",
"pack_item",
"recharge",
"rest",
"wait",
]
@pytest.fixture()
def client():
"""Fresh TestClient (and thus fresh WarehouseEnvService) per test."""
# Re-import to get a fresh app instance each test.
from grid_env.Server import app as app_module
import importlib
importlib.reload(app_module)
with TestClient(app_module.app) as c:
yield c
# ---------------------------------------------------------------------------
# GET /health
# ---------------------------------------------------------------------------
def test_health_returns_200(client):
resp = client.get("/health")
assert resp.status_code == 200
def test_health_has_required_keys(client):
body = client.get("/health").json()
assert "status" in body
assert "task_id" in body
assert "episode_id" in body
def test_health_status_is_ok(client):
body = client.get("/health").json()
assert body["status"] == "ok"
# ---------------------------------------------------------------------------
# GET /tasks
# ---------------------------------------------------------------------------
def test_tasks_returns_200(client):
resp = client.get("/tasks")
assert resp.status_code == 200
def test_tasks_has_tasks_key(client):
body = client.get("/tasks").json()
assert "tasks" in body
assert isinstance(body["tasks"], list)
def test_tasks_returns_all(client):
body = client.get("/tasks").json()
ids = {t["task_id"] for t in body["tasks"]}
assert ids == set(TASK_IDS)
# ---------------------------------------------------------------------------
# POST /reset
# ---------------------------------------------------------------------------
def test_reset_default_returns_200(client):
resp = client.post("/reset", json={})
assert resp.status_code == 200
def test_reset_response_has_observation_and_state(client):
body = client.post("/reset", json={}).json()
assert "observation" in body
assert "state" in body
def test_reset_observation_has_task_id(client):
body = client.post("/reset", json={"task_id": "easy_single_pick"}).json()
assert body["observation"]["task_id"] == "easy_single_pick"
@pytest.mark.parametrize("task_id", TASK_IDS)
def test_reset_each_task(client, task_id):
body = client.post("/reset", json={"task_id": task_id}).json()
assert body["observation"]["task_id"] == task_id
def test_reset_with_seed_is_deterministic(client):
"""Same seed must produce the same initial battery level."""
body1 = client.post("/reset", json={"task_id": "easy_single_pick", "seed": 42}).json()
body2 = client.post("/reset", json={"task_id": "easy_single_pick", "seed": 42}).json()
assert body1["observation"]["battery_level"] == body2["observation"]["battery_level"]
def test_reset_unknown_task_returns_404(client):
resp = client.post("/reset", json={"task_id": "nonexistent_task"})
assert resp.status_code == 404
# ---------------------------------------------------------------------------
# POST /step
# ---------------------------------------------------------------------------
def test_step_after_reset_returns_200(client):
client.post("/reset", json={"task_id": "easy_single_pick"})
resp = client.post("/step", json={"command": "wait"})
assert resp.status_code == 200
def test_step_response_has_required_keys(client):
client.post("/reset", json={"task_id": "easy_single_pick"})
body = client.post("/step", json={"command": "wait"}).json()
assert "observation" in body
assert "reward" in body
assert "done" in body
assert "info" in body
assert "state" in body
def test_step_done_is_bool(client):
client.post("/reset", json={"task_id": "easy_single_pick"})
body = client.post("/step", json={"command": "wait"}).json()
assert isinstance(body["done"], bool)
@pytest.mark.parametrize("command", ALL_ACTIONS)
def test_step_all_commands_accepted(client, command):
"""Every valid command string should return 200."""
client.post("/reset", json={"task_id": "easy_single_pick"})
resp = client.post("/step", json={"command": command})
assert resp.status_code == 200
def test_step_invalid_command_returns_400(client):
client.post("/reset", json={"task_id": "easy_single_pick"})
resp = client.post("/step", json={"command": "fly_to_mars"})
assert resp.status_code == 400
def test_step_increments_step_count(client):
client.post("/reset", json={"task_id": "easy_single_pick"})
client.post("/step", json={"command": "wait"})
body = client.post("/step", json={"command": "wait"}).json()
assert body["state"]["step_count"] == 2
def test_step_episode_terminates(client):
"""Enough wait steps must eventually set done=True."""
client.post("/reset", json={"task_id": "easy_single_pick"})
done = False
for _ in range(60): # easy task max_steps is 40
resp = client.post("/step", json={"command": "wait"})
if resp.json()["done"]:
done = True
break
assert done
def test_step_score_in_range_after_done(client):
"""score in info must be in [0, 1] after the episode ends."""
client.post("/reset", json={"task_id": "easy_single_pick"})
score = None
for _ in range(60):
body = client.post("/step", json={"command": "wait"}).json()
if body["done"]:
score = body["info"].get("score")
break
assert score is not None
assert 0.0 <= score <= 1.0
# ---------------------------------------------------------------------------
# GET /state
# ---------------------------------------------------------------------------
def test_state_returns_200(client):
resp = client.get("/state")
assert resp.status_code == 200
def test_state_has_task_id(client):
body = client.get("/state").json()
assert "task_id" in body
def test_state_reflects_reset(client):
client.post("/reset", json={"task_id": "medium_multi_item"})
body = client.get("/state").json()
assert body["task_id"] == "medium_multi_item"
def test_state_step_count_matches_steps_taken(client):
client.post("/reset", json={"task_id": "easy_single_pick"})
for _ in range(3):
client.post("/step", json={"command": "wait"})
body = client.get("/state").json()
assert body["step_count"] == 3