agentic-traffic / scripts /run_demo.py
Aditya2162's picture
Upload folder using huggingface_hub
3d2dbcf verified
from __future__ import annotations
import json
from agents.district_coordinator import RuleBasedDistrictCoordinator
from agents.local_policy import SharedHeuristicLocalPolicy
from training.rollout import run_episode
def make_env():
from env.traffic_env import TrafficEnv
from env.intersection_config import IntersectionConfig, DistrictConfig
intersections = {
"I1": IntersectionConfig(
intersection_id="I1",
district_id="D0",
incoming_lanes=["I1_N", "I1_S", "I1_E", "I1_W"],
outgoing_lanes=[],
neighbors=["I2"],
is_border=False,
),
"I2": IntersectionConfig(
intersection_id="I2",
district_id="D0",
incoming_lanes=["I2_N", "I2_S", "I2_E", "I2_W"],
outgoing_lanes=[],
neighbors=["I1", "I3"],
is_border=True,
),
"I3": IntersectionConfig(
intersection_id="I3",
district_id="D1",
incoming_lanes=["I3_N", "I3_S", "I3_E", "I3_W"],
outgoing_lanes=[],
neighbors=["I2", "I4"],
is_border=True,
),
"I4": IntersectionConfig(
intersection_id="I4",
district_id="D1",
incoming_lanes=["I4_N", "I4_S", "I4_E", "I4_W"],
outgoing_lanes=[],
neighbors=["I3"],
is_border=False,
),
}
districts = {
"D0": DistrictConfig(
district_id="D0",
intersection_ids=["I1", "I2"],
neighbor_districts=["D1"],
),
"D1": DistrictConfig(
district_id="D1",
intersection_ids=["I3", "I4"],
neighbor_districts=["D0"],
),
}
return TrafficEnv(
config_path="data/cityflow/config.json",
intersections=intersections,
districts=districts,
coordination_interval=20,
max_steps=200,
)
def main():
env = make_env()
local_policy = SharedHeuristicLocalPolicy()
district_coordinators = {
"D0": RuleBasedDistrictCoordinator(),
"D1": RuleBasedDistrictCoordinator(),
}
result = run_episode(
env=env,
local_policy=local_policy,
district_coordinators=district_coordinators,
seed=0,
max_steps=200,
record_history=True,
policy_update=False,
)
output = {
"seed": result.seed,
"steps": result.steps,
"mean_reward": result.mean_reward,
"total_waiting": result.total_waiting,
"total_queue": result.total_queue,
"final_info": result.final_info,
}
print(json.dumps(output, indent=2))
if __name__ == "__main__":
main()