from __future__ import annotations import argparse import json import random import numpy as np from env.observation_builder import ObservationConfig from env.reward import RewardConfig from env.traffic_env import EnvConfig, TrafficEnv from training.cityflow_dataset import CityFlowDataset def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Smoke test for the CityFlow RL environment.") parser.add_argument("--generated-root", default="data/generated") parser.add_argument("--splits-root", default="data/splits") parser.add_argument("--city-id", default=None) parser.add_argument("--scenario-name", default=None) parser.add_argument("--decision-steps", type=int, default=5) parser.add_argument("--decision-interval", type=int, default=5) parser.add_argument("--min-green-time", type=int, default=10) parser.add_argument("--thread-num", type=int, default=1) parser.add_argument("--seed", type=int, default=7) return parser.parse_args() def main() -> None: args = parse_args() rng = random.Random(args.seed) dataset = CityFlowDataset( generated_root=args.generated_root, splits_root=args.splits_root, ) dataset.generate_default_splits() scenario_spec = ( dataset.build_scenario_spec(args.city_id, args.scenario_name) if args.city_id and args.scenario_name else dataset.sample_scenario("train", rng) ) env = TrafficEnv( city_id=scenario_spec.city_id, scenario_name=scenario_spec.scenario_name, city_dir=scenario_spec.city_dir, scenario_dir=scenario_spec.scenario_dir, config_path=scenario_spec.config_path, roadnet_path=scenario_spec.roadnet_path, district_map_path=scenario_spec.district_map_path, metadata_path=scenario_spec.metadata_path, env_config=EnvConfig( decision_interval=args.decision_interval, min_green_time=args.min_green_time, thread_num=args.thread_num, observation=ObservationConfig(), reward=RewardConfig(), ), ) observation_batch = env.reset() print( json.dumps( { "city_id": env.city_id, "scenario_name": env.scenario_name, "num_controlled_intersections": len(observation_batch["intersection_ids"]), "observation_shape": list(observation_batch["observations"].shape), "lane_mask_shape": list(observation_batch["lane_mask"].shape), "observation_dim": env.observation_dim, }, indent=2, ) ) for decision_step in range(args.decision_steps): random_actions = np.asarray( [rng.randint(0, 1) for _ in observation_batch["intersection_ids"]], dtype=np.int64, ) observation_batch, rewards, done, info = env.step(random_actions) reward_summary = { "decision_step": decision_step + 1, "reward_mean": float(rewards.mean()), "reward_min": float(rewards.min()), "reward_max": float(rewards.max()), "mean_waiting_vehicles": info["metrics"]["mean_waiting_vehicles"], "throughput": info["metrics"]["throughput"], "sim_time": info["sim_time"], } print(json.dumps(reward_summary)) if done: break if __name__ == "__main__": main()