from __future__ import annotations import json from agents.district_coordinator import RuleBasedDistrictCoordinator from agents.local_policy import SharedHeuristicLocalPolicy from training.rollout import run_episode def make_env(): from env.traffic_env import TrafficEnv from env.intersection_config import IntersectionConfig, DistrictConfig intersections = { "I1": IntersectionConfig( intersection_id="I1", district_id="D0", incoming_lanes=["I1_N", "I1_S", "I1_E", "I1_W"], outgoing_lanes=[], neighbors=["I2"], is_border=False, ), "I2": IntersectionConfig( intersection_id="I2", district_id="D0", incoming_lanes=["I2_N", "I2_S", "I2_E", "I2_W"], outgoing_lanes=[], neighbors=["I1", "I3"], is_border=True, ), "I3": IntersectionConfig( intersection_id="I3", district_id="D1", incoming_lanes=["I3_N", "I3_S", "I3_E", "I3_W"], outgoing_lanes=[], neighbors=["I2", "I4"], is_border=True, ), "I4": IntersectionConfig( intersection_id="I4", district_id="D1", incoming_lanes=["I4_N", "I4_S", "I4_E", "I4_W"], outgoing_lanes=[], neighbors=["I3"], is_border=False, ), } districts = { "D0": DistrictConfig( district_id="D0", intersection_ids=["I1", "I2"], neighbor_districts=["D1"], ), "D1": DistrictConfig( district_id="D1", intersection_ids=["I3", "I4"], neighbor_districts=["D0"], ), } return TrafficEnv( config_path="data/cityflow/config.json", intersections=intersections, districts=districts, coordination_interval=20, max_steps=200, ) def main(): env = make_env() local_policy = SharedHeuristicLocalPolicy() district_coordinators = { "D0": RuleBasedDistrictCoordinator(), "D1": RuleBasedDistrictCoordinator(), } result = run_episode( env=env, local_policy=local_policy, district_coordinators=district_coordinators, seed=0, max_steps=200, record_history=True, policy_update=False, ) output = { "seed": result.seed, "steps": result.steps, "mean_reward": result.mean_reward, "total_waiting": result.total_waiting, "total_queue": result.total_queue, "final_info": result.final_info, } print(json.dumps(output, indent=2)) if __name__ == "__main__": main()