import tempfile import unittest from pathlib import Path from fastapi.testclient import TestClient from dispatch_arena.client import DispatchArenaClient from dispatch_arena.models import Config, Mode, VerifierVerdict from dispatch_arena.server.app import DispatchArenaServerApp, create_app, run_local_server_in_thread from dispatch_arena.server.env import DispatchArenaEnvironment from dispatch_arena.server.replay_store import ReplayStore class DispatchArenaServerClientTests(unittest.TestCase): def test_imports_and_object_creation(self): env = DispatchArenaEnvironment() app = create_app() client = DispatchArenaClient() self.assertIsInstance(env, DispatchArenaEnvironment) self.assertIsInstance(app.state.dispatch_arena, DispatchArenaServerApp) self.assertIsInstance(client, DispatchArenaClient) def test_fastapi_session_replay_and_openenv_paths(self): app = create_app(max_concurrent_envs=4) client = TestClient(app) health = client.get("/healthz").json() self.assertEqual(health["service"], "dispatch_arena") created = client.post("/api/sessions", json={"mode": "mini", "seed": 7, "config": {"max_ticks": 12}}).json() session_id = created["session_id"] obs = created["observation"] while not obs["done"]: for action in ["pickup", "dropoff", "go_pickup", "go_dropoff", "wait"]: if action in obs["legal_actions"]: break obs = client.post(f"/api/sessions/{session_id}/step", json={"action": action}).json()["observation"] self.assertEqual(obs["verifier_status"], "delivered_successfully") state = client.get(f"/api/sessions/{session_id}/state").json()["state"] self.assertTrue(state["done"]) replay = client.get(f"/api/sessions/{session_id}/replay").json()["records"] self.assertGreaterEqual(len(replay), state["tick"] + 1) self.assertEqual(replay[0]["type"], "reset") self.assertEqual(replay[-1]["type"], "summary") reset = client.post("/reset", json={"seed": 1, "config": {"mode": "mini", "max_ticks": 12}}).json() openenv_session = reset["session_id"] openenv_state = client.get("/state", params={"session_id": openenv_session}).json()["state"] self.assertEqual(openenv_state["mode"], "mini") def test_replay_store_persists_reward_components(self): with tempfile.TemporaryDirectory() as tmp: store = ReplayStore(root=Path(tmp)) manager = DispatchArenaServerApp(replay_store=store) session_id, obs = manager.create_session(Config(mode=Mode.MINI, max_ticks=12), seed=7) self.assertFalse(obs.done) obs = manager.step(session_id, "go_pickup") records = manager.replay(session_id) self.assertEqual(records[0]["type"], "reset") self.assertEqual(records[1]["type"], "step") self.assertIn("reward_breakdown", records[1]) self.assertEqual(records[1]["reward_breakdown"]["total_reward"], obs.reward) def test_one_episode_over_client(self): try: server, thread = run_local_server_in_thread(port=0, max_concurrent_envs=4) except PermissionError: self.skipTest("Socket bind not permitted in current sandbox") return host, port = server.server_address client = DispatchArenaClient(base_url=f"http://{host}:{port}") try: self.assertEqual(client.health()["service"], "dispatch_arena") obs = client.reset(seed=7) while not obs.done: obs = client.step(obs.legal_actions[0]) self.assertEqual(obs.verifier_status, VerifierVerdict.DELIVERED_SUCCESSFULLY) self.assertEqual(client.fetch_summary()["final_verdict"], "delivered_successfully") self.assertGreaterEqual(len(client.fetch_replay()), obs.state.tick + 1) finally: server.shutdown() server.server_close() thread.join(timeout=2) if __name__ == "__main__": unittest.main()