""" main.py — RL Traffic Signal Control entry point. Automated pipeline (recommended): python main.py --auto Manual usage: python main.py --mode train --agent q_learning --episodes 50 python main.py --mode train --agent dqn --episodes 150 python main.py --mode eval --agent q_learning --model-path models/q_learning_best.pth python main.py --mode eval --agent dqn --model-path models/dqn_best.pth python main.py --mode fixed # Fixed-signal baseline only The --auto flag runs the full pipeline: 1. Fixed-signal baseline (10 episodes) 2. Q-Learning training (50 episodes) 3. DQN training (150 episodes) 4. Evaluation & comparison plots """ from __future__ import annotations import argparse import sys from pathlib import Path import numpy as np # ── Project imports ─────────────────────────────────────────────────────────── import config as cfg from environment import TrafficEnvironment from agent import QLearningAgent, DQNAgent, DQN_AVAILABLE from training import Trainer, Evaluator from utils import setup_logger, MetricsTracker, plot_training_curves from utils.visualizer import plot_comparison, plot_bar_comparison logger = setup_logger("main") # ═══════════════════════════════════════════════════════════════════════════════ # Factory helpers # ═══════════════════════════════════════════════════════════════════════════════ def make_env() -> TrafficEnvironment: """Create a fresh environment instance.""" return TrafficEnvironment(cfg) def make_q_learning_agent() -> QLearningAgent: """Instantiate a Q-Learning agent using project config.""" return QLearningAgent( state_size=cfg.STATE_SIZE, action_size=cfg.ACTION_SIZE, config=cfg.Q_LEARNING_CONFIG, ) def make_dqn_agent(): """Instantiate a DQN agent using project config (requires PyTorch).""" if not DQN_AVAILABLE: logger.error("PyTorch is not installed — DQN unavailable.") logger.error("Install with: pip install torch") sys.exit(1) return DQNAgent( state_size=cfg.STATE_SIZE, action_size=cfg.ACTION_SIZE, config=cfg.DQN_CONFIG, ) # ═══════════════════════════════════════════════════════════════════════════════ # Fixed-signal baseline # ═══════════════════════════════════════════════════════════════════════════════ class FixedSignalAgent: """ Round-robin fixed-timing signal — cycles phases every 30 steps. Used as the comparison baseline. """ def __init__(self, switch_interval: int = 30): self.switch_interval = switch_interval self._step = 0 def select_action(self, state, training: bool = False) -> int: self._step += 1 return 1 if self._step % self.switch_interval == 0 else 0 def train_step(self, *args, **kwargs): return None def save(self, filepath): pass def load(self, filepath): pass def reset(self): self._step = 0 def run_fixed_baseline(num_episodes: int = 10) -> tuple[list[float], dict]: """ Evaluate the fixed-timing signal for *num_episodes* episodes. Returns: (episode_rewards, summary_dict) """ logger.info("=" * 60) logger.info(f"FIXED-SIGNAL BASELINE ({num_episodes} episodes)") logger.info("=" * 60) agent = FixedSignalAgent(switch_interval=30) env = make_env() rewards: list[float] = [] info: dict = {} for ep in range(1, num_episodes + 1): state, _ = env.reset() agent.reset() ep_reward = 0.0 done = False while not done: action = agent.select_action(state) state, reward, terminated, truncated, info = env.step(action) done = terminated or truncated ep_reward += reward rewards.append(ep_reward) logger.info(f" Episode {ep:3d}/{num_episodes} reward={ep_reward:.2f}") mean_r = float(np.mean(rewards)) logger.info(f"Baseline mean reward: {mean_r:.2f}") return rewards, { "mean_reward": mean_r, "std_reward": float(np.std(rewards)), "best_reward": float(np.max(rewards)), "mean_waiting_time": float(info.get("average_waiting_time", 0)), "mean_queue_length": float(info.get("total_queue_length", 0)), } # ═══════════════════════════════════════════════════════════════════════════════ # Training mode # ═══════════════════════════════════════════════════════════════════════════════ def run_training(agent_type: str, num_episodes: int): """ Train the specified agent and save the best model. Args: agent_type: "q_learning" or "dqn". num_episodes: Number of training episodes. """ logger.info("=" * 60) logger.info(f"TRAINING agent={agent_type} episodes={num_episodes}") logger.info("=" * 60) env = make_env() if agent_type == "q_learning": cfg.AGENT_TYPE = "q_learning" agent = make_q_learning_agent() elif agent_type == "dqn": cfg.AGENT_TYPE = "dqn" agent = make_dqn_agent() else: logger.error(f"Unknown agent type: {agent_type!r}") sys.exit(1) trainer = Trainer(env, agent, cfg) trainer.train(num_episodes) logger.info(f"Training complete. Best reward: {trainer.best_reward:.2f}") return trainer # ═══════════════════════════════════════════════════════════════════════════════ # Evaluation mode # ═══════════════════════════════════════════════════════════════════════════════ def run_evaluation(agent_type: str, model_path: str, num_episodes: int = 10) -> dict: """ Load a saved model and evaluate it. Args: agent_type: "q_learning" or "dqn". model_path: Path to saved model file. num_episodes: Evaluation episodes. Returns: Evaluation results dictionary. """ logger.info("=" * 60) logger.info(f"EVALUATION agent={agent_type} model={model_path}") logger.info("=" * 60) env = make_env() if agent_type == "q_learning": cfg.AGENT_TYPE = "q_learning" agent = make_q_learning_agent() else: cfg.AGENT_TYPE = "dqn" agent = make_dqn_agent() agent.load(model_path) evaluator = Evaluator(env, agent, cfg) results = evaluator.evaluate(num_episodes) logger.info("Evaluation results:") for k, v in results.items(): logger.info(f" {k}: {v:.4f}" if isinstance(v, float) else f" {k}: {v}") return results # ═══════════════════════════════════════════════════════════════════════════════ # Automated pipeline # ═══════════════════════════════════════════════════════════════════════════════ def run_auto_pipeline(): """ Full automated pipeline: 1. Fixed-signal baseline 2. Q-Learning training (50 episodes) 3. DQN training (150 episodes) 4. Evaluation of all methods 5. Comparison plots """ logger.info("╔" + "═" * 58 + "╗") logger.info("║ AUTOMATED RL TRAFFIC SIGNAL CONTROL PIPELINE ║") logger.info("╚" + "═" * 58 + "╝") summary: dict[str, dict] = {} # ── 1. Fixed-signal baseline ────────────────────────────────────── baseline_rewards, baseline_results = run_fixed_baseline(num_episodes=10) summary["Fixed Signal"] = baseline_results # ── 2. Q-Learning ──────────────────────────────────────────────── ql_trainer = run_training("q_learning", num_episodes=50) summary["Q-Learning"] = { "mean_reward": ql_trainer.metrics.get_mean("episode_reward"), "best_reward": ql_trainer.best_reward, "std_reward": ql_trainer.metrics.get_std("episode_reward"), } # ── 3. DQN ─────────────────────────────────────────────────────── if DQN_AVAILABLE: dqn_trainer = run_training("dqn", num_episodes=150) summary["DQN"] = { "mean_reward": dqn_trainer.metrics.get_mean("episode_reward"), "best_reward": dqn_trainer.best_reward, "std_reward": dqn_trainer.metrics.get_std("episode_reward"), } else: logger.warning("DQN skipped (PyTorch not available).") # ── 4. Print comparison table ───────────────────────────────────── _print_comparison_table(summary) # ── 5. Plots ────────────────────────────────────────────────────── _generate_comparison_plots(summary) logger.info("Pipeline complete.") def _print_comparison_table(summary: dict): """Print a neat comparison table to stdout.""" print("\n") print("=" * 60) print(f"{'Method':<18} {'Mean Reward':>14} {'Best Reward':>14}") print("-" * 60) baseline_mean = summary.get("Fixed Signal", {}).get("mean_reward", 0) for method, res in summary.items(): mean_r = res.get("mean_reward", 0) best_r = res.get("best_reward", 0) delta = mean_r - baseline_mean if method != "Fixed Signal" else 0 delta_str = f" ({delta:+.2f})" if method != "Fixed Signal" else "" print(f"{method:<18} {mean_r:>14.2f} {best_r:>14.2f}{delta_str}") print("=" * 60) print() def _generate_comparison_plots(summary: dict): """Save bar-chart comparison of mean rewards.""" scores = {m: r.get("mean_reward", 0) for m, r in summary.items()} save_path = cfg.RESULTS_DIR / "plots" / "comparison_bar.png" plot_bar_comparison( scores, title="Mean Reward by Method (higher = better)", ylabel="Mean Reward", save_path=save_path, ) # ═══════════════════════════════════════════════════════════════════════════════ # CLI # ═══════════════════════════════════════════════════════════════════════════════ def _build_parser() -> argparse.ArgumentParser: p = argparse.ArgumentParser( description="RL Traffic Signal Control", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: python main.py --auto python main.py --mode train --agent q_learning --episodes 50 python main.py --mode train --agent dqn --episodes 150 python main.py --mode eval --agent q_learning --model-path models/q_learning_best.pth python main.py --mode fixed """, ) p.add_argument( "--auto", action="store_true", help="Run the full automated pipeline (recommended)", ) p.add_argument( "--mode", choices=["train", "eval", "fixed"], default="train", help="Mode to run (ignored when --auto is set)", ) p.add_argument( "--agent", choices=["q_learning", "dqn"], default="q_learning", help="Agent type", ) p.add_argument( "--episodes", type=int, default=50, help="Number of episodes", ) p.add_argument( "--model-path", type=str, default=None, help="Path to saved model file (required for --mode eval)", ) return p def main(): parser = _build_parser() args = parser.parse_args() if args.auto: run_auto_pipeline() return if args.mode == "fixed": run_fixed_baseline(num_episodes=args.episodes) elif args.mode == "train": run_training(args.agent, args.episodes) elif args.mode == "eval": if args.model_path is None: parser.error("--model-path is required for --mode eval") run_evaluation(args.agent, args.model_path, num_episodes=10) if __name__ == "__main__": main()