traffic-visualizer / server /environment.py
tokev's picture
Add files using upload-large-folder tool
248a619 verified
from __future__ import annotations
import os
import uuid
from pathlib import Path
from typing import Any
from district_llm.inference import DistrictLLMInference
from district_llm.schema import DistrictAction
from models import (
AgenticTrafficAction,
AgenticTrafficObservation,
AgenticTrafficState,
)
from openenv.core.env_server.interfaces import Environment
from openenv_app.openenv_wrapper import OpenEnvTrafficWrapper
REPO_ROOT = Path(__file__).resolve().parents[1]
DATA_DIR = Path(os.environ.get("DATA_DIR", "") or (REPO_ROOT / "data" / "generated"))
SPLITS_DIR = Path(os.environ.get("SPLITS_DIR", "") or (REPO_ROOT / "data" / "splits"))
DISTRICT_LLM_ADAPTER_PATH = Path(
os.environ.get("DISTRICT_LLM_ADAPTER_PATH", "")
or (REPO_ROOT / "artifacts" / "district_llm_adapter_v3" / "main_run" / "adapter")
)
DISTRICT_LLM_DEVICE = os.environ.get("DISTRICT_LLM_DEVICE")
class AgenticTrafficEnvironment(
Environment[AgenticTrafficAction, AgenticTrafficObservation, AgenticTrafficState]
):
"""Minimal OpenEnv-compatible wrapper around the existing district controller stack."""
def __init__(self) -> None:
super().__init__()
self.wrapper = OpenEnvTrafficWrapper(
generated_root=DATA_DIR,
splits_root=SPLITS_DIR,
)
self._state = AgenticTrafficState()
self._llm_inference: DistrictLLMInference | None = None
self._llm_load_attempted = False
self._llm_error: str | None = None
def reset(
self,
seed: int | None = None,
episode_id: str | None = None,
**kwargs: Any,
) -> AgenticTrafficObservation:
payload = self.wrapper.reset(
seed=seed,
city_id=kwargs.get("city_id"),
scenario_name=kwargs.get("scenario_name"),
)
self._state.episode_id = episode_id or str(uuid.uuid4())
self._state.step_count = 0
self._sync_state()
observation = AgenticTrafficObservation.model_validate(payload["observation"])
observation.reward = None
observation.done = False
observation.metadata["llm"] = self._llm_status()
return observation
def step(
self,
action: AgenticTrafficAction,
timeout_s: float | None = None,
**kwargs: Any,
) -> AgenticTrafficObservation:
del timeout_s, kwargs
payload = self.wrapper.step(action=self._build_step_payload(action))
self._state.step_count += 1
self._sync_state()
observation = AgenticTrafficObservation.model_validate(payload["observation"])
observation.done = bool(payload.get("done", False))
observation.reward = float(payload.get("reward", 0.0))
observation.metadata["llm"] = self._llm_status()
return observation
@property
def state(self) -> AgenticTrafficState:
self._sync_state()
return self._state
def _build_step_payload(self, action: AgenticTrafficAction) -> dict[str, Any]:
district_actions = dict(action.district_actions)
llm_generated_actions: dict[str, Any] = {}
if action.use_llm:
llm_generated_actions = self._generate_llm_actions(
existing_actions=district_actions,
max_new_tokens=action.llm_max_new_tokens,
)
for district_id, directive in llm_generated_actions.items():
district_actions.setdefault(district_id, directive)
payload = {"district_actions": district_actions}
payload["metadata"] = {
"use_llm": bool(action.use_llm),
"llm_generated_districts": sorted(llm_generated_actions),
"llm": self._llm_status(),
}
return payload
def _generate_llm_actions(
self,
existing_actions: dict[str, Any],
max_new_tokens: int,
) -> dict[str, Any]:
if not self.wrapper.last_summaries:
return {}
inference = self._get_llm_inference()
if inference is None:
return {}
generated_actions: dict[str, Any] = {}
for district_id, summary in self.wrapper.last_summaries.items():
if district_id in existing_actions:
continue
result = inference.predict_with_result(summary=summary, max_new_tokens=max_new_tokens)
generated_actions[district_id] = result.action.to_dict()
return generated_actions
def _get_llm_inference(self) -> DistrictLLMInference | None:
if self._llm_inference is not None:
return self._llm_inference
if self._llm_load_attempted:
return None
self._llm_load_attempted = True
if not DISTRICT_LLM_ADAPTER_PATH.exists():
self._llm_error = f"Adapter not found at {DISTRICT_LLM_ADAPTER_PATH}"
return None
try:
self._llm_inference = DistrictLLMInference(
model_name_or_path=str(DISTRICT_LLM_ADAPTER_PATH),
device=DISTRICT_LLM_DEVICE,
fallback_action=DistrictAction.default_hold(
duration_steps=self.wrapper.district_decision_interval
),
)
except Exception as exc:
self._llm_error = f"{type(exc).__name__}: {exc}"
self._llm_inference = None
return self._llm_inference
def _llm_status(self) -> dict[str, Any]:
return {
"adapter_path": str(DISTRICT_LLM_ADAPTER_PATH),
"adapter_present": DISTRICT_LLM_ADAPTER_PATH.exists(),
"loaded": self._llm_inference is not None,
"load_attempted": self._llm_load_attempted,
"error": self._llm_error,
}
def _sync_state(self) -> None:
payload = self.wrapper.state()["state"]
self._state = AgenticTrafficState.model_validate(
{
**payload,
"episode_id": self._state.episode_id,
"step_count": self._state.step_count,
"llm": self._llm_status(),
}
)