from __future__ import annotations import json from pathlib import Path import pandas as pd import streamlit as st import matplotlib.pyplot as plt from agents.district_coordinator import RuleBasedDistrictCoordinator from agents.local_policy import SharedHeuristicLocalPolicy from dashboard.metrics import flatten_directives, summarize_history from training.rollout import run_episode def make_env(): from env.traffic_env import TrafficEnv from env.intersection_config import IntersectionConfig, DistrictConfig intersections = { "I1": IntersectionConfig( intersection_id="I1", district_id="D0", incoming_lanes=["I1_N", "I1_S", "I1_E", "I1_W"], outgoing_lanes=[], neighbors=["I2"], is_border=False, ), "I2": IntersectionConfig( intersection_id="I2", district_id="D0", incoming_lanes=["I2_N", "I2_S", "I2_E", "I2_W"], outgoing_lanes=[], neighbors=["I1", "I3"], is_border=True, ), "I3": IntersectionConfig( intersection_id="I3", district_id="D1", incoming_lanes=["I3_N", "I3_S", "I3_E", "I3_W"], outgoing_lanes=[], neighbors=["I2", "I4"], is_border=True, ), "I4": IntersectionConfig( intersection_id="I4", district_id="D1", incoming_lanes=["I4_N", "I4_S", "I4_E", "I4_W"], outgoing_lanes=[], neighbors=["I3"], is_border=False, ), } districts = { "D0": DistrictConfig( district_id="D0", intersection_ids=["I1", "I2"], neighbor_districts=["D1"], ), "D1": DistrictConfig( district_id="D1", intersection_ids=["I3", "I4"], neighbor_districts=["D0"], ), } return TrafficEnv( config_path="data/cityflow/config.json", intersections=intersections, districts=districts, coordination_interval=20, max_steps=200, ) def build_history_frames(history: list[dict]): metric_rows = [] for row in history: metrics = row.get("metrics", {}) metric_rows.append( { "step": row.get("step", 0), "total_waiting": float(metrics.get("total_waiting", 0.0)), "total_queue": float(metrics.get("total_queue", 0.0)), "mean_reward": float(metrics.get("mean_reward", 0.0)), } ) metrics_df = pd.DataFrame(metric_rows) directives_df = pd.DataFrame(flatten_directives(history)) return metrics_df, directives_df def _load_json(path: Path) -> dict: return json.loads(path.read_text(encoding="utf-8")) def _list_generated_cities(root: Path) -> list[str]: if not root.exists(): return [] return sorted( p.name for p in root.iterdir() if p.is_dir() and p.name.startswith("city_") ) def _district_color_map(district_ids: list[str]) -> dict[str, tuple[float, float, float, float]]: cmap = plt.cm.get_cmap("tab20", max(1, len(district_ids))) return {did: cmap(idx) for idx, did in enumerate(district_ids)} def _plot_city_geometry( roadnet: dict, district_map: dict | None, show_gateways: bool, show_districts: bool, show_labels: bool, ): fig, ax = plt.subplots(figsize=(10, 10)) intersections = { node["id"]: (float(node["point"]["x"]), float(node["point"]["y"])) for node in roadnet.get("intersections", []) } roads = roadnet.get("roads", []) intersection_to_district = ( district_map.get("intersection_to_district", {}) if district_map else {} ) district_ids = sorted(set(intersection_to_district.values())) district_colors = _district_color_map(district_ids) if district_ids else {} gateway_nodes = set(district_map.get("gateway_intersections", [])) if district_map else set() gateway_roads = set(district_map.get("gateway_roads", [])) if district_map else set() for road in roads: points = road.get("points", []) if len(points) < 2: continue x = [points[0]["x"], points[-1]["x"]] y = [points[0]["y"], points[-1]["y"]] color = "#7f8c8d" width = 0.8 alpha = 0.45 if show_gateways and road["id"] in gateway_roads: color = "#f39c12" width = 1.8 alpha = 0.95 elif show_districts: start = road.get("startIntersection") end = road.get("endIntersection") ds = intersection_to_district.get(start) de = intersection_to_district.get(end) if ds and ds == de: color = district_colors.get(ds, color) width = 1.0 alpha = 0.7 elif ds and de and ds != de: color = "#2c3e50" width = 1.25 alpha = 0.9 ax.plot(x, y, color=color, linewidth=width, alpha=alpha, solid_capstyle="round") if show_districts and district_colors: for district_id in district_ids: nodes = [ nid for nid, did in intersection_to_district.items() if did == district_id and nid in intersections ] if not nodes: continue xs = [intersections[n][0] for n in nodes] ys = [intersections[n][1] for n in nodes] ax.scatter(xs, ys, s=14, color=district_colors[district_id], alpha=0.8, label=district_id) else: xs = [p[0] for p in intersections.values()] ys = [p[1] for p in intersections.values()] ax.scatter(xs, ys, s=8, color="#34495e", alpha=0.65) if show_gateways and gateway_nodes: gxs = [intersections[n][0] for n in sorted(gateway_nodes) if n in intersections] gys = [intersections[n][1] for n in sorted(gateway_nodes) if n in intersections] ax.scatter(gxs, gys, s=44, color="#d35400", edgecolors="#1c2833", linewidths=0.6, zorder=10, label="gateway") if show_labels: for nid, (x, y) in intersections.items(): if nid in gateway_nodes: ax.text(x, y, nid, fontsize=6, color="#922b21") ax.set_aspect("equal") ax.set_xlabel("X") ax.set_ylabel("Y") ax.set_title("Roadnet Viewer") ax.grid(True, alpha=0.15) if show_districts or show_gateways: ax.legend(loc="upper right", fontsize=7, frameon=True) return fig def main(): st.set_page_config(page_title="DistrictFlow Dashboard", layout="wide") st.title("DistrictFlow Dashboard") st.caption("Multi-agent traffic control with district-level coordination") col1, col2, col3 = st.columns(3) seed = col1.number_input("Seed", min_value=0, max_value=100000, value=0, step=1) max_steps = col2.slider( "Max steps", min_value=50, max_value=500, value=200, step=10 ) use_district_coordination = col3.checkbox( "Enable district coordination", value=True ) st.subheader("Generated City Viewer") root_dir = Path( st.text_input("Generated dataset dir", value="data/generated") ) cities = _list_generated_cities(root_dir) if not cities: st.info("No generated cities found in the selected directory.") else: selected_city = st.selectbox("City", options=cities, index=0) show_districts = st.checkbox("Show district overlay", value=True) show_gateways = st.checkbox("Show perimeter gateways", value=True) show_gateway_labels = st.checkbox("Label gateways", value=False) city_dir = root_dir / selected_city roadnet_path = city_dir / "roadnet.json" district_map_path = city_dir / "district_map.json" if roadnet_path.exists(): roadnet = _load_json(roadnet_path) district_map = _load_json(district_map_path) if district_map_path.exists() else None fig = _plot_city_geometry( roadnet=roadnet, district_map=district_map, show_gateways=show_gateways, show_districts=show_districts, show_labels=show_gateway_labels, ) st.pyplot(fig, use_container_width=True) else: st.warning(f"Missing roadnet file: {roadnet_path}") if st.button("Run Simulation", use_container_width=True): env = make_env() env.max_steps = max_steps local_policy = SharedHeuristicLocalPolicy() district_coordinators = {} if use_district_coordination: district_coordinators = { "D0": RuleBasedDistrictCoordinator(), "D1": RuleBasedDistrictCoordinator(), } result = run_episode( env=env, local_policy=local_policy, district_coordinators=district_coordinators, seed=int(seed), max_steps=max_steps, record_history=True, policy_update=False, ) summary = summarize_history(result.history) metrics_df, directives_df = build_history_frames(result.history) s1, s2, s3, s4 = st.columns(4) s1.metric("Avg Waiting", f"{summary['avg_total_waiting']:.2f}") s2.metric("Avg Queue", f"{summary['avg_total_queue']:.2f}") s3.metric("Avg Reward", f"{summary['avg_mean_reward']:.2f}") s4.metric("Steps", int(summary["num_steps"])) if not metrics_df.empty: st.subheader("Simulation Metrics") fig1 = plt.figure() plt.plot(metrics_df["step"], metrics_df["total_waiting"]) plt.xlabel("Step") plt.ylabel("Total Waiting") plt.title("Total Waiting Over Time") st.pyplot(fig1) fig2 = plt.figure() plt.plot(metrics_df["step"], metrics_df["total_queue"]) plt.xlabel("Step") plt.ylabel("Total Queue") plt.title("Total Queue Over Time") st.pyplot(fig2) fig3 = plt.figure() plt.plot(metrics_df["step"], metrics_df["mean_reward"]) plt.xlabel("Step") plt.ylabel("Mean Reward") plt.title("Mean Reward Over Time") st.pyplot(fig3) st.subheader("Raw Metrics") st.dataframe(metrics_df, use_container_width=True) if not directives_df.empty: st.subheader("District Directives") st.dataframe(directives_df, use_container_width=True) st.subheader("Final Info") st.json(result.final_info) if __name__ == "__main__": main()