""" server/environment.py — Enhanced 911 Dispatch Triage Environment v2 WHAT IS HAPPENING HERE ---------------------- This is a genuine multi-step RL environment. Here is the full episode lifecycle: 1. Episode starts → incidents arrive (with people counts, call descriptions) 2. Agent observes the full board: who needs help, which units are free 3. Agent dispatches ONE unit to ONE incident per step 4. Time ticks every step: - Unresolved incidents accumulate steps_waiting (severity decays) - Fire incidents spread (severity increases every FIRE_SPREAD_INTERVAL steps) - En-route units count down their travel time → become available again 5. When units return, agent dispatches them again 6. Episode ends when all incidents resolved OR max_steps reached REWARD MATHEMATICS (always in [0, 1]) -------------------------------------- For each dispatched incident: contribution = severity × log(1 + people_count) ← people multiplier × e^(-DECAY_LAMBDA × wait) ← time decay × match_quality(type, unit) ← unit type effectiveness For each unresolved incident at episode end: penalty = severity × log(1 + people_count) × UNRESOLVED_PENALTY_FRACTION For cascade violations (hard mode): penalty += CASCADE_PENALTY per violation final_score = max(0, sum(contributions) - sum(penalties)) / max_possible ∈ [0, 1] always max_possible = sum(severity × log(1+people) for ALL incidents) dispatched instantly perfect match. Upper bound — agent approaches but rarely reaches 1.0. WHY THIS IS REAL RL (not just sorting) ---------------------------------------- - Agent must learn that people_count matters more than raw severity - Agent must learn optimal dispatch timing (units returning from medium priority calls might be better held for an incoming high priority call) - Agent must learn cascade dependencies (gas_leak before cardiac) - Agent must learn fire is a time bomb (severity grows each step) - None of this is told to the agent — it discovers it from reward signals """ import uuid import math from typing import List, Optional from copy import deepcopy try: from openenv.core.env_server import Environment except ModuleNotFoundError: import sys, os sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) from openenv_stubs import Environment try: from models import ( DispatchAction, DispatchObservation, DispatchState, Incident, Unit, DECAY_LAMBDA, FIRE_SPREAD_INTERVAL, UNRESOLVED_PENALTY_FRACTION, CASCADE_PENALTY, get_match_quality, effective_priority, ) except ModuleNotFoundError: from ..models import ( DispatchAction, DispatchObservation, DispatchState, Incident, Unit, DECAY_LAMBDA, FIRE_SPREAD_INTERVAL, UNRESOLVED_PENALTY_FRACTION, CASCADE_PENALTY, get_match_quality, effective_priority, ) # ───────────────────────────────────────────────────────────────────────────── # Scenario definitions # ───────────────────────────────────────────────────────────────────────────── # # Each incident has: # call_description — raw 911 call text (LLM reads this to infer severity) # people_count — how many people at risk # fire_spreads — whether severity grows over time # # The agent sees severity + people_count as numbers and learns the weighting. # The LLM can additionally re-assess severity from call_description text. SCENARIOS = { # ────────────────────────────────────────────────────────────────── # EASY — 3 incidents, 3 units, travel_time=1, max_steps=6 # # Lesson: dispatch highest effective priority first. # cardiac + 1 person looks obvious, but agent must learn # that crash + 3 people > fire + 1 person despite equal raw severity. # ────────────────────────────────────────────────────────────────── "easy": { "max_steps": 6, "incidents": [ Incident( id=0, type="cardiac_arrest", severity=9, location="Block 2A", people_count=1, fire_spreads=False, call_description=( "Caller reports elderly man collapsed, not breathing. " "CPR being attempted. Address: 12 Oak Street, Block 2A." ), ), Incident( id=1, type="car_crash", severity=5, location="Block 7C", people_count=3, fire_spreads=False, call_description=( "Three-car collision at Block 7C. Three people injured, " "one trapped in vehicle. No fire visible." ), ), Incident( id=2, type="fire", severity=3, location="Block 1B", people_count=1, fire_spreads=True, call_description=( "Small kitchen fire at Block 1B apartment. One resident " "evacuated. Fire contained to one room so far." ), ), ], "units": [ Unit(id=0, type="ambulance", travel_time=1), Unit(id=1, type="police", travel_time=1), Unit(id=2, type="fire_truck", travel_time=1), ], }, # ────────────────────────────────────────────────────────────────── # MEDIUM — 5 incidents, 3 units, travel_time=2, max_steps=12 # # Lessons: # 1. Units return after 2 steps — agent must plan across waves # 2. car_crash with 8 people (effective priority 8.79) beats # gas_leak with 0 people (effective priority 0.0) despite # gas_leak having higher raw severity # 3. Fire spreads — delay costs grow non-linearly # ────────────────────────────────────────────────────────────────── "medium": { "max_steps": 12, "incidents": [ Incident( id=0, type="cardiac_arrest", severity=9, location="Block 3A", people_count=2, fire_spreads=False, call_description=( "Two people collapsed at Block 3A community centre, " "possible carbon monoxide poisoning. Both unresponsive." ), ), Incident( id=1, type="fire", severity=7, location="Block 5D", people_count=6, fire_spreads=True, call_description=( "Large building fire at Block 5D. Flames visible from " "third floor. At least 6 residents unable to evacuate." ), ), Incident( id=2, type="car_crash", severity=5, location="Block 9B", people_count=8, fire_spreads=False, call_description=( "Major crash on Block 9B expressway. Multiple vehicles. " "Caller reports 8 people injured, one vehicle on its side." ), ), Incident( id=3, type="gas_leak", severity=6, location="Block 2C", people_count=0, fire_spreads=False, call_description=( "Strong gas smell reported at Block 2C. Building evacuated. " "No injuries yet but area needs to be secured." ), ), Incident( id=4, type="car_crash", severity=3, location="Block 11E", people_count=1, fire_spreads=False, call_description=( "Minor fender-bender at Block 11E. One driver with " "minor cuts. No serious injuries reported." ), ), ], "units": [ Unit(id=0, type="ambulance", travel_time=2), Unit(id=1, type="fire_truck", travel_time=2), Unit(id=2, type="police", travel_time=2), ], }, # ────────────────────────────────────────────────────────────────── # HARD — 7 incidents, 3 units, travel_time=2, max_steps=18 # # Lessons: # 1. Cascade: cardiac (id=0) depends on gas_leak (id=1). # Gas leak at Block 4B is adjacent. Dispatching cardiac first # without clearing the gas leak triggers CASCADE_PENALTY. # 2. Fire at Block 9C has 12 people AND spreads. # Every 2 steps ignored, severity +1. After 4 steps: sev=10. # 3. Three waves of dispatch needed — agent must plan 6 steps ahead. # 4. car_crash id=3 has 5 people. Despite sev=5, effective priority # = 5 × log(6) = 8.96 — nearly as urgent as the cardiac. # ────────────────────────────────────────────────────────────────── "hard": { "max_steps": 18, "incidents": [ Incident( id=0, type="cardiac_arrest", severity=7, location="Block 4A", people_count=1, fire_spreads=False, depends_on=[1], call_description=( "Man having heart attack at Block 4A, next to the building " "with the reported gas leak. Caller is panicking. " "Address same block as the gas incident." ), ), Incident( id=1, type="gas_leak", severity=6, location="Block 4B", people_count=0, fire_spreads=False, call_description=( "Major gas leak at Block 4B, adjacent to Block 4A. " "Strong smell reported. Area not yet evacuated. " "Risk of explosion if ignition source present." ), ), Incident( id=2, type="fire", severity=8, location="Block 9C", people_count=12, fire_spreads=True, call_description=( "Warehouse fire at Block 9C, spreading rapidly. " "Night shift workers trapped inside, approximately 12 people. " "Flames visible from street." ), ), Incident( id=3, type="car_crash", severity=5, location="Block 2D", people_count=5, fire_spreads=False, call_description=( "Head-on collision at Block 2D intersection. Five occupants, " "multiple injuries. One child among the injured. " "Vehicles blocking traffic." ), ), Incident( id=4, type="car_crash", severity=4, location="Block 6E", people_count=2, fire_spreads=False, call_description=( "Vehicle hit a lamp post at Block 6E. Driver and passenger " "injured. Airbags deployed. Both conscious." ), ), Incident( id=5, type="fire", severity=3, location="Block 1F", people_count=0, fire_spreads=True, call_description=( "Bin fire at Block 1F alley. No people involved. " "Risk of spreading to nearby building if not contained." ), ), Incident( id=6, type="cardiac_arrest", severity=2, location="Block 12G", people_count=1, fire_spreads=False, call_description=( "Elderly woman feeling chest pains at Block 12G. " "Conscious and breathing. Not a confirmed cardiac event yet." ), ), ], "units": [ Unit(id=0, type="ambulance", travel_time=2), Unit(id=1, type="fire_truck", travel_time=2), Unit(id=2, type="police", travel_time=2), ], }, } # ───────────────────────────────────────────────────────────────────────────── # Helpers # ───────────────────────────────────────────────────────────────────────────── def _compute_max_possible(incidents: List[Incident]) -> float: """ Realistic sequential optimum: sort incidents by EP descending, then assign the minimum achievable wait = dispatch_index (because each step() call ticks time once for every incident). Why this matters ---------------- The old formula used wait=0 for ALL incidents — impossible with sequential dispatch. With N incidents, the k-th dispatch always incurs at least k steps of waiting for the remaining incidents. The new formula makes score = 1.0 reachable when the agent: 1. Dispatches highest-EP incident first (lowest cumulative decay) 2. Matches units correctly (match_quality = 1.0) 3. Never wastes steps This gives a meaningful target the agent can actually hit. """ ep_sorted = sorted( incidents, key=lambda inc: inc.severity * math.log(1 + inc.people_count), reverse=True, ) total = sum( inc.severity * math.log(1 + inc.people_count) * math.exp(-DECAY_LAMBDA * idx) # minimum wait = dispatch order index * 1.0 # assume perfect unit match for idx, inc in enumerate(ep_sorted) ) return max(total, 1.0) # guard against division by zero def _compute_contribution(inc: Incident, unit_type: str) -> float: """ Score contribution for dispatching unit_type to incident at its current steps_waiting. = severity × log(1+people) × e^(-λ×wait) × match_quality """ ep = inc.severity * math.log(1 + inc.people_count) decay = math.exp(-DECAY_LAMBDA * inc.steps_waiting) match = get_match_quality(inc.type, unit_type) return ep * decay * match # ───────────────────────────────────────────────────────────────────────────── # Environment # ───────────────────────────────────────────────────────────────────────────── class DispatchEnvironment(Environment): """ Multi-step 911 dispatch triage environment. One step = one dispatch action + one time tick. One episode = multiple steps until all resolved or max_steps reached. The agent interacts via: obs = env.reset(difficulty="easy"|"medium"|"hard") obs = env.step(DispatchAction(incident_id=X, unit_id=Y)) state = env.state """ SUPPORTS_CONCURRENT_SESSIONS = True def __init__(self): self._incidents: List[Incident] = [] self._units: List[Unit] = [] self._step_count: int = 0 self._max_steps: int = 10 self._dispatch_count: int = 0 self._raw_score: float = 0.0 self._penalties: float = 0.0 self._max_possible: float = 1.0 self._state = DispatchState() # ────────────────────────────────────────────────────────────────── # reset # ────────────────────────────────────────────────────────────────── def reset( self, seed=None, episode_id=None, difficulty: str = "easy", **kwargs, ) -> DispatchObservation: if difficulty not in SCENARIOS: difficulty = "easy" scenario = SCENARIOS[difficulty] self._incidents = [i.model_copy(deep=True) for i in scenario["incidents"]] self._units = [u.model_copy(deep=True) for u in scenario["units"]] self._max_steps = scenario["max_steps"] self._step_count = 0 self._dispatch_count = 0 self._raw_score = 0.0 self._penalties = 0.0 self._max_possible = _compute_max_possible(self._incidents) self._state = DispatchState( episode_id=episode_id or str(uuid.uuid4()), step_count=0, difficulty=difficulty, total_incidents=len(self._incidents), total_units=len(self._units), max_steps=self._max_steps, max_possible_score=self._max_possible, ) return self._make_obs( done=False, message=( f"[{difficulty.upper()}] {len(self._incidents)} incidents, " f"{len(self._units)} units available, " f"{self._max_steps} steps budget. " f"Dispatch wisely — people count and time decay matter." ), ) # ────────────────────────────────────────────────────────────────── # step # ────────────────────────────────────────────────────────────────── def step( self, action: DispatchAction, timeout_s=None, **kwargs, ) -> DispatchObservation: self._step_count += 1 self._state.step_count = self._step_count notes = [] # ── Validate ───────────────────────────────────────────────── incident = self._find_incident(action.incident_id) unit = self._find_unit(action.unit_id) if incident is None: self._tick_time() return self._make_obs( done=self._is_done(), message=f"Invalid incident id {action.incident_id}. Time still ticked.", ) if unit is None: self._tick_time() return self._make_obs( done=self._is_done(), message=f"Invalid unit id {action.unit_id}. Time still ticked.", ) if not unit.available: self._tick_time() return self._make_obs( done=self._is_done(), message=( f"Unit {action.unit_id} ({unit.type}) is en route — " f"returns in {unit.steps_until_free} step(s). Time ticked." ), ) if incident.resolved: self._tick_time() return self._make_obs( done=self._is_done(), message=f"Incident {action.incident_id} already resolved. Time ticked.", ) # ── Check cascade ───────────────────────────────────────────── cascade_triggered = False if incident.depends_on: unresolved_deps = [ d for d in incident.depends_on if not self._find_incident(d).resolved ] if unresolved_deps: cascade_triggered = True self._penalties += CASCADE_PENALTY notes.append( f"CASCADE PENALTY -{CASCADE_PENALTY}: " f"dependency incident(s) {unresolved_deps} unresolved!" ) # ── Dispatch ────────────────────────────────────────────────── self._dispatch_count += 1 contribution = _compute_contribution(incident, unit.type) self._raw_score += contribution unit.available = False unit.steps_until_free = unit.travel_time incident.resolved = True incident.assigned_unit_id = unit.id # Build dispatch note ep = incident.severity * math.log(1 + incident.people_count) decay = math.exp(-DECAY_LAMBDA * incident.steps_waiting) match = get_match_quality(incident.type, unit.type) notes.append( f"Dispatched {unit.type} -> {incident.type} at {incident.location} " f"[sev={incident.severity} people={incident.people_count} " f"wait={incident.steps_waiting}s] " f"contribution={contribution:.3f} " f"(ep={ep:.2f} x decay={decay:.2f} x match={match:.1f})" ) if match < 0.5: notes.append(f"WRONG UNIT TYPE — match quality only {match:.1f}") # ── Tick time ───────────────────────────────────────────────── spread_notes = self._tick_time() notes.extend(spread_notes) # ── Done? ───────────────────────────────────────────────────── done = self._is_done() score = self._current_score() if done: # Apply unresolved penalties unresolved = [i for i in self._incidents if not i.resolved] for inc in unresolved: pen = ( inc.severity * math.log(1 + inc.people_count) * UNRESOLVED_PENALTY_FRACTION ) self._penalties += pen notes.append( f"UNRESOLVED PENALTY -{pen:.3f}: " f"{inc.type} at {inc.location} " f"[sev={inc.severity} people={inc.people_count}]" ) score = self._current_score() notes.append(f"EPISODE DONE. Final score: {score:.4f}") return DispatchObservation( done=done, reward=score if done else None, incidents=deepcopy(self._incidents), units=deepcopy(self._units), step_count=self._step_count, max_steps=self._max_steps, dispatch_count=self._dispatch_count, score_so_far=score, message=" | ".join(notes), ) # ────────────────────────────────────────────────────────────────── # state # ────────────────────────────────────────────────────────────────── @property def state(self) -> DispatchState: return self._state # ────────────────────────────────────────────────────────────────── # Private helpers # ────────────────────────────────────────────────────────────────── def _tick_time(self) -> List[str]: """ Advance the world by one time step. - Units en route: decrement countdown, return to base when 0 - Unresolved incidents: accumulate waiting time - Fire incidents: spread (severity +1) every FIRE_SPREAD_INTERVAL steps Returns list of notable event strings. """ notes = [] # Units returning to base for u in self._units: if not u.available and u.steps_until_free > 0: u.steps_until_free -= 1 if u.steps_until_free == 0: u.available = True notes.append(f"{u.type} (id={u.id}) returned to base — available.") # Incidents waiting for inc in self._incidents: if not inc.resolved: inc.steps_waiting += 1 # Fire spreads if ( inc.fire_spreads and inc.steps_waiting > 0 and inc.steps_waiting % FIRE_SPREAD_INTERVAL == 0 and inc.severity < 10 ): inc.severity += 1 notes.append( f"FIRE SPREAD at {inc.location}: " f"severity now {inc.severity}!" ) return notes def _is_done(self) -> bool: all_resolved = all(i.resolved for i in self._incidents) max_steps_hit = self._step_count >= self._max_steps # No units available AND none returning AND unresolved incidents exist no_help_possible = ( any(not i.resolved for i in self._incidents) and not any(u.available for u in self._units) and not any(u.steps_until_free > 0 for u in self._units) ) return all_resolved or max_steps_hit or no_help_possible def _current_score(self) -> float: net = max(0.0, self._raw_score - self._penalties) return min(1.0, net / self._max_possible) def _make_obs(self, *, done: bool, message: str) -> DispatchObservation: return DispatchObservation( done=done, reward=self._current_score() if done else None, incidents=deepcopy(self._incidents), units=deepcopy(self._units), step_count=self._step_count, max_steps=self._max_steps, dispatch_count=self._dispatch_count, score_so_far=self._current_score(), message=message, ) def _find_incident(self, iid: int) -> Optional[Incident]: return next((i for i in self._incidents if i.id == iid), None) def _find_unit(self, uid: int) -> Optional[Unit]: return next((u for u in self._units if u.id == uid), None)