"""Reward engine and grader primitives.""" from pydantic import BaseModel, Field from src.models import ( Action, DispatchAction, IncidentSeverity, Observation, State, UnitStatus, ) from src.phraseology import PhraseologyJudge def _clamp01(value: float) -> float: return max(0.0, min(1.0, float(value))) def _normalize_enumish_key(value: object) -> str: """Normalize keys that may be stored as Enum-ish strings. We accept forms like: - "CARDIAC_ARREST" - "IncidentType.CARDIAC_ARREST" - "src.models.IncidentType.CARDIAC_ARREST" - Enum members (IncidentType.CARDIAC_ARREST) """ if isinstance(value, str): text = value else: text = getattr(value, "value", None) or str(value) # If the value looks like a qualified enum name, use the trailing segment. if "." in text: return text.split(".")[-1] return text def _normalize_str_list(values: object) -> list[str]: if values is None: return [] if not isinstance(values, (list, tuple, set)): return [_normalize_enumish_key(values)] return [_normalize_enumish_key(v) for v in values] class RewardSignal(BaseModel): """Signal components for reward breakdown.""" model_config = {"extra": "forbid"} response_time: float = Field(..., ge=0.0, le=1.0) triage: float = Field(..., ge=0.0, le=1.0) survival: float = Field(..., ge=0.0, le=1.0) coverage: float = Field(..., ge=0.0, le=1.0) protocol: float = Field(..., ge=0.0, le=1.0) class RewardCalculator: """Evaluates dispatcher decisions with response-time, triage, survival, coverage, protocol.""" weights: dict[str, float] = { "response_time": 0.30, "triage": 0.25, "survival": 0.25, "coverage": 0.12, "protocol": 0.08, } def compute_reward(self, state: State, action: Action, obs: Observation) -> tuple[RewardSignal, float]: """Compute reward signal and total weighted score. Args: state: Current lifecycle state action: Action taken by agent obs: Observation returned by environment Returns: Tuple of (reward signal components, total weighted score clamped to [0.0, 1.0]) """ response_time = self._compute_response_time(state, action) triage = self._compute_triage(state, action) survival = self._compute_survival(state) coverage = self._compute_coverage(state) protocol = self._compute_protocol(action, obs) signal = RewardSignal( response_time=response_time, triage=triage, survival=survival, coverage=coverage, protocol=protocol, ) total = self._compute_weighted_total(signal, state) return signal, total def _compute_response_time(self, state: State, action: Action) -> float: """Score dispatch timeliness via ETA benchmarks. If no dispatch occurs this step, return a neutral 0.5. """ if action.action_type != DispatchAction.DISPATCH: return 0.5 unit = state.units.get(action.unit_id) incident = state.incidents.get(action.incident_id) if unit is None or incident is None: return 0.0 benchmark: float if incident.severity == IncidentSeverity.PRIORITY_1: benchmark = 240.0 elif incident.severity == IncidentSeverity.PRIORITY_2: benchmark = 480.0 else: benchmark = 900.0 eta = max(float(unit.eta_seconds), 1e-6) return _clamp01(benchmark / eta) def _compute_triage(self, state: State, action: Action) -> float: """Score whether dispatched unit type matches the incident's required types.""" if action.action_type != DispatchAction.DISPATCH: return 0.5 unit = state.units.get(action.unit_id) incident = state.incidents.get(action.incident_id) if unit is None or incident is None: return 0.0 required_map_raw = state.metadata.get("default_required_units", {}) if not isinstance(required_map_raw, dict): return 0.5 # Normalize metadata so lookups work across serialization styles. required_map: dict[str, list[str]] = { _normalize_enumish_key(k): _normalize_str_list(v) for k, v in required_map_raw.items() } incident_key = _normalize_enumish_key(incident.incident_type) required_types = required_map.get(incident_key, []) if not required_types: return 0.5 # required_types are stored as strings in metadata (often with enum qualifiers). if _normalize_enumish_key(unit.unit_type) in set(required_types): return 1.0 return 0.0 def _compute_survival(self, state: State) -> float: """Score survival outcomes for Priority-1 incidents. Uses state.metadata bookkeeping written by the state machine. """ p1_seen: list[str] = list(state.metadata.get("p1_seen", [])) if not p1_seen: return 1.0 resolved: set[str] = set(state.metadata.get("resolved_incidents", [])) failed: set[str] = set(state.metadata.get("failed_incidents", [])) ok = 0 for incident_id in p1_seen: if incident_id in resolved and incident_id not in failed: ok += 1 return _clamp01(ok / max(len(p1_seen), 1)) def _compute_coverage(self, state: State) -> float: """Score geographic coverage of AVAILABLE units across districts. Districts are derived by slicing the x-axis into equal bins. """ districts: list[str] = list(state.metadata.get("districts", [])) grid_size = state.metadata.get("grid_size") if not districts or not grid_size: return 1.0 width = float(grid_size[0]) if width <= 0.0: return 1.0 covered: set[int] = set() bin_width = width / len(districts) for unit in state.units.values(): if unit.status != UnitStatus.AVAILABLE: continue idx = int(min(len(districts) - 1, max(0.0, unit.location_x) // max(bin_width, 1e-6))) covered.add(idx) return _clamp01(len(covered) / len(districts)) def _compute_protocol(self, action: Action, obs: Observation) -> float: """Score action protocol + phraseology quality. - If the action is illegal, protocol score is 0.0. - If action is legal and no phraseology is provided (`Action.notes`), return neutral 0.5. - If phraseology is provided, use PhraseologyJudge to score correctness/readback. """ if not obs.protocol_ok: return 0.0 candidate = (action.notes or "").strip() if not candidate: return 0.5 judge = PhraseologyJudge() phrase_score = float(judge.score(action, candidate)) readback_score = 1.0 if judge.check_readback(candidate, action) else 0.0 return _clamp01(0.6 * phrase_score + 0.4 * readback_score) def _compute_weighted_total(self, signal: RewardSignal, state: State) -> float: total = ( signal.response_time * self.weights["response_time"] + signal.triage * self.weights["triage"] + signal.survival * self.weights["survival"] + signal.coverage * self.weights["coverage"] + signal.protocol * self.weights["protocol"] ) total = _clamp01(total) # Dominance rule: if any Priority-1 incidents existed and survival == 0.0, cap score. if state.metadata.get("p1_seen") and signal.survival == 0.0: total = min(total, 0.2) return total class TaskGrader: """Aggregates episode rewards and returns final normalized score.""" def grade_episode(self, episode_rewards: list[float], task_id: str) -> float: """Aggregate rewards over episode and return final score. Args: episode_rewards: List of per-step reward values task_id: Task identifier (unused in base grader) Returns: Final score in [0.0, 1.0] """ if not episode_rewards: return 0.0 total = sum(episode_rewards) avg = total / len(episode_rewards) return max(0.0, min(1.0, avg))