Spaces:
Sleeping
Sleeping
| """ | |
| server/environment.py — Enhanced 911 Dispatch Triage Environment v2 | |
| WHAT IS HAPPENING HERE | |
| ---------------------- | |
| This is a genuine multi-step RL environment. Here is the full episode lifecycle: | |
| 1. Episode starts → incidents arrive (with people counts, call descriptions) | |
| 2. Agent observes the full board: who needs help, which units are free | |
| 3. Agent dispatches ONE unit to ONE incident per step | |
| 4. Time ticks every step: | |
| - Unresolved incidents accumulate steps_waiting (severity decays) | |
| - Fire incidents spread (severity increases every FIRE_SPREAD_INTERVAL steps) | |
| - En-route units count down their travel time → become available again | |
| 5. When units return, agent dispatches them again | |
| 6. Episode ends when all incidents resolved OR max_steps reached | |
| REWARD MATHEMATICS (always in [0, 1]) | |
| -------------------------------------- | |
| For each dispatched incident: | |
| contribution = severity | |
| × log(1 + people_count) ← people multiplier | |
| × e^(-DECAY_LAMBDA × wait) ← time decay | |
| × match_quality(type, unit) ← unit type effectiveness | |
| For each unresolved incident at episode end: | |
| penalty = severity × log(1 + people_count) × UNRESOLVED_PENALTY_FRACTION | |
| For cascade violations (hard mode): | |
| penalty += CASCADE_PENALTY per violation | |
| final_score = max(0, sum(contributions) - sum(penalties)) / max_possible | |
| ∈ [0, 1] always | |
| max_possible = sum(severity × log(1+people) for ALL incidents) dispatched instantly | |
| perfect match. Upper bound — agent approaches but rarely reaches 1.0. | |
| WHY THIS IS REAL RL (not just sorting) | |
| ---------------------------------------- | |
| - Agent must learn that people_count matters more than raw severity | |
| - Agent must learn optimal dispatch timing (units returning from medium priority | |
| calls might be better held for an incoming high priority call) | |
| - Agent must learn cascade dependencies (gas_leak before cardiac) | |
| - Agent must learn fire is a time bomb (severity grows each step) | |
| - None of this is told to the agent — it discovers it from reward signals | |
| """ | |
| import uuid | |
| import math | |
| from typing import List, Optional | |
| from copy import deepcopy | |
| try: | |
| from openenv.core.env_server import Environment | |
| except ModuleNotFoundError: | |
| import sys, os | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) | |
| from openenv_stubs import Environment | |
| try: | |
| from models import ( | |
| DispatchAction, DispatchObservation, DispatchState, | |
| Incident, Unit, | |
| DECAY_LAMBDA, FIRE_SPREAD_INTERVAL, | |
| UNRESOLVED_PENALTY_FRACTION, CASCADE_PENALTY, | |
| get_match_quality, effective_priority, | |
| ) | |
| except ModuleNotFoundError: | |
| from ..models import ( | |
| DispatchAction, DispatchObservation, DispatchState, | |
| Incident, Unit, | |
| DECAY_LAMBDA, FIRE_SPREAD_INTERVAL, | |
| UNRESOLVED_PENALTY_FRACTION, CASCADE_PENALTY, | |
| get_match_quality, effective_priority, | |
| ) | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # Scenario definitions | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # | |
| # Each incident has: | |
| # call_description — raw 911 call text (LLM reads this to infer severity) | |
| # people_count — how many people at risk | |
| # fire_spreads — whether severity grows over time | |
| # | |
| # The agent sees severity + people_count as numbers and learns the weighting. | |
| # The LLM can additionally re-assess severity from call_description text. | |
| SCENARIOS = { | |
| # ────────────────────────────────────────────────────────────────── | |
| # EASY — 3 incidents, 3 units, travel_time=1, max_steps=6 | |
| # | |
| # Lesson: dispatch highest effective priority first. | |
| # cardiac + 1 person looks obvious, but agent must learn | |
| # that crash + 3 people > fire + 1 person despite equal raw severity. | |
| # ────────────────────────────────────────────────────────────────── | |
| "easy": { | |
| "max_steps": 6, | |
| "incidents": [ | |
| Incident( | |
| id=0, type="cardiac_arrest", severity=9, location="Block 2A", | |
| people_count=1, fire_spreads=False, | |
| call_description=( | |
| "Caller reports elderly man collapsed, not breathing. " | |
| "CPR being attempted. Address: 12 Oak Street, Block 2A." | |
| ), | |
| ), | |
| Incident( | |
| id=1, type="car_crash", severity=5, location="Block 7C", | |
| people_count=3, fire_spreads=False, | |
| call_description=( | |
| "Three-car collision at Block 7C. Three people injured, " | |
| "one trapped in vehicle. No fire visible." | |
| ), | |
| ), | |
| Incident( | |
| id=2, type="fire", severity=3, location="Block 1B", | |
| people_count=1, fire_spreads=True, | |
| call_description=( | |
| "Small kitchen fire at Block 1B apartment. One resident " | |
| "evacuated. Fire contained to one room so far." | |
| ), | |
| ), | |
| ], | |
| "units": [ | |
| Unit(id=0, type="ambulance", travel_time=1), | |
| Unit(id=1, type="police", travel_time=1), | |
| Unit(id=2, type="fire_truck", travel_time=1), | |
| ], | |
| }, | |
| # ────────────────────────────────────────────────────────────────── | |
| # MEDIUM — 5 incidents, 3 units, travel_time=2, max_steps=12 | |
| # | |
| # Lessons: | |
| # 1. Units return after 2 steps — agent must plan across waves | |
| # 2. car_crash with 8 people (effective priority 8.79) beats | |
| # gas_leak with 0 people (effective priority 0.0) despite | |
| # gas_leak having higher raw severity | |
| # 3. Fire spreads — delay costs grow non-linearly | |
| # ────────────────────────────────────────────────────────────────── | |
| "medium": { | |
| "max_steps": 12, | |
| "incidents": [ | |
| Incident( | |
| id=0, type="cardiac_arrest", severity=9, location="Block 3A", | |
| people_count=2, fire_spreads=False, | |
| call_description=( | |
| "Two people collapsed at Block 3A community centre, " | |
| "possible carbon monoxide poisoning. Both unresponsive." | |
| ), | |
| ), | |
| Incident( | |
| id=1, type="fire", severity=7, location="Block 5D", | |
| people_count=6, fire_spreads=True, | |
| call_description=( | |
| "Large building fire at Block 5D. Flames visible from " | |
| "third floor. At least 6 residents unable to evacuate." | |
| ), | |
| ), | |
| Incident( | |
| id=2, type="car_crash", severity=5, location="Block 9B", | |
| people_count=8, fire_spreads=False, | |
| call_description=( | |
| "Major crash on Block 9B expressway. Multiple vehicles. " | |
| "Caller reports 8 people injured, one vehicle on its side." | |
| ), | |
| ), | |
| Incident( | |
| id=3, type="gas_leak", severity=6, location="Block 2C", | |
| people_count=0, fire_spreads=False, | |
| call_description=( | |
| "Strong gas smell reported at Block 2C. Building evacuated. " | |
| "No injuries yet but area needs to be secured." | |
| ), | |
| ), | |
| Incident( | |
| id=4, type="car_crash", severity=3, location="Block 11E", | |
| people_count=1, fire_spreads=False, | |
| call_description=( | |
| "Minor fender-bender at Block 11E. One driver with " | |
| "minor cuts. No serious injuries reported." | |
| ), | |
| ), | |
| ], | |
| "units": [ | |
| Unit(id=0, type="ambulance", travel_time=2), | |
| Unit(id=1, type="fire_truck", travel_time=2), | |
| Unit(id=2, type="police", travel_time=2), | |
| ], | |
| }, | |
| # ────────────────────────────────────────────────────────────────── | |
| # HARD — 7 incidents, 3 units, travel_time=2, max_steps=18 | |
| # | |
| # Lessons: | |
| # 1. Cascade: cardiac (id=0) depends on gas_leak (id=1). | |
| # Gas leak at Block 4B is adjacent. Dispatching cardiac first | |
| # without clearing the gas leak triggers CASCADE_PENALTY. | |
| # 2. Fire at Block 9C has 12 people AND spreads. | |
| # Every 2 steps ignored, severity +1. After 4 steps: sev=10. | |
| # 3. Three waves of dispatch needed — agent must plan 6 steps ahead. | |
| # 4. car_crash id=3 has 5 people. Despite sev=5, effective priority | |
| # = 5 × log(6) = 8.96 — nearly as urgent as the cardiac. | |
| # ────────────────────────────────────────────────────────────────── | |
| "hard": { | |
| "max_steps": 18, | |
| "incidents": [ | |
| Incident( | |
| id=0, type="cardiac_arrest", severity=7, location="Block 4A", | |
| people_count=1, fire_spreads=False, depends_on=[1], | |
| call_description=( | |
| "Man having heart attack at Block 4A, next to the building " | |
| "with the reported gas leak. Caller is panicking. " | |
| "Address same block as the gas incident." | |
| ), | |
| ), | |
| Incident( | |
| id=1, type="gas_leak", severity=6, location="Block 4B", | |
| people_count=0, fire_spreads=False, | |
| call_description=( | |
| "Major gas leak at Block 4B, adjacent to Block 4A. " | |
| "Strong smell reported. Area not yet evacuated. " | |
| "Risk of explosion if ignition source present." | |
| ), | |
| ), | |
| Incident( | |
| id=2, type="fire", severity=8, location="Block 9C", | |
| people_count=12, fire_spreads=True, | |
| call_description=( | |
| "Warehouse fire at Block 9C, spreading rapidly. " | |
| "Night shift workers trapped inside, approximately 12 people. " | |
| "Flames visible from street." | |
| ), | |
| ), | |
| Incident( | |
| id=3, type="car_crash", severity=5, location="Block 2D", | |
| people_count=5, fire_spreads=False, | |
| call_description=( | |
| "Head-on collision at Block 2D intersection. Five occupants, " | |
| "multiple injuries. One child among the injured. " | |
| "Vehicles blocking traffic." | |
| ), | |
| ), | |
| Incident( | |
| id=4, type="car_crash", severity=4, location="Block 6E", | |
| people_count=2, fire_spreads=False, | |
| call_description=( | |
| "Vehicle hit a lamp post at Block 6E. Driver and passenger " | |
| "injured. Airbags deployed. Both conscious." | |
| ), | |
| ), | |
| Incident( | |
| id=5, type="fire", severity=3, location="Block 1F", | |
| people_count=0, fire_spreads=True, | |
| call_description=( | |
| "Bin fire at Block 1F alley. No people involved. " | |
| "Risk of spreading to nearby building if not contained." | |
| ), | |
| ), | |
| Incident( | |
| id=6, type="cardiac_arrest", severity=2, location="Block 12G", | |
| people_count=1, fire_spreads=False, | |
| call_description=( | |
| "Elderly woman feeling chest pains at Block 12G. " | |
| "Conscious and breathing. Not a confirmed cardiac event yet." | |
| ), | |
| ), | |
| ], | |
| "units": [ | |
| Unit(id=0, type="ambulance", travel_time=2), | |
| Unit(id=1, type="fire_truck", travel_time=2), | |
| Unit(id=2, type="police", travel_time=2), | |
| ], | |
| }, | |
| } | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # Helpers | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| def _compute_max_possible(incidents: List[Incident]) -> float: | |
| """ | |
| Realistic sequential optimum: sort incidents by EP descending, then | |
| assign the minimum achievable wait = dispatch_index (because each | |
| step() call ticks time once for every incident). | |
| Why this matters | |
| ---------------- | |
| The old formula used wait=0 for ALL incidents — impossible with | |
| sequential dispatch. With N incidents, the k-th dispatch always | |
| incurs at least k steps of waiting for the remaining incidents. | |
| The new formula makes score = 1.0 reachable when the agent: | |
| 1. Dispatches highest-EP incident first (lowest cumulative decay) | |
| 2. Matches units correctly (match_quality = 1.0) | |
| 3. Never wastes steps | |
| This gives a meaningful target the agent can actually hit. | |
| """ | |
| ep_sorted = sorted( | |
| incidents, | |
| key=lambda inc: inc.severity * math.log(1 + inc.people_count), | |
| reverse=True, | |
| ) | |
| total = sum( | |
| inc.severity | |
| * math.log(1 + inc.people_count) | |
| * math.exp(-DECAY_LAMBDA * idx) # minimum wait = dispatch order index | |
| * 1.0 # assume perfect unit match | |
| for idx, inc in enumerate(ep_sorted) | |
| ) | |
| return max(total, 1.0) # guard against division by zero | |
| def _compute_contribution(inc: Incident, unit_type: str) -> float: | |
| """ | |
| Score contribution for dispatching unit_type to incident at its | |
| current steps_waiting. | |
| = severity × log(1+people) × e^(-λ×wait) × match_quality | |
| """ | |
| ep = inc.severity * math.log(1 + inc.people_count) | |
| decay = math.exp(-DECAY_LAMBDA * inc.steps_waiting) | |
| match = get_match_quality(inc.type, unit_type) | |
| return ep * decay * match | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # Environment | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| class DispatchEnvironment(Environment): | |
| """ | |
| Multi-step 911 dispatch triage environment. | |
| One step = one dispatch action + one time tick. | |
| One episode = multiple steps until all resolved or max_steps reached. | |
| The agent interacts via: | |
| obs = env.reset(difficulty="easy"|"medium"|"hard") | |
| obs = env.step(DispatchAction(incident_id=X, unit_id=Y)) | |
| state = env.state | |
| """ | |
| SUPPORTS_CONCURRENT_SESSIONS = True | |
| def __init__(self): | |
| self._incidents: List[Incident] = [] | |
| self._units: List[Unit] = [] | |
| self._step_count: int = 0 | |
| self._max_steps: int = 10 | |
| self._dispatch_count: int = 0 | |
| self._raw_score: float = 0.0 | |
| self._penalties: float = 0.0 | |
| self._max_possible: float = 1.0 | |
| self._state = DispatchState() | |
| # ────────────────────────────────────────────────────────────────── | |
| # reset | |
| # ────────────────────────────────────────────────────────────────── | |
| def reset( | |
| self, | |
| seed=None, | |
| episode_id=None, | |
| difficulty: str = "easy", | |
| **kwargs, | |
| ) -> DispatchObservation: | |
| if difficulty not in SCENARIOS: | |
| difficulty = "easy" | |
| scenario = SCENARIOS[difficulty] | |
| self._incidents = [i.model_copy(deep=True) for i in scenario["incidents"]] | |
| self._units = [u.model_copy(deep=True) for u in scenario["units"]] | |
| self._max_steps = scenario["max_steps"] | |
| self._step_count = 0 | |
| self._dispatch_count = 0 | |
| self._raw_score = 0.0 | |
| self._penalties = 0.0 | |
| self._max_possible = _compute_max_possible(self._incidents) | |
| self._state = DispatchState( | |
| episode_id=episode_id or str(uuid.uuid4()), | |
| step_count=0, | |
| difficulty=difficulty, | |
| total_incidents=len(self._incidents), | |
| total_units=len(self._units), | |
| max_steps=self._max_steps, | |
| max_possible_score=self._max_possible, | |
| ) | |
| return self._make_obs( | |
| done=False, | |
| message=( | |
| f"[{difficulty.upper()}] {len(self._incidents)} incidents, " | |
| f"{len(self._units)} units available, " | |
| f"{self._max_steps} steps budget. " | |
| f"Dispatch wisely — people count and time decay matter." | |
| ), | |
| ) | |
| # ────────────────────────────────────────────────────────────────── | |
| # step | |
| # ────────────────────────────────────────────────────────────────── | |
| def step( | |
| self, | |
| action: DispatchAction, | |
| timeout_s=None, | |
| **kwargs, | |
| ) -> DispatchObservation: | |
| self._step_count += 1 | |
| self._state.step_count = self._step_count | |
| notes = [] | |
| # ── Validate ───────────────────────────────────────────────── | |
| incident = self._find_incident(action.incident_id) | |
| unit = self._find_unit(action.unit_id) | |
| if incident is None: | |
| self._tick_time() | |
| return self._make_obs( | |
| done=self._is_done(), | |
| message=f"Invalid incident id {action.incident_id}. Time still ticked.", | |
| ) | |
| if unit is None: | |
| self._tick_time() | |
| return self._make_obs( | |
| done=self._is_done(), | |
| message=f"Invalid unit id {action.unit_id}. Time still ticked.", | |
| ) | |
| if not unit.available: | |
| self._tick_time() | |
| return self._make_obs( | |
| done=self._is_done(), | |
| message=( | |
| f"Unit {action.unit_id} ({unit.type}) is en route — " | |
| f"returns in {unit.steps_until_free} step(s). Time ticked." | |
| ), | |
| ) | |
| if incident.resolved: | |
| self._tick_time() | |
| return self._make_obs( | |
| done=self._is_done(), | |
| message=f"Incident {action.incident_id} already resolved. Time ticked.", | |
| ) | |
| # ── Check cascade ───────────────────────────────────────────── | |
| cascade_triggered = False | |
| if incident.depends_on: | |
| unresolved_deps = [ | |
| d for d in incident.depends_on | |
| if not self._find_incident(d).resolved | |
| ] | |
| if unresolved_deps: | |
| cascade_triggered = True | |
| self._penalties += CASCADE_PENALTY | |
| notes.append( | |
| f"CASCADE PENALTY -{CASCADE_PENALTY}: " | |
| f"dependency incident(s) {unresolved_deps} unresolved!" | |
| ) | |
| # ── Dispatch ────────────────────────────────────────────────── | |
| self._dispatch_count += 1 | |
| contribution = _compute_contribution(incident, unit.type) | |
| self._raw_score += contribution | |
| unit.available = False | |
| unit.steps_until_free = unit.travel_time | |
| incident.resolved = True | |
| incident.assigned_unit_id = unit.id | |
| # Build dispatch note | |
| ep = incident.severity * math.log(1 + incident.people_count) | |
| decay = math.exp(-DECAY_LAMBDA * incident.steps_waiting) | |
| match = get_match_quality(incident.type, unit.type) | |
| notes.append( | |
| f"Dispatched {unit.type} -> {incident.type} at {incident.location} " | |
| f"[sev={incident.severity} people={incident.people_count} " | |
| f"wait={incident.steps_waiting}s] " | |
| f"contribution={contribution:.3f} " | |
| f"(ep={ep:.2f} x decay={decay:.2f} x match={match:.1f})" | |
| ) | |
| if match < 0.5: | |
| notes.append(f"WRONG UNIT TYPE — match quality only {match:.1f}") | |
| # ── Tick time ───────────────────────────────────────────────── | |
| spread_notes = self._tick_time() | |
| notes.extend(spread_notes) | |
| # ── Done? ───────────────────────────────────────────────────── | |
| done = self._is_done() | |
| score = self._current_score() | |
| if done: | |
| # Apply unresolved penalties | |
| unresolved = [i for i in self._incidents if not i.resolved] | |
| for inc in unresolved: | |
| pen = ( | |
| inc.severity | |
| * math.log(1 + inc.people_count) | |
| * UNRESOLVED_PENALTY_FRACTION | |
| ) | |
| self._penalties += pen | |
| notes.append( | |
| f"UNRESOLVED PENALTY -{pen:.3f}: " | |
| f"{inc.type} at {inc.location} " | |
| f"[sev={inc.severity} people={inc.people_count}]" | |
| ) | |
| score = self._current_score() | |
| notes.append(f"EPISODE DONE. Final score: {score:.4f}") | |
| return DispatchObservation( | |
| done=done, | |
| reward=score if done else None, | |
| incidents=deepcopy(self._incidents), | |
| units=deepcopy(self._units), | |
| step_count=self._step_count, | |
| max_steps=self._max_steps, | |
| dispatch_count=self._dispatch_count, | |
| score_so_far=score, | |
| message=" | ".join(notes), | |
| ) | |
| # ────────────────────────────────────────────────────────────────── | |
| # state | |
| # ────────────────────────────────────────────────────────────────── | |
| def state(self) -> DispatchState: | |
| return self._state | |
| # ────────────────────────────────────────────────────────────────── | |
| # Private helpers | |
| # ────────────────────────────────────────────────────────────────── | |
| def _tick_time(self) -> List[str]: | |
| """ | |
| Advance the world by one time step. | |
| - Units en route: decrement countdown, return to base when 0 | |
| - Unresolved incidents: accumulate waiting time | |
| - Fire incidents: spread (severity +1) every FIRE_SPREAD_INTERVAL steps | |
| Returns list of notable event strings. | |
| """ | |
| notes = [] | |
| # Units returning to base | |
| for u in self._units: | |
| if not u.available and u.steps_until_free > 0: | |
| u.steps_until_free -= 1 | |
| if u.steps_until_free == 0: | |
| u.available = True | |
| notes.append(f"{u.type} (id={u.id}) returned to base — available.") | |
| # Incidents waiting | |
| for inc in self._incidents: | |
| if not inc.resolved: | |
| inc.steps_waiting += 1 | |
| # Fire spreads | |
| if ( | |
| inc.fire_spreads | |
| and inc.steps_waiting > 0 | |
| and inc.steps_waiting % FIRE_SPREAD_INTERVAL == 0 | |
| and inc.severity < 10 | |
| ): | |
| inc.severity += 1 | |
| notes.append( | |
| f"FIRE SPREAD at {inc.location}: " | |
| f"severity now {inc.severity}!" | |
| ) | |
| return notes | |
| def _is_done(self) -> bool: | |
| all_resolved = all(i.resolved for i in self._incidents) | |
| max_steps_hit = self._step_count >= self._max_steps | |
| # No units available AND none returning AND unresolved incidents exist | |
| no_help_possible = ( | |
| any(not i.resolved for i in self._incidents) | |
| and not any(u.available for u in self._units) | |
| and not any(u.steps_until_free > 0 for u in self._units) | |
| ) | |
| return all_resolved or max_steps_hit or no_help_possible | |
| def _current_score(self) -> float: | |
| net = max(0.0, self._raw_score - self._penalties) | |
| return min(1.0, net / self._max_possible) | |
| def _make_obs(self, *, done: bool, message: str) -> DispatchObservation: | |
| return DispatchObservation( | |
| done=done, | |
| reward=self._current_score() if done else None, | |
| incidents=deepcopy(self._incidents), | |
| units=deepcopy(self._units), | |
| step_count=self._step_count, | |
| max_steps=self._max_steps, | |
| dispatch_count=self._dispatch_count, | |
| score_so_far=self._current_score(), | |
| message=message, | |
| ) | |
| def _find_incident(self, iid: int) -> Optional[Incident]: | |
| return next((i for i in self._incidents if i.id == iid), None) | |
| def _find_unit(self, uid: int) -> Optional[Unit]: | |
| return next((u for u in self._units if u.id == uid), None) | |