File size: 4,751 Bytes
3d2dbcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from __future__ import annotations

import argparse
import json
from collections import Counter

import numpy as np

from agents.local_policy import FixedCyclePolicy, HoldPhasePolicy
from training.cityflow_dataset import CityFlowDataset
from training.train_local_policy import build_env, build_env_config


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Smoke test the CityFlow environment and district-type routing."
    )
    parser.add_argument("--generated-root", default="data/generated")
    parser.add_argument("--splits-root", default="data/splits")
    parser.add_argument("--city-id", default="city_0001")
    parser.add_argument("--scenario-name", default="normal")
    parser.add_argument("--steps", type=int, default=5)
    parser.add_argument("--policy", choices=("hold", "fixed", "random"), default="random")
    parser.add_argument("--seed", type=int, default=7)
    parser.add_argument("--decision-interval", type=int, default=5)
    parser.add_argument("--simulator-interval", type=int, default=1)
    parser.add_argument("--min-green-time", type=int, default=10)
    parser.add_argument("--thread-num", type=int, default=1)
    parser.add_argument("--max-episode-seconds", type=int, default=120)
    parser.add_argument("--max-incoming-lanes", type=int, default=16)
    parser.add_argument("--count-scale", type=float, default=20.0)
    parser.add_argument("--elapsed-time-scale", type=float, default=60.0)
    parser.add_argument("--disable-district-context", action="store_true")
    parser.add_argument("--disable-outgoing-congestion", action="store_true")
    parser.add_argument("--waiting-weight", type=float, default=1.0)
    parser.add_argument("--vehicle-weight", type=float, default=0.25)
    parser.add_argument("--pressure-weight", type=float, default=0.0)
    parser.add_argument("--reward-scale", type=float, default=1.0)
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    rng = np.random.default_rng(args.seed)
    dataset = CityFlowDataset(
        generated_root=args.generated_root,
        splits_root=args.splits_root,
    )
    scenario_spec = dataset.build_scenario_spec(args.city_id, args.scenario_name)
    env = build_env(build_env_config(args), scenario_spec)

    observation = env.reset()
    print(
        json.dumps(
            {
                "city_id": args.city_id,
                "scenario_name": args.scenario_name,
                "observation_shape": list(observation["observations"].shape),
                "observation_dim": env.observation_dim,
                "controlled_intersections": len(observation["intersection_ids"]),
                "district_type_counts": Counter(observation["district_types"]),
                "district_type_indices_sample": observation["district_type_indices"][:10].tolist(),
                "boundary_fraction": float(observation["boundary_mask"].mean()),
            },
            indent=2,
        )
    )

    policy = resolve_policy(args.policy)
    for step in range(args.steps):
        if args.policy == "random":
            actions = sample_random_actions(observation["action_mask"], rng)
        else:
            actions = policy.act(observation)

        observation, rewards, done, info = env.step(actions)
        print(
            json.dumps(
                {
                    "step": step,
                    "sim_time": info["sim_time"],
                    "reward_mean": float(rewards.mean()),
                    "reward_min": float(rewards.min()),
                    "reward_max": float(rewards.max()),
                    "waiting_mean": info["metrics"]["mean_waiting_vehicles"],
                    "throughput": info["metrics"]["throughput"],
                    "district_type_metrics": {
                        key: value
                        for key, value in info["metrics"].items()
                        if "residential" in key
                        or "commercial" in key
                        or "industrial" in key
                        or "mixed" in key
                    },
                },
                indent=2,
            )
        )
        if done:
            break


def resolve_policy(name: str):
    if name == "hold":
        return HoldPhasePolicy()
    if name == "fixed":
        return FixedCyclePolicy()
    return None


def sample_random_actions(action_mask: np.ndarray, rng: np.random.Generator) -> np.ndarray:
    actions = np.zeros(action_mask.shape[0], dtype=np.int64)
    for row_index, mask in enumerate(action_mask):
        valid_actions = np.flatnonzero(mask > 0.0)
        actions[row_index] = int(rng.choice(valid_actions))
    return actions


if __name__ == "__main__":
    main()