""" Trainer — manages the training loop for any BaseAgent. Features: • Per-episode logging (episode, reward, waiting time, queue, throughput) • Automatic best-model saving • Periodic checkpoints • Early stopping • DQN target-network updates • Graceful error recovery (episode-level try/except) • Optional tqdm progress bar """ from __future__ import annotations import sys import traceback from pathlib import Path import numpy as np try: from tqdm import tqdm _TQDM = True except ImportError: _TQDM = False from utils.logger import setup_logger from utils.metrics import MetricsTracker class Trainer: """ Orchestrates the RL training loop. Args: env: A Gymnasium-compatible environment. agent: Any agent that inherits :class:`BaseAgent`. config: The project config module. """ def __init__(self, env, agent, config): self.env = env self.agent = agent self.config = config # Directories config.RESULTS_DIR.mkdir(parents=True, exist_ok=True) config.MODELS_DIR.mkdir(parents=True, exist_ok=True) (config.RESULTS_DIR / "logs").mkdir(exist_ok=True) (config.RESULTS_DIR / "plots").mkdir(exist_ok=True) (config.RESULTS_DIR / "checkpoints").mkdir(exist_ok=True) # Logger log_file = config.RESULTS_DIR / "logs" / "training.log" self.logger = setup_logger("trainer", log_file=str(log_file)) # Metrics self.metrics = MetricsTracker() # State self.best_reward: float = -np.inf self.episodes_without_improvement: int = 0 self.current_episode: int = 0 self.total_steps: int = 0 self.logger.info("=" * 70) self.logger.info("TRAINER READY") self.logger.info(f" Agent type : {config.AGENT_TYPE}") self.logger.info(f" Results dir: {config.RESULTS_DIR}") self.logger.info("=" * 70) # ------------------------------------------------------------------ # Public # ------------------------------------------------------------------ def train(self, num_episodes: int): """ Run the training loop for *num_episodes* episodes. Args: num_episodes: Number of episodes to train. """ self.logger.info(f"Starting training — {num_episodes} episodes") iterator = ( tqdm(range(1, num_episodes + 1), desc="Training", unit="ep") if _TQDM else range(1, num_episodes + 1) ) try: for episode in iterator: self.current_episode = episode try: ep_reward, ep_info = self._run_episode(training=True) except KeyboardInterrupt: self.logger.info("Training interrupted by user.") self._save_checkpoint(episode, emergency=True) raise except Exception as exc: self.logger.error(f"Episode {episode} error: {exc}") self.logger.debug(traceback.format_exc()) continue # Record metrics self.metrics.add("episode_reward", ep_reward) self.metrics.add("average_waiting_time", ep_info.get("average_waiting_time", 0.0)) self.metrics.add("average_queue_length", ep_info.get("total_queue_length", 0.0)) self.metrics.add("throughput", ep_info.get("vehicles_passed", 0)) # Per-episode log self.logger.info( f"Ep {episode:4d}/{num_episodes} " f"reward={ep_reward:8.2f} " f"wait={ep_info.get('average_waiting_time', 0):7.1f} " f"queue={ep_info.get('total_queue_length', 0):6.1f} " f"thru={ep_info.get('vehicles_passed', 0):4d}" ) # DQN: sync target network if hasattr(self.agent, "update_target_network"): freq = self.config.DQN_CONFIG.get("target_update", 10) if episode % freq == 0: self.agent.update_target_network() # Save best model if ep_reward > self.best_reward: self.best_reward = ep_reward self.episodes_without_improvement = 0 self._save_best_model(episode, ep_reward) else: self.episodes_without_improvement += 1 # Periodic checkpoint if episode % self.config.SAVE_FREQUENCY == 0: self._save_checkpoint(episode) # Periodic summary if episode % 100 == 0: self._log_summary(episode, num_episodes) # Early stopping if self.episodes_without_improvement >= self.config.EARLY_STOPPING_PATIENCE: self.logger.info( f"Early stopping at episode {episode} " f"(no improvement for " f"{self.config.EARLY_STOPPING_PATIENCE} episodes)." ) break except KeyboardInterrupt: self.logger.info("Exiting gracefully.") sys.exit(0) self.logger.info("=" * 70) self.logger.info("TRAINING COMPLETE") self._log_final_summary() self._save_metrics() self._plot_results() # ------------------------------------------------------------------ # Internal # ------------------------------------------------------------------ def _run_episode(self, training: bool = True) -> tuple[float, dict]: """Execute one full episode.""" state, _ = self.env.reset() ep_reward = 0.0 done = False info: dict = {} max_steps = self.config.EPISODE_LENGTH * 2 steps = 0 while not done and steps < max_steps: action = self.agent.select_action(state, training=training) next_state, reward, terminated, truncated, info = self.env.step(action) done = terminated or truncated if training: loss = self.agent.train_step(state, action, reward, next_state, done) if loss is not None: self.metrics.add("loss", float(loss)) state = next_state ep_reward += reward steps += 1 self.total_steps += 1 return ep_reward, info def _save_best_model(self, episode: int, reward: float): path = self.config.MODELS_DIR / f"{self.config.AGENT_TYPE}_best.pth" try: self.agent.save(str(path)) self.logger.info( f"[OK] Best model saved reward={reward:.2f} (episode {episode})" ) except Exception as exc: self.logger.error(f"[FAIL] Could not save best model: {exc}") def _save_checkpoint(self, episode: int, emergency: bool = False): tag = "emergency" if emergency else f"ep{episode}" path = ( self.config.RESULTS_DIR / "checkpoints" / f"{self.config.AGENT_TYPE}_{tag}.pth" ) try: self.agent.save(str(path)) self.logger.info(f"[OK] Checkpoint saved -> {path}") except Exception as exc: self.logger.error(f"[FAIL] Could not save checkpoint: {exc}") def _log_summary(self, episode: int, total: int): n = min(100, episode) self.logger.info("-" * 70) self.logger.info(f"Summary ep {episode}/{total}") self.logger.info( f" Avg reward (last {n}): " f"{self.metrics.get_mean('episode_reward', last_n=n):8.2f}" ) self.logger.info( f" Avg wait (last {n}): " f"{self.metrics.get_mean('average_waiting_time', last_n=n):8.2f}" ) self.logger.info(f" Best reward so far : {self.best_reward:8.2f}") self.logger.info("-" * 70) def _log_final_summary(self): all_r = self.metrics.get("episode_reward") if not all_r: return self.logger.info("FINAL STATISTICS") self.logger.info(f" Total episodes : {len(all_r)}") self.logger.info(f" Best reward : {self.best_reward:.2f}") self.logger.info(f" Mean reward : {np.mean(all_r):.2f}") self.logger.info(f" Std reward : {np.std(all_r):.2f}") def _save_metrics(self): path = self.config.RESULTS_DIR / "metrics.json" try: self.metrics.save(path) self.logger.info(f"[OK] Metrics saved -> {path}") except Exception as exc: self.logger.warning(f"Could not save metrics: {exc}") def _plot_results(self): try: from utils.visualizer import plot_training_curves save = self.config.RESULTS_DIR / "plots" / f"{self.config.AGENT_TYPE}_training.png" plot_training_curves(self.metrics, save_path=save) except Exception as exc: self.logger.warning(f"Could not plot results: {exc}")