dispatch_arena_v0 / tests /test_server_client.py
Freakdivi's picture
Upload folder using huggingface_hub
c71bf62 verified
import tempfile
import unittest
from pathlib import Path
from fastapi.testclient import TestClient
from dispatch_arena.client import DispatchArenaClient
from dispatch_arena.models import Config, Mode, VerifierVerdict
from dispatch_arena.server.app import DispatchArenaServerApp, create_app, run_local_server_in_thread
from dispatch_arena.server.env import DispatchArenaEnvironment
from dispatch_arena.server.replay_store import ReplayStore
class DispatchArenaServerClientTests(unittest.TestCase):
def test_imports_and_object_creation(self):
env = DispatchArenaEnvironment()
app = create_app()
client = DispatchArenaClient()
self.assertIsInstance(env, DispatchArenaEnvironment)
self.assertIsInstance(app.state.dispatch_arena, DispatchArenaServerApp)
self.assertIsInstance(client, DispatchArenaClient)
def test_fastapi_session_replay_and_openenv_paths(self):
app = create_app(max_concurrent_envs=4)
client = TestClient(app)
health = client.get("/healthz").json()
self.assertEqual(health["service"], "dispatch_arena")
created = client.post("/api/sessions", json={"mode": "mini", "seed": 7, "config": {"max_ticks": 12}}).json()
session_id = created["session_id"]
obs = created["observation"]
while not obs["done"]:
for action in ["pickup", "dropoff", "go_pickup", "go_dropoff", "wait"]:
if action in obs["legal_actions"]:
break
obs = client.post(f"/api/sessions/{session_id}/step", json={"action": action}).json()["observation"]
self.assertEqual(obs["verifier_status"], "delivered_successfully")
state = client.get(f"/api/sessions/{session_id}/state").json()["state"]
self.assertTrue(state["done"])
replay = client.get(f"/api/sessions/{session_id}/replay").json()["records"]
self.assertGreaterEqual(len(replay), state["tick"] + 1)
self.assertEqual(replay[0]["type"], "reset")
self.assertEqual(replay[-1]["type"], "summary")
reset = client.post("/reset", json={"seed": 1, "config": {"mode": "mini", "max_ticks": 12}}).json()
openenv_session = reset["session_id"]
openenv_state = client.get("/state", params={"session_id": openenv_session}).json()["state"]
self.assertEqual(openenv_state["mode"], "mini")
def test_replay_store_persists_reward_components(self):
with tempfile.TemporaryDirectory() as tmp:
store = ReplayStore(root=Path(tmp))
manager = DispatchArenaServerApp(replay_store=store)
session_id, obs = manager.create_session(Config(mode=Mode.MINI, max_ticks=12), seed=7)
self.assertFalse(obs.done)
obs = manager.step(session_id, "go_pickup")
records = manager.replay(session_id)
self.assertEqual(records[0]["type"], "reset")
self.assertEqual(records[1]["type"], "step")
self.assertIn("reward_breakdown", records[1])
self.assertEqual(records[1]["reward_breakdown"]["total_reward"], obs.reward)
def test_one_episode_over_client(self):
try:
server, thread = run_local_server_in_thread(port=0, max_concurrent_envs=4)
except PermissionError:
self.skipTest("Socket bind not permitted in current sandbox")
return
host, port = server.server_address
client = DispatchArenaClient(base_url=f"http://{host}:{port}")
try:
self.assertEqual(client.health()["service"], "dispatch_arena")
obs = client.reset(seed=7)
while not obs.done:
obs = client.step(obs.legal_actions[0])
self.assertEqual(obs.verifier_status, VerifierVerdict.DELIVERED_SUCCESSFULLY)
self.assertEqual(client.fetch_summary()["final_verdict"], "delivered_successfully")
self.assertGreaterEqual(len(client.fetch_replay()), obs.state.tick + 1)
finally:
server.shutdown()
server.server_close()
thread.join(timeout=2)
if __name__ == "__main__":
unittest.main()