dispatch-triage / server /environment.py
muskanp's picture
Upload folder using huggingface_hub
86d3c2a verified
"""
server/environment.py — Enhanced 911 Dispatch Triage Environment v2
WHAT IS HAPPENING HERE
----------------------
This is a genuine multi-step RL environment. Here is the full episode lifecycle:
1. Episode starts → incidents arrive (with people counts, call descriptions)
2. Agent observes the full board: who needs help, which units are free
3. Agent dispatches ONE unit to ONE incident per step
4. Time ticks every step:
- Unresolved incidents accumulate steps_waiting (severity decays)
- Fire incidents spread (severity increases every FIRE_SPREAD_INTERVAL steps)
- En-route units count down their travel time → become available again
5. When units return, agent dispatches them again
6. Episode ends when all incidents resolved OR max_steps reached
REWARD MATHEMATICS (always in [0, 1])
--------------------------------------
For each dispatched incident:
contribution = severity
× log(1 + people_count) ← people multiplier
× e^(-DECAY_LAMBDA × wait) ← time decay
× match_quality(type, unit) ← unit type effectiveness
For each unresolved incident at episode end:
penalty = severity × log(1 + people_count) × UNRESOLVED_PENALTY_FRACTION
For cascade violations (hard mode):
penalty += CASCADE_PENALTY per violation
final_score = max(0, sum(contributions) - sum(penalties)) / max_possible
∈ [0, 1] always
max_possible = sum(severity × log(1+people) for ALL incidents) dispatched instantly
perfect match. Upper bound — agent approaches but rarely reaches 1.0.
WHY THIS IS REAL RL (not just sorting)
----------------------------------------
- Agent must learn that people_count matters more than raw severity
- Agent must learn optimal dispatch timing (units returning from medium priority
calls might be better held for an incoming high priority call)
- Agent must learn cascade dependencies (gas_leak before cardiac)
- Agent must learn fire is a time bomb (severity grows each step)
- None of this is told to the agent — it discovers it from reward signals
"""
import uuid
import math
from typing import List, Optional
from copy import deepcopy
try:
from openenv.core.env_server import Environment
except ModuleNotFoundError:
import sys, os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from openenv_stubs import Environment
try:
from models import (
DispatchAction, DispatchObservation, DispatchState,
Incident, Unit,
DECAY_LAMBDA, FIRE_SPREAD_INTERVAL,
UNRESOLVED_PENALTY_FRACTION, CASCADE_PENALTY,
get_match_quality, effective_priority,
)
except ModuleNotFoundError:
from ..models import (
DispatchAction, DispatchObservation, DispatchState,
Incident, Unit,
DECAY_LAMBDA, FIRE_SPREAD_INTERVAL,
UNRESOLVED_PENALTY_FRACTION, CASCADE_PENALTY,
get_match_quality, effective_priority,
)
# ─────────────────────────────────────────────────────────────────────────────
# Scenario definitions
# ─────────────────────────────────────────────────────────────────────────────
#
# Each incident has:
# call_description — raw 911 call text (LLM reads this to infer severity)
# people_count — how many people at risk
# fire_spreads — whether severity grows over time
#
# The agent sees severity + people_count as numbers and learns the weighting.
# The LLM can additionally re-assess severity from call_description text.
SCENARIOS = {
# ──────────────────────────────────────────────────────────────────
# EASY — 3 incidents, 3 units, travel_time=1, max_steps=6
#
# Lesson: dispatch highest effective priority first.
# cardiac + 1 person looks obvious, but agent must learn
# that crash + 3 people > fire + 1 person despite equal raw severity.
# ──────────────────────────────────────────────────────────────────
"easy": {
"max_steps": 6,
"incidents": [
Incident(
id=0, type="cardiac_arrest", severity=9, location="Block 2A",
people_count=1, fire_spreads=False,
call_description=(
"Caller reports elderly man collapsed, not breathing. "
"CPR being attempted. Address: 12 Oak Street, Block 2A."
),
),
Incident(
id=1, type="car_crash", severity=5, location="Block 7C",
people_count=3, fire_spreads=False,
call_description=(
"Three-car collision at Block 7C. Three people injured, "
"one trapped in vehicle. No fire visible."
),
),
Incident(
id=2, type="fire", severity=3, location="Block 1B",
people_count=1, fire_spreads=True,
call_description=(
"Small kitchen fire at Block 1B apartment. One resident "
"evacuated. Fire contained to one room so far."
),
),
],
"units": [
Unit(id=0, type="ambulance", travel_time=1),
Unit(id=1, type="police", travel_time=1),
Unit(id=2, type="fire_truck", travel_time=1),
],
},
# ──────────────────────────────────────────────────────────────────
# MEDIUM — 5 incidents, 3 units, travel_time=2, max_steps=12
#
# Lessons:
# 1. Units return after 2 steps — agent must plan across waves
# 2. car_crash with 8 people (effective priority 8.79) beats
# gas_leak with 0 people (effective priority 0.0) despite
# gas_leak having higher raw severity
# 3. Fire spreads — delay costs grow non-linearly
# ──────────────────────────────────────────────────────────────────
"medium": {
"max_steps": 12,
"incidents": [
Incident(
id=0, type="cardiac_arrest", severity=9, location="Block 3A",
people_count=2, fire_spreads=False,
call_description=(
"Two people collapsed at Block 3A community centre, "
"possible carbon monoxide poisoning. Both unresponsive."
),
),
Incident(
id=1, type="fire", severity=7, location="Block 5D",
people_count=6, fire_spreads=True,
call_description=(
"Large building fire at Block 5D. Flames visible from "
"third floor. At least 6 residents unable to evacuate."
),
),
Incident(
id=2, type="car_crash", severity=5, location="Block 9B",
people_count=8, fire_spreads=False,
call_description=(
"Major crash on Block 9B expressway. Multiple vehicles. "
"Caller reports 8 people injured, one vehicle on its side."
),
),
Incident(
id=3, type="gas_leak", severity=6, location="Block 2C",
people_count=0, fire_spreads=False,
call_description=(
"Strong gas smell reported at Block 2C. Building evacuated. "
"No injuries yet but area needs to be secured."
),
),
Incident(
id=4, type="car_crash", severity=3, location="Block 11E",
people_count=1, fire_spreads=False,
call_description=(
"Minor fender-bender at Block 11E. One driver with "
"minor cuts. No serious injuries reported."
),
),
],
"units": [
Unit(id=0, type="ambulance", travel_time=2),
Unit(id=1, type="fire_truck", travel_time=2),
Unit(id=2, type="police", travel_time=2),
],
},
# ──────────────────────────────────────────────────────────────────
# HARD — 7 incidents, 3 units, travel_time=2, max_steps=18
#
# Lessons:
# 1. Cascade: cardiac (id=0) depends on gas_leak (id=1).
# Gas leak at Block 4B is adjacent. Dispatching cardiac first
# without clearing the gas leak triggers CASCADE_PENALTY.
# 2. Fire at Block 9C has 12 people AND spreads.
# Every 2 steps ignored, severity +1. After 4 steps: sev=10.
# 3. Three waves of dispatch needed — agent must plan 6 steps ahead.
# 4. car_crash id=3 has 5 people. Despite sev=5, effective priority
# = 5 × log(6) = 8.96 — nearly as urgent as the cardiac.
# ──────────────────────────────────────────────────────────────────
"hard": {
"max_steps": 18,
"incidents": [
Incident(
id=0, type="cardiac_arrest", severity=7, location="Block 4A",
people_count=1, fire_spreads=False, depends_on=[1],
call_description=(
"Man having heart attack at Block 4A, next to the building "
"with the reported gas leak. Caller is panicking. "
"Address same block as the gas incident."
),
),
Incident(
id=1, type="gas_leak", severity=6, location="Block 4B",
people_count=0, fire_spreads=False,
call_description=(
"Major gas leak at Block 4B, adjacent to Block 4A. "
"Strong smell reported. Area not yet evacuated. "
"Risk of explosion if ignition source present."
),
),
Incident(
id=2, type="fire", severity=8, location="Block 9C",
people_count=12, fire_spreads=True,
call_description=(
"Warehouse fire at Block 9C, spreading rapidly. "
"Night shift workers trapped inside, approximately 12 people. "
"Flames visible from street."
),
),
Incident(
id=3, type="car_crash", severity=5, location="Block 2D",
people_count=5, fire_spreads=False,
call_description=(
"Head-on collision at Block 2D intersection. Five occupants, "
"multiple injuries. One child among the injured. "
"Vehicles blocking traffic."
),
),
Incident(
id=4, type="car_crash", severity=4, location="Block 6E",
people_count=2, fire_spreads=False,
call_description=(
"Vehicle hit a lamp post at Block 6E. Driver and passenger "
"injured. Airbags deployed. Both conscious."
),
),
Incident(
id=5, type="fire", severity=3, location="Block 1F",
people_count=0, fire_spreads=True,
call_description=(
"Bin fire at Block 1F alley. No people involved. "
"Risk of spreading to nearby building if not contained."
),
),
Incident(
id=6, type="cardiac_arrest", severity=2, location="Block 12G",
people_count=1, fire_spreads=False,
call_description=(
"Elderly woman feeling chest pains at Block 12G. "
"Conscious and breathing. Not a confirmed cardiac event yet."
),
),
],
"units": [
Unit(id=0, type="ambulance", travel_time=2),
Unit(id=1, type="fire_truck", travel_time=2),
Unit(id=2, type="police", travel_time=2),
],
},
}
# ─────────────────────────────────────────────────────────────────────────────
# Helpers
# ─────────────────────────────────────────────────────────────────────────────
def _compute_max_possible(incidents: List[Incident]) -> float:
"""
Realistic sequential optimum: sort incidents by EP descending, then
assign the minimum achievable wait = dispatch_index (because each
step() call ticks time once for every incident).
Why this matters
----------------
The old formula used wait=0 for ALL incidents — impossible with
sequential dispatch. With N incidents, the k-th dispatch always
incurs at least k steps of waiting for the remaining incidents.
The new formula makes score = 1.0 reachable when the agent:
1. Dispatches highest-EP incident first (lowest cumulative decay)
2. Matches units correctly (match_quality = 1.0)
3. Never wastes steps
This gives a meaningful target the agent can actually hit.
"""
ep_sorted = sorted(
incidents,
key=lambda inc: inc.severity * math.log(1 + inc.people_count),
reverse=True,
)
total = sum(
inc.severity
* math.log(1 + inc.people_count)
* math.exp(-DECAY_LAMBDA * idx) # minimum wait = dispatch order index
* 1.0 # assume perfect unit match
for idx, inc in enumerate(ep_sorted)
)
return max(total, 1.0) # guard against division by zero
def _compute_contribution(inc: Incident, unit_type: str) -> float:
"""
Score contribution for dispatching unit_type to incident at its
current steps_waiting.
= severity × log(1+people) × e^(-λ×wait) × match_quality
"""
ep = inc.severity * math.log(1 + inc.people_count)
decay = math.exp(-DECAY_LAMBDA * inc.steps_waiting)
match = get_match_quality(inc.type, unit_type)
return ep * decay * match
# ─────────────────────────────────────────────────────────────────────────────
# Environment
# ─────────────────────────────────────────────────────────────────────────────
class DispatchEnvironment(Environment):
"""
Multi-step 911 dispatch triage environment.
One step = one dispatch action + one time tick.
One episode = multiple steps until all resolved or max_steps reached.
The agent interacts via:
obs = env.reset(difficulty="easy"|"medium"|"hard")
obs = env.step(DispatchAction(incident_id=X, unit_id=Y))
state = env.state
"""
SUPPORTS_CONCURRENT_SESSIONS = True
def __init__(self):
self._incidents: List[Incident] = []
self._units: List[Unit] = []
self._step_count: int = 0
self._max_steps: int = 10
self._dispatch_count: int = 0
self._raw_score: float = 0.0
self._penalties: float = 0.0
self._max_possible: float = 1.0
self._state = DispatchState()
# ──────────────────────────────────────────────────────────────────
# reset
# ──────────────────────────────────────────────────────────────────
def reset(
self,
seed=None,
episode_id=None,
difficulty: str = "easy",
**kwargs,
) -> DispatchObservation:
if difficulty not in SCENARIOS:
difficulty = "easy"
scenario = SCENARIOS[difficulty]
self._incidents = [i.model_copy(deep=True) for i in scenario["incidents"]]
self._units = [u.model_copy(deep=True) for u in scenario["units"]]
self._max_steps = scenario["max_steps"]
self._step_count = 0
self._dispatch_count = 0
self._raw_score = 0.0
self._penalties = 0.0
self._max_possible = _compute_max_possible(self._incidents)
self._state = DispatchState(
episode_id=episode_id or str(uuid.uuid4()),
step_count=0,
difficulty=difficulty,
total_incidents=len(self._incidents),
total_units=len(self._units),
max_steps=self._max_steps,
max_possible_score=self._max_possible,
)
return self._make_obs(
done=False,
message=(
f"[{difficulty.upper()}] {len(self._incidents)} incidents, "
f"{len(self._units)} units available, "
f"{self._max_steps} steps budget. "
f"Dispatch wisely — people count and time decay matter."
),
)
# ──────────────────────────────────────────────────────────────────
# step
# ──────────────────────────────────────────────────────────────────
def step(
self,
action: DispatchAction,
timeout_s=None,
**kwargs,
) -> DispatchObservation:
self._step_count += 1
self._state.step_count = self._step_count
notes = []
# ── Validate ─────────────────────────────────────────────────
incident = self._find_incident(action.incident_id)
unit = self._find_unit(action.unit_id)
if incident is None:
self._tick_time()
return self._make_obs(
done=self._is_done(),
message=f"Invalid incident id {action.incident_id}. Time still ticked.",
)
if unit is None:
self._tick_time()
return self._make_obs(
done=self._is_done(),
message=f"Invalid unit id {action.unit_id}. Time still ticked.",
)
if not unit.available:
self._tick_time()
return self._make_obs(
done=self._is_done(),
message=(
f"Unit {action.unit_id} ({unit.type}) is en route — "
f"returns in {unit.steps_until_free} step(s). Time ticked."
),
)
if incident.resolved:
self._tick_time()
return self._make_obs(
done=self._is_done(),
message=f"Incident {action.incident_id} already resolved. Time ticked.",
)
# ── Check cascade ─────────────────────────────────────────────
cascade_triggered = False
if incident.depends_on:
unresolved_deps = [
d for d in incident.depends_on
if not self._find_incident(d).resolved
]
if unresolved_deps:
cascade_triggered = True
self._penalties += CASCADE_PENALTY
notes.append(
f"CASCADE PENALTY -{CASCADE_PENALTY}: "
f"dependency incident(s) {unresolved_deps} unresolved!"
)
# ── Dispatch ──────────────────────────────────────────────────
self._dispatch_count += 1
contribution = _compute_contribution(incident, unit.type)
self._raw_score += contribution
unit.available = False
unit.steps_until_free = unit.travel_time
incident.resolved = True
incident.assigned_unit_id = unit.id
# Build dispatch note
ep = incident.severity * math.log(1 + incident.people_count)
decay = math.exp(-DECAY_LAMBDA * incident.steps_waiting)
match = get_match_quality(incident.type, unit.type)
notes.append(
f"Dispatched {unit.type} -> {incident.type} at {incident.location} "
f"[sev={incident.severity} people={incident.people_count} "
f"wait={incident.steps_waiting}s] "
f"contribution={contribution:.3f} "
f"(ep={ep:.2f} x decay={decay:.2f} x match={match:.1f})"
)
if match < 0.5:
notes.append(f"WRONG UNIT TYPE — match quality only {match:.1f}")
# ── Tick time ─────────────────────────────────────────────────
spread_notes = self._tick_time()
notes.extend(spread_notes)
# ── Done? ─────────────────────────────────────────────────────
done = self._is_done()
score = self._current_score()
if done:
# Apply unresolved penalties
unresolved = [i for i in self._incidents if not i.resolved]
for inc in unresolved:
pen = (
inc.severity
* math.log(1 + inc.people_count)
* UNRESOLVED_PENALTY_FRACTION
)
self._penalties += pen
notes.append(
f"UNRESOLVED PENALTY -{pen:.3f}: "
f"{inc.type} at {inc.location} "
f"[sev={inc.severity} people={inc.people_count}]"
)
score = self._current_score()
notes.append(f"EPISODE DONE. Final score: {score:.4f}")
return DispatchObservation(
done=done,
reward=score if done else None,
incidents=deepcopy(self._incidents),
units=deepcopy(self._units),
step_count=self._step_count,
max_steps=self._max_steps,
dispatch_count=self._dispatch_count,
score_so_far=score,
message=" | ".join(notes),
)
# ──────────────────────────────────────────────────────────────────
# state
# ──────────────────────────────────────────────────────────────────
@property
def state(self) -> DispatchState:
return self._state
# ──────────────────────────────────────────────────────────────────
# Private helpers
# ──────────────────────────────────────────────────────────────────
def _tick_time(self) -> List[str]:
"""
Advance the world by one time step.
- Units en route: decrement countdown, return to base when 0
- Unresolved incidents: accumulate waiting time
- Fire incidents: spread (severity +1) every FIRE_SPREAD_INTERVAL steps
Returns list of notable event strings.
"""
notes = []
# Units returning to base
for u in self._units:
if not u.available and u.steps_until_free > 0:
u.steps_until_free -= 1
if u.steps_until_free == 0:
u.available = True
notes.append(f"{u.type} (id={u.id}) returned to base — available.")
# Incidents waiting
for inc in self._incidents:
if not inc.resolved:
inc.steps_waiting += 1
# Fire spreads
if (
inc.fire_spreads
and inc.steps_waiting > 0
and inc.steps_waiting % FIRE_SPREAD_INTERVAL == 0
and inc.severity < 10
):
inc.severity += 1
notes.append(
f"FIRE SPREAD at {inc.location}: "
f"severity now {inc.severity}!"
)
return notes
def _is_done(self) -> bool:
all_resolved = all(i.resolved for i in self._incidents)
max_steps_hit = self._step_count >= self._max_steps
# No units available AND none returning AND unresolved incidents exist
no_help_possible = (
any(not i.resolved for i in self._incidents)
and not any(u.available for u in self._units)
and not any(u.steps_until_free > 0 for u in self._units)
)
return all_resolved or max_steps_hit or no_help_possible
def _current_score(self) -> float:
net = max(0.0, self._raw_score - self._penalties)
return min(1.0, net / self._max_possible)
def _make_obs(self, *, done: bool, message: str) -> DispatchObservation:
return DispatchObservation(
done=done,
reward=self._current_score() if done else None,
incidents=deepcopy(self._incidents),
units=deepcopy(self._units),
step_count=self._step_count,
max_steps=self._max_steps,
dispatch_count=self._dispatch_count,
score_so_far=self._current_score(),
message=message,
)
def _find_incident(self, iid: int) -> Optional[Incident]:
return next((i for i in self._incidents if i.id == iid), None)
def _find_unit(self, uid: int) -> Optional[Unit]:
return next((u for u in self._units if u.id == uid), None)