Spaces:
Sleeping
Sleeping
| """Tests for the dispatch state machine.""" | |
| from __future__ import annotations | |
| from src.city_schema import CitySchemaLoader | |
| from src.models import Action, DispatchAction | |
| from src.state_machine import DispatchStateMachine | |
| def test_reset_sets_ids_and_has_entities() -> None: | |
| schema = CitySchemaLoader.load("metro_city") | |
| sm = DispatchStateMachine(schema=schema, seed=42) | |
| state = sm.reset(task_id="single_incident", episode_id="ep-1") | |
| assert state.task_id == "single_incident" | |
| assert state.episode_id == "ep-1" | |
| assert state.step_count == 0 | |
| assert state.units | |
| assert state.incidents | |
| def test_legal_actions_non_empty_initially() -> None: | |
| schema = CitySchemaLoader.load("metro_city") | |
| sm = DispatchStateMachine(schema=schema, seed=42) | |
| state = sm.reset(task_id="single_incident", episode_id="ep-1") | |
| legal = sm.get_legal_actions(state) | |
| assert legal | |
| assert any(a.action_type == DispatchAction.DISPATCH for a in legal) | |
| def test_additional_actions_become_reachable() -> None: | |
| schema = CitySchemaLoader.load("metro_city") | |
| sm = DispatchStateMachine(schema=schema, seed=42) | |
| # Multi-incident is a better reachability surface (multiple incidents + P2 incident). | |
| state = sm.reset(task_id="multi_incident", episode_id="ep-1") | |
| legal = sm.get_legal_actions(state) | |
| assert any(a.action_type == DispatchAction.STAGE for a in legal) | |
| assert any(a.action_type == DispatchAction.UPGRADE for a in legal) | |
| assert any(a.action_type == DispatchAction.DOWNGRADE for a in legal) | |
| # After a dispatch, REASSIGN should be legal to the other active incident. | |
| dispatch = next(a for a in legal if a.action_type == DispatchAction.DISPATCH) | |
| state, _ = sm.step(state, dispatch) | |
| legal2 = sm.get_legal_actions(state) | |
| assert any(a.action_type == DispatchAction.CANCEL for a in legal2) | |
| assert any(a.action_type == DispatchAction.REASSIGN for a in legal2) | |
| def test_mutual_aid_appears_when_type_exhausted() -> None: | |
| schema = CitySchemaLoader.load("metro_city") | |
| sm = DispatchStateMachine(schema=schema, seed=42) | |
| state = sm.reset(task_id="multi_incident", episode_id="ep-1") | |
| # Exhaust all MEDIC availability. | |
| from src.models import UnitStatus | |
| for unit in state.units.values(): | |
| if unit.unit_type.value == "MEDIC": | |
| unit.status = UnitStatus.DISPATCHED | |
| unit.assigned_incident_id = "INC-001" | |
| legal = sm.get_legal_actions(state) | |
| assert any(a.action_type == DispatchAction.MUTUAL_AID for a in legal) | |
| def test_invalid_action_yields_protocol_ok_false() -> None: | |
| schema = CitySchemaLoader.load("metro_city") | |
| sm = DispatchStateMachine(schema=schema, seed=42) | |
| state = sm.reset(task_id="single_incident", episode_id="ep-1") | |
| bad = Action(action_type=DispatchAction.DISPATCH, unit_id="NOPE", incident_id="INC-001") | |
| state2, obs = sm.step(state, bad) | |
| assert obs.protocol_ok is False | |
| assert obs.result == "invalid action" | |
| assert state2.step_count == 1 | |
| def test_dispatch_progresses_incident() -> None: | |
| schema = CitySchemaLoader.load("metro_city") | |
| sm = DispatchStateMachine(schema=schema, seed=42) | |
| state = sm.reset(task_id="single_incident", episode_id="ep-1") | |
| legal = sm.get_legal_actions(state) | |
| state, obs = sm.step(state, legal[0]) | |
| assert obs.protocol_ok is True | |
| assert any(u.status.value != "AVAILABLE" for u in state.units.values()) | |