tokev commited on
Commit
793a22b
·
verified ·
1 Parent(s): 5e526e3

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. models.py +45 -11
  2. server/environment.py +122 -6
models.py CHANGED
@@ -1,27 +1,61 @@
1
  from __future__ import annotations
2
 
 
3
  from typing import Any
4
 
5
- from pydantic import BaseModel, Field
6
-
7
-
8
- class AgenticTrafficAction(BaseModel):
9
- district_actions: dict[str, Any] = Field(default_factory=dict)
10
-
11
-
12
- class AgenticTrafficObservation(BaseModel):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  city_id: str | None = None
14
  scenario_name: str | None = None
15
  decision_step: int = 0
16
  sim_time: int = 0
17
  district_summaries: dict[str, Any] = Field(default_factory=dict)
18
- done: bool = False
19
- reward: float = 0.0
20
 
21
 
22
- class AgenticTrafficState(BaseModel):
23
  scenario: dict[str, Any] | None = None
24
  controller: dict[str, Any] = Field(default_factory=dict)
25
  district_decision_interval: int = 0
26
  district_summaries: dict[str, Any] = Field(default_factory=dict)
 
27
  last_info: dict[str, Any] = Field(default_factory=dict)
 
1
  from __future__ import annotations
2
 
3
+ import json
4
  from typing import Any
5
 
6
+ from openenv.core.env_server import Action, Observation, State
7
+ from pydantic import Field, field_validator
8
+
9
+
10
+ class AgenticTrafficAction(Action):
11
+ use_llm: bool = Field(
12
+ default=False,
13
+ description=(
14
+ "When true, use the bundled district LLM adapter to generate district_actions "
15
+ "for districts not explicitly provided."
16
+ ),
17
+ )
18
+ district_actions: dict[str, Any] = Field(
19
+ default_factory=dict,
20
+ description=(
21
+ "JSON object keyed by district_id. Use {} for a no-op step, or provide "
22
+ 'entries like {"d_00":{"strategy":"hold","phase_bias":"NS","duration_steps":10}}.'
23
+ ),
24
+ )
25
+ llm_max_new_tokens: int = Field(
26
+ default=128,
27
+ ge=16,
28
+ le=512,
29
+ description="Maximum new tokens to generate per district when use_llm=true.",
30
+ )
31
+
32
+ @field_validator("district_actions", mode="before")
33
+ @classmethod
34
+ def parse_district_actions(cls, value: Any) -> dict[str, Any]:
35
+ if value is None or value == "":
36
+ return {}
37
+ if isinstance(value, str):
38
+ parsed = json.loads(value)
39
+ if not isinstance(parsed, dict):
40
+ raise ValueError("district_actions must decode to a JSON object.")
41
+ return parsed
42
+ if isinstance(value, dict):
43
+ return value
44
+ raise ValueError("district_actions must be a dict or JSON object string.")
45
+
46
+
47
+ class AgenticTrafficObservation(Observation):
48
  city_id: str | None = None
49
  scenario_name: str | None = None
50
  decision_step: int = 0
51
  sim_time: int = 0
52
  district_summaries: dict[str, Any] = Field(default_factory=dict)
 
 
53
 
54
 
55
+ class AgenticTrafficState(State):
56
  scenario: dict[str, Any] | None = None
57
  controller: dict[str, Any] = Field(default_factory=dict)
58
  district_decision_interval: int = 0
59
  district_summaries: dict[str, Any] = Field(default_factory=dict)
60
+ llm: dict[str, Any] = Field(default_factory=dict)
61
  last_info: dict[str, Any] = Field(default_factory=dict)
server/environment.py CHANGED
@@ -1,8 +1,12 @@
1
  from __future__ import annotations
2
 
3
  import os
 
4
  from pathlib import Path
 
5
 
 
 
