Spaces:
Running
Running
File size: 5,141 Bytes
a888789 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | """
Enhanced training script for the Double DQN (DDQN) bus routing agent.
Upgrades:
- Best-model saving (tracks max cumulative reward)
- Expanded metric tracking (Loss, Avg Q-Values)
- Improved terminal telemetry
- Multi-task support with OpenEnv compliance
"""
from __future__ import annotations
import argparse
import os
from typing import Dict, List
import numpy as np
import torch
from environment import BusRoutingEnv
from agent import DQNAgent, DQNConfig
from tasks import get_task
def train(
task_name: str = "medium",
episodes: int = 200, # Increased default for better convergence
seed: int = 0,
model_out: str = "models/dqn_bus.pt",
metrics_out: str = "models/training_metrics.csv",
) -> Dict[str, List[float]]:
"""Train a DDQN agent on the specified task and save the best model."""
task_cfg = get_task(task_name)
task_cfg.seed = seed
env = task_cfg.build_env()
# Initialize Agent with optimized Hackathon-level config
agent = DQNAgent(env.obs_size, env.num_actions, config=DQNConfig(), seed=seed)
history: Dict[str, List[float]] = {
"reward": [],
"avg_wait": [],
"fuel_used": [],
"loss": [],
"epsilon": []
}
best_reward = -float("inf")
best_model_path = model_out.replace(".pt", "_best.pt")
print(f"🚀 Training Hackathon-Level DDQN on task: {task_cfg.name}")
print(f" Stops: {task_cfg.num_stops} | Max Steps: {task_cfg.max_steps} | Capacity: {task_cfg.bus_capacity}")
print(f" Episodes: {episodes} | Seed: {seed}")
print("-" * 60)
for ep in range(1, int(episodes) + 1):
obs_model = env.reset()
obs = obs_model.to_array()
done = False
episode_losses = []
while not done:
# select_action uses the new internal pipeline (preprocess -> select)
action = agent.act(obs, greedy=False)
obs_model, reward_model, done, _info = env.step(action)
obs2 = obs_model.to_array()
agent.observe(obs, action, reward_model.value, obs2, done)
obs = obs2
if agent.can_train():
metrics = agent.train_step()
if not np.isnan(metrics["loss"]):
episode_losses.append(metrics["loss"])
# Episode stats calculation
avg_wait = (
env.total_wait_time_picked / env.total_picked
if env.total_picked > 0
else 20.0 # Penalty/default for no pickups
)
total_reward = float(env.total_reward)
avg_loss = np.mean(episode_losses) if episode_losses else 0.0
history["reward"].append(total_reward)
history["avg_wait"].append(float(avg_wait))
history["fuel_used"].append(float(env.total_fuel_used))
history["loss"].append(float(avg_loss))
history["epsilon"].append(agent.epsilon())
agent.on_episode_end()
# [BEST MODEL SAVING]
if total_reward > best_reward and ep > 20:
best_reward = total_reward
os.makedirs(os.path.dirname(best_model_path) or ".", exist_ok=True)
agent.save(best_model_path)
# print(f" [New Best!] Ep {ep:03d} | Reward: {total_reward:.2f}")
# Logging periodic status
if ep % 20 == 0 or ep == 1 or ep == episodes:
print(
f"ep={ep:03d} | rew={total_reward:7.1f} | wait={avg_wait:5.2f} | "
f"fuel={env.total_fuel_used:5.1f} | loss={avg_loss:6.4f} | eps={agent.epsilon():.3f}"
)
# Save final model
os.makedirs(os.path.dirname(model_out) or ".", exist_ok=True)
agent.save(model_out)
print(f"\n✅ Training Complete.")
print(f" Final Model: {model_out}")
print(f" Best Model: {best_model_path} (Reward: {best_reward:.2f})")
if metrics_out:
os.makedirs(os.path.dirname(metrics_out) or ".", exist_ok=True)
with open(metrics_out, "w", encoding="utf-8") as f:
f.write("episode,total_reward,avg_wait_time,fuel_used,loss,epsilon\n")
for i in range(len(history["reward"])):
f.write(f"{i+1},{history['reward'][i]},{history['avg_wait'][i]},"
f"{history['fuel_used'][i]},{history['loss'][i]},{history['epsilon'][i]}\n")
print(f" Metrics: {metrics_out}")
return history
def main() -> None:
p = argparse.ArgumentParser(description="Train Double DQN agent on an OpenEnv task")
p.add_argument("--task", type=str, default="medium", choices=["easy", "medium", "hard"])
p.add_argument("--episodes", type=int, default=200)
p.add_argument("--seed", type=int, default=0)
p.add_argument("--model-out", type=str, default="models/dqn_bus_v6.pt")
p.add_argument("--metrics-out", type=str, default="models/training_metrics_v6.csv")
args = p.parse_args()
train(
task_name=args.task,
episodes=args.episodes,
seed=args.seed,
model_out=args.model_out,
metrics_out=args.metrics_out,
)
if __name__ == "__main__":
main()
|