Spaces:
Sleeping
Sleeping
| from typing import Callable | |
| from .sim.builders import ( | |
| add_corridor, | |
| add_intersection, | |
| add_road, | |
| connect_neighbors, | |
| new_world, | |
| schedule_incident, | |
| spawn, | |
| spawn_stream, | |
| wire, | |
| ) | |
| from .sim.rl_controller import get_controller | |
| from .sim.world import World | |
| TASK_IDS = [ | |
| "grid_balanced", | |
| "demand_shift", | |
| "incident_corridor", | |
| "rush_hour_wave", | |
| "multi_crisis", | |
| ] | |
| GRID_ROWS = 4 | |
| GRID_COLS = 4 | |
| ROAD_LEN = 8 | |
| SOURCE_LEN = 6 | |
| SINK_LEN = 5 | |
| def build(task: str, seed: int) -> World: | |
| builder = _BUILDERS.get(task) | |
| if builder is None: | |
| raise ValueError(f"unknown task: {task}") | |
| return builder(seed) | |
| def _iid(r: int, c: int) -> str: | |
| return f"I_{r}_{c}" | |
| def _build_grid( | |
| task: str, | |
| seed: int, | |
| horizon: int, | |
| budget: int, | |
| ) -> World: | |
| w = new_world(task, horizon=horizon, seed=seed, interventions_budget=budget, controller_mode="dqn") | |
| w.rl_controller = get_controller() | |
| for r in range(GRID_ROWS): | |
| for c in range(GRID_COLS): | |
| add_intersection(w, _iid(r, c), position=(c, r), min_phase_ticks=6, max_phase_ticks=45) | |
| # Horizontal roads (westβeast), approach direction W | |
| for r in range(GRID_ROWS): | |
| add_road(w, f"R_src_W_{r}", f"SRC_W_{r}", _iid(r, 0), approach="W", length=SOURCE_LEN) | |
| for c in range(GRID_COLS - 1): | |
| add_road(w, f"R_h_{r}_{c}", _iid(r, c), _iid(r, c + 1), approach="W", length=ROAD_LEN) | |
| add_road(w, f"R_sink_E_{r}", _iid(r, GRID_COLS - 1), f"SINK_E_{r}", approach="W", length=SINK_LEN) | |
| # Vertical roads (southβnorth), approach direction S | |
| for c in range(GRID_COLS): | |
| add_road(w, f"R_src_S_{c}", f"SRC_S_{c}", _iid(0, c), approach="S", length=SOURCE_LEN) | |
| for r in range(GRID_ROWS - 1): | |
| add_road(w, f"R_v_{r}_{c}", _iid(r, c), _iid(r + 1, c), approach="S", length=ROAD_LEN) | |
| add_road(w, f"R_sink_N_{c}", _iid(GRID_ROWS - 1, c), f"SINK_N_{c}", approach="S", length=SINK_LEN) | |
| # Wire each intersection | |
| for r in range(GRID_ROWS): | |
| for c in range(GRID_COLS): | |
| inc = {} | |
| out = {} | |
| inc["W"] = f"R_src_W_{r}" if c == 0 else f"R_h_{r}_{c - 1}" | |
| out["E"] = f"R_sink_E_{r}" if c == GRID_COLS - 1 else f"R_h_{r}_{c}" | |
| inc["S"] = f"R_src_S_{c}" if r == 0 else f"R_v_{r - 1}_{c}" | |
| out["N"] = f"R_sink_N_{c}" if r == GRID_ROWS - 1 else f"R_v_{r}_{c}" | |
| wire(w, _iid(r, c), incoming=inc, outgoing=out) | |
| connect_neighbors(w) | |
| # Corridors | |
| for r in range(GRID_ROWS): | |
| add_corridor(w, f"ew_row_{r}", [_iid(r, c) for c in range(GRID_COLS)], direction="W") | |
| for c in range(GRID_COLS): | |
| add_corridor(w, f"ns_col_{c}", [_iid(r, c) for r in range(GRID_ROWS)], direction="S") | |
| return w | |
| def _ew_route(row: int) -> list[str]: | |
| route = [f"R_src_W_{row}"] | |
| for c in range(GRID_COLS - 1): | |
| route.append(f"R_h_{row}_{c}") | |
| route.append(f"R_sink_E_{row}") | |
| return route | |
| def _ns_route(col: int) -> list[str]: | |
| route = [f"R_src_S_{col}"] | |
| for r in range(GRID_ROWS - 1): | |
| route.append(f"R_v_{r}_{col}") | |
| route.append(f"R_sink_N_{col}") | |
| return route | |
| # ββ Task 1: grid_balanced (Easy) βββββββββββββββββββββββββββββββββββββββββ | |
| def _build_grid_balanced(seed: int) -> World: | |
| w = _build_grid("grid_balanced", seed, horizon=250, budget=6) | |
| for r in range(GRID_ROWS): | |
| spawn_stream(w, 2 + r, w.horizon - 30, 7, f"EW_{r}", "civilian", _ew_route(r), jitter=0.3) | |
| for c in range(GRID_COLS): | |
| spawn_stream(w, 4 + c, w.horizon - 30, 9, f"NS_{c}", "civilian", _ns_route(c), jitter=0.3) | |
| # One ambulance crossing east-west through the middle | |
| amb_tick = 80 + int(w.rng.integers(0, 30)) | |
| spawn(w, amb_tick, "AMB_1", "ambulance", _ew_route(1)) | |
| return w | |
| # ββ Task 2: demand_shift (Medium) ββββββββββββββββββββββββββββββββββββββββ | |
| def _build_demand_shift(seed: int) -> World: | |
| w = _build_grid("demand_shift", seed, horizon=300, budget=6) | |
| flip_tick = 140 + int(w.rng.integers(0, 20)) | |
| # Phase A: heavy north-south, light east-west | |
| for c in range(GRID_COLS): | |
| spawn_stream(w, 2 + c, flip_tick, 4, f"NS_A_{c}", "civilian", _ns_route(c), jitter=0.25) | |
| for r in range(GRID_ROWS): | |
| spawn_stream(w, 4 + r, flip_tick, 18, f"EW_A_{r}", "civilian", _ew_route(r), jitter=0.3) | |
| # Phase B: demand flips β heavy east-west, light north-south | |
| for r in range(GRID_ROWS): | |
| spawn_stream(w, flip_tick, w.horizon - 30, 4, f"EW_B_{r}", "civilian", _ew_route(r), jitter=0.25) | |
| for c in range(GRID_COLS): | |
| spawn_stream(w, flip_tick, w.horizon - 30, 18, f"NS_B_{c}", "civilian", _ns_route(c), jitter=0.3) | |
| # Ambulance during transition | |
| amb_tick = flip_tick + int(w.rng.integers(5, 20)) | |
| spawn(w, amb_tick, "AMB_1", "ambulance", _ns_route(2)) | |
| return w | |
| # ββ Task 3: incident_corridor (Hard) βββββββββββββββββββββββββββββββββββββ | |
| def _build_incident_corridor(seed: int) -> World: | |
| w = _build_grid("incident_corridor", seed, horizon=280, budget=8) | |
| for r in range(GRID_ROWS): | |
| spawn_stream(w, 2 + r, w.horizon - 30, 8, f"EW_{r}", "civilian", _ew_route(r), jitter=0.3) | |
| for c in range(GRID_COLS): | |
| spawn_stream(w, 4 + c, w.horizon - 30, 9, f"NS_{c}", "civilian", _ns_route(c), jitter=0.3) | |
| # Incident blocks row 1 mid-corridor | |
| inc_tick = 50 + int(w.rng.integers(0, 15)) | |
| inc_end = inc_tick + 150 + int(w.rng.integers(0, 30)) | |
| schedule_incident(w, inc_tick, "INC_1", "R_h_1_1", "accident", inc_end) | |
| # Ambulance whose route goes through blocked road | |
| amb_tick = inc_tick + 20 + int(w.rng.integers(0, 10)) | |
| spawn(w, amb_tick, "AMB_1", "ambulance", _ew_route(1)) | |
| # Fire truck on north-south (doesn't hit incident but needs preempt) | |
| fire_tick = inc_tick + 60 + int(w.rng.integers(0, 20)) | |
| spawn(w, fire_tick, "FIRE_1", "fire", _ns_route(3)) | |
| return w | |
| # ββ Task 4: rush_hour_wave (Hard) ββββββββββββββββββββββββββββββββββββββββ | |
| def _build_rush_hour_wave(seed: int) -> World: | |
| w = _build_grid("rush_hour_wave", seed, horizon=280, budget=8) | |
| surge_tick = 90 + int(w.rng.integers(0, 20)) | |
| # Phase 1: light balanced traffic | |
| for r in range(GRID_ROWS): | |
| spawn_stream(w, 2 + r, surge_tick, 14, f"EW_L_{r}", "civilian", _ew_route(r), jitter=0.3) | |
| for c in range(GRID_COLS): | |
| spawn_stream(w, 4 + c, surge_tick, 14, f"NS_L_{c}", "civilian", _ns_route(c), jitter=0.3) | |
| # Phase 2: demand triples from the south β wave ripples north | |
| for c in range(GRID_COLS): | |
| spawn_stream(w, surge_tick, w.horizon - 30, 3, f"NS_H_{c}", "civilian", _ns_route(c), jitter=0.25) | |
| # East-west stays moderate | |
| for r in range(GRID_ROWS): | |
| spawn_stream(w, surge_tick, w.horizon - 30, 10, f"EW_H_{r}", "civilian", _ew_route(r), jitter=0.3) | |
| # Police car during peak | |
| pol_tick = surge_tick + 30 + int(w.rng.integers(0, 20)) | |
| spawn(w, pol_tick, "POLICE_1", "police", _ew_route(2)) | |
| return w | |
| # ββ Task 5: multi_crisis (Expert) ββββββββββββββββββββββββββββββββββββββββ | |
| def _build_multi_crisis(seed: int) -> World: | |
| w = _build_grid("multi_crisis", seed, horizon=320, budget=12) | |
| # Moderate asymmetric traffic (heavier east-west) | |
| for r in range(GRID_ROWS): | |
| spawn_stream(w, 2 + r, w.horizon - 40, 6, f"EW_{r}", "civilian", _ew_route(r), jitter=0.3) | |
| for c in range(GRID_COLS): | |
| spawn_stream(w, 4 + c, w.horizon - 40, 10, f"NS_{c}", "civilian", _ns_route(c), jitter=0.3) | |
| # Incident 1: blocks row 0 at tick ~50 | |
| inc1_tick = 45 + int(w.rng.integers(0, 15)) | |
| inc1_end = inc1_tick + 120 + int(w.rng.integers(0, 30)) | |
| schedule_incident(w, inc1_tick, "INC_1", "R_h_0_1", "accident", inc1_end) | |
| # Incident 2: blocks column 2 at tick ~130 | |
| inc2_tick = 120 + int(w.rng.integers(0, 20)) | |
| inc2_end = inc2_tick + 100 + int(w.rng.integers(0, 20)) | |
| schedule_incident(w, inc2_tick, "INC_2", "R_v_2_2", "construction", inc2_end) | |
| # Emergency 1: ambulance through incident 1 zone | |
| amb_tick = inc1_tick + 20 + int(w.rng.integers(0, 10)) | |
| spawn(w, amb_tick, "AMB_1", "ambulance", _ew_route(0)) | |
| # Emergency 2: fire truck through incident 2 zone | |
| fire_tick = inc2_tick + 15 + int(w.rng.integers(0, 10)) | |
| spawn(w, fire_tick, "FIRE_1", "fire", _ns_route(2)) | |
| # Emergency 3: police car diagonal route (late) | |
| pol_tick = fire_tick + 40 + int(w.rng.integers(0, 15)) | |
| spawn(w, pol_tick, "POLICE_1", "police", _ew_route(3)) | |
| return w | |
| _BUILDERS: dict[str, Callable[[int], World]] = { | |
| "grid_balanced": _build_grid_balanced, | |
| "demand_shift": _build_demand_shift, | |
| "incident_corridor": _build_incident_corridor, | |
| "rush_hour_wave": _build_rush_hour_wave, | |
| "multi_crisis": _build_multi_crisis, | |
| } | |