agentic-traffic / scripts /smoke_test_env.py
Aditya2162's picture
Upload folder using huggingface_hub
3d2dbcf verified
from __future__ import annotations
import argparse
import json
from collections import Counter
import numpy as np
from agents.local_policy import FixedCyclePolicy, HoldPhasePolicy
from training.cityflow_dataset import CityFlowDataset
from training.train_local_policy import build_env, build_env_config
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Smoke test the CityFlow environment and district-type routing."
)
parser.add_argument("--generated-root", default="data/generated")
parser.add_argument("--splits-root", default="data/splits")
parser.add_argument("--city-id", default="city_0001")
parser.add_argument("--scenario-name", default="normal")
parser.add_argument("--steps", type=int, default=5)
parser.add_argument("--policy", choices=("hold", "fixed", "random"), default="random")
parser.add_argument("--seed", type=int, default=7)
parser.add_argument("--decision-interval", type=int, default=5)
parser.add_argument("--simulator-interval", type=int, default=1)
parser.add_argument("--min-green-time", type=int, default=10)
parser.add_argument("--thread-num", type=int, default=1)
parser.add_argument("--max-episode-seconds", type=int, default=120)
parser.add_argument("--max-incoming-lanes", type=int, default=16)
parser.add_argument("--count-scale", type=float, default=20.0)
parser.add_argument("--elapsed-time-scale", type=float, default=60.0)
parser.add_argument("--disable-district-context", action="store_true")
parser.add_argument("--disable-outgoing-congestion", action="store_true")
parser.add_argument("--waiting-weight", type=float, default=1.0)
parser.add_argument("--vehicle-weight", type=float, default=0.25)
parser.add_argument("--pressure-weight", type=float, default=0.0)
parser.add_argument("--reward-scale", type=float, default=1.0)
return parser.parse_args()
def main() -> None:
args = parse_args()
rng = np.random.default_rng(args.seed)
dataset = CityFlowDataset(
generated_root=args.generated_root,
splits_root=args.splits_root,
)
scenario_spec = dataset.build_scenario_spec(args.city_id, args.scenario_name)
env = build_env(build_env_config(args), scenario_spec)
observation = env.reset()
print(
json.dumps(
{
"city_id": args.city_id,
"scenario_name": args.scenario_name,
"observation_shape": list(observation["observations"].shape),
"observation_dim": env.observation_dim,
"controlled_intersections": len(observation["intersection_ids"]),
"district_type_counts": Counter(observation["district_types"]),
"district_type_indices_sample": observation["district_type_indices"][:10].tolist(),
"boundary_fraction": float(observation["boundary_mask"].mean()),
},
indent=2,
)
)
policy = resolve_policy(args.policy)
for step in range(args.steps):
if args.policy == "random":
actions = sample_random_actions(observation["action_mask"], rng)
else:
actions = policy.act(observation)
observation, rewards, done, info = env.step(actions)
print(
json.dumps(
{
"step": step,
"sim_time": info["sim_time"],
"reward_mean": float(rewards.mean()),
"reward_min": float(rewards.min()),
"reward_max": float(rewards.max()),
"waiting_mean": info["metrics"]["mean_waiting_vehicles"],
"throughput": info["metrics"]["throughput"],
"district_type_metrics": {
key: value
for key, value in info["metrics"].items()
if "residential" in key
or "commercial" in key
or "industrial" in key
or "mixed" in key
},
},
indent=2,
)
)
if done:
break
def resolve_policy(name: str):
if name == "hold":
return HoldPhasePolicy()
if name == "fixed":
return FixedCyclePolicy()
return None
def sample_random_actions(action_mask: np.ndarray, rng: np.random.Generator) -> np.ndarray:
actions = np.zeros(action_mask.shape[0], dtype=np.int64)
for row_index, mask in enumerate(action_mask):
valid_actions = np.flatnonzero(mask > 0.0)
actions[row_index] = int(rng.choice(valid_actions))
return actions
if __name__ == "__main__":
main()