openenv-rl-demo / app.py
aparekh02's picture
bundle overflow_env locally, drop openenv-core git dep (websockets conflict fix)
69d4a95 verified
"""
OpenENV RL Demo β€” auto-runs 20 steps per episode using the openenv policy system.
Policies: FlatMLPPolicy / TicketAttentionPolicy (from openenv)
Training: PPO mini-update after each episode β€” rewards increase over time
Display: Live step-by-step feed + episode reward history
"""
import math, time, threading
import numpy as np
import torch
import torch.optim as optim
import gradio as gr
from overflow_env.environment import OverflowEnvironment
from overflow_env.models import OverflowAction
from policies.flat_mlp_policy import FlatMLPPolicy
from policies.ticket_attention_policy import TicketAttentionPolicy
from policies.policy_spec import build_obs, build_ticket_vector, OBS_DIM
STEPS_PER_EPISODE = 20
# ── Observation adapter ───────────────────────────────────────────────────────
def obs_to_vec(overflow_obs) -> np.ndarray:
cars = overflow_obs.cars
if not cars:
return np.zeros(OBS_DIM, dtype=np.float32)
ego = next((c for c in cars if c.carId == 0), cars[0])
ego_spd = ego.speed / 4.5
ego_x = ego.position.x
ego_y = (ego.lane - 2) * 3.7
tickets = []
for car in cars:
if car.carId == 0:
continue
rx = car.position.x - ego.position.x
ry = (car.lane - ego.lane) * 3.7
cs = car.speed / 4.5
d = math.sqrt(rx**2 + ry**2)
if d > 80:
continue
cl = max(ego_spd - cs * math.copysign(1, max(rx, 0.01)), 0.1)
tickets.append(build_ticket_vector(
severity_weight=1.0 if d < 8 else 0.75 if d < 15 else 0.5,
ttl=5.0, pos_x=rx, pos_y=ry, pos_z=0.0,
vel_x=cs, vel_y=0.0, vel_z=0.0, heading=0.0,
size_length=4.0, size_width=2.0, size_height=1.5,
distance=d, time_to_collision=min(d / cl, 30.0),
bearing=math.atan2(ry, max(rx, 0.01)),
ticket_type="collision_risk", entity_type="vehicle", confidence=1.0,
))
tv = np.array(tickets, dtype=np.float32) if tickets else None
return build_obs(ego_x=ego_x, ego_y=ego_y, ego_z=0.0,
ego_vx=ego_spd, ego_vy=0.0,
heading=0.0, speed=ego_spd,
steer=0.0, throttle=0.5, brake=0.0,
ticket_vectors=tv)
def action_to_decision(a: np.ndarray) -> str:
s, t, b = float(a[0]), float(a[1]), float(a[2])
if abs(s) > 0.35: return "lane_change_left" if s < 0 else "lane_change_right"
if b > 0.25: return "brake"
if t > 0.20: return "accelerate"
return "maintain"
# ── Global training state ─────────────────────────────────────────────────────
policy = FlatMLPPolicy(obs_dim=OBS_DIM)
optimizer = optim.Adam(policy.parameters(), lr=3e-4, eps=1e-5)
# Rollout buffer (lightweight β€” one episode at a time)
_buf_obs = []
_buf_acts = []
_buf_rews = []
_buf_logps = []
_buf_vals = []
_buf_dones = []
episode_history = [] # [{ep, steps, reward, outcome}]
step_log = [] # [{ep, step, decision, reward, scene}]
_running = False
_lock = threading.Lock()
def _ppo_mini_update():
"""Single PPO gradient step on the just-completed episode."""
if len(_buf_obs) < 2:
return
obs_t = torch.tensor(np.array(_buf_obs), dtype=torch.float32)
acts_t = torch.tensor(np.array(_buf_acts), dtype=torch.float32)
rews_t = torch.tensor(_buf_rews, dtype=torch.float32)
logp_t = torch.tensor(_buf_logps, dtype=torch.float32)
vals_t = torch.tensor(_buf_vals, dtype=torch.float32)
done_t = torch.tensor(_buf_dones, dtype=torch.float32)
# GAE returns
gamma, lam = 0.99, 0.95
adv = torch.zeros_like(rews_t)
gae = 0.0
for t in reversed(range(len(rews_t))):
nv = 0.0 if t == len(rews_t) - 1 else float(vals_t[t + 1])
d = rews_t[t] + gamma * nv * (1 - done_t[t]) - vals_t[t]
gae = d + gamma * lam * (1 - done_t[t]) * gae
adv[t] = gae
ret = adv + vals_t
adv = (adv - adv.mean()) / (adv.std() + 1e-8)
policy.train()
act_mean, val = policy(obs_t)
val = val.squeeze(-1)
dist = torch.distributions.Normal(act_mean, torch.ones_like(act_mean) * 0.3)
logp = dist.log_prob(acts_t).sum(dim=-1)
entropy = dist.entropy().sum(dim=-1).mean()
ratio = torch.exp(logp - logp_t)
pg = torch.max(-adv * ratio, -adv * ratio.clamp(0.8, 1.2)).mean()
vf = 0.5 * ((val - ret) ** 2).mean()
loss = pg + 0.5 * vf - 0.02 * entropy
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(policy.parameters(), 0.5)
optimizer.step()
def run_episodes_loop():
"""Background thread β€” runs episodes continuously, updates policy after each."""
global _running
ep_num = 0
env = OverflowEnvironment()
while _running:
ep_num += 1
obs = env.reset()
ep_rew = 0.0
outcome = "timeout"
_buf_obs.clear(); _buf_acts.clear(); _buf_rews.clear()
_buf_logps.clear(); _buf_vals.clear(); _buf_dones.clear()
for step in range(1, STEPS_PER_EPISODE + 1):
if not _running:
break
obs_vec = obs_to_vec(obs)
policy.eval()
with torch.no_grad():
obs_t = torch.tensor(obs_vec, dtype=torch.float32).unsqueeze(0)
act_mean, val = policy(obs_t)
dist = torch.distributions.Normal(act_mean.squeeze(0),
torch.ones(3) * 0.3)
action = dist.sample().clamp(-1, 1)
logp = dist.log_prob(action).sum()
decision = action_to_decision(action.numpy())
obs = env.step(OverflowAction(decision=decision, reasoning=""))
reward = float(obs.reward or 0.0)
done = obs.done
_buf_obs.append(obs_vec)
_buf_acts.append(action.numpy())
_buf_rews.append(reward)
_buf_logps.append(float(logp))
_buf_vals.append(float(val.squeeze()))
_buf_dones.append(float(done))
ep_rew += reward
with _lock:
step_log.append({
"ep": ep_num, "step": step,
"decision": decision,
"reward": round(reward, 2),
"ep_reward": round(ep_rew, 2),
"scene": obs.scene_description.split("\n")[0],
"incident": obs.incident_report or "",
})
if done:
outcome = "CRASH" if "CRASH" in (obs.incident_report or "") else "GOAL"
break
time.sleep(0.6) # pace so UI can show each step
_ppo_mini_update()
with _lock:
episode_history.append({
"ep": ep_num,
"steps": step,
"reward": round(ep_rew, 2),
"outcome": outcome,
})
# ── Gradio UI ─────────────────────────────────────────────────────────────────
def start_training():
global _running
if not _running:
_running = True
step_log.clear()
episode_history.clear()
t = threading.Thread(target=run_episodes_loop, daemon=True)
t.start()
return gr.update(value="Running...", interactive=False), gr.update(interactive=True)
def stop_training():
global _running
_running = False
return gr.update(value="Start", interactive=True), gr.update(interactive=False)
def get_updates():
"""Called by gr.Timer every second β€” returns latest display content."""
with _lock:
logs = list(step_log[-20:])
eps = list(episode_history[-30:])
# Step feed
lines = []
for e in reversed(logs):
flag = ""
if "CRASH" in e["incident"]: flag = " πŸ’₯"
elif "GOAL" in e["incident"]: flag = " βœ“"
elif "NEAR MISS" in e["incident"]: flag = " ⚠"
lines.append(
f"ep {e['ep']:>3d} | step {e['step']:>2d} | "
f"{e['decision']:<20} | r={e['reward']:>+6.2f} | "
f"ep_total={e['ep_reward']:>7.2f}{flag}"
)
step_text = "\n".join(lines) if lines else "Waiting for first episode..."
# Episode summary
ep_lines = ["Episode | Steps | Total Reward | Outcome", "-" * 44]
for e in reversed(eps):
ep_lines.append(
f" {e['ep']:>4d} | {e['steps']:>3d} | "
f" {e['reward']:>+8.2f} | {e['outcome']}"
)
ep_text = "\n".join(ep_lines) if eps else "No episodes completed yet."
# Mean reward trend
if len(eps) >= 2:
rewards = [e["reward"] for e in eps]
n = len(rewards)
half = max(n // 2, 1)
early = sum(rewards[:half]) / half
late = sum(rewards[half:]) / max(n - half, 1)
trend = f"Mean reward (early {half} eps): {early:+.2f} β†’ (last {n-half} eps): {late:+.2f}"
arrow = "↑ improving" if late > early else "↓ declining"
trend_text = f"{trend} {arrow}"
else:
trend_text = "Collecting data..."
status = "● RUNNING" if _running else "β–  STOPPED"
return step_text, ep_text, trend_text, status
with gr.Blocks(title="OpenENV RL Demo") as demo:
gr.Markdown("# OpenENV RL β€” Live Policy Training\n"
"FlatMLPPolicy runs 20 steps per episode on OverflowEnvironment. "
"PPO mini-update after each episode β€” watch rewards improve over time.")
with gr.Row():
start_btn = gr.Button("Start", variant="primary")
stop_btn = gr.Button("Stop", variant="stop", interactive=False)
status_box = gr.Textbox(value="β–  STOPPED", label="Status",
interactive=False, scale=0, min_width=120)
gr.Markdown("### Live Step Feed (most recent 20 steps)")
step_display = gr.Textbox(
value="Press Start to begin...",
lines=22, max_lines=22, interactive=False,
elem_id="step_feed",
)
with gr.Row():
with gr.Column():
gr.Markdown("### Episode History")
ep_display = gr.Textbox(lines=12, interactive=False)
with gr.Column():
gr.Markdown("### Reward Trend")
trend_display = gr.Textbox(lines=3, interactive=False)
# Auto-refresh every 1 second
timer = gr.Timer(value=1.0)
timer.tick(
fn=get_updates,
outputs=[step_display, ep_display, trend_display, status_box],
)
start_btn.click(fn=start_training, outputs=[start_btn, stop_btn])
stop_btn.click(fn=stop_training, outputs=[start_btn, stop_btn])
if __name__ == "__main__":
demo.launch()