Spaces:
Sleeping
Sleeping
File size: 3,195 Bytes
13517a8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 | """Unit tests for shift_surge episode grading."""
from __future__ import annotations
from src.models import (
IncidentSeverity,
IncidentState,
IncidentStatus,
IncidentType,
State,
UnitState,
UnitStatus,
UnitType,
)
from src.tasks.shift_surge import ShiftSurgeGrader
def _base_state() -> State:
units = {
"MED-1": UnitState(
unit_id="MED-1",
unit_type=UnitType.MEDIC,
status=UnitStatus.AVAILABLE,
location_x=10.0,
location_y=10.0,
assigned_incident_id=None,
eta_seconds=0.0,
crew_count=2,
),
"ENG-1": UnitState(
unit_id="ENG-1",
unit_type=UnitType.ENGINE,
status=UnitStatus.AVAILABLE,
location_x=50.0,
location_y=50.0,
assigned_incident_id=None,
eta_seconds=0.0,
crew_count=4,
),
"PAT-1": UnitState(
unit_id="PAT-1",
unit_type=UnitType.PATROL,
status=UnitStatus.AVAILABLE,
location_x=90.0,
location_y=10.0,
assigned_incident_id=None,
eta_seconds=0.0,
crew_count=2,
),
}
incidents = {
"INC-001": IncidentState(
incident_id="INC-001",
incident_type=IncidentType.CARDIAC_ARREST,
severity=IncidentSeverity.PRIORITY_1,
location_x=12.0,
location_y=12.0,
reported_at_step=0,
units_assigned=[],
status=IncidentStatus.PENDING,
survival_clock=600.0,
),
"INC-002": IncidentState(
incident_id="INC-002",
incident_type=IncidentType.STRUCTURE_FIRE,
severity=IncidentSeverity.PRIORITY_2,
location_x=55.0,
location_y=48.0,
reported_at_step=0,
units_assigned=[],
status=IncidentStatus.PENDING,
survival_clock=1200.0,
),
}
return State(
units=units,
incidents=incidents,
episode_id="ep",
step_count=10,
task_id="shift_surge",
city_time=300.0,
metadata={
"districts": ["a", "b", "c"],
"grid_size": [100, 100],
"p1_seen": ["INC-001"],
"resolved_incidents": [],
"failed_incidents": [],
},
)
def test_shift_surge_grader_rewards_good_outcome() -> None:
state = _base_state()
state.incidents["INC-001"].status = IncidentStatus.RESOLVED
state.incidents["INC-002"].status = IncidentStatus.RESOLVED
state.metadata["resolved_incidents"] = ["INC-001", "INC-002"]
score = ShiftSurgeGrader().grade(state, rewards=[0.9] * 10)
assert 0.8 <= score <= 1.0
def test_shift_surge_grader_penalizes_failures_and_backlog() -> None:
state = _base_state()
state.incidents["INC-001"].status = IncidentStatus.ESCALATED
state.incidents["INC-002"].status = IncidentStatus.RESPONDING
state.metadata["failed_incidents"] = ["INC-001"]
score = ShiftSurgeGrader().grade(state, rewards=[0.2] * 10)
assert 0.0 <= score <= 0.4
|