""" OpenENV RL Demo — auto-runs 20 steps per episode using the openenv policy system. Policies: FlatMLPPolicy / TicketAttentionPolicy (from openenv) Training: PPO mini-update after each episode — rewards increase over time Display: Live step-by-step feed + episode reward history """ import math, time, threading import numpy as np import torch import torch.optim as optim import gradio as gr from overflow_env.environment import OverflowEnvironment from overflow_env.models import OverflowAction from policies.flat_mlp_policy import FlatMLPPolicy from policies.ticket_attention_policy import TicketAttentionPolicy from policies.policy_spec import build_obs, build_ticket_vector, OBS_DIM STEPS_PER_EPISODE = 20 # ── Observation adapter ─────────────────────────────────────────────────────── def obs_to_vec(overflow_obs) -> np.ndarray: cars = overflow_obs.cars if not cars: return np.zeros(OBS_DIM, dtype=np.float32) ego = next((c for c in cars if c.carId == 0), cars[0]) ego_spd = ego.speed / 4.5 ego_x = ego.position.x ego_y = (ego.lane - 2) * 3.7 tickets = [] for car in cars: if car.carId == 0: continue rx = car.position.x - ego.position.x ry = (car.lane - ego.lane) * 3.7 cs = car.speed / 4.5 d = math.sqrt(rx**2 + ry**2) if d > 80: continue cl = max(ego_spd - cs * math.copysign(1, max(rx, 0.01)), 0.1) tickets.append(build_ticket_vector( severity_weight=1.0 if d < 8 else 0.75 if d < 15 else 0.5, ttl=5.0, pos_x=rx, pos_y=ry, pos_z=0.0, vel_x=cs, vel_y=0.0, vel_z=0.0, heading=0.0, size_length=4.0, size_width=2.0, size_height=1.5, distance=d, time_to_collision=min(d / cl, 30.0), bearing=math.atan2(ry, max(rx, 0.01)), ticket_type="collision_risk", entity_type="vehicle", confidence=1.0, )) tv = np.array(tickets, dtype=np.float32) if tickets else None return build_obs(ego_x=ego_x, ego_y=ego_y, ego_z=0.0, ego_vx=ego_spd, ego_vy=0.0, heading=0.0, speed=ego_spd, steer=0.0, throttle=0.5, brake=0.0, ticket_vectors=tv) def action_to_decision(a: np.ndarray) -> str: s, t, b = float(a[0]), float(a[1]), float(a[2]) if abs(s) > 0.35: return "lane_change_left" if s < 0 else "lane_change_right" if b > 0.25: return "brake" if t > 0.20: return "accelerate" return "maintain" # ── Global training state ───────────────────────────────────────────────────── policy = FlatMLPPolicy(obs_dim=OBS_DIM) optimizer = optim.Adam(policy.parameters(), lr=3e-4, eps=1e-5) # Rollout buffer (lightweight — one episode at a time) _buf_obs = [] _buf_acts = [] _buf_rews = [] _buf_logps = [] _buf_vals = [] _buf_dones = [] episode_history = [] # [{ep, steps, reward, outcome}] step_log = [] # [{ep, step, decision, reward, scene}] _running = False _lock = threading.Lock() def _ppo_mini_update(): """Single PPO gradient step on the just-completed episode.""" if len(_buf_obs) < 2: return obs_t = torch.tensor(np.array(_buf_obs), dtype=torch.float32) acts_t = torch.tensor(np.array(_buf_acts), dtype=torch.float32) rews_t = torch.tensor(_buf_rews, dtype=torch.float32) logp_t = torch.tensor(_buf_logps, dtype=torch.float32) vals_t = torch.tensor(_buf_vals, dtype=torch.float32) done_t = torch.tensor(_buf_dones, dtype=torch.float32) # GAE returns gamma, lam = 0.99, 0.95 adv = torch.zeros_like(rews_t) gae = 0.0 for t in reversed(range(len(rews_t))): nv = 0.0 if t == len(rews_t) - 1 else float(vals_t[t + 1]) d = rews_t[t] + gamma * nv * (1 - done_t[t]) - vals_t[t] gae = d + gamma * lam * (1 - done_t[t]) * gae adv[t] = gae ret = adv + vals_t adv = (adv - adv.mean()) / (adv.std() + 1e-8) policy.train() act_mean, val = policy(obs_t) val = val.squeeze(-1) dist = torch.distributions.Normal(act_mean, torch.ones_like(act_mean) * 0.3) logp = dist.log_prob(acts_t).sum(dim=-1) entropy = dist.entropy().sum(dim=-1).mean() ratio = torch.exp(logp - logp_t) pg = torch.max(-adv * ratio, -adv * ratio.clamp(0.8, 1.2)).mean() vf = 0.5 * ((val - ret) ** 2).mean() loss = pg + 0.5 * vf - 0.02 * entropy optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(policy.parameters(), 0.5) optimizer.step() def run_episodes_loop(): """Background thread — runs episodes continuously, updates policy after each.""" global _running ep_num = 0 env = OverflowEnvironment() while _running: ep_num += 1 obs = env.reset() ep_rew = 0.0 outcome = "timeout" _buf_obs.clear(); _buf_acts.clear(); _buf_rews.clear() _buf_logps.clear(); _buf_vals.clear(); _buf_dones.clear() for step in range(1, STEPS_PER_EPISODE + 1): if not _running: break obs_vec = obs_to_vec(obs) policy.eval() with torch.no_grad(): obs_t = torch.tensor(obs_vec, dtype=torch.float32).unsqueeze(0) act_mean, val = policy(obs_t) dist = torch.distributions.Normal(act_mean.squeeze(0), torch.ones(3) * 0.3) action = dist.sample().clamp(-1, 1) logp = dist.log_prob(action).sum() decision = action_to_decision(action.numpy()) obs = env.step(OverflowAction(decision=decision, reasoning="")) reward = float(obs.reward or 0.0) done = obs.done _buf_obs.append(obs_vec) _buf_acts.append(action.numpy()) _buf_rews.append(reward) _buf_logps.append(float(logp)) _buf_vals.append(float(val.squeeze())) _buf_dones.append(float(done)) ep_rew += reward with _lock: step_log.append({ "ep": ep_num, "step": step, "decision": decision, "reward": round(reward, 2), "ep_reward": round(ep_rew, 2), "scene": obs.scene_description.split("\n")[0], "incident": obs.incident_report or "", }) if done: outcome = "CRASH" if "CRASH" in (obs.incident_report or "") else "GOAL" break time.sleep(0.6) # pace so UI can show each step _ppo_mini_update() with _lock: episode_history.append({ "ep": ep_num, "steps": step, "reward": round(ep_rew, 2), "outcome": outcome, }) # ── Gradio UI ───────────────────────────────────────────────────────────────── def start_training(): global _running if not _running: _running = True step_log.clear() episode_history.clear() t = threading.Thread(target=run_episodes_loop, daemon=True) t.start() return gr.update(value="Running...", interactive=False), gr.update(interactive=True) def stop_training(): global _running _running = False return gr.update(value="Start", interactive=True), gr.update(interactive=False) def get_updates(): """Called by gr.Timer every second — returns latest display content.""" with _lock: logs = list(step_log[-20:]) eps = list(episode_history[-30:]) # Step feed lines = [] for e in reversed(logs): flag = "" if "CRASH" in e["incident"]: flag = " 💥" elif "GOAL" in e["incident"]: flag = " ✓" elif "NEAR MISS" in e["incident"]: flag = " ⚠" lines.append( f"ep {e['ep']:>3d} | step {e['step']:>2d} | " f"{e['decision']:<20} | r={e['reward']:>+6.2f} | " f"ep_total={e['ep_reward']:>7.2f}{flag}" ) step_text = "\n".join(lines) if lines else "Waiting for first episode..." # Episode summary ep_lines = ["Episode | Steps | Total Reward | Outcome", "-" * 44] for e in reversed(eps): ep_lines.append( f" {e['ep']:>4d} | {e['steps']:>3d} | " f" {e['reward']:>+8.2f} | {e['outcome']}" ) ep_text = "\n".join(ep_lines) if eps else "No episodes completed yet." # Mean reward trend if len(eps) >= 2: rewards = [e["reward"] for e in eps] n = len(rewards) half = max(n // 2, 1) early = sum(rewards[:half]) / half late = sum(rewards[half:]) / max(n - half, 1) trend = f"Mean reward (early {half} eps): {early:+.2f} → (last {n-half} eps): {late:+.2f}" arrow = "↑ improving" if late > early else "↓ declining" trend_text = f"{trend} {arrow}" else: trend_text = "Collecting data..." status = "● RUNNING" if _running else "■ STOPPED" return step_text, ep_text, trend_text, status with gr.Blocks(title="OpenENV RL Demo") as demo: gr.Markdown("# OpenENV RL — Live Policy Training\n" "FlatMLPPolicy runs 20 steps per episode on OverflowEnvironment. " "PPO mini-update after each episode — watch rewards improve over time.") with gr.Row(): start_btn = gr.Button("Start", variant="primary") stop_btn = gr.Button("Stop", variant="stop", interactive=False) status_box = gr.Textbox(value="■ STOPPED", label="Status", interactive=False, scale=0, min_width=120) gr.Markdown("### Live Step Feed (most recent 20 steps)") step_display = gr.Textbox( value="Press Start to begin...", lines=22, max_lines=22, interactive=False, elem_id="step_feed", ) with gr.Row(): with gr.Column(): gr.Markdown("### Episode History") ep_display = gr.Textbox(lines=12, interactive=False) with gr.Column(): gr.Markdown("### Reward Trend") trend_display = gr.Textbox(lines=3, interactive=False) # Auto-refresh every 1 second timer = gr.Timer(value=1.0) timer.tick( fn=get_updates, outputs=[step_display, ep_display, trend_display, status_box], ) start_btn.click(fn=start_training, outputs=[start_btn, stop_btn]) stop_btn.click(fn=stop_training, outputs=[start_btn, stop_btn]) if __name__ == "__main__": demo.launch()