agentic-traffic / scripts /compare_local_policies.py
Aditya2162's picture
Upload folder using huggingface_hub
3d2dbcf verified
from __future__ import annotations
import argparse
import json
from pathlib import Path
import sys
import torch
from tqdm.auto import tqdm
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from agents.local_policy import FixedCyclePolicy, RandomPhasePolicy
from training.cityflow_dataset import CityFlowDataset
from training.device import configure_torch_runtime, resolve_torch_device
from training.models import RunningNormalizer, TrafficControlQNetwork
from training.rollout import evaluate_policy
from training.train_local_policy import build_env, build_env_config, load_env_config
from training.trainer import aggregate_metrics, aggregate_metrics_by_scenario
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=(
"Compare a learned local policy checkpoint against fixed and random "
"baselines under the same reward config."
)
)
parser.add_argument("--checkpoint", required=True)
parser.add_argument("--city-id", default=None)
parser.add_argument("--scenario-name", default=None)
parser.add_argument("--split", default="val", choices=("train", "val", "test"))
parser.add_argument("--max-val-cities", type=int, default=None)
parser.add_argument("--scenarios-per-city", type=int, default=1)
parser.add_argument("--generated-root", default="data/generated")
parser.add_argument("--splits-root", default="data/splits")
parser.add_argument("--device", default=None)
parser.add_argument("--decision-interval", type=int, default=5)
parser.add_argument("--simulator-interval", type=int, default=1)
parser.add_argument("--min-green-time", type=int, default=10)
parser.add_argument("--thread-num", type=int, default=1)
parser.add_argument("--max-episode-seconds", type=int, default=None)
parser.add_argument("--max-incoming-lanes", type=int, default=16)
parser.add_argument("--count-scale", type=float, default=20.0)
parser.add_argument("--elapsed-time-scale", type=float, default=60.0)
parser.add_argument("--disable-district-context", action="store_true")
parser.add_argument("--disable-outgoing-congestion", action="store_true")
parser.add_argument("--reward-variant", default="wait_queue_throughput")
parser.add_argument("--waiting-weight", type=float, default=1.0)
parser.add_argument("--vehicle-weight", type=float, default=0.1)
parser.add_argument("--pressure-weight", type=float, default=0.0)
parser.add_argument("--reward-scale", type=float, default=0.1)
parser.add_argument("--disable-lane-reward-normalization", action="store_true")
parser.add_argument("--reward-clip", type=float, default=5.0)
parser.add_argument("--queue-delta-weight", type=float, default=2.0)
parser.add_argument("--wait-delta-weight", type=float, default=4.0)
parser.add_argument("--queue-level-weight", type=float, default=0.5)
parser.add_argument("--wait-level-weight", type=float, default=1.0)
parser.add_argument("--throughput-weight", type=float, default=0.1)
parser.add_argument("--imbalance-weight", type=float, default=0.1)
parser.add_argument("--reward-delta-clip", type=float, default=2.0)
parser.add_argument("--reward-level-normalizer", type=float, default=10.0)
parser.add_argument("--throughput-normalizer", type=float, default=2.0)
parser.add_argument("--policy-arch", default="single_head_with_district_feature")
parser.add_argument("--fixed-green-time", type=int, default=20)
parser.add_argument("--random-seed", type=int, default=7)
parser.add_argument("--disable-tqdm", action="store_true")
parser.add_argument("--verbose-progress", action="store_true")
return parser.parse_args()
def main() -> None:
args = parse_args()
if (args.city_id is None) != (args.scenario_name is None):
raise ValueError("--city-id and --scenario-name must be provided together.")
dataset = CityFlowDataset(
generated_root=args.generated_root,
splits_root=args.splits_root,
)
scenario_specs = build_scenario_specs(dataset, args)
device = resolve_torch_device(args.device)
configure_torch_runtime(device)
print(f"[setup] torch_device={device.type}")
env_config = build_env_config(args)
checkpoint = torch.load(
args.checkpoint,
map_location=device,
weights_only=False,
)
if checkpoint.get("env_config"):
env_config = load_env_config(checkpoint["env_config"])
network_architecture = checkpoint.get("network_architecture") or checkpoint.get(
"policy_architecture", {}
)
trainer_config = checkpoint.get("dqn_config", {})
checkpoint_policy_arch = network_architecture.get(
"policy_arch",
trainer_config.get("policy_arch", args.policy_arch),
)
actor = TrafficControlQNetwork(
observation_dim=int(network_architecture["observation_dim"]),
action_dim=int(network_architecture.get("action_dim", 2)),
hidden_dim=int(trainer_config.get("hidden_dim", 256)),
num_layers=int(trainer_config.get("hidden_layers", 2)),
district_types=tuple(network_architecture.get("district_types", ())),
policy_arch=checkpoint_policy_arch,
dueling=bool(network_architecture.get("dueling", True)),
).to(device)
actor.load_state_dict(
checkpoint.get("q_network_state_dict") or checkpoint["policy_state_dict"]
)
actor.eval()
obs_normalizer = None
if checkpoint.get("obs_normalizer"):
obs_normalizer = RunningNormalizer()
obs_normalizer.load_state_dict(checkpoint["obs_normalizer"])
policies = {
"learned": (actor, device, obs_normalizer),
"fixed": (FixedCyclePolicy(green_time=args.fixed_green_time), None, None),
"random": (RandomPhasePolicy(seed=args.random_seed), None, None),
}
scope = build_scope_summary(args, scenario_specs)
print(
"[compare] "
f"num_cities={scope['num_cities']} "
f"num_scenarios={scope['num_scenarios']} "
f"reward_variant={env_config.reward.variant}"
)
aggregate_results: dict[str, dict[str, float]] = {}
scenario_breakdowns: dict[str, dict[str, float]] = {}
for name, (policy, policy_device, normalizer) in policies.items():
print(f"[compare] starting policy={name}")
episode_metrics = []
iterator = enumerate(scenario_specs, start=1)
if not args.disable_tqdm:
iterator = tqdm(
iterator,
total=len(scenario_specs),
desc=f"compare:{name}",
dynamic_ncols=True,
leave=False,
)
for index, spec in iterator:
if args.verbose_progress:
message = (
f"[compare] policy={name} city={spec.city_id} "
f"scenario={spec.scenario_name} i={index}/{len(scenario_specs)}"
)
if args.disable_tqdm:
print(message)
else:
tqdm.write(message)
metrics = evaluate_policy(
env_factory=lambda spec=spec, config=env_config: build_env(config, spec),
actor=policy,
device=policy_device,
obs_normalizer=normalizer,
deterministic=True,
)
episode_metrics.append(metrics)
if not args.disable_tqdm:
iterator.set_postfix(
city=spec.city_id,
scenario=spec.scenario_name,
ret=f"{metrics['episode_return']:.3f}",
)
aggregate_results[name] = aggregate_metrics(episode_metrics)
scenario_breakdowns[name] = aggregate_metrics_by_scenario(episode_metrics)
mean_return = aggregate_results[name].get("mean_episode_return", float("nan"))
mean_wait = aggregate_results[name].get("mean_mean_waiting_vehicles", float("nan"))
mean_throughput = aggregate_results[name].get("mean_throughput", float("nan"))
message = (
f"[compare] finished policy={name} "
f"mean_return={mean_return:.3f} "
f"wait={mean_wait:.3f} "
f"throughput={mean_throughput:.1f}"
)
if args.disable_tqdm:
print(message)
else:
tqdm.write(message)
learned = aggregate_results["learned"]
fixed = aggregate_results["fixed"]
random = aggregate_results["random"]
summary = {
"comparison_scope": build_scope_summary(args, scenario_specs),
"reward_variant": env_config.reward.variant,
"checkpoint": args.checkpoint,
"results": aggregate_results,
"scenario_breakdowns": scenario_breakdowns,
"deltas": {
"learned_minus_fixed_return": float(learned["mean_episode_return"])
- float(fixed["mean_episode_return"]),
"learned_minus_random_return": float(learned["mean_episode_return"])
- float(random["mean_episode_return"]),
"learned_minus_fixed_wait": float(learned["mean_mean_waiting_vehicles"])
- float(fixed["mean_mean_waiting_vehicles"]),
"learned_minus_random_wait": float(learned["mean_mean_waiting_vehicles"])
- float(random["mean_mean_waiting_vehicles"]),
"learned_minus_fixed_travel_time": float(learned["mean_average_travel_time"])
- float(fixed["mean_average_travel_time"]),
"learned_minus_random_travel_time": float(learned["mean_average_travel_time"])
- float(random["mean_average_travel_time"]),
"learned_minus_fixed_throughput": float(learned["mean_throughput"])
- float(fixed["mean_throughput"]),
"learned_minus_random_throughput": float(learned["mean_throughput"])
- float(random["mean_throughput"]),
},
}
print(json.dumps(summary, indent=2))
def build_scenario_specs(dataset: CityFlowDataset, args: argparse.Namespace) -> list:
if args.city_id and args.scenario_name:
return [dataset.build_scenario_spec(args.city_id, args.scenario_name)]
return dataset.iter_scenarios(
split_name=args.split,
scenarios_per_city=args.scenarios_per_city,
max_cities=args.max_val_cities,
diversify_single_scenario=True,
)
def build_scope_summary(args: argparse.Namespace, scenario_specs: list) -> dict[str, object]:
city_ids = sorted({spec.city_id for spec in scenario_specs})
scenario_names = sorted({spec.scenario_name for spec in scenario_specs})
return {
"split": args.split if not args.city_id else None,
"city_id": args.city_id,
"scenario_name": args.scenario_name,
"num_cities": len(city_ids),
"num_scenarios": len(scenario_specs),
"city_ids": city_ids,
"scenario_names": scenario_names,
}
if __name__ == "__main__":
main()