aparekh02's picture
initial push: overflow_env with Gradio RL demo UI
cb054fe verified
"""
OpenENV RL Demo β€” Gradio UI entrypoint for HuggingFace Spaces.
Runs inside the overflow_env package root. All imports use absolute paths
so they work both as a package (installed) and as a Space (flat root).
"""
import sys, os
# When running as HF Space, make server/ importable with absolute paths
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import math, time, threading
import numpy as np
import torch
import torch.optim as optim
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import gradio as gr
from server.overflow_environment import OverflowEnvironment
from models import OverflowAction
from policies.flat_mlp_policy import FlatMLPPolicy
from policies.policy_spec import build_obs, build_ticket_vector, OBS_DIM
STEPS_PER_EPISODE = 20
NUM_LANES = 3
ROAD_LENGTH = 200
# ── Observation adapter ───────────────────────────────────────────────────────
def obs_to_vec(overflow_obs) -> np.ndarray:
cars = overflow_obs.cars
if not cars:
return np.zeros(OBS_DIM, dtype=np.float32)
ego = next((c for c in cars if c.carId == 0), cars[0])
ego_spd = ego.speed / 4.5
ego_x = ego.position.x
ego_y = (ego.lane - 2) * 3.7
tickets = []
for car in cars:
if car.carId == 0:
continue
rx = car.position.x - ego.position.x
ry = (car.lane - ego.lane) * 3.7
cs = car.speed / 4.5
d = math.sqrt(rx**2 + ry**2)
if d > 80:
continue
cl = max(ego_spd - cs * math.copysign(1, max(rx, 0.01)), 0.1)
tickets.append(build_ticket_vector(
severity_weight=1.0 if d < 8 else 0.75 if d < 15 else 0.5,
ttl=5.0, pos_x=rx, pos_y=ry, pos_z=0.0,
vel_x=cs, vel_y=0.0, vel_z=0.0, heading=0.0,
size_length=4.0, size_width=2.0, size_height=1.5,
distance=d, time_to_collision=min(d / cl, 30.0),
bearing=math.atan2(ry, max(rx, 0.01)),
ticket_type="collision_risk", entity_type="vehicle", confidence=1.0,
))
tv = np.array(tickets, dtype=np.float32) if tickets else None
return build_obs(ego_x=ego_x, ego_y=ego_y, ego_z=0.0,
ego_vx=ego_spd, ego_vy=0.0,
heading=0.0, speed=ego_spd,
steer=0.0, throttle=0.5, brake=0.0,
ticket_vectors=tv)
def action_to_decision(a: np.ndarray) -> str:
s, t, b = float(a[0]), float(a[1]), float(a[2])
if abs(s) > 0.35: return "lane_change_left" if s < 0 else "lane_change_right"
if b > 0.25: return "brake"
if t > 0.20: return "accelerate"
return "maintain"
# ── Global training state ─────────────────────────────────────────────────────
policy = FlatMLPPolicy(obs_dim=OBS_DIM)
optimizer = optim.Adam(policy.parameters(), lr=3e-4, eps=1e-5)
_buf_obs = []
_buf_acts = []
_buf_rews = []
_buf_logps = []
_buf_vals = []
_buf_dones = []
episode_history = []
step_log = []
_running = False
_lock = threading.Lock()
def _ppo_mini_update():
if len(_buf_obs) < 2:
return
obs_t = torch.tensor(np.array(_buf_obs), dtype=torch.float32)
acts_t = torch.tensor(np.array(_buf_acts), dtype=torch.float32)
rews_t = torch.tensor(_buf_rews, dtype=torch.float32)
logp_t = torch.tensor(_buf_logps, dtype=torch.float32)
vals_t = torch.tensor(_buf_vals, dtype=torch.float32)
done_t = torch.tensor(_buf_dones, dtype=torch.float32)
gamma, lam = 0.99, 0.95
adv = torch.zeros_like(rews_t)
gae = 0.0
for t in reversed(range(len(rews_t))):
nv = 0.0 if t == len(rews_t) - 1 else float(vals_t[t + 1])
d = rews_t[t] + gamma * nv * (1 - done_t[t]) - vals_t[t]
gae = d + gamma * lam * (1 - done_t[t]) * gae
adv[t] = gae
ret = adv + vals_t
adv = (adv - adv.mean()) / (adv.std() + 1e-8)
policy.train()
act_mean, val = policy(obs_t)
val = val.squeeze(-1)
dist = torch.distributions.Normal(act_mean, torch.ones_like(act_mean) * 0.3)
logp = dist.log_prob(acts_t).sum(dim=-1)
entropy = dist.entropy().sum(dim=-1).mean()
ratio = torch.exp(logp - logp_t)
pg = torch.max(-adv * ratio, -adv * ratio.clamp(0.8, 1.2)).mean()
vf = 0.5 * ((val - ret) ** 2).mean()
loss = pg + 0.5 * vf - 0.02 * entropy
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(policy.parameters(), 0.5)
optimizer.step()
def run_episodes_loop():
global _running
ep_num = 0
env = OverflowEnvironment()
while _running:
ep_num += 1
obs = env.reset()
ep_rew = 0.0
outcome = "timeout"
_buf_obs.clear(); _buf_acts.clear(); _buf_rews.clear()
_buf_logps.clear(); _buf_vals.clear(); _buf_dones.clear()
for step in range(1, STEPS_PER_EPISODE + 1):
if not _running:
break
obs_vec = obs_to_vec(obs)
policy.eval()
with torch.no_grad():
obs_t = torch.tensor(obs_vec, dtype=torch.float32).unsqueeze(0)
act_mean, val = policy(obs_t)
dist = torch.distributions.Normal(act_mean.squeeze(0),
torch.ones(3) * 0.3)
action = dist.sample().clamp(-1, 1)
logp = dist.log_prob(action).sum()
decision = action_to_decision(action.numpy())
obs = env.step(OverflowAction(decision=decision, reasoning=""))
reward = float(obs.reward or 0.0)
done = obs.done
ep_rew += reward
_buf_obs.append(obs_vec)
_buf_acts.append(action.numpy())
_buf_rews.append(reward)
_buf_logps.append(float(logp))
_buf_vals.append(float(val.squeeze()))
_buf_dones.append(float(done))
with _lock:
step_log.append({
"ep": ep_num,
"step": step,
"decision": decision,
"reward": round(reward, 2),
"ep_reward": round(ep_rew, 2),
"incident": obs.incident_report or "",
"cars": [(c.carId, c.lane, c.position.x, c.speed)
for c in obs.cars],
})
if done:
outcome = "CRASH" if "CRASH" in (obs.incident_report or "") else "GOAL"
break
time.sleep(0.6)
_ppo_mini_update()
with _lock:
episode_history.append({
"ep": ep_num,
"steps": step,
"reward": round(ep_rew, 2),
"outcome": outcome,
})
# ── Plot helpers ──────────────────────────────────────────────────────────────
DECISION_COLORS = {
"accelerate": "#22c55e",
"brake": "#ef4444",
"lane_change_left": "#f59e0b",
"lane_change_right": "#f59e0b",
"maintain": "#60a5fa",
}
def render_road(cars_snapshot, last_decision, last_incident):
fig, ax = plt.subplots(figsize=(10, 2.8))
fig.patch.set_facecolor("#0f172a")
ax.set_facecolor("#1e293b")
ax.set_xlim(0, ROAD_LENGTH)
ax.set_ylim(0, NUM_LANES + 1)
ax.set_yticks([])
ax.set_xlabel("Position", color="#94a3b8", fontsize=9)
ax.tick_params(colors="#94a3b8")
for spine in ax.spines.values():
spine.set_edgecolor("#334155")
for lane in range(1, NUM_LANES):
ax.axhline(y=lane + 0.5, color="#334155", linewidth=1, linestyle="--", alpha=0.6)
for lane in range(1, NUM_LANES + 1):
ax.text(2, lane, f"L{lane}", color="#475569", fontsize=8, va="center")
ax.axvspan(160, ROAD_LENGTH, alpha=0.12, color="#22c55e")
ax.text(162, NUM_LANES + 0.6, "GOAL ZONE", color="#22c55e", fontsize=7, alpha=0.8)
car_w, car_h = 8, 0.55
for car_id, lane, pos_x, speed in cars_snapshot:
is_ego = car_id == 0
color = "#3b82f6" if is_ego else "#94a3b8"
outline = "#60a5fa" if is_ego else "#475569"
lw = 2.0 if is_ego else 1.0
rect = patches.FancyBboxPatch(
(pos_x - car_w / 2, lane - car_h / 2),
car_w, car_h,
boxstyle="round,pad=0.05",
facecolor=color, edgecolor=outline, linewidth=lw, alpha=0.92,
)
ax.add_patch(rect)
label = f"{'EGO' if is_ego else f'C{car_id}'}\n{speed:.0f}"
ax.text(pos_x, lane, label, ha="center", va="center",
fontsize=6.5, color="white", fontweight="bold" if is_ego else "normal")
dec_color = DECISION_COLORS.get(last_decision, "#60a5fa")
ax.text(ROAD_LENGTH - 2, NUM_LANES + 0.65,
f"Action: {last_decision.replace('_', ' ').upper()}",
color=dec_color, fontsize=8, fontweight="bold", ha="right")
if "CRASH" in last_incident:
ax.text(ROAD_LENGTH / 2, NUM_LANES + 0.65, "CRASH",
color="#ef4444", fontsize=10, fontweight="bold", ha="center")
elif "NEAR MISS" in last_incident:
ax.text(ROAD_LENGTH / 2, NUM_LANES + 0.65, "NEAR MISS",
color="#f59e0b", fontsize=9, fontweight="bold", ha="center")
elif "GOAL" in last_incident:
ax.text(ROAD_LENGTH / 2, NUM_LANES + 0.65, "GOAL REACHED",
color="#22c55e", fontsize=10, fontweight="bold", ha="center")
plt.tight_layout(pad=0.3)
return fig
def render_reward_curve(eps):
fig, ax = plt.subplots(figsize=(10, 2.8))
fig.patch.set_facecolor("#0f172a")
ax.set_facecolor("#1e293b")
for spine in ax.spines.values():
spine.set_edgecolor("#334155")
ax.tick_params(colors="#94a3b8")
ax.set_xlabel("Episode", color="#94a3b8", fontsize=9)
ax.set_ylabel("Total Reward", color="#94a3b8", fontsize=9)
if not eps:
ax.text(0.5, 0.5, "Waiting for episodes...", transform=ax.transAxes,
ha="center", va="center", color="#475569", fontsize=11)
plt.tight_layout(pad=0.3)
return fig
xs = [e["ep"] for e in eps]
ys = [e["reward"] for e in eps]
outcome_colors = {"CRASH": "#ef4444", "GOAL": "#22c55e", "timeout": "#60a5fa"}
for x, y, e in zip(xs, ys, eps):
ax.bar(x, y, color=outcome_colors.get(e["outcome"], "#60a5fa"), alpha=0.6, width=0.7)
if len(ys) >= 3:
w = min(5, len(ys))
smoothed = np.convolve(ys, np.ones(w) / w, mode="valid")
ax.plot(xs[w - 1:], smoothed, color="#f8fafc", linewidth=2)
ax.axhline(0, color="#334155", linewidth=0.8)
from matplotlib.patches import Patch
legend_els = [Patch(facecolor="#ef4444", label="crash"),
Patch(facecolor="#22c55e", label="goal"),
Patch(facecolor="#60a5fa", label="timeout")]
ax.legend(handles=legend_els, facecolor="#1e293b", labelcolor="#94a3b8",
fontsize=8, framealpha=0.6, edgecolor="#334155", loc="upper left")
plt.tight_layout(pad=0.3)
return fig
# ── Gradio UI ─────────────────────────────────────────────────────────────────
def start_training():
global _running
if not _running:
_running = True
step_log.clear()
episode_history.clear()
threading.Thread(target=run_episodes_loop, daemon=True).start()
return gr.update(value="Running...", interactive=False), gr.update(interactive=True)
def stop_training():
global _running
_running = False
return gr.update(value="Start", interactive=True), gr.update(interactive=False)
def get_updates():
with _lock:
logs = list(step_log[-20:])
eps = list(episode_history[-50:])
last = step_log[-1] if step_log else None
road_fig = render_road(last["cars"], last["decision"], last["incident"]) if last \
else render_road([], "maintain", "")
reward_fig = render_reward_curve(eps)
lines = []
for e in reversed(logs):
flag = ""
if "CRASH" in e["incident"]: flag = " πŸ’₯"
elif "GOAL" in e["incident"]: flag = " βœ“"
elif "NEAR MISS" in e["incident"]: flag = " ⚠"
lines.append(
f"ep {e['ep']:>3d} | step {e['step']:>2d} | "
f"{e['decision']:<20} | r={e['reward']:>+6.2f} | "
f"ep_total={e['ep_reward']:>7.2f}{flag}"
)
step_text = "\n".join(lines) if lines else "Waiting for first episode..."
ep_lines = ["Episode | Steps | Total Reward | Outcome", "-" * 44]
for e in reversed(eps[-15:]):
ep_lines.append(
f" {e['ep']:>4d} | {e['steps']:>3d} | "
f" {e['reward']:>+8.2f} | {e['outcome']}"
)
ep_text = "\n".join(ep_lines) if eps else "No episodes completed yet."
if len(eps) >= 2:
rewards = [e["reward"] for e in eps]
n = len(rewards)
half = max(n // 2, 1)
early = sum(rewards[:half]) / half
late = sum(rewards[half:]) / max(n - half, 1)
arrow = "↑ improving" if late > early else "↓ declining"
trend_text = f"Early {half} eps: {early:+.2f} β†’ Last {n-half} eps: {late:+.2f} {arrow}"
else:
trend_text = "Collecting data..."
status = "● RUNNING" if _running else "β–  STOPPED"
return road_fig, reward_fig, step_text, ep_text, trend_text, status
_EMPTY_ROAD = render_road([], "maintain", "")
_EMPTY_REWARD = render_reward_curve([])
with gr.Blocks(title="OpenENV RL Demo", theme=gr.themes.Base()) as demo:
gr.Markdown(
"# OpenENV RL β€” Live Policy Training\n"
"**FlatMLPPolicy** drives Car 0 on a 3-lane road for 20 steps per episode. "
"PPO mini-update after each episode β€” watch rewards trend upward over time."
)
with gr.Row():
start_btn = gr.Button("Start", variant="primary", scale=1)
stop_btn = gr.Button("Stop", variant="stop", interactive=False, scale=1)
status_box = gr.Textbox(value="β–  STOPPED", label="Status",
interactive=False, scale=0, min_width=130)
gr.Markdown("### Road View")
road_plot = gr.Plot(value=_EMPTY_ROAD, show_label=False)
gr.Markdown("### Episode Reward Curve")
reward_plot = gr.Plot(value=_EMPTY_REWARD, show_label=False)
gr.Markdown("### Live Step Feed (last 20 steps)")
step_display = gr.Textbox(
value="Press Start to begin...",
lines=14, max_lines=14, interactive=False,
)
with gr.Row():
with gr.Column():
gr.Markdown("### Episode History")
ep_display = gr.Textbox(lines=10, interactive=False)
with gr.Column():
gr.Markdown("### Reward Trend")
trend_display = gr.Textbox(lines=3, interactive=False)
timer = gr.Timer(value=1.0)
timer.tick(
fn=get_updates,
outputs=[road_plot, reward_plot, step_display, ep_display, trend_display, status_box],
)
start_btn.click(fn=start_training, outputs=[start_btn, stop_btn])
stop_btn.click(fn=stop_training, outputs=[start_btn, stop_btn])
if __name__ == "__main__":
demo.launch()