File size: 8,456 Bytes
4904e85
 
 
 
 
 
 
 
 
 
 
 
13517a8
4904e85
 
 
 
 
 
4dc3d0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4904e85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13517a8
4904e85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4dc3d0a
 
 
 
 
 
 
 
 
 
 
4904e85
 
 
4dc3d0a
 
4904e85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13517a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4904e85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
"""Reward engine and grader primitives."""

from pydantic import BaseModel, Field

from src.models import (
    Action,
    DispatchAction,
    IncidentSeverity,
    Observation,
    State,
    UnitStatus,
)
from src.phraseology import PhraseologyJudge


def _clamp01(value: float) -> float:
    return max(0.0, min(1.0, float(value)))


def _normalize_enumish_key(value: object) -> str:
    """Normalize keys that may be stored as Enum-ish strings.

    We accept forms like:
    - "CARDIAC_ARREST"
    - "IncidentType.CARDIAC_ARREST"
    - "src.models.IncidentType.CARDIAC_ARREST"
    - Enum members (IncidentType.CARDIAC_ARREST)
    """

    if isinstance(value, str):
        text = value
    else:
        text = getattr(value, "value", None) or str(value)

    # If the value looks like a qualified enum name, use the trailing segment.
    if "." in text:
        return text.split(".")[-1]
    return text


def _normalize_str_list(values: object) -> list[str]:
    if values is None:
        return []
    if not isinstance(values, (list, tuple, set)):
        return [_normalize_enumish_key(values)]
    return [_normalize_enumish_key(v) for v in values]


class RewardSignal(BaseModel):
    """Signal components for reward breakdown."""

    model_config = {"extra": "forbid"}

    response_time: float = Field(..., ge=0.0, le=1.0)
    triage: float = Field(..., ge=0.0, le=1.0)
    survival: float = Field(..., ge=0.0, le=1.0)
    coverage: float = Field(..., ge=0.0, le=1.0)
    protocol: float = Field(..., ge=0.0, le=1.0)


