trafficops / server /sim /engine.py
Kunalsinghh's picture
Upload folder using huggingface_hub
92107a5 verified
from typing import Optional
from .world import (
Direction,
Intersection,
Road,
TickStats,
Vehicle,
World,
)
GRIDLOCK_STALL_THRESHOLD = 20
HYSTERESIS = 1.15
def phase_serves(intersection: Intersection, direction: Direction) -> bool:
if intersection.preempt_direction is not None:
return intersection.preempt_direction == direction
return direction in intersection.current_phase()
def tick(world: World, on_tick_start=None) -> tuple[float, TickStats]:
world.tick += 1
stats = TickStats()
if on_tick_start is not None:
on_tick_start(world)
_activate_incidents(world)
_spawn_scheduled(world)
_expire_preempts(world)
# DQN controller runs every N ticks (decision interval)
if world.controller_mode == "dqn" and world.rl_controller is not None:
world.dqn_tick_counter += 1
if world.dqn_tick_counter >= world.dqn_decision_interval:
world.dqn_tick_counter = 0
from .rl_controller import rl_step
rl_step(world.rl_controller, world)
for I in world.intersections.values():
_local_controller_step(I, world, stats)
_move_vehicles(world, stats)
if not stats.moved_any and _has_waiting_vehicles(world):
world.metrics.stalled_streak += 1
if world.metrics.stalled_streak >= GRIDLOCK_STALL_THRESHOLD:
stats.gridlock = 1
world.metrics.gridlock_events += 1
world.metrics.stalled_streak = 0
world.log("GRIDLOCK detected, resetting stall counter")
else:
world.metrics.stalled_streak = 0
world.metrics.cleared_civilian += stats.cleared_civ
world.metrics.cleared_emergency += stats.cleared_em
world.metrics.wasted_green_ticks += stats.wasted_green
reward = _compute_tick_reward(world, stats)
return reward, stats
def _activate_incidents(world: World) -> None:
still_pending = []
for ev in world.incident_schedule:
if ev.tick <= world.tick:
ev.incident.active = True
world.roads[ev.incident.road_id].blocked = True
world.incidents.append(ev.incident)
world.log(
f"INCIDENT {ev.incident.id} {ev.incident.kind} blocks {ev.incident.road_id}"
)
else:
still_pending.append(ev)
world.incident_schedule = still_pending
for inc in world.incidents:
if inc.active and inc.end_tick is not None and world.tick >= inc.end_tick:
inc.active = False
world.roads[inc.road_id].blocked = False
world.log(f"INCIDENT {inc.id} cleared on {inc.road_id}")
def _spawn_scheduled(world: World) -> None:
remaining = []
for ev in world.spawn_schedule:
if ev.tick > world.tick:
remaining.append(ev)
continue
route = list(ev.route)
route = _apply_reroute_overrides(world, route)
first_road = world.roads.get(route[0]) if route else None
if first_road is None or first_road.cells[0] is not None or first_road.blocked:
ev.tick = world.tick + 1
remaining.append(ev)
continue
v = Vehicle(
id=ev.vehicle_id,
type=ev.vehicle_type,
route=route,
route_idx=0,
position_in_road=0,
spawn_tick=world.tick,
)
world.vehicles[v.id] = v
first_road.cells[0] = v.id
if v.is_emergency():
world.metrics.spawned_emergency += 1
world.log(f"SPAWN emergency {v.id} type={v.type} on {first_road.id}")
else:
world.metrics.spawned_civilian += 1
world.spawn_schedule = remaining
def _apply_reroute_overrides(world: World, route: list[str]) -> list[str]:
for blocked_road, detour in world.reroute_overrides.items():
if blocked_road in route:
idx = route.index(blocked_road)
return route[:idx] + detour + route[idx + 1 :]
return route
def _expire_preempts(world: World) -> None:
for I in world.intersections.values():
if I.preempt_direction is not None and I.preempt_expires_tick is not None:
if world.tick >= I.preempt_expires_tick:
I.preempt_direction = None
I.preempt_expires_tick = None
def _has_nondefault_bias(I: Intersection) -> bool:
return any(v != 1.0 for v in I.bias.values())
def _get_corridor_offset(iid: str, world: World) -> int:
for c in world.corridors.values():
if c.coordinated and iid in c.phase_offsets:
return c.phase_offsets[iid]
return 0
def _local_controller_step(I: Intersection, world: World, stats: TickStats) -> None:
I.phase_timer += 1
if I.preempt_direction is not None:
target = I.phase_idx_containing(I.preempt_direction)
if target is not None and target != I.current_phase_idx:
I.current_phase_idx = target
I.phase_timer = 0
return
cur_phase = I.current_phase()
served_demand = 0
waiting_elsewhere = 0
pressures: list[float] = []
for idx, phase in enumerate(I.phases):
p = 0.0
for d in phase:
rid = I.incoming.get(d)
if rid is None:
continue
demand = world.roads[rid].occupancy() + world.roads[rid].queue_at_tail() * 2
p += I.bias.get(d, 1.0) * demand
if idx == I.current_phase_idx:
served_demand += demand
pressures.append(p)
for d, rid in I.incoming.items():
if d not in cur_phase:
waiting_elsewhere += world.roads[rid].queue_at_tail()
if served_demand == 0 and waiting_elsewhere > 0:
stats.wasted_green += 1
if I.phase_timer < I.min_phase_ticks:
return
# Controller hierarchy: LLM bias > Max-Pressure > Fixed-time
if _has_nondefault_bias(I):
# LLM has set bias → pressure-responsive with green-wave offsets
offset = _get_corridor_offset(I.id, world)
effective_timer = I.phase_timer + offset
best = max(range(len(I.phases)), key=lambda i: pressures[i])
current_pressure = pressures[I.current_phase_idx]
force_switch = effective_timer >= I.max_phase_ticks and waiting_elsewhere > 0
if force_switch and best == I.current_phase_idx:
best = (I.current_phase_idx + 1) % len(I.phases)
if best != I.current_phase_idx and (
force_switch or pressures[best] > current_pressure * HYSTERESIS
):
I.current_phase_idx = best
I.phase_timer = 0
elif world.controller_mode == "max_pressure":
mp: list[float] = []
for idx, phase in enumerate(I.phases):
p = 0.0
for d in phase:
upstream = I.incoming.get(d)
downstream = I.outgoing.get(d)
u_q = world.roads[upstream].queue_at_tail() if upstream and upstream in world.roads else 0
d_occ = world.roads[downstream].occupancy() if downstream and downstream in world.roads else 0
p += max(0, u_q - d_occ)
mp.append(p)
best = max(range(len(I.phases)), key=lambda i: mp[i])
if best != I.current_phase_idx and mp[best] > mp[I.current_phase_idx]:
I.current_phase_idx = best
I.phase_timer = 0
elif I.phase_timer >= I.max_phase_ticks:
I.current_phase_idx = (I.current_phase_idx + 1) % len(I.phases)
I.phase_timer = 0
elif world.controller_mode == "dqn":
# DQN controller handles switching externally via rl_step()
# Here we just enforce max_phase as safety fallback
if I.phase_timer >= I.max_phase_ticks:
I.current_phase_idx = (I.current_phase_idx + 1) % len(I.phases)
I.phase_timer = 0
else:
# Fixed-time: cycle after max_phase_ticks unconditionally
if I.phase_timer >= I.max_phase_ticks:
I.current_phase_idx = (I.current_phase_idx + 1) % len(I.phases)
I.phase_timer = 0
def _move_vehicles(world: World, stats: TickStats) -> None:
order = sorted(
[v for v in world.vehicles.values() if not v.cleared],
key=lambda v: (-v.route_idx, -v.position_in_road),
)
for v in order:
_try_move(v, world, stats)
def _try_move(v: Vehicle, world: World, stats: TickStats) -> None:
road = world.roads[v.route[v.route_idx]]
if v.position_in_road < road.length - 1:
nxt = v.position_in_road + 1
if road.cells[nxt] is None:
road.cells[v.position_in_road] = None
v.position_in_road = nxt
road.cells[nxt] = v.id
stats.moved_any = True
else:
v.wait_ticks += 1
world.metrics.max_wait_ticks_seen = max(
world.metrics.max_wait_ticks_seen, v.wait_ticks
)
return
if v.route_idx == len(v.route) - 1:
road.cells[v.position_in_road] = None
v.cleared = True
v.clear_tick = world.tick
stats.moved_any = True
if v.is_emergency():
stats.cleared_em += 1
world.metrics.emergency_clear_times.append(world.tick - v.spawn_tick)
world.log(f"CLEAR emergency {v.id} in {world.tick - v.spawn_tick} ticks")
else:
stats.cleared_civ += 1
return
I = world.intersections[road.to_node]
approach = road.approach_direction
if not phase_serves(I, approach):
v.wait_ticks += 1
world.metrics.max_wait_ticks_seen = max(
world.metrics.max_wait_ticks_seen, v.wait_ticks
)
return
next_road = world.roads[v.route[v.route_idx + 1]]
if next_road.blocked or next_road.cells[0] is not None:
v.wait_ticks += 1
world.metrics.max_wait_ticks_seen = max(
world.metrics.max_wait_ticks_seen, v.wait_ticks
)
return
road.cells[v.position_in_road] = None
v.route_idx += 1
v.position_in_road = 0
next_road.cells[0] = v.id
stats.moved_any = True
def _has_waiting_vehicles(world: World) -> bool:
return any(not v.cleared for v in world.vehicles.values())
def _compute_tick_reward(world: World, stats: TickStats) -> float:
r = 0.0
r += 1.0 * stats.cleared_civ
r += 5.0 * stats.cleared_em
total_wait = sum(v.wait_ticks for v in world.vehicles.values() if not v.cleared)
r -= 0.002 * total_wait
r -= 0.3 * stats.wasted_green
r -= 12.0 * stats.gridlock
return r