Spaces:
Runtime error
Runtime error
| """ | |
| OpenENV RL Demo β Gradio UI entrypoint for HuggingFace Spaces. | |
| Runs inside the overflow_env package root. All imports use absolute paths | |
| so they work both as a package (installed) and as a Space (flat root). | |
| """ | |
| import sys, os | |
| # When running as HF Space, make server/ importable with absolute paths | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| import math, time, threading | |
| import numpy as np | |
| import torch | |
| import torch.optim as optim | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as patches | |
| import gradio as gr | |
| from server.overflow_environment import OverflowEnvironment | |
| from models import OverflowAction | |
| from policies.flat_mlp_policy import FlatMLPPolicy | |
| from policies.policy_spec import build_obs, build_ticket_vector, OBS_DIM | |
| STEPS_PER_EPISODE = 20 | |
| NUM_LANES = 3 | |
| ROAD_LENGTH = 200 | |
| # ββ Observation adapter βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def obs_to_vec(overflow_obs) -> np.ndarray: | |
| cars = overflow_obs.cars | |
| if not cars: | |
| return np.zeros(OBS_DIM, dtype=np.float32) | |
| ego = next((c for c in cars if c.carId == 0), cars[0]) | |
| ego_spd = ego.speed / 4.5 | |
| ego_x = ego.position.x | |
| ego_y = (ego.lane - 2) * 3.7 | |
| tickets = [] | |
| for car in cars: | |
| if car.carId == 0: | |
| continue | |
| rx = car.position.x - ego.position.x | |
| ry = (car.lane - ego.lane) * 3.7 | |
| cs = car.speed / 4.5 | |
| d = math.sqrt(rx**2 + ry**2) | |
| if d > 80: | |
| continue | |
| cl = max(ego_spd - cs * math.copysign(1, max(rx, 0.01)), 0.1) | |
| tickets.append(build_ticket_vector( | |
| severity_weight=1.0 if d < 8 else 0.75 if d < 15 else 0.5, | |
| ttl=5.0, pos_x=rx, pos_y=ry, pos_z=0.0, | |
| vel_x=cs, vel_y=0.0, vel_z=0.0, heading=0.0, | |
| size_length=4.0, size_width=2.0, size_height=1.5, | |
| distance=d, time_to_collision=min(d / cl, 30.0), | |
| bearing=math.atan2(ry, max(rx, 0.01)), | |
| ticket_type="collision_risk", entity_type="vehicle", confidence=1.0, | |
| )) | |
| tv = np.array(tickets, dtype=np.float32) if tickets else None | |
| return build_obs(ego_x=ego_x, ego_y=ego_y, ego_z=0.0, | |
| ego_vx=ego_spd, ego_vy=0.0, | |
| heading=0.0, speed=ego_spd, | |
| steer=0.0, throttle=0.5, brake=0.0, | |
| ticket_vectors=tv) | |
| def action_to_decision(a: np.ndarray) -> str: | |
| s, t, b = float(a[0]), float(a[1]), float(a[2]) | |
| if abs(s) > 0.35: return "lane_change_left" if s < 0 else "lane_change_right" | |
| if b > 0.25: return "brake" | |
| if t > 0.20: return "accelerate" | |
| return "maintain" | |
| # ββ Global training state βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| policy = FlatMLPPolicy(obs_dim=OBS_DIM) | |
| optimizer = optim.Adam(policy.parameters(), lr=3e-4, eps=1e-5) | |
| _buf_obs = [] | |
| _buf_acts = [] | |
| _buf_rews = [] | |
| _buf_logps = [] | |
| _buf_vals = [] | |
| _buf_dones = [] | |
| episode_history = [] | |
| step_log = [] | |
| _running = False | |
| _lock = threading.Lock() | |
| def _ppo_mini_update(): | |
| if len(_buf_obs) < 2: | |
| return | |
| obs_t = torch.tensor(np.array(_buf_obs), dtype=torch.float32) | |
| acts_t = torch.tensor(np.array(_buf_acts), dtype=torch.float32) | |
| rews_t = torch.tensor(_buf_rews, dtype=torch.float32) | |
| logp_t = torch.tensor(_buf_logps, dtype=torch.float32) | |
| vals_t = torch.tensor(_buf_vals, dtype=torch.float32) | |
| done_t = torch.tensor(_buf_dones, dtype=torch.float32) | |
| gamma, lam = 0.99, 0.95 | |
| adv = torch.zeros_like(rews_t) | |
| gae = 0.0 | |
| for t in reversed(range(len(rews_t))): | |
| nv = 0.0 if t == len(rews_t) - 1 else float(vals_t[t + 1]) | |
| d = rews_t[t] + gamma * nv * (1 - done_t[t]) - vals_t[t] | |
| gae = d + gamma * lam * (1 - done_t[t]) * gae | |
| adv[t] = gae | |
| ret = adv + vals_t | |
| adv = (adv - adv.mean()) / (adv.std() + 1e-8) | |
| policy.train() | |
| act_mean, val = policy(obs_t) | |
| val = val.squeeze(-1) | |
| dist = torch.distributions.Normal(act_mean, torch.ones_like(act_mean) * 0.3) | |
| logp = dist.log_prob(acts_t).sum(dim=-1) | |
| entropy = dist.entropy().sum(dim=-1).mean() | |
| ratio = torch.exp(logp - logp_t) | |
| pg = torch.max(-adv * ratio, -adv * ratio.clamp(0.8, 1.2)).mean() | |
| vf = 0.5 * ((val - ret) ** 2).mean() | |
| loss = pg + 0.5 * vf - 0.02 * entropy | |
| optimizer.zero_grad() | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(policy.parameters(), 0.5) | |
| optimizer.step() | |
| def run_episodes_loop(): | |
| global _running | |
| ep_num = 0 | |
| env = OverflowEnvironment() | |
| while _running: | |
| ep_num += 1 | |
| obs = env.reset() | |
| ep_rew = 0.0 | |
| outcome = "timeout" | |
| _buf_obs.clear(); _buf_acts.clear(); _buf_rews.clear() | |
| _buf_logps.clear(); _buf_vals.clear(); _buf_dones.clear() | |
| for step in range(1, STEPS_PER_EPISODE + 1): | |
| if not _running: | |
| break | |
| obs_vec = obs_to_vec(obs) | |
| policy.eval() | |
| with torch.no_grad(): | |
| obs_t = torch.tensor(obs_vec, dtype=torch.float32).unsqueeze(0) | |
| act_mean, val = policy(obs_t) | |
| dist = torch.distributions.Normal(act_mean.squeeze(0), | |
| torch.ones(3) * 0.3) | |
| action = dist.sample().clamp(-1, 1) | |
| logp = dist.log_prob(action).sum() | |
| decision = action_to_decision(action.numpy()) | |
| obs = env.step(OverflowAction(decision=decision, reasoning="")) | |
| reward = float(obs.reward or 0.0) | |
| done = obs.done | |
| ep_rew += reward | |
| _buf_obs.append(obs_vec) | |
| _buf_acts.append(action.numpy()) | |
| _buf_rews.append(reward) | |
| _buf_logps.append(float(logp)) | |
| _buf_vals.append(float(val.squeeze())) | |
| _buf_dones.append(float(done)) | |
| with _lock: | |
| step_log.append({ | |
| "ep": ep_num, | |
| "step": step, | |
| "decision": decision, | |
| "reward": round(reward, 2), | |
| "ep_reward": round(ep_rew, 2), | |
| "incident": obs.incident_report or "", | |
| "cars": [(c.carId, c.lane, c.position.x, c.speed) | |
| for c in obs.cars], | |
| }) | |
| if done: | |
| outcome = "CRASH" if "CRASH" in (obs.incident_report or "") else "GOAL" | |
| break | |
| time.sleep(0.6) | |
| _ppo_mini_update() | |
| with _lock: | |
| episode_history.append({ | |
| "ep": ep_num, | |
| "steps": step, | |
| "reward": round(ep_rew, 2), | |
| "outcome": outcome, | |
| }) | |
| # ββ Plot helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| DECISION_COLORS = { | |
| "accelerate": "#22c55e", | |
| "brake": "#ef4444", | |
| "lane_change_left": "#f59e0b", | |
| "lane_change_right": "#f59e0b", | |
| "maintain": "#60a5fa", | |
| } | |
| def render_road(cars_snapshot, last_decision, last_incident): | |
| fig, ax = plt.subplots(figsize=(10, 2.8)) | |
| fig.patch.set_facecolor("#0f172a") | |
| ax.set_facecolor("#1e293b") | |
| ax.set_xlim(0, ROAD_LENGTH) | |
| ax.set_ylim(0, NUM_LANES + 1) | |
| ax.set_yticks([]) | |
| ax.set_xlabel("Position", color="#94a3b8", fontsize=9) | |
| ax.tick_params(colors="#94a3b8") | |
| for spine in ax.spines.values(): | |
| spine.set_edgecolor("#334155") | |
| for lane in range(1, NUM_LANES): | |
| ax.axhline(y=lane + 0.5, color="#334155", linewidth=1, linestyle="--", alpha=0.6) | |
| for lane in range(1, NUM_LANES + 1): | |
| ax.text(2, lane, f"L{lane}", color="#475569", fontsize=8, va="center") | |
| ax.axvspan(160, ROAD_LENGTH, alpha=0.12, color="#22c55e") | |
| ax.text(162, NUM_LANES + 0.6, "GOAL ZONE", color="#22c55e", fontsize=7, alpha=0.8) | |
| car_w, car_h = 8, 0.55 | |
| for car_id, lane, pos_x, speed in cars_snapshot: | |
| is_ego = car_id == 0 | |
| color = "#3b82f6" if is_ego else "#94a3b8" | |
| outline = "#60a5fa" if is_ego else "#475569" | |
| lw = 2.0 if is_ego else 1.0 | |
| rect = patches.FancyBboxPatch( | |
| (pos_x - car_w / 2, lane - car_h / 2), | |
| car_w, car_h, | |
| boxstyle="round,pad=0.05", | |
| facecolor=color, edgecolor=outline, linewidth=lw, alpha=0.92, | |
| ) | |
| ax.add_patch(rect) | |
| label = f"{'EGO' if is_ego else f'C{car_id}'}\n{speed:.0f}" | |
| ax.text(pos_x, lane, label, ha="center", va="center", | |
| fontsize=6.5, color="white", fontweight="bold" if is_ego else "normal") | |
| dec_color = DECISION_COLORS.get(last_decision, "#60a5fa") | |
| ax.text(ROAD_LENGTH - 2, NUM_LANES + 0.65, | |
| f"Action: {last_decision.replace('_', ' ').upper()}", | |
| color=dec_color, fontsize=8, fontweight="bold", ha="right") | |
| if "CRASH" in last_incident: | |
| ax.text(ROAD_LENGTH / 2, NUM_LANES + 0.65, "CRASH", | |
| color="#ef4444", fontsize=10, fontweight="bold", ha="center") | |
| elif "NEAR MISS" in last_incident: | |
| ax.text(ROAD_LENGTH / 2, NUM_LANES + 0.65, "NEAR MISS", | |
| color="#f59e0b", fontsize=9, fontweight="bold", ha="center") | |
| elif "GOAL" in last_incident: | |
| ax.text(ROAD_LENGTH / 2, NUM_LANES + 0.65, "GOAL REACHED", | |
| color="#22c55e", fontsize=10, fontweight="bold", ha="center") | |
| plt.tight_layout(pad=0.3) | |
| return fig | |
| def render_reward_curve(eps): | |
| fig, ax = plt.subplots(figsize=(10, 2.8)) | |
| fig.patch.set_facecolor("#0f172a") | |
| ax.set_facecolor("#1e293b") | |
| for spine in ax.spines.values(): | |
| spine.set_edgecolor("#334155") | |
| ax.tick_params(colors="#94a3b8") | |
| ax.set_xlabel("Episode", color="#94a3b8", fontsize=9) | |
| ax.set_ylabel("Total Reward", color="#94a3b8", fontsize=9) | |
| if not eps: | |
| ax.text(0.5, 0.5, "Waiting for episodes...", transform=ax.transAxes, | |
| ha="center", va="center", color="#475569", fontsize=11) | |
| plt.tight_layout(pad=0.3) | |
| return fig | |
| xs = [e["ep"] for e in eps] | |
| ys = [e["reward"] for e in eps] | |
| outcome_colors = {"CRASH": "#ef4444", "GOAL": "#22c55e", "timeout": "#60a5fa"} | |
| for x, y, e in zip(xs, ys, eps): | |
| ax.bar(x, y, color=outcome_colors.get(e["outcome"], "#60a5fa"), alpha=0.6, width=0.7) | |
| if len(ys) >= 3: | |
| w = min(5, len(ys)) | |
| smoothed = np.convolve(ys, np.ones(w) / w, mode="valid") | |
| ax.plot(xs[w - 1:], smoothed, color="#f8fafc", linewidth=2) | |
| ax.axhline(0, color="#334155", linewidth=0.8) | |
| from matplotlib.patches import Patch | |
| legend_els = [Patch(facecolor="#ef4444", label="crash"), | |
| Patch(facecolor="#22c55e", label="goal"), | |
| Patch(facecolor="#60a5fa", label="timeout")] | |
| ax.legend(handles=legend_els, facecolor="#1e293b", labelcolor="#94a3b8", | |
| fontsize=8, framealpha=0.6, edgecolor="#334155", loc="upper left") | |
| plt.tight_layout(pad=0.3) | |
| return fig | |
| # ββ Gradio UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def start_training(): | |
| global _running | |
| if not _running: | |
| _running = True | |
| step_log.clear() | |
| episode_history.clear() | |
| threading.Thread(target=run_episodes_loop, daemon=True).start() | |
| return gr.update(value="Running...", interactive=False), gr.update(interactive=True) | |
| def stop_training(): | |
| global _running | |
| _running = False | |
| return gr.update(value="Start", interactive=True), gr.update(interactive=False) | |
| def get_updates(): | |
| with _lock: | |
| logs = list(step_log[-20:]) | |
| eps = list(episode_history[-50:]) | |
| last = step_log[-1] if step_log else None | |
| road_fig = render_road(last["cars"], last["decision"], last["incident"]) if last \ | |
| else render_road([], "maintain", "") | |
| reward_fig = render_reward_curve(eps) | |
| lines = [] | |
| for e in reversed(logs): | |
| flag = "" | |
| if "CRASH" in e["incident"]: flag = " π₯" | |
| elif "GOAL" in e["incident"]: flag = " β" | |
| elif "NEAR MISS" in e["incident"]: flag = " β " | |
| lines.append( | |
| f"ep {e['ep']:>3d} | step {e['step']:>2d} | " | |
| f"{e['decision']:<20} | r={e['reward']:>+6.2f} | " | |
| f"ep_total={e['ep_reward']:>7.2f}{flag}" | |
| ) | |
| step_text = "\n".join(lines) if lines else "Waiting for first episode..." | |
| ep_lines = ["Episode | Steps | Total Reward | Outcome", "-" * 44] | |
| for e in reversed(eps[-15:]): | |
| ep_lines.append( | |
| f" {e['ep']:>4d} | {e['steps']:>3d} | " | |
| f" {e['reward']:>+8.2f} | {e['outcome']}" | |
| ) | |
| ep_text = "\n".join(ep_lines) if eps else "No episodes completed yet." | |
| if len(eps) >= 2: | |
| rewards = [e["reward"] for e in eps] | |
| n = len(rewards) | |
| half = max(n // 2, 1) | |
| early = sum(rewards[:half]) / half | |
| late = sum(rewards[half:]) / max(n - half, 1) | |
| arrow = "β improving" if late > early else "β declining" | |
| trend_text = f"Early {half} eps: {early:+.2f} β Last {n-half} eps: {late:+.2f} {arrow}" | |
| else: | |
| trend_text = "Collecting data..." | |
| status = "β RUNNING" if _running else "β STOPPED" | |
| return road_fig, reward_fig, step_text, ep_text, trend_text, status | |
| _EMPTY_ROAD = render_road([], "maintain", "") | |
| _EMPTY_REWARD = render_reward_curve([]) | |
| with gr.Blocks(title="OpenENV RL Demo", theme=gr.themes.Base()) as demo: | |
| gr.Markdown( | |
| "# OpenENV RL β Live Policy Training\n" | |
| "**FlatMLPPolicy** drives Car 0 on a 3-lane road for 20 steps per episode. " | |
| "PPO mini-update after each episode β watch rewards trend upward over time." | |
| ) | |
| with gr.Row(): | |
| start_btn = gr.Button("Start", variant="primary", scale=1) | |
| stop_btn = gr.Button("Stop", variant="stop", interactive=False, scale=1) | |
| status_box = gr.Textbox(value="β STOPPED", label="Status", | |
| interactive=False, scale=0, min_width=130) | |
| gr.Markdown("### Road View") | |
| road_plot = gr.Plot(value=_EMPTY_ROAD, show_label=False) | |
| gr.Markdown("### Episode Reward Curve") | |
| reward_plot = gr.Plot(value=_EMPTY_REWARD, show_label=False) | |
| gr.Markdown("### Live Step Feed (last 20 steps)") | |
| step_display = gr.Textbox( | |
| value="Press Start to begin...", | |
| lines=14, max_lines=14, interactive=False, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Episode History") | |
| ep_display = gr.Textbox(lines=10, interactive=False) | |
| with gr.Column(): | |
| gr.Markdown("### Reward Trend") | |
| trend_display = gr.Textbox(lines=3, interactive=False) | |
| timer = gr.Timer(value=1.0) | |
| timer.tick( | |
| fn=get_updates, | |
| outputs=[road_plot, reward_plot, step_display, ep_display, trend_display, status_box], | |
| ) | |
| start_btn.click(fn=start_training, outputs=[start_btn, stop_btn]) | |
| stop_btn.click(fn=stop_training, outputs=[start_btn, stop_btn]) | |
| if __name__ == "__main__": | |
| demo.launch() | |