aparekh02 commited on
Commit
a9e4a6a
Β·
verified Β·
1 Parent(s): 4df3960

Add auto-running RL demo with PPO mini-updates

Browse files
README.md CHANGED
@@ -1,12 +1,15 @@
1
  ---
2
- title: Openenv Rl Demo
3
- emoji: πŸ“š
4
- colorFrom: gray
5
  colorTo: green
6
  sdk: gradio
7
- sdk_version: 6.9.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
1
  ---
2
+ title: OpenENV RL Demo
3
+ emoji: πŸš—
4
+ colorFrom: blue
5
  colorTo: green
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
+ # OpenENV RL β€” Live Policy Training
13
+
14
+ Auto-runs 20 steps per episode using the openenv policy system (FlatMLPPolicy).
15
+ PPO mini-update after each episode β€” rewards increase over time.
app.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OpenENV RL Demo β€” auto-runs 20 steps per episode using the openenv policy system.
3
+
4
+ Policies: FlatMLPPolicy / TicketAttentionPolicy (from openenv)
5
+ Training: PPO mini-update after each episode β€” rewards increase over time
6
+ Display: Live step-by-step feed + episode reward history
7
+ """
8
+
9
+ import math, time, threading
10
+ import numpy as np
11
+ import torch
12
+ import torch.optim as optim
13
+ import gradio as gr
14
+
15
+ from overflow_env.server.overflow_environment import OverflowEnvironment
16
+ from overflow_env.models import OverflowAction
17
+ from policies.flat_mlp_policy import FlatMLPPolicy
18
+ from policies.ticket_attention_policy import TicketAttentionPolicy
19
+ from policies.policy_spec import build_obs, build_ticket_vector, OBS_DIM
20
+
21
+
22
+ STEPS_PER_EPISODE = 20
23
+
24
+
25
+ # ── Observation adapter ───────────────────────────────────────────────────────
26
+
27
+ def obs_to_vec(overflow_obs) -> np.ndarray:
28
+ cars = overflow_obs.cars
29
+ if not cars:
30
+ return np.zeros(OBS_DIM, dtype=np.float32)
31
+ ego = next((c for c in cars if c.carId == 0), cars[0])
32
+ ego_spd = ego.speed / 4.5
33
+ ego_x = ego.position.x
34
+ ego_y = (ego.lane - 2) * 3.7
35
+ tickets = []
36
+ for car in cars:
37
+ if car.carId == 0:
38
+ continue
39
+ rx = car.position.x - ego.position.x
40
+ ry = (car.lane - ego.lane) * 3.7
41
+ cs = car.speed / 4.5
42
+ d = math.sqrt(rx**2 + ry**2)
43
+ if d > 80:
44
+ continue
45
+ cl = max(ego_spd - cs * math.copysign(1, max(rx, 0.01)), 0.1)
46
+ tickets.append(build_ticket_vector(
47
+ severity_weight=1.0 if d < 8 else 0.75 if d < 15 else 0.5,
48
+ ttl=5.0, pos_x=rx, pos_y=ry, pos_z=0.0,
49
+ vel_x=cs, vel_y=0.0, vel_z=0.0, heading=0.0,
50
+ size_length=4.0, size_width=2.0, size_height=1.5,
51
+ distance=d, time_to_collision=min(d / cl, 30.0),
52
+ bearing=math.atan2(ry, max(rx, 0.01)),
53
+ ticket_type="collision_risk", entity_type="vehicle", confidence=1.0,
54
+ ))
55
+ tv = np.array(tickets, dtype=np.float32) if tickets else None
56
+ return build_obs(ego_x=ego_x, ego_y=ego_y, ego_z=0.0,
57
+ ego_vx=ego_spd, ego_vy=0.0,
58
+ heading=0.0, speed=ego_spd,
59
+ steer=0.0, throttle=0.5, brake=0.0,
60
+ ticket_vectors=tv)
61
+
62
+
63
+ def action_to_decision(a: np.ndarray) -> str:
64
+ s, t, b = float(a[0]), float(a[1]), float(a[2])
65
+ if abs(s) > 0.35: return "lane_change_left" if s < 0 else "lane_change_right"
66
+ if b > 0.25: return "brake"
67
+ if t > 0.20: return "accelerate"
68
+ return "maintain"
69
+
70
+
71
+ # ── Global training state ─────────────────────────────────────────────────────
72
+
73
+ policy = FlatMLPPolicy(obs_dim=OBS_DIM)
74
+ optimizer = optim.Adam(policy.parameters(), lr=3e-4, eps=1e-5)
75
+
76
+ # Rollout buffer (lightweight β€” one episode at a time)
77
+ _buf_obs = []
78
+ _buf_acts = []
79
+ _buf_rews = []
80
+ _buf_logps = []
81
+ _buf_vals = []
82
+ _buf_dones = []
83
+
84
+ episode_history = [] # [{ep, steps, reward, outcome}]
85
+ step_log = [] # [{ep, step, decision, reward, scene}]
86
+ _running = False
87
+ _lock = threading.Lock()
88
+
89
+
90
+ def _ppo_mini_update():
91
+ """Single PPO gradient step on the just-completed episode."""
92
+ if len(_buf_obs) < 2:
93
+ return
94
+ obs_t = torch.tensor(np.array(_buf_obs), dtype=torch.float32)
95
+ acts_t = torch.tensor(np.array(_buf_acts), dtype=torch.float32)
96
+ rews_t = torch.tensor(_buf_rews, dtype=torch.float32)
97
+ logp_t = torch.tensor(_buf_logps, dtype=torch.float32)
98
+ vals_t = torch.tensor(_buf_vals, dtype=torch.float32)
99
+ done_t = torch.tensor(_buf_dones, dtype=torch.float32)
100
+
101
+ # GAE returns
102
+ gamma, lam = 0.99, 0.95
103
+ adv = torch.zeros_like(rews_t)
104
+ gae = 0.0
105
+ for t in reversed(range(len(rews_t))):
106
+ nv = 0.0 if t == len(rews_t) - 1 else float(vals_t[t + 1])
107
+ d = rews_t[t] + gamma * nv * (1 - done_t[t]) - vals_t[t]
108
+ gae = d + gamma * lam * (1 - done_t[t]) * gae
109
+ adv[t] = gae
110
+ ret = adv + vals_t
111
+ adv = (adv - adv.mean()) / (adv.std() + 1e-8)
112
+
113
+ policy.train()
114
+ act_mean, val = policy(obs_t)
115
+ val = val.squeeze(-1)
116
+ dist = torch.distributions.Normal(act_mean, torch.ones_like(act_mean) * 0.3)
117
+ logp = dist.log_prob(acts_t).sum(dim=-1)
118
+ entropy = dist.entropy().sum(dim=-1).mean()
119
+ ratio = torch.exp(logp - logp_t)
120
+ pg = torch.max(-adv * ratio, -adv * ratio.clamp(0.8, 1.2)).mean()
121
+ vf = 0.5 * ((val - ret) ** 2).mean()
122
+ loss = pg + 0.5 * vf - 0.02 * entropy
123
+ optimizer.zero_grad()
124
+ loss.backward()
125
+ torch.nn.utils.clip_grad_norm_(policy.parameters(), 0.5)
126
+ optimizer.step()
127
+
128
+
129
+ def run_episodes_loop():
130
+ """Background thread β€” runs episodes continuously, updates policy after each."""
131
+ global _running
132
+ ep_num = 0
133
+ env = OverflowEnvironment()
134
+
135
+ while _running:
136
+ ep_num += 1
137
+ obs = env.reset()
138
+ ep_rew = 0.0
139
+ outcome = "timeout"
140
+
141
+ _buf_obs.clear(); _buf_acts.clear(); _buf_rews.clear()
142
+ _buf_logps.clear(); _buf_vals.clear(); _buf_dones.clear()
143
+
144
+ for step in range(1, STEPS_PER_EPISODE + 1):
145
+ if not _running:
146
+ break
147
+
148
+ obs_vec = obs_to_vec(obs)
149
+ policy.eval()
150
+ with torch.no_grad():
151
+ obs_t = torch.tensor(obs_vec, dtype=torch.float32).unsqueeze(0)
152
+ act_mean, val = policy(obs_t)
153
+ dist = torch.distributions.Normal(act_mean.squeeze(0),
154
+ torch.ones(3) * 0.3)
155
+ action = dist.sample().clamp(-1, 1)
156
+ logp = dist.log_prob(action).sum()
157
+
158
+ decision = action_to_decision(action.numpy())
159
+ obs = env.step(OverflowAction(decision=decision, reasoning=""))
160
+ reward = float(obs.reward or 0.0)
161
+ done = obs.done
162
+
163
+ _buf_obs.append(obs_vec)
164
+ _buf_acts.append(action.numpy())
165
+ _buf_rews.append(reward)
166
+ _buf_logps.append(float(logp))
167
+ _buf_vals.append(float(val.squeeze()))
168
+ _buf_dones.append(float(done))
169
+
170
+ ep_rew += reward
171
+
172
+ with _lock:
173
+ step_log.append({
174
+ "ep": ep_num, "step": step,
175
+ "decision": decision,
176
+ "reward": round(reward, 2),
177
+ "ep_reward": round(ep_rew, 2),
178
+ "scene": obs.scene_description.split("\n")[0],
179
+ "incident": obs.incident_report or "",
180
+ })
181
+
182
+ if done:
183
+ outcome = "CRASH" if "CRASH" in (obs.incident_report or "") else "GOAL"
184
+ break
185
+
186
+ time.sleep(0.6) # pace so UI can show each step
187
+
188
+ _ppo_mini_update()
189
+
190
+ with _lock:
191
+ episode_history.append({
192
+ "ep": ep_num,
193
+ "steps": step,
194
+ "reward": round(ep_rew, 2),
195
+ "outcome": outcome,
196
+ })
197
+
198
+
199
+ # ── Gradio UI ─────────────────────────────────────────────────────────────────
200
+
201
+ def start_training():
202
+ global _running
203
+ if not _running:
204
+ _running = True
205
+ step_log.clear()
206
+ episode_history.clear()
207
+ t = threading.Thread(target=run_episodes_loop, daemon=True)
208
+ t.start()
209
+ return gr.update(value="Running...", interactive=False), gr.update(interactive=True)
210
+
211
+
212
+ def stop_training():
213
+ global _running
214
+ _running = False
215
+ return gr.update(value="Start", interactive=True), gr.update(interactive=False)
216
+
217
+
218
+ def get_updates():
219
+ """Called by gr.Timer every second β€” returns latest display content."""
220
+ with _lock:
221
+ logs = list(step_log[-20:])
222
+ eps = list(episode_history[-30:])
223
+
224
+ # Step feed
225
+ lines = []
226
+ for e in reversed(logs):
227
+ flag = ""
228
+ if "CRASH" in e["incident"]: flag = " πŸ’₯"
229
+ elif "GOAL" in e["incident"]: flag = " βœ“"
230
+ elif "NEAR MISS" in e["incident"]: flag = " ⚠"
231
+ lines.append(
232
+ f"ep {e['ep']:>3d} | step {e['step']:>2d} | "
233
+ f"{e['decision']:<20} | r={e['reward']:>+6.2f} | "
234
+ f"ep_total={e['ep_reward']:>7.2f}{flag}"
235
+ )
236
+ step_text = "\n".join(lines) if lines else "Waiting for first episode..."
237
+
238
+ # Episode summary
239
+ ep_lines = ["Episode | Steps | Total Reward | Outcome", "-" * 44]
240
+ for e in reversed(eps):
241
+ ep_lines.append(
242
+ f" {e['ep']:>4d} | {e['steps']:>3d} | "
243
+ f" {e['reward']:>+8.2f} | {e['outcome']}"
244
+ )
245
+ ep_text = "\n".join(ep_lines) if eps else "No episodes completed yet."
246
+
247
+ # Mean reward trend
248
+ if len(eps) >= 2:
249
+ rewards = [e["reward"] for e in eps]
250
+ n = len(rewards)
251
+ half = max(n // 2, 1)
252
+ early = sum(rewards[:half]) / half
253
+ late = sum(rewards[half:]) / max(n - half, 1)
254
+ trend = f"Mean reward (early {half} eps): {early:+.2f} β†’ (last {n-half} eps): {late:+.2f}"
255
+ arrow = "↑ improving" if late > early else "↓ declining"
256
+ trend_text = f"{trend} {arrow}"
257
+ else:
258
+ trend_text = "Collecting data..."
259
+
260
+ status = "● RUNNING" if _running else "β–  STOPPED"
261
+
262
+ return step_text, ep_text, trend_text, status
263
+
264
+
265
+ with gr.Blocks(title="OpenENV RL Demo") as demo:
266
+ gr.Markdown("# OpenENV RL β€” Live Policy Training\n"
267
+ "FlatMLPPolicy runs 20 steps per episode on OverflowEnvironment. "
268
+ "PPO mini-update after each episode β€” watch rewards improve over time.")
269
+
270
+ with gr.Row():
271
+ start_btn = gr.Button("Start", variant="primary")
272
+ stop_btn = gr.Button("Stop", variant="stop", interactive=False)
273
+ status_box = gr.Textbox(value="β–  STOPPED", label="Status",
274
+ interactive=False, scale=0, min_width=120)
275
+
276
+ gr.Markdown("### Live Step Feed (most recent 20 steps)")
277
+ step_display = gr.Textbox(
278
+ value="Press Start to begin...",
279
+ lines=22, max_lines=22, interactive=False,
280
+ elem_id="step_feed",
281
+ )
282
+
283
+ with gr.Row():
284
+ with gr.Column():
285
+ gr.Markdown("### Episode History")
286
+ ep_display = gr.Textbox(lines=12, interactive=False)
287
+ with gr.Column():
288
+ gr.Markdown("### Reward Trend")
289
+ trend_display = gr.Textbox(lines=3, interactive=False)
290
+
291
+ # Auto-refresh every 1 second
292
+ timer = gr.Timer(value=1.0)
293
+ timer.tick(
294
+ fn=get_updates,
295
+ outputs=[step_display, ep_display, trend_display, status_box],
296
+ )
297
+
298
+ start_btn.click(fn=start_training, outputs=[start_btn, stop_btn])
299
+ stop_btn.click(fn=stop_training, outputs=[start_btn, stop_btn])
300
+
301
+
302
+ if __name__ == "__main__":
303
+ demo.launch()
policies/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .flat_mlp_policy import FlatMLPPolicy
2
+ from .ticket_attention_policy import TicketAttentionPolicy
policies/base_policy.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BasePolicy β€” abstract interface all policies implement.
3
+
4
+ All policies expose the same predict() and train_step() API so the
5
+ curriculum trainer can swap them out transparently.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import abc
11
+ from typing import Any, Dict, Optional, Tuple
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+
18
+ class BasePolicy(nn.Module, abc.ABC):
19
+ """
20
+ Abstract base for all driving policies.
21
+
22
+ Subclasses implement:
23
+ forward(obs_tensor) β†’ action_tensor, value_tensor
24
+ encode_obs(obs_np) β†’ torch.Tensor
25
+ """
26
+
27
+ def __init__(self, obs_dim: int, action_dim: int = 3):
28
+ super().__init__()
29
+ self.obs_dim = obs_dim
30
+ self.action_dim = action_dim
31
+
32
+ @abc.abstractmethod
33
+ def forward(
34
+ self, obs: torch.Tensor
35
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
36
+ """
37
+ Returns:
38
+ action_mean β€” shape (B, action_dim)
39
+ value β€” shape (B, 1)
40
+ """
41
+ ...
42
+
43
+ def predict(
44
+ self,
45
+ obs: np.ndarray,
46
+ deterministic: bool = False,
47
+ ) -> np.ndarray:
48
+ """Numpy in, numpy out. Used by the env during rollout."""
49
+ self.eval()
50
+ with torch.no_grad():
51
+ t = torch.as_tensor(obs, dtype=torch.float32).unsqueeze(0)
52
+ mean, _ = self.forward(t)
53
+ if deterministic:
54
+ action = mean
55
+ else:
56
+ action = mean + torch.randn_like(mean) * 0.1
57
+ return action.squeeze(0).numpy()
58
+
59
+ @staticmethod
60
+ def _mlp(dims: list[int], activation=nn.Tanh) -> nn.Sequential:
61
+ layers = []
62
+ for i in range(len(dims) - 1):
63
+ layers.append(nn.Linear(dims[i], dims[i + 1]))
64
+ if i < len(dims) - 2:
65
+ layers.append(activation())
66
+ return nn.Sequential(*layers)
policies/flat_mlp_policy.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FlatMLPPolicy β€” sanity-check baseline.
3
+
4
+ Concatenates the full observation (ego + all tickets flattened) and passes
5
+ it through a standard MLP. No attention, no structure.
6
+
7
+ Use this to:
8
+ 1. Verify the reward signal and environment are working
9
+ 2. Establish a performance floor
10
+ 3. Confirm that TicketAttentionPolicy actually improves over this
11
+
12
+ If FlatMLPPolicy can't learn Stage 1 survival, the reward or env is broken.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from .base_policy import BasePolicy
21
+
22
+
23
+ class FlatMLPPolicy(BasePolicy):
24
+ """Standard 3-layer MLP over the full flat observation."""
25
+
26
+ def __init__(self, obs_dim: int, hidden: int = 256):
27
+ super().__init__(obs_dim)
28
+
29
+ self.actor = nn.Sequential(
30
+ nn.Linear(obs_dim, hidden), nn.LayerNorm(hidden), nn.Tanh(),
31
+ nn.Linear(hidden, hidden), nn.Tanh(),
32
+ nn.Linear(hidden, hidden // 2), nn.Tanh(),
33
+ nn.Linear(hidden // 2, 3), nn.Tanh(),
34
+ )
35
+ self.critic = nn.Sequential(
36
+ nn.Linear(obs_dim, hidden), nn.Tanh(),
37
+ nn.Linear(hidden, hidden // 2), nn.Tanh(),
38
+ nn.Linear(hidden // 2, 1),
39
+ )
40
+ self._init_weights()
41
+
42
+ def _init_weights(self):
43
+ for m in self.modules():
44
+ if isinstance(m, nn.Linear):
45
+ nn.init.orthogonal_(m.weight, gain=1.0)
46
+ nn.init.zeros_(m.bias)
47
+ nn.init.orthogonal_(self.actor[-2].weight, gain=0.01)
48
+
49
+ def forward(self, obs: torch.Tensor):
50
+ return self.actor(obs), self.critic(obs)
policies/policy_spec.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Policy data input specifications β€” formal contracts for observation, action, and ticket data.
3
+
4
+ This module defines the exact data shapes, normalization ranges, and semantic meaning
5
+ of every field consumed by OpenENV policies. Use this as the reference when:
6
+
7
+ 1. Building a new environment that targets these policies
8
+ 2. Writing a bridge/adapter from a different simulator
9
+ 3. Implementing a new policy that must interoperate with the existing set
10
+
11
+ All policies share the same raw observation layout (EGO + ticket matrix).
12
+ Specialized policies (ThreatAvoidance, SystemFailure) select subsets internally.
13
+
14
+ Example usage:
15
+ from openenv.policies.policy_spec import ObsSpec, ActionSpec, validate_obs
16
+
17
+ spec = ObsSpec()
18
+ obs = my_env.get_observation()
19
+ validate_obs(obs, spec) # raises ValueError on shape/range mismatch
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ from dataclasses import dataclass, field
25
+ from typing import Any, Dict, List, Optional, Tuple
26
+
27
+ import numpy as np
28
+
29
+
30
+ # ── Ego state specification ──────────────────────────────────────────────────
31
+
32
+ EGO_STATE_DIM = 11
33
+
34
+ @dataclass(frozen=True)
35
+ class EgoField:
36
+ """Description of a single ego state field."""
37
+ index: int
38
+ name: str
39
+ unit: str
40
+ raw_range: Tuple[float, float] # physical range before normalization
41
+ norm_divisor: float # obs_value = raw_value / norm_divisor
42
+ description: str
43
+
44
+ EGO_FIELDS: List[EgoField] = [
45
+ EgoField(0, "x", "m", (-5000, 5000), 1000.0, "Forward displacement from episode start"),
46
+ EgoField(1, "y", "m", (-6.0, 6.0), 3.7, "Lateral displacement (0 = lane center, + = left)"),
47
+ EgoField(2, "z", "m", (-10, 10), 10.0, "Vertical position (flat road = 0)"),
48
+ EgoField(3, "vx", "m/s", (-20, 20), 20.0, "Forward velocity in world frame"),
49
+ EgoField(4, "vy", "m/s", (-20, 20), 20.0, "Lateral velocity in world frame"),
50
+ EgoField(5, "vz", "m/s", (0, 0), 1.0, "Vertical velocity (always 0 on flat road)"),
51
+ EgoField(6, "heading_sin", "rad", (-1, 1), 1.0, "sin(heading angle), 0 = forward"),
52
+ EgoField(7, "heading_cos", "rad", (-1, 1), 1.0, "cos(heading angle), 1 = forward"),
53
+ EgoField(8, "speed", "m/s", (0, 20), 20.0, "Scalar speed = sqrt(vx^2 + vy^2)"),
54
+ EgoField(9, "steer", "norm", (-1, 1), 1.0, "Current steering command [-1=full left, 1=full right]"),
55
+ EgoField(10, "net_drive", "norm", (-1, 1), 1.0, "throttle - brake [-1=full brake, 1=full throttle]"),
56
+ ]
57
+
58
+
59
+ # ── Ticket vector specification ──────────────────────────────────────────────
60
+
61
+ TICKET_VECTOR_DIM = 37 # 18 fixed + 14 type one-hot + 5 entity one-hot
62
+ MAX_TICKETS = 16
63
+
64
+ # Ticket types (14 total) β€” one-hot encoded starting at index 18
65
+ TICKET_TYPES = [
66
+ "collision_risk", "sudden_brake", "side_impact", "head_on",
67
+ "merge_cut", "rear_end_risk",
68
+ "pedestrian_crossing", "cyclist_lane",
69
+ "tire_blowout", "brake_fade", "steering_loss", "sensor_occlusion",
70
+ "road_hazard", "weather_visibility",
71
+ ]
72
+
73
+ # Entity types (5 total) β€” one-hot encoded after ticket types
74
+ ENTITY_TYPES = ["vehicle", "pedestrian", "cyclist", "obstacle", "system"]
75
+
76
+ # Verify dimension
77
+ assert 18 + len(TICKET_TYPES) + len(ENTITY_TYPES) == TICKET_VECTOR_DIM, (
78
+ f"Ticket vector dim mismatch: 18 + {len(TICKET_TYPES)} + {len(ENTITY_TYPES)} "
79
+ f"!= {TICKET_VECTOR_DIM}"
80
+ )
81
+
82
+ @dataclass(frozen=True)
83
+ class TicketField:
84
+ """Description of a single ticket vector field."""
85
+ offset: int # index within the TICKET_VECTOR_DIM vector
86
+ length: int # number of floats
87
+ name: str
88
+ unit: str
89
+ raw_range: Tuple[float, float]
90
+ norm_divisor: float
91
+ description: str
92
+
93
+ TICKET_FIELDS: List[TicketField] = [
94
+ TicketField(0, 1, "severity_weight", "norm", (0, 1), 1.0, "Severity: 0.25=LOW, 0.5=MED, 0.75=HIGH, 1.0=CRITICAL"),
95
+ TicketField(1, 1, "ttl_norm", "s", (0, 10), 10.0, "Time-to-live remaining, clamped to [0,1]"),
96
+ TicketField(2, 1, "pos_x", "m", (-100, 100), 100.0, "Ego-relative X (forward positive)"),
97
+ TicketField(3, 1, "pos_y", "m", (-50, 50), 50.0, "Ego-relative Y (left positive)"),
98
+ TicketField(4, 1, "pos_z", "m", (-10, 10), 10.0, "Ego-relative Z (up positive)"),
99
+ TicketField(5, 1, "vel_x", "m/s", (-30, 30), 30.0, "Entity velocity X in world frame"),
100
+ TicketField(6, 1, "vel_y", "m/s", (-30, 30), 30.0, "Entity velocity Y in world frame"),
101
+ TicketField(7, 1, "vel_z", "m/s", (-10, 10), 10.0, "Entity velocity Z in world frame"),
102
+ TicketField(8, 1, "heading_sin", "rad", (-1, 1), 1.0, "sin(entity heading relative to ego)"),
103
+ TicketField(9, 1, "heading_cos", "rad", (-1, 1), 1.0, "cos(entity heading relative to ego)"),
104
+ TicketField(10, 1, "size_length", "m", (0, 10), 10.0, "Entity bounding box length"),
105
+ TicketField(11, 1, "size_width", "m", (0, 5), 5.0, "Entity bounding box width"),
106
+ TicketField(12, 1, "size_height", "m", (0, 4), 4.0, "Entity bounding box height"),
107
+ TicketField(13, 1, "distance_norm", "m", (0, 100), 100.0, "Euclidean distance to ego, clamped to [0,1]"),
108
+ TicketField(14, 1, "ttc_norm", "s", (0, 30), 30.0, "Time-to-collision, clamped to [0,1]. 1.0 = no collision"),
109
+ TicketField(15, 1, "bearing_sin", "rad", (-1, 1), 1.0, "sin(bearing angle from ego forward axis)"),
110
+ TicketField(16, 1, "bearing_cos", "rad", (-1, 1), 1.0, "cos(bearing angle from ego forward axis)"),
111
+ TicketField(17, 1, "confidence", "norm", (0, 1), 1.0, "Perception confidence [0=unreliable, 1=certain]"),
112
+ TicketField(18, len(TICKET_TYPES), "type_onehot", "bool", (0, 1), 1.0, "One-hot ticket type"),
113
+ TicketField(18 + len(TICKET_TYPES), len(ENTITY_TYPES), "entity_onehot", "bool", (0, 1), 1.0, "One-hot entity type"),
114
+ ]
115
+
116
+
117
+ # ── Full observation specification ───────────────────────────────────────────
118
+
119
+ OBS_DIM = EGO_STATE_DIM + MAX_TICKETS * TICKET_VECTOR_DIM # 11 + 16*37 = 603
120
+
121
+ @dataclass(frozen=True)
122
+ class ObsSpec:
123
+ """Complete observation space specification."""
124
+ ego_dim: int = EGO_STATE_DIM
125
+ ticket_dim: int = TICKET_VECTOR_DIM
126
+ max_tickets: int = MAX_TICKETS
127
+ total_dim: int = OBS_DIM
128
+ dtype: str = "float32"
129
+ value_range: Tuple[float, float] = (-1.0, 1.0)
130
+
131
+ # Layout: obs[0:ego_dim] = ego state
132
+ # obs[ego_dim:] reshaped to (max_tickets, ticket_dim)
133
+ # Tickets are sorted by severity desc, distance asc. Zero-padded rows = empty slots.
134
+
135
+
136
+ # ── Action specification ─────────────────────────────────────────────────────
137
+
138
+ @dataclass(frozen=True)
139
+ class ActionField:
140
+ index: int
141
+ name: str
142
+ raw_range: Tuple[float, float]
143
+ description: str
144
+
145
+ ACTION_DIM = 3
146
+
147
+ ACTION_FIELDS: List[ActionField] = [
148
+ ActionField(0, "steer", (-1.0, 1.0), "Steering command. -1=full left, +1=full right. Scaled by MAX_STEER=0.6 rad"),
149
+ ActionField(1, "throttle", (-1.0, 1.0), "Throttle command. Only positive values used (clipped to [0,1]). Scaled by MAX_ACCEL=4.0 m/s^2"),
150
+ ActionField(2, "brake", (-1.0, 1.0), "Brake command. Only positive values used (clipped to [0,1]). Scaled by MAX_BRAKE=8.0 m/s^2"),
151
+ ]
152
+
153
+ @dataclass(frozen=True)
154
+ class ActionSpec:
155
+ """Action space specification."""
156
+ dim: int = ACTION_DIM
157
+ dtype: str = "float32"
158
+ value_range: Tuple[float, float] = (-1.0, 1.0)
159
+
160
+
161
+ # ── Policy input requirements ────────────────────────────────────────────────
162
+
163
+ @dataclass(frozen=True)
164
+ class PolicyInputSpec:
165
+ """Describes what a specific policy reads from the observation."""
166
+ name: str
167
+ reads_ego: bool
168
+ ego_indices: Tuple[int, ...] # which ego fields are used
169
+ reads_tickets: bool
170
+ ticket_filter: Optional[str] # None = all, or "kinematic" / "failure"
171
+ max_tickets_used: int # how many ticket slots the policy actually reads
172
+ requires_history: bool # whether GRU/recurrent hidden state is needed
173
+ description: str
174
+
175
+ POLICY_SPECS: Dict[str, PolicyInputSpec] = {
176
+ "SurvivalPolicy": PolicyInputSpec(
177
+ name="SurvivalPolicy",
178
+ reads_ego=True,
179
+ ego_indices=tuple(range(EGO_STATE_DIM)),
180
+ reads_tickets=False,
181
+ ticket_filter=None,
182
+ max_tickets_used=0,
183
+ requires_history=False,
184
+ description="Stage 1 baseline. Reads only ego state (first 11 dims). "
185
+ "Ticket portion of obs is ignored entirely.",
186
+ ),
187
+ "FlatMLPPolicy": PolicyInputSpec(
188
+ name="FlatMLPPolicy",
189
+ reads_ego=True,
190
+ ego_indices=tuple(range(EGO_STATE_DIM)),
191
+ reads_tickets=True,
192
+ ticket_filter=None,
193
+ max_tickets_used=MAX_TICKETS,
194
+ requires_history=False,
195
+ description="Sanity-check baseline. Reads full flat observation (ego + all tickets "
196
+ "concatenated). No attention or structure.",
197
+ ),
198
+ "TicketAttentionPolicy": PolicyInputSpec(
199
+ name="TicketAttentionPolicy",
200
+ reads_ego=True,
201
+ ego_indices=tuple(range(EGO_STATE_DIM)),
202
+ reads_tickets=True,
203
+ ticket_filter=None,
204
+ max_tickets_used=MAX_TICKETS,
205
+ requires_history=False,
206
+ description="Main policy (Stage 2+). Cross-attention: ego queries ticket set. "
207
+ "Order-invariant over tickets. Padding mask on zero-rows.",
208
+ ),
209
+ "ThreatAvoidancePolicy": PolicyInputSpec(
210
+ name="ThreatAvoidancePolicy",
211
+ reads_ego=True,
212
+ ego_indices=tuple(range(EGO_STATE_DIM)),
213
+ reads_tickets=True,
214
+ ticket_filter="kinematic",
215
+ max_tickets_used=1,
216
+ requires_history=False,
217
+ description="Specialist for kinematic threats (collision_risk, sudden_brake, "
218
+ "side_impact, head_on, merge_cut, rear_end_risk). Extracts the "
219
+ "highest-severity kinematic ticket and gates between brake/evade branches.",
220
+ ),
221
+ "SystemFailurePolicy": PolicyInputSpec(
222
+ name="SystemFailurePolicy",
223
+ reads_ego=True,
224
+ ego_indices=tuple(range(EGO_STATE_DIM)),
225
+ reads_tickets=True,
226
+ ticket_filter="failure",
227
+ max_tickets_used=1,
228
+ requires_history=False,
229
+ description="Specialist for onboard failures (tire_blowout, brake_fade, steering_loss). "
230
+ "Mixture-of-experts with one expert per failure type. Initialized with "
231
+ "domain-correct response priors.",
232
+ ),
233
+ "RecurrentPolicy": PolicyInputSpec(
234
+ name="RecurrentPolicy",
235
+ reads_ego=True,
236
+ ego_indices=tuple(range(EGO_STATE_DIM)),
237
+ reads_tickets=True,
238
+ ticket_filter=None,
239
+ max_tickets_used=MAX_TICKETS,
240
+ requires_history=True,
241
+ description="GRU-based policy for partial observability (Stage 4+). Carries hidden "
242
+ "state across timesteps. Requires h_prev to be tracked by caller.",
243
+ ),
244
+ }
245
+
246
+
247
+ # ── Validation helpers ───────────────────────────────────────────────────────
248
+
249
+ def validate_obs(obs: np.ndarray, spec: Optional[ObsSpec] = None) -> None:
250
+ """
251
+ Validate an observation array against the spec.
252
+ Raises ValueError with a descriptive message on any mismatch.
253
+ """
254
+ spec = spec or ObsSpec()
255
+ if obs.ndim != 1:
256
+ raise ValueError(f"Observation must be 1D, got shape {obs.shape}")
257
+ if obs.shape[0] != spec.total_dim:
258
+ raise ValueError(
259
+ f"Observation dim mismatch: expected {spec.total_dim}, got {obs.shape[0]}. "
260
+ f"Check ego_dim ({spec.ego_dim}) + max_tickets ({spec.max_tickets}) "
261
+ f"* ticket_dim ({spec.ticket_dim})"
262
+ )
263
+ if obs.dtype != np.float32:
264
+ raise ValueError(f"Observation dtype must be float32, got {obs.dtype}")
265
+
266
+
267
+ def validate_action(action: np.ndarray) -> None:
268
+ """Validate an action array."""
269
+ if action.shape != (ACTION_DIM,):
270
+ raise ValueError(f"Action shape mismatch: expected ({ACTION_DIM},), got {action.shape}")
271
+ if np.any(action < -1.0) or np.any(action > 1.0):
272
+ raise ValueError(f"Action values must be in [-1, 1], got min={action.min()}, max={action.max()}")
273
+
274
+
275
+ def build_obs(
276
+ ego_x: float, ego_y: float, ego_z: float,
277
+ ego_vx: float, ego_vy: float,
278
+ heading: float, speed: float,
279
+ steer: float, throttle: float, brake: float,
280
+ ticket_vectors: Optional[np.ndarray] = None,
281
+ max_tickets: int = MAX_TICKETS,
282
+ ) -> np.ndarray:
283
+ """
284
+ Build a valid observation vector from raw values.
285
+
286
+ This is the primary entry point for external environments that want to
287
+ produce observations compatible with OpenENV policies.
288
+
289
+ Parameters
290
+ ----------
291
+ ego_x : forward displacement from episode start (metres)
292
+ ego_y : lateral displacement from lane center (metres, + = left)
293
+ ego_z : vertical position (metres)
294
+ ego_vx : forward velocity (m/s)
295
+ ego_vy : lateral velocity (m/s)
296
+ heading : heading angle (radians, 0 = forward)
297
+ speed : scalar speed (m/s)
298
+ steer : current steering command [-1, 1]
299
+ throttle : current throttle command [0, 1]
300
+ brake : current brake command [0, 1]
301
+ ticket_vectors : (N, TICKET_VECTOR_DIM) array of ticket vectors, or None.
302
+ Use EventTicket.to_vector() or build_ticket_vector() to create these.
303
+ max_tickets : number of ticket slots (must match policy expectation, default 16)
304
+
305
+ Returns
306
+ -------
307
+ obs : np.ndarray of shape (EGO_STATE_DIM + max_tickets * TICKET_VECTOR_DIM,)
308
+ """
309
+ import math
310
+
311
+ ego = np.array([
312
+ ego_x / 1000.0,
313
+ ego_y / 3.7, # ROAD_HALF_WIDTH
314
+ ego_z / 10.0,
315
+ ego_vx / 20.0, # MAX_SPEED
316
+ ego_vy / 20.0,
317
+ 0.0, # vz (flat road)
318
+ math.sin(heading),
319
+ math.cos(heading),
320
+ speed / 20.0,
321
+ steer,
322
+ throttle - brake, # net drive signal
323
+ ], dtype=np.float32)
324
+
325
+ ticket_matrix = np.zeros((max_tickets, TICKET_VECTOR_DIM), dtype=np.float32)
326
+ if ticket_vectors is not None:
327
+ n = min(len(ticket_vectors), max_tickets)
328
+ ticket_matrix[:n] = ticket_vectors[:n]
329
+
330
+ return np.concatenate([ego, ticket_matrix.flatten()])
331
+
332
+
333
+ def build_ticket_vector(
334
+ severity_weight: float,
335
+ ttl: float,
336
+ pos_x: float, pos_y: float, pos_z: float,
337
+ vel_x: float, vel_y: float, vel_z: float,
338
+ heading: float,
339
+ size_length: float, size_width: float, size_height: float,
340
+ distance: float,
341
+ time_to_collision: Optional[float],
342
+ bearing: float,
343
+ ticket_type: str,
344
+ entity_type: str,
345
+ confidence: float = 1.0,
346
+ ) -> np.ndarray:
347
+ """
348
+ Build a single ticket vector from raw values without needing the full
349
+ EventTicket class. Use this when adapting a different simulator.
350
+
351
+ Parameters
352
+ ----------
353
+ severity_weight : 0.25 (LOW), 0.5 (MEDIUM), 0.75 (HIGH), 1.0 (CRITICAL)
354
+ ttl : seconds remaining until ticket expires
355
+ pos_x/y/z : ego-relative position (metres)
356
+ vel_x/y/z : entity velocity in world frame (m/s)
357
+ heading : entity heading relative to ego (radians)
358
+ size_length/width/height : entity bounding box (metres)
359
+ distance : euclidean distance to ego (metres)
360
+ time_to_collision : seconds until collision, or None if no collision course
361
+ bearing : angle from ego forward axis (radians)
362
+ ticket_type : one of TICKET_TYPES (e.g., "collision_risk")
363
+ entity_type : one of ENTITY_TYPES (e.g., "vehicle")
364
+ confidence : perception confidence [0, 1]
365
+
366
+ Returns
367
+ -------
368
+ vec : np.ndarray of shape (TICKET_VECTOR_DIM,) = (37,)
369
+ """
370
+ import math
371
+
372
+ ttc_norm = min((time_to_collision if time_to_collision is not None else 30.0) / 30.0, 1.0)
373
+
374
+ type_oh = [0.0] * len(TICKET_TYPES)
375
+ entity_oh = [0.0] * len(ENTITY_TYPES)
376
+
377
+ if ticket_type in TICKET_TYPES:
378
+ type_oh[TICKET_TYPES.index(ticket_type)] = 1.0
379
+ else:
380
+ raise ValueError(f"Unknown ticket_type '{ticket_type}'. Must be one of {TICKET_TYPES}")
381
+
382
+ if entity_type in ENTITY_TYPES:
383
+ entity_oh[ENTITY_TYPES.index(entity_type)] = 1.0
384
+ else:
385
+ raise ValueError(f"Unknown entity_type '{entity_type}'. Must be one of {ENTITY_TYPES}")
386
+
387
+ vec = [
388
+ severity_weight,
389
+ min(ttl / 10.0, 1.0),
390
+ pos_x / 100.0,
391
+ pos_y / 50.0,
392
+ pos_z / 10.0,
393
+ vel_x / 30.0,
394
+ vel_y / 30.0,
395
+ vel_z / 10.0,
396
+ math.sin(heading),
397
+ math.cos(heading),
398
+ size_length / 10.0,
399
+ size_width / 5.0,
400
+ size_height / 4.0,
401
+ min(distance / 100.0, 1.0),
402
+ ttc_norm,
403
+ math.sin(bearing),
404
+ math.cos(bearing),
405
+ confidence,
406
+ *type_oh,
407
+ *entity_oh,
408
+ ]
409
+ return np.array(vec, dtype=np.float32)
policies/ticket_attention_policy.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TicketAttentionPolicy β€” the main policy (Stage 2+).
3
+
4
+ Architecture: two-pass "reflective" cross-attention.
5
+
6
+ Pass 1: ego queries tickets β†’ raw threat context
7
+ Pass 2: (ego + raw context) queries tickets again β†’ refined context
8
+ This forces the policy to "think twice" β€” first perceive, then plan.
9
+
10
+ [ego | refined_context] β†’ steer head β†’ steer action
11
+ β†’ drive head β†’ throttle, brake
12
+ β†’ critic head β†’ value
13
+
14
+ Why two-pass:
15
+ The first pass gathers what threats exist. The second pass re-examines
16
+ tickets knowing what the overall threat picture looks like. This prevents
17
+ the impulsive single-shot responses that cause wild oscillation.
18
+
19
+ Why separate heads:
20
+ Steering requires smooth, conservative output (off-road = death).
21
+ Throttle/brake can be more aggressive. Separate heads + separate
22
+ noise levels let each dimension learn at its own pace.
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.nn.functional as F
30
+
31
+ from .base_policy import BasePolicy
32
+ EGO_STATE_DIM = 11
33
+ MAX_TICKETS = 16
34
+ TICKET_VECTOR_DIM = 37
35
+
36
+
37
+ class TicketAttentionPolicy(BasePolicy):
38
+ """
39
+ Two-pass reflective attention policy.
40
+
41
+ Pass 1: perceive β€” what threats exist?
42
+ Pass 2: plan β€” given what I see, which threats matter most?
43
+ Output: separate steer head (conservative) + drive head (throttle/brake)
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ obs_dim: int,
49
+ ego_embed: int = 64,
50
+ ticket_embed: int = 64,
51
+ n_heads: int = 4,
52
+ hidden: int = 256,
53
+ ):
54
+ super().__init__(obs_dim)
55
+ assert ego_embed % n_heads == 0
56
+ assert ticket_embed == ego_embed
57
+
58
+ self.ego_embed = ego_embed
59
+ self.max_tickets = MAX_TICKETS
60
+ self.ticket_dim = TICKET_VECTOR_DIM
61
+
62
+ # ── Encoders ──────────────────────────────────────────────────────
63
+ self.ego_encoder = nn.Sequential(
64
+ nn.Linear(EGO_STATE_DIM, hidden // 2),
65
+ nn.LayerNorm(hidden // 2),
66
+ nn.Tanh(),
67
+ nn.Linear(hidden // 2, ego_embed),
68
+ nn.LayerNorm(ego_embed),
69
+ )
70
+ self.ticket_encoder = nn.Sequential(
71
+ nn.Linear(TICKET_VECTOR_DIM, hidden // 2),
72
+ nn.LayerNorm(hidden // 2),
73
+ nn.ReLU(),
74
+ nn.Linear(hidden // 2, ticket_embed),
75
+ nn.LayerNorm(ticket_embed),
76
+ )
77
+
78
+ # ── Pass 1: perceive (ego queries tickets) ───────────────────────
79
+ self.attn_pass1 = nn.MultiheadAttention(
80
+ embed_dim=ego_embed, num_heads=n_heads,
81
+ dropout=0.0, batch_first=True,
82
+ )
83
+ self.norm1 = nn.LayerNorm(ego_embed)
84
+
85
+ # ── Reflection gate: fuse ego + pass1 context for second query ───
86
+ self.reflect_proj = nn.Sequential(
87
+ nn.Linear(ego_embed * 2, ego_embed),
88
+ nn.LayerNorm(ego_embed),
89
+ nn.Tanh(),
90
+ )
91
+
92
+ # ── Pass 2: plan (refined query re-attends to tickets) ───────────
93
+ self.attn_pass2 = nn.MultiheadAttention(
94
+ embed_dim=ego_embed, num_heads=n_heads,
95
+ dropout=0.0, batch_first=True,
96
+ )
97
+ self.norm2 = nn.LayerNorm(ego_embed)
98
+
99
+ # ── Fused representation ─────────────────────────────────────────
100
+ fused_dim = ego_embed + ego_embed # ego + refined context
101
+
102
+ # ── Steer head (conservative, smooth output) ─────────────────────
103
+ self.steer_head = nn.Sequential(
104
+ nn.Linear(fused_dim, hidden // 2),
105
+ nn.LayerNorm(hidden // 2),
106
+ nn.Tanh(),
107
+ nn.Linear(hidden // 2, hidden // 4),
108
+ nn.Tanh(),
109
+ nn.Linear(hidden // 4, 1),
110
+ nn.Tanh(),
111
+ )
112
+
113
+ # ── Drive head (throttle + brake) ────────────────────────────────
114
+ self.drive_head = nn.Sequential(
115
+ nn.Linear(fused_dim, hidden // 2),
116
+ nn.LayerNorm(hidden // 2),
117
+ nn.Tanh(),
118
+ nn.Linear(hidden // 2, hidden // 4),
119
+ nn.Tanh(),
120
+ nn.Linear(hidden // 4, 2),
121
+ nn.Tanh(),
122
+ )
123
+
124
+ # ── Critic head ──────────────────────────────────────────────────
125
+ self.critic = nn.Sequential(
126
+ nn.Linear(fused_dim, hidden),
127
+ nn.LayerNorm(hidden),
128
+ nn.Tanh(),
129
+ nn.Linear(hidden, hidden // 2),
130
+ nn.Tanh(),
131
+ nn.Linear(hidden // 2, 1),
132
+ )
133
+
134
+ self._init_weights()
135
+
136
+ def _init_weights(self):
137
+ for m in self.modules():
138
+ if isinstance(m, nn.Linear):
139
+ nn.init.orthogonal_(m.weight, gain=1.0)
140
+ if m.bias is not None:
141
+ nn.init.zeros_(m.bias)
142
+ # Very small initial actions β€” start by doing almost nothing
143
+ nn.init.orthogonal_(self.steer_head[-2].weight, gain=0.01)
144
+ nn.init.orthogonal_(self.drive_head[-2].weight, gain=0.01)
145
+ # Critic starts near zero
146
+ nn.init.orthogonal_(self.critic[-1].weight, gain=0.1)
147
+
148
+ def _attend(self, attn_module, norm_module, query, tk_emb, is_padding, all_empty):
149
+ """Run one attention pass with NaN-safe masking."""
150
+ B = query.shape[0]
151
+ q = query if query.dim() == 3 else query.unsqueeze(1)
152
+
153
+ if all_empty.all():
154
+ return torch.zeros(B, self.ego_embed, device=query.device)
155
+
156
+ safe_mask = is_padding.clone()
157
+ safe_mask[all_empty, 0] = False
158
+ attn_out, _ = attn_module(
159
+ query=q, key=tk_emb, value=tk_emb,
160
+ key_padding_mask=safe_mask,
161
+ )
162
+ context = attn_out.squeeze(1)
163
+ context[all_empty] = 0.0
164
+ return norm_module(context)
165
+
166
+ def forward(self, obs: torch.Tensor):
167
+ B = obs.shape[0]
168
+
169
+ # Split observation
170
+ ego_raw = obs[:, :EGO_STATE_DIM]
171
+ tk_raw = obs[:, EGO_STATE_DIM:].view(B, self.max_tickets, self.ticket_dim)
172
+
173
+ # Encode
174
+ ego_emb = self.ego_encoder(ego_raw)
175
+ tk_emb = self.ticket_encoder(tk_raw)
176
+
177
+ # Padding mask
178
+ is_padding = (tk_raw.abs().sum(dim=-1) == 0)
179
+ all_empty = is_padding.all(dim=-1)
180
+
181
+ # ── Pass 1: perceive ─────────────────────────────────────────────
182
+ ctx1 = self._attend(self.attn_pass1, self.norm1,
183
+ ego_emb, tk_emb, is_padding, all_empty)
184
+
185
+ # ── Reflect: combine ego + initial context into refined query ────
186
+ reflected = self.reflect_proj(torch.cat([ego_emb, ctx1], dim=-1))
187
+
188
+ # ── Pass 2: plan (re-attend with richer query) ───────────────────
189
+ ctx2 = self._attend(self.attn_pass2, self.norm2,
190
+ reflected, tk_emb, is_padding, all_empty)
191
+
192
+ # ── Fuse and decode ──────────────────────────────────────────────
193
+ fused = torch.cat([ego_emb, ctx2], dim=-1)
194
+
195
+ steer = self.steer_head(fused) # (B, 1)
196
+ drive = self.drive_head(fused) # (B, 2)
197
+ action = torch.cat([steer, drive], dim=-1) # (B, 3)
198
+ value = self.critic(fused) # (B, 1)
199
+
200
+ return action, value
201
+
202
+ def get_attention_weights(self, obs: torch.Tensor) -> torch.Tensor:
203
+ """Returns pass-2 attention weights for interpretability."""
204
+ B = obs.shape[0]
205
+ ego_raw = obs[:, :EGO_STATE_DIM]
206
+ tk_raw = obs[:, EGO_STATE_DIM:].view(B, self.max_tickets, self.ticket_dim)
207
+ ego_emb = self.ego_encoder(ego_raw)
208
+ tk_emb = self.ticket_encoder(tk_raw)
209
+ is_padding = (tk_raw.abs().sum(dim=-1) == 0)
210
+ all_empty = is_padding.all(dim=-1)
211
+
212
+ # Pass 1
213
+ ctx1 = self._attend(self.attn_pass1, self.norm1,
214
+ ego_emb, tk_emb, is_padding, all_empty)
215
+ reflected = self.reflect_proj(torch.cat([ego_emb, ctx1], dim=-1))
216
+
217
+ # Pass 2 β€” get weights
218
+ safe_mask = is_padding.clone()
219
+ safe_mask[all_empty, 0] = False
220
+ query = reflected.unsqueeze(1)
221
+ _, weights = self.attn_pass2(
222
+ query=query, key=tk_emb, value=tk_emb,
223
+ key_padding_mask=safe_mask,
224
+ need_weights=True, average_attn_weights=False,
225
+ )
226
+ weights[all_empty] = 0.0
227
+ return weights
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cpu
2
+ torch==2.5.1+cpu
3
+ numpy>=1.24.0
4
+ gradio>=4.44.0
5
+ pydantic>=2.0.0
6
+ requests>=2.31.0
7
+ openenv-overflow-env @ git+https://huggingface.co/spaces/SteveDusty/overflow_env