| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from pathlib import Path |
| import sys |
|
|
| import torch |
| from tqdm.auto import tqdm |
|
|
| REPO_ROOT = Path(__file__).resolve().parents[1] |
| if str(REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(REPO_ROOT)) |
|
|
| from agents.local_policy import FixedCyclePolicy, RandomPhasePolicy |
| from training.cityflow_dataset import CityFlowDataset |
| from training.device import configure_torch_runtime, resolve_torch_device |
| from training.models import RunningNormalizer, TrafficControlQNetwork |
| from training.rollout import evaluate_policy |
| from training.train_local_policy import build_env, build_env_config, load_env_config |
| from training.trainer import aggregate_metrics, aggregate_metrics_by_scenario |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description=( |
| "Compare a learned local policy checkpoint against fixed and random " |
| "baselines under the same reward config." |
| ) |
| ) |
| parser.add_argument("--checkpoint", required=True) |
| parser.add_argument("--city-id", default=None) |
| parser.add_argument("--scenario-name", default=None) |
| parser.add_argument("--split", default="val", choices=("train", "val", "test")) |
| parser.add_argument("--max-val-cities", type=int, default=None) |
| parser.add_argument("--scenarios-per-city", type=int, default=1) |
| parser.add_argument("--generated-root", default="data/generated") |
| parser.add_argument("--splits-root", default="data/splits") |
| parser.add_argument("--device", default=None) |
|
|
| parser.add_argument("--decision-interval", type=int, default=5) |
| parser.add_argument("--simulator-interval", type=int, default=1) |
| parser.add_argument("--min-green-time", type=int, default=10) |
| parser.add_argument("--thread-num", type=int, default=1) |
| parser.add_argument("--max-episode-seconds", type=int, default=None) |
| parser.add_argument("--max-incoming-lanes", type=int, default=16) |
| parser.add_argument("--count-scale", type=float, default=20.0) |
| parser.add_argument("--elapsed-time-scale", type=float, default=60.0) |
| parser.add_argument("--disable-district-context", action="store_true") |
| parser.add_argument("--disable-outgoing-congestion", action="store_true") |
| parser.add_argument("--reward-variant", default="wait_queue_throughput") |
| parser.add_argument("--waiting-weight", type=float, default=1.0) |
| parser.add_argument("--vehicle-weight", type=float, default=0.1) |
| parser.add_argument("--pressure-weight", type=float, default=0.0) |
| parser.add_argument("--reward-scale", type=float, default=0.1) |
| parser.add_argument("--disable-lane-reward-normalization", action="store_true") |
| parser.add_argument("--reward-clip", type=float, default=5.0) |
| parser.add_argument("--queue-delta-weight", type=float, default=2.0) |
| parser.add_argument("--wait-delta-weight", type=float, default=4.0) |
| parser.add_argument("--queue-level-weight", type=float, default=0.5) |
| parser.add_argument("--wait-level-weight", type=float, default=1.0) |
| parser.add_argument("--throughput-weight", type=float, default=0.1) |
| parser.add_argument("--imbalance-weight", type=float, default=0.1) |
| parser.add_argument("--reward-delta-clip", type=float, default=2.0) |
| parser.add_argument("--reward-level-normalizer", type=float, default=10.0) |
| parser.add_argument("--throughput-normalizer", type=float, default=2.0) |
| parser.add_argument("--policy-arch", default="single_head_with_district_feature") |
| parser.add_argument("--fixed-green-time", type=int, default=20) |
| parser.add_argument("--random-seed", type=int, default=7) |
| parser.add_argument("--disable-tqdm", action="store_true") |
| parser.add_argument("--verbose-progress", action="store_true") |
| return parser.parse_args() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| if (args.city_id is None) != (args.scenario_name is None): |
| raise ValueError("--city-id and --scenario-name must be provided together.") |
|
|
| dataset = CityFlowDataset( |
| generated_root=args.generated_root, |
| splits_root=args.splits_root, |
| ) |
| scenario_specs = build_scenario_specs(dataset, args) |
|
|
| device = resolve_torch_device(args.device) |
| configure_torch_runtime(device) |
| print(f"[setup] torch_device={device.type}") |
|
|
| env_config = build_env_config(args) |
| checkpoint = torch.load( |
| args.checkpoint, |
| map_location=device, |
| weights_only=False, |
| ) |
| if checkpoint.get("env_config"): |
| env_config = load_env_config(checkpoint["env_config"]) |
|
|
| network_architecture = checkpoint.get("network_architecture") or checkpoint.get( |
| "policy_architecture", {} |
| ) |
| trainer_config = checkpoint.get("dqn_config", {}) |
| checkpoint_policy_arch = network_architecture.get( |
| "policy_arch", |
| trainer_config.get("policy_arch", args.policy_arch), |
| ) |
|
|
| actor = TrafficControlQNetwork( |
| observation_dim=int(network_architecture["observation_dim"]), |
| action_dim=int(network_architecture.get("action_dim", 2)), |
| hidden_dim=int(trainer_config.get("hidden_dim", 256)), |
| num_layers=int(trainer_config.get("hidden_layers", 2)), |
| district_types=tuple(network_architecture.get("district_types", ())), |
| policy_arch=checkpoint_policy_arch, |
| dueling=bool(network_architecture.get("dueling", True)), |
| ).to(device) |
| actor.load_state_dict( |
| checkpoint.get("q_network_state_dict") or checkpoint["policy_state_dict"] |
| ) |
| actor.eval() |
|
|
| obs_normalizer = None |
| if checkpoint.get("obs_normalizer"): |
| obs_normalizer = RunningNormalizer() |
| obs_normalizer.load_state_dict(checkpoint["obs_normalizer"]) |
|
|
| policies = { |
| "learned": (actor, device, obs_normalizer), |
| "fixed": (FixedCyclePolicy(green_time=args.fixed_green_time), None, None), |
| "random": (RandomPhasePolicy(seed=args.random_seed), None, None), |
| } |
| scope = build_scope_summary(args, scenario_specs) |
| print( |
| "[compare] " |
| f"num_cities={scope['num_cities']} " |
| f"num_scenarios={scope['num_scenarios']} " |
| f"reward_variant={env_config.reward.variant}" |
| ) |
|
|
| aggregate_results: dict[str, dict[str, float]] = {} |
| scenario_breakdowns: dict[str, dict[str, float]] = {} |
| for name, (policy, policy_device, normalizer) in policies.items(): |
| print(f"[compare] starting policy={name}") |
| episode_metrics = [] |
| iterator = enumerate(scenario_specs, start=1) |
| if not args.disable_tqdm: |
| iterator = tqdm( |
| iterator, |
| total=len(scenario_specs), |
| desc=f"compare:{name}", |
| dynamic_ncols=True, |
| leave=False, |
| ) |
| for index, spec in iterator: |
| if args.verbose_progress: |
| message = ( |
| f"[compare] policy={name} city={spec.city_id} " |
| f"scenario={spec.scenario_name} i={index}/{len(scenario_specs)}" |
| ) |
| if args.disable_tqdm: |
| print(message) |
| else: |
| tqdm.write(message) |
| metrics = evaluate_policy( |
| env_factory=lambda spec=spec, config=env_config: build_env(config, spec), |
| actor=policy, |
| device=policy_device, |
| obs_normalizer=normalizer, |
| deterministic=True, |
| ) |
| episode_metrics.append(metrics) |
| if not args.disable_tqdm: |
| iterator.set_postfix( |
| city=spec.city_id, |
| scenario=spec.scenario_name, |
| ret=f"{metrics['episode_return']:.3f}", |
| ) |
| aggregate_results[name] = aggregate_metrics(episode_metrics) |
| scenario_breakdowns[name] = aggregate_metrics_by_scenario(episode_metrics) |
| mean_return = aggregate_results[name].get("mean_episode_return", float("nan")) |
| mean_wait = aggregate_results[name].get("mean_mean_waiting_vehicles", float("nan")) |
| mean_throughput = aggregate_results[name].get("mean_throughput", float("nan")) |
| message = ( |
| f"[compare] finished policy={name} " |
| f"mean_return={mean_return:.3f} " |
| f"wait={mean_wait:.3f} " |
| f"throughput={mean_throughput:.1f}" |
| ) |
| if args.disable_tqdm: |
| print(message) |
| else: |
| tqdm.write(message) |
|
|
| learned = aggregate_results["learned"] |
| fixed = aggregate_results["fixed"] |
| random = aggregate_results["random"] |
| summary = { |
| "comparison_scope": build_scope_summary(args, scenario_specs), |
| "reward_variant": env_config.reward.variant, |
| "checkpoint": args.checkpoint, |
| "results": aggregate_results, |
| "scenario_breakdowns": scenario_breakdowns, |
| "deltas": { |
| "learned_minus_fixed_return": float(learned["mean_episode_return"]) |
| - float(fixed["mean_episode_return"]), |
| "learned_minus_random_return": float(learned["mean_episode_return"]) |
| - float(random["mean_episode_return"]), |
| "learned_minus_fixed_wait": float(learned["mean_mean_waiting_vehicles"]) |
| - float(fixed["mean_mean_waiting_vehicles"]), |
| "learned_minus_random_wait": float(learned["mean_mean_waiting_vehicles"]) |
| - float(random["mean_mean_waiting_vehicles"]), |
| "learned_minus_fixed_travel_time": float(learned["mean_average_travel_time"]) |
| - float(fixed["mean_average_travel_time"]), |
| "learned_minus_random_travel_time": float(learned["mean_average_travel_time"]) |
| - float(random["mean_average_travel_time"]), |
| "learned_minus_fixed_throughput": float(learned["mean_throughput"]) |
| - float(fixed["mean_throughput"]), |
| "learned_minus_random_throughput": float(learned["mean_throughput"]) |
| - float(random["mean_throughput"]), |
| }, |
| } |
| print(json.dumps(summary, indent=2)) |
|
|
|
|
| def build_scenario_specs(dataset: CityFlowDataset, args: argparse.Namespace) -> list: |
| if args.city_id and args.scenario_name: |
| return [dataset.build_scenario_spec(args.city_id, args.scenario_name)] |
| return dataset.iter_scenarios( |
| split_name=args.split, |
| scenarios_per_city=args.scenarios_per_city, |
| max_cities=args.max_val_cities, |
| diversify_single_scenario=True, |
| ) |
|
|
|
|
| def build_scope_summary(args: argparse.Namespace, scenario_specs: list) -> dict[str, object]: |
| city_ids = sorted({spec.city_id for spec in scenario_specs}) |
| scenario_names = sorted({spec.scenario_name for spec in scenario_specs}) |
| return { |
| "split": args.split if not args.city_id else None, |
| "city_id": args.city_id, |
| "scenario_name": args.scenario_name, |
| "num_cities": len(city_ids), |
| "num_scenarios": len(scenario_specs), |
| "city_ids": city_ids, |
| "scenario_names": scenario_names, |
| } |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|