tokev's picture
Add files using upload-large-folder tool
248a619 verified
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
import torch
from torch import nn
from env.intersection_config import DISTRICT_TYPES
POLICY_ARCHES: tuple[str, ...] = (
"multi_head",
"single_head",
"single_head_with_district_feature",
)
class TrafficControlQNetwork(nn.Module):
"""Parameter-shared dueling Q-network for intersection-level control."""
def __init__(
self,
observation_dim: int,
action_dim: int = 2,
hidden_dim: int = 256,
num_layers: int = 2,
district_types: tuple[str, ...] = DISTRICT_TYPES,
policy_arch: str = "single_head_with_district_feature",
dueling: bool = True,
):
super().__init__()
if policy_arch not in POLICY_ARCHES:
raise ValueError(
f"Unsupported policy architecture: {policy_arch}. "
f"Expected one of {POLICY_ARCHES}."
)
layers: list[nn.Module] = []
input_dim = observation_dim
for _ in range(num_layers):
layers.extend(
[
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
]
)
input_dim = hidden_dim
self.observation_dim = int(observation_dim)
self.action_dim = int(action_dim)
self.hidden_dim = int(hidden_dim)
self.num_layers = int(num_layers)
self.district_types = tuple(district_types)
self.policy_arch = policy_arch
self.dueling = bool(dueling)
self.backbone = nn.Sequential(*layers)
if self.policy_arch == "multi_head":
self.advantage_heads = nn.ModuleList(
[nn.Linear(hidden_dim, action_dim) for _ in self.district_types]
)
self.value_heads = nn.ModuleList(
[nn.Linear(hidden_dim, 1) for _ in self.district_types]
)
self.advantage_head = None
self.value_head = None
else:
self.advantage_heads = None
self.value_heads = None
self.advantage_head = nn.Linear(hidden_dim, action_dim)
self.value_head = nn.Linear(hidden_dim, 1)
def forward(
self,
observations: torch.Tensor,
district_type_indices: torch.Tensor,
action_mask: torch.Tensor | None = None,
) -> torch.Tensor:
features = self.backbone(observations)
advantages, values = self._q_streams(features, district_type_indices)
if self.dueling:
q_values = values + advantages - advantages.mean(dim=-1, keepdim=True)
else:
q_values = advantages
if action_mask is not None:
q_values = self._apply_action_mask(q_values, action_mask)
return q_values
def act(
self,
observations: torch.Tensor,
district_type_indices: torch.Tensor,
action_mask: torch.Tensor | None = None,
deterministic: bool = False,
epsilon: float = 0.0,
) -> torch.Tensor:
q_values = self.forward(
observations=observations,
district_type_indices=district_type_indices,
action_mask=action_mask,
)
greedy_actions = q_values.argmax(dim=-1)
if deterministic or epsilon <= 0.0:
return greedy_actions
random_mask = torch.rand(greedy_actions.shape[0], device=greedy_actions.device) < float(
epsilon
)
if not random_mask.any():
return greedy_actions
actions = greedy_actions.clone()
valid_action_mask = (
action_mask if action_mask is not None else torch.ones_like(q_values, dtype=torch.float32)
)
random_rows = torch.nonzero(random_mask, as_tuple=False).flatten()
for row_index in random_rows.tolist():
valid_actions = torch.nonzero(valid_action_mask[row_index] > 0.0, as_tuple=False).flatten()
if valid_actions.numel() == 0:
actions[row_index] = 0
continue
sample_index = torch.randint(
low=0,
high=valid_actions.numel(),
size=(1,),
device=actions.device,
)
actions[row_index] = valid_actions[sample_index].item()
return actions
def q_values_for_actions(
self,
observations: torch.Tensor,
district_type_indices: torch.Tensor,
actions: torch.Tensor,
action_mask: torch.Tensor | None = None,
) -> torch.Tensor:
q_values = self.forward(
observations=observations,
district_type_indices=district_type_indices,
action_mask=action_mask,
)
return q_values.gather(dim=1, index=actions.view(-1, 1)).squeeze(1)
def _q_streams(
self,
features: torch.Tensor,
district_type_indices: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
if self.policy_arch == "multi_head":
all_advantages = torch.stack(
[head(features) for head in self.advantage_heads],
dim=1,
)
all_values = torch.stack(
[head(features) for head in self.value_heads],
dim=1,
)
gather_adv = district_type_indices.view(-1, 1, 1).expand(-1, 1, self.action_dim)
gather_val = district_type_indices.view(-1, 1, 1)
advantages = all_advantages.gather(dim=1, index=gather_adv).squeeze(1)
values = all_values.gather(dim=1, index=gather_val).squeeze(1)
return advantages, values
return self.advantage_head(features), self.value_head(features)
def _apply_action_mask(
self,
q_values: torch.Tensor,
action_mask: torch.Tensor,
) -> torch.Tensor:
masked_q_values = q_values.masked_fill(action_mask <= 0.0, -1.0e9)
all_invalid = action_mask.sum(dim=-1) <= 0.0
if all_invalid.any():
masked_q_values[all_invalid, 0] = 0.0
return masked_q_values
@dataclass
class RunningNormalizer:
epsilon: float = 1e-6
def __post_init__(self) -> None:
self.count = 0
self.mean: np.ndarray | None = None
self.m2: np.ndarray | None = None
def update(self, batch: np.ndarray) -> None:
array = np.asarray(batch, dtype=np.float64)
if array.ndim != 2:
raise ValueError("Normalizer expects a 2D batch of observations.")
if self.mean is None:
self.mean = np.zeros(array.shape[1], dtype=np.float64)
self.m2 = np.zeros(array.shape[1], dtype=np.float64)
for row in array:
self.count += 1
delta = row - self.mean
self.mean += delta / self.count
delta2 = row - self.mean
self.m2 += delta * delta2
def normalize(self, batch: np.ndarray) -> np.ndarray:
array = np.asarray(batch, dtype=np.float32)
if self.mean is None or self.m2 is None or self.count < 2:
return array
variance = self.m2 / max(1, self.count - 1)
std = np.sqrt(np.maximum(variance, self.epsilon))
return ((array - self.mean.astype(np.float32)) / std.astype(np.float32)).astype(
np.float32
)
def state_dict(self) -> dict:
return {
"count": self.count,
"mean": self.mean,
"m2": self.m2,
"epsilon": self.epsilon,
}
def load_state_dict(self, state_dict: dict) -> None:
self.count = int(state_dict["count"])
self.mean = state_dict["mean"]
self.m2 = state_dict["m2"]
self.epsilon = float(state_dict.get("epsilon", self.epsilon))