Spaces:
Sleeping
Sleeping
| """ | |
| train.py β IPPO Agent Training (Step 3) | |
| ======================================== | |
| Trains 4 independent PPO agents (one per lane: N, S, E, W) using | |
| Stable-Baselines3. Each agent learns from its own local observation | |
| of the shared IntersectionSimulator. | |
| Run: | |
| python train.py | |
| Outputs: | |
| agent_N.zip, agent_S.zip, agent_E.zip, agent_W.zip | |
| results.png (bar chart: Fixed 30s vs MARL) | |
| """ | |
| import os | |
| import sys | |
| import time | |
| import numpy as np | |
| import matplotlib | |
| matplotlib.use("Agg") # Non-interactive backend β safe on any OS | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as mpatches | |
| from stable_baselines3 import PPO | |
| from stable_baselines3.common.env_checker import check_env | |
| from traffic_env import IntersectionSimulator, AgentEnv, PHASES, LANES | |
| # ββ Training hyperparameters βββββββββββββββββββββββββββββββββββββββββββββββ | |
| TOTAL_TIMESTEPS = 50_000 # ~2-3 min per agent on CPU (increase for better results) | |
| NET_ARCH = [64, 64] # Two hidden layers β small but effective for this task | |
| LEARNING_RATE = 3e-4 | |
| N_STEPS = 512 | |
| BATCH_SIZE = 64 | |
| N_EPOCHS = 10 | |
| print("=" * 60) | |
| print(" π¦ AUTONOMOUS TRAFFIC CONTROL β IPPO TRAINING") | |
| print(" 4 Independent PPO Agents (one per lane: N, S, E, W)") | |
| print("=" * 60) | |
| print(f" Timesteps per agent : {TOTAL_TIMESTEPS:,}") | |
| print(f" Network : MLP {NET_ARCH}") | |
| print(f" Learning rate : {LEARNING_RATE}") | |
| print() | |
| # ββ Single-agent training wrapper βββββββββββββββββββββββββββββββββββββββββ | |
| class SingleAgentTrainEnv(AgentEnv): | |
| """ | |
| Wraps AgentEnv for Stable-Baselines3 training. | |
| The other 3 lanes use a simple heuristic during training | |
| (request green if queue > 5). This is the standard IPPO approach: | |
| train each agent independently in a shared environment. | |
| """ | |
| def __init__(self, lane: str): | |
| self._sim = IntersectionSimulator() | |
| super().__init__(lane, self._sim) | |
| def reset(self, seed=None, options=None): | |
| super().reset(seed=seed) | |
| self._sim.reset() | |
| return self.get_obs(), {} | |
| def step(self, action: int): | |
| # This agent's vote | |
| agent_actions = {self.lane: int(action)} | |
| # Other 3 agents use heuristic: request green if queue > 5 | |
| for other in LANES: | |
| if other != self.lane: | |
| agent_actions[other] = 1 if self._sim.queues[other] > 5 else 0 | |
| rewards, done = self._sim.step(agent_actions) | |
| obs = self.get_obs() | |
| reward = rewards[self.lane] | |
| return obs, float(reward), done, False, {} | |
| # ββ Train all 4 agents ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| agents = {} | |
| train_times = {} | |
| for lane in LANES: | |
| print(f" Training Agent-{lane} ({TOTAL_TIMESTEPS:,} timesteps)...") | |
| t_start = time.time() | |
| env = SingleAgentTrainEnv(lane) | |
| # Quick sanity check on the environment | |
| try: | |
| check_env(env, warn=True, skip_render_check=True) | |
| except Exception as e: | |
| print(f" β οΈ Env check warning: {e}") | |
| model = PPO( | |
| "MlpPolicy", env, | |
| verbose=0, | |
| learning_rate=LEARNING_RATE, | |
| n_steps=N_STEPS, | |
| batch_size=BATCH_SIZE, | |
| n_epochs=N_EPOCHS, | |
| policy_kwargs=dict(net_arch=NET_ARCH), | |
| ) | |
| model.learn(total_timesteps=TOTAL_TIMESTEPS) | |
| model.save(f"agent_{lane}") | |
| elapsed = time.time() - t_start | |
| train_times[lane] = elapsed | |
| agents[lane] = model | |
| print(f" β Agent-{lane} saved β agent_{lane}.zip ({elapsed:.0f}s)\n") | |
| print(" β All 4 agents trained!\n") | |
| # ββ Evaluation: MARL vs Fixed-cycle baseline ββββββββββββββββββββββββββββββ | |
| def run_marl_episode(agents_dict: dict, n_steps: int = 200) -> float: | |
| """Run 4 trained agents for one episode. Returns total vehicle-steps waiting.""" | |
| sim = IntersectionSimulator() | |
| sim.reset() | |
| envs = {l: AgentEnv(l, sim) for l in LANES} | |
| total_wait = 0 | |
| for _ in range(n_steps): | |
| agent_actions = { | |
| lane: int(agents_dict[lane].predict(envs[lane].get_obs(), deterministic=True)[0]) | |
| for lane in LANES | |
| } | |
| _, done = sim.step(agent_actions) | |
| total_wait += sum(sim.queues.values()) | |
| if done: | |
| break | |
| return total_wait | |
| def run_fixed_episode(n_steps: int = 200, cycle_len: int = 6) -> float: | |
| """Fixed 30-second cycle baseline. Returns total vehicle-steps waiting.""" | |
| sim = IntersectionSimulator() | |
| sim.reset() | |
| timer = 0 | |
| phase_list = list(PHASES.keys()) | |
| phase_idx = 0 | |
| total_wait = 0 | |
| for _ in range(n_steps): | |
| if timer >= cycle_len: | |
| phase_idx = (phase_idx + 1) % len(phase_list) | |
| sim.phase = phase_list[phase_idx] | |
| sim.time_in_phase = 0 | |
| timer = 0 | |
| signals = PHASES[sim.phase] | |
| for lane in LANES: | |
| if signals[lane] == 'GREEN': | |
| sim.queues[lane] = max(0, sim.queues[lane] - 2) | |
| else: | |
| sim.queues[lane] = min( | |
| sim.MAX_QUEUE, | |
| sim.queues[lane] + int(np.random.randint(0, 3)) | |
| ) | |
| total_wait += sum(sim.queues.values()) | |
| timer += 1 | |
| return total_wait | |
| N_EVAL_EPISODES = 10 | |
| print(f" Evaluating over {N_EVAL_EPISODES} episodes each...") | |
| fixed_scores = [run_fixed_episode() for _ in range(N_EVAL_EPISODES)] | |
| marl_scores = [run_marl_episode(agents) for _ in range(N_EVAL_EPISODES)] | |
| fixed_avg = float(np.mean(fixed_scores)) | |
| marl_avg = float(np.mean(marl_scores)) | |
| fixed_std = float(np.std(fixed_scores)) | |
| marl_std = float(np.std(marl_scores)) | |
| improvement = (fixed_avg - marl_avg) / max(fixed_avg, 1) * 100 | |
| print("\n" + "=" * 60) | |
| print(" RESULTS (average over 10 evaluation episodes):") | |
| print(f" Fixed 30s cycle : {fixed_avg:>8.0f} Β± {fixed_std:.0f} vehicle-steps waiting") | |
| print(f" 4-Agent MARL : {marl_avg:>8.0f} Β± {marl_std:.0f} vehicle-steps waiting") | |
| print(f" Improvement : {improvement:>+.1f}%") | |
| print("=" * 60) | |
| # ββ Plot results ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| fig, axes = plt.subplots(1, 2, figsize=(12, 5)) | |
| fig.patch.set_facecolor('#0d0d18') | |
| # Bar chart | |
| ax1 = axes[0] | |
| ax1.set_facecolor('#111120') | |
| bars = ax1.bar( | |
| ["Fixed 30s Cycle", "4-Agent MARL"], | |
| [fixed_avg, marl_avg], | |
| color=["#ef4444", "#22c55e"], | |
| width=0.5, | |
| edgecolor='none', | |
| yerr=[fixed_std, marl_std], | |
| capsize=6, | |
| error_kw=dict(ecolor='#ffffff', capthick=2, elinewidth=2), | |
| ) | |
| for bar, val in zip(bars, [fixed_avg, marl_avg]): | |
| ax1.text( | |
| bar.get_x() + bar.get_width() / 2, | |
| bar.get_height() + fixed_std + 20, | |
| f"{val:.0f}", | |
| ha='center', va='bottom', | |
| color='#e2e8f0', fontsize=12, fontweight='bold' | |
| ) | |
| ax1.set_ylabel("Cumulative Vehicle-Steps Waiting", color='#e2e8f0', fontsize=11) | |
| ax1.set_title( | |
| f"Fixed Cycle vs 4-Agent MARL\nImprovement: {improvement:+.1f}%", | |
| color='#e2e8f0', fontsize=13, fontweight='bold', pad=12 | |
| ) | |
| ax1.tick_params(colors='#94a3b8') | |
| ax1.spines['bottom'].set_color('#1e2030') | |
| ax1.spines['left'].set_color('#1e2030') | |
| ax1.spines['top'].set_visible(False) | |
| ax1.spines['right'].set_visible(False) | |
| ax1.set_ylim(0, max(fixed_avg + fixed_std * 2, marl_avg + marl_std * 2) * 1.25) | |
| ax1.yaxis.label.set_color('#94a3b8') | |
| # Episode-by-episode line chart | |
| ax2 = axes[1] | |
| ax2.set_facecolor('#111120') | |
| ep_x = list(range(1, N_EVAL_EPISODES + 1)) | |
| ax2.plot(ep_x, fixed_scores, 'o-', color='#ef4444', linewidth=2, | |
| markersize=6, label='Fixed 30s Cycle') | |
| ax2.plot(ep_x, marl_scores, 's-', color='#22c55e', linewidth=2, | |
| markersize=6, label='4-Agent MARL') | |
| ax2.axhline(y=fixed_avg, color='#ef4444', linestyle='--', alpha=0.5, linewidth=1) | |
| ax2.axhline(y=marl_avg, color='#22c55e', linestyle='--', alpha=0.5, linewidth=1) | |
| ax2.set_xlabel("Evaluation Episode", color='#94a3b8', fontsize=11) | |
| ax2.set_ylabel("Vehicle-Steps Waiting", color='#94a3b8', fontsize=11) | |
| ax2.set_title("Episode-by-Episode Comparison", color='#e2e8f0', fontsize=13, | |
| fontweight='bold', pad=12) | |
| ax2.legend(facecolor='#0d0d18', labelcolor='#e2e8f0', edgecolor='#1e2030', | |
| fontsize=10) | |
| ax2.tick_params(colors='#94a3b8') | |
| ax2.spines['bottom'].set_color('#1e2030') | |
| ax2.spines['left'].set_color('#1e2030') | |
| ax2.spines['top'].set_visible(False) | |
| ax2.spines['right'].set_visible(False) | |
| plt.tight_layout(pad=2.0) | |
| plt.savefig("results.png", dpi=150, bbox_inches='tight', | |
| facecolor=fig.get_facecolor()) | |
| plt.close() | |
| print("\n π Saved: results.png") | |
| print("\n Trained model files:") | |
| for lane in LANES: | |
| size_kb = os.path.getsize(f"agent_{lane}.zip") / 1024 | |
| print(f" agent_{lane}.zip ({size_kb:.1f} KB) in {train_times[lane]:.0f}s") | |
| print() | |
| print(" Next step: streamlit run dashboard_final.py") | |
| print(" Or: python demo_evp.py") | |