| from __future__ import annotations |
|
|
| import argparse |
| import csv |
| import json |
| from dataclasses import asdict, dataclass |
| from datetime import datetime, timezone |
| from pathlib import Path |
| from statistics import median |
| from time import perf_counter |
| from typing import Any |
| import sys |
|
|
| import numpy as np |
| from tqdm.auto import tqdm |
|
|
| REPO_ROOT = Path(__file__).resolve().parents[1] |
| if str(REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(REPO_ROOT)) |
|
|
| from district_llm.heuristic_guidance import HeuristicGuidanceConfig |
| from district_llm.inference import DistrictLLMInference |
| from district_llm.repair import RepairConfig |
| from district_llm.rl_guidance_wrapper import ( |
| BIAS_DECAY_SCHEDULES, |
| DistrictGuidedRLController, |
| FixedRLPolicyAdapter, |
| GATING_MODES, |
| GuidanceInfluenceConfig, |
| HeuristicGuidanceProvider, |
| LLMGuidanceProvider, |
| WRAPPER_MODES, |
| guidance_config_payload, |
| ) |
| from district_llm.summary_builder import DistrictStateSummaryBuilder |
| from env.observation_builder import ObservationConfig |
| from env.reward import RewardConfig |
| from env.traffic_env import EnvConfig |
| from env.utils import load_json |
| from training.cityflow_dataset import CityFlowDataset, ScenarioSpec |
| from training.train_local_policy import build_env |
|
|
|
|
| MODE_CHOICES: tuple[str, ...] = ( |
| "rl_only", |
| "rl_heuristic", |
| "rl_llm", |
| ) |
|
|
|
|
| @dataclass(frozen=True) |
| class EpisodePlan: |
| city_id: str |
| scenario: str |
| seed: int |
| episode_id: int |
| simulator_seed: int |
| scenario_spec: ScenarioSpec |
| seeded_scenario_spec: ScenarioSpec |
|
|
| def pairing_key(self) -> tuple[str, str, int, int]: |
| return (self.city_id, self.scenario, self.seed, self.episode_id) |
|
|
| def to_dict(self) -> dict[str, Any]: |
| return { |
| "city_id": self.city_id, |
| "scenario": self.scenario, |
| "seed": int(self.seed), |
| "episode_id": int(self.episode_id), |
| "simulator_seed": int(self.simulator_seed), |
| "config_path": str(self.seeded_scenario_spec.config_path), |
| } |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description=( |
| "Evaluate a fixed DQN checkpoint under rl_only, rl_heuristic, and " |
| "rl_llm district-guidance modes without changing the RL weights." |
| ) |
| ) |
| parser.add_argument( |
| "--rl-checkpoint", |
| required=True, |
| help="Path to the fixed DQN checkpoint used for all modes.", |
| ) |
| parser.add_argument( |
| "--llm-model-path", |
| default=None, |
| help="Model or adapter path used when rl_llm modes are enabled.", |
| ) |
| parser.add_argument( |
| "--modes", |
| nargs="+", |
| choices=MODE_CHOICES, |
| default=["rl_only", "rl_heuristic", "rl_llm"], |
| ) |
| parser.add_argument("--generated-root", default="data/generated") |
| parser.add_argument("--splits-root", default="data/splits") |
| parser.add_argument("--split", default="val", choices=("train", "val", "test")) |
| parser.add_argument("--cities", nargs="+", default=None) |
| parser.add_argument("--scenarios", nargs="+", default=None) |
| parser.add_argument("--num-episodes", type=int, default=1) |
| parser.add_argument("--seeds", nargs="+", type=int, default=[7, 11, 13]) |
| parser.add_argument( |
| "--max-episode-seconds", |
| type=int, |
| default=None, |
| help="Optional override for scenario horizon. Useful for cheap smoke tests.", |
| ) |
| parser.add_argument("--guidance-refresh-steps", type=int, default=10) |
| parser.add_argument("--guidance-persistence-steps", type=int, default=3) |
| parser.add_argument("--bias-strength", type=float, default=0.12) |
| parser.add_argument( |
| "--targeted-bias-strength", |
| "--target-only-bias-strength", |
| dest="targeted_bias_strength", |
| type=float, |
| default=0.18, |
| ) |
| parser.add_argument("--corridor-bias-strength", type=float, default=0.05) |
| parser.add_argument("--max-guidance-duration", type=int, default=10) |
| parser.add_argument("--max-intersections-affected", type=int, default=3) |
| parser.add_argument( |
| "--gating-mode", |
| choices=GATING_MODES, |
| default="always_on", |
| ) |
| parser.add_argument("--min-avg-queue-for-guidance", type=float, default=150.0) |
| parser.add_argument("--min-queue-imbalance-for-guidance", type=float, default=20.0) |
| parser.add_argument( |
| "--require-incident-or-spillback", |
| action=argparse.BooleanOptionalAction, |
| default=False, |
| ) |
| parser.add_argument( |
| "--allow-guidance-in-normal-conditions", |
| action=argparse.BooleanOptionalAction, |
| default=True, |
| ) |
| parser.add_argument( |
| "--enable-bias-decay", |
| action=argparse.BooleanOptionalAction, |
| default=True, |
| ) |
| parser.add_argument( |
| "--bias-decay-schedule", |
| choices=BIAS_DECAY_SCHEDULES, |
| default="linear", |
| ) |
| parser.add_argument( |
| "--apply-global-bias", |
| action=argparse.BooleanOptionalAction, |
| default=False, |
| ) |
| parser.add_argument( |
| "--apply-target-only", |
| action=argparse.BooleanOptionalAction, |
| default=True, |
| ) |
| parser.add_argument( |
| "--wrapper-modes", |
| "--wrapper-mode", |
| dest="wrapper_modes", |
| nargs="+", |
| choices=WRAPPER_MODES, |
| default=["target_only_soft"], |
| ) |
| parser.add_argument( |
| "--allow-only-visible-candidates", |
| action=argparse.BooleanOptionalAction, |
| default=True, |
| ) |
| parser.add_argument("--max-target-intersections", type=int, default=3) |
| parser.add_argument( |
| "--fallback-on-empty-targets", |
| action=argparse.BooleanOptionalAction, |
| default=True, |
| ) |
| parser.add_argument( |
| "--fallback-mode", |
| choices=("heuristic", "hold", "none"), |
| default="heuristic", |
| ) |
| parser.add_argument( |
| "--fallback-policy", |
| choices=("no_op", "hold_previous", "heuristic_weak"), |
| default="hold_previous", |
| ) |
| parser.add_argument( |
| "--log-guidance-debug", |
| action=argparse.BooleanOptionalAction, |
| default=False, |
| ) |
| parser.add_argument("--max-new-tokens", type=int, default=128) |
| parser.add_argument("--device", default=None) |
| parser.add_argument("--output-dir", default="artifacts/rl_guidance_eval") |
| parser.add_argument( |
| "--save-step-metrics", |
| action=argparse.BooleanOptionalAction, |
| default=False, |
| ) |
| parser.add_argument( |
| "--save-guidance-traces", |
| action=argparse.BooleanOptionalAction, |
| default=False, |
| ) |
| return parser.parse_args() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| if "rl_llm" in args.modes and not args.llm_model_path: |
| raise ValueError("--llm-model-path is required when rl_llm is selected.") |
|
|
| output_dir = Path(args.output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| seeded_config_root = output_dir / "seeded_configs" |
| seeded_config_root.mkdir(parents=True, exist_ok=True) |
|
|
| dataset = CityFlowDataset( |
| generated_root=args.generated_root, |
| splits_root=args.splits_root, |
| ) |
| dataset.generate_default_splits() |
| scenario_specs = resolve_scenario_specs(dataset=dataset, args=args) |
| episode_plans = build_episode_plans( |
| scenario_specs=scenario_specs, |
| seeds=args.seeds, |
| num_episodes=args.num_episodes, |
| seeded_config_root=seeded_config_root, |
| ) |
|
|
| rl_policy = FixedRLPolicyAdapter( |
| checkpoint_path=args.rl_checkpoint, |
| device=args.device, |
| ) |
| env_config = rl_policy.env_config or default_env_config() |
| if args.max_episode_seconds is not None: |
| env_config = EnvConfig( |
| simulator_interval=env_config.simulator_interval, |
| decision_interval=env_config.decision_interval, |
| min_green_time=env_config.min_green_time, |
| thread_num=env_config.thread_num, |
| max_episode_seconds=int(args.max_episode_seconds), |
| observation=env_config.observation, |
| reward=env_config.reward, |
| ) |
| controllers = build_mode_controllers( |
| args=args, |
| rl_policy=rl_policy, |
| ) |
| controller_specs = list(controllers.items()) |
|
|
| episode_rows: list[dict[str, Any]] = [] |
| step_rows: list[dict[str, Any]] = [] |
| guidance_trace_rows: list[dict[str, Any]] = [] |
| total_runs = len(episode_plans) * len(controller_specs) |
| progress = tqdm(total=total_runs, desc="RL guidance eval", unit="run") |
| try: |
| for plan_index, plan in enumerate(episode_plans, start=1): |
| tqdm.write( |
| "[episode-plan] " |
| f"{plan_index}/{len(episode_plans)} " |
| f"city={plan.city_id} " |
| f"scenario={plan.scenario} " |
| f"seed={plan.seed} " |
| f"episode_id={plan.episode_id} " |
| f"simulator_seed={plan.simulator_seed}" |
| ) |
| for mode_label, controller in controller_specs: |
| progress.set_postfix_str( |
| f"mode={mode_label} city={plan.city_id} scenario={plan.scenario} seed={plan.seed}" |
| ) |
| episode_row, mode_step_rows, mode_trace_rows = run_episode( |
| plan=plan, |
| mode_label=mode_label, |
| controller=controller, |
| env_config=env_config, |
| save_step_metrics=args.save_step_metrics, |
| save_guidance_traces=args.save_guidance_traces, |
| ) |
| episode_rows.append(episode_row) |
| step_rows.extend(mode_step_rows) |
| guidance_trace_rows.extend(mode_trace_rows) |
| tqdm.write( |
| "[episode-result] " |
| f"mode={mode_label} " |
| f"return={episode_row['total_return']:.3f} " |
| f"avg_queue={episode_row['avg_queue']:.3f} " |
| f"avg_wait={episode_row['avg_wait']:.3f} " |
| f"throughput={episode_row['throughput']:.3f}" |
| ) |
| progress.update(1) |
| finally: |
| progress.close() |
|
|
| config_payload = build_config_payload( |
| args=args, |
| env_config=env_config, |
| episode_plans=episode_plans, |
| ) |
| summary_payload = build_summary_payload( |
| episode_rows=episode_rows, |
| config_payload=config_payload, |
| ) |
|
|
| write_json(output_dir / "config.json", config_payload) |
| write_json(output_dir / "summary.json", summary_payload) |
| write_csv_rows(output_dir / "episode_metrics.csv", episode_rows) |
| write_jsonl(output_dir / "episode_metrics.jsonl", episode_rows) |
| episode_parquet_written = try_write_parquet(output_dir / "episode_metrics.parquet", episode_rows) |
|
|
| if args.save_step_metrics: |
| write_csv_rows(output_dir / "step_metrics.csv", step_rows) |
| write_jsonl(output_dir / "step_metrics.jsonl", step_rows) |
| try_write_parquet(output_dir / "step_metrics.parquet", step_rows) |
|
|
| if args.save_guidance_traces: |
| write_jsonl(output_dir / "guidance_traces.jsonl", guidance_trace_rows) |
|
|
| print(json.dumps(summary_payload, indent=2, sort_keys=True)) |
| if not episode_parquet_written: |
| print( |
| "[warning] episode_metrics.parquet was not written because neither pyarrow nor pandas " |
| "is available in the current Python environment." |
| ) |
|
|
|
|
| def resolve_scenario_specs( |
| dataset: CityFlowDataset, |
| args: argparse.Namespace, |
| ) -> list[ScenarioSpec]: |
| city_ids = list(args.cities) if args.cities else dataset.load_split(args.split) |
| scenario_specs: list[ScenarioSpec] = [] |
| for city_id in city_ids: |
| available_scenarios = dataset.scenarios_for_city(city_id) |
| if not available_scenarios: |
| raise ValueError(f"No scenarios found for city '{city_id}'.") |
| requested_scenarios = list(args.scenarios) if args.scenarios else available_scenarios |
| for scenario_name in requested_scenarios: |
| if scenario_name not in available_scenarios: |
| raise ValueError( |
| f"Scenario '{scenario_name}' is not available for city '{city_id}'. " |
| f"Available scenarios: {available_scenarios}" |
| ) |
| scenario_specs.append(dataset.build_scenario_spec(city_id, scenario_name)) |
| if not scenario_specs: |
| raise ValueError("No scenario specs were resolved for evaluation.") |
| return scenario_specs |
|
|
|
|
| def build_episode_plans( |
| scenario_specs: list[ScenarioSpec], |
| seeds: list[int], |
| num_episodes: int, |
| seeded_config_root: Path, |
| ) -> list[EpisodePlan]: |
| plans: list[EpisodePlan] = [] |
| for scenario_spec in scenario_specs: |
| for seed in seeds: |
| for episode_id in range(num_episodes): |
| simulator_seed = int(seed) * 1000 + int(episode_id) |
| seeded_spec = build_seeded_scenario_spec( |
| scenario_spec=scenario_spec, |
| simulator_seed=simulator_seed, |
| seeded_config_root=seeded_config_root, |
| ) |
| plans.append( |
| EpisodePlan( |
| city_id=scenario_spec.city_id, |
| scenario=scenario_spec.scenario_name, |
| seed=int(seed), |
| episode_id=int(episode_id), |
| simulator_seed=int(simulator_seed), |
| scenario_spec=scenario_spec, |
| seeded_scenario_spec=seeded_spec, |
| ) |
| ) |
| return plans |
|
|
|
|
| def build_seeded_scenario_spec( |
| scenario_spec: ScenarioSpec, |
| simulator_seed: int, |
| seeded_config_root: Path, |
| ) -> ScenarioSpec: |
| payload = load_json(scenario_spec.config_path) |
| payload["seed"] = int(simulator_seed) |
| destination_dir = ( |
| seeded_config_root |
| / scenario_spec.city_id |
| / scenario_spec.scenario_name |
| / f"seed_{int(simulator_seed):08d}" |
| ) |
| destination_dir.mkdir(parents=True, exist_ok=True) |
| config_path = destination_dir / "config.json" |
| write_json(config_path, payload) |
| return ScenarioSpec( |
| city_id=scenario_spec.city_id, |
| scenario_name=scenario_spec.scenario_name, |
| city_dir=scenario_spec.city_dir, |
| scenario_dir=scenario_spec.scenario_dir, |
| config_path=config_path, |
| roadnet_path=scenario_spec.roadnet_path, |
| district_map_path=scenario_spec.district_map_path, |
| metadata_path=scenario_spec.metadata_path, |
| ) |
|
|
|
|
| def build_mode_controllers( |
| args: argparse.Namespace, |
| rl_policy: FixedRLPolicyAdapter, |
| ) -> dict[str, DistrictGuidedRLController]: |
| heuristic_provider = HeuristicGuidanceProvider( |
| config=HeuristicGuidanceConfig( |
| max_target_intersections=args.max_target_intersections, |
| ) |
| ) |
|
|
| llm_inference = None |
| if "rl_llm" in args.modes: |
| llm_inference = DistrictLLMInference( |
| model_name_or_path=args.llm_model_path, |
| device=args.device, |
| repair_config=RepairConfig( |
| allow_only_visible_candidates=args.allow_only_visible_candidates, |
| max_target_intersections=args.max_target_intersections, |
| fallback_on_empty_targets=args.fallback_on_empty_targets, |
| fallback_mode=args.fallback_mode, |
| ), |
| ) |
|
|
| controllers: dict[str, DistrictGuidedRLController] = {} |
| for mode in args.modes: |
| if mode == "rl_only": |
| controllers["rl_only"] = DistrictGuidedRLController( |
| policy=rl_policy, |
| mode_source="rl_only", |
| summary_builder=None, |
| guidance_provider=None, |
| influence_config=GuidanceInfluenceConfig( |
| wrapper_mode="no_op", |
| bias_strength=args.bias_strength, |
| target_only_bias_strength=args.targeted_bias_strength, |
| corridor_bias_strength=args.corridor_bias_strength, |
| max_intersections_affected=args.max_intersections_affected, |
| guidance_refresh_steps=args.guidance_refresh_steps, |
| guidance_persistence_steps=args.guidance_persistence_steps, |
| max_guidance_duration=args.max_guidance_duration, |
| apply_global_bias=False, |
| apply_target_only=True, |
| gating_mode=args.gating_mode, |
| min_avg_queue_for_guidance=args.min_avg_queue_for_guidance, |
| min_queue_imbalance_for_guidance=args.min_queue_imbalance_for_guidance, |
| require_incident_or_spillback=args.require_incident_or_spillback, |
| allow_guidance_in_normal_conditions=args.allow_guidance_in_normal_conditions, |
| enable_bias_decay=args.enable_bias_decay, |
| bias_decay_schedule=args.bias_decay_schedule, |
| fallback_policy=args.fallback_policy, |
| log_guidance_debug=False, |
| ), |
| heuristic_provider=None, |
| ) |
| continue |
|
|
| for wrapper_mode in args.wrapper_modes: |
| influence_config = GuidanceInfluenceConfig( |
| wrapper_mode=wrapper_mode, |
| bias_strength=args.bias_strength, |
| target_only_bias_strength=args.targeted_bias_strength, |
| corridor_bias_strength=args.corridor_bias_strength, |
| max_intersections_affected=args.max_intersections_affected, |
| guidance_refresh_steps=args.guidance_refresh_steps, |
| guidance_persistence_steps=args.guidance_persistence_steps, |
| max_guidance_duration=args.max_guidance_duration, |
| apply_global_bias=args.apply_global_bias, |
| apply_target_only=args.apply_target_only, |
| gating_mode=args.gating_mode, |
| min_avg_queue_for_guidance=args.min_avg_queue_for_guidance, |
| min_queue_imbalance_for_guidance=args.min_queue_imbalance_for_guidance, |
| require_incident_or_spillback=args.require_incident_or_spillback, |
| allow_guidance_in_normal_conditions=args.allow_guidance_in_normal_conditions, |
| enable_bias_decay=args.enable_bias_decay, |
| bias_decay_schedule=args.bias_decay_schedule, |
| fallback_policy=args.fallback_policy, |
| log_guidance_debug=args.log_guidance_debug, |
| ) |
| summary_builder = DistrictStateSummaryBuilder( |
| top_k=3, |
| candidate_limit=max(6, int(args.max_target_intersections)), |
| ) |
| label = f"{mode}+{wrapper_mode}" |
| if mode == "rl_heuristic": |
| controllers[label] = DistrictGuidedRLController( |
| policy=rl_policy, |
| mode_source=mode, |
| summary_builder=summary_builder, |
| guidance_provider=heuristic_provider, |
| influence_config=influence_config, |
| heuristic_provider=heuristic_provider, |
| ) |
| continue |
|
|
| assert llm_inference is not None |
| controllers[label] = DistrictGuidedRLController( |
| policy=rl_policy, |
| mode_source=mode, |
| summary_builder=summary_builder, |
| guidance_provider=LLMGuidanceProvider( |
| inference=llm_inference, |
| max_new_tokens=args.max_new_tokens, |
| ), |
| influence_config=influence_config, |
| heuristic_provider=heuristic_provider, |
| ) |
| return controllers |
|
|
|
|
| def run_episode( |
| plan: EpisodePlan, |
| mode_label: str, |
| controller: DistrictGuidedRLController, |
| env_config: EnvConfig, |
| save_step_metrics: bool, |
| save_guidance_traces: bool, |
| show_step_progress: bool = True, |
| ) -> tuple[dict[str, Any], list[dict[str, Any]], list[dict[str, Any]]]: |
| env = build_env(env_config, plan.seeded_scenario_spec) |
| controller.reset() |
| observation_batch = env.reset() |
| estimated_steps = max( |
| 1, |
| int(np.ceil(float(env.max_episode_seconds) / float(env.env_config.decision_interval))), |
| ) |
| step_progress = None |
| if show_step_progress: |
| step_progress = tqdm( |
| total=estimated_steps, |
| desc=f"{mode_label} {plan.city_id}/{plan.scenario} seed={plan.seed}", |
| unit="step", |
| leave=False, |
| ) |
| episode_started = perf_counter() |
| wrapper_runtime_seconds = 0.0 |
| guidance_runtime_seconds = 0.0 |
| guidance_refresh_count = 0 |
| fallback_used_count = 0 |
| invalid_guidance_count = 0 |
| repaired_guidance_count = 0 |
| action_changes_vs_base = 0 |
| decision_steps = 0 |
| queue_series: list[float] = [] |
| wait_series: list[float] = [] |
| running_vehicle_series: list[float] = [] |
| spillback_total = 0.0 |
| spillback_event_steps = 0 |
| step_rows: list[dict[str, Any]] = [] |
| guidance_trace_rows: list[dict[str, Any]] = [] |
| scenario_metadata = load_scenario_metadata(plan.scenario_spec) |
| done = False |
|
|
| try: |
| while not done: |
| action_batch = controller.act(env=env, observation_batch=observation_batch) |
| wrapper_runtime_seconds += float(action_batch.runtime_seconds) |
| decision_steps += 1 |
| action_changes_vs_base += int(np.sum(action_batch.actions != action_batch.base_actions)) |
|
|
| for trace in action_batch.refresh_traces: |
| guidance_refresh_count += 1 |
| guidance_runtime_seconds += float(trace.guidance.get("runtime_seconds", 0.0)) |
| fallback_used_count += int(trace.fallback_used) |
| invalid_guidance_count += int(trace.guidance.get("invalid_before_repair", False)) |
| repaired_guidance_count += int(trace.guidance.get("repair_applied", False)) |
| if save_guidance_traces: |
| guidance_trace_rows.append( |
| build_guidance_trace_row( |
| plan=plan, |
| mode_label=mode_label, |
| trace=trace, |
| controller=controller, |
| ) |
| ) |
|
|
| next_observation_batch, rewards, done, info = env.step(action_batch.actions) |
| metrics = info["metrics"] |
| queue_total = safe_float(metrics.get("total_incoming_vehicles")) |
| wait_total = safe_float(metrics.get("total_waiting_vehicles")) |
| queue_series.append(queue_total) |
| wait_series.append(wait_total) |
| running_vehicle_series.append(safe_float(metrics.get("running_vehicles"))) |
|
|
| spillback_intersections = estimate_spillback_intersections(observation_batch) |
| spillback_total += float(spillback_intersections) |
| spillback_event_steps += int(spillback_intersections > 0) |
|
|
| if step_progress is not None: |
| step_progress.set_postfix_str( |
| " ".join( |
| [ |
| f"sim={int(info['sim_time'])}", |
| f"queue={queue_total:.0f}", |
| f"wait={wait_total:.0f}", |
| f"thr={safe_float(metrics.get('throughput')):.0f}", |
| f"refresh={guidance_refresh_count}", |
| f"fallback={fallback_used_count}", |
| ] |
| ) |
| ) |
| step_progress.update(1) |
|
|
| if save_step_metrics: |
| step_rows.append( |
| build_step_row( |
| plan=plan, |
| mode_label=mode_label, |
| info=info, |
| action_batch=action_batch, |
| controller=controller, |
| spillback_intersections=spillback_intersections, |
| rewards=rewards, |
| ) |
| ) |
|
|
| observation_batch = next_observation_batch |
| finally: |
| if step_progress is not None: |
| step_progress.close() |
|
|
| episode_runtime_seconds = perf_counter() - episode_started |
| final_metrics = env.last_info["metrics"] |
| wrapper_debug = controller.episode_debug_summary() |
| episode_row = { |
| "mode": mode_label, |
| "mode_source": controller.mode_source, |
| "wrapper_mode": controller.influence_config.wrapper_mode, |
| "city_id": plan.city_id, |
| "scenario": plan.scenario, |
| "seed": int(plan.seed), |
| "episode_id": int(plan.episode_id), |
| "simulator_seed": int(plan.simulator_seed), |
| "total_return": safe_float(env.total_episode_return), |
| "mean_return": safe_float(env.episode_return), |
| "avg_queue": average(queue_series), |
| "max_queue": max_or_zero(queue_series), |
| "total_queue": float(sum(queue_series)), |
| "avg_wait": average(wait_series), |
| "max_wait": max_or_zero(wait_series), |
| "total_wait": float(sum(wait_series)), |
| "throughput": safe_float(final_metrics.get("throughput")), |
| "travel_time": safe_float(final_metrics.get("average_travel_time")), |
| "avg_running_vehicles": average(running_vehicle_series), |
| "max_running_vehicles": max_or_zero(running_vehicle_series), |
| "spillback_count": float(spillback_total), |
| "spillback_event_steps": float(spillback_event_steps), |
| "incident_scenario": float(bool(scenario_metadata.get("blocked_roads"))), |
| "construction_scenario": float(scenario_metadata.get("name") == "construction"), |
| "event_scenario": float(bool(scenario_metadata.get("event_district"))), |
| "overload_scenario": float(bool(scenario_metadata.get("overload_district"))), |
| "num_guidance_refreshes": float(guidance_refresh_count), |
| "runtime_seconds": float(episode_runtime_seconds), |
| "guidance_inference_seconds": float(guidance_runtime_seconds), |
| "wrapper_runtime_seconds": float(wrapper_runtime_seconds), |
| "fallback_used_count": float(fallback_used_count), |
| "invalid_guidance_count": float(invalid_guidance_count), |
| "repaired_guidance_count": float(repaired_guidance_count), |
| "action_changes_vs_base": float(action_changes_vs_base), |
| "decision_steps": float(env.decision_step_count), |
| "num_controlled_intersections": safe_float(final_metrics.get("num_controlled_intersections")), |
| } |
| episode_row.update(wrapper_debug) |
| return episode_row, step_rows, guidance_trace_rows |
|
|
|
|
| def build_step_row( |
| plan: EpisodePlan, |
| mode_label: str, |
| info: dict[str, Any], |
| action_batch, |
| controller: DistrictGuidedRLController, |
| spillback_intersections: int, |
| rewards: np.ndarray, |
| ) -> dict[str, Any]: |
| metrics = info["metrics"] |
| active_guidance = controller.active_guidance_snapshot() |
| return { |
| "mode": mode_label, |
| "mode_source": controller.mode_source, |
| "wrapper_mode": controller.influence_config.wrapper_mode, |
| "city_id": plan.city_id, |
| "scenario": plan.scenario, |
| "seed": int(plan.seed), |
| "episode_id": int(plan.episode_id), |
| "simulator_seed": int(plan.simulator_seed), |
| "step": int(info["decision_step"]), |
| "sim_time": int(info["sim_time"]), |
| "queue": safe_float(metrics.get("total_incoming_vehicles")), |
| "wait": safe_float(metrics.get("total_waiting_vehicles")), |
| "throughput": safe_float(metrics.get("throughput")), |
| "travel_time": safe_float(metrics.get("average_travel_time")), |
| "running_vehicles": safe_float(metrics.get("running_vehicles")), |
| "step_total_reward": float(np.asarray(rewards, dtype=np.float32).sum()), |
| "action_changes_vs_base": int(np.sum(action_batch.actions != action_batch.base_actions)), |
| "mean_abs_q_bias": float(np.abs(action_batch.q_bias).mean()), |
| "spillback_intersections": int(spillback_intersections), |
| "active_guidance_count": int(len(active_guidance)), |
| "active_guidance_json": json.dumps(active_guidance, sort_keys=True), |
| "selected_target_intersections_json": json.dumps( |
| collect_target_intersections(active_guidance), |
| sort_keys=True, |
| ), |
| "phase_bias_json": json.dumps( |
| { |
| district_id: payload.get("phase_bias") |
| for district_id, payload in sorted(active_guidance.items()) |
| }, |
| sort_keys=True, |
| ), |
| "priority_corridor_json": json.dumps( |
| { |
| district_id: payload.get("priority_corridor") |
| for district_id, payload in sorted(active_guidance.items()) |
| }, |
| sort_keys=True, |
| ), |
| } |
|
|
|
|
| def build_guidance_trace_row( |
| plan: EpisodePlan, |
| mode_label: str, |
| trace, |
| controller: DistrictGuidedRLController, |
| ) -> dict[str, Any]: |
| payload = trace.to_dict() |
| payload.update( |
| { |
| "mode": mode_label, |
| "wrapper_mode": controller.influence_config.wrapper_mode, |
| "city_id": plan.city_id, |
| "scenario": plan.scenario, |
| "seed": int(plan.seed), |
| "episode_id": int(plan.episode_id), |
| "simulator_seed": int(plan.simulator_seed), |
| "influence_config": guidance_config_payload(controller.influence_config), |
| } |
| ) |
| return payload |
|
|
|
|
| def estimate_spillback_intersections(observation_batch: dict[str, Any]) -> int: |
| incoming_totals = np.asarray(observation_batch["incoming_counts"], dtype=np.float32).sum(axis=1) |
| outgoing_load = np.asarray(observation_batch["outgoing_congestion"], dtype=np.float32) |
| boundary_mask = np.asarray(observation_batch["boundary_mask"], dtype=np.float32) > 0.0 |
| spillback_mask = outgoing_load >= np.maximum(8.0, incoming_totals * 0.5) |
| boundary_spillback = boundary_mask & (outgoing_load >= np.maximum(4.0, incoming_totals * 0.4)) |
| return int(np.sum(spillback_mask | boundary_spillback)) |
|
|
|
|
| def collect_target_intersections(active_guidance: dict[str, dict[str, Any]]) -> dict[str, list[str]]: |
| return { |
| district_id: list(payload.get("target_intersections", [])) |
| for district_id, payload in sorted(active_guidance.items()) |
| } |
|
|
|
|
| def load_scenario_metadata(scenario_spec: ScenarioSpec) -> dict[str, Any]: |
| metadata_path = scenario_spec.scenario_dir / "scenario_metadata.json" |
| return load_json(metadata_path) if metadata_path.exists() else {} |
|
|
|
|
| def build_summary_payload( |
| episode_rows: list[dict[str, Any]], |
| config_payload: dict[str, Any], |
| ) -> dict[str, Any]: |
| key_metrics = ( |
| "total_return", |
| "mean_return", |
| "avg_queue", |
| "avg_wait", |
| "throughput", |
| "travel_time", |
| "spillback_count", |
| "fallback_used_count", |
| "invalid_guidance_count", |
| "repaired_guidance_count", |
| "num_steps_guidance_blocked_by_gate", |
| "num_guidance_refreshes_blocked_by_gate", |
| "mean_bias_magnitude", |
| "max_bias_magnitude", |
| "avg_num_targeted_intersections", |
| "avg_num_affected_intersections", |
| "percent_steps_with_active_guidance", |
| "num_noop_guidance_events", |
| ) |
| metrics_by_mode: dict[str, Any] = {} |
| for mode in sorted({str(row["mode"]) for row in episode_rows}): |
| mode_rows = [row for row in episode_rows if row["mode"] == mode] |
| metrics_by_mode[mode] = { |
| metric_name: distribution_summary( |
| [safe_float(row.get(metric_name)) for row in mode_rows] |
| ) |
| for metric_name in key_metrics |
| } |
| metrics_by_mode[mode]["num_episodes"] = int(len(mode_rows)) |
| return { |
| "generated_at": datetime.now(timezone.utc).isoformat(), |
| "comparison_scope": config_payload["comparison_scope"], |
| "pairing_keys": ["city_id", "scenario", "seed", "episode_id"], |
| "metrics_by_mode": metrics_by_mode, |
| "analysis_summary": build_analysis_summary(episode_rows), |
| } |
|
|
|
|
| def build_config_payload( |
| args: argparse.Namespace, |
| env_config: EnvConfig, |
| episode_plans: list[EpisodePlan], |
| ) -> dict[str, Any]: |
| return { |
| "generated_at": datetime.now(timezone.utc).isoformat(), |
| "rl_checkpoint": str(args.rl_checkpoint), |
| "llm_model_path": args.llm_model_path, |
| "modes": list(args.modes), |
| "wrapper_modes": list(args.wrapper_modes), |
| "comparison_scope": { |
| "num_episode_plans": int(len(episode_plans)), |
| "cities": sorted({plan.city_id for plan in episode_plans}), |
| "scenarios": sorted({plan.scenario for plan in episode_plans}), |
| "seeds": sorted({int(plan.seed) for plan in episode_plans}), |
| "episodes_per_seed": int(args.num_episodes), |
| "total_mode_runs": int( |
| len( |
| [ |
| 1 |
| for mode in args.modes |
| for _wrapper in ([None] if mode == "rl_only" else args.wrapper_modes) |
| ] |
| ) |
| * len(episode_plans) |
| ), |
| }, |
| "episode_plans": [plan.to_dict() for plan in episode_plans], |
| "guidance_influence_config": guidance_config_payload( |
| GuidanceInfluenceConfig( |
| wrapper_mode=args.wrapper_modes[0], |
| bias_strength=args.bias_strength, |
| target_only_bias_strength=args.targeted_bias_strength, |
| corridor_bias_strength=args.corridor_bias_strength, |
| max_intersections_affected=args.max_intersections_affected, |
| guidance_refresh_steps=args.guidance_refresh_steps, |
| guidance_persistence_steps=args.guidance_persistence_steps, |
| max_guidance_duration=args.max_guidance_duration, |
| apply_global_bias=args.apply_global_bias, |
| apply_target_only=args.apply_target_only, |
| gating_mode=args.gating_mode, |
| min_avg_queue_for_guidance=args.min_avg_queue_for_guidance, |
| min_queue_imbalance_for_guidance=args.min_queue_imbalance_for_guidance, |
| require_incident_or_spillback=args.require_incident_or_spillback, |
| allow_guidance_in_normal_conditions=args.allow_guidance_in_normal_conditions, |
| enable_bias_decay=args.enable_bias_decay, |
| bias_decay_schedule=args.bias_decay_schedule, |
| fallback_policy=args.fallback_policy, |
| log_guidance_debug=args.log_guidance_debug, |
| ) |
| ), |
| "repair_config": asdict( |
| RepairConfig( |
| allow_only_visible_candidates=args.allow_only_visible_candidates, |
| max_target_intersections=args.max_target_intersections, |
| fallback_on_empty_targets=args.fallback_on_empty_targets, |
| fallback_mode=args.fallback_mode, |
| ) |
| ), |
| "heuristic_config": asdict( |
| HeuristicGuidanceConfig( |
| max_target_intersections=args.max_target_intersections, |
| ) |
| ), |
| "env_config": env_config_to_payload(env_config), |
| "save_step_metrics": bool(args.save_step_metrics), |
| "save_guidance_traces": bool(args.save_guidance_traces), |
| } |
|
|
|
|
| def build_analysis_summary(episode_rows: list[dict[str, Any]]) -> dict[str, Any]: |
| if not episode_rows: |
| return {} |
| rl_only_rows = [row for row in episode_rows if row["mode_source"] == "rl_only"] |
| rl_only_return = average([safe_float(row.get("total_return")) for row in rl_only_rows]) |
| guided_modes = sorted({str(row["mode"]) for row in episode_rows if row["mode_source"] != "rl_only"}) |
| ranked_guided = [] |
| mode_source_rows: dict[str, list[dict[str, Any]]] = {} |
| for row in episode_rows: |
| mode_source_rows.setdefault(str(row["mode_source"]), []).append(row) |
| for row_mode in guided_modes: |
| mode_rows = [item for item in episode_rows if item["mode"] == row_mode] |
| return_summary = distribution_summary([safe_float(item.get("total_return")) for item in mode_rows]) |
| queue_summary = distribution_summary([safe_float(item.get("avg_queue")) for item in mode_rows]) |
| fallback_summary = distribution_summary([safe_float(item.get("fallback_used_count")) for item in mode_rows]) |
| affected_summary = distribution_summary( |
| [safe_float(item.get("avg_num_affected_intersections")) for item in mode_rows] |
| ) |
| steps_summary = distribution_summary( |
| [safe_float(item.get("percent_steps_with_active_guidance")) for item in mode_rows] |
| ) |
| gate_block_summary = distribution_summary( |
| [safe_float(item.get("num_steps_guidance_blocked_by_gate")) for item in mode_rows] |
| ) |
| ranked_guided.append( |
| { |
| "mode": row_mode, |
| "mode_source": str(mode_rows[0]["mode_source"]), |
| "wrapper_mode": str(mode_rows[0]["wrapper_mode"]), |
| "mean_total_return": return_summary["mean"], |
| "return_delta_vs_rl_only": return_summary["mean"] - rl_only_return, |
| "mean_avg_queue": queue_summary["mean"], |
| "mean_fallback_used_count": fallback_summary["mean"], |
| "mean_avg_num_affected_intersections": affected_summary["mean"], |
| "mean_percent_steps_with_active_guidance": steps_summary["mean"], |
| "mean_num_steps_guidance_blocked_by_gate": gate_block_summary["mean"], |
| } |
| ) |
| ranked_guided.sort(key=lambda item: item["mean_total_return"], reverse=True) |
|
|
| mode_source_summary = { |
| mode_source: { |
| "mean_total_return": distribution_summary( |
| [safe_float(item.get("total_return")) for item in rows] |
| )["mean"], |
| "mean_fallback_used_count": distribution_summary( |
| [safe_float(item.get("fallback_used_count")) for item in rows] |
| )["mean"], |
| "mean_avg_num_affected_intersections": distribution_summary( |
| [safe_float(item.get("avg_num_affected_intersections")) for item in rows] |
| )["mean"], |
| "mean_num_steps_guidance_blocked_by_gate": distribution_summary( |
| [safe_float(item.get("num_steps_guidance_blocked_by_gate")) for item in rows] |
| )["mean"], |
| } |
| for mode_source, rows in sorted(mode_source_rows.items()) |
| } |
|
|
| heuristic_vs_llm_by_wrapper: list[dict[str, Any]] = [] |
| shared_wrappers = sorted( |
| { |
| str(row["wrapper_mode"]) |
| for row in episode_rows |
| if row["mode_source"] in {"rl_heuristic", "rl_llm"} |
| } |
| ) |
| for wrapper_mode in shared_wrappers: |
| heuristic_rows = [ |
| row for row in episode_rows if row["mode_source"] == "rl_heuristic" and row["wrapper_mode"] == wrapper_mode |
| ] |
| llm_rows = [ |
| row for row in episode_rows if row["mode_source"] == "rl_llm" and row["wrapper_mode"] == wrapper_mode |
| ] |
| if not heuristic_rows or not llm_rows: |
| continue |
| heuristic_return = distribution_summary([safe_float(row.get("total_return")) for row in heuristic_rows])["mean"] |
| llm_return = distribution_summary([safe_float(row.get("total_return")) for row in llm_rows])["mean"] |
| heuristic_vs_llm_by_wrapper.append( |
| { |
| "wrapper_mode": wrapper_mode, |
| "heuristic_mean_total_return": heuristic_return, |
| "llm_mean_total_return": llm_return, |
| "llm_minus_heuristic_return": llm_return - heuristic_return, |
| } |
| ) |
|
|
| aggressive_modes = [ |
| item |
| for item in ranked_guided |
| if item["wrapper_mode"] in {"global_soft", "current_legacy"} |
| ] |
| conservative_modes = [ |
| item |
| for item in ranked_guided |
| if item["wrapper_mode"] in {"no_op", "target_only_soft", "target_only_medium", "corridor_soft"} |
| ] |
| return { |
| "rl_only_mean_total_return": rl_only_return, |
| "guided_mode_rankings": ranked_guided, |
| "least_degrading_guided_mode": ranked_guided[0] if ranked_guided else None, |
| "mode_source_summary": mode_source_summary, |
| "heuristic_vs_llm_by_wrapper": heuristic_vs_llm_by_wrapper, |
| "conservative_guidance_modes": conservative_modes, |
| "aggressive_guidance_modes": aggressive_modes, |
| } |
|
|
|
|
| def distribution_summary(values: list[float]) -> dict[str, float]: |
| filtered = [float(value) for value in values if value is not None] |
| if not filtered: |
| return { |
| "count": 0.0, |
| "mean": 0.0, |
| "std": 0.0, |
| "median": 0.0, |
| "p25": 0.0, |
| "p75": 0.0, |
| "min": 0.0, |
| "max": 0.0, |
| } |
| array = np.asarray(filtered, dtype=np.float64) |
| return { |
| "count": float(array.size), |
| "mean": float(array.mean()), |
| "std": float(array.std(ddof=0)), |
| "median": float(median(filtered)), |
| "p25": float(np.percentile(array, 25)), |
| "p75": float(np.percentile(array, 75)), |
| "min": float(array.min()), |
| "max": float(array.max()), |
| } |
|
|
|
|
| def default_env_config() -> EnvConfig: |
| return EnvConfig( |
| simulator_interval=1, |
| decision_interval=5, |
| min_green_time=10, |
| thread_num=1, |
| max_episode_seconds=None, |
| observation=ObservationConfig(), |
| reward=RewardConfig(variant="wait_queue_throughput"), |
| ) |
|
|
|
|
| def env_config_to_payload(env_config: EnvConfig) -> dict[str, Any]: |
| return { |
| "simulator_interval": env_config.simulator_interval, |
| "decision_interval": env_config.decision_interval, |
| "min_green_time": env_config.min_green_time, |
| "thread_num": env_config.thread_num, |
| "max_episode_seconds": env_config.max_episode_seconds, |
| "observation": asdict(env_config.observation), |
| "reward": asdict(env_config.reward), |
| } |
|
|
|
|
| def write_json(path: Path, payload: Any) -> None: |
| path.parent.mkdir(parents=True, exist_ok=True) |
| path.write_text(json.dumps(to_jsonable(payload), indent=2, sort_keys=True) + "\n") |
|
|
|
|
| def write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None: |
| path.parent.mkdir(parents=True, exist_ok=True) |
| with path.open("w", encoding="utf-8") as handle: |
| for row in rows: |
| handle.write(json.dumps(to_jsonable(row), sort_keys=True) + "\n") |
|
|
|
|
| def write_csv_rows(path: Path, rows: list[dict[str, Any]]) -> None: |
| path.parent.mkdir(parents=True, exist_ok=True) |
| if not rows: |
| path.write_text("") |
| return |
| fieldnames = sorted({key for row in rows for key in row.keys()}) |
| with path.open("w", encoding="utf-8", newline="") as handle: |
| writer = csv.DictWriter(handle, fieldnames=fieldnames) |
| writer.writeheader() |
| for row in rows: |
| writer.writerow( |
| { |
| key: format_csv_value(row.get(key)) |
| for key in fieldnames |
| } |
| ) |
|
|
|
|
| def try_write_parquet(path: Path, rows: list[dict[str, Any]]) -> bool: |
| if not rows: |
| return False |
| json_ready_rows = [to_jsonable(row) for row in rows] |
| try: |
| import pyarrow as pa |
| import pyarrow.parquet as pq |
|
|
| table = pa.Table.from_pylist(json_ready_rows) |
| pq.write_table(table, path) |
| return True |
| except Exception: |
| pass |
|
|
| try: |
| import pandas as pd |
|
|
| frame = pd.DataFrame(json_ready_rows) |
| frame.to_parquet(path, index=False) |
| return True |
| except Exception: |
| return False |
|
|
|
|
| def to_jsonable(value: Any) -> Any: |
| if isinstance(value, dict): |
| return { |
| str(key): to_jsonable(item) |
| for key, item in value.items() |
| } |
| if isinstance(value, list): |
| return [to_jsonable(item) for item in value] |
| if isinstance(value, tuple): |
| return [to_jsonable(item) for item in value] |
| if isinstance(value, Path): |
| return str(value) |
| if isinstance(value, np.ndarray): |
| return value.tolist() |
| if isinstance(value, np.floating): |
| return float(value) |
| if isinstance(value, np.integer): |
| return int(value) |
| if isinstance(value, np.bool_): |
| return bool(value) |
| return value |
|
|
|
|
| def format_csv_value(value: Any) -> Any: |
| if isinstance(value, (dict, list, tuple)): |
| return json.dumps(to_jsonable(value), sort_keys=True) |
| return to_jsonable(value) |
|
|
|
|
| def average(values: list[float]) -> float: |
| if not values: |
| return 0.0 |
| return float(np.mean(np.asarray(values, dtype=np.float64))) |
|
|
|
|
| def max_or_zero(values: list[float]) -> float: |
| return float(max(values)) if values else 0.0 |
|
|
|
|
| def safe_float(value: Any) -> float: |
| if value is None: |
| return 0.0 |
| return float(value) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|