class RewardCalculator:
    """Evaluates dispatcher decisions with response-time, triage, survival, coverage, protocol."""

    weights: dict[str, float] = {
        "response_time": 0.30,
        "triage": 0.25,
        "survival": 0.25,
        "coverage": 0.12,
        "protocol": 0.08,
    }

    def compute_reward(self, state: State, action: Action, obs: Observation) -> tuple[RewardSignal, float]:
        """Compute reward signal and total weighted score.

        Args:
            state: Current lifecycle state
            action: Action taken by agent
            obs: Observation returned by environment

        Returns:
            Tuple of (reward signal components, total weighted score clamped to [0.0, 1.0])
        """
        response_time = self._compute_response_time(state, action)
        triage = self._compute_triage(state, action)
        survival = self._compute_survival(state)
        coverage = self._compute_coverage(state)
        protocol = self._compute_protocol(action, obs)

        signal = RewardSignal(
            response_time=response_time,
            triage=triage,
            survival=survival,
            coverage=coverage,
            protocol=protocol,
        )

        total = self._compute_weighted_total(signal, state)

        return signal, total

    def _compute_response_time(self, state: State, action: Action) -> float:
        """Score dispatch timeliness via ETA benchmarks.

        If no dispatch occurs this step, return a neutral 0.5.
        """
        if action.action_type != DispatchAction.DISPATCH:
            return 0.5

        unit = state.units.get(action.unit_id)
        incident = state.incidents.get(action.incident_id)
        if unit is None or incident is None:
            return 0.0

        benchmark: float
        if incident.severity == IncidentSeverity.PRIORITY_1:
            benchmark = 240.0
        elif incident.severity == IncidentSeverity.PRIORITY_2:
            benchmark = 480.0
        else:
            benchmark = 900.0

        eta = max(float(unit.eta_seconds), 1e-6)
        return _clamp01(benchmark / eta)

    def _compute_triage(self, state: State, action: Action) -> float:
        """Score whether dispatched unit type matches the incident's required types."""
        if action.action_type != DispatchAction.DISPATCH:
            return 0.5

        unit = state.units.get(action.unit_id)
        incident = state.incidents.get(action.incident_id)
        if unit is None or incident is None:
            return 0.0

        required_map_raw = state.metadata.get("default_required_units", {})
        if not isinstance(required_map_raw, dict):
            return 0.5

        # Normalize metadata so lookups work across serialization styles.
        required_map: dict[str, list[str]] = {
            _normalize_enumish_key(k): _normalize_str_list(v) for k, v in required_map_raw.items()
        }

        incident_key = _normalize_enumish_key(incident.incident_type)
        required_types = required_map.get(incident_key, [])
        if not required_types:
            return 0.5

        # required_types are stored as strings in metadata (often with enum qualifiers).
        if _normalize_enumish_key(unit.unit_type) in set(required_types):
            return 1.0
        return 0.0

    def _compute_survival(self, state: State) -> float:
        """Score survival outcomes for Priority-1 incidents.

        Uses state.metadata bookkeeping written by the state machine.
        """
        p1_seen: list[str] = list(state.metadata.get("p1_seen", []))
        if not p1_seen:
            return 1.0

        resolved: set[str] = set(state.metadata.get("resolved_incidents", []))
        failed: set[str] = set(state.metadata.get("failed_incidents", []))

        ok = 0
        for incident_id in p1_seen:
            if incident_id in resolved and incident_id not in failed:
                ok += 1
        return _clamp01(ok / max(len(p1_seen), 1))

    def _compute_coverage(self, state: State) -> float:
        """Score geographic coverage of AVAILABLE units across districts.

        Districts are derived by slicing the x-axis into equal bins.
        """
        districts: list[str] = list(state.metadata.get("districts", []))
        grid_size = state.metadata.get("grid_size")

        if not districts or not grid_size:
            return 1.0

        width = float(grid_size[0])
        if width <= 0.0:
            return 1.0

        covered: set[int] = set()
        bin_width = width / len(districts)
        for unit in state.units.values():
            if unit.status != UnitStatus.AVAILABLE:
                continue
            idx = int(min(len(districts) - 1, max(0.0, unit.location_x) // max(bin_width, 1e-6)))
            covered.add(idx)

        return _clamp01(len(covered) / len(districts))

    def _compute_protocol(self, action: Action, obs: Observation) -> float:
        """Score action protocol + phraseology quality.

        - If the action is illegal, protocol score is 0.0.
        - If action is legal and no phraseology is provided (`Action.notes`), return neutral 0.5.
        - If phraseology is provided, use PhraseologyJudge to score correctness/readback.
        """

        if not obs.protocol_ok:
            return 0.0

        candidate = (action.notes or "").strip()
        if not candidate:
            return 0.5

        judge = PhraseologyJudge()
        phrase_score = float(judge.score(action, candidate))
        readback_score = 1.0 if judge.check_readback(candidate, action) else 0.0
        return _clamp01(0.6 * phrase_score + 0.4 * readback_score)

    def _compute_weighted_total(self, signal: RewardSignal, state: State) -> float:
        total = (
            signal.response_time * self.weights["response_time"]
            + signal.triage * self.weights["triage"]
            + signal.survival * self.weights["survival"]
            + signal.coverage * self.weights["coverage"]
            + signal.protocol * self.weights["protocol"]
        )

        total = _clamp01(total)

        # Dominance rule: if any Priority-1 incidents existed and survival == 0.0, cap score.
        if state.metadata.get("p1_seen") and signal.survival == 0.0:
            total = min(total, 0.2)

        return total


class TaskGrader:
    """Aggregates episode rewards and returns final normalized score."""

    def grade_episode(self, episode_rewards: list[float], task_id: str) -> float:
        """Aggregate rewards over episode and return final score.

        Args:
            episode_rewards: List of per-step reward values
            task_id: Task identifier (unused in base grader)

        Returns:
            Final score in [0.0, 1.0]
        """
        if not episode_rewards:
            return 0.0

        total = sum(episode_rewards)
        avg = total / len(episode_rewards)

        return max(0.0, min(1.0, avg))