| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from dataclasses import asdict |
| from datetime import datetime, timezone |
| from pathlib import Path |
| from typing import Any |
| import sys |
|
|
| from tqdm.auto import tqdm |
|
|
| REPO_ROOT = Path(__file__).resolve().parents[1] |
| if str(REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(REPO_ROOT)) |
|
|
| from district_llm.heuristic_guidance import HeuristicGuidanceConfig |
| from district_llm.inference import DistrictLLMInference |
| from district_llm.repair import RepairConfig |
| from district_llm.rl_guidance_wrapper import ( |
| DistrictGuidedRLController, |
| FixedRLPolicyAdapter, |
| GuidanceInfluenceConfig, |
| HeuristicGuidanceProvider, |
| LLMGuidanceProvider, |
| guidance_config_payload, |
| ) |
| from district_llm.summary_builder import DistrictStateSummaryBuilder |
| from env.traffic_env import EnvConfig |
| from scripts.eval_rl_guidance_ablation import ( |
| build_episode_plans, |
| default_env_config, |
| distribution_summary, |
| env_config_to_payload, |
| run_episode, |
| safe_float, |
| try_write_parquet, |
| write_csv_rows, |
| write_json, |
| ) |
| from training.cityflow_dataset import CityFlowDataset |
|
|
|
|
| DEFAULT_SEEDS: tuple[int, ...] = (7,) |
| PREFERRED_DEFAULT_CITIES: tuple[str, ...] = ("city_0001",) |
| PREFERRED_DEFAULT_SCENARIOS: tuple[str, ...] = ("normal",) |
| SCENARIO_ALIASES: dict[str, str] = { |
| "rush": "morning_rush", |
| } |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description=( |
| "Quick paired evaluation for rl_only vs rl_heuristic vs rl_llm using the " |
| "best target_only_soft wrapper settings." |
| ) |
| ) |
| parser.add_argument("--rl-checkpoint", required=True) |
| parser.add_argument("--llm-model-path", required=True) |
| parser.add_argument("--generated-root", default="data/generated") |
| parser.add_argument("--splits-root", default="data/splits") |
| parser.add_argument("--split", default="val", choices=("train", "val", "test")) |
| parser.add_argument("--cities", nargs="+", default=None) |
| parser.add_argument("--scenarios", nargs="+", default=None) |
| parser.add_argument("--seeds", nargs="+", type=int, default=list(DEFAULT_SEEDS)) |
| parser.add_argument("--episodes-per-seed", type=int, default=1) |
| parser.add_argument( |
| "--max-episode-seconds", |
| type=int, |
| default=120, |
| help="Short default horizon so the quick check stays under roughly 10-20 minutes.", |
| ) |
| parser.add_argument("--max-new-tokens", type=int, default=128) |
| parser.add_argument("--device", default=None) |
| parser.add_argument("--output-dir", default="artifacts/quick_rl_llm_eval") |
| parser.add_argument( |
| "--allow-only-visible-candidates", |
| action=argparse.BooleanOptionalAction, |
| default=True, |
| ) |
| parser.add_argument("--max-target-intersections", type=int, default=3) |
| parser.add_argument( |
| "--fallback-on-empty-targets", |
| action=argparse.BooleanOptionalAction, |
| default=True, |
| ) |
| parser.add_argument( |
| "--fallback-mode", |
| choices=("heuristic", "hold", "none"), |
| default="heuristic", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| output_dir = Path(args.output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| seeded_config_root = output_dir / "seeded_configs" |
| seeded_config_root.mkdir(parents=True, exist_ok=True) |
|
|
| dataset = CityFlowDataset( |
| generated_root=args.generated_root, |
| splits_root=args.splits_root, |
| ) |
| dataset.generate_default_splits() |
|
|
| city_ids = resolve_quick_cities(dataset=dataset, requested_cities=args.cities) |
| scenario_specs = resolve_quick_scenario_specs( |
| dataset=dataset, |
| city_ids=city_ids, |
| requested_scenarios=args.scenarios, |
| ) |
| episode_plans = build_episode_plans( |
| scenario_specs=scenario_specs, |
| seeds=args.seeds, |
| num_episodes=args.episodes_per_seed, |
| seeded_config_root=seeded_config_root, |
| ) |
|
|
| rl_policy = FixedRLPolicyAdapter( |
| checkpoint_path=args.rl_checkpoint, |
| device=args.device, |
| ) |
| env_config = rl_policy.env_config or default_env_config() |
| env_config = EnvConfig( |
| simulator_interval=env_config.simulator_interval, |
| decision_interval=env_config.decision_interval, |
| min_green_time=env_config.min_green_time, |
| thread_num=env_config.thread_num, |
| max_episode_seconds=int(args.max_episode_seconds), |
| observation=env_config.observation, |
| reward=env_config.reward, |
| ) |
|
|
| tuned_config = GuidanceInfluenceConfig( |
| wrapper_mode="target_only_soft", |
| bias_strength=0.025, |
| target_only_bias_strength=0.025, |
| corridor_bias_strength=0.0125, |
| max_intersections_affected=2, |
| guidance_refresh_steps=10, |
| guidance_persistence_steps=5, |
| max_guidance_duration=10, |
| apply_global_bias=False, |
| apply_target_only=True, |
| gating_mode="queue_or_imbalance", |
| min_avg_queue_for_guidance=150.0, |
| min_queue_imbalance_for_guidance=20.0, |
| require_incident_or_spillback=False, |
| allow_guidance_in_normal_conditions=False, |
| enable_bias_decay=False, |
| bias_decay_schedule="linear", |
| fallback_policy="no_op", |
| log_guidance_debug=False, |
| ).validate() |
|
|
| controllers = build_controllers( |
| args=args, |
| rl_policy=rl_policy, |
| tuned_config=tuned_config, |
| ) |
|
|
| episode_rows: list[dict[str, Any]] = [] |
| rows_by_pair: dict[tuple[str, str, int, int], dict[str, dict[str, Any]]] = {} |
| total_runs = len(episode_plans) * len(controllers) |
| progress = tqdm(total=total_runs, desc="Quick RL+LLM eval", unit="run") |
| try: |
| for plan in episode_plans: |
| for mode_label, controller in controllers.items(): |
| progress.set_postfix_str( |
| f"mode={mode_label} city={plan.city_id} scenario={plan.scenario} seed={plan.seed}" |
| ) |
| episode_row, _, _ = run_episode( |
| plan=plan, |
| mode_label=mode_label, |
| controller=controller, |
| env_config=env_config, |
| save_step_metrics=False, |
| save_guidance_traces=False, |
| show_step_progress=False, |
| ) |
| episode_row = augment_episode_row(episode_row, tuned_config) |
| episode_rows.append(episode_row) |
| rows_by_pair.setdefault(plan.pairing_key(), {})[mode_label] = episode_row |
| progress.update(1) |
| finally: |
| progress.close() |
|
|
| paired_delta_rows = build_paired_delta_rows(rows_by_pair) |
| summary_payload = build_summary_payload( |
| episode_rows=episode_rows, |
| paired_delta_rows=paired_delta_rows, |
| tuned_config=tuned_config, |
| args=args, |
| scenario_specs=scenario_specs, |
| ) |
|
|
| write_csv_rows(output_dir / "episode_metrics.csv", episode_rows) |
| episode_parquet_written = try_write_parquet(output_dir / "episode_metrics.parquet", episode_rows) |
| write_csv_rows(output_dir / "paired_deltas.csv", paired_delta_rows) |
| try_write_parquet(output_dir / "paired_deltas.parquet", paired_delta_rows) |
| write_json(output_dir / "summary.json", summary_payload) |
|
|
| print(json.dumps(summary_payload, indent=2, sort_keys=True)) |
| if not episode_parquet_written: |
| print( |
| "[warning] episode_metrics.parquet was not written because neither pyarrow nor pandas " |
| "is available in the current Python environment." |
| ) |
|
|
|
|
| def resolve_quick_cities( |
| dataset: CityFlowDataset, |
| requested_cities: list[str] | None, |
| ) -> list[str]: |
| available = set(dataset.discover_cities()) |
| if requested_cities: |
| selected = [city_id for city_id in requested_cities if city_id in available] |
| if not selected: |
| raise ValueError(f"None of the requested cities are available: {requested_cities}") |
| return selected |
| defaults = [city_id for city_id in PREFERRED_DEFAULT_CITIES if city_id in available] |
| if defaults: |
| return defaults[:1] |
| discovered = sorted(available) |
| if not discovered: |
| raise ValueError("No generated cities were found under the generated-root.") |
| return discovered[:1] |
|
|
|
|
| def resolve_quick_scenario_specs( |
| dataset: CityFlowDataset, |
| city_ids: list[str], |
| requested_scenarios: list[str] | None, |
| ) -> list[Any]: |
| specs: list[Any] = [] |
| for city_id in city_ids: |
| available_scenarios = set(dataset.scenarios_for_city(city_id)) |
| if requested_scenarios: |
| desired = [ |
| SCENARIO_ALIASES.get(scenario_name, scenario_name) |
| for scenario_name in requested_scenarios |
| ] |
| else: |
| desired = [ |
| scenario_name |
| for scenario_name in PREFERRED_DEFAULT_SCENARIOS |
| if scenario_name in available_scenarios |
| ][:2] |
| selected = [scenario_name for scenario_name in desired if scenario_name in available_scenarios] |
| if not selected: |
| raise ValueError( |
| f"No requested/default scenarios are available for city '{city_id}'. " |
| f"Available scenarios: {sorted(available_scenarios)}" |
| ) |
| for scenario_name in selected: |
| specs.append(dataset.build_scenario_spec(city_id, scenario_name)) |
| if not specs: |
| raise ValueError("No scenario specs were resolved for the quick evaluation.") |
| return specs |
|
|
|
|
| def build_controllers( |
| args: argparse.Namespace, |
| rl_policy: FixedRLPolicyAdapter, |
| tuned_config: GuidanceInfluenceConfig, |
| ) -> dict[str, DistrictGuidedRLController]: |
| heuristic_provider = HeuristicGuidanceProvider( |
| config=HeuristicGuidanceConfig( |
| max_target_intersections=args.max_target_intersections, |
| ) |
| ) |
| llm_inference = DistrictLLMInference( |
| model_name_or_path=args.llm_model_path, |
| device=args.device, |
| repair_config=RepairConfig( |
| allow_only_visible_candidates=args.allow_only_visible_candidates, |
| max_target_intersections=args.max_target_intersections, |
| fallback_on_empty_targets=args.fallback_on_empty_targets, |
| fallback_mode=args.fallback_mode, |
| ), |
| ) |
| llm_provider = LLMGuidanceProvider( |
| inference=llm_inference, |
| max_new_tokens=args.max_new_tokens, |
| ) |
|
|
| def summary_builder() -> DistrictStateSummaryBuilder: |
| return DistrictStateSummaryBuilder( |
| top_k=3, |
| candidate_limit=max(6, int(args.max_target_intersections)), |
| ) |
|
|
| return { |
| "rl_only": DistrictGuidedRLController( |
| policy=rl_policy, |
| mode_source="rl_only", |
| summary_builder=None, |
| guidance_provider=None, |
| influence_config=GuidanceInfluenceConfig( |
| wrapper_mode="no_op", |
| bias_strength=0.0, |
| target_only_bias_strength=0.0, |
| corridor_bias_strength=0.0, |
| max_intersections_affected=1, |
| guidance_refresh_steps=tuned_config.guidance_refresh_steps, |
| guidance_persistence_steps=1, |
| max_guidance_duration=tuned_config.max_guidance_duration, |
| fallback_policy="no_op", |
| enable_bias_decay=False, |
| ), |
| heuristic_provider=None, |
| ), |
| "rl_heuristic": DistrictGuidedRLController( |
| policy=rl_policy, |
| mode_source="rl_heuristic", |
| summary_builder=summary_builder(), |
| guidance_provider=heuristic_provider, |
| influence_config=tuned_config, |
| heuristic_provider=heuristic_provider, |
| ), |
| "rl_llm": DistrictGuidedRLController( |
| policy=rl_policy, |
| mode_source="rl_llm", |
| summary_builder=summary_builder(), |
| guidance_provider=llm_provider, |
| influence_config=tuned_config, |
| heuristic_provider=heuristic_provider, |
| ), |
| } |
|
|
|
|
| def augment_episode_row( |
| row: dict[str, Any], |
| tuned_config: GuidanceInfluenceConfig, |
| ) -> dict[str, Any]: |
| payload = dict(row) |
| payload.update( |
| { |
| "wrapper_mode": tuned_config.wrapper_mode if row["mode"] != "rl_only" else "no_op", |
| "bias_strength": 0.0 if row["mode"] == "rl_only" else tuned_config.bias_strength, |
| "target_only_bias_strength": 0.0 |
| if row["mode"] == "rl_only" |
| else tuned_config.target_only_bias_strength, |
| "corridor_bias_strength": 0.0 |
| if row["mode"] == "rl_only" |
| else tuned_config.corridor_bias_strength, |
| "max_intersections_affected": 0 |
| if row["mode"] == "rl_only" |
| else tuned_config.max_intersections_affected, |
| "gating_mode": "always_on" if row["mode"] == "rl_only" else tuned_config.gating_mode, |
| "guidance_persistence_steps": 0 |
| if row["mode"] == "rl_only" |
| else tuned_config.guidance_persistence_steps, |
| "guidance_refresh_steps": 0 |
| if row["mode"] == "rl_only" |
| else tuned_config.guidance_refresh_steps, |
| "enable_bias_decay": False if row["mode"] == "rl_only" else tuned_config.enable_bias_decay, |
| "min_avg_queue_for_guidance": 0.0 |
| if row["mode"] == "rl_only" |
| else tuned_config.min_avg_queue_for_guidance, |
| "min_queue_imbalance_for_guidance": 0.0 |
| if row["mode"] == "rl_only" |
| else tuned_config.min_queue_imbalance_for_guidance, |
| } |
| ) |
| return payload |
|
|
|
|
| def build_paired_delta_rows( |
| rows_by_pair: dict[tuple[str, str, int, int], dict[str, dict[str, Any]]], |
| ) -> list[dict[str, Any]]: |
| comparison_modes = ("rl_heuristic", "rl_llm") |
| paired_rows: list[dict[str, Any]] = [] |
| for (city_id, scenario, seed, episode_id), mode_rows in sorted(rows_by_pair.items()): |
| rl_only_row = mode_rows.get("rl_only") |
| if rl_only_row is None: |
| continue |
| for comparison_mode in comparison_modes: |
| other_row = mode_rows.get(comparison_mode) |
| if other_row is None: |
| continue |
| paired_rows.append( |
| { |
| "city_id": city_id, |
| "scenario": scenario, |
| "seed": int(seed), |
| "episode_id": int(episode_id), |
| "comparison": f"{comparison_mode}_vs_rl_only", |
| "mode": comparison_mode, |
| "total_return_delta": safe_float(other_row.get("total_return")) |
| - safe_float(rl_only_row.get("total_return")), |
| "avg_queue_delta": safe_float(other_row.get("avg_queue")) |
| - safe_float(rl_only_row.get("avg_queue")), |
| "avg_wait_delta": safe_float(other_row.get("avg_wait")) |
| - safe_float(rl_only_row.get("avg_wait")), |
| "throughput_delta": safe_float(other_row.get("throughput")) |
| - safe_float(rl_only_row.get("throughput")), |
| "travel_time_delta": safe_float(other_row.get("travel_time")) |
| - safe_float(rl_only_row.get("travel_time")), |
| "spillback_delta": safe_float(other_row.get("spillback_count")) |
| - safe_float(rl_only_row.get("spillback_count")), |
| "return_beats_rl_only": float( |
| safe_float(other_row.get("total_return")) |
| > safe_float(rl_only_row.get("total_return")) |
| ), |
| } |
| ) |
| return paired_rows |
|
|
|
|
| def build_summary_payload( |
| episode_rows: list[dict[str, Any]], |
| paired_delta_rows: list[dict[str, Any]], |
| tuned_config: GuidanceInfluenceConfig, |
| args: argparse.Namespace, |
| scenario_specs: list[Any], |
| ) -> dict[str, Any]: |
| metrics_by_mode: dict[str, dict[str, float]] = {} |
| for mode in ("rl_only", "rl_heuristic", "rl_llm"): |
| mode_rows = [row for row in episode_rows if row["mode"] == mode] |
| metrics_by_mode[mode] = { |
| "mean_total_return": distribution_summary( |
| [safe_float(row.get("total_return")) for row in mode_rows] |
| )["mean"], |
| "std_total_return": distribution_summary( |
| [safe_float(row.get("total_return")) for row in mode_rows] |
| )["std"], |
| "mean_avg_queue": distribution_summary( |
| [safe_float(row.get("avg_queue")) for row in mode_rows] |
| )["mean"], |
| "mean_avg_wait": distribution_summary( |
| [safe_float(row.get("avg_wait")) for row in mode_rows] |
| )["mean"], |
| "mean_throughput": distribution_summary( |
| [safe_float(row.get("throughput")) for row in mode_rows] |
| )["mean"], |
| "mean_travel_time": distribution_summary( |
| [safe_float(row.get("travel_time")) for row in mode_rows] |
| )["mean"], |
| "mean_spillback_count": distribution_summary( |
| [safe_float(row.get("spillback_count")) for row in mode_rows] |
| )["mean"], |
| "mean_percent_steps_with_active_guidance": distribution_summary( |
| [safe_float(row.get("percent_steps_with_active_guidance")) for row in mode_rows] |
| )["mean"], |
| "mean_avg_num_affected_intersections": distribution_summary( |
| [safe_float(row.get("avg_num_affected_intersections")) for row in mode_rows] |
| )["mean"], |
| "mean_fallback_used_count": distribution_summary( |
| [safe_float(row.get("fallback_used_count")) for row in mode_rows] |
| )["mean"], |
| "mean_invalid_guidance_count": distribution_summary( |
| [safe_float(row.get("invalid_guidance_count")) for row in mode_rows] |
| )["mean"], |
| } |
|
|
| rl_only_metrics = metrics_by_mode["rl_only"] |
| paired_summary = { |
| comparison: { |
| "mean_total_return_delta": distribution_summary( |
| [safe_float(row.get("total_return_delta")) for row in paired_delta_rows if row["comparison"] == comparison] |
| )["mean"], |
| "std_total_return_delta": distribution_summary( |
| [safe_float(row.get("total_return_delta")) for row in paired_delta_rows if row["comparison"] == comparison] |
| )["std"], |
| "mean_avg_queue_delta": distribution_summary( |
| [safe_float(row.get("avg_queue_delta")) for row in paired_delta_rows if row["comparison"] == comparison] |
| )["mean"], |
| "mean_avg_wait_delta": distribution_summary( |
| [safe_float(row.get("avg_wait_delta")) for row in paired_delta_rows if row["comparison"] == comparison] |
| )["mean"], |
| "mean_throughput_delta": distribution_summary( |
| [safe_float(row.get("throughput_delta")) for row in paired_delta_rows if row["comparison"] == comparison] |
| )["mean"], |
| "beats_fraction": distribution_summary( |
| [safe_float(row.get("return_beats_rl_only")) for row in paired_delta_rows if row["comparison"] == comparison] |
| )["mean"], |
| } |
| for comparison in ("rl_heuristic_vs_rl_only", "rl_llm_vs_rl_only") |
| } |
|
|
| return { |
| "generated_at": datetime.now(timezone.utc).isoformat(), |
| "comparison_scope": { |
| "cities": sorted({spec.city_id for spec in scenario_specs}), |
| "scenarios": sorted({spec.scenario_name for spec in scenario_specs}), |
| "seeds": [int(seed) for seed in args.seeds], |
| "episodes_per_seed": int(args.episodes_per_seed), |
| "max_episode_seconds": int(args.max_episode_seconds), |
| "total_runs": int(len(episode_rows)), |
| }, |
| "wrapper_config": guidance_config_payload(tuned_config), |
| "repair_config": asdict( |
| RepairConfig( |
| allow_only_visible_candidates=args.allow_only_visible_candidates, |
| max_target_intersections=args.max_target_intersections, |
| fallback_on_empty_targets=args.fallback_on_empty_targets, |
| fallback_mode=args.fallback_mode, |
| ) |
| ), |
| "metrics_by_mode": metrics_by_mode, |
| "paired_summary": paired_summary, |
| "rl_only_mean_return": rl_only_metrics["mean_total_return"], |
| "rl_heuristic_mean_return": metrics_by_mode["rl_heuristic"]["mean_total_return"], |
| "rl_llm_mean_return": metrics_by_mode["rl_llm"]["mean_total_return"], |
| "rl_heuristic_return_delta_vs_rl_only": ( |
| metrics_by_mode["rl_heuristic"]["mean_total_return"] - rl_only_metrics["mean_total_return"] |
| ), |
| "rl_llm_return_delta_vs_rl_only": ( |
| metrics_by_mode["rl_llm"]["mean_total_return"] - rl_only_metrics["mean_total_return"] |
| ), |
| "rl_heuristic_avg_queue_delta_vs_rl_only": ( |
| metrics_by_mode["rl_heuristic"]["mean_avg_queue"] - rl_only_metrics["mean_avg_queue"] |
| ), |
| "rl_llm_avg_queue_delta_vs_rl_only": ( |
| metrics_by_mode["rl_llm"]["mean_avg_queue"] - rl_only_metrics["mean_avg_queue"] |
| ), |
| "rl_heuristic_avg_wait_delta_vs_rl_only": ( |
| metrics_by_mode["rl_heuristic"]["mean_avg_wait"] - rl_only_metrics["mean_avg_wait"] |
| ), |
| "rl_llm_avg_wait_delta_vs_rl_only": ( |
| metrics_by_mode["rl_llm"]["mean_avg_wait"] - rl_only_metrics["mean_avg_wait"] |
| ), |
| "rl_heuristic_throughput_delta_vs_rl_only": ( |
| metrics_by_mode["rl_heuristic"]["mean_throughput"] - rl_only_metrics["mean_throughput"] |
| ), |
| "rl_llm_throughput_delta_vs_rl_only": ( |
| metrics_by_mode["rl_llm"]["mean_throughput"] - rl_only_metrics["mean_throughput"] |
| ), |
| "heuristic_beats_rl_fraction": paired_summary["rl_heuristic_vs_rl_only"]["beats_fraction"], |
| "llm_beats_rl_fraction": paired_summary["rl_llm_vs_rl_only"]["beats_fraction"], |
| } |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|