from __future__ import annotations import os import uuid from pathlib import Path from typing import Any from district_llm.inference import DistrictLLMInference from district_llm.schema import DistrictAction from models import ( AgenticTrafficAction, AgenticTrafficObservation, AgenticTrafficState, ) from openenv.core.env_server.interfaces import Environment from openenv_app.openenv_wrapper import OpenEnvTrafficWrapper REPO_ROOT = Path(__file__).resolve().parents[1] DATA_DIR = Path(os.environ.get("DATA_DIR", "") or (REPO_ROOT / "data" / "generated")) SPLITS_DIR = Path(os.environ.get("SPLITS_DIR", "") or (REPO_ROOT / "data" / "splits")) DISTRICT_LLM_ADAPTER_PATH = Path( os.environ.get("DISTRICT_LLM_ADAPTER_PATH", "") or (REPO_ROOT / "artifacts" / "district_llm_adapter_v3" / "main_run" / "adapter") ) DISTRICT_LLM_DEVICE = os.environ.get("DISTRICT_LLM_DEVICE") class AgenticTrafficEnvironment( Environment[AgenticTrafficAction, AgenticTrafficObservation, AgenticTrafficState] ): """Minimal OpenEnv-compatible wrapper around the existing district controller stack.""" def __init__(self) -> None: super().__init__() self.wrapper = OpenEnvTrafficWrapper( generated_root=DATA_DIR, splits_root=SPLITS_DIR, ) self._state = AgenticTrafficState() self._llm_inference: DistrictLLMInference | None = None self._llm_load_attempted = False self._llm_error: str | None = None def reset( self, seed: int | None = None, episode_id: str | None = None, **kwargs: Any, ) -> AgenticTrafficObservation: payload = self.wrapper.reset( seed=seed, city_id=kwargs.get("city_id"), scenario_name=kwargs.get("scenario_name"), ) self._state.episode_id = episode_id or str(uuid.uuid4()) self._state.step_count = 0 self._sync_state() observation = AgenticTrafficObservation.model_validate(payload["observation"]) observation.reward = None observation.done = False observation.metadata["llm"] = self._llm_status() return observation def step( self, action: AgenticTrafficAction, timeout_s: float | None = None, **kwargs: Any, ) -> AgenticTrafficObservation: del timeout_s, kwargs payload = self.wrapper.step(action=self._build_step_payload(action)) self._state.step_count += 1 self._sync_state() observation = AgenticTrafficObservation.model_validate(payload["observation"]) observation.done = bool(payload.get("done", False)) observation.reward = float(payload.get("reward", 0.0)) observation.metadata["llm"] = self._llm_status() return observation @property def state(self) -> AgenticTrafficState: self._sync_state() return self._state def _build_step_payload(self, action: AgenticTrafficAction) -> dict[str, Any]: district_actions = dict(action.district_actions) llm_generated_actions: dict[str, Any] = {} if action.use_llm: llm_generated_actions = self._generate_llm_actions( existing_actions=district_actions, max_new_tokens=action.llm_max_new_tokens, ) for district_id, directive in llm_generated_actions.items(): district_actions.setdefault(district_id, directive) payload = {"district_actions": district_actions} payload["metadata"] = { "use_llm": bool(action.use_llm), "llm_generated_districts": sorted(llm_generated_actions), "llm": self._llm_status(), } return payload def _generate_llm_actions( self, existing_actions: dict[str, Any], max_new_tokens: int, ) -> dict[str, Any]: if not self.wrapper.last_summaries: return {} inference = self._get_llm_inference() if inference is None: return {} generated_actions: dict[str, Any] = {} for district_id, summary in self.wrapper.last_summaries.items(): if district_id in existing_actions: continue result = inference.predict_with_result(summary=summary, max_new_tokens=max_new_tokens) generated_actions[district_id] = result.action.to_dict() return generated_actions def _get_llm_inference(self) -> DistrictLLMInference | None: if self._llm_inference is not None: return self._llm_inference if self._llm_load_attempted: return None self._llm_load_attempted = True if not DISTRICT_LLM_ADAPTER_PATH.exists(): self._llm_error = f"Adapter not found at {DISTRICT_LLM_ADAPTER_PATH}" return None try: self._llm_inference = DistrictLLMInference( model_name_or_path=str(DISTRICT_LLM_ADAPTER_PATH), device=DISTRICT_LLM_DEVICE, fallback_action=DistrictAction.default_hold( duration_steps=self.wrapper.district_decision_interval ), ) except Exception as exc: self._llm_error = f"{type(exc).__name__}: {exc}" self._llm_inference = None return self._llm_inference def _llm_status(self) -> dict[str, Any]: return { "adapter_path": str(DISTRICT_LLM_ADAPTER_PATH), "adapter_present": DISTRICT_LLM_ADAPTER_PATH.exists(), "loaded": self._llm_inference is not None, "load_attempted": self._llm_load_attempted, "error": self._llm_error, } def _sync_state(self) -> None: payload = self.wrapper.state()["state"] self._state = AgenticTrafficState.model_validate( { **payload, "episode_id": self._state.episode_id, "step_count": self._state.step_count, "llm": self._llm_status(), } )