Spaces:
Sleeping
Sleeping
| """Tests for DispatchAPI client wrapper.""" | |
| from __future__ import annotations | |
| import asyncio | |
| from unittest.mock import AsyncMock, MagicMock, patch | |
| import pytest | |
| from src.api import APIError, ATCAircraftAPI, DispatchAPI | |
| from src.models import Action, DispatchAction, Observation, State | |
| class TestAPIError: | |
| def test_fields(self) -> None: | |
| err = APIError(status_code=404, detail="Not found") | |
| assert err.status_code == 404 | |
| assert "Not found" in err.detail | |
| class TestDispatchAPIInit: | |
| def test_default_base_url(self) -> None: | |
| api = DispatchAPI() | |
| assert api.base_url == "http://localhost:8000" | |
| def test_alias_exists(self) -> None: | |
| api = ATCAircraftAPI() | |
| assert isinstance(api, DispatchAPI) | |
| def test_uses_httpx_async_client(self) -> None: | |
| with patch("src.api.httpx.AsyncClient") as mock_client_class: | |
| api = DispatchAPI() | |
| api._get_client() | |
| mock_client_class.assert_called_once() | |
| class TestDispatchAPIReset: | |
| def test_reset_returns_observation(self) -> None: | |
| mock_response = MagicMock() | |
| mock_response.status_code = 200 | |
| mock_response.json.return_value = { | |
| "result": "dispatch center online", | |
| "score": 0.0, | |
| "protocol_ok": True, | |
| "issues": [], | |
| } | |
| api = DispatchAPI() | |
| api._client = AsyncMock() | |
| api._client.post = AsyncMock(return_value=mock_response) | |
| obs = asyncio.run(api.reset(task_id="single_incident", seed=42)) | |
| assert isinstance(obs, Observation) | |
| assert obs.protocol_ok is True | |
| def test_reset_raises_on_non_200(self) -> None: | |
| mock_response = MagicMock() | |
| mock_response.status_code = 500 | |
| mock_response.text = "boom" | |
| api = DispatchAPI() | |
| api._client = AsyncMock() | |
| api._client.post = AsyncMock(return_value=mock_response) | |
| with pytest.raises(APIError): | |
| asyncio.run(api.reset(task_id="single_incident", seed=1)) | |
| class TestDispatchAPIStep: | |
| def test_step_sends_action_payload(self) -> None: | |
| mock_response = MagicMock() | |
| mock_response.status_code = 200 | |
| mock_response.json.return_value = { | |
| "observation": { | |
| "result": "ok", | |
| "score": 0.8, | |
| "protocol_ok": True, | |
| "issues": [], | |
| }, | |
| "reward": 0.8, | |
| "done": False, | |
| } | |
| api = DispatchAPI() | |
| api._client = AsyncMock() | |
| api._client.post = AsyncMock(return_value=mock_response) | |
| action = Action( | |
| action_type=DispatchAction.DISPATCH, | |
| unit_id="MED-1", | |
| incident_id="INC-001", | |
| ) | |
| obs, reward, done = asyncio.run(api.step(action)) | |
| assert isinstance(obs, Observation) | |
| assert isinstance(reward, float) | |
| assert isinstance(done, bool) | |
| call_kwargs = api._client.post.call_args.kwargs | |
| assert call_kwargs["json"]["action"]["action_type"] == "DISPATCH" | |
| assert call_kwargs["json"]["action"]["unit_id"] == "MED-1" | |
| class TestDispatchAPIState: | |
| def test_state_returns_state(self) -> None: | |
| mock_response = MagicMock() | |
| mock_response.status_code = 200 | |
| mock_response.json.return_value = { | |
| "units": {}, | |
| "incidents": {}, | |
| "episode_id": "ep", | |
| "step_count": 0, | |
| "task_id": "single_incident", | |
| "city_time": 0.0, | |
| "metadata": {}, | |
| } | |
| api = DispatchAPI() | |
| api._client = AsyncMock() | |
| api._client.get = AsyncMock(return_value=mock_response) | |
| state = asyncio.run(api.state()) | |
| assert isinstance(state, State) | |
| assert state.task_id == "single_incident" | |