from __future__ import annotations import argparse import json from dataclasses import asdict, dataclass from datetime import datetime, timezone from pathlib import Path from typing import Any import sys from tqdm.auto import tqdm REPO_ROOT = Path(__file__).resolve().parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) from district_llm.heuristic_guidance import HeuristicGuidanceConfig from district_llm.inference import DistrictLLMInference from district_llm.repair import RepairConfig from district_llm.rl_guidance_wrapper import ( BIAS_DECAY_SCHEDULES, GATING_MODES, DistrictGuidedRLController, FixedRLPolicyAdapter, GuidanceInfluenceConfig, HeuristicGuidanceProvider, LLMGuidanceProvider, guidance_config_payload, ) from district_llm.summary_builder import DistrictStateSummaryBuilder from env.traffic_env import EnvConfig from scripts.eval_rl_guidance_ablation import ( build_episode_plans, default_env_config, distribution_summary, env_config_to_payload, resolve_scenario_specs, run_episode, safe_float, try_write_parquet, write_csv_rows, write_json, write_jsonl, ) from training.cityflow_dataset import CityFlowDataset PRESET_CHOICES: tuple[str, ...] = ( "strength_only", "strength_and_targets", "strength_targets_gating", "full_conservative", ) DEFAULT_CITIES: tuple[str, ...] = ("city_0001",) DEFAULT_SCENARIOS: tuple[str, ...] = ("normal",) @dataclass(frozen=True) class SweepConfigSpec: config_id: str description: str wrapper_mode: str bias_strength: float target_only_bias_strength: float corridor_bias_strength: float max_intersections_affected: int guidance_persistence_steps: int guidance_refresh_steps: int max_guidance_duration: int gating_mode: str min_avg_queue_for_guidance: float min_queue_imbalance_for_guidance: float require_incident_or_spillback: bool allow_guidance_in_normal_conditions: bool enable_bias_decay: bool bias_decay_schedule: str fallback_policy: str is_reference: bool = False def to_influence_config(self) -> GuidanceInfluenceConfig: return GuidanceInfluenceConfig( wrapper_mode=self.wrapper_mode, bias_strength=self.bias_strength, target_only_bias_strength=self.target_only_bias_strength, corridor_bias_strength=self.corridor_bias_strength, max_intersections_affected=self.max_intersections_affected, guidance_refresh_steps=self.guidance_refresh_steps, guidance_persistence_steps=self.guidance_persistence_steps, max_guidance_duration=self.max_guidance_duration, apply_global_bias=False, apply_target_only=True, gating_mode=self.gating_mode, min_avg_queue_for_guidance=self.min_avg_queue_for_guidance, min_queue_imbalance_for_guidance=self.min_queue_imbalance_for_guidance, require_incident_or_spillback=self.require_incident_or_spillback, allow_guidance_in_normal_conditions=self.allow_guidance_in_normal_conditions, enable_bias_decay=self.enable_bias_decay, bias_decay_schedule=self.bias_decay_schedule, fallback_policy=self.fallback_policy, log_guidance_debug=False, ).validate() def to_dict(self) -> dict[str, Any]: return asdict(self) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description=( "Cheap paired hyperparameter sweep for the fixed RL + district LLM wrapper. " "The RL checkpoint and LLM checkpoint stay fixed; only inference-time wrapper " "hyperparameters are varied." ) ) parser.add_argument("--rl-checkpoint", required=True) parser.add_argument("--llm-model-path", required=True) parser.add_argument("--generated-root", default="data/generated") parser.add_argument("--splits-root", default="data/splits") parser.add_argument("--split", default="val", choices=("train", "val", "test")) parser.add_argument("--cities", nargs="+", default=list(DEFAULT_CITIES)) parser.add_argument("--scenarios", nargs="+", default=list(DEFAULT_SCENARIOS)) parser.add_argument("--seeds", nargs="+", type=int, default=[7, 11, 13]) parser.add_argument("--episodes-per-seed", type=int, default=1) parser.add_argument( "--max-episode-seconds", type=int, default=300, help="Cheap default horizon for wrapper tuning sweeps.", ) parser.add_argument( "--preset", choices=PRESET_CHOICES, default="strength_targets_gating", ) parser.add_argument("--guidance-refresh-steps", type=int, default=10) parser.add_argument("--max-guidance-duration", type=int, default=10) parser.add_argument("--queue-threshold", type=float, default=150.0) parser.add_argument("--imbalance-threshold", type=float, default=20.0) parser.add_argument("--max-new-tokens", type=int, default=128) parser.add_argument("--device", default=None) parser.add_argument("--output-dir", default="artifacts/rl_llm_wrapper_sweep") parser.add_argument( "--allow-only-visible-candidates", action=argparse.BooleanOptionalAction, default=True, ) parser.add_argument("--max-target-intersections", type=int, default=3) parser.add_argument( "--fallback-on-empty-targets", action=argparse.BooleanOptionalAction, default=True, ) parser.add_argument( "--fallback-mode", choices=("heuristic", "hold", "none"), default="heuristic", ) parser.add_argument( "--fallback-policy", choices=("no_op", "hold_previous", "heuristic_weak"), default="no_op", ) parser.add_argument( "--save-step-metrics", action=argparse.BooleanOptionalAction, default=False, ) parser.add_argument( "--save-guidance-traces", action=argparse.BooleanOptionalAction, default=False, ) parser.add_argument( "--bias-decay-schedule", choices=BIAS_DECAY_SCHEDULES, default="linear", ) return parser.parse_args() def main() -> None: args = parse_args() output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) seeded_config_root = output_dir / "seeded_configs" seeded_config_root.mkdir(parents=True, exist_ok=True) dataset = CityFlowDataset( generated_root=args.generated_root, splits_root=args.splits_root, ) dataset.generate_default_splits() scenario_specs = resolve_scenario_specs(dataset=dataset, args=args) episode_plans = build_episode_plans( scenario_specs=scenario_specs, seeds=args.seeds, num_episodes=args.episodes_per_seed, seeded_config_root=seeded_config_root, ) sweep_configs = build_sweep_configs(args) rl_policy = FixedRLPolicyAdapter( checkpoint_path=args.rl_checkpoint, device=args.device, ) env_config = rl_policy.env_config or default_env_config() if args.max_episode_seconds is not None: env_config = EnvConfig( simulator_interval=env_config.simulator_interval, decision_interval=env_config.decision_interval, min_green_time=env_config.min_green_time, thread_num=env_config.thread_num, max_episode_seconds=int(args.max_episode_seconds), observation=env_config.observation, reward=env_config.reward, ) rl_only_controller = build_rl_only_controller( rl_policy=rl_policy, guidance_refresh_steps=args.guidance_refresh_steps, max_guidance_duration=args.max_guidance_duration, ) guided_controllers = build_guided_controllers( args=args, rl_policy=rl_policy, sweep_configs=sweep_configs, ) sweep_rows: list[dict[str, Any]] = [] paired_rows: list[dict[str, Any]] = [] rl_only_rows: list[dict[str, Any]] = [] step_rows: list[dict[str, Any]] = [] guidance_trace_rows: list[dict[str, Any]] = [] total_runs = len(episode_plans) * (1 + len(sweep_configs)) progress = tqdm(total=total_runs, desc="RL+LLM wrapper sweep", unit="run") try: for plan_index, plan in enumerate(episode_plans, start=1): progress.set_postfix_str( f"rl_only city={plan.city_id} scenario={plan.scenario} seed={plan.seed}" ) rl_only_row, rl_only_step_rows, rl_only_trace_rows = run_episode( plan=plan, mode_label="rl_only", controller=rl_only_controller, env_config=env_config, save_step_metrics=args.save_step_metrics, save_guidance_traces=args.save_guidance_traces, show_step_progress=False, ) rl_only_row = augment_rl_only_row(rl_only_row) rl_only_rows.append(rl_only_row) if args.save_step_metrics: step_rows.extend( augment_auxiliary_rows( rows=rl_only_step_rows, config_id="rl_only", config_spec=None, ) ) if args.save_guidance_traces: guidance_trace_rows.extend( augment_auxiliary_rows( rows=rl_only_trace_rows, config_id="rl_only", config_spec=None, ) ) progress.update(1) for config in sweep_configs: controller = guided_controllers[config.config_id] progress.set_postfix_str( f"{config.config_id} city={plan.city_id} scenario={plan.scenario} seed={plan.seed}" ) episode_row, mode_step_rows, mode_trace_rows = run_episode( plan=plan, mode_label=config.config_id, controller=controller, env_config=env_config, save_step_metrics=args.save_step_metrics, save_guidance_traces=args.save_guidance_traces, show_step_progress=False, ) episode_row = augment_guided_row(episode_row, config) sweep_rows.append(episode_row) paired_rows.append(build_paired_row(guided_row=episode_row, rl_only_row=rl_only_row)) if args.save_step_metrics: step_rows.extend( augment_auxiliary_rows( rows=mode_step_rows, config_id=config.config_id, config_spec=config, ) ) if args.save_guidance_traces: guidance_trace_rows.extend( augment_auxiliary_rows( rows=mode_trace_rows, config_id=config.config_id, config_spec=config, ) ) progress.update(1) tqdm.write( "[sweep-plan] " f"{plan_index}/{len(episode_plans)} " f"city={plan.city_id} scenario={plan.scenario} seed={plan.seed} complete" ) finally: progress.close() ranking_rows = build_config_rankings( paired_rows=paired_rows, sweep_configs=sweep_configs, ) summary_report = build_summary_report( paired_rows=paired_rows, ranking_rows=ranking_rows, rl_only_rows=rl_only_rows, args=args, sweep_configs=sweep_configs, ) config_payload = build_config_payload( args=args, env_config=env_config, episode_plans=episode_plans, sweep_configs=sweep_configs, ) write_json(output_dir / "config.json", config_payload) write_csv_rows(output_dir / "sweep_results.csv", sweep_rows) write_jsonl(output_dir / "sweep_results.jsonl", sweep_rows) try_write_parquet(output_dir / "sweep_results.parquet", sweep_rows) write_csv_rows(output_dir / "paired_episode_metrics.csv", paired_rows) write_jsonl(output_dir / "paired_episode_metrics.jsonl", paired_rows) try_write_parquet(output_dir / "paired_episode_metrics.parquet", paired_rows) write_csv_rows(output_dir / "rl_only_episode_metrics.csv", rl_only_rows) write_json(output_dir / "ranking.json", ranking_rows) write_json(output_dir / "summary_report.json", summary_report) if args.save_step_metrics: write_csv_rows(output_dir / "step_metrics.csv", step_rows) write_jsonl(output_dir / "step_metrics.jsonl", step_rows) try_write_parquet(output_dir / "step_metrics.parquet", step_rows) if args.save_guidance_traces: write_jsonl(output_dir / "guidance_traces.jsonl", guidance_trace_rows) print(json.dumps(summary_report, indent=2, sort_keys=True)) def build_rl_only_controller( rl_policy: FixedRLPolicyAdapter, guidance_refresh_steps: int, max_guidance_duration: int, ) -> DistrictGuidedRLController: return DistrictGuidedRLController( policy=rl_policy, mode_source="rl_only", summary_builder=None, guidance_provider=None, influence_config=GuidanceInfluenceConfig( wrapper_mode="no_op", bias_strength=0.0, target_only_bias_strength=0.0, corridor_bias_strength=0.0, max_intersections_affected=1, guidance_refresh_steps=guidance_refresh_steps, guidance_persistence_steps=1, max_guidance_duration=max_guidance_duration, gating_mode="always_on", enable_bias_decay=False, fallback_policy="no_op", ), heuristic_provider=None, ) def build_guided_controllers( args: argparse.Namespace, rl_policy: FixedRLPolicyAdapter, sweep_configs: list[SweepConfigSpec], ) -> dict[str, DistrictGuidedRLController]: repair_config = RepairConfig( allow_only_visible_candidates=args.allow_only_visible_candidates, max_target_intersections=args.max_target_intersections, fallback_on_empty_targets=args.fallback_on_empty_targets, fallback_mode=args.fallback_mode, ) llm_inference = DistrictLLMInference( model_name_or_path=args.llm_model_path, device=args.device, repair_config=repair_config, ) heuristic_provider = HeuristicGuidanceProvider( config=HeuristicGuidanceConfig( max_target_intersections=args.max_target_intersections, ) ) llm_provider = LLMGuidanceProvider( inference=llm_inference, max_new_tokens=args.max_new_tokens, ) controllers: dict[str, DistrictGuidedRLController] = {} for config in sweep_configs: controllers[config.config_id] = DistrictGuidedRLController( policy=rl_policy, mode_source="rl_llm", summary_builder=DistrictStateSummaryBuilder( top_k=3, candidate_limit=max(6, int(args.max_target_intersections)), ), guidance_provider=llm_provider, influence_config=config.to_influence_config(), heuristic_provider=heuristic_provider, ) return controllers def build_sweep_configs(args: argparse.Namespace) -> list[SweepConfigSpec]: configs: list[SweepConfigSpec] = [ build_baseline_reference_config(args), ] if args.preset == "strength_only": for bias_strength in (0.025, 0.05, 0.075, 0.10): configs.append( build_target_only_soft_config( args=args, bias_strength=bias_strength, max_intersections_affected=2, guidance_persistence_steps=5, gating_mode="queue_or_imbalance", enable_bias_decay=False, ) ) elif args.preset == "strength_and_targets": for bias_strength in (0.025, 0.05, 0.075, 0.10): for max_intersections_affected in (1, 2): configs.append( build_target_only_soft_config( args=args, bias_strength=bias_strength, max_intersections_affected=max_intersections_affected, guidance_persistence_steps=5, gating_mode="queue_or_imbalance", enable_bias_decay=False, ) ) elif args.preset == "strength_targets_gating": for bias_strength in (0.025, 0.05, 0.075): for max_intersections_affected in (1, 2): for gating_mode in ("always_on", "incident_or_spillback", "queue_or_imbalance"): configs.append( build_target_only_soft_config( args=args, bias_strength=bias_strength, max_intersections_affected=max_intersections_affected, guidance_persistence_steps=5, gating_mode=gating_mode, enable_bias_decay=False, ) ) else: for bias_strength in (0.025, 0.05, 0.075): for max_intersections_affected in (1, 2): for gating_mode, guidance_persistence_steps, enable_bias_decay in ( ("queue_or_imbalance", 5, False), ("queue_or_imbalance", 10, True), ("incident_or_spillback", 5, False), ("incident_or_spillback", 10, True), ): configs.append( build_target_only_soft_config( args=args, bias_strength=bias_strength, max_intersections_affected=max_intersections_affected, guidance_persistence_steps=guidance_persistence_steps, gating_mode=gating_mode, enable_bias_decay=enable_bias_decay, ) ) return dedupe_sweep_configs(configs) def build_baseline_reference_config(args: argparse.Namespace) -> SweepConfigSpec: return SweepConfigSpec( config_id="baseline_current_soft", description="Current rl_llm + target_only_soft reference config from the smoke runs.", wrapper_mode="target_only_soft", bias_strength=0.12, target_only_bias_strength=0.18, corridor_bias_strength=0.05, max_intersections_affected=3, guidance_persistence_steps=3, guidance_refresh_steps=args.guidance_refresh_steps, max_guidance_duration=max(args.max_guidance_duration, 3), gating_mode="always_on", min_avg_queue_for_guidance=args.queue_threshold, min_queue_imbalance_for_guidance=args.imbalance_threshold, require_incident_or_spillback=False, allow_guidance_in_normal_conditions=True, enable_bias_decay=True, bias_decay_schedule=args.bias_decay_schedule, fallback_policy=args.fallback_policy, is_reference=True, ) def build_target_only_soft_config( args: argparse.Namespace, bias_strength: float, max_intersections_affected: int, guidance_persistence_steps: int, gating_mode: str, enable_bias_decay: bool, ) -> SweepConfigSpec: target_only_bias_strength = bias_strength corridor_bias_strength = 0.5 * bias_strength config_id = ( f"bs{format_float_token(bias_strength)}" f"_aff{int(max_intersections_affected)}" f"_gate{gating_mode_token(gating_mode)}" f"_p{int(guidance_persistence_steps)}" f"_decay{int(enable_bias_decay)}" ) return SweepConfigSpec( config_id=config_id, description=( "Curated conservative target_only_soft sweep config with locally tied target/corridor " "bias strengths." ), wrapper_mode="target_only_soft", bias_strength=float(bias_strength), target_only_bias_strength=float(target_only_bias_strength), corridor_bias_strength=float(corridor_bias_strength), max_intersections_affected=int(max_intersections_affected), guidance_persistence_steps=int(guidance_persistence_steps), guidance_refresh_steps=int(args.guidance_refresh_steps), max_guidance_duration=max(int(args.max_guidance_duration), int(guidance_persistence_steps)), gating_mode=gating_mode, min_avg_queue_for_guidance=float(args.queue_threshold), min_queue_imbalance_for_guidance=float(args.imbalance_threshold), require_incident_or_spillback=False, allow_guidance_in_normal_conditions=(gating_mode == "always_on"), enable_bias_decay=bool(enable_bias_decay), bias_decay_schedule=args.bias_decay_schedule, fallback_policy=args.fallback_policy, is_reference=False, ) def dedupe_sweep_configs(configs: list[SweepConfigSpec]) -> list[SweepConfigSpec]: deduped: list[SweepConfigSpec] = [] seen_ids: set[str] = set() for config in configs: if config.config_id in seen_ids: continue deduped.append(config) seen_ids.add(config.config_id) return deduped def augment_rl_only_row(row: dict[str, Any]) -> dict[str, Any]: payload = dict(row) payload.update( { "config_id": "rl_only", "description": "Fixed RL policy with no district guidance.", "is_reference": True, "bias_strength": 0.0, "target_only_bias_strength": 0.0, "corridor_bias_strength": 0.0, "max_intersections_affected": 0, "guidance_persistence_steps": 0, "guidance_refresh_steps": 0, "max_guidance_duration": 0, "gating_mode": "always_on", "min_avg_queue_for_guidance": 0.0, "min_queue_imbalance_for_guidance": 0.0, "require_incident_or_spillback": False, "allow_guidance_in_normal_conditions": True, "enable_bias_decay": False, "bias_decay_schedule": "linear", } ) return payload def augment_guided_row(row: dict[str, Any], config: SweepConfigSpec) -> dict[str, Any]: payload = dict(row) payload.update( { "config_id": config.config_id, "description": config.description, "is_reference": bool(config.is_reference), "bias_strength": float(config.bias_strength), "target_only_bias_strength": float(config.target_only_bias_strength), "corridor_bias_strength": float(config.corridor_bias_strength), "max_intersections_affected": int(config.max_intersections_affected), "guidance_persistence_steps": int(config.guidance_persistence_steps), "guidance_refresh_steps": int(config.guidance_refresh_steps), "max_guidance_duration": int(config.max_guidance_duration), "gating_mode": config.gating_mode, "min_avg_queue_for_guidance": float(config.min_avg_queue_for_guidance), "min_queue_imbalance_for_guidance": float(config.min_queue_imbalance_for_guidance), "require_incident_or_spillback": bool(config.require_incident_or_spillback), "allow_guidance_in_normal_conditions": bool(config.allow_guidance_in_normal_conditions), "enable_bias_decay": bool(config.enable_bias_decay), "bias_decay_schedule": config.bias_decay_schedule, } ) return payload def augment_auxiliary_rows( rows: list[dict[str, Any]], config_id: str, config_spec: SweepConfigSpec | None, ) -> list[dict[str, Any]]: augmented: list[dict[str, Any]] = [] for row in rows: payload = dict(row) payload["config_id"] = config_id payload["is_reference"] = bool(config_spec.is_reference) if config_spec is not None else False if config_spec is not None: payload["gating_mode"] = config_spec.gating_mode payload["bias_strength"] = float(config_spec.bias_strength) payload["max_intersections_affected"] = int(config_spec.max_intersections_affected) payload["guidance_persistence_steps"] = int(config_spec.guidance_persistence_steps) payload["enable_bias_decay"] = bool(config_spec.enable_bias_decay) augmented.append(payload) return augmented def build_paired_row(guided_row: dict[str, Any], rl_only_row: dict[str, Any]) -> dict[str, Any]: paired_row = dict(guided_row) paired_row.update( { "rl_only_total_return": safe_float(rl_only_row.get("total_return")), "rl_only_avg_queue": safe_float(rl_only_row.get("avg_queue")), "rl_only_avg_wait": safe_float(rl_only_row.get("avg_wait")), "rl_only_throughput": safe_float(rl_only_row.get("throughput")), "rl_only_travel_time": safe_float(rl_only_row.get("travel_time")), "total_return_delta_vs_rl_only": safe_float(guided_row.get("total_return")) - safe_float(rl_only_row.get("total_return")), "avg_queue_delta_vs_rl_only": safe_float(guided_row.get("avg_queue")) - safe_float(rl_only_row.get("avg_queue")), "avg_wait_delta_vs_rl_only": safe_float(guided_row.get("avg_wait")) - safe_float(rl_only_row.get("avg_wait")), "throughput_delta_vs_rl_only": safe_float(guided_row.get("throughput")) - safe_float(rl_only_row.get("throughput")), "travel_time_delta_vs_rl_only": safe_float(guided_row.get("travel_time")) - safe_float(rl_only_row.get("travel_time")), } ) return paired_row def build_config_rankings( paired_rows: list[dict[str, Any]], sweep_configs: list[SweepConfigSpec], ) -> list[dict[str, Any]]: rows_by_config = { config.config_id: [row for row in paired_rows if row["config_id"] == config.config_id] for config in sweep_configs } rankings: list[dict[str, Any]] = [] config_lookup = {config.config_id: config for config in sweep_configs} for config_id, rows in rows_by_config.items(): if not rows: continue config = config_lookup[config_id] summary = { "config_id": config_id, "description": config.description, "is_reference": bool(config.is_reference), "wrapper_mode": config.wrapper_mode, "bias_strength": float(config.bias_strength), "target_only_bias_strength": float(config.target_only_bias_strength), "corridor_bias_strength": float(config.corridor_bias_strength), "max_intersections_affected": int(config.max_intersections_affected), "guidance_persistence_steps": int(config.guidance_persistence_steps), "guidance_refresh_steps": int(config.guidance_refresh_steps), "gating_mode": config.gating_mode, "min_avg_queue_for_guidance": float(config.min_avg_queue_for_guidance), "min_queue_imbalance_for_guidance": float(config.min_queue_imbalance_for_guidance), "require_incident_or_spillback": bool(config.require_incident_or_spillback), "allow_guidance_in_normal_conditions": bool(config.allow_guidance_in_normal_conditions), "enable_bias_decay": bool(config.enable_bias_decay), "mean_total_return": distribution_summary( [safe_float(row.get("total_return")) for row in rows] )["mean"], "mean_return_delta_vs_rl_only": distribution_summary( [safe_float(row.get("total_return_delta_vs_rl_only")) for row in rows] )["mean"], "mean_avg_queue_delta_vs_rl_only": distribution_summary( [safe_float(row.get("avg_queue_delta_vs_rl_only")) for row in rows] )["mean"], "mean_avg_wait_delta_vs_rl_only": distribution_summary( [safe_float(row.get("avg_wait_delta_vs_rl_only")) for row in rows] )["mean"], "mean_throughput_delta_vs_rl_only": distribution_summary( [safe_float(row.get("throughput_delta_vs_rl_only")) for row in rows] )["mean"], "mean_travel_time_delta_vs_rl_only": distribution_summary( [safe_float(row.get("travel_time_delta_vs_rl_only")) for row in rows] )["mean"], "mean_percent_steps_with_active_guidance": distribution_summary( [safe_float(row.get("percent_steps_with_active_guidance")) for row in rows] )["mean"], "mean_avg_num_affected_intersections": distribution_summary( [safe_float(row.get("avg_num_affected_intersections")) for row in rows] )["mean"], "mean_avg_num_targeted_intersections": distribution_summary( [safe_float(row.get("avg_num_targeted_intersections")) for row in rows] )["mean"], "mean_num_steps_guidance_blocked_by_gate": distribution_summary( [safe_float(row.get("num_steps_guidance_blocked_by_gate")) for row in rows] )["mean"], "mean_fallback_used_count": distribution_summary( [safe_float(row.get("fallback_used_count")) for row in rows] )["mean"], "mean_invalid_guidance_count": distribution_summary( [safe_float(row.get("invalid_guidance_count")) for row in rows] )["mean"], "num_episodes": int(len(rows)), } summary["beats_rl_only"] = bool(summary["mean_return_delta_vs_rl_only"] >= 0.0) rankings.append(summary) rankings.sort( key=lambda item: ( float(item["mean_return_delta_vs_rl_only"]), float(item["mean_throughput_delta_vs_rl_only"]), -float(item["mean_avg_queue_delta_vs_rl_only"]), -float(item["mean_avg_wait_delta_vs_rl_only"]), ), reverse=True, ) return rankings def build_summary_report( paired_rows: list[dict[str, Any]], ranking_rows: list[dict[str, Any]], rl_only_rows: list[dict[str, Any]], args: argparse.Namespace, sweep_configs: list[SweepConfigSpec], ) -> dict[str, Any]: rl_only_mean_total_return = distribution_summary( [safe_float(row.get("total_return")) for row in rl_only_rows] )["mean"] top_5 = ranking_rows[:5] best_config = ranking_rows[0] if ranking_rows else None configs_beating_rl_only = [row for row in ranking_rows if row["beats_rl_only"]] bias_effects = group_rankings_by_parameter(ranking_rows, "bias_strength") affected_intersections_effects = group_rankings_by_parameter(ranking_rows, "max_intersections_affected") gating_effects = group_rankings_by_parameter(ranking_rows, "gating_mode") persistence_effects = group_rankings_by_parameter(ranking_rows, "guidance_persistence_steps") decay_effects = group_rankings_by_parameter(ranking_rows, "enable_bias_decay") best_bias = best_group_value(bias_effects) best_max_affected = best_group_value(affected_intersections_effects) best_gating = best_group_value(gating_effects) best_persistence = best_group_value(persistence_effects) best_decay = best_group_value(decay_effects) recommendation = None if best_config is not None: recommendation = ( f"Start the next paired eval with {best_config['config_id']} " f"(bias={best_config['bias_strength']}, max_affected={best_config['max_intersections_affected']}, " f"gate={best_config['gating_mode']}, persistence={best_config['guidance_persistence_steps']}, " f"decay={best_config['enable_bias_decay']})." ) return { "generated_at": datetime.now(timezone.utc).isoformat(), "preset": args.preset, "comparison_scope": { "cities": list(args.cities), "scenarios": list(args.scenarios), "seeds": [int(seed) for seed in args.seeds], "episodes_per_seed": int(args.episodes_per_seed), "num_sweep_configs": int(len(sweep_configs)), "num_paired_rows": int(len(paired_rows)), }, "rl_only_mean_total_return": rl_only_mean_total_return, "best_overall_config": best_config, "did_any_rl_llm_config_beat_rl_only": bool(configs_beating_rl_only), "closest_if_no_beat": None if configs_beating_rl_only else best_config, "top_5_configs": top_5, "parameter_effects": { "bias_strength": bias_effects, "max_intersections_affected": affected_intersections_effects, "gating_mode": gating_effects, "guidance_persistence_steps": persistence_effects, "enable_bias_decay": decay_effects, }, "analysis_answers": { "which_config_was_best_overall": None if best_config is None else best_config["config_id"], "did_any_rl_llm_config_beat_rl_only": bool(configs_beating_rl_only), "did_weaker_bias_help": best_bias in {"0.025", "0.05"}, "did_affecting_fewer_intersections_help": best_max_affected == "1", "did_gating_help": best_gating not in {None, "always_on"}, "did_shorter_persistence_help": best_persistence == "5", "did_bias_decay_help": best_decay == "True", }, "recommendation": recommendation, } def group_rankings_by_parameter( ranking_rows: list[dict[str, Any]], parameter_name: str, ) -> list[dict[str, Any]]: buckets: dict[str, list[dict[str, Any]]] = {} for row in ranking_rows: key = str(row[parameter_name]) buckets.setdefault(key, []).append(row) grouped: list[dict[str, Any]] = [] for key, rows in sorted(buckets.items(), key=lambda item: item[0]): grouped.append( { "value": key, "num_configs": int(len(rows)), "mean_return_delta_vs_rl_only": distribution_summary( [safe_float(row.get("mean_return_delta_vs_rl_only")) for row in rows] )["mean"], "mean_throughput_delta_vs_rl_only": distribution_summary( [safe_float(row.get("mean_throughput_delta_vs_rl_only")) for row in rows] )["mean"], "mean_avg_queue_delta_vs_rl_only": distribution_summary( [safe_float(row.get("mean_avg_queue_delta_vs_rl_only")) for row in rows] )["mean"], "mean_avg_wait_delta_vs_rl_only": distribution_summary( [safe_float(row.get("mean_avg_wait_delta_vs_rl_only")) for row in rows] )["mean"], "mean_percent_steps_with_active_guidance": distribution_summary( [safe_float(row.get("mean_percent_steps_with_active_guidance")) for row in rows] )["mean"], } ) grouped.sort( key=lambda item: ( float(item["mean_return_delta_vs_rl_only"]), float(item["mean_throughput_delta_vs_rl_only"]), -float(item["mean_avg_queue_delta_vs_rl_only"]), -float(item["mean_avg_wait_delta_vs_rl_only"]), ), reverse=True, ) return grouped def best_group_value(grouped_rows: list[dict[str, Any]]) -> str | None: return grouped_rows[0]["value"] if grouped_rows else None def build_config_payload( args: argparse.Namespace, env_config: EnvConfig, episode_plans: list[Any], sweep_configs: list[SweepConfigSpec], ) -> dict[str, Any]: return { "generated_at": datetime.now(timezone.utc).isoformat(), "preset": args.preset, "rl_checkpoint": str(args.rl_checkpoint), "llm_model_path": str(args.llm_model_path), "comparison_scope": { "num_episode_plans": int(len(episode_plans)), "cities": sorted({plan.city_id for plan in episode_plans}), "scenarios": sorted({plan.scenario for plan in episode_plans}), "seeds": sorted({int(plan.seed) for plan in episode_plans}), "episodes_per_seed": int(args.episodes_per_seed), "max_episode_seconds": args.max_episode_seconds, "total_runs": int(len(episode_plans) * (1 + len(sweep_configs))), }, "episode_plans": [plan.to_dict() for plan in episode_plans], "sweep_configs": [config.to_dict() for config in sweep_configs], "influence_configs": { config.config_id: guidance_config_payload(config.to_influence_config()) for config in sweep_configs }, "repair_config": asdict( RepairConfig( allow_only_visible_candidates=args.allow_only_visible_candidates, max_target_intersections=args.max_target_intersections, fallback_on_empty_targets=args.fallback_on_empty_targets, fallback_mode=args.fallback_mode, ) ), "env_config": env_config_to_payload(env_config), "save_step_metrics": bool(args.save_step_metrics), "save_guidance_traces": bool(args.save_guidance_traces), } def format_float_token(value: float) -> str: text = f"{float(value):.3f}".rstrip("0").rstrip(".") return text.replace("-", "m").replace(".", "p") def gating_mode_token(value: str) -> str: return { "always_on": "always", "incident_or_spillback": "incident", "queue_threshold": "queue", "imbalance_threshold": "imbalance", "queue_or_imbalance": "queue_or_imb", "combined": "combined", }[value] if __name__ == "__main__": main()