Spaces:
Sleeping
Sleeping
| """ | |
| Trainer — manages the training loop for any BaseAgent. | |
| Features: | |
| • Per-episode logging (episode, reward, waiting time, queue, throughput) | |
| • Automatic best-model saving | |
| • Periodic checkpoints | |
| • Early stopping | |
| • DQN target-network updates | |
| • Graceful error recovery (episode-level try/except) | |
| • Optional tqdm progress bar | |
| """ | |
| from __future__ import annotations | |
| import sys | |
| import traceback | |
| from pathlib import Path | |
| import numpy as np | |
| try: | |
| from tqdm import tqdm | |
| _TQDM = True | |
| except ImportError: | |
| _TQDM = False | |
| from utils.logger import setup_logger | |
| from utils.metrics import MetricsTracker | |
| class Trainer: | |
| """ | |
| Orchestrates the RL training loop. | |
| Args: | |
| env: A Gymnasium-compatible environment. | |
| agent: Any agent that inherits :class:`BaseAgent`. | |
| config: The project config module. | |
| """ | |
| def __init__(self, env, agent, config): | |
| self.env = env | |
| self.agent = agent | |
| self.config = config | |
| # Directories | |
| config.RESULTS_DIR.mkdir(parents=True, exist_ok=True) | |
| config.MODELS_DIR.mkdir(parents=True, exist_ok=True) | |
| (config.RESULTS_DIR / "logs").mkdir(exist_ok=True) | |
| (config.RESULTS_DIR / "plots").mkdir(exist_ok=True) | |
| (config.RESULTS_DIR / "checkpoints").mkdir(exist_ok=True) | |
| # Logger | |
| log_file = config.RESULTS_DIR / "logs" / "training.log" | |
| self.logger = setup_logger("trainer", log_file=str(log_file)) | |
| # Metrics | |
| self.metrics = MetricsTracker() | |
| # State | |
| self.best_reward: float = -np.inf | |
| self.episodes_without_improvement: int = 0 | |
| self.current_episode: int = 0 | |
| self.total_steps: int = 0 | |
| self.logger.info("=" * 70) | |
| self.logger.info("TRAINER READY") | |
| self.logger.info(f" Agent type : {config.AGENT_TYPE}") | |
| self.logger.info(f" Results dir: {config.RESULTS_DIR}") | |
| self.logger.info("=" * 70) | |
| # ------------------------------------------------------------------ | |
| # Public | |
| # ------------------------------------------------------------------ | |
| def train(self, num_episodes: int): | |
| """ | |
| Run the training loop for *num_episodes* episodes. | |
| Args: | |
| num_episodes: Number of episodes to train. | |
| """ | |
| self.logger.info(f"Starting training — {num_episodes} episodes") | |
| iterator = ( | |
| tqdm(range(1, num_episodes + 1), desc="Training", unit="ep") | |
| if _TQDM | |
| else range(1, num_episodes + 1) | |
| ) | |
| try: | |
| for episode in iterator: | |
| self.current_episode = episode | |
| try: | |
| ep_reward, ep_info = self._run_episode(training=True) | |
| except KeyboardInterrupt: | |
| self.logger.info("Training interrupted by user.") | |
| self._save_checkpoint(episode, emergency=True) | |
| raise | |
| except Exception as exc: | |
| self.logger.error(f"Episode {episode} error: {exc}") | |
| self.logger.debug(traceback.format_exc()) | |
| continue | |
| # Record metrics | |
| self.metrics.add("episode_reward", ep_reward) | |
| self.metrics.add("average_waiting_time", | |
| ep_info.get("average_waiting_time", 0.0)) | |
| self.metrics.add("average_queue_length", | |
| ep_info.get("total_queue_length", 0.0)) | |
| self.metrics.add("throughput", | |
| ep_info.get("vehicles_passed", 0)) | |
| # Per-episode log | |
| self.logger.info( | |
| f"Ep {episode:4d}/{num_episodes} " | |
| f"reward={ep_reward:8.2f} " | |
| f"wait={ep_info.get('average_waiting_time', 0):7.1f} " | |
| f"queue={ep_info.get('total_queue_length', 0):6.1f} " | |
| f"thru={ep_info.get('vehicles_passed', 0):4d}" | |
| ) | |
| # DQN: sync target network | |
| if hasattr(self.agent, "update_target_network"): | |
| freq = self.config.DQN_CONFIG.get("target_update", 10) | |
| if episode % freq == 0: | |
| self.agent.update_target_network() | |
| # Save best model | |
| if ep_reward > self.best_reward: | |
| self.best_reward = ep_reward | |
| self.episodes_without_improvement = 0 | |
| self._save_best_model(episode, ep_reward) | |
| else: | |
| self.episodes_without_improvement += 1 | |
| # Periodic checkpoint | |
| if episode % self.config.SAVE_FREQUENCY == 0: | |
| self._save_checkpoint(episode) | |
| # Periodic summary | |
| if episode % 100 == 0: | |
| self._log_summary(episode, num_episodes) | |
| # Early stopping | |
| if self.episodes_without_improvement >= self.config.EARLY_STOPPING_PATIENCE: | |
| self.logger.info( | |
| f"Early stopping at episode {episode} " | |
| f"(no improvement for " | |
| f"{self.config.EARLY_STOPPING_PATIENCE} episodes)." | |
| ) | |
| break | |
| except KeyboardInterrupt: | |
| self.logger.info("Exiting gracefully.") | |
| sys.exit(0) | |
| self.logger.info("=" * 70) | |
| self.logger.info("TRAINING COMPLETE") | |
| self._log_final_summary() | |
| self._save_metrics() | |
| self._plot_results() | |
| # ------------------------------------------------------------------ | |
| # Internal | |
| # ------------------------------------------------------------------ | |
| def _run_episode(self, training: bool = True) -> tuple[float, dict]: | |
| """Execute one full episode.""" | |
| state, _ = self.env.reset() | |
| ep_reward = 0.0 | |
| done = False | |
| info: dict = {} | |
| max_steps = self.config.EPISODE_LENGTH * 2 | |
| steps = 0 | |
| while not done and steps < max_steps: | |
| action = self.agent.select_action(state, training=training) | |
| next_state, reward, terminated, truncated, info = self.env.step(action) | |
| done = terminated or truncated | |
| if training: | |
| loss = self.agent.train_step(state, action, reward, next_state, done) | |
| if loss is not None: | |
| self.metrics.add("loss", float(loss)) | |
| state = next_state | |
| ep_reward += reward | |
| steps += 1 | |
| self.total_steps += 1 | |
| return ep_reward, info | |
| def _save_best_model(self, episode: int, reward: float): | |
| path = self.config.MODELS_DIR / f"{self.config.AGENT_TYPE}_best.pth" | |
| try: | |
| self.agent.save(str(path)) | |
| self.logger.info( | |
| f"[OK] Best model saved reward={reward:.2f} (episode {episode})" | |
| ) | |
| except Exception as exc: | |
| self.logger.error(f"[FAIL] Could not save best model: {exc}") | |
| def _save_checkpoint(self, episode: int, emergency: bool = False): | |
| tag = "emergency" if emergency else f"ep{episode}" | |
| path = ( | |
| self.config.RESULTS_DIR | |
| / "checkpoints" | |
| / f"{self.config.AGENT_TYPE}_{tag}.pth" | |
| ) | |
| try: | |
| self.agent.save(str(path)) | |
| self.logger.info(f"[OK] Checkpoint saved -> {path}") | |
| except Exception as exc: | |
| self.logger.error(f"[FAIL] Could not save checkpoint: {exc}") | |
| def _log_summary(self, episode: int, total: int): | |
| n = min(100, episode) | |
| self.logger.info("-" * 70) | |
| self.logger.info(f"Summary ep {episode}/{total}") | |
| self.logger.info( | |
| f" Avg reward (last {n}): " | |
| f"{self.metrics.get_mean('episode_reward', last_n=n):8.2f}" | |
| ) | |
| self.logger.info( | |
| f" Avg wait (last {n}): " | |
| f"{self.metrics.get_mean('average_waiting_time', last_n=n):8.2f}" | |
| ) | |
| self.logger.info(f" Best reward so far : {self.best_reward:8.2f}") | |
| self.logger.info("-" * 70) | |
| def _log_final_summary(self): | |
| all_r = self.metrics.get("episode_reward") | |
| if not all_r: | |
| return | |
| self.logger.info("FINAL STATISTICS") | |
| self.logger.info(f" Total episodes : {len(all_r)}") | |
| self.logger.info(f" Best reward : {self.best_reward:.2f}") | |
| self.logger.info(f" Mean reward : {np.mean(all_r):.2f}") | |
| self.logger.info(f" Std reward : {np.std(all_r):.2f}") | |
| def _save_metrics(self): | |
| path = self.config.RESULTS_DIR / "metrics.json" | |
| try: | |
| self.metrics.save(path) | |
| self.logger.info(f"[OK] Metrics saved -> {path}") | |
| except Exception as exc: | |
| self.logger.warning(f"Could not save metrics: {exc}") | |
| def _plot_results(self): | |
| try: | |
| from utils.visualizer import plot_training_curves | |
| save = self.config.RESULTS_DIR / "plots" / f"{self.config.AGENT_TYPE}_training.png" | |
| plot_training_curves(self.metrics, save_path=save) | |
| except Exception as exc: | |
| self.logger.warning(f"Could not plot results: {exc}") | |