Traffic-Control / training /trainer.py
Dhaerya's picture
Add files
b00d5d5
"""
Trainer — manages the training loop for any BaseAgent.
Features:
• Per-episode logging (episode, reward, waiting time, queue, throughput)
• Automatic best-model saving
• Periodic checkpoints
• Early stopping
• DQN target-network updates
• Graceful error recovery (episode-level try/except)
• Optional tqdm progress bar
"""
from __future__ import annotations
import sys
import traceback
from pathlib import Path
import numpy as np
try:
from tqdm import tqdm
_TQDM = True
except ImportError:
_TQDM = False
from utils.logger import setup_logger
from utils.metrics import MetricsTracker
class Trainer:
"""
Orchestrates the RL training loop.
Args:
env: A Gymnasium-compatible environment.
agent: Any agent that inherits :class:`BaseAgent`.
config: The project config module.
"""
def __init__(self, env, agent, config):
self.env = env
self.agent = agent
self.config = config
# Directories
config.RESULTS_DIR.mkdir(parents=True, exist_ok=True)
config.MODELS_DIR.mkdir(parents=True, exist_ok=True)
(config.RESULTS_DIR / "logs").mkdir(exist_ok=True)
(config.RESULTS_DIR / "plots").mkdir(exist_ok=True)
(config.RESULTS_DIR / "checkpoints").mkdir(exist_ok=True)
# Logger
log_file = config.RESULTS_DIR / "logs" / "training.log"
self.logger = setup_logger("trainer", log_file=str(log_file))
# Metrics
self.metrics = MetricsTracker()
# State
self.best_reward: float = -np.inf
self.episodes_without_improvement: int = 0
self.current_episode: int = 0
self.total_steps: int = 0
self.logger.info("=" * 70)
self.logger.info("TRAINER READY")
self.logger.info(f" Agent type : {config.AGENT_TYPE}")
self.logger.info(f" Results dir: {config.RESULTS_DIR}")
self.logger.info("=" * 70)
# ------------------------------------------------------------------
# Public
# ------------------------------------------------------------------
def train(self, num_episodes: int):
"""
Run the training loop for *num_episodes* episodes.
Args:
num_episodes: Number of episodes to train.
"""
self.logger.info(f"Starting training — {num_episodes} episodes")
iterator = (
tqdm(range(1, num_episodes + 1), desc="Training", unit="ep")
if _TQDM
else range(1, num_episodes + 1)
)
try:
for episode in iterator:
self.current_episode = episode
try:
ep_reward, ep_info = self._run_episode(training=True)
except KeyboardInterrupt:
self.logger.info("Training interrupted by user.")
self._save_checkpoint(episode, emergency=True)
raise
except Exception as exc:
self.logger.error(f"Episode {episode} error: {exc}")
self.logger.debug(traceback.format_exc())
continue
# Record metrics
self.metrics.add("episode_reward", ep_reward)
self.metrics.add("average_waiting_time",
ep_info.get("average_waiting_time", 0.0))
self.metrics.add("average_queue_length",
ep_info.get("total_queue_length", 0.0))
self.metrics.add("throughput",
ep_info.get("vehicles_passed", 0))
# Per-episode log
self.logger.info(
f"Ep {episode:4d}/{num_episodes} "
f"reward={ep_reward:8.2f} "
f"wait={ep_info.get('average_waiting_time', 0):7.1f} "
f"queue={ep_info.get('total_queue_length', 0):6.1f} "
f"thru={ep_info.get('vehicles_passed', 0):4d}"
)
# DQN: sync target network
if hasattr(self.agent, "update_target_network"):
freq = self.config.DQN_CONFIG.get("target_update", 10)
if episode % freq == 0:
self.agent.update_target_network()
# Save best model
if ep_reward > self.best_reward:
self.best_reward = ep_reward
self.episodes_without_improvement = 0
self._save_best_model(episode, ep_reward)
else:
self.episodes_without_improvement += 1
# Periodic checkpoint
if episode % self.config.SAVE_FREQUENCY == 0:
self._save_checkpoint(episode)
# Periodic summary
if episode % 100 == 0:
self._log_summary(episode, num_episodes)
# Early stopping
if self.episodes_without_improvement >= self.config.EARLY_STOPPING_PATIENCE:
self.logger.info(
f"Early stopping at episode {episode} "
f"(no improvement for "
f"{self.config.EARLY_STOPPING_PATIENCE} episodes)."
)
break
except KeyboardInterrupt:
self.logger.info("Exiting gracefully.")
sys.exit(0)
self.logger.info("=" * 70)
self.logger.info("TRAINING COMPLETE")
self._log_final_summary()
self._save_metrics()
self._plot_results()
# ------------------------------------------------------------------
# Internal
# ------------------------------------------------------------------
def _run_episode(self, training: bool = True) -> tuple[float, dict]:
"""Execute one full episode."""
state, _ = self.env.reset()
ep_reward = 0.0
done = False
info: dict = {}
max_steps = self.config.EPISODE_LENGTH * 2
steps = 0
while not done and steps < max_steps:
action = self.agent.select_action(state, training=training)
next_state, reward, terminated, truncated, info = self.env.step(action)
done = terminated or truncated
if training:
loss = self.agent.train_step(state, action, reward, next_state, done)
if loss is not None:
self.metrics.add("loss", float(loss))
state = next_state
ep_reward += reward
steps += 1
self.total_steps += 1
return ep_reward, info
def _save_best_model(self, episode: int, reward: float):
path = self.config.MODELS_DIR / f"{self.config.AGENT_TYPE}_best.pth"
try:
self.agent.save(str(path))
self.logger.info(
f"[OK] Best model saved reward={reward:.2f} (episode {episode})"
)
except Exception as exc:
self.logger.error(f"[FAIL] Could not save best model: {exc}")
def _save_checkpoint(self, episode: int, emergency: bool = False):
tag = "emergency" if emergency else f"ep{episode}"
path = (
self.config.RESULTS_DIR
/ "checkpoints"
/ f"{self.config.AGENT_TYPE}_{tag}.pth"
)
try:
self.agent.save(str(path))
self.logger.info(f"[OK] Checkpoint saved -> {path}")
except Exception as exc:
self.logger.error(f"[FAIL] Could not save checkpoint: {exc}")
def _log_summary(self, episode: int, total: int):
n = min(100, episode)
self.logger.info("-" * 70)
self.logger.info(f"Summary ep {episode}/{total}")
self.logger.info(
f" Avg reward (last {n}): "
f"{self.metrics.get_mean('episode_reward', last_n=n):8.2f}"
)
self.logger.info(
f" Avg wait (last {n}): "
f"{self.metrics.get_mean('average_waiting_time', last_n=n):8.2f}"
)
self.logger.info(f" Best reward so far : {self.best_reward:8.2f}")
self.logger.info("-" * 70)
def _log_final_summary(self):
all_r = self.metrics.get("episode_reward")
if not all_r:
return
self.logger.info("FINAL STATISTICS")
self.logger.info(f" Total episodes : {len(all_r)}")
self.logger.info(f" Best reward : {self.best_reward:.2f}")
self.logger.info(f" Mean reward : {np.mean(all_r):.2f}")
self.logger.info(f" Std reward : {np.std(all_r):.2f}")
def _save_metrics(self):
path = self.config.RESULTS_DIR / "metrics.json"
try:
self.metrics.save(path)
self.logger.info(f"[OK] Metrics saved -> {path}")
except Exception as exc:
self.logger.warning(f"Could not save metrics: {exc}")
def _plot_results(self):
try:
from utils.visualizer import plot_training_curves
save = self.config.RESULTS_DIR / "plots" / f"{self.config.AGENT_TYPE}_training.png"
plot_training_curves(self.metrics, save_path=save)
except Exception as exc:
self.logger.warning(f"Could not plot results: {exc}")