Spaces:
Sleeping
Sleeping
File size: 8,456 Bytes
4904e85 13517a8 4904e85 4dc3d0a 4904e85 13517a8 4904e85 4dc3d0a 4904e85 4dc3d0a 4904e85 13517a8 4904e85 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 | """Reward engine and grader primitives."""
from pydantic import BaseModel, Field
from src.models import (
Action,
DispatchAction,
IncidentSeverity,
Observation,
State,
UnitStatus,
)
from src.phraseology import PhraseologyJudge
def _clamp01(value: float) -> float:
return max(0.0, min(1.0, float(value)))
def _normalize_enumish_key(value: object) -> str:
"""Normalize keys that may be stored as Enum-ish strings.
We accept forms like:
- "CARDIAC_ARREST"
- "IncidentType.CARDIAC_ARREST"
- "src.models.IncidentType.CARDIAC_ARREST"
- Enum members (IncidentType.CARDIAC_ARREST)
"""
if isinstance(value, str):
text = value
else:
text = getattr(value, "value", None) or str(value)
# If the value looks like a qualified enum name, use the trailing segment.
if "." in text:
return text.split(".")[-1]
return text
def _normalize_str_list(values: object) -> list[str]:
if values is None:
return []
if not isinstance(values, (list, tuple, set)):
return [_normalize_enumish_key(values)]
return [_normalize_enumish_key(v) for v in values]
class RewardSignal(BaseModel):
"""Signal components for reward breakdown."""
model_config = {"extra": "forbid"}
response_time: float = Field(..., ge=0.0, le=1.0)
triage: float = Field(..., ge=0.0, le=1.0)
survival: float = Field(..., ge=0.0, le=1.0)
coverage: float = Field(..., ge=0.0, le=1.0)
protocol: float = Field(..., ge=0.0, le=1.0)
class RewardCalculator:
"""Evaluates dispatcher decisions with response-time, triage, survival, coverage, protocol."""
weights: dict[str, float] = {
"response_time": 0.30,
"triage": 0.25,
"survival": 0.25,
"coverage": 0.12,
"protocol": 0.08,
}
def compute_reward(self, state: State, action: Action, obs: Observation) -> tuple[RewardSignal, float]:
"""Compute reward signal and total weighted score.
Args:
state: Current lifecycle state
action: Action taken by agent
obs: Observation returned by environment
Returns:
Tuple of (reward signal components, total weighted score clamped to [0.0, 1.0])
"""
response_time = self._compute_response_time(state, action)
triage = self._compute_triage(state, action)
survival = self._compute_survival(state)
coverage = self._compute_coverage(state)
protocol = self._compute_protocol(action, obs)
signal = RewardSignal(
response_time=response_time,
triage=triage,
survival=survival,
coverage=coverage,
protocol=protocol,
)
total = self._compute_weighted_total(signal, state)
return signal, total
def _compute_response_time(self, state: State, action: Action) -> float:
"""Score dispatch timeliness via ETA benchmarks.
If no dispatch occurs this step, return a neutral 0.5.
"""
if action.action_type != DispatchAction.DISPATCH:
return 0.5
unit = state.units.get(action.unit_id)
incident = state.incidents.get(action.incident_id)
if unit is None or incident is None:
return 0.0
benchmark: float
if incident.severity == IncidentSeverity.PRIORITY_1:
benchmark = 240.0
elif incident.severity == IncidentSeverity.PRIORITY_2:
benchmark = 480.0
else:
benchmark = 900.0
eta = max(float(unit.eta_seconds), 1e-6)
return _clamp01(benchmark / eta)
def _compute_triage(self, state: State, action: Action) -> float:
"""Score whether dispatched unit type matches the incident's required types."""
if action.action_type != DispatchAction.DISPATCH:
return 0.5
unit = state.units.get(action.unit_id)
incident = state.incidents.get(action.incident_id)
if unit is None or incident is None:
return 0.0
required_map_raw = state.metadata.get("default_required_units", {})
if not isinstance(required_map_raw, dict):
return 0.5
# Normalize metadata so lookups work across serialization styles.
required_map: dict[str, list[str]] = {
_normalize_enumish_key(k): _normalize_str_list(v) for k, v in required_map_raw.items()
}
incident_key = _normalize_enumish_key(incident.incident_type)
required_types = required_map.get(incident_key, [])
if not required_types:
return 0.5
# required_types are stored as strings in metadata (often with enum qualifiers).
if _normalize_enumish_key(unit.unit_type) in set(required_types):
return 1.0
return 0.0
def _compute_survival(self, state: State) -> float:
"""Score survival outcomes for Priority-1 incidents.
Uses state.metadata bookkeeping written by the state machine.
"""
p1_seen: list[str] = list(state.metadata.get("p1_seen", []))
if not p1_seen:
return 1.0
resolved: set[str] = set(state.metadata.get("resolved_incidents", []))
failed: set[str] = set(state.metadata.get("failed_incidents", []))
ok = 0
for incident_id in p1_seen:
if incident_id in resolved and incident_id not in failed:
ok += 1
return _clamp01(ok / max(len(p1_seen), 1))
def _compute_coverage(self, state: State) -> float:
"""Score geographic coverage of AVAILABLE units across districts.
Districts are derived by slicing the x-axis into equal bins.
"""
districts: list[str] = list(state.metadata.get("districts", []))
grid_size = state.metadata.get("grid_size")
if not districts or not grid_size:
return 1.0
width = float(grid_size[0])
if width <= 0.0:
return 1.0
covered: set[int] = set()
bin_width = width / len(districts)
for unit in state.units.values():
if unit.status != UnitStatus.AVAILABLE:
continue
idx = int(min(len(districts) - 1, max(0.0, unit.location_x) // max(bin_width, 1e-6)))
covered.add(idx)
return _clamp01(len(covered) / len(districts))
def _compute_protocol(self, action: Action, obs: Observation) -> float:
"""Score action protocol + phraseology quality.
- If the action is illegal, protocol score is 0.0.
- If action is legal and no phraseology is provided (`Action.notes`), return neutral 0.5.
- If phraseology is provided, use PhraseologyJudge to score correctness/readback.
"""
if not obs.protocol_ok:
return 0.0
candidate = (action.notes or "").strip()
if not candidate:
return 0.5
judge = PhraseologyJudge()
phrase_score = float(judge.score(action, candidate))
readback_score = 1.0 if judge.check_readback(candidate, action) else 0.0
return _clamp01(0.6 * phrase_score + 0.4 * readback_score)
def _compute_weighted_total(self, signal: RewardSignal, state: State) -> float:
total = (
signal.response_time * self.weights["response_time"]
+ signal.triage * self.weights["triage"]
+ signal.survival * self.weights["survival"]
+ signal.coverage * self.weights["coverage"]
+ signal.protocol * self.weights["protocol"]
)
total = _clamp01(total)
# Dominance rule: if any Priority-1 incidents existed and survival == 0.0, cap score.
if state.metadata.get("p1_seen") and signal.survival == 0.0:
total = min(total, 0.2)
return total
class TaskGrader:
"""Aggregates episode rewards and returns final normalized score."""
def grade_episode(self, episode_rewards: list[float], task_id: str) -> float:
"""Aggregate rewards over episode and return final score.
Args:
episode_rewards: List of per-step reward values
task_id: Task identifier (unused in base grader)
Returns:
Final score in [0.0, 1.0]
"""
if not episode_rewards:
return 0.0
total = sum(episode_rewards)
avg = total / len(episode_rewards)
return max(0.0, min(1.0, avg))
|