6
  from models import (
7
  AgenticTrafficAction,
8
  AgenticTrafficObservation,
@@ -15,6 +19,11 @@ from openenv_app.openenv_wrapper import OpenEnvTrafficWrapper
15
  REPO_ROOT = Path(__file__).resolve().parents[1]
16
  DATA_DIR = Path(os.environ.get("DATA_DIR", "") or (REPO_ROOT / "data" / "generated"))
17
  SPLITS_DIR = Path(os.environ.get("SPLITS_DIR", "") or (REPO_ROOT / "data" / "splits"))
 
 
 
 
 
18
 
19
 
20
  class AgenticTrafficEnvironment(
@@ -23,23 +32,50 @@ class AgenticTrafficEnvironment(
23
  """Minimal OpenEnv-compatible wrapper around the existing district controller stack."""
24
 
25
  def __init__(self) -> None:
 
26
  self.wrapper = OpenEnvTrafficWrapper(
27
  generated_root=DATA_DIR,
28
  splits_root=SPLITS_DIR,
29
  )
30
  self._state = AgenticTrafficState()
 
 
 
31
 
32
- def reset(self) -> AgenticTrafficObservation:
33
- payload = self.wrapper.reset(seed=None)
 
 
 
 
 
 
 
 
 
 
 
34
  self._sync_state()
35
- return AgenticTrafficObservation.model_validate(payload["observation"])
 
 
 
 
36
 
37
- def step(self, action: AgenticTrafficAction) -> AgenticTrafficObservation:
38
- payload = self.wrapper.step(action=action.model_dump())
 
 
 
 
 
 
 
39
  self._sync_state()
40
  observation = AgenticTrafficObservation.model_validate(payload["observation"])
41
  observation.done = bool(payload.get("done", False))
42
  observation.reward = float(payload.get("reward", 0.0))
 
43
  return observation
44
 
45
  @property
@@ -47,6 +83,86 @@ class AgenticTrafficEnvironment(
47
  self._sync_state()
48
  return self._state
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def _sync_state(self) -> None:
51
  payload = self.wrapper.state()["state"]
52
- self._state = AgenticTrafficState.model_validate(payload)
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
  import os
4
+ import uuid
5
  from pathlib import Path
6
+ from typing import Any
7
 
8
+ from district_llm.inference import DistrictLLMInference
9
+ from district_llm.schema import DistrictAction
10
  from models import (
11
  AgenticTrafficAction,
12
  AgenticTrafficObservation,
 
19
  REPO_ROOT = Path(__file__).resolve().parents[1]
20
  DATA_DIR = Path(os.environ.get("DATA_DIR", "") or (REPO_ROOT / "data" / "generated"))
21
  SPLITS_DIR = Path(os.environ.get("SPLITS_DIR", "") or (REPO_ROOT / "data" / "splits"))
22
+ DISTRICT_LLM_ADAPTER_PATH = Path(
23
+ os.environ.get("DISTRICT_LLM_ADAPTER_PATH", "")
24
+ or (REPO_ROOT / "artifacts" / "district_llm_adapter_v3" / "main_run" / "adapter")
25
+ )
26
+ DISTRICT_LLM_DEVICE = os.environ.get("DISTRICT_LLM_DEVICE")
27
 
28
 
29
  class AgenticTrafficEnvironment(
 
32
  """Minimal OpenEnv-compatible wrapper around the existing district controller stack."""
33
 
34
  def __init__(self) -> None:
35
+ super().__init__()
36
  self.wrapper = OpenEnvTrafficWrapper(
37
  generated_root=DATA_DIR,
38
  splits_root=SPLITS_DIR,
39
  )
40
  self._state = AgenticTrafficState()
41
+ self._llm_inference: DistrictLLMInference | None = None
42
+ self._llm_load_attempted = False
43
+ self._llm_error: str | None = None
44
 
45
+ def reset(
46
+ self,
47
+ seed: int | None = None,
48
+ episode_id: str | None = None,
49
+ **kwargs: Any,
50
+ ) -> AgenticTrafficObservation:
51
+ payload = self.wrapper.reset(
52
+ seed=seed,
53
+ city_id=kwargs.get("city_id"),
54
+ scenario_name=kwargs.get("scenario_name"),
55
+ )
56
+ self._state.episode_id = episode_id or str(uuid.uuid4())
57
+ self._state.step_count = 0
58
  self._sync_state()
59
+ observation = AgenticTrafficObservation.model_validate(payload["observation"])
60
+ observation.reward = None
61
+ observation.done = False
62
+ observation.metadata["llm"] = self._llm_status()
63
+ return observation
64
 
65
+ def step(
66
+ self,
67
+ action: AgenticTrafficAction,
68
+ timeout_s: float | None = None,
69
+ **kwargs: Any,
70
+ ) -> AgenticTrafficObservation:
71
+ del timeout_s, kwargs
72
+ payload = self.wrapper.step(action=self._build_step_payload(action))
73
+ self._state.step_count += 1
74
  self._sync_state()
75
  observation = AgenticTrafficObservation.model_validate(payload["observation"])
76
  observation.done = bool(payload.get("done", False))
77
  observation.reward = float(payload.get("reward", 0.0))
78
+ observation.metadata["llm"] = self._llm_status()
79
  return observation
80
 
81
  @property
 
83
  self._sync_state()
84
  return self._state
85
 
86
+ def _build_step_payload(self, action: AgenticTrafficAction) -> dict[str, Any]:
87
+ district_actions = dict(action.district_actions)
88
+ llm_generated_actions: dict[str, Any] = {}
89
+
90
+ if action.use_llm:
91
+ llm_generated_actions = self._generate_llm_actions(
92
+ existing_actions=district_actions,
93
+ max_new_tokens=action.llm_max_new_tokens,
94
+ )
95
+ for district_id, directive in llm_generated_actions.items():
96
+ district_actions.setdefault(district_id, directive)
97
+
98
+ payload = {"district_actions": district_actions}
99
+ payload["metadata"] = {
100
+ "use_llm": bool(action.use_llm),
101
+ "llm_generated_districts": sorted(llm_generated_actions),
102
+ "llm": self._llm_status(),
103
+ }
104
+ return payload
105
+
106
+ def _generate_llm_actions(
107
+ self,
108
+ existing_actions: dict[str, Any],
109
+ max_new_tokens: int,
110
+ ) -> dict[str, Any]:
111
+ if not self.wrapper.last_summaries:
112
+ return {}
113
+
114
+ inference = self._get_llm_inference()
115
+ if inference is None:
116
+ return {}
117
+
118
+ generated_actions: dict[str, Any] = {}
119
+ for district_id, summary in self.wrapper.last_summaries.items():
120
+ if district_id in existing_actions:
121
+ continue
122
+ result = inference.predict_with_result(summary=summary, max_new_tokens=max_new_tokens)
123
+ generated_actions[district_id] = result.action.to_dict()
124
+ return generated_actions
125
+
126
+ def _get_llm_inference(self) -> DistrictLLMInference | None:
127
+ if self._llm_inference is not None:
128
+ return self._llm_inference
129
+ if self._llm_load_attempted:
130
+ return None
131
+
132
+ self._llm_load_attempted = True
133
+ if not DISTRICT_LLM_ADAPTER_PATH.exists():
134
+ self._llm_error = f"Adapter not found at {DISTRICT_LLM_ADAPTER_PATH}"
135
+ return None
136
+
137
+ try:
138
+ self._llm_inference = DistrictLLMInference(
139
+ model_name_or_path=str(DISTRICT_LLM_ADAPTER_PATH),
140
+ device=DISTRICT_LLM_DEVICE,
141
+ fallback_action=DistrictAction.default_hold(
142
+ duration_steps=self.wrapper.district_decision_interval
143
+ ),
144
+ )
145
+ except Exception as exc:
146
+ self._llm_error = f"{type(exc).__name__}: {exc}"
147
+ self._llm_inference = None
148
+ return self._llm_inference
149
+
150
+ def _llm_status(self) -> dict[str, Any]:
151
+ return {
152
+ "adapter_path": str(DISTRICT_LLM_ADAPTER_PATH),
153
+ "adapter_present": DISTRICT_LLM_ADAPTER_PATH.exists(),
154
+ "loaded": self._llm_inference is not None,
155
+ "load_attempted": self._llm_load_attempted,
156
+ "error": self._llm_error,
157
+ }
158
+
159
  def _sync_state(self) -> None:
160
  payload = self.wrapper.state()["state"]
161
+ self._state = AgenticTrafficState.model_validate(
162
+ {
163
+ **payload,
164
+ "episode_id": self._state.episode_id,
165
+ "step_count": self._state.step_count,
166
+ "llm": self._llm_status(),
167
+ }
168
+ )