pranit / train.py
RushiMane2003's picture
Upload 41 files
99f938a verified
"""
train.py β€” IPPO Agent Training (Step 3)
========================================
Trains 4 independent PPO agents (one per lane: N, S, E, W) using
Stable-Baselines3. Each agent learns from its own local observation
of the shared IntersectionSimulator.
Run:
python train.py
Outputs:
agent_N.zip, agent_S.zip, agent_E.zip, agent_W.zip
results.png (bar chart: Fixed 30s vs MARL)
"""
import os
import sys
import time
import numpy as np
import matplotlib
matplotlib.use("Agg") # Non-interactive backend β€” safe on any OS
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env
from traffic_env import IntersectionSimulator, AgentEnv, PHASES, LANES
# ── Training hyperparameters ───────────────────────────────────────────────
TOTAL_TIMESTEPS = 50_000 # ~2-3 min per agent on CPU (increase for better results)
NET_ARCH = [64, 64] # Two hidden layers β€” small but effective for this task
LEARNING_RATE = 3e-4
N_STEPS = 512
BATCH_SIZE = 64
N_EPOCHS = 10
print("=" * 60)
print(" 🚦 AUTONOMOUS TRAFFIC CONTROL β€” IPPO TRAINING")
print(" 4 Independent PPO Agents (one per lane: N, S, E, W)")
print("=" * 60)
print(f" Timesteps per agent : {TOTAL_TIMESTEPS:,}")
print(f" Network : MLP {NET_ARCH}")
print(f" Learning rate : {LEARNING_RATE}")
print()
# ── Single-agent training wrapper ─────────────────────────────────────────
class SingleAgentTrainEnv(AgentEnv):
"""
Wraps AgentEnv for Stable-Baselines3 training.
The other 3 lanes use a simple heuristic during training
(request green if queue > 5). This is the standard IPPO approach:
train each agent independently in a shared environment.
"""
def __init__(self, lane: str):
self._sim = IntersectionSimulator()
super().__init__(lane, self._sim)
def reset(self, seed=None, options=None):
super().reset(seed=seed)
self._sim.reset()
return self.get_obs(), {}
def step(self, action: int):
# This agent's vote
agent_actions = {self.lane: int(action)}
# Other 3 agents use heuristic: request green if queue > 5
for other in LANES:
if other != self.lane:
agent_actions[other] = 1 if self._sim.queues[other] > 5 else 0
rewards, done = self._sim.step(agent_actions)
obs = self.get_obs()
reward = rewards[self.lane]
return obs, float(reward), done, False, {}
# ── Train all 4 agents ────────────────────────────────────────────────────
agents = {}
train_times = {}
for lane in LANES:
print(f" Training Agent-{lane} ({TOTAL_TIMESTEPS:,} timesteps)...")
t_start = time.time()
env = SingleAgentTrainEnv(lane)
# Quick sanity check on the environment
try:
check_env(env, warn=True, skip_render_check=True)
except Exception as e:
print(f" ⚠️ Env check warning: {e}")
model = PPO(
"MlpPolicy", env,
verbose=0,
learning_rate=LEARNING_RATE,
n_steps=N_STEPS,
batch_size=BATCH_SIZE,
n_epochs=N_EPOCHS,
policy_kwargs=dict(net_arch=NET_ARCH),
)
model.learn(total_timesteps=TOTAL_TIMESTEPS)
model.save(f"agent_{lane}")
elapsed = time.time() - t_start
train_times[lane] = elapsed
agents[lane] = model
print(f" βœ“ Agent-{lane} saved β†’ agent_{lane}.zip ({elapsed:.0f}s)\n")
print(" βœ… All 4 agents trained!\n")
# ── Evaluation: MARL vs Fixed-cycle baseline ──────────────────────────────
def run_marl_episode(agents_dict: dict, n_steps: int = 200) -> float:
"""Run 4 trained agents for one episode. Returns total vehicle-steps waiting."""
sim = IntersectionSimulator()
sim.reset()
envs = {l: AgentEnv(l, sim) for l in LANES}
total_wait = 0
for _ in range(n_steps):
agent_actions = {
lane: int(agents_dict[lane].predict(envs[lane].get_obs(), deterministic=True)[0])
for lane in LANES
}
_, done = sim.step(agent_actions)
total_wait += sum(sim.queues.values())
if done:
break
return total_wait
def run_fixed_episode(n_steps: int = 200, cycle_len: int = 6) -> float:
"""Fixed 30-second cycle baseline. Returns total vehicle-steps waiting."""
sim = IntersectionSimulator()
sim.reset()
timer = 0
phase_list = list(PHASES.keys())
phase_idx = 0
total_wait = 0
for _ in range(n_steps):
if timer >= cycle_len:
phase_idx = (phase_idx + 1) % len(phase_list)
sim.phase = phase_list[phase_idx]
sim.time_in_phase = 0
timer = 0
signals = PHASES[sim.phase]
for lane in LANES:
if signals[lane] == 'GREEN':
sim.queues[lane] = max(0, sim.queues[lane] - 2)
else:
sim.queues[lane] = min(
sim.MAX_QUEUE,
sim.queues[lane] + int(np.random.randint(0, 3))
)
total_wait += sum(sim.queues.values())
timer += 1
return total_wait
N_EVAL_EPISODES = 10
print(f" Evaluating over {N_EVAL_EPISODES} episodes each...")
fixed_scores = [run_fixed_episode() for _ in range(N_EVAL_EPISODES)]
marl_scores = [run_marl_episode(agents) for _ in range(N_EVAL_EPISODES)]
fixed_avg = float(np.mean(fixed_scores))
marl_avg = float(np.mean(marl_scores))
fixed_std = float(np.std(fixed_scores))
marl_std = float(np.std(marl_scores))
improvement = (fixed_avg - marl_avg) / max(fixed_avg, 1) * 100
print("\n" + "=" * 60)
print(" RESULTS (average over 10 evaluation episodes):")
print(f" Fixed 30s cycle : {fixed_avg:>8.0f} Β± {fixed_std:.0f} vehicle-steps waiting")
print(f" 4-Agent MARL : {marl_avg:>8.0f} Β± {marl_std:.0f} vehicle-steps waiting")
print(f" Improvement : {improvement:>+.1f}%")
print("=" * 60)
# ── Plot results ──────────────────────────────────────────────────────────
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
fig.patch.set_facecolor('#0d0d18')
# Bar chart
ax1 = axes[0]
ax1.set_facecolor('#111120')
bars = ax1.bar(
["Fixed 30s Cycle", "4-Agent MARL"],
[fixed_avg, marl_avg],
color=["#ef4444", "#22c55e"],
width=0.5,
edgecolor='none',
yerr=[fixed_std, marl_std],
capsize=6,
error_kw=dict(ecolor='#ffffff', capthick=2, elinewidth=2),
)
for bar, val in zip(bars, [fixed_avg, marl_avg]):
ax1.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height() + fixed_std + 20,
f"{val:.0f}",
ha='center', va='bottom',
color='#e2e8f0', fontsize=12, fontweight='bold'
)
ax1.set_ylabel("Cumulative Vehicle-Steps Waiting", color='#e2e8f0', fontsize=11)
ax1.set_title(
f"Fixed Cycle vs 4-Agent MARL\nImprovement: {improvement:+.1f}%",
color='#e2e8f0', fontsize=13, fontweight='bold', pad=12
)
ax1.tick_params(colors='#94a3b8')
ax1.spines['bottom'].set_color('#1e2030')
ax1.spines['left'].set_color('#1e2030')
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)
ax1.set_ylim(0, max(fixed_avg + fixed_std * 2, marl_avg + marl_std * 2) * 1.25)
ax1.yaxis.label.set_color('#94a3b8')
# Episode-by-episode line chart
ax2 = axes[1]
ax2.set_facecolor('#111120')
ep_x = list(range(1, N_EVAL_EPISODES + 1))
ax2.plot(ep_x, fixed_scores, 'o-', color='#ef4444', linewidth=2,
markersize=6, label='Fixed 30s Cycle')
ax2.plot(ep_x, marl_scores, 's-', color='#22c55e', linewidth=2,
markersize=6, label='4-Agent MARL')
ax2.axhline(y=fixed_avg, color='#ef4444', linestyle='--', alpha=0.5, linewidth=1)
ax2.axhline(y=marl_avg, color='#22c55e', linestyle='--', alpha=0.5, linewidth=1)
ax2.set_xlabel("Evaluation Episode", color='#94a3b8', fontsize=11)
ax2.set_ylabel("Vehicle-Steps Waiting", color='#94a3b8', fontsize=11)
ax2.set_title("Episode-by-Episode Comparison", color='#e2e8f0', fontsize=13,
fontweight='bold', pad=12)
ax2.legend(facecolor='#0d0d18', labelcolor='#e2e8f0', edgecolor='#1e2030',
fontsize=10)
ax2.tick_params(colors='#94a3b8')
ax2.spines['bottom'].set_color('#1e2030')
ax2.spines['left'].set_color('#1e2030')
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)
plt.tight_layout(pad=2.0)
plt.savefig("results.png", dpi=150, bbox_inches='tight',
facecolor=fig.get_facecolor())
plt.close()
print("\n πŸ“Š Saved: results.png")
print("\n Trained model files:")
for lane in LANES:
size_kb = os.path.getsize(f"agent_{lane}.zip") / 1024
print(f" agent_{lane}.zip ({size_kb:.1f} KB) in {train_times[lane]:.0f}s")
print()
print(" Next step: streamlit run dashboard_final.py")
print(" Or: python demo_evp.py")