Spaces:
Sleeping
Sleeping
File size: 3,809 Bytes
6172160 4904e85 6172160 4904e85 6172160 4904e85 6172160 4904e85 6172160 4904e85 6172160 4904e85 6172160 4904e85 6172160 4904e85 6172160 4904e85 6172160 4904e85 6172160 4904e85 6172160 4904e85 6172160 4904e85 6172160 4904e85 6172160 4904e85 6172160 4904e85 6172160 4904e85 6172160 4904e85 6172160 4904e85 6172160 4904e85 6172160 4904e85 6172160 4904e85 6172160 4904e85 6172160 4904e85 6172160 4904e85 6172160 4904e85 6172160 4904e85 6172160 4904e85 6172160 4904e85 6172160 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | """Tests for DispatchAPI client wrapper."""
from __future__ import annotations
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from src.api import APIError, ATCAircraftAPI, DispatchAPI
from src.models import Action, DispatchAction, Observation, State
class TestAPIError:
def test_fields(self) -> None:
err = APIError(status_code=404, detail="Not found")
assert err.status_code == 404
assert "Not found" in err.detail
class TestDispatchAPIInit:
def test_default_base_url(self) -> None:
api = DispatchAPI()
assert api.base_url == "http://localhost:8000"
def test_alias_exists(self) -> None:
api = ATCAircraftAPI()
assert isinstance(api, DispatchAPI)
def test_uses_httpx_async_client(self) -> None:
with patch("src.api.httpx.AsyncClient") as mock_client_class:
api = DispatchAPI()
api._get_client()
mock_client_class.assert_called_once()
class TestDispatchAPIReset:
def test_reset_returns_observation(self) -> None:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"result": "dispatch center online",
"score": 0.0,
"protocol_ok": True,
"issues": [],
}
api = DispatchAPI()
api._client = AsyncMock()
api._client.post = AsyncMock(return_value=mock_response)
obs = asyncio.run(api.reset(task_id="single_incident", seed=42))
assert isinstance(obs, Observation)
assert obs.protocol_ok is True
def test_reset_raises_on_non_200(self) -> None:
mock_response = MagicMock()
mock_response.status_code = 500
mock_response.text = "boom"
api = DispatchAPI()
api._client = AsyncMock()
api._client.post = AsyncMock(return_value=mock_response)
with pytest.raises(APIError):
asyncio.run(api.reset(task_id="single_incident", seed=1))
class TestDispatchAPIStep:
def test_step_sends_action_payload(self) -> None:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"observation": {
"result": "ok",
"score": 0.8,
"protocol_ok": True,
"issues": [],
},
"reward": 0.8,
"done": False,
}
api = DispatchAPI()
api._client = AsyncMock()
api._client.post = AsyncMock(return_value=mock_response)
action = Action(
action_type=DispatchAction.DISPATCH,
unit_id="MED-1",
incident_id="INC-001",
)
obs, reward, done = asyncio.run(api.step(action))
assert isinstance(obs, Observation)
assert isinstance(reward, float)
assert isinstance(done, bool)
call_kwargs = api._client.post.call_args.kwargs
assert call_kwargs["json"]["action"]["action_type"] == "DISPATCH"
assert call_kwargs["json"]["action"]["unit_id"] == "MED-1"
class TestDispatchAPIState:
def test_state_returns_state(self) -> None:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"units": {},
"incidents": {},
"episode_id": "ep",
"step_count": 0,
"task_id": "single_incident",
"city_time": 0.0,
"metadata": {},
}
api = DispatchAPI()
api._client = AsyncMock()
api._client.get = AsyncMock(return_value=mock_response)
state = asyncio.run(api.state())
assert isinstance(state, State)
assert state.task_id == "single_incident"
|