Spaces:
Sleeping
Sleeping
| """Tests for reward engine and grader primitives (dispatch domain).""" | |
| from __future__ import annotations | |
| import pytest | |
| from src.models import ( | |
| Action, | |
| DispatchAction, | |
| IncidentSeverity, | |
| IncidentState, | |
| IncidentStatus, | |
| IncidentType, | |
| Observation, | |
| State, | |
| UnitState, | |
| UnitStatus, | |
| UnitType, | |
| ) | |
| from src.rewards import RewardCalculator, RewardSignal | |
| def _state_with_one_dispatch() -> State: | |
| unit = UnitState( | |
| unit_id="MED-1", | |
| unit_type=UnitType.MEDIC, | |
| status=UnitStatus.DISPATCHED, | |
| location_x=0.0, | |
| location_y=0.0, | |
| assigned_incident_id="INC-001", | |
| eta_seconds=200.0, | |
| crew_count=2, | |
| ) | |
| inc = IncidentState( | |
| incident_id="INC-001", | |
| incident_type=IncidentType.CARDIAC_ARREST, | |
| severity=IncidentSeverity.PRIORITY_1, | |
| location_x=10.0, | |
| location_y=10.0, | |
| reported_at_step=0, | |
| units_assigned=["MED-1"], | |
| status=IncidentStatus.RESPONDING, | |
| survival_clock=100.0, | |
| ) | |
| return State( | |
| units={"MED-1": unit}, | |
| incidents={"INC-001": inc}, | |
| episode_id="ep", | |
| step_count=1, | |
| task_id="single_incident", | |
| city_time=30.0, | |
| metadata={ | |
| "default_required_units": {"IncidentType.CARDIAC_ARREST": ["UnitType.MEDIC"]}, | |
| "districts": ["a", "b"], | |
| "grid_size": [100, 100], | |
| }, | |
| ) | |
| def test_reward_signal_requires_fields() -> None: | |
| with pytest.raises(Exception): | |
| RewardSignal() # type: ignore[call-arg] | |
| def test_compute_reward_returns_tuple() -> None: | |
| calc = RewardCalculator() | |
| state = _state_with_one_dispatch() | |
| action = Action( | |
| action_type=DispatchAction.DISPATCH, | |
| unit_id="MED-1", | |
| incident_id="INC-001", | |
| notes="DISPATCH MED-1 -> INC-001", | |
| ) | |
| obs = Observation(result="ok", score=0.8, protocol_ok=True, issues=[]) | |
| signal, total = calc.compute_reward(state, action, obs) | |
| assert isinstance(signal, RewardSignal) | |
| # Fixture metadata stores enum-ish strings (e.g. "IncidentType.CARDIAC_ARREST"). | |
| # Triage should still award full credit for a correct match. | |
| assert signal.triage == 1.0 | |
| assert signal.protocol == 1.0 | |
| assert 0.0 <= total <= 1.0 | |
| def test_protocol_reward_uses_phraseology_notes() -> None: | |
| calc = RewardCalculator() | |
| state = _state_with_one_dispatch() | |
| obs = Observation(result="ok", score=0.8, protocol_ok=True, issues=[]) | |
| # No notes => neutral. | |
| action_no_notes = Action(action_type=DispatchAction.DISPATCH, unit_id="MED-1", incident_id="INC-001") | |
| signal, _ = calc.compute_reward(state, action_no_notes, obs) | |
| assert signal.protocol == 0.5 | |
| # Wrong notes => poor phraseology score. | |
| action_bad_notes = Action( | |
| action_type=DispatchAction.DISPATCH, | |
| unit_id="MED-1", | |
| incident_id="INC-001", | |
| notes="send car 12", | |
| ) | |
| signal2, _ = calc.compute_reward(state, action_bad_notes, obs) | |
| assert signal2.protocol == 0.0 | |
| def test_weights_sum_to_one() -> None: | |
| calc = RewardCalculator() | |
| assert abs(sum(calc.weights.values()) - 1.0) < 1e-9 | |