Traffic-Control / main.py
Dhaerya's picture
Add files
b00d5d5
"""
main.py β€” RL Traffic Signal Control entry point.
Automated pipeline (recommended):
python main.py --auto
Manual usage:
python main.py --mode train --agent q_learning --episodes 50
python main.py --mode train --agent dqn --episodes 150
python main.py --mode eval --agent q_learning --model-path models/q_learning_best.pth
python main.py --mode eval --agent dqn --model-path models/dqn_best.pth
python main.py --mode fixed # Fixed-signal baseline only
The --auto flag runs the full pipeline:
1. Fixed-signal baseline (10 episodes)
2. Q-Learning training (50 episodes)
3. DQN training (150 episodes)
4. Evaluation & comparison plots
"""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
import numpy as np
# ── Project imports ───────────────────────────────────────────────────────────
import config as cfg
from environment import TrafficEnvironment
from agent import QLearningAgent, DQNAgent, DQN_AVAILABLE
from training import Trainer, Evaluator
from utils import setup_logger, MetricsTracker, plot_training_curves
from utils.visualizer import plot_comparison, plot_bar_comparison
logger = setup_logger("main")
# ═══════════════════════════════════════════════════════════════════════════════
# Factory helpers
# ═══════════════════════════════════════════════════════════════════════════════
def make_env() -> TrafficEnvironment:
"""Create a fresh environment instance."""
return TrafficEnvironment(cfg)
def make_q_learning_agent() -> QLearningAgent:
"""Instantiate a Q-Learning agent using project config."""
return QLearningAgent(
state_size=cfg.STATE_SIZE,
action_size=cfg.ACTION_SIZE,
config=cfg.Q_LEARNING_CONFIG,
)
def make_dqn_agent():
"""Instantiate a DQN agent using project config (requires PyTorch)."""
if not DQN_AVAILABLE:
logger.error("PyTorch is not installed β€” DQN unavailable.")
logger.error("Install with: pip install torch")
sys.exit(1)
return DQNAgent(
state_size=cfg.STATE_SIZE,
action_size=cfg.ACTION_SIZE,
config=cfg.DQN_CONFIG,
)
# ═══════════════════════════════════════════════════════════════════════════════
# Fixed-signal baseline
# ═══════════════════════════════════════════════════════════════════════════════
class FixedSignalAgent:
"""
Round-robin fixed-timing signal β€” cycles phases every 30 steps.
Used as the comparison baseline.
"""
def __init__(self, switch_interval: int = 30):
self.switch_interval = switch_interval
self._step = 0
def select_action(self, state, training: bool = False) -> int:
self._step += 1
return 1 if self._step % self.switch_interval == 0 else 0
def train_step(self, *args, **kwargs):
return None
def save(self, filepath):
pass
def load(self, filepath):
pass
def reset(self):
self._step = 0
def run_fixed_baseline(num_episodes: int = 10) -> tuple[list[float], dict]:
"""
Evaluate the fixed-timing signal for *num_episodes* episodes.
Returns:
(episode_rewards, summary_dict)
"""
logger.info("=" * 60)
logger.info(f"FIXED-SIGNAL BASELINE ({num_episodes} episodes)")
logger.info("=" * 60)
agent = FixedSignalAgent(switch_interval=30)
env = make_env()
rewards: list[float] = []
info: dict = {}
for ep in range(1, num_episodes + 1):
state, _ = env.reset()
agent.reset()
ep_reward = 0.0
done = False
while not done:
action = agent.select_action(state)
state, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
ep_reward += reward
rewards.append(ep_reward)
logger.info(f" Episode {ep:3d}/{num_episodes} reward={ep_reward:.2f}")
mean_r = float(np.mean(rewards))
logger.info(f"Baseline mean reward: {mean_r:.2f}")
return rewards, {
"mean_reward": mean_r,
"std_reward": float(np.std(rewards)),
"best_reward": float(np.max(rewards)),
"mean_waiting_time": float(info.get("average_waiting_time", 0)),
"mean_queue_length": float(info.get("total_queue_length", 0)),
}
# ═══════════════════════════════════════════════════════════════════════════════
# Training mode
# ═══════════════════════════════════════════════════════════════════════════════
def run_training(agent_type: str, num_episodes: int):
"""
Train the specified agent and save the best model.
Args:
agent_type: "q_learning" or "dqn".
num_episodes: Number of training episodes.
"""
logger.info("=" * 60)
logger.info(f"TRAINING agent={agent_type} episodes={num_episodes}")
logger.info("=" * 60)
env = make_env()
if agent_type == "q_learning":
cfg.AGENT_TYPE = "q_learning"
agent = make_q_learning_agent()
elif agent_type == "dqn":
cfg.AGENT_TYPE = "dqn"
agent = make_dqn_agent()
else:
logger.error(f"Unknown agent type: {agent_type!r}")
sys.exit(1)
trainer = Trainer(env, agent, cfg)
trainer.train(num_episodes)
logger.info(f"Training complete. Best reward: {trainer.best_reward:.2f}")
return trainer
# ═══════════════════════════════════════════════════════════════════════════════
# Evaluation mode
# ═══════════════════════════════════════════════════════════════════════════════
def run_evaluation(agent_type: str, model_path: str, num_episodes: int = 10) -> dict:
"""
Load a saved model and evaluate it.
Args:
agent_type: "q_learning" or "dqn".
model_path: Path to saved model file.
num_episodes: Evaluation episodes.
Returns:
Evaluation results dictionary.
"""
logger.info("=" * 60)
logger.info(f"EVALUATION agent={agent_type} model={model_path}")
logger.info("=" * 60)
env = make_env()
if agent_type == "q_learning":
cfg.AGENT_TYPE = "q_learning"
agent = make_q_learning_agent()
else:
cfg.AGENT_TYPE = "dqn"
agent = make_dqn_agent()
agent.load(model_path)
evaluator = Evaluator(env, agent, cfg)
results = evaluator.evaluate(num_episodes)
logger.info("Evaluation results:")
for k, v in results.items():
logger.info(f" {k}: {v:.4f}" if isinstance(v, float) else f" {k}: {v}")
return results
# ═══════════════════════════════════════════════════════════════════════════════
# Automated pipeline
# ═══════════════════════════════════════════════════════════════════════════════
def run_auto_pipeline():
"""
Full automated pipeline:
1. Fixed-signal baseline
2. Q-Learning training (50 episodes)
3. DQN training (150 episodes)
4. Evaluation of all methods
5. Comparison plots
"""
logger.info("β•”" + "═" * 58 + "β•—")
logger.info("β•‘ AUTOMATED RL TRAFFIC SIGNAL CONTROL PIPELINE β•‘")
logger.info("β•š" + "═" * 58 + "╝")
summary: dict[str, dict] = {}
# ── 1. Fixed-signal baseline ──────────────────────────────────────
baseline_rewards, baseline_results = run_fixed_baseline(num_episodes=10)
summary["Fixed Signal"] = baseline_results
# ── 2. Q-Learning ────────────────────────────────────────────────
ql_trainer = run_training("q_learning", num_episodes=50)
summary["Q-Learning"] = {
"mean_reward": ql_trainer.metrics.get_mean("episode_reward"),
"best_reward": ql_trainer.best_reward,
"std_reward": ql_trainer.metrics.get_std("episode_reward"),
}
# ── 3. DQN ───────────────────────────────────────────────────────
if DQN_AVAILABLE:
dqn_trainer = run_training("dqn", num_episodes=150)
summary["DQN"] = {
"mean_reward": dqn_trainer.metrics.get_mean("episode_reward"),
"best_reward": dqn_trainer.best_reward,
"std_reward": dqn_trainer.metrics.get_std("episode_reward"),
}
else:
logger.warning("DQN skipped (PyTorch not available).")
# ── 4. Print comparison table ─────────────────────────────────────
_print_comparison_table(summary)
# ── 5. Plots ──────────────────────────────────────────────────────
_generate_comparison_plots(summary)
logger.info("Pipeline complete.")
def _print_comparison_table(summary: dict):
"""Print a neat comparison table to stdout."""
print("\n")
print("=" * 60)
print(f"{'Method':<18} {'Mean Reward':>14} {'Best Reward':>14}")
print("-" * 60)
baseline_mean = summary.get("Fixed Signal", {}).get("mean_reward", 0)
for method, res in summary.items():
mean_r = res.get("mean_reward", 0)
best_r = res.get("best_reward", 0)
delta = mean_r - baseline_mean if method != "Fixed Signal" else 0
delta_str = f" ({delta:+.2f})" if method != "Fixed Signal" else ""
print(f"{method:<18} {mean_r:>14.2f} {best_r:>14.2f}{delta_str}")
print("=" * 60)
print()
def _generate_comparison_plots(summary: dict):
"""Save bar-chart comparison of mean rewards."""
scores = {m: r.get("mean_reward", 0) for m, r in summary.items()}
save_path = cfg.RESULTS_DIR / "plots" / "comparison_bar.png"
plot_bar_comparison(
scores,
title="Mean Reward by Method (higher = better)",
ylabel="Mean Reward",
save_path=save_path,
)
# ═══════════════════════════════════════════════════════════════════════════════
# CLI
# ═══════════════════════════════════════════════════════════════════════════════
def _build_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(
description="RL Traffic Signal Control",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python main.py --auto
python main.py --mode train --agent q_learning --episodes 50
python main.py --mode train --agent dqn --episodes 150
python main.py --mode eval --agent q_learning --model-path models/q_learning_best.pth
python main.py --mode fixed
""",
)
p.add_argument(
"--auto",
action="store_true",
help="Run the full automated pipeline (recommended)",
)
p.add_argument(
"--mode",
choices=["train", "eval", "fixed"],
default="train",
help="Mode to run (ignored when --auto is set)",
)
p.add_argument(
"--agent",
choices=["q_learning", "dqn"],
default="q_learning",
help="Agent type",
)
p.add_argument(
"--episodes",
type=int,
default=50,
help="Number of episodes",
)
p.add_argument(
"--model-path",
type=str,
default=None,
help="Path to saved model file (required for --mode eval)",
)
return p
def main():
parser = _build_parser()
args = parser.parse_args()
if args.auto:
run_auto_pipeline()
return
if args.mode == "fixed":
run_fixed_baseline(num_episodes=args.episodes)
elif args.mode == "train":
run_training(args.agent, args.episodes)
elif args.mode == "eval":
if args.model_path is None:
parser.error("--model-path is required for --mode eval")
run_evaluation(args.agent, args.model_path, num_episodes=10)
if __name__ == "__main__":
main()