|
|
|
|
|
"""Shared mobility-classification utilities used across Task 2 helpers. |
|
|
|
|
|
This module provides the lightweight LWM classifier head plus supporting |
|
|
sampling and normalization helpers that were previously bundled inside the |
|
|
stand-alone mobility fine-tuning scripts. They remain available so that |
|
|
benchmarking, router training, and visualisation pipelines can reuse the same |
|
|
logic without depending on a separate CLI. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import glob |
|
|
import json |
|
|
from collections import defaultdict |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict, Iterable, List, Sequence, Tuple |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from pretraining.pretrained_model import lwm as lwm_model |
|
|
from task1.train_mcs_models import ( |
|
|
_extract_metadata, |
|
|
identify_modulation, |
|
|
load_all_samples, |
|
|
) |
|
|
|
|
|
MOBILITY_LABELS = ["static", "pedestrian", "vehicular"] |
|
|
BINARY_MOBILITY_LABELS = ["vehicular", "pedestrian"] |
|
|
|
|
|
|
|
|
def load_dataset_stats(models_root: Path) -> Dict[str, float | str]: |
|
|
"""Load dataset statistics (mean/std/normalization mode) from a models directory.""" |
|
|
stats_path = models_root / "dataset_stats.json" |
|
|
if not stats_path.exists(): |
|
|
print( |
|
|
f"[WARN] dataset_stats.json not found under {models_root}; " |
|
|
"falling back to per-sample normalization with mean=0/std=1.", |
|
|
flush=True, |
|
|
) |
|
|
return {"mean": 0.0, "std": 1.0, "normalization": "per_sample"} |
|
|
with open(stats_path, "r", encoding="utf-8") as f: |
|
|
stats = json.load(f) |
|
|
mean = float(stats.get("mean", 0.0)) |
|
|
std = float(stats.get("std", 1.0)) |
|
|
if std == 0.0: |
|
|
std = 1.0 |
|
|
normalization = str(stats.get("normalization", stats.get("mode", "dataset"))) |
|
|
return { |
|
|
"mean": mean, |
|
|
"std": std, |
|
|
"normalization": normalization, |
|
|
} |
|
|
|
|
|
|
|
|
def gather_controlled_groups( |
|
|
data_root: Path, |
|
|
cities: Sequence[str], |
|
|
comm: str, |
|
|
mobilities: Sequence[str], |
|
|
snrs: Sequence[str] | None, |
|
|
fft_whitelist: Sequence[str] | None, |
|
|
) -> Dict[Tuple[str, str, str, str, str], Dict[str, List[str]]]: |
|
|
"""Group spectrogram paths by (city, modulation, rate, SNR, FFT) while balancing mobilities.""" |
|
|
groups: Dict[Tuple[str, str, str, str, str], Dict[str, List[str]]] = defaultdict(lambda: defaultdict(list)) |
|
|
mobility_set = set(mobilities) |
|
|
snr_set = set(snrs) if snrs else None |
|
|
fft_set = set(fft_whitelist) if fft_whitelist else None |
|
|
|
|
|
for city in cities: |
|
|
base = data_root / city / comm |
|
|
if not base.exists(): |
|
|
continue |
|
|
pattern = str(base / "**" / "spectrograms" / "*.pkl") |
|
|
for path_str in glob.iglob(pattern, recursive=True): |
|
|
path = Path(path_str) |
|
|
rate, snr, mobility = _extract_metadata(path.parts) |
|
|
if mobility not in mobility_set: |
|
|
continue |
|
|
if snr_set is not None and snr not in snr_set: |
|
|
continue |
|
|
fft = next((part for part in path.parts if part.startswith("win")), "fft_unknown") |
|
|
if fft_set is not None and fft not in fft_set: |
|
|
continue |
|
|
_, modulation = identify_modulation(path_str) |
|
|
if modulation is None: |
|
|
continue |
|
|
key = (city, modulation, rate, snr, fft) |
|
|
groups[key][mobility].append(str(path)) |
|
|
return {key: dict(mob_map) for key, mob_map in groups.items()} |
|
|
|
|
|
|
|
|
def _collect_balanced_arrays( |
|
|
groups: Dict[Tuple[str, str, str, str, str], Dict[str, List[str]]], |
|
|
mobilities: Sequence[str], |
|
|
max_per_config: int, |
|
|
rng: np.random.Generator, |
|
|
) -> Tuple[np.ndarray, np.ndarray, Dict[str, Any]]: |
|
|
"""Load spectrogram arrays with per-configuration balance across mobilities.""" |
|
|
features: List[np.ndarray] = [] |
|
|
labels: List[np.ndarray] = [] |
|
|
mobility_to_idx = {mob: idx for idx, mob in enumerate(mobilities)} |
|
|
per_mobility_totals = {mob: 0 for mob in mobilities} |
|
|
matched_configs = 0 |
|
|
preview_configs: List[Tuple[str, str, str, str, str]] = [] |
|
|
|
|
|
for key, mobility_map in groups.items(): |
|
|
if not all(mob in mobility_map for mob in mobilities): |
|
|
continue |
|
|
|
|
|
cached_arrays: Dict[str, np.ndarray] = {} |
|
|
per_mobility_counts: List[int] = [] |
|
|
for mobility in mobilities: |
|
|
paths = mobility_map[mobility] |
|
|
collected: List[np.ndarray] = [] |
|
|
for path in paths: |
|
|
arr = load_all_samples(path) |
|
|
if arr.size == 0: |
|
|
continue |
|
|
collected.append(arr) |
|
|
if not collected: |
|
|
cached_arrays = {} |
|
|
break |
|
|
stacked = np.concatenate(collected, axis=0) |
|
|
cached_arrays[mobility] = stacked |
|
|
per_mobility_counts.append(stacked.shape[0]) |
|
|
|
|
|
if len(cached_arrays) != len(mobilities): |
|
|
continue |
|
|
|
|
|
limit = min(per_mobility_counts) |
|
|
if max_per_config > 0: |
|
|
limit = min(limit, max_per_config) |
|
|
if limit == 0: |
|
|
continue |
|
|
|
|
|
for mobility in mobilities: |
|
|
arr = cached_arrays[mobility] |
|
|
if arr.shape[0] > limit: |
|
|
indices = rng.permutation(arr.shape[0])[:limit] |
|
|
arr = arr[indices] |
|
|
features.append(arr) |
|
|
labels.append(np.full(arr.shape[0], mobility_to_idx[mob], dtype=np.int64)) |
|
|
per_mobility_totals[mobility] += arr.shape[0] |
|
|
|
|
|
if matched_configs < 5: |
|
|
preview_configs.append(key) |
|
|
matched_configs += 1 |
|
|
|
|
|
if not features: |
|
|
return ( |
|
|
np.empty((0, 128, 128), dtype=np.float32), |
|
|
np.empty((0,), dtype=np.int64), |
|
|
{"per_mobility": per_mobility_totals, "matched_configs": matched_configs, "preview_configs": preview_configs}, |
|
|
) |
|
|
|
|
|
stacked_features = np.concatenate(features, axis=0).astype(np.float32, copy=False) |
|
|
stacked_labels = np.concatenate(labels, axis=0).astype(np.int64, copy=False) |
|
|
return stacked_features, stacked_labels, { |
|
|
"per_mobility": per_mobility_totals, |
|
|
"matched_configs": matched_configs, |
|
|
"preview_configs": preview_configs, |
|
|
} |
|
|
|
|
|
|
|
|
class ResidualBlock1D(nn.Module): |
|
|
"""1D Residual block used by the Res1DCNN classification head.""" |
|
|
|
|
|
def __init__(self, in_channels: int, out_channels: int) -> None: |
|
|
super().__init__() |
|
|
self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1) |
|
|
self.bn1 = nn.BatchNorm1d(out_channels) |
|
|
self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1) |
|
|
self.bn2 = nn.BatchNorm1d(out_channels) |
|
|
self.shortcut = nn.Sequential() |
|
|
if in_channels != out_channels: |
|
|
self.shortcut = nn.Sequential( |
|
|
nn.Conv1d(in_channels, out_channels, kernel_size=1), |
|
|
nn.BatchNorm1d(out_channels), |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
residual = x |
|
|
x = F.relu(self.bn1(self.conv1(x))) |
|
|
x = self.bn2(self.conv2(x)) |
|
|
x += self.shortcut(residual) |
|
|
return F.relu(x) |
|
|
|
|
|
|
|
|
class Res1DCNNHead(nn.Module): |
|
|
"""Compact ResNet-style 1D head for classifying 128-d embeddings.""" |
|
|
|
|
|
def __init__(self, input_dim: int, num_classes: int, dropout: float = 0.5) -> None: |
|
|
super().__init__() |
|
|
hidden_dim = 64 |
|
|
self.conv1 = nn.Conv1d(1, hidden_dim, kernel_size=3, padding=1) |
|
|
self.bn1 = nn.BatchNorm1d(hidden_dim) |
|
|
self.res_block = ResidualBlock1D(hidden_dim, hidden_dim) |
|
|
self.fc = nn.Linear(hidden_dim, num_classes) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = x.unsqueeze(1) |
|
|
x = F.relu(self.bn1(self.conv1(x))) |
|
|
x = self.res_block(x) |
|
|
x = F.adaptive_avg_pool1d(x, 1).squeeze(-1) |
|
|
x = self.dropout(x) |
|
|
return self.fc(x) |
|
|
|
|
|
|
|
|
class LWMClassifierMinimal(nn.Module): |
|
|
"""LWM backbone wrapper with configurable classifier and optional projection head.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
backbone: nn.Module, |
|
|
num_classes: int, |
|
|
classifier_dim: int, |
|
|
dropout: float, |
|
|
trainable_layers: int, |
|
|
projection_dim: int, |
|
|
append_input_stats: bool, |
|
|
normalization_stats: Dict[str, object] | None, |
|
|
head_type: str = "mlp", |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.backbone = backbone |
|
|
self.patch_size = 4 |
|
|
self.unfold = nn.Unfold(kernel_size=self.patch_size, stride=self.patch_size) |
|
|
self.head_type = head_type |
|
|
|
|
|
self.append_input_stats = bool(append_input_stats) |
|
|
stats_info = normalization_stats or {} |
|
|
self.normalization_mode = str(stats_info.get("normalization", "dataset")).lower() |
|
|
self.dataset_mean = float(stats_info.get("mean", 0.0)) |
|
|
self.dataset_std = float(stats_info.get("std", 1.0)) |
|
|
if abs(self.dataset_std) < 1e-6: |
|
|
self.dataset_std = 1e-6 |
|
|
base_dim = 128 |
|
|
stats_dim = 2 if self.append_input_stats else 0 |
|
|
input_dim = base_dim + stats_dim |
|
|
|
|
|
classifier_dim = max(32, int(classifier_dim)) |
|
|
dropout = max(0.0, float(dropout)) |
|
|
|
|
|
if head_type == "linear": |
|
|
self.classifier = nn.Sequential( |
|
|
nn.LayerNorm(input_dim), |
|
|
nn.Linear(input_dim, num_classes), |
|
|
) |
|
|
elif head_type == "res1dcnn": |
|
|
self.classifier = nn.Sequential( |
|
|
nn.LayerNorm(input_dim), |
|
|
Res1DCNNHead(input_dim, num_classes, dropout=dropout), |
|
|
) |
|
|
else: |
|
|
head_layers: List[nn.Module] = [ |
|
|
nn.LayerNorm(input_dim), |
|
|
nn.Linear(input_dim, classifier_dim), |
|
|
nn.GELU(), |
|
|
] |
|
|
if dropout > 0: |
|
|
head_layers.append(nn.Dropout(dropout)) |
|
|
head_layers.append(nn.Linear(classifier_dim, num_classes)) |
|
|
self.classifier = nn.Sequential(*head_layers) |
|
|
|
|
|
proj_dim = int(projection_dim) |
|
|
if proj_dim > 0: |
|
|
self.projection_head = nn.Sequential( |
|
|
nn.Linear(128, proj_dim), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Linear(proj_dim, proj_dim), |
|
|
) |
|
|
else: |
|
|
self.projection_head = None |
|
|
|
|
|
for param in self.backbone.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
if trainable_layers > 0: |
|
|
layers = getattr(self.backbone, "layers", None) |
|
|
if layers is not None: |
|
|
trainable_layers = min(trainable_layers, len(layers)) |
|
|
for layer in layers[-trainable_layers:]: |
|
|
for param in layer.parameters(): |
|
|
param.requires_grad = True |
|
|
|
|
|
def spectrogram_to_tokens(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = x.unsqueeze(1) |
|
|
patches = self.unfold(x).transpose(1, 2) |
|
|
cls_token = torch.full( |
|
|
(patches.size(0), 1, patches.size(-1)), |
|
|
0.2, |
|
|
dtype=patches.dtype, |
|
|
device=patches.device, |
|
|
) |
|
|
return torch.cat([cls_token, patches], dim=1) |
|
|
|
|
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor: |
|
|
tokens = self.spectrogram_to_tokens(x) |
|
|
outputs = self.backbone(tokens) |
|
|
if outputs.size(1) <= 1: |
|
|
return outputs[:, 0, :] |
|
|
return outputs[:, 1:, :].mean(dim=1) |
|
|
|
|
|
def _collect_input_stats(self, x: torch.Tensor) -> torch.Tensor: |
|
|
mean = x.mean(dim=(1, 2)) |
|
|
std = x.std(dim=(1, 2), unbiased=False) |
|
|
if self.normalization_mode == "dataset": |
|
|
mean = mean * self.dataset_std + self.dataset_mean |
|
|
std = std * self.dataset_std |
|
|
return torch.stack([mean, std], dim=1) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
*, |
|
|
input_stats: torch.Tensor | None = None, |
|
|
return_projection: bool = False, |
|
|
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]: |
|
|
features = self.forward_features(x) |
|
|
classifier_input = features |
|
|
if self.append_input_stats: |
|
|
stats = input_stats if input_stats is not None else self._collect_input_stats(x) |
|
|
if stats.dtype != classifier_input.dtype: |
|
|
stats = stats.to(classifier_input.dtype) |
|
|
stats = stats.to(classifier_input.device) |
|
|
classifier_input = torch.cat([classifier_input, stats], dim=1) |
|
|
logits = self.classifier(classifier_input) |
|
|
if return_projection: |
|
|
projection = self.projection_head(features) if self.projection_head is not None else None |
|
|
return logits, projection |
|
|
return logits |
|
|
|
|
|
|
|
|
def prepare_model( |
|
|
checkpoint: Path, |
|
|
num_classes: int, |
|
|
classifier_dim: int, |
|
|
dropout: float, |
|
|
trainable_layers: int, |
|
|
projection_dim: int, |
|
|
*, |
|
|
append_input_stats: bool = False, |
|
|
normalization_stats: Dict[str, object] | None = None, |
|
|
head_type: str = "mlp", |
|
|
) -> nn.Module: |
|
|
"""Instantiate an LWM backbone with the minimal classifier head.""" |
|
|
backbone = lwm_model(element_length=16, d_model=128, n_layers=12, max_len=1025, n_heads=8, dropout=0.1) |
|
|
state = torch.load(checkpoint, map_location="cpu") |
|
|
if any(k.startswith("module.") for k in state): |
|
|
state = {k.replace("module.", ""): v for k, v in state.items()} |
|
|
backbone.load_state_dict(state, strict=False) |
|
|
return LWMClassifierMinimal( |
|
|
backbone, |
|
|
num_classes=num_classes, |
|
|
classifier_dim=classifier_dim, |
|
|
dropout=dropout, |
|
|
trainable_layers=trainable_layers, |
|
|
projection_dim=projection_dim, |
|
|
append_input_stats=append_input_stats, |
|
|
normalization_stats=normalization_stats, |
|
|
head_type=head_type, |
|
|
) |
|
|
|
|
|
|
|
|
def supervised_contrastive_loss( |
|
|
features: torch.Tensor, |
|
|
labels: torch.Tensor, |
|
|
temperature: float, |
|
|
) -> torch.Tensor: |
|
|
"""Supervised contrastive loss over a batch of feature embeddings.""" |
|
|
batch_size = features.size(0) |
|
|
if batch_size < 2: |
|
|
return features.new_tensor(0.0) |
|
|
|
|
|
features = F.normalize(features, dim=1) |
|
|
similarity = torch.div(torch.matmul(features, features.T), max(temperature, 1e-6)) |
|
|
logits_max, _ = similarity.max(dim=1, keepdim=True) |
|
|
similarity = similarity - logits_max.detach() |
|
|
|
|
|
device = features.device |
|
|
labels = labels.contiguous().view(-1, 1) |
|
|
mask = torch.eq(labels, labels.T).float().to(device) |
|
|
logits_mask = torch.ones_like(mask) - torch.eye(batch_size, device=device) |
|
|
mask = mask * logits_mask |
|
|
|
|
|
exp_logits = torch.exp(similarity) * logits_mask |
|
|
log_prob = similarity - torch.log(exp_logits.sum(dim=1, keepdim=True) + 1e-12) |
|
|
|
|
|
mask_sum = mask.sum(dim=1) |
|
|
valid = mask_sum > 0 |
|
|
if not torch.any(valid): |
|
|
return features.new_tensor(0.0) |
|
|
|
|
|
mean_log_prob_pos = (mask * log_prob).sum(dim=1) / mask_sum.clamp_min(1e-12) |
|
|
loss = -mean_log_prob_pos[valid].mean() |
|
|
return loss |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"BINARY_MOBILITY_LABELS", |
|
|
"LWMClassifierMinimal", |
|
|
"MOBILITY_LABELS", |
|
|
"Res1DCNNHead", |
|
|
"_collect_balanced_arrays", |
|
|
"gather_controlled_groups", |
|
|
"load_dataset_stats", |
|
|
"prepare_model", |
|
|
"supervised_contrastive_loss", |
|
|
] |
|
|
|