File size: 3,281 Bytes
3d2dbcf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 | from __future__ import annotations
import json
from agents.district_coordinator import RuleBasedDistrictCoordinator
from agents.local_policy import SharedHeuristicLocalPolicy
from training.trainer import DistrictCoordinatorEvaluator
def make_env():
from env.traffic_env import TrafficEnv
from env.intersection_config import IntersectionConfig, DistrictConfig
intersections = {
"I1": IntersectionConfig(
intersection_id="I1",
district_id="D0",
incoming_lanes=["I1_N", "I1_S", "I1_E", "I1_W"],
outgoing_lanes=[],
neighbors=["I2"],
is_border=False,
),
"I2": IntersectionConfig(
intersection_id="I2",
district_id="D0",
incoming_lanes=["I2_N", "I2_S", "I2_E", "I2_W"],
outgoing_lanes=[],
neighbors=["I1", "I3"],
is_border=True,
),
"I3": IntersectionConfig(
intersection_id="I3",
district_id="D1",
incoming_lanes=["I3_N", "I3_S", "I3_E", "I3_W"],
outgoing_lanes=[],
neighbors=["I2", "I4"],
is_border=True,
),
"I4": IntersectionConfig(
intersection_id="I4",
district_id="D1",
incoming_lanes=["I4_N", "I4_S", "I4_E", "I4_W"],
outgoing_lanes=[],
neighbors=["I3"],
is_border=False,
),
}
districts = {
"D0": DistrictConfig(
district_id="D0",
intersection_ids=["I1", "I2"],
neighbor_districts=["D1"],
),
"D1": DistrictConfig(
district_id="D1",
intersection_ids=["I3", "I4"],
neighbor_districts=["D0"],
),
}
return TrafficEnv(
config_path="data/cityflow/config.json",
intersections=intersections,
districts=districts,
coordination_interval=20,
max_steps=200,
)
def main():
local_policy = SharedHeuristicLocalPolicy()
evaluator = DistrictCoordinatorEvaluator(
env_factory=make_env,
local_policy=local_policy,
)
local_only = {}
coordinated = {
"D0": RuleBasedDistrictCoordinator(),
"D1": RuleBasedDistrictCoordinator(),
}
results = evaluator.compare(
seeds=[0, 1, 2, 3, 4],
local_only_coordinators=local_only,
coordinated_coordinators=coordinated,
max_steps=200,
)
print(
json.dumps(
{
"local_only": {
"avg_mean_reward": results["local_only"]["avg_mean_reward"],
"avg_total_waiting": results["local_only"]["avg_total_waiting"],
"avg_total_queue": results["local_only"]["avg_total_queue"],
},
"coordinated": {
"avg_mean_reward": results["coordinated"]["avg_mean_reward"],
"avg_total_waiting": results["coordinated"]["avg_total_waiting"],
"avg_total_queue": results["coordinated"]["avg_total_queue"],
},
"improvements": results["improvements"],
},
indent=2,
)
)
if __name__ == "__main__":
main()
|