File size: 7,071 Bytes
7f83247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eaa79f0
 
 
 
 
 
 
 
 
 
7f83247
 
 
 
 
 
 
 
eaa79f0
7f83247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eaa79f0
7f83247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
"""
Integration tests for the FastAPI HTTP server (HF Space entrypoint).

Uses FastAPI's TestClient (backed by httpx) to exercise all five endpoints:
  GET  /health
  GET  /tasks
  POST /reset
  POST /step
  GET  /state

These tests are self-contained: each test creates its own TestClient so there
is no shared mutable state between tests.
"""

import pytest
from fastapi.testclient import TestClient

from grid_env.Server.app import app


TASK_IDS = [
    "easy_single_pick",
    "medium_multi_item",
    "hard_restock_priority",
    "obstacle_course",
    "heavy_lifting",
    "stamina_run",
    "budget_run",
    "gauntlet",
]
ALL_ACTIONS = [
    "turn_left",
    "turn_right",
    "move_forward",
    "scan_bin",
    "pick_item",
    "pack_item",
    "recharge",
    "rest",
    "wait",
]


@pytest.fixture()
def client():
    """Fresh TestClient (and thus fresh WarehouseEnvService) per test."""
    # Re-import to get a fresh app instance each test.
    from grid_env.Server import app as app_module
    import importlib
    importlib.reload(app_module)
    with TestClient(app_module.app) as c:
        yield c


# ---------------------------------------------------------------------------
# GET /health
# ---------------------------------------------------------------------------

def test_health_returns_200(client):
    resp = client.get("/health")
    assert resp.status_code == 200


def test_health_has_required_keys(client):
    body = client.get("/health").json()
    assert "status" in body
    assert "task_id" in body
    assert "episode_id" in body


def test_health_status_is_ok(client):
    body = client.get("/health").json()
    assert body["status"] == "ok"


# ---------------------------------------------------------------------------
# GET /tasks
# ---------------------------------------------------------------------------

def test_tasks_returns_200(client):
    resp = client.get("/tasks")
    assert resp.status_code == 200


def test_tasks_has_tasks_key(client):
    body = client.get("/tasks").json()
    assert "tasks" in body
    assert isinstance(body["tasks"], list)


def test_tasks_returns_all(client):
    body = client.get("/tasks").json()
    ids = {t["task_id"] for t in body["tasks"]}
    assert ids == set(TASK_IDS)


# ---------------------------------------------------------------------------
# POST /reset
# ---------------------------------------------------------------------------

def test_reset_default_returns_200(client):
    resp = client.post("/reset", json={})
    assert resp.status_code == 200


def test_reset_response_has_observation_and_state(client):
    body = client.post("/reset", json={}).json()
    assert "observation" in body
    assert "state" in body


def test_reset_observation_has_task_id(client):
    body = client.post("/reset", json={"task_id": "easy_single_pick"}).json()
    assert body["observation"]["task_id"] == "easy_single_pick"


@pytest.mark.parametrize("task_id", TASK_IDS)
def test_reset_each_task(client, task_id):
    body = client.post("/reset", json={"task_id": task_id}).json()
    assert body["observation"]["task_id"] == task_id


def test_reset_with_seed_is_deterministic(client):
    """Same seed must produce the same initial battery level."""
    body1 = client.post("/reset", json={"task_id": "easy_single_pick", "seed": 42}).json()
    body2 = client.post("/reset", json={"task_id": "easy_single_pick", "seed": 42}).json()
    assert body1["observation"]["battery_level"] == body2["observation"]["battery_level"]


def test_reset_unknown_task_returns_404(client):
    resp = client.post("/reset", json={"task_id": "nonexistent_task"})
    assert resp.status_code == 404


# ---------------------------------------------------------------------------
# POST /step
# ---------------------------------------------------------------------------

def test_step_after_reset_returns_200(client):
    client.post("/reset", json={"task_id": "easy_single_pick"})
    resp = client.post("/step", json={"command": "wait"})
    assert resp.status_code == 200


def test_step_response_has_required_keys(client):
    client.post("/reset", json={"task_id": "easy_single_pick"})
    body = client.post("/step", json={"command": "wait"}).json()
    assert "observation" in body
    assert "reward" in body
    assert "done" in body
    assert "info" in body
    assert "state" in body


def test_step_done_is_bool(client):
    client.post("/reset", json={"task_id": "easy_single_pick"})
    body = client.post("/step", json={"command": "wait"}).json()
    assert isinstance(body["done"], bool)


@pytest.mark.parametrize("command", ALL_ACTIONS)
def test_step_all_commands_accepted(client, command):
    """Every valid command string should return 200."""
    client.post("/reset", json={"task_id": "easy_single_pick"})
    resp = client.post("/step", json={"command": command})
    assert resp.status_code == 200


def test_step_invalid_command_returns_400(client):
    client.post("/reset", json={"task_id": "easy_single_pick"})
    resp = client.post("/step", json={"command": "fly_to_mars"})
    assert resp.status_code == 400


def test_step_increments_step_count(client):
    client.post("/reset", json={"task_id": "easy_single_pick"})
    client.post("/step", json={"command": "wait"})
    body = client.post("/step", json={"command": "wait"}).json()
    assert body["state"]["step_count"] == 2


def test_step_episode_terminates(client):
    """Enough wait steps must eventually set done=True."""
    client.post("/reset", json={"task_id": "easy_single_pick"})
    done = False
    for _ in range(60):  # easy task max_steps is 40
        resp = client.post("/step", json={"command": "wait"})
        if resp.json()["done"]:
            done = True
            break
    assert done


def test_step_score_in_range_after_done(client):
    """score in info must be in [0, 1] after the episode ends."""
    client.post("/reset", json={"task_id": "easy_single_pick"})
    score = None
    for _ in range(60):
        body = client.post("/step", json={"command": "wait"}).json()
        if body["done"]:
            score = body["info"].get("score")
            break
    assert score is not None
    assert 0.0 <= score <= 1.0


# ---------------------------------------------------------------------------
# GET /state
# ---------------------------------------------------------------------------

def test_state_returns_200(client):
    resp = client.get("/state")
    assert resp.status_code == 200


def test_state_has_task_id(client):
    body = client.get("/state").json()
    assert "task_id" in body


def test_state_reflects_reset(client):
    client.post("/reset", json={"task_id": "medium_multi_item"})
    body = client.get("/state").json()
    assert body["task_id"] == "medium_multi_item"


def test_state_step_count_matches_steps_taken(client):
    client.post("/reset", json={"task_id": "easy_single_pick"})
    for _ in range(3):
        client.post("/step", json={"command": "wait"})
    body = client.get("/state").json()
    assert body["step_count"] == 3