agentic-traffic / training /trainer.py
Aditya2162's picture
Upload folder using huggingface_hub
3d2dbcf verified
from __future__ import annotations
from collections import deque
from concurrent.futures import ProcessPoolExecutor, as_completed
import json
from itertools import islice
import os
import random
from dataclasses import asdict, dataclass
from pathlib import Path
from time import perf_counter
from typing import Any
import numpy as np
import torch
from torch import nn
from tqdm.auto import tqdm
try:
from torch.utils.tensorboard import SummaryWriter
except ImportError:
SummaryWriter = None
from agents.local_policy import FixedCyclePolicy, HoldPhasePolicy, QueueGreedyPolicy, RandomPhasePolicy
from env.observation_builder import ObservationConfig
from env.reward import RewardConfig
from env.traffic_env import EnvConfig, TrafficEnv
from training.cityflow_dataset import CityFlowDataset, ScenarioSpec
from training.device import configure_torch_runtime, resolve_torch_device
from training.models import POLICY_ARCHES, RunningNormalizer, TrafficControlQNetwork
from training.rollout import evaluate_policy
_EVAL_CONTEXT: dict[str, Any] = {}
@dataclass(frozen=True)
class DQNConfig:
policy_arch: str = "single_head_with_district_feature"
total_updates: int = 200
learning_rate: float = 1e-4
gamma: float = 0.99
n_step: int = 3
replay_capacity: int = 500_000
minibatch_size: int = 1024
learning_starts: int = 10_000
gradient_steps: int = 64
target_tau: float = 0.01
max_grad_norm: float = 10.0
hidden_dim: int = 256
hidden_layers: int = 2
dueling: bool = True
seed: int = 7
eval_every: int = 40
checkpoint_every: int = 40
checkpoint_on_eval: bool = True
val_scenarios_per_city: int | None = 1
max_val_cities: int | None = 5
max_train_cities: int | None = None
num_rollout_workers: int = 4
rollout_episodes_per_update: int | None = None
train_city_id: str | None = None
train_scenario_name: str | None = None
overfit_val_on_train_scenario: bool = False
rollout_decision_steps: int | None = 256
resume_from: str | None = None
use_observation_normalization: bool = True
epsilon_start: float = 1.0
epsilon_end: float = 0.05
epsilon_decay_steps: int = 50_000
prioritized_replay_alpha: float = 0.6
prioritized_replay_beta_start: float = 0.4
prioritized_replay_beta_end: float = 1.0
prioritized_replay_beta_steps: int = 200_000
compare_baselines: bool = True
skip_failed_validation_episodes: bool = True
verbose_progress: bool = False
eval_num_workers: int = -1
enable_tensorboard: bool = True
tensorboard_log_dir: str | None = None
rolling_window_size: int = 20
use_tqdm: bool = True
@dataclass
class TrainerState:
update_index: int = 0
best_validation_score: float = float("-inf")
total_decision_steps: int = 0
total_transitions: int = 0
gradient_steps: int = 0
@dataclass(frozen=True)
class StepRecord:
observation: np.ndarray
district_type_index: int
action_mask: np.ndarray
action: int
reward: float
next_observation: np.ndarray
next_district_type_index: int
next_action_mask: np.ndarray
done: bool
class PrioritizedReplayBuffer:
def __init__(
self,
capacity: int,
prioritized_alpha: float = 0.6,
epsilon: float = 1e-6,
):
self.capacity = int(capacity)
self.prioritized_alpha = float(prioritized_alpha)
self.epsilon = float(epsilon)
self.position = 0
self.size = 0
self.max_priority = 1.0
self.observations: np.ndarray | None = None
self.next_observations: np.ndarray | None = None
self.district_type_indices: np.ndarray | None = None
self.next_district_type_indices: np.ndarray | None = None
self.action_masks: np.ndarray | None = None
self.next_action_masks: np.ndarray | None = None
self.actions: np.ndarray | None = None
self.rewards: np.ndarray | None = None
self.dones: np.ndarray | None = None
self.discounts: np.ndarray | None = None
self.priorities = np.zeros(self.capacity, dtype=np.float32)
def add(
self,
observation: np.ndarray,
district_type_index: int,
action_mask: np.ndarray,
action: int,
reward: float,
next_observation: np.ndarray,
next_district_type_index: int,
next_action_mask: np.ndarray,
done: bool,
discount: float,
) -> None:
if self.observations is None:
obs_dim = observation.shape[0]
action_dim = action_mask.shape[0]
self.observations = np.zeros((self.capacity, obs_dim), dtype=np.float32)
self.next_observations = np.zeros((self.capacity, obs_dim), dtype=np.float32)
self.district_type_indices = np.zeros(self.capacity, dtype=np.int64)
self.next_district_type_indices = np.zeros(self.capacity, dtype=np.int64)
self.action_masks = np.zeros((self.capacity, action_dim), dtype=np.float32)
self.next_action_masks = np.zeros((self.capacity, action_dim), dtype=np.float32)
self.actions = np.zeros(self.capacity, dtype=np.int64)
self.rewards = np.zeros(self.capacity, dtype=np.float32)
self.dones = np.zeros(self.capacity, dtype=np.float32)
self.discounts = np.zeros(self.capacity, dtype=np.float32)
index = self.position
self.observations[index] = observation.astype(np.float32)
self.next_observations[index] = next_observation.astype(np.float32)
self.district_type_indices[index] = int(district_type_index)
self.next_district_type_indices[index] = int(next_district_type_index)
self.action_masks[index] = action_mask.astype(np.float32)
self.next_action_masks[index] = next_action_mask.astype(np.float32)
self.actions[index] = int(action)
self.rewards[index] = float(reward)
self.dones[index] = float(done)
self.discounts[index] = float(discount)
self.priorities[index] = self.max_priority
self.position = (self.position + 1) % self.capacity
self.size = min(self.size + 1, self.capacity)
def sample(self, batch_size: int, beta: float) -> dict[str, np.ndarray]:
if self.size <= 0:
raise ValueError("Cannot sample from an empty replay buffer.")
replace = self.size < batch_size
if self.prioritized_alpha > 0.0:
scaled_priorities = np.power(
np.maximum(self.priorities[: self.size], self.epsilon),
self.prioritized_alpha,
)
probabilities = scaled_priorities / scaled_priorities.sum()
indices = np.random.choice(
self.size,
size=batch_size,
replace=replace,
p=probabilities,
)
weights = np.power(self.size * probabilities[indices], -beta).astype(np.float32)
weights /= max(1.0, float(weights.max()))
else:
indices = np.random.choice(self.size, size=batch_size, replace=replace)
weights = np.ones(batch_size, dtype=np.float32)
return {
"indices": indices.astype(np.int64),
"weights": weights.astype(np.float32),
"observations": self.observations[indices],
"next_observations": self.next_observations[indices],
"district_type_indices": self.district_type_indices[indices],
"next_district_type_indices": self.next_district_type_indices[indices],
"action_masks": self.action_masks[indices],
"next_action_masks": self.next_action_masks[indices],
"actions": self.actions[indices],
"rewards": self.rewards[indices],
"dones": self.dones[indices],
"discounts": self.discounts[indices],
}
def update_priorities(self, indices: np.ndarray, td_errors: np.ndarray) -> None:
updated_priorities = np.abs(td_errors).astype(np.float32) + self.epsilon
self.priorities[indices] = updated_priorities
if updated_priorities.size:
self.max_priority = max(self.max_priority, float(updated_priorities.max()))
def state_dict(self) -> dict[str, Any]:
return {
"capacity": self.capacity,
"prioritized_alpha": self.prioritized_alpha,
"epsilon": self.epsilon,
"position": self.position,
"size": self.size,
"max_priority": self.max_priority,
"observations": self.observations,
"next_observations": self.next_observations,
"district_type_indices": self.district_type_indices,
"next_district_type_indices": self.next_district_type_indices,
"action_masks": self.action_masks,
"next_action_masks": self.next_action_masks,
"actions": self.actions,
"rewards": self.rewards,
"dones": self.dones,
"discounts": self.discounts,
"priorities": self.priorities,
}
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
self.capacity = int(state_dict["capacity"])
self.prioritized_alpha = float(state_dict["prioritized_alpha"])
self.epsilon = float(state_dict["epsilon"])
self.position = int(state_dict["position"])
self.size = int(state_dict["size"])
self.max_priority = float(state_dict["max_priority"])
self.observations = state_dict["observations"]
self.next_observations = state_dict["next_observations"]
self.district_type_indices = state_dict["district_type_indices"]
self.next_district_type_indices = state_dict["next_district_type_indices"]
self.action_masks = state_dict["action_masks"]
self.next_action_masks = state_dict["next_action_masks"]
self.actions = state_dict["actions"]
self.rewards = state_dict["rewards"]
self.dones = state_dict["dones"]
self.discounts = state_dict["discounts"]
self.priorities = state_dict["priorities"]
class DQNTrainer:
def __init__(
self,
dataset: CityFlowDataset,
env_config: EnvConfig,
dqn_config: DQNConfig,
output_dir: str | Path = "artifacts/dqn_shared",
device: str | None = None,
):
self.dataset = dataset
self.env_config = env_config
self.dqn_config = dqn_config
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.checkpoint_dir = self.output_dir / "checkpoints"
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
self.rng = random.Random(dqn_config.seed)
np.random.seed(dqn_config.seed)
torch.manual_seed(dqn_config.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(dqn_config.seed)
self.device = resolve_torch_device(device)
configure_torch_runtime(self.device)
if self.dqn_config.policy_arch not in POLICY_ARCHES:
raise ValueError(
f"Unsupported policy architecture: {self.dqn_config.policy_arch}. "
f"Expected one of {POLICY_ARCHES}."
)
self.train_city_ids = self.dataset.load_split("train")
if self.dqn_config.max_train_cities is not None:
self.train_city_ids = self.train_city_ids[: self.dqn_config.max_train_cities]
self.fixed_train_scenario_spec = self._resolve_fixed_train_scenario()
if not self.train_city_ids:
raise ValueError("No training cities available for DQN training.")
sample_spec = self._sample_train_scenario()
sample_env = self._make_env(sample_spec)
observation_dim = sample_env.observation_dim
self.q_network = TrafficControlQNetwork(
observation_dim=observation_dim,
hidden_dim=dqn_config.hidden_dim,
num_layers=dqn_config.hidden_layers,
policy_arch=dqn_config.policy_arch,
dueling=dqn_config.dueling,
).to(self.device)
self.target_network = TrafficControlQNetwork(
observation_dim=observation_dim,
hidden_dim=dqn_config.hidden_dim,
num_layers=dqn_config.hidden_layers,
policy_arch=dqn_config.policy_arch,
dueling=dqn_config.dueling,
).to(self.device)
self.target_network.load_state_dict(self.q_network.state_dict())
self.target_network.eval()
self.optimizer = torch.optim.Adam(self.q_network.parameters(), lr=dqn_config.learning_rate)
self.obs_normalizer = RunningNormalizer() if dqn_config.use_observation_normalization else None
self.replay_buffer = PrioritizedReplayBuffer(
capacity=dqn_config.replay_capacity,
prioritized_alpha=dqn_config.prioritized_replay_alpha,
)
self.state = TrainerState()
self.training_log_path = self.output_dir / "training_log.jsonl"
self.validation_log_path = self.output_dir / "validation_log.jsonl"
self.tensorboard_log_dir = Path(
dqn_config.tensorboard_log_dir or (self.output_dir / "tensorboard")
)
self.writer = self._build_tensorboard_writer()
self._rolling_metrics: dict[tuple[str, str], deque[float]] = {}
self.rollout_executor: ProcessPoolExecutor | None = None
if self.dqn_config.num_rollout_workers > 1:
self.rollout_executor = ProcessPoolExecutor(
max_workers=self.dqn_config.num_rollout_workers,
)
print(
"[setup] "
f"torch_device={self.device.type} "
f"algorithm=ps_d3qn "
f"policy_arch={self.dqn_config.policy_arch} "
f"reward_variant={self.env_config.reward.variant} "
f"rollout_workers={self.dqn_config.num_rollout_workers}"
)
if self.fixed_train_scenario_spec is not None:
print(
"[setup] "
f"fixed_train_city={self.fixed_train_scenario_spec.city_id} "
f"fixed_train_scenario={self.fixed_train_scenario_spec.scenario_name} "
f"overfit_val_on_train_scenario={self.dqn_config.overfit_val_on_train_scenario}"
)
if dqn_config.resume_from:
self.load_checkpoint(dqn_config.resume_from)
def fit(self) -> None:
progress_bar: tqdm | None = None
try:
if self.dqn_config.use_tqdm:
progress_bar = tqdm(
total=self.dqn_config.total_updates,
initial=self.state.update_index,
desc="train",
dynamic_ncols=True,
)
for update_index in range(self.state.update_index, self.dqn_config.total_updates):
rollout_start = perf_counter()
episode_records = self._collect_rollout_batch()
rollout_seconds = perf_counter() - rollout_start
update_start = perf_counter()
losses = self._optimize()
update_seconds = perf_counter() - update_start
self.state.update_index = update_index + 1
validation_seconds = 0.0
checkpoint_seconds = 0.0
train_record = self._summarize_rollout_batch(episode_records)
train_record.update(
{
"update": self.state.update_index,
"algorithm": "ps_d3qn",
"policy_arch": self.dqn_config.policy_arch,
"reward_variant": self.env_config.reward.variant,
"replay_size": float(self.replay_buffer.size),
"epsilon": float(self._epsilon()),
**losses,
}
)
self._attach_rolling_metrics(
namespace="train",
record=train_record,
keys=(
"episode_return",
"total_episode_return",
"mean_waiting_vehicles",
"throughput",
"td_loss",
"mean_q_value",
"mean_abs_td_error",
),
)
self._append_jsonl(self.training_log_path, train_record)
self._print_train_log(train_record)
self._log_tensorboard_scalars("train", train_record, self.state.update_index)
if progress_bar is not None:
progress_bar.set_postfix(
ret=f"{train_record['episode_return']:.3f}",
wait=f"{train_record['mean_waiting_vehicles']:.2f}",
td=f"{train_record['td_loss']:.4f}",
eps=f"{train_record['epsilon']:.3f}",
)
progress_bar.update(1)
should_evaluate = self.state.update_index % self.dqn_config.eval_every == 0
should_periodic_checkpoint = (
self.state.update_index % self.dqn_config.checkpoint_every == 0
)
if should_periodic_checkpoint and not (
should_evaluate and self.dqn_config.checkpoint_on_eval
):
print(f"[train] saving checkpoint at update={self.state.update_index}")
checkpoint_start = perf_counter()
self.save_checkpoint(self.checkpoint_dir / f"update_{self.state.update_index:04d}.pt")
checkpoint_seconds += perf_counter() - checkpoint_start
print(f"[train] finished checkpoint at update={self.state.update_index}")
if should_evaluate:
print(f"[train] starting validation at update={self.state.update_index}")
validation_start = perf_counter()
validation_record = self.evaluate_split("val")
validation_seconds = perf_counter() - validation_start
validation_record["update"] = self.state.update_index
validation_record["algorithm"] = "ps_d3qn"
validation_record["policy_arch"] = self.dqn_config.policy_arch
validation_record["reward_variant"] = self.env_config.reward.variant
self._attach_rolling_metrics(
namespace="eval",
record=validation_record,
keys=(
"mean_episode_return",
"mean_total_episode_return",
"mean_mean_waiting_vehicles",
"mean_throughput",
),
)
self._append_jsonl(self.validation_log_path, validation_record)
self._print_eval_log(validation_record)
self._log_tensorboard_scalars("eval", validation_record, self.state.update_index)
print(f"[train] finished validation at update={self.state.update_index}")
if self.dqn_config.checkpoint_on_eval:
print(f"[train] saving checkpoint at update={self.state.update_index}")
checkpoint_start = perf_counter()
self.save_checkpoint(
self.checkpoint_dir / f"update_{self.state.update_index:04d}.pt"
)
checkpoint_seconds += perf_counter() - checkpoint_start
print(f"[train] finished checkpoint at update={self.state.update_index}")
validation_score = float(validation_record["mean_episode_return"])
if validation_score > self.state.best_validation_score:
self.state.best_validation_score = validation_score
print(f"[train] saving checkpoint at update={self.state.update_index}")
checkpoint_start = perf_counter()
self.save_checkpoint(self.output_dir / "best_validation.pt")
checkpoint_seconds += perf_counter() - checkpoint_start
print(f"[train] finished checkpoint at update={self.state.update_index}")
print(
"[timing] "
f"rollout={rollout_seconds:.2f}s "
f"update={update_seconds:.2f}s "
f"validation={validation_seconds:.2f}s "
f"checkpoint={checkpoint_seconds:.2f}s"
)
print(f"[train] saving checkpoint at update={self.state.update_index}")
final_checkpoint_start = perf_counter()
self.save_checkpoint(self.output_dir / "last.pt")
final_checkpoint_seconds = perf_counter() - final_checkpoint_start
print(f"[train] finished checkpoint at update={self.state.update_index}")
print(f"[timing] final_checkpoint={final_checkpoint_seconds:.2f}s")
finally:
if progress_bar is not None:
progress_bar.close()
if self.rollout_executor is not None:
self.rollout_executor.shutdown(wait=True, cancel_futures=False)
if self.writer is not None:
self.writer.close()
def evaluate_split(self, split_name: str) -> dict[str, float]:
if split_name == "val" and self.dqn_config.overfit_val_on_train_scenario:
if self.fixed_train_scenario_spec is None:
raise ValueError(
"--overfit-val-on-train-scenario requires a fixed training city/scenario."
)
scenario_specs = [self.fixed_train_scenario_spec]
else:
scenario_specs = self.dataset.iter_scenarios(
split_name=split_name,
scenarios_per_city=self.dqn_config.val_scenarios_per_city,
max_cities=self.dqn_config.max_val_cities,
diversify_single_scenario=True,
)
if self._resolved_eval_workers(len(scenario_specs)) > 1:
episode_metrics = self._evaluate_policy_parallel(scenario_specs)
else:
episode_metrics = self._evaluate_policy_sequential(scenario_specs)
if not episode_metrics:
raise RuntimeError("Validation produced no successful episodes.")
aggregate = aggregate_metrics(episode_metrics)
aggregate.update(aggregate_metrics_by_scenario(episode_metrics))
if self.dqn_config.compare_baselines:
if self._resolved_eval_workers(len(scenario_specs)) > 1:
aggregate.update(self._evaluate_baselines_parallel(scenario_specs))
else:
aggregate.update(self._evaluate_baselines(scenario_specs))
if "fixed_mean_episode_return" in aggregate:
aggregate["learner_minus_fixed_return"] = (
aggregate["mean_episode_return"] - aggregate["fixed_mean_episode_return"]
)
if "random_mean_episode_return" in aggregate:
aggregate["learner_minus_random_return"] = (
aggregate["mean_episode_return"] - aggregate["random_mean_episode_return"]
)
return aggregate
def save_checkpoint(self, path: str | Path) -> None:
checkpoint = {
"algorithm": "ps_d3qn",
"q_network_state_dict": self.q_network.state_dict(),
"target_network_state_dict": self.target_network.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"trainer_state": asdict(self.state),
"dqn_config": asdict(self.dqn_config),
"network_architecture": {
"observation_dim": self.q_network.observation_dim,
"action_dim": self.q_network.action_dim,
"district_types": self.q_network.district_types,
"policy_arch": self.q_network.policy_arch,
"dueling": self.q_network.dueling,
},
"env_config": {
"simulator_interval": self.env_config.simulator_interval,
"decision_interval": self.env_config.decision_interval,
"min_green_time": self.env_config.min_green_time,
"thread_num": self.env_config.thread_num,
"max_episode_seconds": self.env_config.max_episode_seconds,
"observation": asdict(self.env_config.observation),
"reward": asdict(self.env_config.reward),
},
"obs_normalizer": self.obs_normalizer.state_dict() if self.obs_normalizer else None,
}
torch.save(checkpoint, path)
def load_checkpoint(self, path: str | Path) -> None:
checkpoint = torch.load(
path,
map_location=self.device,
weights_only=False,
)
q_state_dict = checkpoint.get("q_network_state_dict") or checkpoint.get("policy_state_dict")
if q_state_dict is None:
raise ValueError(f"Checkpoint at {path} does not contain a Q-network state dict.")
self.q_network.load_state_dict(q_state_dict)
target_state_dict = checkpoint.get("target_network_state_dict") or q_state_dict
self.target_network.load_state_dict(target_state_dict)
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
self.state = TrainerState(**checkpoint["trainer_state"])
if self.obs_normalizer and checkpoint.get("obs_normalizer"):
self.obs_normalizer.load_state_dict(checkpoint["obs_normalizer"])
def _make_env(self, scenario_spec: ScenarioSpec) -> TrafficEnv:
return TrafficEnv(
city_id=scenario_spec.city_id,
scenario_name=scenario_spec.scenario_name,
city_dir=scenario_spec.city_dir,
scenario_dir=scenario_spec.scenario_dir,
config_path=scenario_spec.config_path,
roadnet_path=scenario_spec.roadnet_path,
district_map_path=scenario_spec.district_map_path,
metadata_path=scenario_spec.metadata_path,
env_config=self.env_config,
)
def _collect_rollout_batch(self) -> list[dict[str, float | str]]:
episodes_per_update = self.dqn_config.rollout_episodes_per_update or max(
1,
self.dqn_config.num_rollout_workers,
)
scenario_specs = [self._sample_train_scenario() for _ in range(episodes_per_update)]
if self.rollout_executor is None or episodes_per_update <= 1:
episode_record = self._collect_episode(self._make_env(scenario_specs[0]))
return [episode_record]
return self._collect_rollouts_parallel(scenario_specs)
def _sample_train_scenario(self) -> ScenarioSpec:
if self.fixed_train_scenario_spec is not None:
return self.fixed_train_scenario_spec
selected_city = self.rng.choice(self.train_city_ids)
selected_scenario = self.rng.choice(self.dataset.scenarios_for_city(selected_city))
if self.dqn_config.verbose_progress:
print(f"[train] sampled city={selected_city} scenario={selected_scenario}")
return self.dataset.build_scenario_spec(selected_city, selected_scenario)
def _resolve_fixed_train_scenario(self) -> ScenarioSpec | None:
if self.dqn_config.train_city_id is None:
return None
city_id = self.dqn_config.train_city_id
available_train_cities = set(self.dataset.load_split("train"))
if city_id not in available_train_cities:
raise ValueError(
f"Fixed train city {city_id!r} is not in the train split."
)
scenario_names = self.dataset.scenarios_for_city(city_id)
scenario_name = self.dqn_config.train_scenario_name
if scenario_name is None:
scenario_name = scenario_names[0]
if scenario_name not in scenario_names:
raise ValueError(
f"Scenario {scenario_name!r} not found for train city {city_id!r}. "
f"Available: {scenario_names}"
)
self.train_city_ids = [city_id]
return self.dataset.build_scenario_spec(city_id, scenario_name)
def _collect_rollouts_parallel(
self,
scenario_specs: list[ScenarioSpec],
) -> list[dict[str, float | str]]:
if self.rollout_executor is None:
raise RuntimeError("Parallel rollout collection requested without a rollout executor.")
context = self._build_parallel_rollout_context()
epsilon = self._epsilon()
total_specs = len(scenario_specs)
episode_records: list[dict[str, float | str]] = []
futures = {
self.rollout_executor.submit(
_parallel_rollout_collection_worker,
spec,
context,
epsilon,
self.dqn_config.rollout_decision_steps,
self.dqn_config.gamma,
self.dqn_config.n_step,
): (index, spec)
for index, spec in enumerate(scenario_specs, start=1)
}
for future in as_completed(futures):
index, spec = futures[future]
result = future.result()
self._ingest_transition_batch(result["transitions"])
self.state.total_decision_steps += int(result["episode_record"]["decision_steps"])
self.state.total_transitions += int(result["episode_record"]["transitions"])
episode_records.append(result["episode_record"])
if self.dqn_config.verbose_progress:
print(
f"[rollout] city={spec.city_id} scenario={spec.scenario_name} "
f"i={index}/{total_specs}"
)
return episode_records
def _build_parallel_rollout_context(self) -> dict[str, Any]:
return {
"env_config": _env_config_to_payload(self.env_config),
"network_architecture": {
"observation_dim": self.q_network.observation_dim,
"action_dim": self.q_network.action_dim,
"hidden_dim": self.q_network.hidden_dim,
"num_layers": self.q_network.num_layers,
"district_types": self.q_network.district_types,
"policy_arch": self.q_network.policy_arch,
"dueling": self.q_network.dueling,
},
"q_network_state_dict": {
key: value.detach().cpu()
for key, value in self.q_network.state_dict().items()
},
"obs_normalizer": self.obs_normalizer.state_dict() if self.obs_normalizer else None,
}
def _ingest_transition_batch(self, transitions: dict[str, np.ndarray]) -> None:
if transitions["observations"].size == 0:
return
if self.obs_normalizer is not None:
self.obs_normalizer.update(transitions["observations"])
transition_count = transitions["actions"].shape[0]
for index in range(transition_count):
self.replay_buffer.add(
observation=transitions["observations"][index],
district_type_index=int(transitions["district_type_indices"][index]),
action_mask=transitions["action_masks"][index],
action=int(transitions["actions"][index]),
reward=float(transitions["rewards"][index]),
next_observation=transitions["next_observations"][index],
next_district_type_index=int(transitions["next_district_type_indices"][index]),
next_action_mask=transitions["next_action_masks"][index],
done=bool(transitions["dones"][index]),
discount=float(transitions["discounts"][index]),
)
def _summarize_rollout_batch(
self,
episode_records: list[dict[str, float | str]],
) -> dict[str, float | str]:
if len(episode_records) == 1:
record = dict(episode_records[0])
record["num_rollout_episodes"] = 1.0
return record
aggregate = aggregate_metrics(episode_records)
city_ids = sorted({str(record["city_id"]) for record in episode_records})
scenario_names = sorted({str(record["scenario_name"]) for record in episode_records})
summary: dict[str, float | str] = {
"city_id": city_ids[0] if len(city_ids) == 1 else f"{len(city_ids)}_cities",
"scenario_name": scenario_names[0]
if len(scenario_names) == 1
else f"{len(scenario_names)}_scenarios",
"num_rollout_episodes": float(len(episode_records)),
}
for source_key, target_key in (
("mean_decision_steps", "decision_steps"),
("mean_transitions", "transitions"),
("mean_episode_return", "episode_return"),
("mean_total_episode_return", "total_episode_return"),
("mean_mean_waiting_vehicles", "mean_waiting_vehicles"),
("mean_throughput", "throughput"),
("mean_mean_q_value", "mean_q_value"),
("mean_epsilon", "epsilon"),
("mean_replay_size", "replay_size"),
):
if source_key in aggregate:
summary[target_key] = aggregate[source_key]
for key, value in aggregate.items():
if key not in summary:
summary[key] = value
return summary
def _collect_episode(self, env: TrafficEnv) -> dict[str, float | str]:
observation_batch = env.reset()
decision_steps = 0
transitions_added = 0
q_value_samples: list[float] = []
n_step_buffers = [
deque() for _ in range(len(observation_batch["intersection_ids"]))
]
epsilon = self._epsilon()
last_info = env.last_info
done = False
while not done:
if (
self.dqn_config.rollout_decision_steps is not None
and decision_steps >= self.dqn_config.rollout_decision_steps
):
break
raw_obs = observation_batch["observations"].astype(np.float32)
if self.obs_normalizer is not None:
self.obs_normalizer.update(raw_obs)
normalized_obs = self.obs_normalizer.normalize(raw_obs)
else:
normalized_obs = raw_obs
obs_tensor = torch.as_tensor(normalized_obs, dtype=torch.float32, device=self.device)
district_type_tensor = torch.as_tensor(
observation_batch["district_type_indices"],
dtype=torch.int64,
device=self.device,
)
action_mask_tensor = torch.as_tensor(
observation_batch["action_mask"],
dtype=torch.float32,
device=self.device,
)
with torch.no_grad():
q_values = self.q_network.forward(
observations=obs_tensor,
district_type_indices=district_type_tensor,
action_mask=action_mask_tensor,
)
action_tensor = self.q_network.act(
observations=obs_tensor,
district_type_indices=district_type_tensor,
action_mask=action_mask_tensor,
deterministic=False,
epsilon=epsilon,
)
q_value_samples.append(float(q_values.max(dim=-1).values.mean().detach().cpu()))
actions = action_tensor.detach().cpu().numpy()
next_observation_batch, rewards, done, info = env.step(actions)
transitions_added += self._append_step_records(
buffers=n_step_buffers,
observation_batch=observation_batch,
actions=actions,
rewards=np.asarray(rewards, dtype=np.float32),
next_observation_batch=next_observation_batch,
done=done,
)
observation_batch = next_observation_batch
last_info = info
decision_steps += 1
self.state.total_decision_steps += 1
epsilon = self._epsilon()
transitions_added += self._flush_n_step_buffers(n_step_buffers)
self.state.total_transitions += transitions_added
episode_metrics = {
key: float(value)
for key, value in last_info["metrics"].items()
if value is not None and isinstance(value, (int, float))
}
episode_metrics.update(
{
"city_id": env.city_id,
"scenario_name": env.scenario_name,
"decision_steps": decision_steps,
"transitions": transitions_added,
"episode_return": float(env.episode_return),
"total_episode_return": float(env.total_episode_return),
"epsilon": float(epsilon),
"replay_size": float(self.replay_buffer.size),
"mean_q_value": float(np.mean(q_value_samples)) if q_value_samples else 0.0,
}
)
return episode_metrics
def _append_step_records(
self,
buffers: list[deque[StepRecord]],
observation_batch: dict[str, Any],
actions: np.ndarray,
rewards: np.ndarray,
next_observation_batch: dict[str, Any],
done: bool,
) -> int:
transitions_added = 0
for row_index, buffer in enumerate(buffers):
record = StepRecord(
observation=observation_batch["observations"][row_index].astype(np.float32),
district_type_index=int(observation_batch["district_type_indices"][row_index]),
action_mask=observation_batch["action_mask"][row_index].astype(np.float32),
action=int(actions[row_index]),
reward=float(rewards[row_index]),
next_observation=next_observation_batch["observations"][row_index].astype(np.float32),
next_district_type_index=int(next_observation_batch["district_type_indices"][row_index]),
next_action_mask=next_observation_batch["action_mask"][row_index].astype(np.float32),
done=bool(done),
)
buffer.append(record)
if len(buffer) >= self.dqn_config.n_step:
self._push_n_step_transition(buffer, steps=self.dqn_config.n_step)
transitions_added += 1
return transitions_added
def _flush_n_step_buffers(self, buffers: list[deque[StepRecord]]) -> int:
transitions_added = 0
for buffer in buffers:
while buffer:
self._push_n_step_transition(buffer, steps=len(buffer))
transitions_added += 1
return transitions_added
def _push_n_step_transition(self, buffer: deque[StepRecord], steps: int) -> None:
records = list(islice(buffer, 0, steps))
reward = 0.0
for step_index, record in enumerate(records):
reward += (self.dqn_config.gamma ** step_index) * float(record.reward)
first_record = records[0]
last_record = records[-1]
discount = self.dqn_config.gamma ** len(records)
self.replay_buffer.add(
observation=first_record.observation,
district_type_index=first_record.district_type_index,
action_mask=first_record.action_mask,
action=first_record.action,
reward=reward,
next_observation=last_record.next_observation,
next_district_type_index=last_record.next_district_type_index,
next_action_mask=last_record.next_action_mask,
done=last_record.done,
discount=discount,
)
buffer.popleft()
def _optimize(self) -> dict[str, float]:
minimum_replay = max(self.dqn_config.learning_starts, self.dqn_config.minibatch_size)
if self.replay_buffer.size < minimum_replay:
return {
"td_loss": 0.0,
"mean_abs_td_error": 0.0,
"mean_target_q": 0.0,
"mean_q_value": 0.0,
"beta": self._beta(),
"gradient_steps": 0.0,
}
batch_size = min(self.dqn_config.minibatch_size, self.replay_buffer.size)
td_losses: list[float] = []
td_errors: list[float] = []
target_values: list[float] = []
q_values: list[float] = []
beta = self._beta()
for _ in range(self.dqn_config.gradient_steps):
batch = self.replay_buffer.sample(batch_size=batch_size, beta=beta)
observations = batch["observations"]
next_observations = batch["next_observations"]
if self.obs_normalizer is not None:
observations = self.obs_normalizer.normalize(observations)
next_observations = self.obs_normalizer.normalize(next_observations)
obs_tensor = torch.as_tensor(observations, dtype=torch.float32, device=self.device)
next_obs_tensor = torch.as_tensor(next_observations, dtype=torch.float32, device=self.device)
district_type_tensor = torch.as_tensor(
batch["district_type_indices"],
dtype=torch.int64,
device=self.device,
)
next_district_type_tensor = torch.as_tensor(
batch["next_district_type_indices"],
dtype=torch.int64,
device=self.device,
)
action_mask_tensor = torch.as_tensor(batch["action_masks"], dtype=torch.float32, device=self.device)
next_action_mask_tensor = torch.as_tensor(
batch["next_action_masks"],
dtype=torch.float32,
device=self.device,
)
action_tensor = torch.as_tensor(batch["actions"], dtype=torch.int64, device=self.device)
reward_tensor = torch.as_tensor(batch["rewards"], dtype=torch.float32, device=self.device)
done_tensor = torch.as_tensor(batch["dones"], dtype=torch.float32, device=self.device)
discount_tensor = torch.as_tensor(batch["discounts"], dtype=torch.float32, device=self.device)
weight_tensor = torch.as_tensor(batch["weights"], dtype=torch.float32, device=self.device)
predicted_q = self.q_network.q_values_for_actions(
observations=obs_tensor,
district_type_indices=district_type_tensor,
actions=action_tensor,
action_mask=action_mask_tensor,
)
with torch.no_grad():
next_online_q = self.q_network.forward(
observations=next_obs_tensor,
district_type_indices=next_district_type_tensor,
action_mask=next_action_mask_tensor,
)
next_actions = next_online_q.argmax(dim=-1)
next_target_q = self.target_network.forward(
observations=next_obs_tensor,
district_type_indices=next_district_type_tensor,
action_mask=next_action_mask_tensor,
).gather(dim=1, index=next_actions.view(-1, 1)).squeeze(1)
target_q = reward_tensor + (1.0 - done_tensor) * discount_tensor * next_target_q
td_error = target_q - predicted_q
per_sample_loss = nn.functional.smooth_l1_loss(
predicted_q,
target_q,
reduction="none",
)
loss = (weight_tensor * per_sample_loss).mean()
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.q_network.parameters(), self.dqn_config.max_grad_norm)
self.optimizer.step()
self._soft_update_target()
self.replay_buffer.update_priorities(
batch["indices"],
td_errors=np.abs(td_error.detach().cpu().numpy()),
)
self.state.gradient_steps += 1
td_losses.append(float(loss.detach().cpu()))
td_errors.append(float(torch.abs(td_error).mean().detach().cpu()))
target_values.append(float(target_q.mean().detach().cpu()))
q_values.append(float(predicted_q.mean().detach().cpu()))
return {
"td_loss": float(np.mean(td_losses)),
"mean_abs_td_error": float(np.mean(td_errors)),
"mean_target_q": float(np.mean(target_values)),
"mean_q_value": float(np.mean(q_values)),
"beta": float(beta),
"gradient_steps": float(self.dqn_config.gradient_steps),
}
def _soft_update_target(self) -> None:
tau = float(self.dqn_config.target_tau)
with torch.no_grad():
for target_param, online_param in zip(
self.target_network.parameters(),
self.q_network.parameters(),
strict=True,
):
target_param.data.mul_(1.0 - tau).add_(online_param.data, alpha=tau)
def _epsilon(self) -> float:
if self.dqn_config.epsilon_decay_steps <= 0:
return float(self.dqn_config.epsilon_end)
progress = min(1.0, self.state.total_decision_steps / float(self.dqn_config.epsilon_decay_steps))
return float(
self.dqn_config.epsilon_start
+ progress * (self.dqn_config.epsilon_end - self.dqn_config.epsilon_start)
)
def _beta(self) -> float:
if self.dqn_config.prioritized_replay_beta_steps <= 0:
return float(self.dqn_config.prioritized_replay_beta_end)
progress = min(
1.0,
self.state.total_decision_steps / float(self.dqn_config.prioritized_replay_beta_steps),
)
return float(
self.dqn_config.prioritized_replay_beta_start
+ progress
* (
self.dqn_config.prioritized_replay_beta_end
- self.dqn_config.prioritized_replay_beta_start
)
)
def _evaluate_policy_sequential(
self,
scenario_specs: list[ScenarioSpec],
) -> list[dict[str, float | str]]:
episode_metrics: list[dict[str, float | str]] = []
total_specs = len(scenario_specs)
iterator = enumerate(scenario_specs, start=1)
if self.dqn_config.use_tqdm:
iterator = tqdm(
iterator,
total=total_specs,
desc="eval:learned",
leave=False,
dynamic_ncols=True,
)
for index, spec in iterator:
print(f"[eval] city={spec.city_id} scenario={spec.scenario_name} i={index}/{total_specs}")
try:
episode_metrics.append(
evaluate_policy(
env_factory=lambda spec=spec: self._make_env(spec),
actor=self.q_network,
device=self.device,
obs_normalizer=self.obs_normalizer,
deterministic=True,
)
)
except Exception as exc:
self._handle_eval_failure("validation", spec, exc)
return episode_metrics
def _evaluate_policy_parallel(
self,
scenario_specs: list[ScenarioSpec],
) -> list[dict[str, float | str]]:
resolved_workers = self._resolved_eval_workers(len(scenario_specs))
print(f"[eval] learned_workers={resolved_workers}")
return self._run_parallel_eval(
scenario_specs=scenario_specs,
worker_kind="learned",
initializer=_init_parallel_learned_eval_worker,
initargs=(self._build_parallel_learned_eval_context(),),
max_workers=resolved_workers,
)
def _append_jsonl(self, path: Path, record: dict) -> None:
with path.open("a") as handle:
handle.write(json.dumps(record) + "\n")
def _build_tensorboard_writer(self) -> SummaryWriter | None:
if not self.dqn_config.enable_tensorboard:
return None
if SummaryWriter is None:
print("[setup] tensorboard_disabled=torch.utils.tensorboard unavailable")
return None
self.tensorboard_log_dir.mkdir(parents=True, exist_ok=True)
return SummaryWriter(log_dir=str(self.tensorboard_log_dir))
def _log_tensorboard_scalars(
self,
namespace: str,
record: dict[str, Any],
step: int,
) -> None:
if self.writer is None:
return
for key, value in record.items():
if isinstance(value, (int, float)):
self.writer.add_scalar(f"{namespace}/{key}", float(value), step)
self.writer.flush()
def _attach_rolling_metrics(
self,
namespace: str,
record: dict[str, Any],
keys: tuple[str, ...],
) -> None:
for key in keys:
value = record.get(key)
if not isinstance(value, (int, float)):
continue
window = self._rolling_metrics.setdefault(
(namespace, key),
deque(maxlen=self.dqn_config.rolling_window_size),
)
window.append(float(value))
record[f"rolling_{key}"] = float(np.mean(window))
def _evaluate_baselines(self, scenario_specs: list[ScenarioSpec]) -> dict[str, float]:
baseline_metrics: dict[str, float] = {}
for baseline_name in ("random", "fixed"):
metrics: list[dict[str, float | str]] = []
total_specs = len(scenario_specs)
for offset, spec in enumerate(scenario_specs, start=1):
print(
f"[eval] baseline={baseline_name} city={spec.city_id} "
f"scenario={spec.scenario_name} i={offset}/{total_specs}"
)
try:
actor = (
RandomPhasePolicy(seed=self.dqn_config.seed + offset)
if baseline_name == "random"
else FixedCyclePolicy(green_time=max(20, self.env_config.min_green_time * 2))
)
metrics.append(
evaluate_policy(
env_factory=lambda spec=spec: self._make_env(spec),
actor=actor,
)
)
except Exception as exc:
message = (
f"[warn] baseline={baseline_name} failed for city={spec.city_id} "
f"scenario={spec.scenario_name}: {exc}"
)
if self.dqn_config.skip_failed_validation_episodes:
print(message)
continue
raise RuntimeError(message) from exc
if not metrics:
continue
aggregate = aggregate_metrics(metrics)
for key, value in aggregate.items():
baseline_metrics[f"{baseline_name}_{key}"] = value
return baseline_metrics
def _evaluate_baselines_parallel(self, scenario_specs: list[ScenarioSpec]) -> dict[str, float]:
baseline_metrics: dict[str, float] = {}
resolved_workers = self._resolved_eval_workers(len(scenario_specs))
print(f"[eval] baseline_workers={resolved_workers}")
for baseline_name in ("random", "fixed"):
metrics = self._run_parallel_eval(
scenario_specs=scenario_specs,
worker_kind=baseline_name,
initializer=_init_parallel_baseline_worker,
initargs=(self._build_parallel_baseline_context(baseline_name),),
max_workers=resolved_workers,
)
if not metrics:
continue
aggregate = aggregate_metrics(metrics)
for key, value in aggregate.items():
baseline_metrics[f"{baseline_name}_{key}"] = value
return baseline_metrics
def _run_parallel_eval(
self,
scenario_specs: list[ScenarioSpec],
worker_kind: str,
initializer,
initargs: tuple[Any, ...],
max_workers: int,
) -> list[dict[str, float | str]]:
metrics: list[dict[str, float | str]] = []
total_specs = len(scenario_specs)
with ProcessPoolExecutor(
max_workers=max_workers,
initializer=initializer,
initargs=initargs,
) as executor:
futures = {
executor.submit(_parallel_eval_worker, spec, index, worker_kind): (spec, index)
for index, spec in enumerate(scenario_specs, start=1)
}
iterator = as_completed(futures)
if self.dqn_config.use_tqdm:
iterator = tqdm(
iterator,
total=total_specs,
desc=f"eval:{worker_kind}",
leave=False,
dynamic_ncols=True,
)
for future in iterator:
spec, index = futures[future]
try:
result = future.result()
except Exception as exc:
self._handle_eval_failure(worker_kind, spec, exc)
continue
prefix = f"[eval] baseline={worker_kind}"
print(f"{prefix} city={spec.city_id} scenario={spec.scenario_name} i={index}/{total_specs}")
metrics.append(result)
return metrics
def _handle_eval_failure(
self,
phase: str,
spec: ScenarioSpec,
exc: Exception,
) -> None:
message = f"[warn] {phase} failed for city={spec.city_id} scenario={spec.scenario_name}: {exc}"
if self.dqn_config.skip_failed_validation_episodes:
print(message)
return
raise RuntimeError(message) from exc
def _build_parallel_baseline_context(self, baseline_name: str) -> dict[str, Any]:
return {
"env_config": _env_config_to_payload(self.env_config),
"baseline_name": baseline_name,
"fixed_green_time": max(20, self.env_config.min_green_time * 2),
"seed": self.dqn_config.seed,
}
def _build_parallel_learned_eval_context(self) -> dict[str, Any]:
return {
"env_config": _env_config_to_payload(self.env_config),
"network_architecture": {
"observation_dim": self.q_network.observation_dim,
"action_dim": self.q_network.action_dim,
"hidden_dim": self.q_network.hidden_dim,
"num_layers": self.q_network.num_layers,
"district_types": self.q_network.district_types,
"policy_arch": self.q_network.policy_arch,
"dueling": self.q_network.dueling,
},
"q_network_state_dict": {
key: value.detach().cpu()
for key, value in self.q_network.state_dict().items()
},
"obs_normalizer": self.obs_normalizer.state_dict() if self.obs_normalizer else None,
}
def _resolved_eval_workers(self, total_specs: int) -> int:
requested = self.dqn_config.eval_num_workers
if requested == -1:
requested = os.cpu_count() or 1
if requested <= 1:
return 1
return min(requested, total_specs)
def _print_train_log(self, record: dict[str, float | str]) -> None:
message = (
"[train] "
f"update={record['update']} algo={record['algorithm']} arch={record['policy_arch']} "
f"reward={record['reward_variant']} episodes={int(record.get('num_rollout_episodes', 1.0))} "
f"city={record['city_id']} scenario={record['scenario_name']} "
f"mean_return={record['episode_return']:.3f} "
f"(avg={record.get('rolling_episode_return', record['episode_return']):.3f}) "
f"wait={record['mean_waiting_vehicles']:.3f} "
f"(avg={record.get('rolling_mean_waiting_vehicles', record['mean_waiting_vehicles']):.3f}) "
f"throughput={record['throughput']:.1f} "
f"(avg={record.get('rolling_throughput', record['throughput']):.1f}) "
f"epsilon={record['epsilon']:.3f} replay={int(record['replay_size'])} "
f"td_loss={record['td_loss']:.4f} "
f"(avg={record.get('rolling_td_loss', record['td_loss']):.4f}) "
f"q={record['mean_q_value']:.4f} "
f"td_err={record['mean_abs_td_error']:.4f}"
)
if self.dqn_config.use_tqdm:
tqdm.write(message)
else:
print(message)
def _print_eval_log(self, record: dict[str, float]) -> None:
message = (
"[eval] "
f"algo={record['algorithm']} arch={record['policy_arch']} reward={record['reward_variant']} "
f"episodes={int(record['num_episodes'])} "
f"mean_return={record['mean_episode_return']:.3f} "
f"(avg={record.get('rolling_mean_episode_return', record['mean_episode_return']):.3f}) "
f"wait={record['mean_mean_waiting_vehicles']:.3f} "
f"throughput={record['mean_throughput']:.1f} "
f"travel_time={record.get('mean_average_travel_time', float('nan')):.3f}"
)
if self.dqn_config.compare_baselines:
message += (
f" fixed={record.get('fixed_mean_episode_return', float('nan')):.3f}"
f" random={record.get('random_mean_episode_return', float('nan')):.3f}"
f" vs_fixed={record.get('learner_minus_fixed_return', float('nan')):.3f}"
f" vs_random={record.get('learner_minus_random_return', float('nan')):.3f}"
)
if self.dqn_config.use_tqdm:
tqdm.write(message)
else:
print(message)
scenario_summaries = []
for scenario_name in (
"accident",
"construction",
"district_overload",
"evening_rush",
"event_spike",
"morning_rush",
"normal",
):
key = f"scenario_{scenario_name}_mean_episode_return"
if key in record:
scenario_summaries.append(f"{scenario_name}={record[key]:.3f}")
if scenario_summaries:
if self.dqn_config.use_tqdm:
tqdm.write("[eval] scenario_returns " + " ".join(scenario_summaries))
else:
print("[eval] scenario_returns " + " ".join(scenario_summaries))
def aggregate_metrics(metrics: list[dict[str, float | str]]) -> dict[str, float]:
numeric_keys = {
key
for item in metrics
for key, value in item.items()
if isinstance(value, (int, float))
}
aggregate = {"num_episodes": float(len(metrics))}
for key in sorted(numeric_keys):
aggregate[f"mean_{key}"] = float(
np.mean([float(item[key]) for item in metrics if key in item])
)
return aggregate
def aggregate_metrics_by_scenario(metrics: list[dict[str, float | str]]) -> dict[str, float]:
scenario_names = sorted(
{
str(item["scenario_name"])
for item in metrics
if isinstance(item.get("scenario_name"), str)
}
)
aggregate: dict[str, float] = {}
for scenario_name in scenario_names:
scenario_metrics = [item for item in metrics if item.get("scenario_name") == scenario_name]
if not scenario_metrics:
continue
scenario_aggregate = aggregate_metrics(scenario_metrics)
for key, value in scenario_aggregate.items():
aggregate[f"scenario_{scenario_name}_{key}"] = value
return aggregate
def _env_config_to_payload(env_config: EnvConfig) -> dict[str, Any]:
return {
"simulator_interval": env_config.simulator_interval,
"decision_interval": env_config.decision_interval,
"min_green_time": env_config.min_green_time,
"thread_num": env_config.thread_num,
"max_episode_seconds": env_config.max_episode_seconds,
"observation": asdict(env_config.observation),
"reward": asdict(env_config.reward),
}
def _env_config_from_payload(payload: dict[str, Any]) -> EnvConfig:
return EnvConfig(
simulator_interval=payload["simulator_interval"],
decision_interval=payload["decision_interval"],
min_green_time=payload["min_green_time"],
thread_num=payload["thread_num"],
max_episode_seconds=payload["max_episode_seconds"],
observation=ObservationConfig(**payload["observation"]),
reward=RewardConfig(**payload["reward"]),
)
def _init_parallel_baseline_worker(context: dict[str, Any]) -> None:
_init_parallel_eval_worker_from_context(context)
def _init_parallel_learned_eval_worker(context: dict[str, Any]) -> None:
_init_parallel_eval_worker_from_context(context)
def _build_standalone_eval_context(
env_config: EnvConfig,
actor: TrafficControlQNetwork | RandomPhasePolicy | FixedCyclePolicy | HoldPhasePolicy | QueueGreedyPolicy,
obs_normalizer: RunningNormalizer | None,
device: torch.device,
seed: int,
fixed_green_time: int,
baseline_name: str | None,
) -> dict[str, Any]:
del device
if baseline_name is not None:
return {
"env_config": _env_config_to_payload(env_config),
"baseline_name": baseline_name,
"fixed_green_time": fixed_green_time,
"seed": seed,
}
if not isinstance(actor, TrafficControlQNetwork):
raise ValueError("Standalone parallel learned evaluation requires a Q-network actor.")
return {
"env_config": _env_config_to_payload(env_config),
"network_architecture": {
"observation_dim": actor.observation_dim,
"action_dim": actor.action_dim,
"hidden_dim": actor.hidden_dim,
"num_layers": actor.num_layers,
"district_types": actor.district_types,
"policy_arch": actor.policy_arch,
"dueling": actor.dueling,
},
"q_network_state_dict": {
key: value.detach().cpu()
for key, value in actor.state_dict().items()
},
"obs_normalizer": obs_normalizer.state_dict() if obs_normalizer else None,
}
def _init_parallel_eval_worker_from_context(context: dict[str, Any]) -> None:
global _EVAL_CONTEXT
env_config = _env_config_from_payload(context["env_config"])
if "baseline_name" in context:
baseline_name = context["baseline_name"]
if baseline_name == "random":
actor = RandomPhasePolicy(seed=context["seed"])
elif baseline_name == "fixed":
actor = FixedCyclePolicy(green_time=context["fixed_green_time"])
elif baseline_name == "hold":
actor = HoldPhasePolicy()
elif baseline_name == "queue_greedy":
actor = QueueGreedyPolicy()
else:
raise ValueError(f"Unsupported baseline worker kind: {baseline_name}")
obs_normalizer = None
else:
architecture = context["network_architecture"]
actor = TrafficControlQNetwork(
observation_dim=architecture["observation_dim"],
action_dim=architecture["action_dim"],
hidden_dim=architecture["hidden_dim"],
num_layers=architecture["num_layers"],
district_types=tuple(architecture["district_types"]),
policy_arch=architecture["policy_arch"],
dueling=bool(architecture.get("dueling", True)),
).to(torch.device("cpu"))
actor.load_state_dict(context["q_network_state_dict"])
actor.eval()
obs_normalizer = None
if context.get("obs_normalizer"):
obs_normalizer = RunningNormalizer()
obs_normalizer.load_state_dict(context["obs_normalizer"])
_EVAL_CONTEXT = {
"env_config": env_config,
"actor": actor,
"obs_normalizer": obs_normalizer,
}
def _parallel_eval_worker(
scenario_spec: ScenarioSpec,
index: int,
worker_kind: str,
) -> dict[str, float | str]:
del index, worker_kind
env_config = _EVAL_CONTEXT["env_config"]
actor = _EVAL_CONTEXT["actor"]
obs_normalizer = _EVAL_CONTEXT["obs_normalizer"]
return evaluate_policy(
env_factory=lambda: TrafficEnv(
city_id=scenario_spec.city_id,
scenario_name=scenario_spec.scenario_name,
city_dir=scenario_spec.city_dir,
scenario_dir=scenario_spec.scenario_dir,
config_path=scenario_spec.config_path,
roadnet_path=scenario_spec.roadnet_path,
district_map_path=scenario_spec.district_map_path,
metadata_path=scenario_spec.metadata_path,
env_config=env_config,
),
actor=actor,
device=torch.device("cpu"),
obs_normalizer=obs_normalizer,
deterministic=True,
)
def _parallel_rollout_collection_worker(
scenario_spec: ScenarioSpec,
context: dict[str, Any],
epsilon: float,
max_decision_steps: int | None,
gamma: float,
n_step: int,
) -> dict[str, Any]:
env_config = _env_config_from_payload(context["env_config"])
architecture = context["network_architecture"]
q_network = TrafficControlQNetwork(
observation_dim=architecture["observation_dim"],
action_dim=architecture["action_dim"],
hidden_dim=architecture["hidden_dim"],
num_layers=architecture["num_layers"],
district_types=tuple(architecture["district_types"]),
policy_arch=architecture["policy_arch"],
dueling=bool(architecture.get("dueling", True)),
).to(torch.device("cpu"))
q_network.load_state_dict(context["q_network_state_dict"])
q_network.eval()
obs_normalizer = None
if context.get("obs_normalizer"):
obs_normalizer = RunningNormalizer()
obs_normalizer.load_state_dict(context["obs_normalizer"])
env = TrafficEnv(
city_id=scenario_spec.city_id,
scenario_name=scenario_spec.scenario_name,
city_dir=scenario_spec.city_dir,
scenario_dir=scenario_spec.scenario_dir,
config_path=scenario_spec.config_path,
roadnet_path=scenario_spec.roadnet_path,
district_map_path=scenario_spec.district_map_path,
metadata_path=scenario_spec.metadata_path,
env_config=env_config,
)
return _collect_episode_trajectory(
env=env,
q_network=q_network,
obs_normalizer=obs_normalizer,
epsilon=epsilon,
max_decision_steps=max_decision_steps,
gamma=gamma,
n_step=n_step,
device=torch.device("cpu"),
)
def _collect_episode_trajectory(
env: TrafficEnv,
q_network: TrafficControlQNetwork,
obs_normalizer: RunningNormalizer | None,
epsilon: float,
max_decision_steps: int | None,
gamma: float,
n_step: int,
device: torch.device,
) -> dict[str, Any]:
observation_batch = env.reset()
n_step_buffers = [
deque() for _ in range(len(observation_batch["intersection_ids"]))
]
q_value_samples: list[float] = []
transition_records: list[tuple[np.ndarray, int, np.ndarray, int, float, np.ndarray, int, np.ndarray, bool, float]] = []
done = False
decision_steps = 0
last_info = env.last_info
while not done:
if max_decision_steps is not None and decision_steps >= max_decision_steps:
break
raw_obs = observation_batch["observations"].astype(np.float32)
normalized_obs = obs_normalizer.normalize(raw_obs) if obs_normalizer else raw_obs
obs_tensor = torch.as_tensor(normalized_obs, dtype=torch.float32, device=device)
district_type_tensor = torch.as_tensor(
observation_batch["district_type_indices"],
dtype=torch.int64,
device=device,
)
action_mask_tensor = torch.as_tensor(
observation_batch["action_mask"],
dtype=torch.float32,
device=device,
)
with torch.no_grad():
q_values = q_network.forward(
observations=obs_tensor,
district_type_indices=district_type_tensor,
action_mask=action_mask_tensor,
)
actions = q_network.act(
observations=obs_tensor,
district_type_indices=district_type_tensor,
action_mask=action_mask_tensor,
deterministic=False,
epsilon=epsilon,
).cpu().numpy()
q_value_samples.append(float(q_values.max(dim=-1).values.mean().detach().cpu()))
next_observation_batch, rewards, done, info = env.step(actions)
transition_records.extend(
_build_n_step_transitions(
buffers=n_step_buffers,
observation_batch=observation_batch,
actions=actions,
rewards=np.asarray(rewards, dtype=np.float32),
next_observation_batch=next_observation_batch,
done=done,
gamma=gamma,
n_step=n_step,
)
)
observation_batch = next_observation_batch
last_info = info
decision_steps += 1
transition_records.extend(
_flush_n_step_transition_buffers(
buffers=n_step_buffers,
gamma=gamma,
)
)
episode_metrics = {
key: float(value)
for key, value in last_info["metrics"].items()
if value is not None and isinstance(value, (int, float))
}
episode_record = {
**episode_metrics,
"city_id": env.city_id,
"scenario_name": env.scenario_name,
"decision_steps": decision_steps,
"transitions": len(transition_records),
"episode_return": float(env.episode_return),
"total_episode_return": float(env.total_episode_return),
"epsilon": float(epsilon),
"mean_q_value": float(np.mean(q_value_samples)) if q_value_samples else 0.0,
}
return {
"episode_record": episode_record,
"transitions": _pack_transition_records(transition_records, env.observation_dim),
}
def _build_n_step_transitions(
buffers: list[deque[StepRecord]],
observation_batch: dict[str, Any],
actions: np.ndarray,
rewards: np.ndarray,
next_observation_batch: dict[str, Any],
done: bool,
gamma: float,
n_step: int,
) -> list[tuple[np.ndarray, int, np.ndarray, int, float, np.ndarray, int, np.ndarray, bool, float]]:
transition_records: list[tuple[np.ndarray, int, np.ndarray, int, float, np.ndarray, int, np.ndarray, bool, float]] = []
for row_index, buffer in enumerate(buffers):
record = StepRecord(
observation=observation_batch["observations"][row_index].astype(np.float32),
district_type_index=int(observation_batch["district_type_indices"][row_index]),
action_mask=observation_batch["action_mask"][row_index].astype(np.float32),
action=int(actions[row_index]),
reward=float(rewards[row_index]),
next_observation=next_observation_batch["observations"][row_index].astype(np.float32),
next_district_type_index=int(next_observation_batch["district_type_indices"][row_index]),
next_action_mask=next_observation_batch["action_mask"][row_index].astype(np.float32),
done=bool(done),
)
buffer.append(record)
if len(buffer) >= n_step:
transition_records.append(_make_transition_from_buffer(buffer, steps=n_step, gamma=gamma))
buffer.popleft()
return transition_records
def _flush_n_step_transition_buffers(
buffers: list[deque[StepRecord]],
gamma: float,
) -> list[tuple[np.ndarray, int, np.ndarray, int, float, np.ndarray, int, np.ndarray, bool, float]]:
transition_records: list[tuple[np.ndarray, int, np.ndarray, int, float, np.ndarray, int, np.ndarray, bool, float]] = []
for buffer in buffers:
while buffer:
transition_records.append(
_make_transition_from_buffer(buffer, steps=len(buffer), gamma=gamma)
)
buffer.popleft()
return transition_records
def _make_transition_from_buffer(
buffer: deque[StepRecord],
steps: int,
gamma: float,
) -> tuple[np.ndarray, int, np.ndarray, int, float, np.ndarray, int, np.ndarray, bool, float]:
records = list(islice(buffer, 0, steps))
reward = 0.0
for step_index, record in enumerate(records):
reward += (gamma ** step_index) * float(record.reward)
first_record = records[0]
last_record = records[-1]
discount = gamma ** len(records)
return (
first_record.observation,
first_record.district_type_index,
first_record.action_mask,
first_record.action,
reward,
last_record.next_observation,
last_record.next_district_type_index,
last_record.next_action_mask,
last_record.done,
discount,
)
def _pack_transition_records(
transition_records: list[tuple[np.ndarray, int, np.ndarray, int, float, np.ndarray, int, np.ndarray, bool, float]],
observation_dim: int,
) -> dict[str, np.ndarray]:
if not transition_records:
return {
"observations": np.zeros((0, observation_dim), dtype=np.float32),
"district_type_indices": np.zeros(0, dtype=np.int64),
"action_masks": np.zeros((0, 2), dtype=np.float32),
"actions": np.zeros(0, dtype=np.int64),
"rewards": np.zeros(0, dtype=np.float32),
"next_observations": np.zeros((0, observation_dim), dtype=np.float32),
"next_district_type_indices": np.zeros(0, dtype=np.int64),
"next_action_masks": np.zeros((0, 2), dtype=np.float32),
"dones": np.zeros(0, dtype=np.float32),
"discounts": np.zeros(0, dtype=np.float32),
}
observations = np.stack([record[0] for record in transition_records]).astype(np.float32)
district_type_indices = np.asarray([record[1] for record in transition_records], dtype=np.int64)
action_masks = np.stack([record[2] for record in transition_records]).astype(np.float32)
actions = np.asarray([record[3] for record in transition_records], dtype=np.int64)
rewards = np.asarray([record[4] for record in transition_records], dtype=np.float32)
next_observations = np.stack([record[5] for record in transition_records]).astype(np.float32)
next_district_type_indices = np.asarray([record[6] for record in transition_records], dtype=np.int64)
next_action_masks = np.stack([record[7] for record in transition_records]).astype(np.float32)
dones = np.asarray([record[8] for record in transition_records], dtype=np.float32)
discounts = np.asarray([record[9] for record in transition_records], dtype=np.float32)
return {
"observations": observations,
"district_type_indices": district_type_indices,
"action_masks": action_masks,
"actions": actions,
"rewards": rewards,
"next_observations": next_observations,
"next_district_type_indices": next_district_type_indices,
"next_action_masks": next_action_masks,
"dones": dones,
"discounts": discounts,
}