Spaces:
Sleeping
Sleeping
| import sys | |
| sys.path.insert(0, ".") | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as patches | |
| import matplotlib.animation as animation | |
| import numpy as np | |
| from server.trafficops_environment import TrafficOpsEnvironment | |
| from models import TrafficOpsAction | |
| def run_episode(task="incident_corridor", seed=42): | |
| env = TrafficOpsEnvironment() | |
| obs = env.reset(seed=seed, task=task) | |
| frames = [obs] | |
| # Simple scripted agent | |
| while not obs.done: | |
| action = TrafficOpsAction(op="noop") | |
| # Reroute around incidents | |
| if obs.incidents: | |
| for inc in obs.incidents: | |
| if inc.active and not any(p.op == "reroute" for p in obs.active_plans): | |
| if inc.road_id == "R_h_1_1": | |
| action = TrafficOpsAction( | |
| op="reroute", targets=[inc.road_id], | |
| params={"blocked_road": inc.road_id, "detour": ["R_v_1_1", "R_h_2_1", "R_v_2_2"], "duration_ticks": 200}, | |
| reason=f"reroute around {inc.id}") | |
| break | |
| # Preempt for emergencies | |
| if action.op == "noop" and obs.emergencies: | |
| for em in obs.emergencies: | |
| if not em.cleared and em.remaining_route: | |
| targets = [] | |
| direction = None | |
| for rid in em.remaining_route[:3]: | |
| for iv in obs.intersections: | |
| if any(rid == r for r in [iv.id]): | |
| continue | |
| for road in obs.roads: | |
| if road.id == rid: | |
| direction = road.approach_direction | |
| targets.append(road.to_node) | |
| targets = [t for t in targets if t.startswith("I_")][:3] | |
| if targets and direction: | |
| action = TrafficOpsAction( | |
| op="preempt", targets=targets, | |
| params={"direction": direction, "duration_ticks": 15}, | |
| reason=f"clear path for {em.id}") | |
| break | |
| obs = env.step(action) | |
| frames.append(obs) | |
| return frames | |
| def draw_frame(ax, obs, frame_num): | |
| ax.clear() | |
| ax.set_xlim(-1.5, 5.5) | |
| ax.set_ylim(-1.5, 5.5) | |
| ax.set_aspect("equal") | |
| ax.set_facecolor("#1a1a2e") | |
| # Title | |
| score_str = f" Score: {obs.final_score:.3f}" if obs.final_score else "" | |
| ax.set_title( | |
| f"TrafficOps β {obs.task} | Tick {obs.tick}/{obs.horizon} | " | |
| f"Budget: {obs.interventions_used}/{obs.interventions_budget}{score_str}", | |
| color="white", fontsize=11, fontweight="bold", pad=10 | |
| ) | |
| # Draw roads as gray lines | |
| for road in obs.roads: | |
| from_pos = None | |
| to_pos = None | |
| for iv in obs.intersections: | |
| if iv.id == road.from_node: | |
| from_pos = iv.position | |
| if iv.id == road.to_node: | |
| to_pos = iv.position | |
| if from_pos and to_pos: | |
| color = "#dc2626" if road.blocked else "#4a4a6a" | |
| lw = 3 if road.blocked else 1.5 | |
| ax.plot([from_pos[0], to_pos[0]], [from_pos[1], to_pos[1]], | |
| color=color, linewidth=lw, zorder=1) | |
| # Show occupancy as dots along road | |
| if road.occupancy > 0: | |
| mid_x = (from_pos[0] + to_pos[0]) / 2 | |
| mid_y = (from_pos[1] + to_pos[1]) / 2 | |
| ax.annotate(str(road.occupancy), (mid_x, mid_y), | |
| color="#94a3b8", fontsize=6, ha="center", va="center") | |
| # Draw source/sink arrows | |
| for road in obs.roads: | |
| if road.from_node.startswith("SRC"): | |
| to_pos = None | |
| for iv in obs.intersections: | |
| if iv.id == road.to_node: | |
| to_pos = iv.position | |
| if to_pos: | |
| if "W" in road.from_node: | |
| ax.annotate("β", (to_pos[0] - 1.2, to_pos[1]), color="#64748b", fontsize=10, ha="center", va="center") | |
| elif "S" in road.from_node: | |
| ax.annotate("β", (to_pos[0], to_pos[1] - 1.2), color="#64748b", fontsize=10, ha="center", va="center") | |
| # Draw intersections | |
| for iv in obs.intersections: | |
| x, y = iv.position | |
| is_ns = "N" in iv.current_phase | |
| # Signal phase colors | |
| if iv.preempt_direction: | |
| color = "#fbbf24" # yellow for preempt | |
| edge = "#f59e0b" | |
| elif is_ns: | |
| color = "#22c55e" # green for NS | |
| edge = "#16a34a" | |
| else: | |
| color = "#3b82f6" # blue for EW | |
| edge = "#2563eb" | |
| circle = plt.Circle((x, y), 0.35, facecolor=color, edgecolor=edge, linewidth=2, zorder=5) | |
| ax.add_patch(circle) | |
| # Queue bars | |
| q_s = iv.queues.get("S", 0) | |
| q_w = iv.queues.get("W", 0) | |
| q_n = iv.queues.get("N", 0) | |
| q_e = iv.queues.get("E", 0) | |
| bar_scale = 0.08 | |
| max_bar = 0.8 | |
| if q_s > 0: | |
| h = min(q_s * bar_scale, max_bar) | |
| ax.add_patch(patches.Rectangle((x - 0.1, y - 0.4 - h), 0.2, h, | |
| facecolor="#ef4444", alpha=0.8, zorder=4)) | |
| if q_w > 0: | |
| w = min(q_w * bar_scale, max_bar) | |
| ax.add_patch(patches.Rectangle((x - 0.4 - w, y - 0.1), w, 0.2, | |
| facecolor="#f97316", alpha=0.8, zorder=4)) | |
| if q_n > 0: | |
| h = min(q_n * bar_scale, max_bar) | |
| ax.add_patch(patches.Rectangle((x - 0.1, y + 0.4), 0.2, h, | |
| facecolor="#ef4444", alpha=0.6, zorder=4)) | |
| if q_e > 0: | |
| w = min(q_e * bar_scale, max_bar) | |
| ax.add_patch(patches.Rectangle((x + 0.4, y - 0.1), w, 0.2, | |
| facecolor="#f97316", alpha=0.6, zorder=4)) | |
| # Intersection label | |
| label = iv.id.replace("I_", "") | |
| ax.text(x, y, label, color="white", fontsize=7, ha="center", va="center", | |
| fontweight="bold", zorder=6) | |
| # Preempt indicator | |
| if iv.preempt_direction: | |
| ax.text(x, y + 0.55, f"β‘{iv.preempt_direction}", color="#fbbf24", | |
| fontsize=7, ha="center", va="center", zorder=6) | |
| # Draw emergencies | |
| for em in obs.emergencies: | |
| if em.cleared: | |
| continue | |
| for iv in obs.intersections: | |
| for road in obs.roads: | |
| if road.id == em.current_road and road.to_node == iv.id: | |
| x, y = iv.position | |
| emoji = {"ambulance": "π", "fire": "π", "police": "π"}.get(em.type, "π¨") | |
| ax.text(x + 0.5, y + 0.5, emoji, fontsize=12, ha="center", va="center", zorder=10) | |
| # Draw incidents | |
| for inc in obs.incidents: | |
| if not inc.active: | |
| continue | |
| ax.text(0.02, 0.02 + obs.incidents.index(inc) * 0.04, | |
| f"π§ {inc.id}: {inc.kind} blocks {inc.road_id}", | |
| transform=ax.transAxes, color="#ef4444", fontsize=8, va="bottom") | |
| # Metrics bar at bottom | |
| m = obs.metrics | |
| if m: | |
| metrics_text = ( | |
| f"Cleared: {m.cleared_civilian}/{m.spawned_civilian} civ " | |
| f"{m.cleared_emergency}/{m.spawned_emergency} em | " | |
| f"Wait: avg={m.mean_wait_ticks:.0f} max={m.max_wait_ticks} | " | |
| f"Wasted: {m.wasted_green_ticks} | Gridlocks: {m.gridlock_events}" | |
| ) | |
| ax.text(0.5, -0.08, metrics_text, transform=ax.transAxes, | |
| color="#94a3b8", fontsize=8, ha="center", va="top") | |
| # Legend | |
| legend_items = [ | |
| ("β", "#22c55e", "N-S Green"), | |
| ("β", "#3b82f6", "E-W Green"), | |
| ("β", "#fbbf24", "Preempt"), | |
| ("β ", "#ef4444", "Queue S"), | |
| ("β ", "#f97316", "Queue W"), | |
| ("β", "#dc2626", "Blocked"), | |
| ] | |
| for i, (sym, col, label) in enumerate(legend_items): | |
| ax.text(4.5 + (i // 3) * 1.5, -1.0 - (i % 3) * 0.3, f"{sym} {label}", | |
| color=col, fontsize=7, va="center") | |
| ax.axis("off") | |
| def main(): | |
| task = sys.argv[1] if len(sys.argv) > 1 else "incident_corridor" | |
| seed = int(sys.argv[2]) if len(sys.argv) > 2 else 42 | |
| print(f"Running {task} (seed={seed})...") | |
| frames = run_episode(task, seed) | |
| print(f"Got {len(frames)} frames, final_score={frames[-1].final_score}") | |
| fig, ax = plt.subplots(1, 1, figsize=(10, 10)) | |
| fig.patch.set_facecolor("#0f0f23") | |
| def update(frame_num): | |
| if frame_num < len(frames): | |
| draw_frame(ax, frames[frame_num], frame_num) | |
| return [] | |
| ani = animation.FuncAnimation(fig, update, frames=len(frames), interval=300, repeat=True) | |
| # Save as GIF | |
| print("Saving trafficops_demo.gif...") | |
| ani.save("trafficops_demo.gif", writer="pillow", fps=4, dpi=100) | |
| print("Done! Open trafficops_demo.gif") | |
| plt.show() | |
| if __name__ == "__main__": | |
| main() | |