from __future__ import annotations from dataclasses import dataclass import numpy as np import torch from torch import nn from env.intersection_config import DISTRICT_TYPES POLICY_ARCHES: tuple[str, ...] = ( "multi_head", "single_head", "single_head_with_district_feature", ) class TrafficControlQNetwork(nn.Module): """Parameter-shared dueling Q-network for intersection-level control.""" def __init__( self, observation_dim: int, action_dim: int = 2, hidden_dim: int = 256, num_layers: int = 2, district_types: tuple[str, ...] = DISTRICT_TYPES, policy_arch: str = "single_head_with_district_feature", dueling: bool = True, ): super().__init__() if policy_arch not in POLICY_ARCHES: raise ValueError( f"Unsupported policy architecture: {policy_arch}. " f"Expected one of {POLICY_ARCHES}." ) layers: list[nn.Module] = [] input_dim = observation_dim for _ in range(num_layers): layers.extend( [ nn.Linear(input_dim, hidden_dim), nn.ReLU(), ] ) input_dim = hidden_dim self.observation_dim = int(observation_dim) self.action_dim = int(action_dim) self.hidden_dim = int(hidden_dim) self.num_layers = int(num_layers) self.district_types = tuple(district_types) self.policy_arch = policy_arch self.dueling = bool(dueling) self.backbone = nn.Sequential(*layers) if self.policy_arch == "multi_head": self.advantage_heads = nn.ModuleList( [nn.Linear(hidden_dim, action_dim) for _ in self.district_types] ) self.value_heads = nn.ModuleList( [nn.Linear(hidden_dim, 1) for _ in self.district_types] ) self.advantage_head = None self.value_head = None else: self.advantage_heads = None self.value_heads = None self.advantage_head = nn.Linear(hidden_dim, action_dim) self.value_head = nn.Linear(hidden_dim, 1) def forward( self, observations: torch.Tensor, district_type_indices: torch.Tensor, action_mask: torch.Tensor | None = None, ) -> torch.Tensor: features = self.backbone(observations) advantages, values = self._q_streams(features, district_type_indices) if self.dueling: q_values = values + advantages - advantages.mean(dim=-1, keepdim=True) else: q_values = advantages if action_mask is not None: q_values = self._apply_action_mask(q_values, action_mask) return q_values def act( self, observations: torch.Tensor, district_type_indices: torch.Tensor, action_mask: torch.Tensor | None = None, deterministic: bool = False, epsilon: float = 0.0, ) -> torch.Tensor: q_values = self.forward( observations=observations, district_type_indices=district_type_indices, action_mask=action_mask, ) greedy_actions = q_values.argmax(dim=-1) if deterministic or epsilon <= 0.0: return greedy_actions random_mask = torch.rand(greedy_actions.shape[0], device=greedy_actions.device) < float( epsilon ) if not random_mask.any(): return greedy_actions actions = greedy_actions.clone() valid_action_mask = ( action_mask if action_mask is not None else torch.ones_like(q_values, dtype=torch.float32) ) random_rows = torch.nonzero(random_mask, as_tuple=False).flatten() for row_index in random_rows.tolist(): valid_actions = torch.nonzero(valid_action_mask[row_index] > 0.0, as_tuple=False).flatten() if valid_actions.numel() == 0: actions[row_index] = 0 continue sample_index = torch.randint( low=0, high=valid_actions.numel(), size=(1,), device=actions.device, ) actions[row_index] = valid_actions[sample_index].item() return actions def q_values_for_actions( self, observations: torch.Tensor, district_type_indices: torch.Tensor, actions: torch.Tensor, action_mask: torch.Tensor | None = None, ) -> torch.Tensor: q_values = self.forward( observations=observations, district_type_indices=district_type_indices, action_mask=action_mask, ) return q_values.gather(dim=1, index=actions.view(-1, 1)).squeeze(1) def _q_streams( self, features: torch.Tensor, district_type_indices: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: if self.policy_arch == "multi_head": all_advantages = torch.stack( [head(features) for head in self.advantage_heads], dim=1, ) all_values = torch.stack( [head(features) for head in self.value_heads], dim=1, ) gather_adv = district_type_indices.view(-1, 1, 1).expand(-1, 1, self.action_dim) gather_val = district_type_indices.view(-1, 1, 1) advantages = all_advantages.gather(dim=1, index=gather_adv).squeeze(1) values = all_values.gather(dim=1, index=gather_val).squeeze(1) return advantages, values return self.advantage_head(features), self.value_head(features) def _apply_action_mask( self, q_values: torch.Tensor, action_mask: torch.Tensor, ) -> torch.Tensor: masked_q_values = q_values.masked_fill(action_mask <= 0.0, -1.0e9) all_invalid = action_mask.sum(dim=-1) <= 0.0 if all_invalid.any(): masked_q_values[all_invalid, 0] = 0.0 return masked_q_values @dataclass class RunningNormalizer: epsilon: float = 1e-6 def __post_init__(self) -> None: self.count = 0 self.mean: np.ndarray | None = None self.m2: np.ndarray | None = None def update(self, batch: np.ndarray) -> None: array = np.asarray(batch, dtype=np.float64) if array.ndim != 2: raise ValueError("Normalizer expects a 2D batch of observations.") if self.mean is None: self.mean = np.zeros(array.shape[1], dtype=np.float64) self.m2 = np.zeros(array.shape[1], dtype=np.float64) for row in array: self.count += 1 delta = row - self.mean self.mean += delta / self.count delta2 = row - self.mean self.m2 += delta * delta2 def normalize(self, batch: np.ndarray) -> np.ndarray: array = np.asarray(batch, dtype=np.float32) if self.mean is None or self.m2 is None or self.count < 2: return array variance = self.m2 / max(1, self.count - 1) std = np.sqrt(np.maximum(variance, self.epsilon)) return ((array - self.mean.astype(np.float32)) / std.astype(np.float32)).astype( np.float32 ) def state_dict(self) -> dict: return { "count": self.count, "mean": self.mean, "m2": self.m2, "epsilon": self.epsilon, } def load_state_dict(self, state_dict: dict) -> None: self.count = int(state_dict["count"]) self.mean = state_dict["mean"] self.m2 = state_dict["m2"] self.epsilon = float(state_dict.get("epsilon", self.epsilon))