Spaces:
Running
Running
| from __future__ import annotations | |
| import random | |
| from pathlib import Path | |
| from typing import Any | |
| import numpy as np | |
| from district_llm.guided_control import DistrictGuidedLocalController | |
| from district_llm.schema import DistrictAction | |
| from district_llm.summary_builder import DistrictStateSummaryBuilder | |
| from district_llm.teachers import build_teacher, parse_teacher_spec | |
| from env.observation_builder import ObservationConfig | |
| from env.reward import RewardConfig | |
| from env.traffic_env import EnvConfig, TrafficEnv | |
| from training.cityflow_dataset import CityFlowDataset | |
| class OpenEnvTrafficWrapper: | |
| """ | |
| OpenEnv-style district environment backed by the current DQN local stack. | |
| External action: | |
| - a dict of district-level directives keyed by district_id | |
| Internal execution: | |
| - the shared DQN (or a baseline fallback) produces low-level actions | |
| - district directives bias those low-level actions over a slower district window | |
| """ | |
| def __init__( | |
| self, | |
| generated_root: str | Path = "data/generated", | |
| splits_root: str | Path = "data/splits", | |
| split: str = "train", | |
| controller_spec: str | None = None, | |
| district_decision_interval: int = 10, | |
| seed: int = 7, | |
| ): | |
| self.dataset = CityFlowDataset( | |
| generated_root=generated_root, | |
| splits_root=splits_root, | |
| ) | |
| self.dataset.generate_default_splits() | |
| self.split = split | |
| self.rng = random.Random(seed) | |
| self.district_decision_interval = int(district_decision_interval) | |
| self.summary_builder = DistrictStateSummaryBuilder() | |
| default_checkpoint = Path("artifacts/dqn_shared/best_validation.pt") | |
| if controller_spec is None: | |
| controller_spec = ( | |
| f"rl_checkpoint={default_checkpoint}" | |
| if default_checkpoint.exists() | |
| else "queue_greedy" | |
| ) | |
| controller_type, checkpoint = parse_teacher_spec(controller_spec) | |
| try: | |
| self.teacher = build_teacher( | |
| controller_type=controller_type, | |
| checkpoint=checkpoint, | |
| seed=seed, | |
| ) | |
| except ImportError: | |
| if controller_spec != "queue_greedy": | |
| self.teacher = build_teacher( | |
| controller_type="queue_greedy", | |
| checkpoint=None, | |
| seed=seed, | |
| ) | |
| else: | |
| raise | |
| self.guided_controller = DistrictGuidedLocalController(base_teacher=self.teacher) | |
| self.env_config = self.teacher.env_config or self._default_env_config() | |
| self.env: TrafficEnv | None = None | |
| self.current_scenario_spec = None | |
| self.last_obs: dict[str, Any] | None = None | |
| self.last_info: dict[str, Any] | None = None | |
| self.last_summaries: dict[str, Any] = {} | |
| def reset( | |
| self, | |
| seed: int | None = None, | |
| city_id: str | None = None, | |
| scenario_name: str | None = None, | |
| ) -> dict[str, Any]: | |
| scenario_spec = ( | |
| self.dataset.build_scenario_spec(city_id, scenario_name) | |
| if city_id and scenario_name | |
| else self.dataset.sample_scenario( | |
| split_name=self.split, | |
| rng=self.rng, | |
| city_id=city_id, | |
| scenario_name=scenario_name, | |
| ) | |
| ) | |
| self.current_scenario_spec = scenario_spec | |
| self.env = TrafficEnv( | |
| city_id=scenario_spec.city_id, | |
| scenario_name=scenario_spec.scenario_name, | |
| city_dir=scenario_spec.city_dir, | |
| scenario_dir=scenario_spec.scenario_dir, | |
| config_path=scenario_spec.config_path, | |
| roadnet_path=scenario_spec.roadnet_path, | |
| district_map_path=scenario_spec.district_map_path, | |
| metadata_path=scenario_spec.metadata_path, | |
| env_config=self.env_config, | |
| ) | |
| self.summary_builder.reset() | |
| self.last_obs = self.env.reset(seed=seed) | |
| self.last_summaries = self.summary_builder.build_all(self.env, self.last_obs) | |
| self.last_info = { | |
| "seed": seed, | |
| "city_id": scenario_spec.city_id, | |
| "scenario_name": scenario_spec.scenario_name, | |
| "controller_type": self.teacher.metadata.controller_type, | |
| "controller_family": self.teacher.metadata.controller_family, | |
| "teacher_algorithm": self.teacher.metadata.teacher_algorithm, | |
| "district_decision_interval": self.district_decision_interval, | |
| } | |
| return { | |
| "observation": self._build_observation_payload(), | |
| "info": self.last_info, | |
| } | |
| def step(self, action: dict[str, Any]) -> dict[str, Any]: | |
| if self.env is None or self.last_obs is None: | |
| self.reset(seed=None) | |
| assert self.env is not None | |
| district_actions = self._parse_district_actions(action.get("district_actions", {})) | |
| done = False | |
| reward_total = 0.0 | |
| steps_executed = 0 | |
| info: dict[str, Any] = {} | |
| for _ in range(self.district_decision_interval): | |
| local_actions = self.guided_controller.act( | |
| observation_batch=self.last_obs, | |
| district_actions=district_actions, | |
| ) | |
| next_obs, rewards, done, info = self.env.step(local_actions) | |
| reward_total += float(np.asarray(rewards, dtype=np.float32).mean()) | |
| self.last_obs = next_obs | |
| steps_executed += 1 | |
| if done: | |
| break | |
| self.last_summaries = self.summary_builder.build_all(self.env, self.last_obs) | |
| self.last_info = { | |
| **info, | |
| "controller_type": self.teacher.metadata.controller_type, | |
| "controller_family": self.teacher.metadata.controller_family, | |
| "teacher_algorithm": self.teacher.metadata.teacher_algorithm, | |
| "steps_executed": steps_executed, | |
| "district_actions": { | |
| district_id: directive.to_dict() | |
| for district_id, directive in district_actions.items() | |
| }, | |
| } | |
| return { | |
| "observation": self._build_observation_payload(), | |
| "reward": float(reward_total), | |
| "done": bool(done), | |
| "truncated": False, | |
| "info": self.last_info, | |
| } | |
| def state(self) -> dict[str, Any]: | |
| return { | |
| "state": { | |
| "scenario": ( | |
| None | |
| if self.current_scenario_spec is None | |
| else { | |
| "city_id": self.current_scenario_spec.city_id, | |
| "scenario_name": self.current_scenario_spec.scenario_name, | |
| } | |
| ), | |
| "controller": self.teacher.metadata.to_dict(), | |
| "district_decision_interval": self.district_decision_interval, | |
| "district_summaries": { | |
| district_id: summary.to_dict() | |
| for district_id, summary in self.last_summaries.items() | |
| }, | |
| "last_info": self.last_info or {}, | |
| } | |
| } | |
| def health(self) -> dict[str, Any]: | |
| return { | |
| "status": "ok", | |
| "message": "DistrictFlow OpenEnv wrapper is running.", | |
| } | |
| def _build_observation_payload(self) -> dict[str, Any]: | |
| if self.env is None or self.last_obs is None: | |
| return {"district_summaries": {}} | |
| return { | |
| "city_id": self.env.city_id, | |
| "scenario_name": self.env.scenario_name, | |
| "decision_step": int(self.last_obs["decision_step"]), | |
| "sim_time": int(self.last_obs["sim_time"]), | |
| "district_summaries": { | |
| district_id: summary.to_dict() | |
| for district_id, summary in self.last_summaries.items() | |
| }, | |
| } | |
| def _parse_district_actions(self, payload: dict[str, Any]) -> dict[str, DistrictAction]: | |
| if self.env is None: | |
| return {} | |
| parsed: dict[str, DistrictAction] = {} | |
| for district_id in self.env.districts: | |
| raw_action = payload.get(district_id) | |
| if raw_action is None: | |
| parsed[district_id] = DistrictAction.default_hold( | |
| duration_steps=self.district_decision_interval | |
| ) | |
| continue | |
| try: | |
| parsed[district_id] = DistrictAction.from_dict(raw_action) | |
| except Exception: | |
| parsed[district_id] = DistrictAction.default_hold( | |
| duration_steps=self.district_decision_interval | |
| ) | |
| return parsed | |
| def _default_env_config() -> EnvConfig: | |
| return EnvConfig( | |
| simulator_interval=1, | |
| decision_interval=5, | |
| min_green_time=10, | |
| thread_num=1, | |
| max_episode_seconds=None, | |
| observation=ObservationConfig(), | |
| reward=RewardConfig(variant="wait_queue_throughput"), | |
| ) | |