File size: 3,809 Bytes
6172160
4904e85
 
 
6172160
4904e85
 
 
 
6172160
 
4904e85
 
 
6172160
 
 
 
4904e85
 
6172160
4904e85
6172160
4904e85
 
6172160
 
 
4904e85
 
 
6172160
4904e85
 
 
 
6172160
 
4904e85
 
 
6172160
4904e85
6172160
4904e85
 
 
6172160
4904e85
 
 
6172160
 
 
4904e85
6172160
4904e85
 
6172160
4904e85
6172160
4904e85
 
 
6172160
 
4904e85
 
6172160
 
4904e85
 
 
 
6172160
 
 
4904e85
 
6172160
4904e85
 
 
6172160
4904e85
 
 
 
6172160
 
 
4904e85
6172160
4904e85
 
 
 
 
6172160
 
 
4904e85
 
6172160
 
4904e85
 
 
6172160
 
 
4904e85
6172160
 
4904e85
 
 
6172160
4904e85
 
 
6172160
4904e85
6172160
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""Tests for DispatchAPI client wrapper."""

from __future__ import annotations

import asyncio
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from src.api import APIError, ATCAircraftAPI, DispatchAPI
from src.models import Action, DispatchAction, Observation, State


class TestAPIError:
    def test_fields(self) -> None:
        err = APIError(status_code=404, detail="Not found")
        assert err.status_code == 404
        assert "Not found" in err.detail


class TestDispatchAPIInit:
    def test_default_base_url(self) -> None:
        api = DispatchAPI()
        assert api.base_url == "http://localhost:8000"

    def test_alias_exists(self) -> None:
        api = ATCAircraftAPI()
        assert isinstance(api, DispatchAPI)

    def test_uses_httpx_async_client(self) -> None:
        with patch("src.api.httpx.AsyncClient") as mock_client_class:
            api = DispatchAPI()
            api._get_client()
            mock_client_class.assert_called_once()


class TestDispatchAPIReset:
    def test_reset_returns_observation(self) -> None:
        mock_response = MagicMock()
        mock_response.status_code = 200
        mock_response.json.return_value = {
            "result": "dispatch center online",
            "score": 0.0,
            "protocol_ok": True,
            "issues": [],
        }

        api = DispatchAPI()
        api._client = AsyncMock()
        api._client.post = AsyncMock(return_value=mock_response)

        obs = asyncio.run(api.reset(task_id="single_incident", seed=42))
        assert isinstance(obs, Observation)
        assert obs.protocol_ok is True

    def test_reset_raises_on_non_200(self) -> None:
        mock_response = MagicMock()
        mock_response.status_code = 500
        mock_response.text = "boom"

        api = DispatchAPI()
        api._client = AsyncMock()
        api._client.post = AsyncMock(return_value=mock_response)

        with pytest.raises(APIError):
            asyncio.run(api.reset(task_id="single_incident", seed=1))


class TestDispatchAPIStep:
    def test_step_sends_action_payload(self) -> None:
        mock_response = MagicMock()
        mock_response.status_code = 200
        mock_response.json.return_value = {
            "observation": {
                "result": "ok",
                "score": 0.8,
                "protocol_ok": True,
                "issues": [],
            },
            "reward": 0.8,
            "done": False,
        }

        api = DispatchAPI()
        api._client = AsyncMock()
        api._client.post = AsyncMock(return_value=mock_response)

        action = Action(
            action_type=DispatchAction.DISPATCH,
            unit_id="MED-1",
            incident_id="INC-001",
        )
        obs, reward, done = asyncio.run(api.step(action))

        assert isinstance(obs, Observation)
        assert isinstance(reward, float)
        assert isinstance(done, bool)

        call_kwargs = api._client.post.call_args.kwargs
        assert call_kwargs["json"]["action"]["action_type"] == "DISPATCH"
        assert call_kwargs["json"]["action"]["unit_id"] == "MED-1"


class TestDispatchAPIState:
    def test_state_returns_state(self) -> None:
        mock_response = MagicMock()
        mock_response.status_code = 200
        mock_response.json.return_value = {
            "units": {},
            "incidents": {},
            "episode_id": "ep",
            "step_count": 0,
            "task_id": "single_incident",
            "city_time": 0.0,
            "metadata": {},
        }

        api = DispatchAPI()
        api._client = AsyncMock()
        api._client.get = AsyncMock(return_value=mock_response)

        state = asyncio.run(api.state())
        assert isinstance(state, State)
        assert state.task_id == "single_incident"