""" OpenENV RL Demo — Gradio UI entrypoint for HuggingFace Spaces. Runs inside the overflow_env package root. All imports use absolute paths so they work both as a package (installed) and as a Space (flat root). """ import sys, os # When running as HF Space, make server/ importable with absolute paths sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) import math, time, threading import numpy as np import torch import torch.optim as optim import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.patches as patches import gradio as gr from server.overflow_environment import OverflowEnvironment from models import OverflowAction from policies.flat_mlp_policy import FlatMLPPolicy from policies.policy_spec import build_obs, build_ticket_vector, OBS_DIM STEPS_PER_EPISODE = 20 NUM_LANES = 3 ROAD_LENGTH = 200 # ── Observation adapter ─────────────────────────────────────────────────────── def obs_to_vec(overflow_obs) -> np.ndarray: cars = overflow_obs.cars if not cars: return np.zeros(OBS_DIM, dtype=np.float32) ego = next((c for c in cars if c.carId == 0), cars[0]) ego_spd = ego.speed / 4.5 ego_x = ego.position.x ego_y = (ego.lane - 2) * 3.7 tickets = [] for car in cars: if car.carId == 0: continue rx = car.position.x - ego.position.x ry = (car.lane - ego.lane) * 3.7 cs = car.speed / 4.5 d = math.sqrt(rx**2 + ry**2) if d > 80: continue cl = max(ego_spd - cs * math.copysign(1, max(rx, 0.01)), 0.1) tickets.append(build_ticket_vector( severity_weight=1.0 if d < 8 else 0.75 if d < 15 else 0.5, ttl=5.0, pos_x=rx, pos_y=ry, pos_z=0.0, vel_x=cs, vel_y=0.0, vel_z=0.0, heading=0.0, size_length=4.0, size_width=2.0, size_height=1.5, distance=d, time_to_collision=min(d / cl, 30.0), bearing=math.atan2(ry, max(rx, 0.01)), ticket_type="collision_risk", entity_type="vehicle", confidence=1.0, )) tv = np.array(tickets, dtype=np.float32) if tickets else None return build_obs(ego_x=ego_x, ego_y=ego_y, ego_z=0.0, ego_vx=ego_spd, ego_vy=0.0, heading=0.0, speed=ego_spd, steer=0.0, throttle=0.5, brake=0.0, ticket_vectors=tv) def action_to_decision(a: np.ndarray) -> str: s, t, b = float(a[0]), float(a[1]), float(a[2]) if abs(s) > 0.35: return "lane_change_left" if s < 0 else "lane_change_right" if b > 0.25: return "brake" if t > 0.20: return "accelerate" return "maintain" # ── Global training state ───────────────────────────────────────────────────── policy = FlatMLPPolicy(obs_dim=OBS_DIM) optimizer = optim.Adam(policy.parameters(), lr=3e-4, eps=1e-5) _buf_obs = [] _buf_acts = [] _buf_rews = [] _buf_logps = [] _buf_vals = [] _buf_dones = [] episode_history = [] step_log = [] _running = False _lock = threading.Lock() def _ppo_mini_update(): if len(_buf_obs) < 2: return obs_t = torch.tensor(np.array(_buf_obs), dtype=torch.float32) acts_t = torch.tensor(np.array(_buf_acts), dtype=torch.float32) rews_t = torch.tensor(_buf_rews, dtype=torch.float32) logp_t = torch.tensor(_buf_logps, dtype=torch.float32) vals_t = torch.tensor(_buf_vals, dtype=torch.float32) done_t = torch.tensor(_buf_dones, dtype=torch.float32) gamma, lam = 0.99, 0.95 adv = torch.zeros_like(rews_t) gae = 0.0 for t in reversed(range(len(rews_t))): nv = 0.0 if t == len(rews_t) - 1 else float(vals_t[t + 1]) d = rews_t[t] + gamma * nv * (1 - done_t[t]) - vals_t[t] gae = d + gamma * lam * (1 - done_t[t]) * gae adv[t] = gae ret = adv + vals_t adv = (adv - adv.mean()) / (adv.std() + 1e-8) policy.train() act_mean, val = policy(obs_t) val = val.squeeze(-1) dist = torch.distributions.Normal(act_mean, torch.ones_like(act_mean) * 0.3) logp = dist.log_prob(acts_t).sum(dim=-1) entropy = dist.entropy().sum(dim=-1).mean() ratio = torch.exp(logp - logp_t) pg = torch.max(-adv * ratio, -adv * ratio.clamp(0.8, 1.2)).mean() vf = 0.5 * ((val - ret) ** 2).mean() loss = pg + 0.5 * vf - 0.02 * entropy optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(policy.parameters(), 0.5) optimizer.step() def run_episodes_loop(): global _running ep_num = 0 env = OverflowEnvironment() while _running: ep_num += 1 obs = env.reset() ep_rew = 0.0 outcome = "timeout" _buf_obs.clear(); _buf_acts.clear(); _buf_rews.clear() _buf_logps.clear(); _buf_vals.clear(); _buf_dones.clear() for step in range(1, STEPS_PER_EPISODE + 1): if not _running: break obs_vec = obs_to_vec(obs) policy.eval() with torch.no_grad(): obs_t = torch.tensor(obs_vec, dtype=torch.float32).unsqueeze(0) act_mean, val = policy(obs_t) dist = torch.distributions.Normal(act_mean.squeeze(0), torch.ones(3) * 0.3) action = dist.sample().clamp(-1, 1) logp = dist.log_prob(action).sum() decision = action_to_decision(action.numpy()) obs = env.step(OverflowAction(decision=decision, reasoning="")) reward = float(obs.reward or 0.0) done = obs.done ep_rew += reward _buf_obs.append(obs_vec) _buf_acts.append(action.numpy()) _buf_rews.append(reward) _buf_logps.append(float(logp)) _buf_vals.append(float(val.squeeze())) _buf_dones.append(float(done)) with _lock: step_log.append({ "ep": ep_num, "step": step, "decision": decision, "reward": round(reward, 2), "ep_reward": round(ep_rew, 2), "incident": obs.incident_report or "", "cars": [(c.carId, c.lane, c.position.x, c.speed) for c in obs.cars], }) if done: outcome = "CRASH" if "CRASH" in (obs.incident_report or "") else "GOAL" break time.sleep(0.6) _ppo_mini_update() with _lock: episode_history.append({ "ep": ep_num, "steps": step, "reward": round(ep_rew, 2), "outcome": outcome, }) # ── Plot helpers ────────────────────────────────────────────────────────────── DECISION_COLORS = { "accelerate": "#22c55e", "brake": "#ef4444", "lane_change_left": "#f59e0b", "lane_change_right": "#f59e0b", "maintain": "#60a5fa", } def render_road(cars_snapshot, last_decision, last_incident): fig, ax = plt.subplots(figsize=(10, 2.8)) fig.patch.set_facecolor("#0f172a") ax.set_facecolor("#1e293b") ax.set_xlim(0, ROAD_LENGTH) ax.set_ylim(0, NUM_LANES + 1) ax.set_yticks([]) ax.set_xlabel("Position", color="#94a3b8", fontsize=9) ax.tick_params(colors="#94a3b8") for spine in ax.spines.values(): spine.set_edgecolor("#334155") for lane in range(1, NUM_LANES): ax.axhline(y=lane + 0.5, color="#334155", linewidth=1, linestyle="--", alpha=0.6) for lane in range(1, NUM_LANES + 1): ax.text(2, lane, f"L{lane}", color="#475569", fontsize=8, va="center") ax.axvspan(160, ROAD_LENGTH, alpha=0.12, color="#22c55e") ax.text(162, NUM_LANES + 0.6, "GOAL ZONE", color="#22c55e", fontsize=7, alpha=0.8) car_w, car_h = 8, 0.55 for car_id, lane, pos_x, speed in cars_snapshot: is_ego = car_id == 0 color = "#3b82f6" if is_ego else "#94a3b8" outline = "#60a5fa" if is_ego else "#475569" lw = 2.0 if is_ego else 1.0 rect = patches.FancyBboxPatch( (pos_x - car_w / 2, lane - car_h / 2), car_w, car_h, boxstyle="round,pad=0.05", facecolor=color, edgecolor=outline, linewidth=lw, alpha=0.92, ) ax.add_patch(rect) label = f"{'EGO' if is_ego else f'C{car_id}'}\n{speed:.0f}" ax.text(pos_x, lane, label, ha="center", va="center", fontsize=6.5, color="white", fontweight="bold" if is_ego else "normal") dec_color = DECISION_COLORS.get(last_decision, "#60a5fa") ax.text(ROAD_LENGTH - 2, NUM_LANES + 0.65, f"Action: {last_decision.replace('_', ' ').upper()}", color=dec_color, fontsize=8, fontweight="bold", ha="right") if "CRASH" in last_incident: ax.text(ROAD_LENGTH / 2, NUM_LANES + 0.65, "CRASH", color="#ef4444", fontsize=10, fontweight="bold", ha="center") elif "NEAR MISS" in last_incident: ax.text(ROAD_LENGTH / 2, NUM_LANES + 0.65, "NEAR MISS", color="#f59e0b", fontsize=9, fontweight="bold", ha="center") elif "GOAL" in last_incident: ax.text(ROAD_LENGTH / 2, NUM_LANES + 0.65, "GOAL REACHED", color="#22c55e", fontsize=10, fontweight="bold", ha="center") plt.tight_layout(pad=0.3) return fig def render_reward_curve(eps): fig, ax = plt.subplots(figsize=(10, 2.8)) fig.patch.set_facecolor("#0f172a") ax.set_facecolor("#1e293b") for spine in ax.spines.values(): spine.set_edgecolor("#334155") ax.tick_params(colors="#94a3b8") ax.set_xlabel("Episode", color="#94a3b8", fontsize=9) ax.set_ylabel("Total Reward", color="#94a3b8", fontsize=9) if not eps: ax.text(0.5, 0.5, "Waiting for episodes...", transform=ax.transAxes, ha="center", va="center", color="#475569", fontsize=11) plt.tight_layout(pad=0.3) return fig xs = [e["ep"] for e in eps] ys = [e["reward"] for e in eps] outcome_colors = {"CRASH": "#ef4444", "GOAL": "#22c55e", "timeout": "#60a5fa"} for x, y, e in zip(xs, ys, eps): ax.bar(x, y, color=outcome_colors.get(e["outcome"], "#60a5fa"), alpha=0.6, width=0.7) if len(ys) >= 3: w = min(5, len(ys)) smoothed = np.convolve(ys, np.ones(w) / w, mode="valid") ax.plot(xs[w - 1:], smoothed, color="#f8fafc", linewidth=2) ax.axhline(0, color="#334155", linewidth=0.8) from matplotlib.patches import Patch legend_els = [Patch(facecolor="#ef4444", label="crash"), Patch(facecolor="#22c55e", label="goal"), Patch(facecolor="#60a5fa", label="timeout")] ax.legend(handles=legend_els, facecolor="#1e293b", labelcolor="#94a3b8", fontsize=8, framealpha=0.6, edgecolor="#334155", loc="upper left") plt.tight_layout(pad=0.3) return fig # ── Gradio UI ───────────────────────────────────────────────────────────────── def start_training(): global _running if not _running: _running = True step_log.clear() episode_history.clear() threading.Thread(target=run_episodes_loop, daemon=True).start() return gr.update(value="Running...", interactive=False), gr.update(interactive=True) def stop_training(): global _running _running = False return gr.update(value="Start", interactive=True), gr.update(interactive=False) def get_updates(): with _lock: logs = list(step_log[-20:]) eps = list(episode_history[-50:]) last = step_log[-1] if step_log else None road_fig = render_road(last["cars"], last["decision"], last["incident"]) if last \ else render_road([], "maintain", "") reward_fig = render_reward_curve(eps) lines = [] for e in reversed(logs): flag = "" if "CRASH" in e["incident"]: flag = " 💥" elif "GOAL" in e["incident"]: flag = " ✓" elif "NEAR MISS" in e["incident"]: flag = " ⚠" lines.append( f"ep {e['ep']:>3d} | step {e['step']:>2d} | " f"{e['decision']:<20} | r={e['reward']:>+6.2f} | " f"ep_total={e['ep_reward']:>7.2f}{flag}" ) step_text = "\n".join(lines) if lines else "Waiting for first episode..." ep_lines = ["Episode | Steps | Total Reward | Outcome", "-" * 44] for e in reversed(eps[-15:]): ep_lines.append( f" {e['ep']:>4d} | {e['steps']:>3d} | " f" {e['reward']:>+8.2f} | {e['outcome']}" ) ep_text = "\n".join(ep_lines) if eps else "No episodes completed yet." if len(eps) >= 2: rewards = [e["reward"] for e in eps] n = len(rewards) half = max(n // 2, 1) early = sum(rewards[:half]) / half late = sum(rewards[half:]) / max(n - half, 1) arrow = "↑ improving" if late > early else "↓ declining" trend_text = f"Early {half} eps: {early:+.2f} → Last {n-half} eps: {late:+.2f} {arrow}" else: trend_text = "Collecting data..." status = "● RUNNING" if _running else "■ STOPPED" return road_fig, reward_fig, step_text, ep_text, trend_text, status _EMPTY_ROAD = render_road([], "maintain", "") _EMPTY_REWARD = render_reward_curve([]) with gr.Blocks(title="OpenENV RL Demo", theme=gr.themes.Base()) as demo: gr.Markdown( "# OpenENV RL — Live Policy Training\n" "**FlatMLPPolicy** drives Car 0 on a 3-lane road for 20 steps per episode. " "PPO mini-update after each episode — watch rewards trend upward over time." ) with gr.Row(): start_btn = gr.Button("Start", variant="primary", scale=1) stop_btn = gr.Button("Stop", variant="stop", interactive=False, scale=1) status_box = gr.Textbox(value="■ STOPPED", label="Status", interactive=False, scale=0, min_width=130) gr.Markdown("### Road View") road_plot = gr.Plot(value=_EMPTY_ROAD, show_label=False) gr.Markdown("### Episode Reward Curve") reward_plot = gr.Plot(value=_EMPTY_REWARD, show_label=False) gr.Markdown("### Live Step Feed (last 20 steps)") step_display = gr.Textbox( value="Press Start to begin...", lines=14, max_lines=14, interactive=False, ) with gr.Row(): with gr.Column(): gr.Markdown("### Episode History") ep_display = gr.Textbox(lines=10, interactive=False) with gr.Column(): gr.Markdown("### Reward Trend") trend_display = gr.Textbox(lines=3, interactive=False) timer = gr.Timer(value=1.0) timer.tick( fn=get_updates, outputs=[road_plot, reward_plot, step_display, ep_display, trend_display, status_box], ) start_btn.click(fn=start_training, outputs=[start_btn, stop_btn]) stop_btn.click(fn=stop_training, outputs=[start_btn, stop_btn]) if __name__ == "__main__": demo.launch()