"""
Overflow OpenENV — Continuous PPO Training + Live Incident Management Dashboard
Training loop runs in a background thread.
Every step:
- Classifies the scene into an incident type
- Grades the agent's action against that incident
- Records the incident in a live feed
Dashboard:
Left: 2D road canvas + live incident feed (step-by-step decisions)
Right: Episode reward curve + incident response accuracy chart
GET / → HTML dashboard
GET /api/state → JSON snapshot
GET /api/stream → SSE (0.5s heartbeats + episode/incident events)
POST /api/mode → switch reward mode {capped, uncapped}
"""
from __future__ import annotations
import asyncio
import json
import math
import sys
import threading
import time
from collections import deque
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, Deque, Dict, List, Optional
import numpy as np
import torch
# ── Absolute-import fallback ──────────────────────────────────────────────────
try:
from ..training.overflow_gym_env import OverflowGymEnv
from ..training.curriculum import CurriculumManager
from ..training.reward import compute_episode_bonus, IncidentType
from ..training.ppo_trainer import RolloutBuffer
from ..policies.ticket_attention_policy import TicketAttentionPolicy
from ..policies.policy_spec import OBS_DIM
except ImportError:
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from training.overflow_gym_env import OverflowGymEnv
from training.curriculum import CurriculumManager
from training.reward import compute_episode_bonus, IncidentType
from training.ppo_trainer import RolloutBuffer
from policies.ticket_attention_policy import TicketAttentionPolicy
from policies.policy_spec import OBS_DIM
from fastapi import FastAPI
from fastapi.responses import HTMLResponse, StreamingResponse
from pydantic import BaseModel
# ── Reward mode config ────────────────────────────────────────────────────────
REWARD_CAP = 8.0 # capped mode: ceiling per step (matches max correct response)
TOKEN_SCALE = 0.002 # uncapped: reward += tokens * TOKEN_SCALE
MAX_TOKEN_BONUS = 10.0 # uncapped mode max bonus per step
# ── Shared state ──────────────────────────────────────────────────────────────
@dataclass
class CarSnapshot:
car_id: int
x: float
y: float
lane: int
speed: float
@dataclass
class IncidentEvent:
step: int
episode: int
incident_type: str
decision: str
reward: float
grade_desc: str
@dataclass
class TrainingState:
# Road
cars: List[CarSnapshot] = field(default_factory=list)
ego_x: float = 0.0
ego_lane: int = 2
goal_x: float = 180.0
# Training metrics
total_steps: int = 0
n_updates: int = 0
n_episodes: int = 0
episode_reward: float = 0.0
episode_steps: int = 0
mean_reward_100: float = 0.0
mean_ep_len: float = 0.0
stage: int = 1
stage_name: str = "Survival"
reward_mode: str = "capped"
# Incident stats
incident_feed: List[Dict] = field(default_factory=list) # last 100 events
incident_counts: Dict = field(default_factory=dict) # {type: count}
correct_responses: int = 0
total_responses: int = 0
# History for charts
reward_history: List[float] = field(default_factory=list) # per-episode cumulative
cumulative_reward: List[float] = field(default_factory=list) # running net total
accuracy_history: List[float]= field(default_factory=list) # per-episode correct%
episode_history: List[Dict] = field(default_factory=list)
# PPO
last_pg_loss: float = 0.0
last_vf_loss: float = 0.0
last_entropy: float = 0.0
running: bool = True
error: Optional[str] = None
_state = TrainingState()
_state_lock = threading.Lock()
_sse_queue: Deque[str] = deque(maxlen=100)
# Per-episode counters (reset each episode)
_ep_correct = 0
_ep_total = 0
# Which decisions count as "correct" for each incident type
_CORRECT_RESPONSES: Dict[str, set] = {
IncidentType.CRASH_IMMINENT.value: {"brake", "lane_change_left", "lane_change_right"},
IncidentType.NEAR_MISS_AHEAD.value: {"brake", "lane_change_left", "lane_change_right"},
IncidentType.NEAR_MISS_SIDE.value: {"brake", "maintain"},
IncidentType.BLOCKED_AHEAD.value: {"brake", "lane_change_left", "lane_change_right"},
IncidentType.APPROACHING_GOAL.value: {"accelerate", "maintain"},
IncidentType.CLEAR_ROAD.value: {"accelerate", "maintain"},
}
def _is_correct(incident_type: str, decision: str) -> bool:
return decision in _CORRECT_RESPONSES.get(incident_type, set())
def get_reward_mode() -> str:
with _state_lock:
return _state.reward_mode
def set_reward_mode(mode: str) -> None:
with _state_lock:
_state.reward_mode = mode
def apply_reward_mode(base_reward: float, token_count: int = 0) -> float:
mode = get_reward_mode()
if mode == "uncapped":
bonus = min(token_count * TOKEN_SCALE, MAX_TOKEN_BONUS)
return base_reward + bonus
else:
# Only cap positive rewards — keep penalties intact
return min(base_reward, REWARD_CAP) if base_reward > 0 else base_reward
def _push_sse(data: dict) -> None:
_sse_queue.append(json.dumps(data))
# ── Training thread ───────────────────────────────────────────────────────────
def _training_loop() -> None:
global _ep_correct, _ep_total
try:
policy = TicketAttentionPolicy(obs_dim=OBS_DIM)
env = OverflowGymEnv()
curriculum = CurriculumManager()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
policy.to(device)
optimizer = torch.optim.Adam(policy.parameters(), lr=5e-5, eps=1e-5)
GAMMA = 0.99; GAE_LAMBDA = 0.95; CLIP = 0.1
ENT_COEF = 0.005; VF_COEF = 0.5; MAX_GRAD = 0.5
N_STEPS = 1024; BATCH_SIZE = 128; N_EPOCHS = 3
STEP_DELAY = 0.008 # seconds — slows training so episodes are visible
buf = RolloutBuffer(N_STEPS, OBS_DIM, device)
ep_rewards: deque = deque(maxlen=100)
ep_lengths: deque = deque(maxlen=100)
ep_accuracies: deque = deque(maxlen=100)
obs, _ = env.reset()
ep_reward = 0.0; ep_steps = 0
total_steps = 0; n_updates = 0; n_episodes = 0
_ep_correct = 0; _ep_total = 0
t0 = time.time()
while True:
buf.reset()
policy.eval()
for _ in range(N_STEPS):
curriculum.step(env._sim_time)
obs_t = torch.as_tensor(obs, dtype=torch.float32, device=device)
with torch.no_grad():
act_mean, val = policy(obs_t.unsqueeze(0))
act_mean = act_mean.squeeze(0)
val = val.squeeze(0)
dist = torch.distributions.Normal(act_mean, torch.ones_like(act_mean) * 0.15)
action = dist.sample().clamp(-1, 1)
logp = dist.log_prob(action).sum()
next_obs, base_reward, term, trunc, info = env.step(action.cpu().numpy())
reward = apply_reward_mode(base_reward)
# Pace the training so episodes are visible in the dashboard
if STEP_DELAY > 0:
time.sleep(STEP_DELAY)
buf.add(obs, action.cpu().numpy(), reward, float(val), float(logp), float(term or trunc))
obs = next_obs
ep_reward += reward
ep_steps += 1
total_steps += 1
# ── Incident tracking ────────────────────────────────────
inc_type = info.get("incident_type", "CLEAR_ROAD")
decision = info.get("decision", "maintain")
correct = _is_correct(inc_type, decision)
_ep_correct += int(correct)
_ep_total += 1
evt = {
"step": total_steps,
"episode": n_episodes + 1,
"incident_type": inc_type,
"decision": decision,
"reward": round(reward, 2),
"correct": correct,
"grade_desc": info.get("grade_desc", ""),
}
# Pull car positions from env internals
cars = []
if hasattr(env._env, "_cars"):
for c in env._env._cars:
cars.append(CarSnapshot(
car_id=c.car_id, x=c.position,
y=(c.lane - 2) * 3.7, lane=c.lane, speed=c.speed,
))
goal_x = 180.0
if hasattr(env._env, "_cars") and env._env._cars:
agent_car = next((c for c in env._env._cars if c.car_id == 0), None)
if agent_car:
goal_x = agent_car.goal_position
with _state_lock:
_state.total_steps = total_steps
_state.episode_reward = round(ep_reward, 2)
_state.episode_steps = ep_steps
_state.goal_x = goal_x
_state.correct_responses += int(correct)
_state.total_responses += 1
_state.incident_counts[inc_type] = _state.incident_counts.get(inc_type, 0) + 1
_state.incident_feed.append(evt)
if len(_state.incident_feed) > 100:
_state.incident_feed = _state.incident_feed[-100:]
if cars:
_state.cars = cars
ego = next((c for c in cars if c.car_id == 0), None)
if ego:
_state.ego_x = ego.x
_state.ego_lane = ego.lane
# Push incident event to SSE (only threatening incidents or every 10 steps)
if inc_type not in (IncidentType.CLEAR_ROAD.value,) or total_steps % 10 == 0:
_push_sse({"type": "incident", "data": evt})
# ── Episode end ──────────────────────────────────────────
if term or trunc:
bonus = compute_episode_bonus(
total_steps=ep_steps,
survived=not info.get("collision", False),
)
ep_reward += bonus
n_episodes += 1
ep_rewards.append(ep_reward)
ep_lengths.append(ep_steps)
acc = (_ep_correct / _ep_total * 100) if _ep_total > 0 else 0.0
ep_accuracies.append(acc)
advanced = curriculum.record_episode_reward(ep_reward)
outcome = (
"crash" if info.get("collision") else
"goal" if info.get("goal_reached") else "timeout"
)
ep_rec = {
"episode": n_episodes,
"steps": ep_steps,
"reward": round(ep_reward, 2),
"outcome": outcome,
"stage": curriculum.current_stage,
"accuracy": round(acc, 1),
"reward_mode": get_reward_mode(),
}
with _state_lock:
_state.n_episodes = n_episodes
_state.stage = curriculum.current_stage
_state.stage_name = curriculum.config.name
_state.mean_reward_100 = round(float(np.mean(ep_rewards)), 2)
_state.mean_ep_len = round(float(np.mean(ep_lengths)), 1)
_state.reward_history.append(round(ep_reward, 2))
prev_cum = _state.cumulative_reward[-1] if _state.cumulative_reward else 0.0
_state.cumulative_reward.append(round(prev_cum + ep_reward, 2))
_state.accuracy_history.append(round(acc, 1))
_state.episode_history.append(ep_rec)
_push_sse({"type": "episode", "data": ep_rec})
obs, _ = env.reset()
ep_reward = 0.0; ep_steps = 0
_ep_correct = 0; _ep_total = 0
# ── PPO update ────────────────────────────────────────────────
with torch.no_grad():
obs_t = torch.as_tensor(obs, dtype=torch.float32, device=device)
_, last_val = policy(obs_t.unsqueeze(0))
buf.compute_returns(float(last_val), GAMMA, GAE_LAMBDA)
policy.train()
all_obs = buf.obs; all_acts = buf.acts; old_logp = buf.logp
adv = buf.ret - buf.val
adv = (adv - adv.mean()) / (adv.std() + 1e-8)
ret = buf.ret; old_val = buf.val
indices = torch.randperm(N_STEPS, device=device)
pg_ls, vf_ls, ents = [], [], []
for _ in range(N_EPOCHS):
for start in range(0, N_STEPS, BATCH_SIZE):
idx = indices[start: start + BATCH_SIZE]
am, val = policy(all_obs[idx])
val = val.squeeze(-1)
dist = torch.distributions.Normal(am, torch.ones_like(am) * 0.15)
logp = dist.log_prob(all_acts[idx]).sum(dim=-1)
ent = dist.entropy().sum(dim=-1).mean()
ratio = torch.exp(logp - old_logp[idx])
pg_loss = torch.max(-adv[idx]*ratio, -adv[idx]*ratio.clamp(1-CLIP, 1+CLIP)).mean()
vc = old_val[idx] + (val - old_val[idx]).clamp(-CLIP, CLIP)
vf_loss = 0.5*torch.max((val-ret[idx])**2, (vc-ret[idx])**2).mean()
loss = pg_loss + VF_COEF*vf_loss - ENT_COEF*ent
optimizer.zero_grad(); loss.backward()
torch.nn.utils.clip_grad_norm_(policy.parameters(), MAX_GRAD)
optimizer.step()
pg_ls.append(float(pg_loss)); vf_ls.append(float(vf_loss)); ents.append(float(ent))
n_updates += 1
with _state_lock:
_state.n_updates = n_updates
_state.last_pg_loss = round(float(np.mean(pg_ls)), 5)
_state.last_vf_loss = round(float(np.mean(vf_ls)), 5)
_state.last_entropy = round(float(np.mean(ents)), 5)
_state.steps_per_sec = round(total_steps / max(time.time() - t0, 1.0), 1)
_push_sse({"type": "update", "data": {
"n_updates": n_updates, "total_steps": total_steps,
"mean_reward": round(float(np.mean(ep_rewards)) if ep_rewards else 0.0, 2),
"stage": curriculum.current_stage,
"pg_loss": round(float(np.mean(pg_ls)), 5),
"vf_loss": round(float(np.mean(vf_ls)), 5),
"entropy": round(float(np.mean(ents)), 5),
}})
except Exception as exc:
import traceback
tb = traceback.format_exc()
with _state_lock:
_state.running = False
_state.error = f"{exc}\n\n{tb}"
print(f"[Training] FATAL: {exc}\n{tb}", flush=True)
# ── FastAPI ───────────────────────────────────────────────────────────────────
app = FastAPI(title="Overflow OpenENV")
_training_thread: Optional[threading.Thread] = None
@app.on_event("startup")
def _start_training():
global _training_thread
_training_thread = threading.Thread(target=_training_loop, daemon=True)
_training_thread.start()
@app.get("/health")
def health():
return {"status": "ok"}
class ModeRequest(BaseModel):
mode: str
@app.post("/api/mode")
def set_mode(req: ModeRequest):
if req.mode not in ("capped", "uncapped"):
return {"error": "mode must be capped or uncapped"}
set_reward_mode(req.mode)
return {"mode": req.mode}
@app.post("/api/restart")
def restart_training():
global _state, _training_thread, _ep_correct, _ep_total
with _state_lock:
mode = _state.reward_mode # preserve current mode selection
_state = TrainingState()
_state.reward_mode = mode
_ep_correct = 0
_ep_total = 0
_sse_queue.clear()
# Start a fresh training thread (old one is daemon and will die)
_training_thread = threading.Thread(target=_training_loop, daemon=True)
_training_thread.start()
return {"status": "restarted"}
@app.get("/api/state")
def get_state():
with _state_lock:
s = _state
acc = round(s.correct_responses / s.total_responses * 100, 1) if s.total_responses > 0 else 0.0
return {
"total_steps": s.total_steps,
"n_updates": s.n_updates,
"n_episodes": s.n_episodes,
"episode_reward": s.episode_reward,
"episode_steps": s.episode_steps,
"mean_reward": s.mean_reward_100,
"mean_ep_len": s.mean_ep_len,
"stage": s.stage,
"stage_name": s.stage_name,
"reward_mode": s.reward_mode,
"steps_per_sec": s.steps_per_sec if hasattr(s, "steps_per_sec") else 0,
"pg_loss": s.last_pg_loss,
"vf_loss": s.last_vf_loss,
"entropy": s.last_entropy,
"reward_history": s.reward_history,
"cumulative_reward": s.cumulative_reward,
"accuracy_history": s.accuracy_history,
"episode_history": s.episode_history[-50:],
"incident_feed": s.incident_feed[-30:],
"incident_counts": s.incident_counts,
"response_accuracy": acc,
"cars": [asdict(c) for c in s.cars],
"ego_x": s.ego_x,
"ego_lane": s.ego_lane,
"goal_x": s.goal_x,
"running": s.running,
"error": s.error,
}
@app.get("/api/stream")
async def sse_stream():
async def gen():
last_idx = len(_sse_queue)
while True:
current = list(_sse_queue)
for msg in current[last_idx:]:
yield f"data: {msg}\n\n"
last_idx = len(current)
with _state_lock:
s = _state
acc = round(s.correct_responses / s.total_responses * 100, 1) if s.total_responses > 0 else 0.0
snap = {
"type": "tick",
"data": {
"total_steps": s.total_steps,
"episode_reward": s.episode_reward,
"episode_steps": s.episode_steps,
"stage": s.stage,
"stage_name": s.stage_name,
"reward_mode": s.reward_mode,
"response_accuracy": acc,
"cars": [asdict(c) for c in s.cars],
"ego_x": s.ego_x,
"goal_x": s.goal_x,
}
}
yield f"data: {json.dumps(snap)}\n\n"
await asyncio.sleep(0.5)
return StreamingResponse(gen(), media_type="text/event-stream")
# ── HTML Dashboard ────────────────────────────────────────────────────────────
DASHBOARD_HTML = r"""
Overflow OpenENV — Incident Management Training
Response Acc.
0%
correct actions
INCIDENT FEED — LIVE DECISIONS ACC 0%
STEP
INCIDENT
DECISION
REWARD
OK?
EPISODE REWARD — CAPPED | mean: 0.00 | net: 0.00
Why does net reward start negative?
Early in training the agent has a random policy — it crashes frequently and earns large negative penalties
(collision = heavy punishment). The cumulative net drops as these early failures stack up. As PPO learns to
avoid crashes and respond correctly to incidents, per-episode rewards turn positive and the net climbs back.
The crossover point where net goes positive is when the agent's learned gains have fully offset its early
exploration cost — a sign that training is working.
RESPONSE ACCURACY % (per episode)
EPISODE LOG
| # | Steps | Reward | Outcome | Acc% | Stage |
PG: —
VF: —
Ent: —
Mode: capped
"""
@app.get("/", response_class=HTMLResponse)
@app.get("/web", response_class=HTMLResponse)
def dashboard():
return HTMLResponse(content=DASHBOARD_HTML)
def main():
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
if __name__ == "__main__":
main()