trafficops / visualize.py
Kunalsinghh's picture
Upload folder using huggingface_hub
96cf49e verified
import sys
sys.path.insert(0, ".")
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.animation as animation
import numpy as np
from server.trafficops_environment import TrafficOpsEnvironment
from models import TrafficOpsAction
def run_episode(task="incident_corridor", seed=42):
env = TrafficOpsEnvironment()
obs = env.reset(seed=seed, task=task)
frames = [obs]
# Simple scripted agent
while not obs.done:
action = TrafficOpsAction(op="noop")
# Reroute around incidents
if obs.incidents:
for inc in obs.incidents:
if inc.active and not any(p.op == "reroute" for p in obs.active_plans):
if inc.road_id == "R_h_1_1":
action = TrafficOpsAction(
op="reroute", targets=[inc.road_id],
params={"blocked_road": inc.road_id, "detour": ["R_v_1_1", "R_h_2_1", "R_v_2_2"], "duration_ticks": 200},
reason=f"reroute around {inc.id}")
break
# Preempt for emergencies
if action.op == "noop" and obs.emergencies:
for em in obs.emergencies:
if not em.cleared and em.remaining_route:
targets = []
direction = None
for rid in em.remaining_route[:3]:
for iv in obs.intersections:
if any(rid == r for r in [iv.id]):
continue
for road in obs.roads:
if road.id == rid:
direction = road.approach_direction
targets.append(road.to_node)
targets = [t for t in targets if t.startswith("I_")][:3]
if targets and direction:
action = TrafficOpsAction(
op="preempt", targets=targets,
params={"direction": direction, "duration_ticks": 15},
reason=f"clear path for {em.id}")
break
obs = env.step(action)
frames.append(obs)
return frames
def draw_frame(ax, obs, frame_num):
ax.clear()
ax.set_xlim(-1.5, 5.5)
ax.set_ylim(-1.5, 5.5)
ax.set_aspect("equal")
ax.set_facecolor("#1a1a2e")
# Title
score_str = f" Score: {obs.final_score:.3f}" if obs.final_score else ""
ax.set_title(
f"TrafficOps β€” {obs.task} | Tick {obs.tick}/{obs.horizon} | "
f"Budget: {obs.interventions_used}/{obs.interventions_budget}{score_str}",
color="white", fontsize=11, fontweight="bold", pad=10
)
# Draw roads as gray lines
for road in obs.roads:
from_pos = None
to_pos = None
for iv in obs.intersections:
if iv.id == road.from_node:
from_pos = iv.position
if iv.id == road.to_node:
to_pos = iv.position
if from_pos and to_pos:
color = "#dc2626" if road.blocked else "#4a4a6a"
lw = 3 if road.blocked else 1.5
ax.plot([from_pos[0], to_pos[0]], [from_pos[1], to_pos[1]],
color=color, linewidth=lw, zorder=1)
# Show occupancy as dots along road
if road.occupancy > 0:
mid_x = (from_pos[0] + to_pos[0]) / 2
mid_y = (from_pos[1] + to_pos[1]) / 2
ax.annotate(str(road.occupancy), (mid_x, mid_y),
color="#94a3b8", fontsize=6, ha="center", va="center")
# Draw source/sink arrows
for road in obs.roads:
if road.from_node.startswith("SRC"):
to_pos = None
for iv in obs.intersections:
if iv.id == road.to_node:
to_pos = iv.position
if to_pos:
if "W" in road.from_node:
ax.annotate("β†’", (to_pos[0] - 1.2, to_pos[1]), color="#64748b", fontsize=10, ha="center", va="center")
elif "S" in road.from_node:
ax.annotate("↑", (to_pos[0], to_pos[1] - 1.2), color="#64748b", fontsize=10, ha="center", va="center")
# Draw intersections
for iv in obs.intersections:
x, y = iv.position
is_ns = "N" in iv.current_phase
# Signal phase colors
if iv.preempt_direction:
color = "#fbbf24" # yellow for preempt
edge = "#f59e0b"
elif is_ns:
color = "#22c55e" # green for NS
edge = "#16a34a"
else:
color = "#3b82f6" # blue for EW
edge = "#2563eb"
circle = plt.Circle((x, y), 0.35, facecolor=color, edgecolor=edge, linewidth=2, zorder=5)
ax.add_patch(circle)
# Queue bars
q_s = iv.queues.get("S", 0)
q_w = iv.queues.get("W", 0)
q_n = iv.queues.get("N", 0)
q_e = iv.queues.get("E", 0)
bar_scale = 0.08
max_bar = 0.8
if q_s > 0:
h = min(q_s * bar_scale, max_bar)
ax.add_patch(patches.Rectangle((x - 0.1, y - 0.4 - h), 0.2, h,
facecolor="#ef4444", alpha=0.8, zorder=4))
if q_w > 0:
w = min(q_w * bar_scale, max_bar)
ax.add_patch(patches.Rectangle((x - 0.4 - w, y - 0.1), w, 0.2,
facecolor="#f97316", alpha=0.8, zorder=4))
if q_n > 0:
h = min(q_n * bar_scale, max_bar)
ax.add_patch(patches.Rectangle((x - 0.1, y + 0.4), 0.2, h,
facecolor="#ef4444", alpha=0.6, zorder=4))
if q_e > 0:
w = min(q_e * bar_scale, max_bar)
ax.add_patch(patches.Rectangle((x + 0.4, y - 0.1), w, 0.2,
facecolor="#f97316", alpha=0.6, zorder=4))
# Intersection label
label = iv.id.replace("I_", "")
ax.text(x, y, label, color="white", fontsize=7, ha="center", va="center",
fontweight="bold", zorder=6)
# Preempt indicator
if iv.preempt_direction:
ax.text(x, y + 0.55, f"⚑{iv.preempt_direction}", color="#fbbf24",
fontsize=7, ha="center", va="center", zorder=6)
# Draw emergencies
for em in obs.emergencies:
if em.cleared:
continue
for iv in obs.intersections:
for road in obs.roads:
if road.id == em.current_road and road.to_node == iv.id:
x, y = iv.position
emoji = {"ambulance": "πŸš‘", "fire": "πŸš’", "police": "πŸš“"}.get(em.type, "🚨")
ax.text(x + 0.5, y + 0.5, emoji, fontsize=12, ha="center", va="center", zorder=10)
# Draw incidents
for inc in obs.incidents:
if not inc.active:
continue
ax.text(0.02, 0.02 + obs.incidents.index(inc) * 0.04,
f"🚧 {inc.id}: {inc.kind} blocks {inc.road_id}",
transform=ax.transAxes, color="#ef4444", fontsize=8, va="bottom")
# Metrics bar at bottom
m = obs.metrics
if m:
metrics_text = (
f"Cleared: {m.cleared_civilian}/{m.spawned_civilian} civ "
f"{m.cleared_emergency}/{m.spawned_emergency} em | "
f"Wait: avg={m.mean_wait_ticks:.0f} max={m.max_wait_ticks} | "
f"Wasted: {m.wasted_green_ticks} | Gridlocks: {m.gridlock_events}"
)
ax.text(0.5, -0.08, metrics_text, transform=ax.transAxes,
color="#94a3b8", fontsize=8, ha="center", va="top")
# Legend
legend_items = [
("●", "#22c55e", "N-S Green"),
("●", "#3b82f6", "E-W Green"),
("●", "#fbbf24", "Preempt"),
("β– ", "#ef4444", "Queue S"),
("β– ", "#f97316", "Queue W"),
("β€”", "#dc2626", "Blocked"),
]
for i, (sym, col, label) in enumerate(legend_items):
ax.text(4.5 + (i // 3) * 1.5, -1.0 - (i % 3) * 0.3, f"{sym} {label}",
color=col, fontsize=7, va="center")
ax.axis("off")
def main():
task = sys.argv[1] if len(sys.argv) > 1 else "incident_corridor"
seed = int(sys.argv[2]) if len(sys.argv) > 2 else 42
print(f"Running {task} (seed={seed})...")
frames = run_episode(task, seed)
print(f"Got {len(frames)} frames, final_score={frames[-1].final_score}")
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
fig.patch.set_facecolor("#0f0f23")
def update(frame_num):
if frame_num < len(frames):
draw_frame(ax, frames[frame_num], frame_num)
return []
ani = animation.FuncAnimation(fig, update, frames=len(frames), interval=300, repeat=True)
# Save as GIF
print("Saving trafficops_demo.gif...")
ani.save("trafficops_demo.gif", writer="pillow", fps=4, dpi=100)
print("Done! Open trafficops_demo.gif")
plt.show()
if __name__ == "__main__":
main()