| from __future__ import annotations |
|
|
| import json |
|
|
| from agents.district_coordinator import RuleBasedDistrictCoordinator |
| from agents.local_policy import SharedHeuristicLocalPolicy |
| from training.trainer import DistrictCoordinatorEvaluator |
|
|
|
|
| def make_env(): |
| from env.traffic_env import TrafficEnv |
| from env.intersection_config import IntersectionConfig, DistrictConfig |
|
|
| intersections = { |
| "I1": IntersectionConfig( |
| intersection_id="I1", |
| district_id="D0", |
| incoming_lanes=["I1_N", "I1_S", "I1_E", "I1_W"], |
| outgoing_lanes=[], |
| neighbors=["I2"], |
| is_border=False, |
| ), |
| "I2": IntersectionConfig( |
| intersection_id="I2", |
| district_id="D0", |
| incoming_lanes=["I2_N", "I2_S", "I2_E", "I2_W"], |
| outgoing_lanes=[], |
| neighbors=["I1", "I3"], |
| is_border=True, |
| ), |
| "I3": IntersectionConfig( |
| intersection_id="I3", |
| district_id="D1", |
| incoming_lanes=["I3_N", "I3_S", "I3_E", "I3_W"], |
| outgoing_lanes=[], |
| neighbors=["I2", "I4"], |
| is_border=True, |
| ), |
| "I4": IntersectionConfig( |
| intersection_id="I4", |
| district_id="D1", |
| incoming_lanes=["I4_N", "I4_S", "I4_E", "I4_W"], |
| outgoing_lanes=[], |
| neighbors=["I3"], |
| is_border=False, |
| ), |
| } |
|
|
| districts = { |
| "D0": DistrictConfig( |
| district_id="D0", |
| intersection_ids=["I1", "I2"], |
| neighbor_districts=["D1"], |
| ), |
| "D1": DistrictConfig( |
| district_id="D1", |
| intersection_ids=["I3", "I4"], |
| neighbor_districts=["D0"], |
| ), |
| } |
|
|
| return TrafficEnv( |
| config_path="data/cityflow/config.json", |
| intersections=intersections, |
| districts=districts, |
| coordination_interval=20, |
| max_steps=200, |
| ) |
|
|
|
|
| def main(): |
| local_policy = SharedHeuristicLocalPolicy() |
| evaluator = DistrictCoordinatorEvaluator( |
| env_factory=make_env, |
| local_policy=local_policy, |
| ) |
|
|
| local_only = {} |
| coordinated = { |
| "D0": RuleBasedDistrictCoordinator(), |
| "D1": RuleBasedDistrictCoordinator(), |
| } |
|
|
| results = evaluator.compare( |
| seeds=[0, 1, 2, 3, 4], |
| local_only_coordinators=local_only, |
| coordinated_coordinators=coordinated, |
| max_steps=200, |
| ) |
|
|
| print( |
| json.dumps( |
| { |
| "local_only": { |
| "avg_mean_reward": results["local_only"]["avg_mean_reward"], |
| "avg_total_waiting": results["local_only"]["avg_total_waiting"], |
| "avg_total_queue": results["local_only"]["avg_total_queue"], |
| }, |
| "coordinated": { |
| "avg_mean_reward": results["coordinated"]["avg_mean_reward"], |
| "avg_total_waiting": results["coordinated"]["avg_total_waiting"], |
| "avg_total_queue": results["coordinated"]["avg_total_queue"], |
| }, |
| "improvements": results["improvements"], |
| }, |
| indent=2, |
| ) |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|