File size: 4,751 Bytes
3d2dbcf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 | from __future__ import annotations
import argparse
import json
from collections import Counter
import numpy as np
from agents.local_policy import FixedCyclePolicy, HoldPhasePolicy
from training.cityflow_dataset import CityFlowDataset
from training.train_local_policy import build_env, build_env_config
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Smoke test the CityFlow environment and district-type routing."
)
parser.add_argument("--generated-root", default="data/generated")
parser.add_argument("--splits-root", default="data/splits")
parser.add_argument("--city-id", default="city_0001")
parser.add_argument("--scenario-name", default="normal")
parser.add_argument("--steps", type=int, default=5)
parser.add_argument("--policy", choices=("hold", "fixed", "random"), default="random")
parser.add_argument("--seed", type=int, default=7)
parser.add_argument("--decision-interval", type=int, default=5)
parser.add_argument("--simulator-interval", type=int, default=1)
parser.add_argument("--min-green-time", type=int, default=10)
parser.add_argument("--thread-num", type=int, default=1)
parser.add_argument("--max-episode-seconds", type=int, default=120)
parser.add_argument("--max-incoming-lanes", type=int, default=16)
parser.add_argument("--count-scale", type=float, default=20.0)
parser.add_argument("--elapsed-time-scale", type=float, default=60.0)
parser.add_argument("--disable-district-context", action="store_true")
parser.add_argument("--disable-outgoing-congestion", action="store_true")
parser.add_argument("--waiting-weight", type=float, default=1.0)
parser.add_argument("--vehicle-weight", type=float, default=0.25)
parser.add_argument("--pressure-weight", type=float, default=0.0)
parser.add_argument("--reward-scale", type=float, default=1.0)
return parser.parse_args()
def main() -> None:
args = parse_args()
rng = np.random.default_rng(args.seed)
dataset = CityFlowDataset(
generated_root=args.generated_root,
splits_root=args.splits_root,
)
scenario_spec = dataset.build_scenario_spec(args.city_id, args.scenario_name)
env = build_env(build_env_config(args), scenario_spec)
observation = env.reset()
print(
json.dumps(
{
"city_id": args.city_id,
"scenario_name": args.scenario_name,
"observation_shape": list(observation["observations"].shape),
"observation_dim": env.observation_dim,
"controlled_intersections": len(observation["intersection_ids"]),
"district_type_counts": Counter(observation["district_types"]),
"district_type_indices_sample": observation["district_type_indices"][:10].tolist(),
"boundary_fraction": float(observation["boundary_mask"].mean()),
},
indent=2,
)
)
policy = resolve_policy(args.policy)
for step in range(args.steps):
if args.policy == "random":
actions = sample_random_actions(observation["action_mask"], rng)
else:
actions = policy.act(observation)
observation, rewards, done, info = env.step(actions)
print(
json.dumps(
{
"step": step,
"sim_time": info["sim_time"],
"reward_mean": float(rewards.mean()),
"reward_min": float(rewards.min()),
"reward_max": float(rewards.max()),
"waiting_mean": info["metrics"]["mean_waiting_vehicles"],
"throughput": info["metrics"]["throughput"],
"district_type_metrics": {
key: value
for key, value in info["metrics"].items()
if "residential" in key
or "commercial" in key
or "industrial" in key
or "mixed" in key
},
},
indent=2,
)
)
if done:
break
def resolve_policy(name: str):
if name == "hold":
return HoldPhasePolicy()
if name == "fixed":
return FixedCyclePolicy()
return None
def sample_random_actions(action_mask: np.ndarray, rng: np.random.Generator) -> np.ndarray:
actions = np.zeros(action_mask.shape[0], dtype=np.int64)
for row_index, mask in enumerate(action_mask):
valid_actions = np.flatnonzero(mask > 0.0)
actions[row_index] = int(rng.choice(valid_actions))
return actions
if __name__ == "__main__":
main()
|