File size: 3,195 Bytes
13517a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""Unit tests for shift_surge episode grading."""

from __future__ import annotations

from src.models import (
    IncidentSeverity,
    IncidentState,
    IncidentStatus,
    IncidentType,
    State,
    UnitState,
    UnitStatus,
    UnitType,
)
from src.tasks.shift_surge import ShiftSurgeGrader


def _base_state() -> State:
    units = {
        "MED-1": UnitState(
            unit_id="MED-1",
            unit_type=UnitType.MEDIC,
            status=UnitStatus.AVAILABLE,
            location_x=10.0,
            location_y=10.0,
            assigned_incident_id=None,
            eta_seconds=0.0,
            crew_count=2,
        ),
        "ENG-1": UnitState(
            unit_id="ENG-1",
            unit_type=UnitType.ENGINE,
            status=UnitStatus.AVAILABLE,
            location_x=50.0,
            location_y=50.0,
            assigned_incident_id=None,
            eta_seconds=0.0,
            crew_count=4,
        ),
        "PAT-1": UnitState(
            unit_id="PAT-1",
            unit_type=UnitType.PATROL,
            status=UnitStatus.AVAILABLE,
            location_x=90.0,
            location_y=10.0,
            assigned_incident_id=None,
            eta_seconds=0.0,
            crew_count=2,
        ),
    }

    incidents = {
        "INC-001": IncidentState(
            incident_id="INC-001",
            incident_type=IncidentType.CARDIAC_ARREST,
            severity=IncidentSeverity.PRIORITY_1,
            location_x=12.0,
            location_y=12.0,
            reported_at_step=0,
            units_assigned=[],
            status=IncidentStatus.PENDING,
            survival_clock=600.0,
        ),
        "INC-002": IncidentState(
            incident_id="INC-002",
            incident_type=IncidentType.STRUCTURE_FIRE,
            severity=IncidentSeverity.PRIORITY_2,
            location_x=55.0,
            location_y=48.0,
            reported_at_step=0,
            units_assigned=[],
            status=IncidentStatus.PENDING,
            survival_clock=1200.0,
        ),
    }

    return State(
        units=units,
        incidents=incidents,
        episode_id="ep",
        step_count=10,
        task_id="shift_surge",
        city_time=300.0,
        metadata={
            "districts": ["a", "b", "c"],
            "grid_size": [100, 100],
            "p1_seen": ["INC-001"],
            "resolved_incidents": [],
            "failed_incidents": [],
        },
    )


def test_shift_surge_grader_rewards_good_outcome() -> None:
    state = _base_state()
    state.incidents["INC-001"].status = IncidentStatus.RESOLVED
    state.incidents["INC-002"].status = IncidentStatus.RESOLVED
    state.metadata["resolved_incidents"] = ["INC-001", "INC-002"]

    score = ShiftSurgeGrader().grade(state, rewards=[0.9] * 10)
    assert 0.8 <= score <= 1.0


def test_shift_surge_grader_penalizes_failures_and_backlog() -> None:
    state = _base_state()
    state.incidents["INC-001"].status = IncidentStatus.ESCALATED
    state.incidents["INC-002"].status = IncidentStatus.RESPONDING
    state.metadata["failed_incidents"] = ["INC-001"]

    score = ShiftSurgeGrader().grade(state, rewards=[0.2] * 10)
    assert 0.0 <= score <= 0.4