Spaces:
Runtime error
Runtime error
| """ | |
| OpenENV RL Demo β auto-runs 20 steps per episode using the openenv policy system. | |
| Policies: FlatMLPPolicy / TicketAttentionPolicy (from openenv) | |
| Training: PPO mini-update after each episode β rewards increase over time | |
| Display: Live step-by-step feed + episode reward history | |
| """ | |
| import math, time, threading | |
| import numpy as np | |
| import torch | |
| import torch.optim as optim | |
| import gradio as gr | |
| from overflow_env.environment import OverflowEnvironment | |
| from overflow_env.models import OverflowAction | |
| from policies.flat_mlp_policy import FlatMLPPolicy | |
| from policies.ticket_attention_policy import TicketAttentionPolicy | |
| from policies.policy_spec import build_obs, build_ticket_vector, OBS_DIM | |
| STEPS_PER_EPISODE = 20 | |
| # ββ Observation adapter βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def obs_to_vec(overflow_obs) -> np.ndarray: | |
| cars = overflow_obs.cars | |
| if not cars: | |
| return np.zeros(OBS_DIM, dtype=np.float32) | |
| ego = next((c for c in cars if c.carId == 0), cars[0]) | |
| ego_spd = ego.speed / 4.5 | |
| ego_x = ego.position.x | |
| ego_y = (ego.lane - 2) * 3.7 | |
| tickets = [] | |
| for car in cars: | |
| if car.carId == 0: | |
| continue | |
| rx = car.position.x - ego.position.x | |
| ry = (car.lane - ego.lane) * 3.7 | |
| cs = car.speed / 4.5 | |
| d = math.sqrt(rx**2 + ry**2) | |
| if d > 80: | |
| continue | |
| cl = max(ego_spd - cs * math.copysign(1, max(rx, 0.01)), 0.1) | |
| tickets.append(build_ticket_vector( | |
| severity_weight=1.0 if d < 8 else 0.75 if d < 15 else 0.5, | |
| ttl=5.0, pos_x=rx, pos_y=ry, pos_z=0.0, | |
| vel_x=cs, vel_y=0.0, vel_z=0.0, heading=0.0, | |
| size_length=4.0, size_width=2.0, size_height=1.5, | |
| distance=d, time_to_collision=min(d / cl, 30.0), | |
| bearing=math.atan2(ry, max(rx, 0.01)), | |
| ticket_type="collision_risk", entity_type="vehicle", confidence=1.0, | |
| )) | |
| tv = np.array(tickets, dtype=np.float32) if tickets else None | |
| return build_obs(ego_x=ego_x, ego_y=ego_y, ego_z=0.0, | |
| ego_vx=ego_spd, ego_vy=0.0, | |
| heading=0.0, speed=ego_spd, | |
| steer=0.0, throttle=0.5, brake=0.0, | |
| ticket_vectors=tv) | |
| def action_to_decision(a: np.ndarray) -> str: | |
| s, t, b = float(a[0]), float(a[1]), float(a[2]) | |
| if abs(s) > 0.35: return "lane_change_left" if s < 0 else "lane_change_right" | |
| if b > 0.25: return "brake" | |
| if t > 0.20: return "accelerate" | |
| return "maintain" | |
| # ββ Global training state βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| policy = FlatMLPPolicy(obs_dim=OBS_DIM) | |
| optimizer = optim.Adam(policy.parameters(), lr=3e-4, eps=1e-5) | |
| # Rollout buffer (lightweight β one episode at a time) | |
| _buf_obs = [] | |
| _buf_acts = [] | |
| _buf_rews = [] | |
| _buf_logps = [] | |
| _buf_vals = [] | |
| _buf_dones = [] | |
| episode_history = [] # [{ep, steps, reward, outcome}] | |
| step_log = [] # [{ep, step, decision, reward, scene}] | |
| _running = False | |
| _lock = threading.Lock() | |
| def _ppo_mini_update(): | |
| """Single PPO gradient step on the just-completed episode.""" | |
| if len(_buf_obs) < 2: | |
| return | |
| obs_t = torch.tensor(np.array(_buf_obs), dtype=torch.float32) | |
| acts_t = torch.tensor(np.array(_buf_acts), dtype=torch.float32) | |
| rews_t = torch.tensor(_buf_rews, dtype=torch.float32) | |
| logp_t = torch.tensor(_buf_logps, dtype=torch.float32) | |
| vals_t = torch.tensor(_buf_vals, dtype=torch.float32) | |
| done_t = torch.tensor(_buf_dones, dtype=torch.float32) | |
| # GAE returns | |
| gamma, lam = 0.99, 0.95 | |
| adv = torch.zeros_like(rews_t) | |
| gae = 0.0 | |
| for t in reversed(range(len(rews_t))): | |
| nv = 0.0 if t == len(rews_t) - 1 else float(vals_t[t + 1]) | |
| d = rews_t[t] + gamma * nv * (1 - done_t[t]) - vals_t[t] | |
| gae = d + gamma * lam * (1 - done_t[t]) * gae | |
| adv[t] = gae | |
| ret = adv + vals_t | |
| adv = (adv - adv.mean()) / (adv.std() + 1e-8) | |
| policy.train() | |
| act_mean, val = policy(obs_t) | |
| val = val.squeeze(-1) | |
| dist = torch.distributions.Normal(act_mean, torch.ones_like(act_mean) * 0.3) | |
| logp = dist.log_prob(acts_t).sum(dim=-1) | |
| entropy = dist.entropy().sum(dim=-1).mean() | |
| ratio = torch.exp(logp - logp_t) | |
| pg = torch.max(-adv * ratio, -adv * ratio.clamp(0.8, 1.2)).mean() | |
| vf = 0.5 * ((val - ret) ** 2).mean() | |
| loss = pg + 0.5 * vf - 0.02 * entropy | |
| optimizer.zero_grad() | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(policy.parameters(), 0.5) | |
| optimizer.step() | |
| def run_episodes_loop(): | |
| """Background thread β runs episodes continuously, updates policy after each.""" | |
| global _running | |
| ep_num = 0 | |
| env = OverflowEnvironment() | |
| while _running: | |
| ep_num += 1 | |
| obs = env.reset() | |
| ep_rew = 0.0 | |
| outcome = "timeout" | |
| _buf_obs.clear(); _buf_acts.clear(); _buf_rews.clear() | |
| _buf_logps.clear(); _buf_vals.clear(); _buf_dones.clear() | |
| for step in range(1, STEPS_PER_EPISODE + 1): | |
| if not _running: | |
| break | |
| obs_vec = obs_to_vec(obs) | |
| policy.eval() | |
| with torch.no_grad(): | |
| obs_t = torch.tensor(obs_vec, dtype=torch.float32).unsqueeze(0) | |
| act_mean, val = policy(obs_t) | |
| dist = torch.distributions.Normal(act_mean.squeeze(0), | |
| torch.ones(3) * 0.3) | |
| action = dist.sample().clamp(-1, 1) | |
| logp = dist.log_prob(action).sum() | |
| decision = action_to_decision(action.numpy()) | |
| obs = env.step(OverflowAction(decision=decision, reasoning="")) | |
| reward = float(obs.reward or 0.0) | |
| done = obs.done | |
| _buf_obs.append(obs_vec) | |
| _buf_acts.append(action.numpy()) | |
| _buf_rews.append(reward) | |
| _buf_logps.append(float(logp)) | |
| _buf_vals.append(float(val.squeeze())) | |
| _buf_dones.append(float(done)) | |
| ep_rew += reward | |
| with _lock: | |
| step_log.append({ | |
| "ep": ep_num, "step": step, | |
| "decision": decision, | |
| "reward": round(reward, 2), | |
| "ep_reward": round(ep_rew, 2), | |
| "scene": obs.scene_description.split("\n")[0], | |
| "incident": obs.incident_report or "", | |
| }) | |
| if done: | |
| outcome = "CRASH" if "CRASH" in (obs.incident_report or "") else "GOAL" | |
| break | |
| time.sleep(0.6) # pace so UI can show each step | |
| _ppo_mini_update() | |
| with _lock: | |
| episode_history.append({ | |
| "ep": ep_num, | |
| "steps": step, | |
| "reward": round(ep_rew, 2), | |
| "outcome": outcome, | |
| }) | |
| # ββ Gradio UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def start_training(): | |
| global _running | |
| if not _running: | |
| _running = True | |
| step_log.clear() | |
| episode_history.clear() | |
| t = threading.Thread(target=run_episodes_loop, daemon=True) | |
| t.start() | |
| return gr.update(value="Running...", interactive=False), gr.update(interactive=True) | |
| def stop_training(): | |
| global _running | |
| _running = False | |
| return gr.update(value="Start", interactive=True), gr.update(interactive=False) | |
| def get_updates(): | |
| """Called by gr.Timer every second β returns latest display content.""" | |
| with _lock: | |
| logs = list(step_log[-20:]) | |
| eps = list(episode_history[-30:]) | |
| # Step feed | |
| lines = [] | |
| for e in reversed(logs): | |
| flag = "" | |
| if "CRASH" in e["incident"]: flag = " π₯" | |
| elif "GOAL" in e["incident"]: flag = " β" | |
| elif "NEAR MISS" in e["incident"]: flag = " β " | |
| lines.append( | |
| f"ep {e['ep']:>3d} | step {e['step']:>2d} | " | |
| f"{e['decision']:<20} | r={e['reward']:>+6.2f} | " | |
| f"ep_total={e['ep_reward']:>7.2f}{flag}" | |
| ) | |
| step_text = "\n".join(lines) if lines else "Waiting for first episode..." | |
| # Episode summary | |
| ep_lines = ["Episode | Steps | Total Reward | Outcome", "-" * 44] | |
| for e in reversed(eps): | |
| ep_lines.append( | |
| f" {e['ep']:>4d} | {e['steps']:>3d} | " | |
| f" {e['reward']:>+8.2f} | {e['outcome']}" | |
| ) | |
| ep_text = "\n".join(ep_lines) if eps else "No episodes completed yet." | |
| # Mean reward trend | |
| if len(eps) >= 2: | |
| rewards = [e["reward"] for e in eps] | |
| n = len(rewards) | |
| half = max(n // 2, 1) | |
| early = sum(rewards[:half]) / half | |
| late = sum(rewards[half:]) / max(n - half, 1) | |
| trend = f"Mean reward (early {half} eps): {early:+.2f} β (last {n-half} eps): {late:+.2f}" | |
| arrow = "β improving" if late > early else "β declining" | |
| trend_text = f"{trend} {arrow}" | |
| else: | |
| trend_text = "Collecting data..." | |
| status = "β RUNNING" if _running else "β STOPPED" | |
| return step_text, ep_text, trend_text, status | |
| with gr.Blocks(title="OpenENV RL Demo") as demo: | |
| gr.Markdown("# OpenENV RL β Live Policy Training\n" | |
| "FlatMLPPolicy runs 20 steps per episode on OverflowEnvironment. " | |
| "PPO mini-update after each episode β watch rewards improve over time.") | |
| with gr.Row(): | |
| start_btn = gr.Button("Start", variant="primary") | |
| stop_btn = gr.Button("Stop", variant="stop", interactive=False) | |
| status_box = gr.Textbox(value="β STOPPED", label="Status", | |
| interactive=False, scale=0, min_width=120) | |
| gr.Markdown("### Live Step Feed (most recent 20 steps)") | |
| step_display = gr.Textbox( | |
| value="Press Start to begin...", | |
| lines=22, max_lines=22, interactive=False, | |
| elem_id="step_feed", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Episode History") | |
| ep_display = gr.Textbox(lines=12, interactive=False) | |
| with gr.Column(): | |
| gr.Markdown("### Reward Trend") | |
| trend_display = gr.Textbox(lines=3, interactive=False) | |
| # Auto-refresh every 1 second | |
| timer = gr.Timer(value=1.0) | |
| timer.tick( | |
| fn=get_updates, | |
| outputs=[step_display, ep_display, trend_display, status_box], | |
| ) | |
| start_btn.click(fn=start_training, outputs=[start_btn, stop_btn]) | |
| stop_btn.click(fn=stop_training, outputs=[start_btn, stop_btn]) | |
| if __name__ == "__main__": | |
| demo.launch() | |