""" Enhanced training script for the Double DQN (DDQN) bus routing agent. Upgrades: - Best-model saving (tracks max cumulative reward) - Expanded metric tracking (Loss, Avg Q-Values) - Improved terminal telemetry - Multi-task support with OpenEnv compliance """ from __future__ import annotations import argparse import os from typing import Dict, List import numpy as np import torch from environment import BusRoutingEnv from agent import DQNAgent, DQNConfig from tasks import get_task def train( task_name: str = "medium", episodes: int = 200, # Increased default for better convergence seed: int = 0, model_out: str = "models/dqn_bus.pt", metrics_out: str = "models/training_metrics.csv", ) -> Dict[str, List[float]]: """Train a DDQN agent on the specified task and save the best model.""" task_cfg = get_task(task_name) task_cfg.seed = seed env = task_cfg.build_env() # Initialize Agent with optimized Hackathon-level config agent = DQNAgent(env.obs_size, env.num_actions, config=DQNConfig(), seed=seed) history: Dict[str, List[float]] = { "reward": [], "avg_wait": [], "fuel_used": [], "loss": [], "epsilon": [] } best_reward = -float("inf") best_model_path = model_out.replace(".pt", "_best.pt") print(f"šŸš€ Training Hackathon-Level DDQN on task: {task_cfg.name}") print(f" Stops: {task_cfg.num_stops} | Max Steps: {task_cfg.max_steps} | Capacity: {task_cfg.bus_capacity}") print(f" Episodes: {episodes} | Seed: {seed}") print("-" * 60) for ep in range(1, int(episodes) + 1): obs_model = env.reset() obs = obs_model.to_array() done = False episode_losses = [] while not done: # select_action uses the new internal pipeline (preprocess -> select) action = agent.act(obs, greedy=False) obs_model, reward_model, done, _info = env.step(action) obs2 = obs_model.to_array() agent.observe(obs, action, reward_model.value, obs2, done) obs = obs2 if agent.can_train(): metrics = agent.train_step() if not np.isnan(metrics["loss"]): episode_losses.append(metrics["loss"]) # Episode stats calculation avg_wait = ( env.total_wait_time_picked / env.total_picked if env.total_picked > 0 else 20.0 # Penalty/default for no pickups ) total_reward = float(env.total_reward) avg_loss = np.mean(episode_losses) if episode_losses else 0.0 history["reward"].append(total_reward) history["avg_wait"].append(float(avg_wait)) history["fuel_used"].append(float(env.total_fuel_used)) history["loss"].append(float(avg_loss)) history["epsilon"].append(agent.epsilon()) agent.on_episode_end() # [BEST MODEL SAVING] if total_reward > best_reward and ep > 20: best_reward = total_reward os.makedirs(os.path.dirname(best_model_path) or ".", exist_ok=True) agent.save(best_model_path) # print(f" [New Best!] Ep {ep:03d} | Reward: {total_reward:.2f}") # Logging periodic status if ep % 20 == 0 or ep == 1 or ep == episodes: print( f"ep={ep:03d} | rew={total_reward:7.1f} | wait={avg_wait:5.2f} | " f"fuel={env.total_fuel_used:5.1f} | loss={avg_loss:6.4f} | eps={agent.epsilon():.3f}" ) # Save final model os.makedirs(os.path.dirname(model_out) or ".", exist_ok=True) agent.save(model_out) print(f"\nāœ… Training Complete.") print(f" Final Model: {model_out}") print(f" Best Model: {best_model_path} (Reward: {best_reward:.2f})") if metrics_out: os.makedirs(os.path.dirname(metrics_out) or ".", exist_ok=True) with open(metrics_out, "w", encoding="utf-8") as f: f.write("episode,total_reward,avg_wait_time,fuel_used,loss,epsilon\n") for i in range(len(history["reward"])): f.write(f"{i+1},{history['reward'][i]},{history['avg_wait'][i]}," f"{history['fuel_used'][i]},{history['loss'][i]},{history['epsilon'][i]}\n") print(f" Metrics: {metrics_out}") return history def main() -> None: p = argparse.ArgumentParser(description="Train Double DQN agent on an OpenEnv task") p.add_argument("--task", type=str, default="medium", choices=["easy", "medium", "hard"]) p.add_argument("--episodes", type=int, default=200) p.add_argument("--seed", type=int, default=0) p.add_argument("--model-out", type=str, default="models/dqn_bus_v6.pt") p.add_argument("--metrics-out", type=str, default="models/training_metrics_v6.csv") args = p.parse_args() train( task_name=args.task, episodes=args.episodes, seed=args.seed, model_out=args.model_out, metrics_out=args.metrics_out, ) if __name__ == "__main__": main()