File size: 2,749 Bytes
3d2dbcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from __future__ import annotations

import json

from agents.district_coordinator import RuleBasedDistrictCoordinator
from agents.local_policy import SharedHeuristicLocalPolicy
from training.rollout import run_episode


def make_env():
    from env.traffic_env import TrafficEnv
    from env.intersection_config import IntersectionConfig, DistrictConfig

    intersections = {
        "I1": IntersectionConfig(
            intersection_id="I1",
            district_id="D0",
            incoming_lanes=["I1_N", "I1_S", "I1_E", "I1_W"],
            outgoing_lanes=[],
            neighbors=["I2"],
            is_border=False,
        ),
        "I2": IntersectionConfig(
            intersection_id="I2",
            district_id="D0",
            incoming_lanes=["I2_N", "I2_S", "I2_E", "I2_W"],
            outgoing_lanes=[],
            neighbors=["I1", "I3"],
            is_border=True,
        ),
        "I3": IntersectionConfig(
            intersection_id="I3",
            district_id="D1",
            incoming_lanes=["I3_N", "I3_S", "I3_E", "I3_W"],
            outgoing_lanes=[],
            neighbors=["I2", "I4"],
            is_border=True,
        ),
        "I4": IntersectionConfig(
            intersection_id="I4",
            district_id="D1",
            incoming_lanes=["I4_N", "I4_S", "I4_E", "I4_W"],
            outgoing_lanes=[],
            neighbors=["I3"],
            is_border=False,
        ),
    }

    districts = {
        "D0": DistrictConfig(
            district_id="D0",
            intersection_ids=["I1", "I2"],
            neighbor_districts=["D1"],
        ),
        "D1": DistrictConfig(
            district_id="D1",
            intersection_ids=["I3", "I4"],
            neighbor_districts=["D0"],
        ),
    }

    return TrafficEnv(
        config_path="data/cityflow/config.json",
        intersections=intersections,
        districts=districts,
        coordination_interval=20,
        max_steps=200,
    )


def main():
    env = make_env()
    local_policy = SharedHeuristicLocalPolicy()
    district_coordinators = {
        "D0": RuleBasedDistrictCoordinator(),
        "D1": RuleBasedDistrictCoordinator(),
    }

    result = run_episode(
        env=env,
        local_policy=local_policy,
        district_coordinators=district_coordinators,
        seed=0,
        max_steps=200,
        record_history=True,
        policy_update=False,
    )

    output = {
        "seed": result.seed,
        "steps": result.steps,
        "mean_reward": result.mean_reward,
        "total_waiting": result.total_waiting,
        "total_queue": result.total_queue,
        "final_info": result.final_info,
    }

    print(json.dumps(output, indent=2))


if __name__ == "__main__":
    main()