Spaces:
Running
Running
| from __future__ import annotations | |
| import os | |
| import uuid | |
| from pathlib import Path | |
| from typing import Any | |
| from district_llm.inference import DistrictLLMInference | |
| from district_llm.schema import DistrictAction | |
| from models import ( | |
| AgenticTrafficAction, | |
| AgenticTrafficObservation, | |
| AgenticTrafficState, | |
| ) | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv_app.openenv_wrapper import OpenEnvTrafficWrapper | |
| REPO_ROOT = Path(__file__).resolve().parents[1] | |
| DATA_DIR = Path(os.environ.get("DATA_DIR", "") or (REPO_ROOT / "data" / "generated")) | |
| SPLITS_DIR = Path(os.environ.get("SPLITS_DIR", "") or (REPO_ROOT / "data" / "splits")) | |
| DISTRICT_LLM_ADAPTER_PATH = Path( | |
| os.environ.get("DISTRICT_LLM_ADAPTER_PATH", "") | |
| or (REPO_ROOT / "artifacts" / "district_llm_adapter_v3" / "main_run" / "adapter") | |
| ) | |
| DISTRICT_LLM_DEVICE = os.environ.get("DISTRICT_LLM_DEVICE") | |
| class AgenticTrafficEnvironment( | |
| Environment[AgenticTrafficAction, AgenticTrafficObservation, AgenticTrafficState] | |
| ): | |
| """Minimal OpenEnv-compatible wrapper around the existing district controller stack.""" | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self.wrapper = OpenEnvTrafficWrapper( | |
| generated_root=DATA_DIR, | |
| splits_root=SPLITS_DIR, | |
| ) | |
| self._state = AgenticTrafficState() | |
| self._llm_inference: DistrictLLMInference | None = None | |
| self._llm_load_attempted = False | |
| self._llm_error: str | None = None | |
| def reset( | |
| self, | |
| seed: int | None = None, | |
| episode_id: str | None = None, | |
| **kwargs: Any, | |
| ) -> AgenticTrafficObservation: | |
| payload = self.wrapper.reset( | |
| seed=seed, | |
| city_id=kwargs.get("city_id"), | |
| scenario_name=kwargs.get("scenario_name"), | |
| ) | |
| self._state.episode_id = episode_id or str(uuid.uuid4()) | |
| self._state.step_count = 0 | |
| self._sync_state() | |
| observation = AgenticTrafficObservation.model_validate(payload["observation"]) | |
| observation.reward = None | |
| observation.done = False | |
| observation.metadata["llm"] = self._llm_status() | |
| return observation | |
| def step( | |
| self, | |
| action: AgenticTrafficAction, | |
| timeout_s: float | None = None, | |
| **kwargs: Any, | |
| ) -> AgenticTrafficObservation: | |
| del timeout_s, kwargs | |
| payload = self.wrapper.step(action=self._build_step_payload(action)) | |
| self._state.step_count += 1 | |
| self._sync_state() | |
| observation = AgenticTrafficObservation.model_validate(payload["observation"]) | |
| observation.done = bool(payload.get("done", False)) | |
| observation.reward = float(payload.get("reward", 0.0)) | |
| observation.metadata["llm"] = self._llm_status() | |
| return observation | |
| def state(self) -> AgenticTrafficState: | |
| self._sync_state() | |
| return self._state | |
| def _build_step_payload(self, action: AgenticTrafficAction) -> dict[str, Any]: | |
| district_actions = dict(action.district_actions) | |
| llm_generated_actions: dict[str, Any] = {} | |
| if action.use_llm: | |
| llm_generated_actions = self._generate_llm_actions( | |
| existing_actions=district_actions, | |
| max_new_tokens=action.llm_max_new_tokens, | |
| ) | |
| for district_id, directive in llm_generated_actions.items(): | |
| district_actions.setdefault(district_id, directive) | |
| payload = {"district_actions": district_actions} | |
| payload["metadata"] = { | |
| "use_llm": bool(action.use_llm), | |
| "llm_generated_districts": sorted(llm_generated_actions), | |
| "llm": self._llm_status(), | |
| } | |
| return payload | |
| def _generate_llm_actions( | |
| self, | |
| existing_actions: dict[str, Any], | |
| max_new_tokens: int, | |
| ) -> dict[str, Any]: | |
| if not self.wrapper.last_summaries: | |
| return {} | |
| inference = self._get_llm_inference() | |
| if inference is None: | |
| return {} | |
| generated_actions: dict[str, Any] = {} | |
| for district_id, summary in self.wrapper.last_summaries.items(): | |
| if district_id in existing_actions: | |
| continue | |
| result = inference.predict_with_result(summary=summary, max_new_tokens=max_new_tokens) | |
| generated_actions[district_id] = result.action.to_dict() | |
| return generated_actions | |
| def _get_llm_inference(self) -> DistrictLLMInference | None: | |
| if self._llm_inference is not None: | |
| return self._llm_inference | |
| if self._llm_load_attempted: | |
| return None | |
| self._llm_load_attempted = True | |
| if not DISTRICT_LLM_ADAPTER_PATH.exists(): | |
| self._llm_error = f"Adapter not found at {DISTRICT_LLM_ADAPTER_PATH}" | |
| return None | |
| try: | |
| self._llm_inference = DistrictLLMInference( | |
| model_name_or_path=str(DISTRICT_LLM_ADAPTER_PATH), | |
| device=DISTRICT_LLM_DEVICE, | |
| fallback_action=DistrictAction.default_hold( | |
| duration_steps=self.wrapper.district_decision_interval | |
| ), | |
| ) | |
| except Exception as exc: | |
| self._llm_error = f"{type(exc).__name__}: {exc}" | |
| self._llm_inference = None | |
| return self._llm_inference | |
| def _llm_status(self) -> dict[str, Any]: | |
| return { | |
| "adapter_path": str(DISTRICT_LLM_ADAPTER_PATH), | |
| "adapter_present": DISTRICT_LLM_ADAPTER_PATH.exists(), | |
| "loaded": self._llm_inference is not None, | |
| "load_attempted": self._llm_load_attempted, | |
| "error": self._llm_error, | |
| } | |
| def _sync_state(self) -> None: | |
| payload = self.wrapper.state()["state"] | |
| self._state = AgenticTrafficState.model_validate( | |
| { | |
| **payload, | |
| "episode_id": self._state.episode_id, | |
| "step_count": self._state.step_count, | |
| "llm": self._llm_status(), | |
| } | |
| ) | |