File size: 3,479 Bytes
3d2dbcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from __future__ import annotations

import argparse
import json
import random

import numpy as np

from env.observation_builder import ObservationConfig
from env.reward import RewardConfig
from env.traffic_env import EnvConfig, TrafficEnv
from training.cityflow_dataset import CityFlowDataset


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Smoke test for the CityFlow RL environment.")
    parser.add_argument("--generated-root", default="data/generated")
    parser.add_argument("--splits-root", default="data/splits")
    parser.add_argument("--city-id", default=None)
    parser.add_argument("--scenario-name", default=None)
    parser.add_argument("--decision-steps", type=int, default=5)
    parser.add_argument("--decision-interval", type=int, default=5)
    parser.add_argument("--min-green-time", type=int, default=10)
    parser.add_argument("--thread-num", type=int, default=1)
    parser.add_argument("--seed", type=int, default=7)
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    rng = random.Random(args.seed)

    dataset = CityFlowDataset(
        generated_root=args.generated_root,
        splits_root=args.splits_root,
    )
    dataset.generate_default_splits()
    scenario_spec = (
        dataset.build_scenario_spec(args.city_id, args.scenario_name)
        if args.city_id and args.scenario_name
        else dataset.sample_scenario("train", rng)
    )

    env = TrafficEnv(
        city_id=scenario_spec.city_id,
        scenario_name=scenario_spec.scenario_name,
        city_dir=scenario_spec.city_dir,
        scenario_dir=scenario_spec.scenario_dir,
        config_path=scenario_spec.config_path,
        roadnet_path=scenario_spec.roadnet_path,
        district_map_path=scenario_spec.district_map_path,
        metadata_path=scenario_spec.metadata_path,
        env_config=EnvConfig(
            decision_interval=args.decision_interval,
            min_green_time=args.min_green_time,
            thread_num=args.thread_num,
            observation=ObservationConfig(),
            reward=RewardConfig(),
        ),
    )

    observation_batch = env.reset()
    print(
        json.dumps(
            {
                "city_id": env.city_id,
                "scenario_name": env.scenario_name,
                "num_controlled_intersections": len(observation_batch["intersection_ids"]),
                "observation_shape": list(observation_batch["observations"].shape),
                "lane_mask_shape": list(observation_batch["lane_mask"].shape),
                "observation_dim": env.observation_dim,
            },
            indent=2,
        )
    )

    for decision_step in range(args.decision_steps):
        random_actions = np.asarray(
            [rng.randint(0, 1) for _ in observation_batch["intersection_ids"]],
            dtype=np.int64,
        )
        observation_batch, rewards, done, info = env.step(random_actions)
        reward_summary = {
            "decision_step": decision_step + 1,
            "reward_mean": float(rewards.mean()),
            "reward_min": float(rewards.min()),
            "reward_max": float(rewards.max()),
            "mean_waiting_vehicles": info["metrics"]["mean_waiting_vehicles"],
            "throughput": info["metrics"]["throughput"],
            "sim_time": info["sim_time"],
        }
        print(json.dumps(reward_summary))
        if done:
            break


if __name__ == "__main__":
    main()