from __future__ import annotations import random from pathlib import Path from typing import Any import numpy as np from district_llm.guided_control import DistrictGuidedLocalController from district_llm.schema import DistrictAction from district_llm.summary_builder import DistrictStateSummaryBuilder from district_llm.teachers import build_teacher, parse_teacher_spec from env.observation_builder import ObservationConfig from env.reward import RewardConfig from env.traffic_env import EnvConfig, TrafficEnv from training.cityflow_dataset import CityFlowDataset class OpenEnvTrafficWrapper: """ OpenEnv-style district environment backed by the current DQN local stack. External action: - a dict of district-level directives keyed by district_id Internal execution: - the shared DQN (or a baseline fallback) produces low-level actions - district directives bias those low-level actions over a slower district window """ def __init__( self, generated_root: str | Path = "data/generated", splits_root: str | Path = "data/splits", split: str = "train", controller_spec: str | None = None, district_decision_interval: int = 10, seed: int = 7, ): self.dataset = CityFlowDataset( generated_root=generated_root, splits_root=splits_root, ) self.dataset.generate_default_splits() self.split = split self.rng = random.Random(seed) self.district_decision_interval = int(district_decision_interval) self.summary_builder = DistrictStateSummaryBuilder() default_checkpoint = Path("artifacts/dqn_shared/best_validation.pt") if controller_spec is None: controller_spec = ( f"rl_checkpoint={default_checkpoint}" if default_checkpoint.exists() else "queue_greedy" ) controller_type, checkpoint = parse_teacher_spec(controller_spec) try: self.teacher = build_teacher( controller_type=controller_type, checkpoint=checkpoint, seed=seed, ) except ImportError: if controller_spec != "queue_greedy": self.teacher = build_teacher( controller_type="queue_greedy", checkpoint=None, seed=seed, ) else: raise self.guided_controller = DistrictGuidedLocalController(base_teacher=self.teacher) self.env_config = self.teacher.env_config or self._default_env_config() self.env: TrafficEnv | None = None self.current_scenario_spec = None self.last_obs: dict[str, Any] | None = None self.last_info: dict[str, Any] | None = None self.last_summaries: dict[str, Any] = {} def reset( self, seed: int | None = None, city_id: str | None = None, scenario_name: str | None = None, ) -> dict[str, Any]: scenario_spec = ( self.dataset.build_scenario_spec(city_id, scenario_name) if city_id and scenario_name else self.dataset.sample_scenario( split_name=self.split, rng=self.rng, city_id=city_id, scenario_name=scenario_name, ) ) self.current_scenario_spec = scenario_spec self.env = TrafficEnv( city_id=scenario_spec.city_id, scenario_name=scenario_spec.scenario_name, city_dir=scenario_spec.city_dir, scenario_dir=scenario_spec.scenario_dir, config_path=scenario_spec.config_path, roadnet_path=scenario_spec.roadnet_path, district_map_path=scenario_spec.district_map_path, metadata_path=scenario_spec.metadata_path, env_config=self.env_config, ) self.summary_builder.reset() self.last_obs = self.env.reset(seed=seed) self.last_summaries = self.summary_builder.build_all(self.env, self.last_obs) self.last_info = { "seed": seed, "city_id": scenario_spec.city_id, "scenario_name": scenario_spec.scenario_name, "controller_type": self.teacher.metadata.controller_type, "controller_family": self.teacher.metadata.controller_family, "teacher_algorithm": self.teacher.metadata.teacher_algorithm, "district_decision_interval": self.district_decision_interval, } return { "observation": self._build_observation_payload(), "info": self.last_info, } def step(self, action: dict[str, Any]) -> dict[str, Any]: if self.env is None or self.last_obs is None: self.reset(seed=None) assert self.env is not None district_actions = self._parse_district_actions(action.get("district_actions", {})) done = False reward_total = 0.0 steps_executed = 0 info: dict[str, Any] = {} for _ in range(self.district_decision_interval): local_actions = self.guided_controller.act( observation_batch=self.last_obs, district_actions=district_actions, ) next_obs, rewards, done, info = self.env.step(local_actions) reward_total += float(np.asarray(rewards, dtype=np.float32).mean()) self.last_obs = next_obs steps_executed += 1 if done: break self.last_summaries = self.summary_builder.build_all(self.env, self.last_obs) self.last_info = { **info, "controller_type": self.teacher.metadata.controller_type, "controller_family": self.teacher.metadata.controller_family, "teacher_algorithm": self.teacher.metadata.teacher_algorithm, "steps_executed": steps_executed, "district_actions": { district_id: directive.to_dict() for district_id, directive in district_actions.items() }, } return { "observation": self._build_observation_payload(), "reward": float(reward_total), "done": bool(done), "truncated": False, "info": self.last_info, } def state(self) -> dict[str, Any]: return { "state": { "scenario": ( None if self.current_scenario_spec is None else { "city_id": self.current_scenario_spec.city_id, "scenario_name": self.current_scenario_spec.scenario_name, } ), "controller": self.teacher.metadata.to_dict(), "district_decision_interval": self.district_decision_interval, "district_summaries": { district_id: summary.to_dict() for district_id, summary in self.last_summaries.items() }, "last_info": self.last_info or {}, } } def health(self) -> dict[str, Any]: return { "status": "ok", "message": "DistrictFlow OpenEnv wrapper is running.", } def _build_observation_payload(self) -> dict[str, Any]: if self.env is None or self.last_obs is None: return {"district_summaries": {}} return { "city_id": self.env.city_id, "scenario_name": self.env.scenario_name, "decision_step": int(self.last_obs["decision_step"]), "sim_time": int(self.last_obs["sim_time"]), "district_summaries": { district_id: summary.to_dict() for district_id, summary in self.last_summaries.items() }, } def _parse_district_actions(self, payload: dict[str, Any]) -> dict[str, DistrictAction]: if self.env is None: return {} parsed: dict[str, DistrictAction] = {} for district_id in self.env.districts: raw_action = payload.get(district_id) if raw_action is None: parsed[district_id] = DistrictAction.default_hold( duration_steps=self.district_decision_interval ) continue try: parsed[district_id] = DistrictAction.from_dict(raw_action) except Exception: parsed[district_id] = DistrictAction.default_hold( duration_steps=self.district_decision_interval ) return parsed @staticmethod def _default_env_config() -> EnvConfig: return EnvConfig( simulator_interval=1, decision_interval=5, min_green_time=10, thread_num=1, max_episode_seconds=None, observation=ObservationConfig(), reward=RewardConfig(variant="wait_queue_throughput"), )