trafficops / server /tasks.py
Kunalsinghh's picture
Upload folder using huggingface_hub
92107a5 verified
from typing import Callable
from .sim.builders import (
add_corridor,
add_intersection,
add_road,
connect_neighbors,
new_world,
schedule_incident,
spawn,
spawn_stream,
wire,
)
from .sim.rl_controller import get_controller
from .sim.world import World
TASK_IDS = [
"grid_balanced",
"demand_shift",
"incident_corridor",
"rush_hour_wave",
"multi_crisis",
]
GRID_ROWS = 4
GRID_COLS = 4
ROAD_LEN = 8
SOURCE_LEN = 6
SINK_LEN = 5
def build(task: str, seed: int) -> World:
builder = _BUILDERS.get(task)
if builder is None:
raise ValueError(f"unknown task: {task}")
return builder(seed)
def _iid(r: int, c: int) -> str:
return f"I_{r}_{c}"
def _build_grid(
task: str,
seed: int,
horizon: int,
budget: int,
) -> World:
w = new_world(task, horizon=horizon, seed=seed, interventions_budget=budget, controller_mode="dqn")
w.rl_controller = get_controller()
for r in range(GRID_ROWS):
for c in range(GRID_COLS):
add_intersection(w, _iid(r, c), position=(c, r), min_phase_ticks=6, max_phase_ticks=45)
# Horizontal roads (west→east), approach direction W
for r in range(GRID_ROWS):
add_road(w, f"R_src_W_{r}", f"SRC_W_{r}", _iid(r, 0), approach="W", length=SOURCE_LEN)
for c in range(GRID_COLS - 1):
add_road(w, f"R_h_{r}_{c}", _iid(r, c), _iid(r, c + 1), approach="W", length=ROAD_LEN)
add_road(w, f"R_sink_E_{r}", _iid(r, GRID_COLS - 1), f"SINK_E_{r}", approach="W", length=SINK_LEN)
# Vertical roads (south→north), approach direction S
for c in range(GRID_COLS):
add_road(w, f"R_src_S_{c}", f"SRC_S_{c}", _iid(0, c), approach="S", length=SOURCE_LEN)
for r in range(GRID_ROWS - 1):
add_road(w, f"R_v_{r}_{c}", _iid(r, c), _iid(r + 1, c), approach="S", length=ROAD_LEN)
add_road(w, f"R_sink_N_{c}", _iid(GRID_ROWS - 1, c), f"SINK_N_{c}", approach="S", length=SINK_LEN)
# Wire each intersection
for r in range(GRID_ROWS):
for c in range(GRID_COLS):
inc = {}
out = {}
inc["W"] = f"R_src_W_{r}" if c == 0 else f"R_h_{r}_{c - 1}"
out["E"] = f"R_sink_E_{r}" if c == GRID_COLS - 1 else f"R_h_{r}_{c}"
inc["S"] = f"R_src_S_{c}" if r == 0 else f"R_v_{r - 1}_{c}"
out["N"] = f"R_sink_N_{c}" if r == GRID_ROWS - 1 else f"R_v_{r}_{c}"
wire(w, _iid(r, c), incoming=inc, outgoing=out)
connect_neighbors(w)
# Corridors
for r in range(GRID_ROWS):
add_corridor(w, f"ew_row_{r}", [_iid(r, c) for c in range(GRID_COLS)], direction="W")
for c in range(GRID_COLS):
add_corridor(w, f"ns_col_{c}", [_iid(r, c) for r in range(GRID_ROWS)], direction="S")
return w
def _ew_route(row: int) -> list[str]:
route = [f"R_src_W_{row}"]
for c in range(GRID_COLS - 1):
route.append(f"R_h_{row}_{c}")
route.append(f"R_sink_E_{row}")
return route
def _ns_route(col: int) -> list[str]:
route = [f"R_src_S_{col}"]
for r in range(GRID_ROWS - 1):
route.append(f"R_v_{r}_{col}")
route.append(f"R_sink_N_{col}")
return route
# ── Task 1: grid_balanced (Easy) ─────────────────────────────────────────
def _build_grid_balanced(seed: int) -> World:
w = _build_grid("grid_balanced", seed, horizon=250, budget=6)
for r in range(GRID_ROWS):
spawn_stream(w, 2 + r, w.horizon - 30, 7, f"EW_{r}", "civilian", _ew_route(r), jitter=0.3)
for c in range(GRID_COLS):
spawn_stream(w, 4 + c, w.horizon - 30, 9, f"NS_{c}", "civilian", _ns_route(c), jitter=0.3)
# One ambulance crossing east-west through the middle
amb_tick = 80 + int(w.rng.integers(0, 30))
spawn(w, amb_tick, "AMB_1", "ambulance", _ew_route(1))
return w
# ── Task 2: demand_shift (Medium) ────────────────────────────────────────
def _build_demand_shift(seed: int) -> World:
w = _build_grid("demand_shift", seed, horizon=300, budget=6)
flip_tick = 140 + int(w.rng.integers(0, 20))
# Phase A: heavy north-south, light east-west
for c in range(GRID_COLS):
spawn_stream(w, 2 + c, flip_tick, 4, f"NS_A_{c}", "civilian", _ns_route(c), jitter=0.25)
for r in range(GRID_ROWS):
spawn_stream(w, 4 + r, flip_tick, 18, f"EW_A_{r}", "civilian", _ew_route(r), jitter=0.3)
# Phase B: demand flips β€” heavy east-west, light north-south
for r in range(GRID_ROWS):
spawn_stream(w, flip_tick, w.horizon - 30, 4, f"EW_B_{r}", "civilian", _ew_route(r), jitter=0.25)
for c in range(GRID_COLS):
spawn_stream(w, flip_tick, w.horizon - 30, 18, f"NS_B_{c}", "civilian", _ns_route(c), jitter=0.3)
# Ambulance during transition
amb_tick = flip_tick + int(w.rng.integers(5, 20))
spawn(w, amb_tick, "AMB_1", "ambulance", _ns_route(2))
return w
# ── Task 3: incident_corridor (Hard) ─────────────────────────────────────
def _build_incident_corridor(seed: int) -> World:
w = _build_grid("incident_corridor", seed, horizon=280, budget=8)
for r in range(GRID_ROWS):
spawn_stream(w, 2 + r, w.horizon - 30, 8, f"EW_{r}", "civilian", _ew_route(r), jitter=0.3)
for c in range(GRID_COLS):
spawn_stream(w, 4 + c, w.horizon - 30, 9, f"NS_{c}", "civilian", _ns_route(c), jitter=0.3)
# Incident blocks row 1 mid-corridor
inc_tick = 50 + int(w.rng.integers(0, 15))
inc_end = inc_tick + 150 + int(w.rng.integers(0, 30))
schedule_incident(w, inc_tick, "INC_1", "R_h_1_1", "accident", inc_end)
# Ambulance whose route goes through blocked road
amb_tick = inc_tick + 20 + int(w.rng.integers(0, 10))
spawn(w, amb_tick, "AMB_1", "ambulance", _ew_route(1))
# Fire truck on north-south (doesn't hit incident but needs preempt)
fire_tick = inc_tick + 60 + int(w.rng.integers(0, 20))
spawn(w, fire_tick, "FIRE_1", "fire", _ns_route(3))
return w
# ── Task 4: rush_hour_wave (Hard) ────────────────────────────────────────
def _build_rush_hour_wave(seed: int) -> World:
w = _build_grid("rush_hour_wave", seed, horizon=280, budget=8)
surge_tick = 90 + int(w.rng.integers(0, 20))
# Phase 1: light balanced traffic
for r in range(GRID_ROWS):
spawn_stream(w, 2 + r, surge_tick, 14, f"EW_L_{r}", "civilian", _ew_route(r), jitter=0.3)
for c in range(GRID_COLS):
spawn_stream(w, 4 + c, surge_tick, 14, f"NS_L_{c}", "civilian", _ns_route(c), jitter=0.3)
# Phase 2: demand triples from the south β€” wave ripples north
for c in range(GRID_COLS):
spawn_stream(w, surge_tick, w.horizon - 30, 3, f"NS_H_{c}", "civilian", _ns_route(c), jitter=0.25)
# East-west stays moderate
for r in range(GRID_ROWS):
spawn_stream(w, surge_tick, w.horizon - 30, 10, f"EW_H_{r}", "civilian", _ew_route(r), jitter=0.3)
# Police car during peak
pol_tick = surge_tick + 30 + int(w.rng.integers(0, 20))
spawn(w, pol_tick, "POLICE_1", "police", _ew_route(2))
return w
# ── Task 5: multi_crisis (Expert) ────────────────────────────────────────
def _build_multi_crisis(seed: int) -> World:
w = _build_grid("multi_crisis", seed, horizon=320, budget=12)
# Moderate asymmetric traffic (heavier east-west)
for r in range(GRID_ROWS):
spawn_stream(w, 2 + r, w.horizon - 40, 6, f"EW_{r}", "civilian", _ew_route(r), jitter=0.3)
for c in range(GRID_COLS):
spawn_stream(w, 4 + c, w.horizon - 40, 10, f"NS_{c}", "civilian", _ns_route(c), jitter=0.3)
# Incident 1: blocks row 0 at tick ~50
inc1_tick = 45 + int(w.rng.integers(0, 15))
inc1_end = inc1_tick + 120 + int(w.rng.integers(0, 30))
schedule_incident(w, inc1_tick, "INC_1", "R_h_0_1", "accident", inc1_end)
# Incident 2: blocks column 2 at tick ~130
inc2_tick = 120 + int(w.rng.integers(0, 20))
inc2_end = inc2_tick + 100 + int(w.rng.integers(0, 20))
schedule_incident(w, inc2_tick, "INC_2", "R_v_2_2", "construction", inc2_end)
# Emergency 1: ambulance through incident 1 zone
amb_tick = inc1_tick + 20 + int(w.rng.integers(0, 10))
spawn(w, amb_tick, "AMB_1", "ambulance", _ew_route(0))
# Emergency 2: fire truck through incident 2 zone
fire_tick = inc2_tick + 15 + int(w.rng.integers(0, 10))
spawn(w, fire_tick, "FIRE_1", "fire", _ns_route(2))
# Emergency 3: police car diagonal route (late)
pol_tick = fire_tick + 40 + int(w.rng.integers(0, 15))
spawn(w, pol_tick, "POLICE_1", "police", _ew_route(3))
return w
_BUILDERS: dict[str, Callable[[int], World]] = {
"grid_balanced": _build_grid_balanced,
"demand_shift": _build_demand_shift,
"incident_corridor": _build_incident_corridor,
"rush_hour_wave": _build_rush_hour_wave,
"multi_crisis": _build_multi_crisis,
}