Spaces:
Sleeping
Sleeping
| """ | |
| main.py β RL Traffic Signal Control entry point. | |
| Automated pipeline (recommended): | |
| python main.py --auto | |
| Manual usage: | |
| python main.py --mode train --agent q_learning --episodes 50 | |
| python main.py --mode train --agent dqn --episodes 150 | |
| python main.py --mode eval --agent q_learning --model-path models/q_learning_best.pth | |
| python main.py --mode eval --agent dqn --model-path models/dqn_best.pth | |
| python main.py --mode fixed # Fixed-signal baseline only | |
| The --auto flag runs the full pipeline: | |
| 1. Fixed-signal baseline (10 episodes) | |
| 2. Q-Learning training (50 episodes) | |
| 3. DQN training (150 episodes) | |
| 4. Evaluation & comparison plots | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import sys | |
| from pathlib import Path | |
| import numpy as np | |
| # ββ Project imports βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| import config as cfg | |
| from environment import TrafficEnvironment | |
| from agent import QLearningAgent, DQNAgent, DQN_AVAILABLE | |
| from training import Trainer, Evaluator | |
| from utils import setup_logger, MetricsTracker, plot_training_curves | |
| from utils.visualizer import plot_comparison, plot_bar_comparison | |
| logger = setup_logger("main") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Factory helpers | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def make_env() -> TrafficEnvironment: | |
| """Create a fresh environment instance.""" | |
| return TrafficEnvironment(cfg) | |
| def make_q_learning_agent() -> QLearningAgent: | |
| """Instantiate a Q-Learning agent using project config.""" | |
| return QLearningAgent( | |
| state_size=cfg.STATE_SIZE, | |
| action_size=cfg.ACTION_SIZE, | |
| config=cfg.Q_LEARNING_CONFIG, | |
| ) | |
| def make_dqn_agent(): | |
| """Instantiate a DQN agent using project config (requires PyTorch).""" | |
| if not DQN_AVAILABLE: | |
| logger.error("PyTorch is not installed β DQN unavailable.") | |
| logger.error("Install with: pip install torch") | |
| sys.exit(1) | |
| return DQNAgent( | |
| state_size=cfg.STATE_SIZE, | |
| action_size=cfg.ACTION_SIZE, | |
| config=cfg.DQN_CONFIG, | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Fixed-signal baseline | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class FixedSignalAgent: | |
| """ | |
| Round-robin fixed-timing signal β cycles phases every 30 steps. | |
| Used as the comparison baseline. | |
| """ | |
| def __init__(self, switch_interval: int = 30): | |
| self.switch_interval = switch_interval | |
| self._step = 0 | |
| def select_action(self, state, training: bool = False) -> int: | |
| self._step += 1 | |
| return 1 if self._step % self.switch_interval == 0 else 0 | |
| def train_step(self, *args, **kwargs): | |
| return None | |
| def save(self, filepath): | |
| pass | |
| def load(self, filepath): | |
| pass | |
| def reset(self): | |
| self._step = 0 | |
| def run_fixed_baseline(num_episodes: int = 10) -> tuple[list[float], dict]: | |
| """ | |
| Evaluate the fixed-timing signal for *num_episodes* episodes. | |
| Returns: | |
| (episode_rewards, summary_dict) | |
| """ | |
| logger.info("=" * 60) | |
| logger.info(f"FIXED-SIGNAL BASELINE ({num_episodes} episodes)") | |
| logger.info("=" * 60) | |
| agent = FixedSignalAgent(switch_interval=30) | |
| env = make_env() | |
| rewards: list[float] = [] | |
| info: dict = {} | |
| for ep in range(1, num_episodes + 1): | |
| state, _ = env.reset() | |
| agent.reset() | |
| ep_reward = 0.0 | |
| done = False | |
| while not done: | |
| action = agent.select_action(state) | |
| state, reward, terminated, truncated, info = env.step(action) | |
| done = terminated or truncated | |
| ep_reward += reward | |
| rewards.append(ep_reward) | |
| logger.info(f" Episode {ep:3d}/{num_episodes} reward={ep_reward:.2f}") | |
| mean_r = float(np.mean(rewards)) | |
| logger.info(f"Baseline mean reward: {mean_r:.2f}") | |
| return rewards, { | |
| "mean_reward": mean_r, | |
| "std_reward": float(np.std(rewards)), | |
| "best_reward": float(np.max(rewards)), | |
| "mean_waiting_time": float(info.get("average_waiting_time", 0)), | |
| "mean_queue_length": float(info.get("total_queue_length", 0)), | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Training mode | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_training(agent_type: str, num_episodes: int): | |
| """ | |
| Train the specified agent and save the best model. | |
| Args: | |
| agent_type: "q_learning" or "dqn". | |
| num_episodes: Number of training episodes. | |
| """ | |
| logger.info("=" * 60) | |
| logger.info(f"TRAINING agent={agent_type} episodes={num_episodes}") | |
| logger.info("=" * 60) | |
| env = make_env() | |
| if agent_type == "q_learning": | |
| cfg.AGENT_TYPE = "q_learning" | |
| agent = make_q_learning_agent() | |
| elif agent_type == "dqn": | |
| cfg.AGENT_TYPE = "dqn" | |
| agent = make_dqn_agent() | |
| else: | |
| logger.error(f"Unknown agent type: {agent_type!r}") | |
| sys.exit(1) | |
| trainer = Trainer(env, agent, cfg) | |
| trainer.train(num_episodes) | |
| logger.info(f"Training complete. Best reward: {trainer.best_reward:.2f}") | |
| return trainer | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Evaluation mode | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_evaluation(agent_type: str, model_path: str, num_episodes: int = 10) -> dict: | |
| """ | |
| Load a saved model and evaluate it. | |
| Args: | |
| agent_type: "q_learning" or "dqn". | |
| model_path: Path to saved model file. | |
| num_episodes: Evaluation episodes. | |
| Returns: | |
| Evaluation results dictionary. | |
| """ | |
| logger.info("=" * 60) | |
| logger.info(f"EVALUATION agent={agent_type} model={model_path}") | |
| logger.info("=" * 60) | |
| env = make_env() | |
| if agent_type == "q_learning": | |
| cfg.AGENT_TYPE = "q_learning" | |
| agent = make_q_learning_agent() | |
| else: | |
| cfg.AGENT_TYPE = "dqn" | |
| agent = make_dqn_agent() | |
| agent.load(model_path) | |
| evaluator = Evaluator(env, agent, cfg) | |
| results = evaluator.evaluate(num_episodes) | |
| logger.info("Evaluation results:") | |
| for k, v in results.items(): | |
| logger.info(f" {k}: {v:.4f}" if isinstance(v, float) else f" {k}: {v}") | |
| return results | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Automated pipeline | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_auto_pipeline(): | |
| """ | |
| Full automated pipeline: | |
| 1. Fixed-signal baseline | |
| 2. Q-Learning training (50 episodes) | |
| 3. DQN training (150 episodes) | |
| 4. Evaluation of all methods | |
| 5. Comparison plots | |
| """ | |
| logger.info("β" + "β" * 58 + "β") | |
| logger.info("β AUTOMATED RL TRAFFIC SIGNAL CONTROL PIPELINE β") | |
| logger.info("β" + "β" * 58 + "β") | |
| summary: dict[str, dict] = {} | |
| # ββ 1. Fixed-signal baseline ββββββββββββββββββββββββββββββββββββββ | |
| baseline_rewards, baseline_results = run_fixed_baseline(num_episodes=10) | |
| summary["Fixed Signal"] = baseline_results | |
| # ββ 2. Q-Learning ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ql_trainer = run_training("q_learning", num_episodes=50) | |
| summary["Q-Learning"] = { | |
| "mean_reward": ql_trainer.metrics.get_mean("episode_reward"), | |
| "best_reward": ql_trainer.best_reward, | |
| "std_reward": ql_trainer.metrics.get_std("episode_reward"), | |
| } | |
| # ββ 3. DQN βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if DQN_AVAILABLE: | |
| dqn_trainer = run_training("dqn", num_episodes=150) | |
| summary["DQN"] = { | |
| "mean_reward": dqn_trainer.metrics.get_mean("episode_reward"), | |
| "best_reward": dqn_trainer.best_reward, | |
| "std_reward": dqn_trainer.metrics.get_std("episode_reward"), | |
| } | |
| else: | |
| logger.warning("DQN skipped (PyTorch not available).") | |
| # ββ 4. Print comparison table βββββββββββββββββββββββββββββββββββββ | |
| _print_comparison_table(summary) | |
| # ββ 5. Plots ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _generate_comparison_plots(summary) | |
| logger.info("Pipeline complete.") | |
| def _print_comparison_table(summary: dict): | |
| """Print a neat comparison table to stdout.""" | |
| print("\n") | |
| print("=" * 60) | |
| print(f"{'Method':<18} {'Mean Reward':>14} {'Best Reward':>14}") | |
| print("-" * 60) | |
| baseline_mean = summary.get("Fixed Signal", {}).get("mean_reward", 0) | |
| for method, res in summary.items(): | |
| mean_r = res.get("mean_reward", 0) | |
| best_r = res.get("best_reward", 0) | |
| delta = mean_r - baseline_mean if method != "Fixed Signal" else 0 | |
| delta_str = f" ({delta:+.2f})" if method != "Fixed Signal" else "" | |
| print(f"{method:<18} {mean_r:>14.2f} {best_r:>14.2f}{delta_str}") | |
| print("=" * 60) | |
| print() | |
| def _generate_comparison_plots(summary: dict): | |
| """Save bar-chart comparison of mean rewards.""" | |
| scores = {m: r.get("mean_reward", 0) for m, r in summary.items()} | |
| save_path = cfg.RESULTS_DIR / "plots" / "comparison_bar.png" | |
| plot_bar_comparison( | |
| scores, | |
| title="Mean Reward by Method (higher = better)", | |
| ylabel="Mean Reward", | |
| save_path=save_path, | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CLI | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _build_parser() -> argparse.ArgumentParser: | |
| p = argparse.ArgumentParser( | |
| description="RL Traffic Signal Control", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| Examples: | |
| python main.py --auto | |
| python main.py --mode train --agent q_learning --episodes 50 | |
| python main.py --mode train --agent dqn --episodes 150 | |
| python main.py --mode eval --agent q_learning --model-path models/q_learning_best.pth | |
| python main.py --mode fixed | |
| """, | |
| ) | |
| p.add_argument( | |
| "--auto", | |
| action="store_true", | |
| help="Run the full automated pipeline (recommended)", | |
| ) | |
| p.add_argument( | |
| "--mode", | |
| choices=["train", "eval", "fixed"], | |
| default="train", | |
| help="Mode to run (ignored when --auto is set)", | |
| ) | |
| p.add_argument( | |
| "--agent", | |
| choices=["q_learning", "dqn"], | |
| default="q_learning", | |
| help="Agent type", | |
| ) | |
| p.add_argument( | |
| "--episodes", | |
| type=int, | |
| default=50, | |
| help="Number of episodes", | |
| ) | |
| p.add_argument( | |
| "--model-path", | |
| type=str, | |
| default=None, | |
| help="Path to saved model file (required for --mode eval)", | |
| ) | |
| return p | |
| def main(): | |
| parser = _build_parser() | |
| args = parser.parse_args() | |
| if args.auto: | |
| run_auto_pipeline() | |
| return | |
| if args.mode == "fixed": | |
| run_fixed_baseline(num_episodes=args.episodes) | |
| elif args.mode == "train": | |
| run_training(args.agent, args.episodes) | |
| elif args.mode == "eval": | |
| if args.model_path is None: | |
| parser.error("--model-path is required for --mode eval") | |
| run_evaluation(args.agent, args.model_path, num_episodes=10) | |
| if __name__ == "__main__": | |
| main() | |