agentic-traffic / dashboard /streamlit_app.py
Aditya2162's picture
Upload folder using huggingface_hub
3d2dbcf verified
from __future__ import annotations
import json
from pathlib import Path
import pandas as pd
import streamlit as st
import matplotlib.pyplot as plt
from agents.district_coordinator import RuleBasedDistrictCoordinator
from agents.local_policy import SharedHeuristicLocalPolicy
from dashboard.metrics import flatten_directives, summarize_history
from training.rollout import run_episode
def make_env():
from env.traffic_env import TrafficEnv
from env.intersection_config import IntersectionConfig, DistrictConfig
intersections = {
"I1": IntersectionConfig(
intersection_id="I1",
district_id="D0",
incoming_lanes=["I1_N", "I1_S", "I1_E", "I1_W"],
outgoing_lanes=[],
neighbors=["I2"],
is_border=False,
),
"I2": IntersectionConfig(
intersection_id="I2",
district_id="D0",
incoming_lanes=["I2_N", "I2_S", "I2_E", "I2_W"],
outgoing_lanes=[],
neighbors=["I1", "I3"],
is_border=True,
),
"I3": IntersectionConfig(
intersection_id="I3",
district_id="D1",
incoming_lanes=["I3_N", "I3_S", "I3_E", "I3_W"],
outgoing_lanes=[],
neighbors=["I2", "I4"],
is_border=True,
),
"I4": IntersectionConfig(
intersection_id="I4",
district_id="D1",
incoming_lanes=["I4_N", "I4_S", "I4_E", "I4_W"],
outgoing_lanes=[],
neighbors=["I3"],
is_border=False,
),
}
districts = {
"D0": DistrictConfig(
district_id="D0",
intersection_ids=["I1", "I2"],
neighbor_districts=["D1"],
),
"D1": DistrictConfig(
district_id="D1",
intersection_ids=["I3", "I4"],
neighbor_districts=["D0"],
),
}
return TrafficEnv(
config_path="data/cityflow/config.json",
intersections=intersections,
districts=districts,
coordination_interval=20,
max_steps=200,
)
def build_history_frames(history: list[dict]):
metric_rows = []
for row in history:
metrics = row.get("metrics", {})
metric_rows.append(
{
"step": row.get("step", 0),
"total_waiting": float(metrics.get("total_waiting", 0.0)),
"total_queue": float(metrics.get("total_queue", 0.0)),
"mean_reward": float(metrics.get("mean_reward", 0.0)),
}
)
metrics_df = pd.DataFrame(metric_rows)
directives_df = pd.DataFrame(flatten_directives(history))
return metrics_df, directives_df
def _load_json(path: Path) -> dict:
return json.loads(path.read_text(encoding="utf-8"))
def _list_generated_cities(root: Path) -> list[str]:
if not root.exists():
return []
return sorted(
p.name for p in root.iterdir() if p.is_dir() and p.name.startswith("city_")
)
def _district_color_map(district_ids: list[str]) -> dict[str, tuple[float, float, float, float]]:
cmap = plt.cm.get_cmap("tab20", max(1, len(district_ids)))
return {did: cmap(idx) for idx, did in enumerate(district_ids)}
def _plot_city_geometry(
roadnet: dict,
district_map: dict | None,
show_gateways: bool,
show_districts: bool,
show_labels: bool,
):
fig, ax = plt.subplots(figsize=(10, 10))
intersections = {
node["id"]: (float(node["point"]["x"]), float(node["point"]["y"]))
for node in roadnet.get("intersections", [])
}
roads = roadnet.get("roads", [])
intersection_to_district = (
district_map.get("intersection_to_district", {}) if district_map else {}
)
district_ids = sorted(set(intersection_to_district.values()))
district_colors = _district_color_map(district_ids) if district_ids else {}
gateway_nodes = set(district_map.get("gateway_intersections", [])) if district_map else set()
gateway_roads = set(district_map.get("gateway_roads", [])) if district_map else set()
for road in roads:
points = road.get("points", [])
if len(points) < 2:
continue
x = [points[0]["x"], points[-1]["x"]]
y = [points[0]["y"], points[-1]["y"]]
color = "#7f8c8d"
width = 0.8
alpha = 0.45
if show_gateways and road["id"] in gateway_roads:
color = "#f39c12"
width = 1.8
alpha = 0.95
elif show_districts:
start = road.get("startIntersection")
end = road.get("endIntersection")
ds = intersection_to_district.get(start)
de = intersection_to_district.get(end)
if ds and ds == de:
color = district_colors.get(ds, color)
width = 1.0
alpha = 0.7
elif ds and de and ds != de:
color = "#2c3e50"
width = 1.25
alpha = 0.9
ax.plot(x, y, color=color, linewidth=width, alpha=alpha, solid_capstyle="round")
if show_districts and district_colors:
for district_id in district_ids:
nodes = [
nid for nid, did in intersection_to_district.items() if did == district_id and nid in intersections
]
if not nodes:
continue
xs = [intersections[n][0] for n in nodes]
ys = [intersections[n][1] for n in nodes]
ax.scatter(xs, ys, s=14, color=district_colors[district_id], alpha=0.8, label=district_id)
else:
xs = [p[0] for p in intersections.values()]
ys = [p[1] for p in intersections.values()]
ax.scatter(xs, ys, s=8, color="#34495e", alpha=0.65)
if show_gateways and gateway_nodes:
gxs = [intersections[n][0] for n in sorted(gateway_nodes) if n in intersections]
gys = [intersections[n][1] for n in sorted(gateway_nodes) if n in intersections]
ax.scatter(gxs, gys, s=44, color="#d35400", edgecolors="#1c2833", linewidths=0.6, zorder=10, label="gateway")
if show_labels:
for nid, (x, y) in intersections.items():
if nid in gateway_nodes:
ax.text(x, y, nid, fontsize=6, color="#922b21")
ax.set_aspect("equal")
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_title("Roadnet Viewer")
ax.grid(True, alpha=0.15)
if show_districts or show_gateways:
ax.legend(loc="upper right", fontsize=7, frameon=True)
return fig
def main():
st.set_page_config(page_title="DistrictFlow Dashboard", layout="wide")
st.title("DistrictFlow Dashboard")
st.caption("Multi-agent traffic control with district-level coordination")
col1, col2, col3 = st.columns(3)
seed = col1.number_input("Seed", min_value=0, max_value=100000, value=0, step=1)
max_steps = col2.slider(
"Max steps", min_value=50, max_value=500, value=200, step=10
)
use_district_coordination = col3.checkbox(
"Enable district coordination", value=True
)
st.subheader("Generated City Viewer")
root_dir = Path(
st.text_input("Generated dataset dir", value="data/generated")
)
cities = _list_generated_cities(root_dir)
if not cities:
st.info("No generated cities found in the selected directory.")
else:
selected_city = st.selectbox("City", options=cities, index=0)
show_districts = st.checkbox("Show district overlay", value=True)
show_gateways = st.checkbox("Show perimeter gateways", value=True)
show_gateway_labels = st.checkbox("Label gateways", value=False)
city_dir = root_dir / selected_city
roadnet_path = city_dir / "roadnet.json"
district_map_path = city_dir / "district_map.json"
if roadnet_path.exists():
roadnet = _load_json(roadnet_path)
district_map = _load_json(district_map_path) if district_map_path.exists() else None
fig = _plot_city_geometry(
roadnet=roadnet,
district_map=district_map,
show_gateways=show_gateways,
show_districts=show_districts,
show_labels=show_gateway_labels,
)
st.pyplot(fig, use_container_width=True)
else:
st.warning(f"Missing roadnet file: {roadnet_path}")
if st.button("Run Simulation", use_container_width=True):
env = make_env()
env.max_steps = max_steps
local_policy = SharedHeuristicLocalPolicy()
district_coordinators = {}
if use_district_coordination:
district_coordinators = {
"D0": RuleBasedDistrictCoordinator(),
"D1": RuleBasedDistrictCoordinator(),
}
result = run_episode(
env=env,
local_policy=local_policy,
district_coordinators=district_coordinators,
seed=int(seed),
max_steps=max_steps,
record_history=True,
policy_update=False,
)
summary = summarize_history(result.history)
metrics_df, directives_df = build_history_frames(result.history)
s1, s2, s3, s4 = st.columns(4)
s1.metric("Avg Waiting", f"{summary['avg_total_waiting']:.2f}")
s2.metric("Avg Queue", f"{summary['avg_total_queue']:.2f}")
s3.metric("Avg Reward", f"{summary['avg_mean_reward']:.2f}")
s4.metric("Steps", int(summary["num_steps"]))
if not metrics_df.empty:
st.subheader("Simulation Metrics")
fig1 = plt.figure()
plt.plot(metrics_df["step"], metrics_df["total_waiting"])
plt.xlabel("Step")
plt.ylabel("Total Waiting")
plt.title("Total Waiting Over Time")
st.pyplot(fig1)
fig2 = plt.figure()
plt.plot(metrics_df["step"], metrics_df["total_queue"])
plt.xlabel("Step")
plt.ylabel("Total Queue")
plt.title("Total Queue Over Time")
st.pyplot(fig2)
fig3 = plt.figure()
plt.plot(metrics_df["step"], metrics_df["mean_reward"])
plt.xlabel("Step")
plt.ylabel("Mean Reward")
plt.title("Mean Reward Over Time")
st.pyplot(fig3)
st.subheader("Raw Metrics")
st.dataframe(metrics_df, use_container_width=True)
if not directives_df.empty:
st.subheader("District Directives")
st.dataframe(directives_df, use_container_width=True)
st.subheader("Final Info")
st.json(result.final_info)
if __name__ == "__main__":
main()