from __future__ import annotations import argparse import json from collections import Counter import numpy as np from agents.local_policy import FixedCyclePolicy, HoldPhasePolicy from training.cityflow_dataset import CityFlowDataset from training.train_local_policy import build_env, build_env_config def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Smoke test the CityFlow environment and district-type routing." ) parser.add_argument("--generated-root", default="data/generated") parser.add_argument("--splits-root", default="data/splits") parser.add_argument("--city-id", default="city_0001") parser.add_argument("--scenario-name", default="normal") parser.add_argument("--steps", type=int, default=5) parser.add_argument("--policy", choices=("hold", "fixed", "random"), default="random") parser.add_argument("--seed", type=int, default=7) parser.add_argument("--decision-interval", type=int, default=5) parser.add_argument("--simulator-interval", type=int, default=1) parser.add_argument("--min-green-time", type=int, default=10) parser.add_argument("--thread-num", type=int, default=1) parser.add_argument("--max-episode-seconds", type=int, default=120) parser.add_argument("--max-incoming-lanes", type=int, default=16) parser.add_argument("--count-scale", type=float, default=20.0) parser.add_argument("--elapsed-time-scale", type=float, default=60.0) parser.add_argument("--disable-district-context", action="store_true") parser.add_argument("--disable-outgoing-congestion", action="store_true") parser.add_argument("--waiting-weight", type=float, default=1.0) parser.add_argument("--vehicle-weight", type=float, default=0.25) parser.add_argument("--pressure-weight", type=float, default=0.0) parser.add_argument("--reward-scale", type=float, default=1.0) return parser.parse_args() def main() -> None: args = parse_args() rng = np.random.default_rng(args.seed) dataset = CityFlowDataset( generated_root=args.generated_root, splits_root=args.splits_root, ) scenario_spec = dataset.build_scenario_spec(args.city_id, args.scenario_name) env = build_env(build_env_config(args), scenario_spec) observation = env.reset() print( json.dumps( { "city_id": args.city_id, "scenario_name": args.scenario_name, "observation_shape": list(observation["observations"].shape), "observation_dim": env.observation_dim, "controlled_intersections": len(observation["intersection_ids"]), "district_type_counts": Counter(observation["district_types"]), "district_type_indices_sample": observation["district_type_indices"][:10].tolist(), "boundary_fraction": float(observation["boundary_mask"].mean()), }, indent=2, ) ) policy = resolve_policy(args.policy) for step in range(args.steps): if args.policy == "random": actions = sample_random_actions(observation["action_mask"], rng) else: actions = policy.act(observation) observation, rewards, done, info = env.step(actions) print( json.dumps( { "step": step, "sim_time": info["sim_time"], "reward_mean": float(rewards.mean()), "reward_min": float(rewards.min()), "reward_max": float(rewards.max()), "waiting_mean": info["metrics"]["mean_waiting_vehicles"], "throughput": info["metrics"]["throughput"], "district_type_metrics": { key: value for key, value in info["metrics"].items() if "residential" in key or "commercial" in key or "industrial" in key or "mixed" in key }, }, indent=2, ) ) if done: break def resolve_policy(name: str): if name == "hold": return HoldPhasePolicy() if name == "fixed": return FixedCyclePolicy() return None def sample_random_actions(action_mask: np.ndarray, rng: np.random.Generator) -> np.ndarray: actions = np.zeros(action_mask.shape[0], dtype=np.int64) for row_index, mask in enumerate(action_mask): valid_actions = np.flatnonzero(mask > 0.0) actions[row_index] = int(rng.choice(valid_actions)) return actions if __name__ == "__main__": main()