import sys sys.path.insert(0, ".") import matplotlib.pyplot as plt import matplotlib.patches as patches import matplotlib.animation as animation import numpy as np from server.trafficops_environment import TrafficOpsEnvironment from models import TrafficOpsAction def run_episode(task="incident_corridor", seed=42): env = TrafficOpsEnvironment() obs = env.reset(seed=seed, task=task) frames = [obs] # Simple scripted agent while not obs.done: action = TrafficOpsAction(op="noop") # Reroute around incidents if obs.incidents: for inc in obs.incidents: if inc.active and not any(p.op == "reroute" for p in obs.active_plans): if inc.road_id == "R_h_1_1": action = TrafficOpsAction( op="reroute", targets=[inc.road_id], params={"blocked_road": inc.road_id, "detour": ["R_v_1_1", "R_h_2_1", "R_v_2_2"], "duration_ticks": 200}, reason=f"reroute around {inc.id}") break # Preempt for emergencies if action.op == "noop" and obs.emergencies: for em in obs.emergencies: if not em.cleared and em.remaining_route: targets = [] direction = None for rid in em.remaining_route[:3]: for iv in obs.intersections: if any(rid == r for r in [iv.id]): continue for road in obs.roads: if road.id == rid: direction = road.approach_direction targets.append(road.to_node) targets = [t for t in targets if t.startswith("I_")][:3] if targets and direction: action = TrafficOpsAction( op="preempt", targets=targets, params={"direction": direction, "duration_ticks": 15}, reason=f"clear path for {em.id}") break obs = env.step(action) frames.append(obs) return frames def draw_frame(ax, obs, frame_num): ax.clear() ax.set_xlim(-1.5, 5.5) ax.set_ylim(-1.5, 5.5) ax.set_aspect("equal") ax.set_facecolor("#1a1a2e") # Title score_str = f" Score: {obs.final_score:.3f}" if obs.final_score else "" ax.set_title( f"TrafficOps — {obs.task} | Tick {obs.tick}/{obs.horizon} | " f"Budget: {obs.interventions_used}/{obs.interventions_budget}{score_str}", color="white", fontsize=11, fontweight="bold", pad=10 ) # Draw roads as gray lines for road in obs.roads: from_pos = None to_pos = None for iv in obs.intersections: if iv.id == road.from_node: from_pos = iv.position if iv.id == road.to_node: to_pos = iv.position if from_pos and to_pos: color = "#dc2626" if road.blocked else "#4a4a6a" lw = 3 if road.blocked else 1.5 ax.plot([from_pos[0], to_pos[0]], [from_pos[1], to_pos[1]], color=color, linewidth=lw, zorder=1) # Show occupancy as dots along road if road.occupancy > 0: mid_x = (from_pos[0] + to_pos[0]) / 2 mid_y = (from_pos[1] + to_pos[1]) / 2 ax.annotate(str(road.occupancy), (mid_x, mid_y), color="#94a3b8", fontsize=6, ha="center", va="center") # Draw source/sink arrows for road in obs.roads: if road.from_node.startswith("SRC"): to_pos = None for iv in obs.intersections: if iv.id == road.to_node: to_pos = iv.position if to_pos: if "W" in road.from_node: ax.annotate("→", (to_pos[0] - 1.2, to_pos[1]), color="#64748b", fontsize=10, ha="center", va="center") elif "S" in road.from_node: ax.annotate("↑", (to_pos[0], to_pos[1] - 1.2), color="#64748b", fontsize=10, ha="center", va="center") # Draw intersections for iv in obs.intersections: x, y = iv.position is_ns = "N" in iv.current_phase # Signal phase colors if iv.preempt_direction: color = "#fbbf24" # yellow for preempt edge = "#f59e0b" elif is_ns: color = "#22c55e" # green for NS edge = "#16a34a" else: color = "#3b82f6" # blue for EW edge = "#2563eb" circle = plt.Circle((x, y), 0.35, facecolor=color, edgecolor=edge, linewidth=2, zorder=5) ax.add_patch(circle) # Queue bars q_s = iv.queues.get("S", 0) q_w = iv.queues.get("W", 0) q_n = iv.queues.get("N", 0) q_e = iv.queues.get("E", 0) bar_scale = 0.08 max_bar = 0.8 if q_s > 0: h = min(q_s * bar_scale, max_bar) ax.add_patch(patches.Rectangle((x - 0.1, y - 0.4 - h), 0.2, h, facecolor="#ef4444", alpha=0.8, zorder=4)) if q_w > 0: w = min(q_w * bar_scale, max_bar) ax.add_patch(patches.Rectangle((x - 0.4 - w, y - 0.1), w, 0.2, facecolor="#f97316", alpha=0.8, zorder=4)) if q_n > 0: h = min(q_n * bar_scale, max_bar) ax.add_patch(patches.Rectangle((x - 0.1, y + 0.4), 0.2, h, facecolor="#ef4444", alpha=0.6, zorder=4)) if q_e > 0: w = min(q_e * bar_scale, max_bar) ax.add_patch(patches.Rectangle((x + 0.4, y - 0.1), w, 0.2, facecolor="#f97316", alpha=0.6, zorder=4)) # Intersection label label = iv.id.replace("I_", "") ax.text(x, y, label, color="white", fontsize=7, ha="center", va="center", fontweight="bold", zorder=6) # Preempt indicator if iv.preempt_direction: ax.text(x, y + 0.55, f"⚡{iv.preempt_direction}", color="#fbbf24", fontsize=7, ha="center", va="center", zorder=6) # Draw emergencies for em in obs.emergencies: if em.cleared: continue for iv in obs.intersections: for road in obs.roads: if road.id == em.current_road and road.to_node == iv.id: x, y = iv.position emoji = {"ambulance": "🚑", "fire": "🚒", "police": "🚓"}.get(em.type, "🚨") ax.text(x + 0.5, y + 0.5, emoji, fontsize=12, ha="center", va="center", zorder=10) # Draw incidents for inc in obs.incidents: if not inc.active: continue ax.text(0.02, 0.02 + obs.incidents.index(inc) * 0.04, f"🚧 {inc.id}: {inc.kind} blocks {inc.road_id}", transform=ax.transAxes, color="#ef4444", fontsize=8, va="bottom") # Metrics bar at bottom m = obs.metrics if m: metrics_text = ( f"Cleared: {m.cleared_civilian}/{m.spawned_civilian} civ " f"{m.cleared_emergency}/{m.spawned_emergency} em | " f"Wait: avg={m.mean_wait_ticks:.0f} max={m.max_wait_ticks} | " f"Wasted: {m.wasted_green_ticks} | Gridlocks: {m.gridlock_events}" ) ax.text(0.5, -0.08, metrics_text, transform=ax.transAxes, color="#94a3b8", fontsize=8, ha="center", va="top") # Legend legend_items = [ ("●", "#22c55e", "N-S Green"), ("●", "#3b82f6", "E-W Green"), ("●", "#fbbf24", "Preempt"), ("■", "#ef4444", "Queue S"), ("■", "#f97316", "Queue W"), ("—", "#dc2626", "Blocked"), ] for i, (sym, col, label) in enumerate(legend_items): ax.text(4.5 + (i // 3) * 1.5, -1.0 - (i % 3) * 0.3, f"{sym} {label}", color=col, fontsize=7, va="center") ax.axis("off") def main(): task = sys.argv[1] if len(sys.argv) > 1 else "incident_corridor" seed = int(sys.argv[2]) if len(sys.argv) > 2 else 42 print(f"Running {task} (seed={seed})...") frames = run_episode(task, seed) print(f"Got {len(frames)} frames, final_score={frames[-1].final_score}") fig, ax = plt.subplots(1, 1, figsize=(10, 10)) fig.patch.set_facecolor("#0f0f23") def update(frame_num): if frame_num < len(frames): draw_frame(ax, frames[frame_num], frame_num) return [] ani = animation.FuncAnimation(fig, update, frames=len(frames), interval=300, repeat=True) # Save as GIF print("Saving trafficops_demo.gif...") ani.save("trafficops_demo.gif", writer="pillow", fps=4, dpi=100) print("Done! Open trafficops_demo.gif") plt.show() if __name__ == "__main__": main()