agentic-traffic / scripts /evaluate.py
Aditya2162's picture
Upload folder using huggingface_hub
3d2dbcf verified
from __future__ import annotations
import json
from agents.district_coordinator import RuleBasedDistrictCoordinator
from agents.local_policy import SharedHeuristicLocalPolicy
from training.trainer import DistrictCoordinatorEvaluator
def make_env():
from env.traffic_env import TrafficEnv
from env.intersection_config import IntersectionConfig, DistrictConfig
intersections = {
"I1": IntersectionConfig(
intersection_id="I1",
district_id="D0",
incoming_lanes=["I1_N", "I1_S", "I1_E", "I1_W"],
outgoing_lanes=[],
neighbors=["I2"],
is_border=False,
),
"I2": IntersectionConfig(
intersection_id="I2",
district_id="D0",
incoming_lanes=["I2_N", "I2_S", "I2_E", "I2_W"],
outgoing_lanes=[],
neighbors=["I1", "I3"],
is_border=True,
),
"I3": IntersectionConfig(
intersection_id="I3",
district_id="D1",
incoming_lanes=["I3_N", "I3_S", "I3_E", "I3_W"],
outgoing_lanes=[],
neighbors=["I2", "I4"],
is_border=True,
),
"I4": IntersectionConfig(
intersection_id="I4",
district_id="D1",
incoming_lanes=["I4_N", "I4_S", "I4_E", "I4_W"],
outgoing_lanes=[],
neighbors=["I3"],
is_border=False,
),
}
districts = {
"D0": DistrictConfig(
district_id="D0",
intersection_ids=["I1", "I2"],
neighbor_districts=["D1"],
),
"D1": DistrictConfig(
district_id="D1",
intersection_ids=["I3", "I4"],
neighbor_districts=["D0"],
),
}
return TrafficEnv(
config_path="data/cityflow/config.json",
intersections=intersections,
districts=districts,
coordination_interval=20,
max_steps=200,
)
def main():
local_policy = SharedHeuristicLocalPolicy()
evaluator = DistrictCoordinatorEvaluator(
env_factory=make_env,
local_policy=local_policy,
)
local_only = {}
coordinated = {
"D0": RuleBasedDistrictCoordinator(),
"D1": RuleBasedDistrictCoordinator(),
}
results = evaluator.compare(
seeds=[0, 1, 2, 3, 4],
local_only_coordinators=local_only,
coordinated_coordinators=coordinated,
max_steps=200,
)
print(
json.dumps(
{
"local_only": {
"avg_mean_reward": results["local_only"]["avg_mean_reward"],
"avg_total_waiting": results["local_only"]["avg_total_waiting"],
"avg_total_queue": results["local_only"]["avg_total_queue"],
},
"coordinated": {
"avg_mean_reward": results["coordinated"]["avg_mean_reward"],
"avg_total_waiting": results["coordinated"]["avg_total_waiting"],
"avg_total_queue": results["coordinated"]["avg_total_queue"],
},
"improvements": results["improvements"],
},
indent=2,
)
)
if __name__ == "__main__":
main()