from __future__ import annotations from collections import deque from concurrent.futures import ProcessPoolExecutor, as_completed import json from itertools import islice import os import random from dataclasses import asdict, dataclass from pathlib import Path from time import perf_counter from typing import Any import numpy as np import torch from torch import nn from tqdm.auto import tqdm try: from torch.utils.tensorboard import SummaryWriter except ImportError: SummaryWriter = None from agents.local_policy import FixedCyclePolicy, HoldPhasePolicy, QueueGreedyPolicy, RandomPhasePolicy from env.observation_builder import ObservationConfig from env.reward import RewardConfig from env.traffic_env import EnvConfig, TrafficEnv from training.cityflow_dataset import CityFlowDataset, ScenarioSpec from training.device import configure_torch_runtime, resolve_torch_device from training.models import POLICY_ARCHES, RunningNormalizer, TrafficControlQNetwork from training.rollout import evaluate_policy _EVAL_CONTEXT: dict[str, Any] = {} @dataclass(frozen=True) class DQNConfig: policy_arch: str = "single_head_with_district_feature" total_updates: int = 200 learning_rate: float = 1e-4 gamma: float = 0.99 n_step: int = 3 replay_capacity: int = 500_000 minibatch_size: int = 1024 learning_starts: int = 10_000 gradient_steps: int = 64 target_tau: float = 0.01 max_grad_norm: float = 10.0 hidden_dim: int = 256 hidden_layers: int = 2 dueling: bool = True seed: int = 7 eval_every: int = 40 checkpoint_every: int = 40 checkpoint_on_eval: bool = True val_scenarios_per_city: int | None = 1 max_val_cities: int | None = 5 max_train_cities: int | None = None num_rollout_workers: int = 4 rollout_episodes_per_update: int | None = None train_city_id: str | None = None train_scenario_name: str | None = None overfit_val_on_train_scenario: bool = False rollout_decision_steps: int | None = 256 resume_from: str | None = None use_observation_normalization: bool = True epsilon_start: float = 1.0 epsilon_end: float = 0.05 epsilon_decay_steps: int = 50_000 prioritized_replay_alpha: float = 0.6 prioritized_replay_beta_start: float = 0.4 prioritized_replay_beta_end: float = 1.0 prioritized_replay_beta_steps: int = 200_000 compare_baselines: bool = True skip_failed_validation_episodes: bool = True verbose_progress: bool = False eval_num_workers: int = -1 enable_tensorboard: bool = True tensorboard_log_dir: str | None = None rolling_window_size: int = 20 use_tqdm: bool = True @dataclass class TrainerState: update_index: int = 0 best_validation_score: float = float("-inf") total_decision_steps: int = 0 total_transitions: int = 0 gradient_steps: int = 0 @dataclass(frozen=True) class StepRecord: observation: np.ndarray district_type_index: int action_mask: np.ndarray action: int reward: float next_observation: np.ndarray next_district_type_index: int next_action_mask: np.ndarray done: bool class PrioritizedReplayBuffer: def __init__( self, capacity: int, prioritized_alpha: float = 0.6, epsilon: float = 1e-6, ): self.capacity = int(capacity) self.prioritized_alpha = float(prioritized_alpha) self.epsilon = float(epsilon) self.position = 0 self.size = 0 self.max_priority = 1.0 self.observations: np.ndarray | None = None self.next_observations: np.ndarray | None = None self.district_type_indices: np.ndarray | None = None self.next_district_type_indices: np.ndarray | None = None self.action_masks: np.ndarray | None = None self.next_action_masks: np.ndarray | None = None self.actions: np.ndarray | None = None self.rewards: np.ndarray | None = None self.dones: np.ndarray | None = None self.discounts: np.ndarray | None = None self.priorities = np.zeros(self.capacity, dtype=np.float32) def add( self, observation: np.ndarray, district_type_index: int, action_mask: np.ndarray, action: int, reward: float, next_observation: np.ndarray, next_district_type_index: int, next_action_mask: np.ndarray, done: bool, discount: float, ) -> None: if self.observations is None: obs_dim = observation.shape[0] action_dim = action_mask.shape[0] self.observations = np.zeros((self.capacity, obs_dim), dtype=np.float32) self.next_observations = np.zeros((self.capacity, obs_dim), dtype=np.float32) self.district_type_indices = np.zeros(self.capacity, dtype=np.int64) self.next_district_type_indices = np.zeros(self.capacity, dtype=np.int64) self.action_masks = np.zeros((self.capacity, action_dim), dtype=np.float32) self.next_action_masks = np.zeros((self.capacity, action_dim), dtype=np.float32) self.actions = np.zeros(self.capacity, dtype=np.int64) self.rewards = np.zeros(self.capacity, dtype=np.float32) self.dones = np.zeros(self.capacity, dtype=np.float32) self.discounts = np.zeros(self.capacity, dtype=np.float32) index = self.position self.observations[index] = observation.astype(np.float32) self.next_observations[index] = next_observation.astype(np.float32) self.district_type_indices[index] = int(district_type_index) self.next_district_type_indices[index] = int(next_district_type_index) self.action_masks[index] = action_mask.astype(np.float32) self.next_action_masks[index] = next_action_mask.astype(np.float32) self.actions[index] = int(action) self.rewards[index] = float(reward) self.dones[index] = float(done) self.discounts[index] = float(discount) self.priorities[index] = self.max_priority self.position = (self.position + 1) % self.capacity self.size = min(self.size + 1, self.capacity) def sample(self, batch_size: int, beta: float) -> dict[str, np.ndarray]: if self.size <= 0: raise ValueError("Cannot sample from an empty replay buffer.") replace = self.size < batch_size if self.prioritized_alpha > 0.0: scaled_priorities = np.power( np.maximum(self.priorities[: self.size], self.epsilon), self.prioritized_alpha, ) probabilities = scaled_priorities / scaled_priorities.sum() indices = np.random.choice( self.size, size=batch_size, replace=replace, p=probabilities, ) weights = np.power(self.size * probabilities[indices], -beta).astype(np.float32) weights /= max(1.0, float(weights.max())) else: indices = np.random.choice(self.size, size=batch_size, replace=replace) weights = np.ones(batch_size, dtype=np.float32) return { "indices": indices.astype(np.int64), "weights": weights.astype(np.float32), "observations": self.observations[indices], "next_observations": self.next_observations[indices], "district_type_indices": self.district_type_indices[indices], "next_district_type_indices": self.next_district_type_indices[indices], "action_masks": self.action_masks[indices], "next_action_masks": self.next_action_masks[indices], "actions": self.actions[indices], "rewards": self.rewards[indices], "dones": self.dones[indices], "discounts": self.discounts[indices], } def update_priorities(self, indices: np.ndarray, td_errors: np.ndarray) -> None: updated_priorities = np.abs(td_errors).astype(np.float32) + self.epsilon self.priorities[indices] = updated_priorities if updated_priorities.size: self.max_priority = max(self.max_priority, float(updated_priorities.max())) def state_dict(self) -> dict[str, Any]: return { "capacity": self.capacity, "prioritized_alpha": self.prioritized_alpha, "epsilon": self.epsilon, "position": self.position, "size": self.size, "max_priority": self.max_priority, "observations": self.observations, "next_observations": self.next_observations, "district_type_indices": self.district_type_indices, "next_district_type_indices": self.next_district_type_indices, "action_masks": self.action_masks, "next_action_masks": self.next_action_masks, "actions": self.actions, "rewards": self.rewards, "dones": self.dones, "discounts": self.discounts, "priorities": self.priorities, } def load_state_dict(self, state_dict: dict[str, Any]) -> None: self.capacity = int(state_dict["capacity"]) self.prioritized_alpha = float(state_dict["prioritized_alpha"]) self.epsilon = float(state_dict["epsilon"]) self.position = int(state_dict["position"]) self.size = int(state_dict["size"]) self.max_priority = float(state_dict["max_priority"]) self.observations = state_dict["observations"] self.next_observations = state_dict["next_observations"] self.district_type_indices = state_dict["district_type_indices"] self.next_district_type_indices = state_dict["next_district_type_indices"] self.action_masks = state_dict["action_masks"] self.next_action_masks = state_dict["next_action_masks"] self.actions = state_dict["actions"] self.rewards = state_dict["rewards"] self.dones = state_dict["dones"] self.discounts = state_dict["discounts"] self.priorities = state_dict["priorities"] class DQNTrainer: def __init__( self, dataset: CityFlowDataset, env_config: EnvConfig, dqn_config: DQNConfig, output_dir: str | Path = "artifacts/dqn_shared", device: str | None = None, ): self.dataset = dataset self.env_config = env_config self.dqn_config = dqn_config self.output_dir = Path(output_dir) self.output_dir.mkdir(parents=True, exist_ok=True) self.checkpoint_dir = self.output_dir / "checkpoints" self.checkpoint_dir.mkdir(parents=True, exist_ok=True) self.rng = random.Random(dqn_config.seed) np.random.seed(dqn_config.seed) torch.manual_seed(dqn_config.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(dqn_config.seed) self.device = resolve_torch_device(device) configure_torch_runtime(self.device) if self.dqn_config.policy_arch not in POLICY_ARCHES: raise ValueError( f"Unsupported policy architecture: {self.dqn_config.policy_arch}. " f"Expected one of {POLICY_ARCHES}." ) self.train_city_ids = self.dataset.load_split("train") if self.dqn_config.max_train_cities is not None: self.train_city_ids = self.train_city_ids[: self.dqn_config.max_train_cities] self.fixed_train_scenario_spec = self._resolve_fixed_train_scenario() if not self.train_city_ids: raise ValueError("No training cities available for DQN training.") sample_spec = self._sample_train_scenario() sample_env = self._make_env(sample_spec) observation_dim = sample_env.observation_dim self.q_network = TrafficControlQNetwork( observation_dim=observation_dim, hidden_dim=dqn_config.hidden_dim, num_layers=dqn_config.hidden_layers, policy_arch=dqn_config.policy_arch, dueling=dqn_config.dueling, ).to(self.device) self.target_network = TrafficControlQNetwork( observation_dim=observation_dim, hidden_dim=dqn_config.hidden_dim, num_layers=dqn_config.hidden_layers, policy_arch=dqn_config.policy_arch, dueling=dqn_config.dueling, ).to(self.device) self.target_network.load_state_dict(self.q_network.state_dict()) self.target_network.eval() self.optimizer = torch.optim.Adam(self.q_network.parameters(), lr=dqn_config.learning_rate) self.obs_normalizer = RunningNormalizer() if dqn_config.use_observation_normalization else None self.replay_buffer = PrioritizedReplayBuffer( capacity=dqn_config.replay_capacity, prioritized_alpha=dqn_config.prioritized_replay_alpha, ) self.state = TrainerState() self.training_log_path = self.output_dir / "training_log.jsonl" self.validation_log_path = self.output_dir / "validation_log.jsonl" self.tensorboard_log_dir = Path( dqn_config.tensorboard_log_dir or (self.output_dir / "tensorboard") ) self.writer = self._build_tensorboard_writer() self._rolling_metrics: dict[tuple[str, str], deque[float]] = {} self.rollout_executor: ProcessPoolExecutor | None = None if self.dqn_config.num_rollout_workers > 1: self.rollout_executor = ProcessPoolExecutor( max_workers=self.dqn_config.num_rollout_workers, ) print( "[setup] " f"torch_device={self.device.type} " f"algorithm=ps_d3qn " f"policy_arch={self.dqn_config.policy_arch} " f"reward_variant={self.env_config.reward.variant} " f"rollout_workers={self.dqn_config.num_rollout_workers}" ) if self.fixed_train_scenario_spec is not None: print( "[setup] " f"fixed_train_city={self.fixed_train_scenario_spec.city_id} " f"fixed_train_scenario={self.fixed_train_scenario_spec.scenario_name} " f"overfit_val_on_train_scenario={self.dqn_config.overfit_val_on_train_scenario}" ) if dqn_config.resume_from: self.load_checkpoint(dqn_config.resume_from) def fit(self) -> None: progress_bar: tqdm | None = None try: if self.dqn_config.use_tqdm: progress_bar = tqdm( total=self.dqn_config.total_updates, initial=self.state.update_index, desc="train", dynamic_ncols=True, ) for update_index in range(self.state.update_index, self.dqn_config.total_updates): rollout_start = perf_counter() episode_records = self._collect_rollout_batch() rollout_seconds = perf_counter() - rollout_start update_start = perf_counter() losses = self._optimize() update_seconds = perf_counter() - update_start self.state.update_index = update_index + 1 validation_seconds = 0.0 checkpoint_seconds = 0.0 train_record = self._summarize_rollout_batch(episode_records) train_record.update( { "update": self.state.update_index, "algorithm": "ps_d3qn", "policy_arch": self.dqn_config.policy_arch, "reward_variant": self.env_config.reward.variant, "replay_size": float(self.replay_buffer.size), "epsilon": float(self._epsilon()), **losses, } ) self._attach_rolling_metrics( namespace="train", record=train_record, keys=( "episode_return", "total_episode_return", "mean_waiting_vehicles", "throughput", "td_loss", "mean_q_value", "mean_abs_td_error", ), ) self._append_jsonl(self.training_log_path, train_record) self._print_train_log(train_record) self._log_tensorboard_scalars("train", train_record, self.state.update_index) if progress_bar is not None: progress_bar.set_postfix( ret=f"{train_record['episode_return']:.3f}", wait=f"{train_record['mean_waiting_vehicles']:.2f}", td=f"{train_record['td_loss']:.4f}", eps=f"{train_record['epsilon']:.3f}", ) progress_bar.update(1) should_evaluate = self.state.update_index % self.dqn_config.eval_every == 0 should_periodic_checkpoint = ( self.state.update_index % self.dqn_config.checkpoint_every == 0 ) if should_periodic_checkpoint and not ( should_evaluate and self.dqn_config.checkpoint_on_eval ): print(f"[train] saving checkpoint at update={self.state.update_index}") checkpoint_start = perf_counter() self.save_checkpoint(self.checkpoint_dir / f"update_{self.state.update_index:04d}.pt") checkpoint_seconds += perf_counter() - checkpoint_start print(f"[train] finished checkpoint at update={self.state.update_index}") if should_evaluate: print(f"[train] starting validation at update={self.state.update_index}") validation_start = perf_counter() validation_record = self.evaluate_split("val") validation_seconds = perf_counter() - validation_start validation_record["update"] = self.state.update_index validation_record["algorithm"] = "ps_d3qn" validation_record["policy_arch"] = self.dqn_config.policy_arch validation_record["reward_variant"] = self.env_config.reward.variant self._attach_rolling_metrics( namespace="eval", record=validation_record, keys=( "mean_episode_return", "mean_total_episode_return", "mean_mean_waiting_vehicles", "mean_throughput", ), ) self._append_jsonl(self.validation_log_path, validation_record) self._print_eval_log(validation_record) self._log_tensorboard_scalars("eval", validation_record, self.state.update_index) print(f"[train] finished validation at update={self.state.update_index}") if self.dqn_config.checkpoint_on_eval: print(f"[train] saving checkpoint at update={self.state.update_index}") checkpoint_start = perf_counter() self.save_checkpoint( self.checkpoint_dir / f"update_{self.state.update_index:04d}.pt" ) checkpoint_seconds += perf_counter() - checkpoint_start print(f"[train] finished checkpoint at update={self.state.update_index}") validation_score = float(validation_record["mean_episode_return"]) if validation_score > self.state.best_validation_score: self.state.best_validation_score = validation_score print(f"[train] saving checkpoint at update={self.state.update_index}") checkpoint_start = perf_counter() self.save_checkpoint(self.output_dir / "best_validation.pt") checkpoint_seconds += perf_counter() - checkpoint_start print(f"[train] finished checkpoint at update={self.state.update_index}") print( "[timing] " f"rollout={rollout_seconds:.2f}s " f"update={update_seconds:.2f}s " f"validation={validation_seconds:.2f}s " f"checkpoint={checkpoint_seconds:.2f}s" ) print(f"[train] saving checkpoint at update={self.state.update_index}") final_checkpoint_start = perf_counter() self.save_checkpoint(self.output_dir / "last.pt") final_checkpoint_seconds = perf_counter() - final_checkpoint_start print(f"[train] finished checkpoint at update={self.state.update_index}") print(f"[timing] final_checkpoint={final_checkpoint_seconds:.2f}s") finally: if progress_bar is not None: progress_bar.close() if self.rollout_executor is not None: self.rollout_executor.shutdown(wait=True, cancel_futures=False) if self.writer is not None: self.writer.close() def evaluate_split(self, split_name: str) -> dict[str, float]: if split_name == "val" and self.dqn_config.overfit_val_on_train_scenario: if self.fixed_train_scenario_spec is None: raise ValueError( "--overfit-val-on-train-scenario requires a fixed training city/scenario." ) scenario_specs = [self.fixed_train_scenario_spec] else: scenario_specs = self.dataset.iter_scenarios( split_name=split_name, scenarios_per_city=self.dqn_config.val_scenarios_per_city, max_cities=self.dqn_config.max_val_cities, diversify_single_scenario=True, ) if self._resolved_eval_workers(len(scenario_specs)) > 1: episode_metrics = self._evaluate_policy_parallel(scenario_specs) else: episode_metrics = self._evaluate_policy_sequential(scenario_specs) if not episode_metrics: raise RuntimeError("Validation produced no successful episodes.") aggregate = aggregate_metrics(episode_metrics) aggregate.update(aggregate_metrics_by_scenario(episode_metrics)) if self.dqn_config.compare_baselines: if self._resolved_eval_workers(len(scenario_specs)) > 1: aggregate.update(self._evaluate_baselines_parallel(scenario_specs)) else: aggregate.update(self._evaluate_baselines(scenario_specs)) if "fixed_mean_episode_return" in aggregate: aggregate["learner_minus_fixed_return"] = ( aggregate["mean_episode_return"] - aggregate["fixed_mean_episode_return"] ) if "random_mean_episode_return" in aggregate: aggregate["learner_minus_random_return"] = ( aggregate["mean_episode_return"] - aggregate["random_mean_episode_return"] ) return aggregate def save_checkpoint(self, path: str | Path) -> None: checkpoint = { "algorithm": "ps_d3qn", "q_network_state_dict": self.q_network.state_dict(), "target_network_state_dict": self.target_network.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(), "trainer_state": asdict(self.state), "dqn_config": asdict(self.dqn_config), "network_architecture": { "observation_dim": self.q_network.observation_dim, "action_dim": self.q_network.action_dim, "district_types": self.q_network.district_types, "policy_arch": self.q_network.policy_arch, "dueling": self.q_network.dueling, }, "env_config": { "simulator_interval": self.env_config.simulator_interval, "decision_interval": self.env_config.decision_interval, "min_green_time": self.env_config.min_green_time, "thread_num": self.env_config.thread_num, "max_episode_seconds": self.env_config.max_episode_seconds, "observation": asdict(self.env_config.observation), "reward": asdict(self.env_config.reward), }, "obs_normalizer": self.obs_normalizer.state_dict() if self.obs_normalizer else None, } torch.save(checkpoint, path) def load_checkpoint(self, path: str | Path) -> None: checkpoint = torch.load( path, map_location=self.device, weights_only=False, ) q_state_dict = checkpoint.get("q_network_state_dict") or checkpoint.get("policy_state_dict") if q_state_dict is None: raise ValueError(f"Checkpoint at {path} does not contain a Q-network state dict.") self.q_network.load_state_dict(q_state_dict) target_state_dict = checkpoint.get("target_network_state_dict") or q_state_dict self.target_network.load_state_dict(target_state_dict) self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) self.state = TrainerState(**checkpoint["trainer_state"]) if self.obs_normalizer and checkpoint.get("obs_normalizer"): self.obs_normalizer.load_state_dict(checkpoint["obs_normalizer"]) def _make_env(self, scenario_spec: ScenarioSpec) -> TrafficEnv: return TrafficEnv( city_id=scenario_spec.city_id, scenario_name=scenario_spec.scenario_name, city_dir=scenario_spec.city_dir, scenario_dir=scenario_spec.scenario_dir, config_path=scenario_spec.config_path, roadnet_path=scenario_spec.roadnet_path, district_map_path=scenario_spec.district_map_path, metadata_path=scenario_spec.metadata_path, env_config=self.env_config, ) def _collect_rollout_batch(self) -> list[dict[str, float | str]]: episodes_per_update = self.dqn_config.rollout_episodes_per_update or max( 1, self.dqn_config.num_rollout_workers, ) scenario_specs = [self._sample_train_scenario() for _ in range(episodes_per_update)] if self.rollout_executor is None or episodes_per_update <= 1: episode_record = self._collect_episode(self._make_env(scenario_specs[0])) return [episode_record] return self._collect_rollouts_parallel(scenario_specs) def _sample_train_scenario(self) -> ScenarioSpec: if self.fixed_train_scenario_spec is not None: return self.fixed_train_scenario_spec selected_city = self.rng.choice(self.train_city_ids) selected_scenario = self.rng.choice(self.dataset.scenarios_for_city(selected_city)) if self.dqn_config.verbose_progress: print(f"[train] sampled city={selected_city} scenario={selected_scenario}") return self.dataset.build_scenario_spec(selected_city, selected_scenario) def _resolve_fixed_train_scenario(self) -> ScenarioSpec | None: if self.dqn_config.train_city_id is None: return None city_id = self.dqn_config.train_city_id available_train_cities = set(self.dataset.load_split("train")) if city_id not in available_train_cities: raise ValueError( f"Fixed train city {city_id!r} is not in the train split." ) scenario_names = self.dataset.scenarios_for_city(city_id) scenario_name = self.dqn_config.train_scenario_name if scenario_name is None: scenario_name = scenario_names[0] if scenario_name not in scenario_names: raise ValueError( f"Scenario {scenario_name!r} not found for train city {city_id!r}. " f"Available: {scenario_names}" ) self.train_city_ids = [city_id] return self.dataset.build_scenario_spec(city_id, scenario_name) def _collect_rollouts_parallel( self, scenario_specs: list[ScenarioSpec], ) -> list[dict[str, float | str]]: if self.rollout_executor is None: raise RuntimeError("Parallel rollout collection requested without a rollout executor.") context = self._build_parallel_rollout_context() epsilon = self._epsilon() total_specs = len(scenario_specs) episode_records: list[dict[str, float | str]] = [] futures = { self.rollout_executor.submit( _parallel_rollout_collection_worker, spec, context, epsilon, self.dqn_config.rollout_decision_steps, self.dqn_config.gamma, self.dqn_config.n_step, ): (index, spec) for index, spec in enumerate(scenario_specs, start=1) } for future in as_completed(futures): index, spec = futures[future] result = future.result() self._ingest_transition_batch(result["transitions"]) self.state.total_decision_steps += int(result["episode_record"]["decision_steps"]) self.state.total_transitions += int(result["episode_record"]["transitions"]) episode_records.append(result["episode_record"]) if self.dqn_config.verbose_progress: print( f"[rollout] city={spec.city_id} scenario={spec.scenario_name} " f"i={index}/{total_specs}" ) return episode_records def _build_parallel_rollout_context(self) -> dict[str, Any]: return { "env_config": _env_config_to_payload(self.env_config), "network_architecture": { "observation_dim": self.q_network.observation_dim, "action_dim": self.q_network.action_dim, "hidden_dim": self.q_network.hidden_dim, "num_layers": self.q_network.num_layers, "district_types": self.q_network.district_types, "policy_arch": self.q_network.policy_arch, "dueling": self.q_network.dueling, }, "q_network_state_dict": { key: value.detach().cpu() for key, value in self.q_network.state_dict().items() }, "obs_normalizer": self.obs_normalizer.state_dict() if self.obs_normalizer else None, } def _ingest_transition_batch(self, transitions: dict[str, np.ndarray]) -> None: if transitions["observations"].size == 0: return if self.obs_normalizer is not None: self.obs_normalizer.update(transitions["observations"]) transition_count = transitions["actions"].shape[0] for index in range(transition_count): self.replay_buffer.add( observation=transitions["observations"][index], district_type_index=int(transitions["district_type_indices"][index]), action_mask=transitions["action_masks"][index], action=int(transitions["actions"][index]), reward=float(transitions["rewards"][index]), next_observation=transitions["next_observations"][index], next_district_type_index=int(transitions["next_district_type_indices"][index]), next_action_mask=transitions["next_action_masks"][index], done=bool(transitions["dones"][index]), discount=float(transitions["discounts"][index]), ) def _summarize_rollout_batch( self, episode_records: list[dict[str, float | str]], ) -> dict[str, float | str]: if len(episode_records) == 1: record = dict(episode_records[0]) record["num_rollout_episodes"] = 1.0 return record aggregate = aggregate_metrics(episode_records) city_ids = sorted({str(record["city_id"]) for record in episode_records}) scenario_names = sorted({str(record["scenario_name"]) for record in episode_records}) summary: dict[str, float | str] = { "city_id": city_ids[0] if len(city_ids) == 1 else f"{len(city_ids)}_cities", "scenario_name": scenario_names[0] if len(scenario_names) == 1 else f"{len(scenario_names)}_scenarios", "num_rollout_episodes": float(len(episode_records)), } for source_key, target_key in ( ("mean_decision_steps", "decision_steps"), ("mean_transitions", "transitions"), ("mean_episode_return", "episode_return"), ("mean_total_episode_return", "total_episode_return"), ("mean_mean_waiting_vehicles", "mean_waiting_vehicles"), ("mean_throughput", "throughput"), ("mean_mean_q_value", "mean_q_value"), ("mean_epsilon", "epsilon"), ("mean_replay_size", "replay_size"), ): if source_key in aggregate: summary[target_key] = aggregate[source_key] for key, value in aggregate.items(): if key not in summary: summary[key] = value return summary def _collect_episode(self, env: TrafficEnv) -> dict[str, float | str]: observation_batch = env.reset() decision_steps = 0 transitions_added = 0 q_value_samples: list[float] = [] n_step_buffers = [ deque() for _ in range(len(observation_batch["intersection_ids"])) ] epsilon = self._epsilon() last_info = env.last_info done = False while not done: if ( self.dqn_config.rollout_decision_steps is not None and decision_steps >= self.dqn_config.rollout_decision_steps ): break raw_obs = observation_batch["observations"].astype(np.float32) if self.obs_normalizer is not None: self.obs_normalizer.update(raw_obs) normalized_obs = self.obs_normalizer.normalize(raw_obs) else: normalized_obs = raw_obs obs_tensor = torch.as_tensor(normalized_obs, dtype=torch.float32, device=self.device) district_type_tensor = torch.as_tensor( observation_batch["district_type_indices"], dtype=torch.int64, device=self.device, ) action_mask_tensor = torch.as_tensor( observation_batch["action_mask"], dtype=torch.float32, device=self.device, ) with torch.no_grad(): q_values = self.q_network.forward( observations=obs_tensor, district_type_indices=district_type_tensor, action_mask=action_mask_tensor, ) action_tensor = self.q_network.act( observations=obs_tensor, district_type_indices=district_type_tensor, action_mask=action_mask_tensor, deterministic=False, epsilon=epsilon, ) q_value_samples.append(float(q_values.max(dim=-1).values.mean().detach().cpu())) actions = action_tensor.detach().cpu().numpy() next_observation_batch, rewards, done, info = env.step(actions) transitions_added += self._append_step_records( buffers=n_step_buffers, observation_batch=observation_batch, actions=actions, rewards=np.asarray(rewards, dtype=np.float32), next_observation_batch=next_observation_batch, done=done, ) observation_batch = next_observation_batch last_info = info decision_steps += 1 self.state.total_decision_steps += 1 epsilon = self._epsilon() transitions_added += self._flush_n_step_buffers(n_step_buffers) self.state.total_transitions += transitions_added episode_metrics = { key: float(value) for key, value in last_info["metrics"].items() if value is not None and isinstance(value, (int, float)) } episode_metrics.update( { "city_id": env.city_id, "scenario_name": env.scenario_name, "decision_steps": decision_steps, "transitions": transitions_added, "episode_return": float(env.episode_return), "total_episode_return": float(env.total_episode_return), "epsilon": float(epsilon), "replay_size": float(self.replay_buffer.size), "mean_q_value": float(np.mean(q_value_samples)) if q_value_samples else 0.0, } ) return episode_metrics def _append_step_records( self, buffers: list[deque[StepRecord]], observation_batch: dict[str, Any], actions: np.ndarray, rewards: np.ndarray, next_observation_batch: dict[str, Any], done: bool, ) -> int: transitions_added = 0 for row_index, buffer in enumerate(buffers): record = StepRecord( observation=observation_batch["observations"][row_index].astype(np.float32), district_type_index=int(observation_batch["district_type_indices"][row_index]), action_mask=observation_batch["action_mask"][row_index].astype(np.float32), action=int(actions[row_index]), reward=float(rewards[row_index]), next_observation=next_observation_batch["observations"][row_index].astype(np.float32), next_district_type_index=int(next_observation_batch["district_type_indices"][row_index]), next_action_mask=next_observation_batch["action_mask"][row_index].astype(np.float32), done=bool(done), ) buffer.append(record) if len(buffer) >= self.dqn_config.n_step: self._push_n_step_transition(buffer, steps=self.dqn_config.n_step) transitions_added += 1 return transitions_added def _flush_n_step_buffers(self, buffers: list[deque[StepRecord]]) -> int: transitions_added = 0 for buffer in buffers: while buffer: self._push_n_step_transition(buffer, steps=len(buffer)) transitions_added += 1 return transitions_added def _push_n_step_transition(self, buffer: deque[StepRecord], steps: int) -> None: records = list(islice(buffer, 0, steps)) reward = 0.0 for step_index, record in enumerate(records): reward += (self.dqn_config.gamma ** step_index) * float(record.reward) first_record = records[0] last_record = records[-1] discount = self.dqn_config.gamma ** len(records) self.replay_buffer.add( observation=first_record.observation, district_type_index=first_record.district_type_index, action_mask=first_record.action_mask, action=first_record.action, reward=reward, next_observation=last_record.next_observation, next_district_type_index=last_record.next_district_type_index, next_action_mask=last_record.next_action_mask, done=last_record.done, discount=discount, ) buffer.popleft() def _optimize(self) -> dict[str, float]: minimum_replay = max(self.dqn_config.learning_starts, self.dqn_config.minibatch_size) if self.replay_buffer.size < minimum_replay: return { "td_loss": 0.0, "mean_abs_td_error": 0.0, "mean_target_q": 0.0, "mean_q_value": 0.0, "beta": self._beta(), "gradient_steps": 0.0, } batch_size = min(self.dqn_config.minibatch_size, self.replay_buffer.size) td_losses: list[float] = [] td_errors: list[float] = [] target_values: list[float] = [] q_values: list[float] = [] beta = self._beta() for _ in range(self.dqn_config.gradient_steps): batch = self.replay_buffer.sample(batch_size=batch_size, beta=beta) observations = batch["observations"] next_observations = batch["next_observations"] if self.obs_normalizer is not None: observations = self.obs_normalizer.normalize(observations) next_observations = self.obs_normalizer.normalize(next_observations) obs_tensor = torch.as_tensor(observations, dtype=torch.float32, device=self.device) next_obs_tensor = torch.as_tensor(next_observations, dtype=torch.float32, device=self.device) district_type_tensor = torch.as_tensor( batch["district_type_indices"], dtype=torch.int64, device=self.device, ) next_district_type_tensor = torch.as_tensor( batch["next_district_type_indices"], dtype=torch.int64, device=self.device, ) action_mask_tensor = torch.as_tensor(batch["action_masks"], dtype=torch.float32, device=self.device) next_action_mask_tensor = torch.as_tensor( batch["next_action_masks"], dtype=torch.float32, device=self.device, ) action_tensor = torch.as_tensor(batch["actions"], dtype=torch.int64, device=self.device) reward_tensor = torch.as_tensor(batch["rewards"], dtype=torch.float32, device=self.device) done_tensor = torch.as_tensor(batch["dones"], dtype=torch.float32, device=self.device) discount_tensor = torch.as_tensor(batch["discounts"], dtype=torch.float32, device=self.device) weight_tensor = torch.as_tensor(batch["weights"], dtype=torch.float32, device=self.device) predicted_q = self.q_network.q_values_for_actions( observations=obs_tensor, district_type_indices=district_type_tensor, actions=action_tensor, action_mask=action_mask_tensor, ) with torch.no_grad(): next_online_q = self.q_network.forward( observations=next_obs_tensor, district_type_indices=next_district_type_tensor, action_mask=next_action_mask_tensor, ) next_actions = next_online_q.argmax(dim=-1) next_target_q = self.target_network.forward( observations=next_obs_tensor, district_type_indices=next_district_type_tensor, action_mask=next_action_mask_tensor, ).gather(dim=1, index=next_actions.view(-1, 1)).squeeze(1) target_q = reward_tensor + (1.0 - done_tensor) * discount_tensor * next_target_q td_error = target_q - predicted_q per_sample_loss = nn.functional.smooth_l1_loss( predicted_q, target_q, reduction="none", ) loss = (weight_tensor * per_sample_loss).mean() self.optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(self.q_network.parameters(), self.dqn_config.max_grad_norm) self.optimizer.step() self._soft_update_target() self.replay_buffer.update_priorities( batch["indices"], td_errors=np.abs(td_error.detach().cpu().numpy()), ) self.state.gradient_steps += 1 td_losses.append(float(loss.detach().cpu())) td_errors.append(float(torch.abs(td_error).mean().detach().cpu())) target_values.append(float(target_q.mean().detach().cpu())) q_values.append(float(predicted_q.mean().detach().cpu())) return { "td_loss": float(np.mean(td_losses)), "mean_abs_td_error": float(np.mean(td_errors)), "mean_target_q": float(np.mean(target_values)), "mean_q_value": float(np.mean(q_values)), "beta": float(beta), "gradient_steps": float(self.dqn_config.gradient_steps), } def _soft_update_target(self) -> None: tau = float(self.dqn_config.target_tau) with torch.no_grad(): for target_param, online_param in zip( self.target_network.parameters(), self.q_network.parameters(), strict=True, ): target_param.data.mul_(1.0 - tau).add_(online_param.data, alpha=tau) def _epsilon(self) -> float: if self.dqn_config.epsilon_decay_steps <= 0: return float(self.dqn_config.epsilon_end) progress = min(1.0, self.state.total_decision_steps / float(self.dqn_config.epsilon_decay_steps)) return float( self.dqn_config.epsilon_start + progress * (self.dqn_config.epsilon_end - self.dqn_config.epsilon_start) ) def _beta(self) -> float: if self.dqn_config.prioritized_replay_beta_steps <= 0: return float(self.dqn_config.prioritized_replay_beta_end) progress = min( 1.0, self.state.total_decision_steps / float(self.dqn_config.prioritized_replay_beta_steps), ) return float( self.dqn_config.prioritized_replay_beta_start + progress * ( self.dqn_config.prioritized_replay_beta_end - self.dqn_config.prioritized_replay_beta_start ) ) def _evaluate_policy_sequential( self, scenario_specs: list[ScenarioSpec], ) -> list[dict[str, float | str]]: episode_metrics: list[dict[str, float | str]] = [] total_specs = len(scenario_specs) iterator = enumerate(scenario_specs, start=1) if self.dqn_config.use_tqdm: iterator = tqdm( iterator, total=total_specs, desc="eval:learned", leave=False, dynamic_ncols=True, ) for index, spec in iterator: print(f"[eval] city={spec.city_id} scenario={spec.scenario_name} i={index}/{total_specs}") try: episode_metrics.append( evaluate_policy( env_factory=lambda spec=spec: self._make_env(spec), actor=self.q_network, device=self.device, obs_normalizer=self.obs_normalizer, deterministic=True, ) ) except Exception as exc: self._handle_eval_failure("validation", spec, exc) return episode_metrics def _evaluate_policy_parallel( self, scenario_specs: list[ScenarioSpec], ) -> list[dict[str, float | str]]: resolved_workers = self._resolved_eval_workers(len(scenario_specs)) print(f"[eval] learned_workers={resolved_workers}") return self._run_parallel_eval( scenario_specs=scenario_specs, worker_kind="learned", initializer=_init_parallel_learned_eval_worker, initargs=(self._build_parallel_learned_eval_context(),), max_workers=resolved_workers, ) def _append_jsonl(self, path: Path, record: dict) -> None: with path.open("a") as handle: handle.write(json.dumps(record) + "\n") def _build_tensorboard_writer(self) -> SummaryWriter | None: if not self.dqn_config.enable_tensorboard: return None if SummaryWriter is None: print("[setup] tensorboard_disabled=torch.utils.tensorboard unavailable") return None self.tensorboard_log_dir.mkdir(parents=True, exist_ok=True) return SummaryWriter(log_dir=str(self.tensorboard_log_dir)) def _log_tensorboard_scalars( self, namespace: str, record: dict[str, Any], step: int, ) -> None: if self.writer is None: return for key, value in record.items(): if isinstance(value, (int, float)): self.writer.add_scalar(f"{namespace}/{key}", float(value), step) self.writer.flush() def _attach_rolling_metrics( self, namespace: str, record: dict[str, Any], keys: tuple[str, ...], ) -> None: for key in keys: value = record.get(key) if not isinstance(value, (int, float)): continue window = self._rolling_metrics.setdefault( (namespace, key), deque(maxlen=self.dqn_config.rolling_window_size), ) window.append(float(value)) record[f"rolling_{key}"] = float(np.mean(window)) def _evaluate_baselines(self, scenario_specs: list[ScenarioSpec]) -> dict[str, float]: baseline_metrics: dict[str, float] = {} for baseline_name in ("random", "fixed"): metrics: list[dict[str, float | str]] = [] total_specs = len(scenario_specs) for offset, spec in enumerate(scenario_specs, start=1): print( f"[eval] baseline={baseline_name} city={spec.city_id} " f"scenario={spec.scenario_name} i={offset}/{total_specs}" ) try: actor = ( RandomPhasePolicy(seed=self.dqn_config.seed + offset) if baseline_name == "random" else FixedCyclePolicy(green_time=max(20, self.env_config.min_green_time * 2)) ) metrics.append( evaluate_policy( env_factory=lambda spec=spec: self._make_env(spec), actor=actor, ) ) except Exception as exc: message = ( f"[warn] baseline={baseline_name} failed for city={spec.city_id} " f"scenario={spec.scenario_name}: {exc}" ) if self.dqn_config.skip_failed_validation_episodes: print(message) continue raise RuntimeError(message) from exc if not metrics: continue aggregate = aggregate_metrics(metrics) for key, value in aggregate.items(): baseline_metrics[f"{baseline_name}_{key}"] = value return baseline_metrics def _evaluate_baselines_parallel(self, scenario_specs: list[ScenarioSpec]) -> dict[str, float]: baseline_metrics: dict[str, float] = {} resolved_workers = self._resolved_eval_workers(len(scenario_specs)) print(f"[eval] baseline_workers={resolved_workers}") for baseline_name in ("random", "fixed"): metrics = self._run_parallel_eval( scenario_specs=scenario_specs, worker_kind=baseline_name, initializer=_init_parallel_baseline_worker, initargs=(self._build_parallel_baseline_context(baseline_name),), max_workers=resolved_workers, ) if not metrics: continue aggregate = aggregate_metrics(metrics) for key, value in aggregate.items(): baseline_metrics[f"{baseline_name}_{key}"] = value return baseline_metrics def _run_parallel_eval( self, scenario_specs: list[ScenarioSpec], worker_kind: str, initializer, initargs: tuple[Any, ...], max_workers: int, ) -> list[dict[str, float | str]]: metrics: list[dict[str, float | str]] = [] total_specs = len(scenario_specs) with ProcessPoolExecutor( max_workers=max_workers, initializer=initializer, initargs=initargs, ) as executor: futures = { executor.submit(_parallel_eval_worker, spec, index, worker_kind): (spec, index) for index, spec in enumerate(scenario_specs, start=1) } iterator = as_completed(futures) if self.dqn_config.use_tqdm: iterator = tqdm( iterator, total=total_specs, desc=f"eval:{worker_kind}", leave=False, dynamic_ncols=True, ) for future in iterator: spec, index = futures[future] try: result = future.result() except Exception as exc: self._handle_eval_failure(worker_kind, spec, exc) continue prefix = f"[eval] baseline={worker_kind}" print(f"{prefix} city={spec.city_id} scenario={spec.scenario_name} i={index}/{total_specs}") metrics.append(result) return metrics def _handle_eval_failure( self, phase: str, spec: ScenarioSpec, exc: Exception, ) -> None: message = f"[warn] {phase} failed for city={spec.city_id} scenario={spec.scenario_name}: {exc}" if self.dqn_config.skip_failed_validation_episodes: print(message) return raise RuntimeError(message) from exc def _build_parallel_baseline_context(self, baseline_name: str) -> dict[str, Any]: return { "env_config": _env_config_to_payload(self.env_config), "baseline_name": baseline_name, "fixed_green_time": max(20, self.env_config.min_green_time * 2), "seed": self.dqn_config.seed, } def _build_parallel_learned_eval_context(self) -> dict[str, Any]: return { "env_config": _env_config_to_payload(self.env_config), "network_architecture": { "observation_dim": self.q_network.observation_dim, "action_dim": self.q_network.action_dim, "hidden_dim": self.q_network.hidden_dim, "num_layers": self.q_network.num_layers, "district_types": self.q_network.district_types, "policy_arch": self.q_network.policy_arch, "dueling": self.q_network.dueling, }, "q_network_state_dict": { key: value.detach().cpu() for key, value in self.q_network.state_dict().items() }, "obs_normalizer": self.obs_normalizer.state_dict() if self.obs_normalizer else None, } def _resolved_eval_workers(self, total_specs: int) -> int: requested = self.dqn_config.eval_num_workers if requested == -1: requested = os.cpu_count() or 1 if requested <= 1: return 1 return min(requested, total_specs) def _print_train_log(self, record: dict[str, float | str]) -> None: message = ( "[train] " f"update={record['update']} algo={record['algorithm']} arch={record['policy_arch']} " f"reward={record['reward_variant']} episodes={int(record.get('num_rollout_episodes', 1.0))} " f"city={record['city_id']} scenario={record['scenario_name']} " f"mean_return={record['episode_return']:.3f} " f"(avg={record.get('rolling_episode_return', record['episode_return']):.3f}) " f"wait={record['mean_waiting_vehicles']:.3f} " f"(avg={record.get('rolling_mean_waiting_vehicles', record['mean_waiting_vehicles']):.3f}) " f"throughput={record['throughput']:.1f} " f"(avg={record.get('rolling_throughput', record['throughput']):.1f}) " f"epsilon={record['epsilon']:.3f} replay={int(record['replay_size'])} " f"td_loss={record['td_loss']:.4f} " f"(avg={record.get('rolling_td_loss', record['td_loss']):.4f}) " f"q={record['mean_q_value']:.4f} " f"td_err={record['mean_abs_td_error']:.4f}" ) if self.dqn_config.use_tqdm: tqdm.write(message) else: print(message) def _print_eval_log(self, record: dict[str, float]) -> None: message = ( "[eval] " f"algo={record['algorithm']} arch={record['policy_arch']} reward={record['reward_variant']} " f"episodes={int(record['num_episodes'])} " f"mean_return={record['mean_episode_return']:.3f} " f"(avg={record.get('rolling_mean_episode_return', record['mean_episode_return']):.3f}) " f"wait={record['mean_mean_waiting_vehicles']:.3f} " f"throughput={record['mean_throughput']:.1f} " f"travel_time={record.get('mean_average_travel_time', float('nan')):.3f}" ) if self.dqn_config.compare_baselines: message += ( f" fixed={record.get('fixed_mean_episode_return', float('nan')):.3f}" f" random={record.get('random_mean_episode_return', float('nan')):.3f}" f" vs_fixed={record.get('learner_minus_fixed_return', float('nan')):.3f}" f" vs_random={record.get('learner_minus_random_return', float('nan')):.3f}" ) if self.dqn_config.use_tqdm: tqdm.write(message) else: print(message) scenario_summaries = [] for scenario_name in ( "accident", "construction", "district_overload", "evening_rush", "event_spike", "morning_rush", "normal", ): key = f"scenario_{scenario_name}_mean_episode_return" if key in record: scenario_summaries.append(f"{scenario_name}={record[key]:.3f}") if scenario_summaries: if self.dqn_config.use_tqdm: tqdm.write("[eval] scenario_returns " + " ".join(scenario_summaries)) else: print("[eval] scenario_returns " + " ".join(scenario_summaries)) def aggregate_metrics(metrics: list[dict[str, float | str]]) -> dict[str, float]: numeric_keys = { key for item in metrics for key, value in item.items() if isinstance(value, (int, float)) } aggregate = {"num_episodes": float(len(metrics))} for key in sorted(numeric_keys): aggregate[f"mean_{key}"] = float( np.mean([float(item[key]) for item in metrics if key in item]) ) return aggregate def aggregate_metrics_by_scenario(metrics: list[dict[str, float | str]]) -> dict[str, float]: scenario_names = sorted( { str(item["scenario_name"]) for item in metrics if isinstance(item.get("scenario_name"), str) } ) aggregate: dict[str, float] = {} for scenario_name in scenario_names: scenario_metrics = [item for item in metrics if item.get("scenario_name") == scenario_name] if not scenario_metrics: continue scenario_aggregate = aggregate_metrics(scenario_metrics) for key, value in scenario_aggregate.items(): aggregate[f"scenario_{scenario_name}_{key}"] = value return aggregate def _env_config_to_payload(env_config: EnvConfig) -> dict[str, Any]: return { "simulator_interval": env_config.simulator_interval, "decision_interval": env_config.decision_interval, "min_green_time": env_config.min_green_time, "thread_num": env_config.thread_num, "max_episode_seconds": env_config.max_episode_seconds, "observation": asdict(env_config.observation), "reward": asdict(env_config.reward), } def _env_config_from_payload(payload: dict[str, Any]) -> EnvConfig: return EnvConfig( simulator_interval=payload["simulator_interval"], decision_interval=payload["decision_interval"], min_green_time=payload["min_green_time"], thread_num=payload["thread_num"], max_episode_seconds=payload["max_episode_seconds"], observation=ObservationConfig(**payload["observation"]), reward=RewardConfig(**payload["reward"]), ) def _init_parallel_baseline_worker(context: dict[str, Any]) -> None: _init_parallel_eval_worker_from_context(context) def _init_parallel_learned_eval_worker(context: dict[str, Any]) -> None: _init_parallel_eval_worker_from_context(context) def _build_standalone_eval_context( env_config: EnvConfig, actor: TrafficControlQNetwork | RandomPhasePolicy | FixedCyclePolicy | HoldPhasePolicy | QueueGreedyPolicy, obs_normalizer: RunningNormalizer | None, device: torch.device, seed: int, fixed_green_time: int, baseline_name: str | None, ) -> dict[str, Any]: del device if baseline_name is not None: return { "env_config": _env_config_to_payload(env_config), "baseline_name": baseline_name, "fixed_green_time": fixed_green_time, "seed": seed, } if not isinstance(actor, TrafficControlQNetwork): raise ValueError("Standalone parallel learned evaluation requires a Q-network actor.") return { "env_config": _env_config_to_payload(env_config), "network_architecture": { "observation_dim": actor.observation_dim, "action_dim": actor.action_dim, "hidden_dim": actor.hidden_dim, "num_layers": actor.num_layers, "district_types": actor.district_types, "policy_arch": actor.policy_arch, "dueling": actor.dueling, }, "q_network_state_dict": { key: value.detach().cpu() for key, value in actor.state_dict().items() }, "obs_normalizer": obs_normalizer.state_dict() if obs_normalizer else None, } def _init_parallel_eval_worker_from_context(context: dict[str, Any]) -> None: global _EVAL_CONTEXT env_config = _env_config_from_payload(context["env_config"]) if "baseline_name" in context: baseline_name = context["baseline_name"] if baseline_name == "random": actor = RandomPhasePolicy(seed=context["seed"]) elif baseline_name == "fixed": actor = FixedCyclePolicy(green_time=context["fixed_green_time"]) elif baseline_name == "hold": actor = HoldPhasePolicy() elif baseline_name == "queue_greedy": actor = QueueGreedyPolicy() else: raise ValueError(f"Unsupported baseline worker kind: {baseline_name}") obs_normalizer = None else: architecture = context["network_architecture"] actor = TrafficControlQNetwork( observation_dim=architecture["observation_dim"], action_dim=architecture["action_dim"], hidden_dim=architecture["hidden_dim"], num_layers=architecture["num_layers"], district_types=tuple(architecture["district_types"]), policy_arch=architecture["policy_arch"], dueling=bool(architecture.get("dueling", True)), ).to(torch.device("cpu")) actor.load_state_dict(context["q_network_state_dict"]) actor.eval() obs_normalizer = None if context.get("obs_normalizer"): obs_normalizer = RunningNormalizer() obs_normalizer.load_state_dict(context["obs_normalizer"]) _EVAL_CONTEXT = { "env_config": env_config, "actor": actor, "obs_normalizer": obs_normalizer, } def _parallel_eval_worker( scenario_spec: ScenarioSpec, index: int, worker_kind: str, ) -> dict[str, float | str]: del index, worker_kind env_config = _EVAL_CONTEXT["env_config"] actor = _EVAL_CONTEXT["actor"] obs_normalizer = _EVAL_CONTEXT["obs_normalizer"] return evaluate_policy( env_factory=lambda: TrafficEnv( city_id=scenario_spec.city_id, scenario_name=scenario_spec.scenario_name, city_dir=scenario_spec.city_dir, scenario_dir=scenario_spec.scenario_dir, config_path=scenario_spec.config_path, roadnet_path=scenario_spec.roadnet_path, district_map_path=scenario_spec.district_map_path, metadata_path=scenario_spec.metadata_path, env_config=env_config, ), actor=actor, device=torch.device("cpu"), obs_normalizer=obs_normalizer, deterministic=True, ) def _parallel_rollout_collection_worker( scenario_spec: ScenarioSpec, context: dict[str, Any], epsilon: float, max_decision_steps: int | None, gamma: float, n_step: int, ) -> dict[str, Any]: env_config = _env_config_from_payload(context["env_config"]) architecture = context["network_architecture"] q_network = TrafficControlQNetwork( observation_dim=architecture["observation_dim"], action_dim=architecture["action_dim"], hidden_dim=architecture["hidden_dim"], num_layers=architecture["num_layers"], district_types=tuple(architecture["district_types"]), policy_arch=architecture["policy_arch"], dueling=bool(architecture.get("dueling", True)), ).to(torch.device("cpu")) q_network.load_state_dict(context["q_network_state_dict"]) q_network.eval() obs_normalizer = None if context.get("obs_normalizer"): obs_normalizer = RunningNormalizer() obs_normalizer.load_state_dict(context["obs_normalizer"]) env = TrafficEnv( city_id=scenario_spec.city_id, scenario_name=scenario_spec.scenario_name, city_dir=scenario_spec.city_dir, scenario_dir=scenario_spec.scenario_dir, config_path=scenario_spec.config_path, roadnet_path=scenario_spec.roadnet_path, district_map_path=scenario_spec.district_map_path, metadata_path=scenario_spec.metadata_path, env_config=env_config, ) return _collect_episode_trajectory( env=env, q_network=q_network, obs_normalizer=obs_normalizer, epsilon=epsilon, max_decision_steps=max_decision_steps, gamma=gamma, n_step=n_step, device=torch.device("cpu"), ) def _collect_episode_trajectory( env: TrafficEnv, q_network: TrafficControlQNetwork, obs_normalizer: RunningNormalizer | None, epsilon: float, max_decision_steps: int | None, gamma: float, n_step: int, device: torch.device, ) -> dict[str, Any]: observation_batch = env.reset() n_step_buffers = [ deque() for _ in range(len(observation_batch["intersection_ids"])) ] q_value_samples: list[float] = [] transition_records: list[tuple[np.ndarray, int, np.ndarray, int, float, np.ndarray, int, np.ndarray, bool, float]] = [] done = False decision_steps = 0 last_info = env.last_info while not done: if max_decision_steps is not None and decision_steps >= max_decision_steps: break raw_obs = observation_batch["observations"].astype(np.float32) normalized_obs = obs_normalizer.normalize(raw_obs) if obs_normalizer else raw_obs obs_tensor = torch.as_tensor(normalized_obs, dtype=torch.float32, device=device) district_type_tensor = torch.as_tensor( observation_batch["district_type_indices"], dtype=torch.int64, device=device, ) action_mask_tensor = torch.as_tensor( observation_batch["action_mask"], dtype=torch.float32, device=device, ) with torch.no_grad(): q_values = q_network.forward( observations=obs_tensor, district_type_indices=district_type_tensor, action_mask=action_mask_tensor, ) actions = q_network.act( observations=obs_tensor, district_type_indices=district_type_tensor, action_mask=action_mask_tensor, deterministic=False, epsilon=epsilon, ).cpu().numpy() q_value_samples.append(float(q_values.max(dim=-1).values.mean().detach().cpu())) next_observation_batch, rewards, done, info = env.step(actions) transition_records.extend( _build_n_step_transitions( buffers=n_step_buffers, observation_batch=observation_batch, actions=actions, rewards=np.asarray(rewards, dtype=np.float32), next_observation_batch=next_observation_batch, done=done, gamma=gamma, n_step=n_step, ) ) observation_batch = next_observation_batch last_info = info decision_steps += 1 transition_records.extend( _flush_n_step_transition_buffers( buffers=n_step_buffers, gamma=gamma, ) ) episode_metrics = { key: float(value) for key, value in last_info["metrics"].items() if value is not None and isinstance(value, (int, float)) } episode_record = { **episode_metrics, "city_id": env.city_id, "scenario_name": env.scenario_name, "decision_steps": decision_steps, "transitions": len(transition_records), "episode_return": float(env.episode_return), "total_episode_return": float(env.total_episode_return), "epsilon": float(epsilon), "mean_q_value": float(np.mean(q_value_samples)) if q_value_samples else 0.0, } return { "episode_record": episode_record, "transitions": _pack_transition_records(transition_records, env.observation_dim), } def _build_n_step_transitions( buffers: list[deque[StepRecord]], observation_batch: dict[str, Any], actions: np.ndarray, rewards: np.ndarray, next_observation_batch: dict[str, Any], done: bool, gamma: float, n_step: int, ) -> list[tuple[np.ndarray, int, np.ndarray, int, float, np.ndarray, int, np.ndarray, bool, float]]: transition_records: list[tuple[np.ndarray, int, np.ndarray, int, float, np.ndarray, int, np.ndarray, bool, float]] = [] for row_index, buffer in enumerate(buffers): record = StepRecord( observation=observation_batch["observations"][row_index].astype(np.float32), district_type_index=int(observation_batch["district_type_indices"][row_index]), action_mask=observation_batch["action_mask"][row_index].astype(np.float32), action=int(actions[row_index]), reward=float(rewards[row_index]), next_observation=next_observation_batch["observations"][row_index].astype(np.float32), next_district_type_index=int(next_observation_batch["district_type_indices"][row_index]), next_action_mask=next_observation_batch["action_mask"][row_index].astype(np.float32), done=bool(done), ) buffer.append(record) if len(buffer) >= n_step: transition_records.append(_make_transition_from_buffer(buffer, steps=n_step, gamma=gamma)) buffer.popleft() return transition_records def _flush_n_step_transition_buffers( buffers: list[deque[StepRecord]], gamma: float, ) -> list[tuple[np.ndarray, int, np.ndarray, int, float, np.ndarray, int, np.ndarray, bool, float]]: transition_records: list[tuple[np.ndarray, int, np.ndarray, int, float, np.ndarray, int, np.ndarray, bool, float]] = [] for buffer in buffers: while buffer: transition_records.append( _make_transition_from_buffer(buffer, steps=len(buffer), gamma=gamma) ) buffer.popleft() return transition_records def _make_transition_from_buffer( buffer: deque[StepRecord], steps: int, gamma: float, ) -> tuple[np.ndarray, int, np.ndarray, int, float, np.ndarray, int, np.ndarray, bool, float]: records = list(islice(buffer, 0, steps)) reward = 0.0 for step_index, record in enumerate(records): reward += (gamma ** step_index) * float(record.reward) first_record = records[0] last_record = records[-1] discount = gamma ** len(records) return ( first_record.observation, first_record.district_type_index, first_record.action_mask, first_record.action, reward, last_record.next_observation, last_record.next_district_type_index, last_record.next_action_mask, last_record.done, discount, ) def _pack_transition_records( transition_records: list[tuple[np.ndarray, int, np.ndarray, int, float, np.ndarray, int, np.ndarray, bool, float]], observation_dim: int, ) -> dict[str, np.ndarray]: if not transition_records: return { "observations": np.zeros((0, observation_dim), dtype=np.float32), "district_type_indices": np.zeros(0, dtype=np.int64), "action_masks": np.zeros((0, 2), dtype=np.float32), "actions": np.zeros(0, dtype=np.int64), "rewards": np.zeros(0, dtype=np.float32), "next_observations": np.zeros((0, observation_dim), dtype=np.float32), "next_district_type_indices": np.zeros(0, dtype=np.int64), "next_action_masks": np.zeros((0, 2), dtype=np.float32), "dones": np.zeros(0, dtype=np.float32), "discounts": np.zeros(0, dtype=np.float32), } observations = np.stack([record[0] for record in transition_records]).astype(np.float32) district_type_indices = np.asarray([record[1] for record in transition_records], dtype=np.int64) action_masks = np.stack([record[2] for record in transition_records]).astype(np.float32) actions = np.asarray([record[3] for record in transition_records], dtype=np.int64) rewards = np.asarray([record[4] for record in transition_records], dtype=np.float32) next_observations = np.stack([record[5] for record in transition_records]).astype(np.float32) next_district_type_indices = np.asarray([record[6] for record in transition_records], dtype=np.int64) next_action_masks = np.stack([record[7] for record in transition_records]).astype(np.float32) dones = np.asarray([record[8] for record in transition_records], dtype=np.float32) discounts = np.asarray([record[9] for record in transition_records], dtype=np.float32) return { "observations": observations, "district_type_indices": district_type_indices, "action_masks": action_masks, "actions": actions, "rewards": rewards, "next_observations": next_observations, "next_district_type_indices": next_district_type_indices, "next_action_masks": next_action_masks, "dones": dones, "discounts": discounts, }