lwm-spectro / task1 /train_mcs_models.py
wi-lab's picture
Upload task1/train_mcs_models.py with huggingface_hub
a48635b verified
#!/usr/bin/env python3
"""Benchmark multiple backbones on modulation classification across dataset sizes.
For each desired training size (samples per MCS class) and repetition, the script:
1. Randomly samples spectrograms from distinct (modulation, rate, SNR, doppler) configs.
2. Builds train/val/test splits (val/test sizes are configurable).
3. Fine-tunes several backbones (LWM, ResNet18, EfficientNet-B0, MobileNet-V3,
and a small CNN) using the same splits.
4. Reports accuracy statistics and stores checkpoints/metrics per experiment.
Input spectrograms are globally normalized using the dataset mean/std stored with
the specified pretrained checkpoint (defaults to the latest run under `models/`).
Usage example (defaults cover city_1_losangeles/LTE with all available SNR·mobility combos):
python task1/train_mcs_models.py --train-sizes 128 --models resnet18 mobilenet_v3_small
"""
from __future__ import annotations
import argparse
import copy
import csv
import glob
import json
import os
import pickle
import random
import re
import sys
from collections import Counter, defaultdict
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
from contextlib import nullcontext
from datetime import datetime
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.amp import autocast, GradScaler
try:
from tqdm import tqdm
except ImportError: # pragma: no cover - optional dependency
def tqdm(iterable, *args, **kwargs):
return iterable
PROJECT_ROOT = Path(__file__).resolve().parent.parent
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from pretraining.pretrained_model import lwm as lwm_model
from utils import count_parameters
COMM_CANONICAL = {
"lte": "LTE",
"wifi": "WiFi",
"5g": "5G",
}
COMM_LOWER = {v: k for k, v in COMM_CANONICAL.items()}
try:
from sklearn.metrics import f1_score as sklearn_f1_score
HAVE_SKLEARN = True
except ImportError:
HAVE_SKLEARN = False
try:
import matplotlib.pyplot as plt
HAVE_MPL = True
except ImportError:
HAVE_MPL = False
try:
from task2.mobility_utils import LWMClassifierMinimal # type: ignore
except ImportError: # pragma: no cover - optional dependency
LWMClassifierMinimal = None # type: ignore[misc]
# HPU support detection
HPU_AVAILABLE = False
try:
import habana_frameworks.torch.core as htcore # type: ignore[import-not-found]
HPU_AVAILABLE = hasattr(torch, "hpu") and torch.hpu.is_available()
except (ImportError, AttributeError):
pass
def compute_f1(y_true: np.ndarray, y_pred: np.ndarray) -> float:
if HAVE_SKLEARN:
return float(sklearn_f1_score(y_true, y_pred, average="macro"))
classes = np.unique(np.concatenate([y_true, y_pred]))
scores = []
for cls in classes:
tp = np.sum((y_true == cls) & (y_pred == cls))
fp = np.sum((y_true != cls) & (y_pred == cls))
fn = np.sum((y_true == cls) & (y_pred != cls))
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
denom = precision + recall
f1 = (2 * precision * recall / denom) if denom > 0 else 0.0
scores.append(f1)
return float(np.mean(scores))
MODULATION_LABELS = {
"BPSK": 0,
"QPSK": 1,
"QAM16": 2,
"QAM64": 3,
"QAM256": 4,
}
LABEL_NAMES = {idx: name for name, idx in MODULATION_LABELS.items()}
DEFAULT_LWM_TRAINABLE_LAYERS = 2 # fine-tune the last two transformer blocks
_SAMPLE_COUNT_CACHE: Dict[str, int] = {}
def normalize_per_sample(specs: np.ndarray, eps: float = 1e-6) -> np.ndarray:
if specs.size == 0:
return specs.astype(np.float32, copy=False)
means = specs.mean(axis=(1, 2), keepdims=True)
stds = specs.std(axis=(1, 2), keepdims=True)
stds = np.maximum(stds, eps)
normalized = (specs - means) / stds
return normalized.astype(np.float32, copy=False)
def apply_normalization(specs: np.ndarray, stats: Dict[str, object]) -> np.ndarray:
mode = str(stats.get("normalization", "dataset")).lower()
mean = float(stats.get("mean", 0.0))
std = float(stats.get("std", 1.0))
if abs(std) < 1e-6:
std = 1e-6
if mode == "dataset":
return ((specs.astype(np.float32, copy=False) - mean) / std).astype(np.float32, copy=False)
return normalize_per_sample(specs)
def _unique_parameters(params: Iterable[nn.Parameter]) -> List[nn.Parameter]:
seen: set[int] = set()
unique: List[nn.Parameter] = []
for param in params:
pid = id(param)
if pid not in seen:
unique.append(param)
seen.add(pid)
return unique
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--data-root",
default=str(PROJECT_ROOT / "spectrograms"),
help="Root directory containing city folders (default: project_root/spectrograms)",
)
parser.add_argument(
"--cities",
nargs="*",
default=["city_1_losangeles"],
help="City directories to include (default: %(default)s)",
)
parser.add_argument("--comm-types", nargs="*", default=["LTE"], help="Communication standards to include (default: %(default)s)")
parser.add_argument("--LTE", dest="select_lte", action="store_true", help="Shortcut for --comm-types LTE")
parser.add_argument("--WiFi", dest="select_wifi", action="store_true", help="Shortcut for --comm-types WiFi")
parser.add_argument("--5G", dest="select_5g", action="store_true", help="Shortcut for --comm-types 5G")
parser.add_argument("--snrs", nargs="*", default=None, help="SNR folders to include for training (default: all available)")
parser.add_argument("--val-snrs", nargs="*", default=None, help="SNR folders for validation/test (default: all available)")
parser.add_argument(
"--mobilities",
nargs="*",
default=None,
help="Mobility folders to include for training (default: all available)",
)
parser.add_argument(
"--val-mobilities",
nargs="*",
default=None,
help="Mobility folders for validation/test (default: all available)",
)
parser.add_argument("--fft-folder", default="512FFT", help="FFT folder name (default: %(default)s)")
parser.add_argument(
"--device",
type=str,
default="auto",
choices=["auto", "cuda", "hpu", "cpu"],
help="Device to use for training (default: auto - detects HPU, then CUDA, then CPU)",
)
parser.add_argument(
"--gpu-ids",
type=int,
nargs="*",
default=None,
help="Specific GPU device IDs to use (only for CUDA, default: all visible GPUs)",
)
parser.add_argument(
"--train-sizes",
type=int,
nargs="*",
default=[2, 4, 8, 16, 32, 64, 128, 256],
help="Training samples per class to benchmark",
)
parser.add_argument("--val-per-class", type=int, default=512, help="Validation samples per class")
parser.add_argument("--test-per-class", type=int, default=512, help="Test samples per class")
parser.add_argument("--repetitions", type=int, default=1, help="Repetitions per train size")
parser.add_argument("--epochs", type=int, default=200, help="Epochs per run")
parser.add_argument("--batch-size", type=int, default=32, help="Mini-batch size")
parser.add_argument("--lr", type=float, default=8e-4, help="Learning rate for fine-tuning")
parser.add_argument("--weight-decay", type=float, default=3e-2, help="Weight decay")
parser.add_argument(
"--no-epoch-history",
action="store_true",
help="Disable aggregated per-epoch history tracking",
)
parser.add_argument(
"--no-epoch-plot",
action="store_true",
help="Disable per-repetition metric plots",
)
parser.add_argument(
"--save-epoch-checkpoints",
action="store_true",
help="Persist per-epoch checkpoints (default: disabled)",
)
parser.add_argument(
"--backbone-lr-factor",
type=float,
default=0.3,
help="Relative LR multiplier applied to unfrozen backbone parameters (default: %(default)s)",
)
parser.add_argument(
"--early-patience",
type=int,
default=5,
help="Early stopping patience based on validation F1 (default: %(default)s)",
)
parser.add_argument(
"--early-min-epochs",
type=int,
default=10,
help="Minimum number of epochs to run before early stopping can trigger (default: %(default)s)",
)
parser.add_argument(
"--finetune-epochs",
type=int,
default=0,
help="Additional fine-tuning epochs to run after the main schedule (default: %(default)s)",
)
parser.add_argument(
"--finetune-lr-factor",
type=float,
default=0.1,
help="Multiplier applied to the base learning rate during fine-tuning (default: %(default)s)",
)
parser.add_argument(
"--finetune-patience",
type=int,
default=3,
help="Early stopping patience for the fine-tuning phase (default: %(default)s)",
)
parser.add_argument(
"--finetune-min-epochs",
type=int,
default=0,
help="Minimum epochs to execute in the fine-tuning phase before early stopping is considered (default: %(default)s)",
)
parser.add_argument(
"--debug-eval-batches",
type=int,
default=0,
help="Log detailed stats for the first N evaluation batches (0 disables logging)",
)
parser.add_argument(
"--debug-eval-interval",
type=int,
default=1,
help="Evaluate logging interval in batches when debug logging is enabled (default: %(default)s)",
)
parser.add_argument(
"--debug-eval-softmax",
action="store_true",
help="When debugging evaluations, also log softmax statistics per batch",
)
parser.add_argument(
"--models",
nargs="*",
default=["lwm", "resnet18", "efficientnet_b0", "mobilenet_v3_small", "simple_cnn", "ieee_cnn"],
help="Models to benchmark",
)
parser.add_argument(
"--raw-input-models",
nargs="*",
default=None,
help=(
"Models that should receive raw spectrograms without additional normalization "
"(default: all non-LWM models)"
),
)
parser.add_argument(
"--lwm-trainable-layers",
type=int,
default=2,
help="Number of transformer layers (from the end) to fine-tune in LWM (default: %(default)s)",
)
parser.add_argument(
"--lwm-classifier-dim",
type=int,
default=64,
help="Hidden width for the LWM classifier MLP head (default: %(default)s; ignored for linear head)",
)
parser.add_argument(
"--lwm-head-dropout",
type=float,
default=0.0,
help="Dropout applied inside the LWM classifier head (default: %(default)s)",
)
parser.add_argument(
"--lwm-head-type",
choices=("linear", "mlp", "res1dcnn"),
default="res1dcnn",
help="Classifier head architecture for LWM (default: %(default)s)",
)
parser.add_argument(
"--lwm-backbone-lr-factor",
type=float,
default=0.2,
help="LR multiplier for unfrozen LWM backbone layers (default: %(default)s)",
)
parser.add_argument(
"--resnet-head-width",
type=int,
default=512,
help="Hidden width for the ResNet18 classifier head (default: %(default)s)",
)
parser.add_argument(
"--efficientnet-head-width",
type=int,
default=296,
help="Hidden width for the EfficientNet-B0 classifier head (default: %(default)s)",
)
parser.add_argument(
"--mobilenet-head-width",
type=int,
default=576,
help="Hidden width for the MobileNetV3-Small classifier head (default: %(default)s)",
)
parser.add_argument(
"--imagenet-head-dropout",
type=float,
default=0.6,
help="Dropout probability used inside ImageNet backbone classifier heads (default: %(default)s)",
)
parser.add_argument(
"--imagenet-weight-decay-scale",
type=float,
default=2.0,
help="Multiplier applied to weight decay for ImageNet backbone trainable parameters (default: %(default)s)",
)
parser.add_argument(
"--simple-cnn-hidden-dims",
type=int,
nargs="*",
default=[272, 128],
help="Hidden layer widths for the Simple CNN classifier (default: %(default)s)",
)
parser.add_argument(
"--ieee-cnn-hidden-dims",
type=int,
nargs="*",
default=[512, 256],
help="Hidden layer widths for the IEEE CNN classifier (default: %(default)s)",
)
parser.add_argument(
"--ieee-cnn-dropout",
type=float,
default=0.3,
help="Dropout rate for the IEEE CNN model (default: %(default)s)",
)
parser.add_argument("--checkpoint", type=Path, default=None, help="Path to pretrained LWM checkpoint (.pth)")
parser.add_argument("--stats", type=Path, default=None, help="dataset_stats.json path")
parser.add_argument(
"--models-root",
type=Path,
default=PROJECT_ROOT / "models",
help="Root with pretrained runs (default: project_root/models)",
)
parser.add_argument(
"--output-dir",
type=Path,
default=PROJECT_ROOT / "task1" / "mcs_benchmarks",
help="Results root directory (per-run subfolder created automatically)",
)
parser.add_argument(
"--export-full-model",
type=Path,
default=None,
help="Directory where best full-model checkpoints (backbone + head) will be exported per run",
)
parser.add_argument("--seed", type=int, default=42, help="Base random seed")
args = parser.parse_args()
args.output_root = args.output_dir
quick_comm: List[str] = []
if getattr(args, "select_lte", False):
quick_comm.append("LTE")
if getattr(args, "select_wifi", False):
quick_comm.append("WiFi")
if getattr(args, "select_5g", False):
quick_comm.append("5G")
if quick_comm:
args.comm_types = quick_comm
normalized: List[str] = []
for comm in args.comm_types:
upper = comm.upper()
if upper == "WIFI":
normalized.append("WiFi")
elif upper == "LTE":
normalized.append("LTE")
elif upper == "5G":
normalized.append("5G")
else:
normalized.append(comm)
args.comm_types = normalized
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
comm_tokens: List[str] = []
for comm in args.comm_types:
canonical = COMM_LOWER.get(comm, comm.lower())
token = re.sub(r"[^a-z0-9]+", "-", canonical.lower()).strip("-")
comm_tokens.append(token or "unknown")
comm_suffix = "-".join(comm_tokens) if comm_tokens else "unknown"
args.run_timestamp = timestamp
args.output_dir = args.output_root / comm_suffix / timestamp
args.comm_suffix = comm_suffix
if args.gpu_ids is not None and len(args.gpu_ids) == 0:
args.gpu_ids = None
if not args.simple_cnn_hidden_dims:
args.simple_cnn_hidden_dims = [512, 256]
if not args.ieee_cnn_hidden_dims:
args.ieee_cnn_hidden_dims = [512, 256]
if args.raw_input_models is None:
args.raw_input_models = [
model.lower()
for model in args.models
if model.lower() not in {"lwm"}
]
else:
args.raw_input_models = [model.lower() for model in args.raw_input_models]
args.models = [model for model in args.models]
args.save_epoch_history = not args.no_epoch_history
args.plot_epoch_history = not args.no_epoch_plot
args.imagenet_head_dropout = float(max(0.0, min(args.imagenet_head_dropout, 0.95)))
args.imagenet_weight_decay_scale = float(max(0.0, args.imagenet_weight_decay_scale))
return args
def find_latest_run(models_root: Path) -> Path:
run_dirs = [p for p in models_root.iterdir() if p.is_dir()]
run_dirs = [p for p in run_dirs if not p.name.lower().endswith("_models")]
valid_runs = [p for p in run_dirs if any(p.glob("*.pth"))]
if valid_runs:
return max(valid_runs, key=lambda p: p.stat().st_mtime)
checkpoints = list(models_root.glob("*.pth"))
if checkpoints:
print(f"[INFO] No checkpoint-bearing subdirectories under {models_root}; using root as run directory.")
return models_root
raise FileNotFoundError(f"No checkpoints found under {models_root}")
def find_best_checkpoint(run_dir: Path) -> Path:
candidates = list(run_dir.glob("*.pth"))
if not candidates:
raise FileNotFoundError(f"No checkpoints in {run_dir}")
def metric(path: Path) -> float:
match = re.search(r"_val([0-9]+(?:\.[0-9]+)?)", path.name)
if match:
try:
return float(match.group(1))
except ValueError:
pass
return float("inf")
best = min(candidates, key=metric)
return best
def resolve_models_directory(args: argparse.Namespace) -> Path:
base = args.models_root.expanduser().resolve()
if not base.exists():
raise FileNotFoundError(f"Models root not found: {base}")
matches: List[Path] = []
for comm in args.comm_types:
subdir = base / f"{comm}_models"
if subdir.exists():
matches.append(subdir)
else:
print(f"[WARN] Models directory for {comm} not found at {subdir}")
if len(matches) == 1:
print(f"[INFO] Using models directory for {args.comm_types[0]}: {matches[0]}")
return matches[0]
if len(matches) > 1:
raise ValueError(
"Multiple communication-specific model directories detected; please provide --checkpoint explicitly."
)
print(f"[INFO] Using shared models directory: {base}")
return base
def resolve_checkpoint_and_stats(args: argparse.Namespace, require_checkpoint: bool) -> Tuple[Path | None, Dict[str, object]]:
checkpoint: Path | None = None
models_dir = resolve_models_directory(args)
user_provided_stats = args.stats is not None
if args.checkpoint is not None:
checkpoint = args.checkpoint.expanduser().resolve()
if not checkpoint.exists():
raise FileNotFoundError(f"Checkpoint not found: {checkpoint}")
stats_path = args.stats.expanduser().resolve() if user_provided_stats else checkpoint.parent / "dataset_stats.json"
else:
run_dir = find_latest_run(models_dir)
stats_path = run_dir / "dataset_stats.json"
if require_checkpoint:
checkpoint = find_best_checkpoint(run_dir)
else:
checkpoint = None
if stats_path.exists():
try:
with open(stats_path, "r", encoding="utf-8") as f:
stats = json.load(f)
except json.JSONDecodeError as exc:
if user_provided_stats:
raise ValueError(
f"Failed to parse dataset_stats.json at {stats_path}: {exc}"
) from exc
print(
f"[WARN] Corrupt dataset_stats.json at {stats_path}; "
"falling back to mean=0/std=1 per-sample normalization."
)
stats = {"mean": 0.0, "std": 1.0, "normalization": "per_sample"}
else:
if "mean" not in stats or "std" not in stats:
raise ValueError("dataset_stats.json must contain 'mean' and 'std'")
stats.setdefault("normalization", stats.get("mode", "dataset"))
else:
if user_provided_stats:
raise FileNotFoundError(f"dataset_stats.json not found: {stats_path}")
stats = {"mean": 0.0, "std": 1.0, "normalization": "per_sample"}
print(f"[WARN] dataset_stats.json not found at {stats_path}. Falling back to per-sample normalization.")
if checkpoint is not None:
print(f"[INFO] Using checkpoint: {checkpoint}")
elif require_checkpoint:
raise FileNotFoundError("LWM requested but no checkpoint available")
else:
print("[INFO] No LWM checkpoint required for selected models")
norm_mode = str(stats.get("normalization", "dataset"))
if norm_mode.lower() == "dataset":
print(f"[INFO] Dataset stats -> mean={stats['mean']:.4f}, std={stats['std']:.4f}")
else:
print("[INFO] Normalization mode: per_sample")
return checkpoint, {
"mean": float(stats.get("mean", 0.0)),
"std": float(stats.get("std", 1.0)),
"normalization": norm_mode,
}
def identify_modulation(path: str) -> tuple[int | None, str | None]:
for mod_name, label in MODULATION_LABELS.items():
if mod_name in path:
return label, mod_name
return None, None
def _extract_metadata(parts: Sequence[str]) -> Tuple[str, str, str]:
rate = next((part for part in parts if part.startswith("rate")), "rate_unknown")
snr = next((part for part in parts if part.startswith("SNR")), "SNR_unknown")
mobility = next((part for part in parts if part in {"static", "pedestrian", "vehicular"}), "mobility_unknown")
return rate, snr, mobility
def discover_snr_mobility(
data_root: Path,
cities: Sequence[str],
comm_types: Sequence[str],
fft_folder: str,
) -> Tuple[List[str], List[str]]:
snrs: set[str] = set()
mobilities: set[str] = set()
for city in cities:
for comm in comm_types:
base = data_root / city / comm
if not base.exists():
continue
for root, dirs, _ in os.walk(base):
parts = Path(root).parts
for part in parts:
if part.startswith("SNR") and part.endswith("dB"):
snrs.add(part)
elif part in {"static", "pedestrian", "vehicular"}:
mobilities.add(part)
if not snrs:
snrs.add("SNR20dB")
if not mobilities:
mobilities.add("static")
return sorted(snrs), sorted(mobilities)
def build_config_map(
data_root: Path,
cities: Sequence[str],
comm_types: Sequence[str],
snrs: Sequence[str],
mobilities: Sequence[str],
fft_folder: str,
) -> Dict[int, Dict[str, List[str]]]:
class_configs: Dict[int, Dict[str, List[str]]] = defaultdict(lambda: defaultdict(list))
for city in cities:
for comm in comm_types:
base = data_root / city / comm
for snr in snrs:
for mobility in mobilities:
pattern = str(base / "**" / snr / mobility / "**" / fft_folder / "**" / "spectrograms" / "*.pkl")
for path_str in glob.glob(pattern, recursive=True):
cls, modulation_name = identify_modulation(path_str)
if cls is None:
continue
rate, _, _ = _extract_metadata(Path(path_str).parts)
config_name = f"{modulation_name}_{rate}_{snr}_{mobility}"
class_configs[cls][config_name].append(path_str)
return class_configs
def build_global_config_map(
data_root: Path,
cities: Sequence[str],
comm_types: Sequence[str],
fft_folder: str,
) -> Dict[int, Dict[str, List[str]]]:
class_configs: Dict[int, Dict[str, List[str]]] = defaultdict(lambda: defaultdict(list))
for city in cities:
for comm in comm_types:
base = data_root / city / comm
pattern = str(base / "**" / fft_folder / "**" / "spectrograms" / "*.pkl")
for path_str in glob.glob(pattern, recursive=True):
cls, modulation_name = identify_modulation(path_str)
if cls is None:
continue
rate, snr_part, mobility_part = _extract_metadata(Path(path_str).parts)
config_name = f"{modulation_name}_{rate}_{snr_part}_{mobility_part}"
class_configs[cls][config_name].append(path_str)
return class_configs
def _count_samples_in_path(path: str) -> int:
cached = _SAMPLE_COUNT_CACHE.get(path)
if cached is not None:
return cached
arr = load_all_samples(path)
count = int(arr.shape[0])
_SAMPLE_COUNT_CACHE[path] = count
return count
class LazyConfigArray:
"""Lazily views spectrograms spread across multiple pickled files."""
__slots__ = ("paths", "_counts", "_offsets", "_total", "shape", "dtype", "ndim")
def __init__(self, paths: Sequence[str]) -> None:
filtered_paths: List[str] = []
counts: List[int] = []
for path in sorted(paths):
count = _count_samples_in_path(path)
if count <= 0:
continue
filtered_paths.append(path)
counts.append(count)
self.paths: Tuple[str, ...] = tuple(filtered_paths)
if counts:
self._counts = np.array(counts, dtype=np.int64)
self._offsets = np.concatenate(([0], np.cumsum(self._counts)))
self._total = int(self._offsets[-1])
else:
self._counts = np.empty(0, dtype=np.int64)
self._offsets = np.array([0], dtype=np.int64)
self._total = 0
self.shape = (self._total, 128, 128)
self.dtype = np.float32
self.ndim = 3
def __len__(self) -> int:
return self._total
def _resolve_index(self, index: int) -> Tuple[int, int]:
if self._total == 0:
raise IndexError("attempting to index empty LazyConfigArray")
if index < 0:
index += self._total
if index < 0 or index >= self._total:
raise IndexError("index out of range for LazyConfigArray")
path_idx = int(np.searchsorted(self._offsets[1:], index, side="right"))
start = int(self._offsets[path_idx])
return path_idx, int(index - start)
def _load_path(self, path_idx: int) -> np.ndarray:
path = self.paths[path_idx]
return load_all_samples(path)
def __getitem__(self, item: Any) -> np.ndarray:
if isinstance(item, (int, np.integer)):
path_idx, local_idx = self._resolve_index(int(item))
data = self._load_path(path_idx)
sample = data[local_idx].copy()
return sample
indices = np.asarray(item, dtype=np.int64)
if indices.ndim == 0:
indices = indices.reshape(1)
else:
indices = indices.reshape(-1)
if indices.size == 0:
return np.empty((0, 128, 128), dtype=np.float32)
resolved: Dict[int, List[Tuple[int, int]]] = {}
for pos, raw_idx in enumerate(indices):
path_idx, local_idx = self._resolve_index(int(raw_idx))
resolved.setdefault(path_idx, []).append((pos, local_idx))
result = np.empty((indices.size, 128, 128), dtype=np.float32)
for path_idx, items in resolved.items():
data = self._load_path(path_idx)
local_positions = [loc for _, loc in items]
chunk = data[local_positions]
for offset, (pos, _) in enumerate(items):
result[pos] = chunk[offset]
return result
def load_config_arrays(class_configs: Dict[int, Dict[str, List[str]]]) -> Dict[int, Dict[str, LazyConfigArray]]:
loaded: Dict[int, Dict[str, LazyConfigArray]] = {}
for cls, configs in class_configs.items():
arrays_for_cls: Dict[str, LazyConfigArray] = {}
for config_name, paths in configs.items():
lazy_array = LazyConfigArray(paths)
if len(lazy_array) == 0:
continue
arrays_for_cls[config_name] = lazy_array
if arrays_for_cls:
loaded[cls] = arrays_for_cls
return loaded
def load_all_samples(path: str) -> np.ndarray:
with open(path, "rb") as f:
data = pickle.load(f)
if isinstance(data, dict) and "spectrograms" in data:
arr = data["spectrograms"]
elif isinstance(data, np.ndarray):
arr = data
else:
return np.empty((0, 128, 128), dtype=np.float32)
arr = np.asarray(arr, dtype=np.float32)
if arr.ndim == 2:
arr = arr[None, ...]
if arr.shape[1:] != (128, 128):
return np.empty((0, 128, 128), dtype=np.float32)
return arr
def sample_from_paths(
paths: Sequence[str],
n_samples: int,
rng: np.random.Generator,
used_map: Dict[str, set[int]],
) -> Tuple[np.ndarray, List[Tuple[str, np.ndarray]]]:
if not paths:
raise RuntimeError("No files available for sampling")
paths_array = np.array(paths, dtype=object)
order = rng.permutation(len(paths_array))
remaining = n_samples
collected: List[np.ndarray] = []
info: List[Tuple[str, np.ndarray]] = []
for idx in order:
if remaining <= 0:
break
path = str(paths_array[idx])
samples = load_all_samples(path)
total = samples.shape[0]
used = used_map[path]
if used:
used_idx = np.fromiter(used, dtype=np.int64, count=len(used))
available = np.setdiff1d(np.arange(total), used_idx, assume_unique=True)
else:
available = np.arange(total)
if available.size == 0:
continue
take = min(remaining, available.size)
chosen = rng.choice(available, size=take, replace=False)
collected.append(samples[chosen])
used_map[path].update(int(i) for i in chosen)
info.append((path, chosen))
remaining -= take
if remaining > 0:
raise RuntimeError("Insufficient samples remaining to satisfy request")
result = np.concatenate(collected, axis=0) if len(collected) > 1 else collected[0]
return result, info
def _ensure_available(total_needed: int, availability: Dict[str, set]) -> None:
remaining = sum(len(indices) for indices in availability.values())
if remaining < total_needed:
raise RuntimeError(
f"Insufficient samples: need {total_needed}, only {remaining} available across configs"
)
def _sample_from_availability(
arrays_map: Dict[str, LazyConfigArray],
availability: Dict[str, set[int]],
total_needed: int,
rng: np.random.Generator,
) -> Tuple[np.ndarray, Dict[str, set[int]]]:
if total_needed <= 0:
return np.empty((0, 128, 128), dtype=np.float32), {cfg: set() for cfg in arrays_map}
_ensure_available(total_needed, availability)
remaining = total_needed
configs = [cfg for cfg, indices in availability.items() if indices]
used: Dict[str, set[int]] = {cfg: set() for cfg in arrays_map}
collected: List[np.ndarray] = []
while remaining > 0 and configs:
cfg = rng.choice(configs)
available_indices = np.array(list(availability[cfg]), dtype=np.int64)
if available_indices.size == 0:
configs = [c for c in configs if c != cfg]
continue
take = min(max(1, remaining // max(len(configs), 1)), remaining, available_indices.size)
chosen = rng.choice(available_indices, size=take, replace=False)
collected.append(arrays_map[cfg][chosen])
chosen_set = {int(idx) for idx in chosen}
used[cfg].update(chosen_set)
availability[cfg].difference_update(chosen_set)
remaining -= take
configs = [c for c in configs if availability[c]]
if remaining > 0:
raise RuntimeError("Sampling failed to collect the requested number of samples")
stacked = np.concatenate(collected, axis=0) if collected else np.empty((0, 128, 128), dtype=np.float32)
return stacked.astype(np.float32, copy=False), used
def sample_train_arrays(
arrays_map: Dict[str, LazyConfigArray],
availability: Dict[str, set[int]],
train_size: int,
rng: np.random.Generator,
) -> Tuple[np.ndarray, Dict[str, set[int]]]:
return _sample_from_availability(arrays_map, availability, train_size, rng)
def sample_global_arrays(
arrays_map: Dict[str, LazyConfigArray],
availability: Dict[str, set[int]],
per_class: int,
rng: np.random.Generator,
) -> Tuple[np.ndarray, Dict[str, set[int]]]:
return _sample_from_availability(arrays_map, availability, per_class, rng)
class SpectrogramDataset(Dataset):
def __init__(self, specs: np.ndarray, labels: np.ndarray):
self.specs = specs.astype(np.float32, copy=False)
self.labels = labels.astype(np.int64, copy=False)
def __len__(self) -> int:
return len(self.labels)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
return torch.from_numpy(self.specs[idx]), int(self.labels[idx])
def normalize_batch(specs: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
mean = specs.mean()
std = specs.std(unbiased=False)
std = torch.clamp(std, min=eps)
return (specs - mean) / std
def apply_spec_augment(
specs: torch.Tensor,
*,
freq_mask_width: int = 12,
time_mask_width: int = 16,
freq_masks: int = 2,
time_masks: int = 2,
mask_prob: float = 0.5,
noise_std: float = 0.0,
) -> torch.Tensor:
"""Apply light-weight SpecAugment-style masking to a batch of spectrograms.
The function accepts tensors shaped ``[B, H, W]`` or ``[B, 1, H, W]`` and
returns an augmented tensor with the same shape. Masks use the sample mean
to avoid introducing large bias and are applied per-sample with the given
probability. Gaussian noise (if requested) is injected before masking.
"""
if mask_prob <= 0.0 and noise_std <= 0.0:
return specs
if specs.dim() not in (3, 4):
raise ValueError(f"Spectrograms must be rank-3 or rank-4, got shape {tuple(specs.shape)}")
needs_squeeze = specs.dim() == 3
augmented = specs.unsqueeze(1) if needs_squeeze else specs
batch_size, _, freq_dim, time_dim = augmented.shape
if mask_prob < 1.0:
apply_mask = torch.rand(batch_size, device=augmented.device) < mask_prob
else:
apply_mask = torch.ones(batch_size, dtype=torch.bool, device=augmented.device)
freq_mask_width = max(0, int(freq_mask_width))
time_mask_width = max(0, int(time_mask_width))
freq_masks = max(0, int(freq_masks))
time_masks = max(0, int(time_masks))
for idx in range(batch_size):
if not apply_mask[idx]:
continue
sample = augmented[idx]
if noise_std > 0.0:
sample = sample + noise_std * torch.randn_like(sample)
fill_value = sample.mean()
if freq_mask_width > 0 and freq_masks > 0:
max_width = min(freq_mask_width, freq_dim)
for _ in range(freq_masks):
width = int(torch.randint(0, max_width + 1, (1,), device=augmented.device).item())
if width == 0 or width > freq_dim:
continue
start = 0 if freq_dim == width else int(torch.randint(0, freq_dim - width + 1, (1,), device=augmented.device).item())
sample[:, start:start + width, :] = fill_value
if time_mask_width > 0 and time_masks > 0:
max_width = min(time_mask_width, time_dim)
for _ in range(time_masks):
width = int(torch.randint(0, max_width + 1, (1,), device=augmented.device).item())
if width == 0 or width > time_dim:
continue
start = 0 if time_dim == width else int(torch.randint(0, time_dim - width + 1, (1,), device=augmented.device).item())
sample[:, :, start:start + width] = fill_value
augmented[idx] = sample
return augmented.squeeze(1) if needs_squeeze else augmented
def _write_epoch_history(
rep_root: Path,
records: Sequence[Dict[str, object]],
enable_csv: bool,
enable_plot: bool,
) -> None:
if not records:
return
rep_root.mkdir(parents=True, exist_ok=True)
if enable_csv:
base_fields = [
"model",
"epoch",
"phase",
"train_loss",
"val_loss",
"val_acc",
"val_f1",
"lr",
"train_size_requested",
"train_size_effective",
]
extra_fields = sorted(
{key for rec in records for key in rec.keys() if key not in base_fields}
)
fieldnames = base_fields + extra_fields
sorted_records = sorted(records, key=lambda r: (r["epoch"], r["model"], r.get("phase", "")))
history_path = rep_root / "epoch_history.csv"
with open(history_path, "w", newline="", encoding="utf-8") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames, extrasaction="ignore")
writer.writeheader()
writer.writerows(sorted_records)
if enable_plot and HAVE_MPL:
models_in_run = sorted({rec["model"] for rec in records})
fig, axes = plt.subplots(2, 1, figsize=(8, 6), sharex=True)
for ax in axes:
ax.grid(True, linestyle='--', alpha=0.3)
for model_name_plot in models_in_run:
model_records = [rec for rec in records if rec["model"] == model_name_plot]
if not model_records:
continue
epochs = [rec["epoch"] for rec in model_records]
val_loss_values = [rec["val_loss"] for rec in model_records]
val_f1_values = [rec["val_f1"] for rec in model_records]
axes[0].plot(epochs, val_loss_values, marker='o', label=model_name_plot)
axes[1].plot(epochs, val_f1_values, marker='o', label=model_name_plot)
axes[0].set_ylabel('Val Loss')
axes[1].set_ylabel('Val F1')
axes[1].set_xlabel('Epoch')
axes[0].legend(loc='best')
axes[0].set_title('Per-epoch validation metrics')
fig.tight_layout()
fig.savefig(rep_root / 'epoch_history.png', dpi=150)
plt.close(fig)
class ResidualBlock1D(nn.Module):
"""Lightweight residual block used by the res1dcnn head."""
def __init__(self, in_channels: int, out_channels: int) -> None:
super().__init__()
self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm1d(out_channels)
self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm1d(out_channels)
self.shortcut = nn.Identity()
if in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv1d(in_channels, out_channels, kernel_size=1),
nn.BatchNorm1d(out_channels),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = F.relu(self.bn1(self.conv1(x)))
x = self.bn2(self.conv2(x))
x = x + self.shortcut(residual)
x = F.relu(x)
return x
class Res1DCNNHead(nn.Module):
"""Residual 1D CNN classifier head that operates on 128-d LWM features."""
def __init__(self, input_dim: int, num_classes: int, dropout: float = 0.1) -> None:
super().__init__()
self.input_dim = int(input_dim)
hidden_dim = 64
self.conv1 = nn.Conv1d(1, hidden_dim, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm1d(hidden_dim)
self.res_block = ResidualBlock1D(hidden_dim, hidden_dim)
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(hidden_dim, num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.unsqueeze(1)
x = F.relu(self.bn1(self.conv1(x)))
x = self.res_block(x)
x = F.adaptive_avg_pool1d(x, 1).squeeze(-1)
x = self.dropout(x)
return self.fc(x)
class LWMClassifier(nn.Module):
def __init__(
self,
backbone: nn.Module,
trainable_layers: int,
num_classes: int,
classifier_dim: int = 128,
head_dropout: float = 0.1,
head_type: str = "mlp",
):
super().__init__()
self.backbone = backbone
self.patch_size = 4
self.unfold = nn.Unfold(kernel_size=self.patch_size, stride=self.patch_size)
head_dropout = max(0.0, float(head_dropout))
head_type = head_type.lower().strip()
if head_type == "linear":
head_layers: List[nn.Module] = [nn.LayerNorm(128)]
if head_dropout > 0:
head_layers.append(nn.Dropout(head_dropout))
head_layers.append(nn.Linear(128, num_classes))
self.classifier = nn.Sequential(*head_layers)
elif head_type == "res1dcnn":
self.classifier = nn.Sequential(
nn.LayerNorm(128),
Res1DCNNHead(128, num_classes, dropout=head_dropout),
)
else:
head_layers = [
nn.LayerNorm(128),
nn.Linear(128, classifier_dim),
nn.GELU(),
]
if head_dropout > 0:
head_layers.append(nn.Dropout(head_dropout))
head_layers.append(nn.Linear(classifier_dim, num_classes))
self.classifier = nn.Sequential(*head_layers)
for param in self.backbone.parameters():
param.requires_grad = False
if trainable_layers > 0:
for layer in self.backbone.layers[-trainable_layers:]:
for param in layer.parameters():
param.requires_grad = True
# Enable gradient checkpointing for memory efficiency
if hasattr(layer, 'gradient_checkpointing'):
layer.gradient_checkpointing = True
def spectrogram_to_tokens(self, x: torch.Tensor) -> torch.Tensor:
x = x.unsqueeze(1)
patches = self.unfold(x).transpose(1, 2)
cls = torch.full(
(patches.size(0), 1, patches.size(-1)), 0.2, dtype=patches.dtype, device=patches.device
)
return torch.cat([cls, patches], dim=1)
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
tokens = self.spectrogram_to_tokens(x)
outputs = self.backbone(tokens)
if outputs.size(1) <= 1:
return outputs[:, 0, :]
return outputs[:, 1:, :].mean(dim=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
cls = self.forward_features(x)
return self.classifier(cls)
def create_simple_cnn(
num_classes: int,
hidden_dims: Tuple[int, ...] = (192,),
dropout: float = 0.3,
) -> nn.Module:
"""Create baseline CNN with configurable classifier width."""
if not hidden_dims:
raise ValueError("hidden_dims must contain at least one value")
layers: List[nn.Module] = [
nn.Conv2d(1, 16, 5, padding=2), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(16, 32, 5, padding=2), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(32, 64, 5, padding=2), nn.ReLU(), nn.AdaptiveAvgPool2d((4, 4)),
nn.Flatten(),
nn.Dropout(dropout),
]
in_dim = 4 * 4 * 64
fc_layers: List[nn.Module] = []
for idx, hidden_dim in enumerate(hidden_dims):
fc_layers.append(nn.Linear(in_dim, hidden_dim))
fc_layers.append(nn.ReLU())
fc_layers.append(nn.Dropout(dropout))
in_dim = hidden_dim
fc_layers.append(nn.Linear(in_dim, num_classes))
return nn.Sequential(*layers, *fc_layers)
def create_ieee_cnn(
num_classes: int,
hidden_dims: Tuple[int, ...] = (512, 256),
dropout: float = 0.3,
) -> nn.Module:
"""CNN inspired by IEEE 2021 joint SNR/mobility classifier."""
if not hidden_dims:
raise ValueError("hidden_dims must contain at least one value")
layers: List[nn.Module] = [
nn.Conv2d(1, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Dropout2d(p=dropout),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Dropout2d(p=dropout),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Dropout2d(p=dropout),
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Dropout2d(p=dropout),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d((4, 4)),
nn.Flatten(),
nn.Dropout(dropout),
]
in_dim = 4 * 4 * 256
fc_layers: List[nn.Module] = []
for hidden_dim in hidden_dims:
fc_layers.append(nn.Linear(in_dim, hidden_dim))
fc_layers.append(nn.BatchNorm1d(hidden_dim))
fc_layers.append(nn.ReLU(inplace=True))
fc_layers.append(nn.Dropout(dropout))
in_dim = hidden_dim
fc_layers.append(nn.Linear(in_dim, num_classes))
return nn.Sequential(*layers, *fc_layers)
def build_model(
name: str,
num_classes: int,
checkpoint: Path,
device: torch.device,
trainable_layers: int,
backbone_lr_factor: float,
overrides: Dict[str, object] | None = None,
) -> Tuple[nn.Module, List[Dict[str, object]]]:
name = name.lower()
param_groups: List[Dict[str, object]] = []
overrides = overrides or {}
if name == "lwm":
backbone = lwm_model(element_length=16, d_model=128, n_layers=12, max_len=1025, n_heads=8, dropout=0.1)
if checkpoint is None:
raise FileNotFoundError("Checkpoint is required for LWM-based models")
try:
state = torch.load(checkpoint, map_location="cpu", weights_only=True)
except TypeError:
# Older torch versions do not support weights_only
state = torch.load(checkpoint, map_location="cpu")
if any(k.startswith("module.") for k in state):
state = {k.replace("module.", ""): v for k, v in state.items()}
if any(k.startswith("backbone.") for k in state):
backbone_state = {
k.split("backbone.", 1)[1]: v
for k, v in state.items()
if k.startswith("backbone.")
}
else:
backbone_state = {
k: v
for k, v in state.items()
if not k.startswith("classifier.") and not k.startswith("projection_head.")
}
backbone.load_state_dict(backbone_state, strict=False)
classifier_dim = int(overrides.get("lwm_classifier_dim", 96))
head_dropout = float(overrides.get("lwm_head_dropout", 0.1))
head_type = str(overrides.get("lwm_head_type", "mlp")).lower()
model = LWMClassifier(
backbone,
trainable_layers=trainable_layers,
num_classes=num_classes,
classifier_dim=classifier_dim,
head_dropout=head_dropout,
head_type=head_type,
)
head_params = list(model.classifier.parameters())
param_groups.append({"params": head_params, "scale": 1.0})
if trainable_layers > 0:
backbone_params: List[nn.Parameter] = []
for layer in model.backbone.layers[-trainable_layers:]:
backbone_params.extend(layer.parameters())
backbone_params = _unique_parameters(backbone_params)
if backbone_params:
param_groups.append({"params": backbone_params, "scale": backbone_lr_factor})
elif name == "resnet18":
from torchvision import models
backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
backbone.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
nn.init.kaiming_normal_(backbone.conv1.weight, mode='fan_out', nonlinearity='relu')
in_features = backbone.fc.in_features
head_width = int(overrides.get("resnet_head_width", 384))
imagenet_head_dropout = float(overrides.get("imagenet_head_dropout", 0.45))
imagenet_head_dropout = max(0.0, min(imagenet_head_dropout, 0.9))
pre_fc_dropout = max(0.0, min(imagenet_head_dropout * 0.5, 0.9))
backbone.fc = nn.Sequential(
nn.Dropout(p=pre_fc_dropout),
nn.Linear(in_features, head_width),
nn.LayerNorm(head_width),
nn.ReLU(inplace=True),
nn.Dropout(p=imagenet_head_dropout),
nn.Linear(head_width, num_classes),
)
for param in backbone.parameters():
param.requires_grad = False
head_params = list(backbone.fc.parameters())
for param in head_params:
param.requires_grad = True
imagenet_weight_decay = overrides.get("imagenet_weight_decay", None)
head_group: Dict[str, object] = {"params": head_params, "scale": 1.0}
if imagenet_weight_decay is not None:
head_group["weight_decay"] = float(imagenet_weight_decay)
param_groups.append(head_group)
adapt_params: List[nn.Parameter] = []
if hasattr(backbone.layer4[0], "downsample") and backbone.layer4[0].downsample is not None:
adapt_params.extend(backbone.layer4[0].downsample[0].parameters())
if len(backbone.layer4[0].downsample) > 1:
adapt_params.extend(backbone.layer4[0].downsample[1].parameters())
for module in backbone.layer4[-1].modules():
if isinstance(module, nn.BatchNorm2d):
adapt_params.extend(module.parameters())
adapt_params = _unique_parameters(adapt_params)
for param in adapt_params:
param.requires_grad = True
if adapt_params:
adapt_group: Dict[str, object] = {"params": adapt_params, "scale": backbone_lr_factor}
if imagenet_weight_decay is not None:
adapt_group["weight_decay"] = float(imagenet_weight_decay)
param_groups.append(adapt_group)
model = backbone
elif name == "efficientnet_b0":
from torchvision import models
backbone = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)
first_conv = backbone.features[0][0]
backbone.features[0][0] = nn.Conv2d(1, first_conv.out_channels, kernel_size=3, stride=2, padding=1, bias=False)
nn.init.kaiming_normal_(backbone.features[0][0].weight, mode='fan_out', nonlinearity='relu')
in_features = backbone.classifier[-1].in_features
head_width = int(overrides.get("efficientnet_head_width", 192))
imagenet_head_dropout = float(overrides.get("imagenet_head_dropout", 0.45))
imagenet_head_dropout = max(0.0, min(imagenet_head_dropout, 0.9))
pre_fc_dropout = max(0.0, min(imagenet_head_dropout * 0.5, 0.9))
backbone.classifier = nn.Sequential(
nn.Dropout(p=pre_fc_dropout),
nn.Linear(in_features, head_width),
nn.LayerNorm(head_width),
nn.ReLU(inplace=True),
nn.Dropout(p=imagenet_head_dropout),
nn.Linear(head_width, num_classes),
)
for param in backbone.parameters():
param.requires_grad = False
head_params = list(backbone.classifier.parameters())
for param in head_params:
param.requires_grad = True
imagenet_weight_decay = overrides.get("imagenet_weight_decay", None)
head_group = {"params": head_params, "scale": 1.0}
if imagenet_weight_decay is not None:
head_group["weight_decay"] = float(imagenet_weight_decay)
param_groups.append(head_group)
adapt_params: List[nn.Parameter] = []
final_block = backbone.features[7][0]
# Depthwise conv + associated norms for the last MBConv block
depthwise = final_block.block[1][0]
adapt_params.extend(depthwise.parameters())
for idx in (0, 1, 3):
for module in final_block.block[idx].modules():
if isinstance(module, nn.BatchNorm2d):
adapt_params.extend(module.parameters())
adapt_params = _unique_parameters(adapt_params)
for param in adapt_params:
param.requires_grad = True
if adapt_params:
adapt_group = {"params": adapt_params, "scale": backbone_lr_factor}
if imagenet_weight_decay is not None:
adapt_group["weight_decay"] = float(imagenet_weight_decay)
param_groups.append(adapt_group)
model = backbone
elif name == "mobilenet_v3_small":
from torchvision import models
backbone = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT)
backbone.features[0][0] = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1, bias=False)
nn.init.kaiming_normal_(backbone.features[0][0].weight, mode='fan_out', nonlinearity='relu')
with torch.no_grad():
dummy = torch.zeros(1, 1, 128, 128)
features = backbone.features(dummy)
pooled = backbone.avgpool(features)
flattened = torch.flatten(pooled, 1)
in_features = flattened.shape[1]
head_width = int(overrides.get("mobilenet_head_width", 320))
imagenet_head_dropout = float(overrides.get("imagenet_head_dropout", 0.45))
imagenet_head_dropout = max(0.0, min(imagenet_head_dropout, 0.9))
pre_fc_dropout = max(0.0, min(imagenet_head_dropout * 0.5, 0.9))
backbone.classifier = nn.Sequential(
nn.Dropout(p=pre_fc_dropout),
nn.Linear(in_features, head_width),
nn.LayerNorm(head_width),
nn.Hardswish(),
nn.Dropout(p=imagenet_head_dropout),
nn.Linear(head_width, num_classes),
)
for param in backbone.parameters():
param.requires_grad = False
head_params = list(backbone.classifier.parameters())
for param in head_params:
param.requires_grad = True
imagenet_weight_decay = overrides.get("imagenet_weight_decay", None)
head_group = {"params": head_params, "scale": 1.0}
if imagenet_weight_decay is not None:
head_group["weight_decay"] = float(imagenet_weight_decay)
param_groups.append(head_group)
adapt_params: List[nn.Parameter] = []
adapt_params.extend(backbone.features[-1][0].parameters())
for module in backbone.features[-2].modules():
if isinstance(module, nn.BatchNorm2d):
adapt_params.extend(module.parameters())
adapt_params = _unique_parameters(adapt_params)
for param in adapt_params:
param.requires_grad = True
if adapt_params:
adapt_group = {"params": adapt_params, "scale": backbone_lr_factor}
if imagenet_weight_decay is not None:
adapt_group["weight_decay"] = float(imagenet_weight_decay)
param_groups.append(adapt_group)
model = backbone
elif name in {"simple_cnn", "simplecnn"}:
hidden_dims = overrides.get("simple_cnn_hidden_dims", (192,))
if isinstance(hidden_dims, Sequence) and not isinstance(hidden_dims, str):
simple_dims = tuple(int(dim) for dim in hidden_dims)
else:
simple_dims = (int(hidden_dims),)
model = create_simple_cnn(num_classes, hidden_dims=simple_dims)
head_params = list(model.parameters())
param_groups.append({"params": head_params, "scale": 1.0})
elif name in {"ieee_cnn", "ieeecnn"}:
hidden_dims = overrides.get("ieee_cnn_hidden_dims", (512, 256))
if isinstance(hidden_dims, Sequence) and not isinstance(hidden_dims, str):
ieee_dims = tuple(int(dim) for dim in hidden_dims)
else:
ieee_dims = (int(hidden_dims),)
dropout = float(overrides.get("ieee_cnn_dropout", 0.3))
model = create_ieee_cnn(num_classes, hidden_dims=ieee_dims, dropout=dropout)
head_params = list(model.parameters())
param_groups.append({"params": head_params, "scale": 1.0})
else:
raise ValueError(f"Unknown model: {name}")
return model.to(device), param_groups
def _unwrap_module(model: nn.Module) -> nn.Module:
return model.module if isinstance(model, nn.DataParallel) else model
def _strip_module_prefix(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
if not state_dict:
return state_dict
needs_strip = any(key.startswith("module.") for key in state_dict)
if not needs_strip:
return state_dict
stripped = state_dict.__class__() if hasattr(state_dict, "__class__") else {}
for key, value in state_dict.items():
new_key = key.split("module.", 1)[1] if key.startswith("module.") else key
stripped[new_key] = value
return stripped
def _model_forward(
model: nn.Module,
specs: torch.Tensor,
input_stats: Optional[torch.Tensor] = None,
) -> torch.Tensor:
base_model = _unwrap_module(model)
is_lwm_like = isinstance(base_model, LWMClassifier)
if not is_lwm_like and LWMClassifierMinimal is not None:
is_lwm_like = isinstance(base_model, LWMClassifierMinimal)
if not is_lwm_like and hasattr(base_model, "spectrogram_to_tokens"):
is_lwm_like = True
if is_lwm_like:
while specs.dim() > 3 and specs.size(1) == 1:
specs = specs.squeeze(1)
if specs.dim() != 3:
specs = specs.view(specs.size(0), specs.size(-2), specs.size(-1))
if not is_lwm_like and specs.dim() == 3:
specs = specs.unsqueeze(1)
if input_stats is not None:
input_stats = input_stats.to(specs.device, non_blocking=True)
supports_stats = bool(
is_lwm_like and hasattr(base_model, "append_input_stats") and getattr(base_model, "append_input_stats")
)
if supports_stats and input_stats is not None:
return model(specs, input_stats=input_stats)
return model(specs)
def train_one_epoch(model, loader, optimizer, device, scaler=None, batch_normalize: bool = False):
criterion = nn.CrossEntropyLoss(reduction='mean')
model.train()
total_loss = 0.0
total = 0
for specs, labels in loader:
specs = specs.to(device, non_blocking=True)
if batch_normalize:
specs = normalize_batch(specs)
labels = labels.to(device, non_blocking=True)
optimizer.zero_grad(set_to_none=True)
# Use autocast only for CUDA
if scaler is not None and device.type == 'cuda':
with autocast(device_type='cuda'):
logits = _model_forward(model, specs)
loss = criterion(logits, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
# HPU and CPU use standard forward/backward
logits = _model_forward(model, specs)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
total_loss += loss.item() * labels.size(0)
total += labels.size(0)
# Clear cache periodically to prevent memory fragmentation
if device.type == 'cuda':
torch.cuda.empty_cache()
return total_loss / max(total, 1)
@torch.no_grad()
def evaluate(
model,
loader,
device,
debug: Optional[Dict[str, object]] = None,
batch_normalize: bool = False,
) -> Tuple[float, float, float]:
criterion = nn.CrossEntropyLoss(reduction='mean')
model.eval()
total_loss = 0.0
correct = 0
total = 0
all_preds: List[np.ndarray] = []
all_labels: List[np.ndarray] = []
debug_batches = int(debug.get("log_batches", 0)) if debug else 0
debug_every = max(1, int(debug.get("log_every", 1))) if debug else 1
log_softmax = bool(debug.get("log_softmax", False)) if debug else False
debug_logged = 0
for batch_idx, batch in enumerate(loader, start=1):
stats_batch: Optional[torch.Tensor]
if isinstance(batch, (list, tuple)) and len(batch) == 3:
specs, stats_batch, labels = batch
stats_batch = stats_batch.to(device, non_blocking=True)
else:
specs, labels = batch # type: ignore[misc]
stats_batch = None
specs = specs.to(device, non_blocking=True)
if batch_normalize:
specs = normalize_batch(specs)
labels = labels.to(device, non_blocking=True)
# Use autocast only for CUDA, not for HPU or CPU
if device.type == 'cuda':
context = autocast(device_type='cuda')
else:
context = nullcontext()
with context:
logits = _model_forward(model, specs, stats_batch)
loss = criterion(logits, labels)
preds = logits.argmax(dim=1)
total_loss += loss.item() * labels.size(0)
correct += (preds == labels).sum().item()
total += labels.size(0)
all_preds.append(preds.detach().cpu().numpy())
all_labels.append(labels.detach().cpu().numpy())
should_log = (
debug_batches > 0
and debug_logged < debug_batches
and ((batch_idx - 1) % debug_every == 0)
)
if should_log:
specs_cpu = specs.detach().cpu()
logits_cpu = logits.detach().cpu()
loss_scalar = float(loss.detach().cpu().item())
finite_specs = torch.isfinite(specs).all().item()
finite_logits = torch.isfinite(logits).all().item()
print(
f" [DEBUG][eval][batch {batch_idx}] loss={loss_scalar:.6f} "
f"reduction={criterion.reduction} labels_shape={tuple(labels.shape)}"
)
print(
f" specs dtype={specs.dtype} mean={specs_cpu.mean():.4f} std={specs_cpu.std():.4f} "
f"min={specs_cpu.min():.4f} max={specs_cpu.max():.4f} finite={bool(finite_specs)}"
)
print(
f" logits dtype={logits.dtype} mean={logits_cpu.mean():.4f} std={logits_cpu.std():.4f} "
f"min={logits_cpu.min():.4f} max={logits_cpu.max():.4f} finite={bool(finite_logits)}"
)
unique_labels, counts = torch.unique(labels.detach().cpu(), return_counts=True)
label_info = ", ".join(
f"{int(lbl)}:{int(cnt)}" for lbl, cnt in zip(unique_labels, counts)
)
print(f" label distribution -> {label_info}")
if log_softmax:
probs = torch.softmax(logits_cpu, dim=1)
print(
f" softmax mean={probs.mean():.4f} std={probs.std():.4f} "
f"min={probs.min():.4f} max={probs.max():.4f}"
)
debug_logged += 1
# Clear cache periodically
if device.type == 'cuda':
torch.cuda.empty_cache()
y_true = np.concatenate(all_labels) if all_labels else np.empty(0)
y_pred = np.concatenate(all_preds) if all_preds else np.empty(0)
f1 = compute_f1(y_true, y_pred) if y_true.size > 0 else 0.0
return total_loss / max(total, 1), correct / max(total, 1), f1
def set_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# HPU seed setting (if available)
if HPU_AVAILABLE and hasattr(torch.hpu, "manual_seed"):
torch.hpu.manual_seed(seed)
def main() -> None:
# Set CUDA memory allocation configuration to reduce fragmentation
if torch.cuda.is_available():
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
args = parse_args()
if args.early_min_epochs < 10:
print(
f"[INFO] Requested early_min_epochs={args.early_min_epochs} < 10; enforcing minimum of 10"
)
args.early_min_epochs = 10
set_seed(args.seed)
require_checkpoint = any(model.lower() == "lwm" for model in args.models)
checkpoint, stats = resolve_checkpoint_and_stats(args, require_checkpoint=require_checkpoint)
normalization_mode = str(stats.get("normalization", "dataset")).lower()
print(f"[INFO] Normalization mode from stats: {normalization_mode}")
data_root = Path(args.data_root)
available_snrs, available_mobilities = discover_snr_mobility(
data_root, args.cities, args.comm_types, args.fft_folder
)
train_snrs = args.snrs if args.snrs else available_snrs
train_mobilities = args.mobilities if args.mobilities else available_mobilities
val_snrs = args.val_snrs if args.val_snrs else available_snrs
val_mobilities = args.val_mobilities if args.val_mobilities else available_mobilities
class_configs = build_config_map(
data_root, args.cities, args.comm_types, train_snrs, train_mobilities, args.fft_folder
)
active_labels = [cls for cls, configs in class_configs.items() if any(configs.values())]
if not active_labels:
raise RuntimeError("No modulation classes found with the provided filters.")
class_configs = {cls: class_configs[cls] for cls in active_labels}
label_to_local = {cls: idx for idx, cls in enumerate(sorted(active_labels))}
num_classes = len(active_labels)
print("[INFO] Active modulation classes:", ", ".join(LABEL_NAMES.get(cls, str(cls)) for cls in sorted(active_labels)))
config_arrays = load_config_arrays(class_configs)
global_config_map = build_config_map(
Path(args.data_root), args.cities, args.comm_types, val_snrs, val_mobilities, args.fft_folder
)
global_config_arrays = load_config_arrays(global_config_map)
print("[INFO] Training SNRs:", ", ".join(train_snrs))
print("[INFO] Training mobilities:", ", ".join(train_mobilities))
print("[INFO] Validation/Test SNRs:", ", ".join(val_snrs))
print("[INFO] Validation/Test mobilities:", ", ".join(val_mobilities))
per_class_totals: Dict[int, int] = {}
for cls in sorted(active_labels):
configs = config_arrays[cls]
total_samples = sum(arr.shape[0] for arr in configs.values())
per_class_totals[cls] = total_samples
print(f"[INFO] Class {LABEL_NAMES.get(cls, str(cls))}: {len(configs)} configs, {total_samples} samples")
if cls not in global_config_arrays or not global_config_arrays[cls]:
raise RuntimeError(f"No global data found for modulation {LABEL_NAMES.get(cls, str(cls))}")
min_class_total = min(per_class_totals.values())
max_train_per_class = min_class_total - args.val_per_class - args.test_per_class
if max_train_per_class <= 0:
raise RuntimeError(
"Requested val/test splits leave no data for training. "
f"Minimum class has {min_class_total} samples; "
f"val={args.val_per_class}, test={args.test_per_class}."
)
if any(size > max_train_per_class for size in args.train_sizes):
adjusted: List[int] = []
for size in args.train_sizes:
if size > max_train_per_class:
print(
f"[WARN] Requested train size {size} exceeds available "
f"{max_train_per_class} per class after val/test splits; capping."
)
capped = min(size, max_train_per_class)
if capped not in adjusted:
adjusted.append(capped)
args.train_sizes = adjusted if adjusted else [max_train_per_class]
print(f"[INFO] Effective train sizes per class: {args.train_sizes}")
# Device selection: auto, cuda, hpu, or cpu
requested_device = args.device.lower()
if requested_device == "auto":
if HPU_AVAILABLE:
requested_device = "hpu"
elif torch.cuda.is_available():
requested_device = "cuda"
else:
requested_device = "cpu"
# Setup device based on selection
if requested_device == "hpu":
if not HPU_AVAILABLE:
raise RuntimeError(
"HPU device requested but not available. "
"Install Habana PyTorch or select --device cuda/cpu."
)
device = torch.device("hpu")
# Set HPU device (typically device 0 for single-process)
if hasattr(torch.hpu, "set_device"):
torch.hpu.set_device(0)
print(f"[INFO] Using HPU device")
active_gpu_ids = [] # Not applicable for HPU
multi_gpu = False
elif requested_device == "cuda":
cuda_available = torch.cuda.is_available()
if not cuda_available:
raise RuntimeError("CUDA device requested but not available.")
available_gpu_ids = list(range(torch.cuda.device_count()))
if args.gpu_ids is not None:
invalid_ids = [gpu_id for gpu_id in args.gpu_ids if gpu_id not in available_gpu_ids]
if invalid_ids:
raise ValueError(
f"Requested GPU IDs not available: {invalid_ids}; available: {available_gpu_ids}"
)
active_gpu_ids = list(dict.fromkeys(args.gpu_ids))
else:
active_gpu_ids = available_gpu_ids
if active_gpu_ids:
primary_gpu = active_gpu_ids[0]
torch.cuda.set_device(primary_gpu)
device = torch.device(f"cuda:{primary_gpu}")
print(f"[INFO] Using CUDA device(s): {', '.join(str(i) for i in active_gpu_ids)}")
else:
device = torch.device("cpu")
print("[INFO] CUDA requested but no GPUs available, using CPU")
multi_gpu = len(active_gpu_ids) > 1
if multi_gpu:
print(f"[INFO] Enabling DataParallel across GPUs: {', '.join(str(i) for i in active_gpu_ids)}")
else: # cpu
device = torch.device("cpu")
if args.gpu_ids is not None:
print("[WARN] GPU IDs specified but using CPU")
print("[INFO] Using CPU")
active_gpu_ids = []
multi_gpu = False
print(f"[INFO] Using device: {device}")
args.output_dir.mkdir(parents=True, exist_ok=True)
print(f"[INFO] Saving outputs under: {args.output_dir}")
eval_debug_config: Optional[Dict[str, object]] = None
if args.debug_eval_batches > 0:
eval_debug_config = {
"log_batches": int(args.debug_eval_batches),
"log_every": max(1, int(args.debug_eval_interval)),
"log_softmax": bool(args.debug_eval_softmax),
}
print(
"[INFO] Evaluation debug logging enabled -> batches:"
f" {eval_debug_config['log_batches']}, interval: {eval_debug_config['log_every']}"
)
summary_device = torch.device("cpu") if device.type == "cuda" else device
model_overrides: Dict[str, object] = {
"resnet_head_width": args.resnet_head_width,
"efficientnet_head_width": args.efficientnet_head_width,
"mobilenet_head_width": args.mobilenet_head_width,
"simple_cnn_hidden_dims": tuple(args.simple_cnn_hidden_dims),
"ieee_cnn_hidden_dims": tuple(args.ieee_cnn_hidden_dims),
"ieee_cnn_dropout": args.ieee_cnn_dropout,
"lwm_classifier_dim": args.lwm_classifier_dim,
"lwm_head_dropout": args.lwm_head_dropout,
"lwm_head_type": args.lwm_head_type,
"imagenet_head_dropout": args.imagenet_head_dropout,
"imagenet_weight_decay": args.weight_decay * args.imagenet_weight_decay_scale,
}
print("\n[INFO] Parameter counts per model (total/trainable):")
for model_name in args.models:
lower_name = model_name.lower()
trainable_layers = args.lwm_trainable_layers if lower_name == "lwm" else 0
model_checkpoint = checkpoint
backbone_lr_factor = args.backbone_lr_factor
if lower_name == "lwm" and args.lwm_backbone_lr_factor is not None:
backbone_lr_factor = args.lwm_backbone_lr_factor
model, _ = build_model(
model_name,
num_classes,
model_checkpoint,
summary_device,
trainable_layers=trainable_layers,
backbone_lr_factor=backbone_lr_factor,
overrides=model_overrides,
)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f" {model_name}: {total_params:,} / {trainable_params:,}")
del model
if device.type == 'cuda':
torch.cuda.empty_cache()
raw_input_models = set(args.raw_input_models)
active_raw_models = [model for model in args.models if model.lower() in raw_input_models]
if active_raw_models:
print(
"[INFO] Raw spectrogram input (per-batch normalization) for models: "
+ ", ".join(active_raw_models)
)
normalized_models = [model for model in args.models if model.lower() not in raw_input_models]
requires_normalized_inputs = len(normalized_models) > 0
if requires_normalized_inputs:
print(
"[INFO] Applying normalization for models: "
+ ", ".join(normalized_models)
)
else:
print("[INFO] All selected models consume raw spectrograms; normalization skipped")
summary: Dict[str, Dict[int, Dict[str, List[float]]]] = {
model: {size: {"acc": [], "f1": [], "val_f1": [], "val_loss": []} for size in args.train_sizes}
for model in args.models
}
train_sizes_sorted = sorted(args.train_sizes)
for repetition in range(1, args.repetitions + 1):
selection_for_repetition: Dict[int, Dict[str, set[int]]] = {}
val_rng_seed = args.seed + repetition * 100000
val_rng = np.random.default_rng(val_rng_seed)
fixed_val_samples: Dict[int, np.ndarray] = {}
fixed_test_samples: Dict[int, np.ndarray] = {}
val_reserved_indices: Dict[int, Dict[str, set[int]]] = {}
test_reserved_indices: Dict[int, Dict[str, set[int]]] = {}
for cls in sorted(config_arrays.keys()):
global_arrays = global_config_arrays[cls]
global_avail = {cfg: set(range(arr.shape[0])) for cfg, arr in global_arrays.items()}
val_samples, val_used = sample_global_arrays(global_arrays, global_avail, args.val_per_class, val_rng)
test_samples, test_used = sample_global_arrays(global_arrays, global_avail, args.test_per_class, val_rng)
fixed_val_samples[cls] = val_samples
fixed_test_samples[cls] = test_samples
val_reserved_indices[cls] = {cfg: set(indices) for cfg, indices in val_used.items()}
test_reserved_indices[cls] = {cfg: set(indices) for cfg, indices in test_used.items()}
for train_size in train_sizes_sorted:
rep_seed = args.seed + train_size * 1000 + repetition
rng = np.random.default_rng(rep_seed)
repetition_records: List[Dict[str, object]] = []
per_size_val_metrics: List[Tuple[str, float, float, float]] = []
train_specs, train_labels = [], []
val_specs, val_labels = [], []
test_specs, test_labels = [], []
class_contexts: Dict[int, Dict[str, Any]] = {}
class_capacities: Dict[int, int] = {}
for cls in sorted(config_arrays.keys()):
arrays_map = config_arrays[cls]
if not arrays_map:
raise RuntimeError(f"No data for class {LABEL_NAMES[cls]}")
if cls not in selection_for_repetition:
selection_for_repetition[cls] = defaultdict(set)
prev_config_indices = selection_for_repetition[cls]
prev_total = sum(len(sel_indices) for sel_indices in prev_config_indices.values())
if prev_total > train_size:
raise ValueError(
f"Requested train size {train_size} is smaller than previously selected {prev_total} "
f"for class {LABEL_NAMES[cls]}"
)
train_avail = {config: set(range(arr.shape[0])) for config, arr in arrays_map.items()}
for config, sel_indices in prev_config_indices.items():
if sel_indices and config in train_avail:
train_avail[config].difference_update(sel_indices)
val_reserved = val_reserved_indices.get(cls, {})
test_reserved = test_reserved_indices.get(cls, {})
for config, reserved in val_reserved.items():
if reserved and config in train_avail:
train_avail[config].difference_update(reserved)
for config, reserved in test_reserved.items():
if reserved and config in train_avail:
train_avail[config].difference_update(reserved)
available_now = sum(len(indices) for indices in train_avail.values())
capacity = prev_total + available_now
class_contexts[cls] = {
"arrays_map": arrays_map,
"prev_indices": prev_config_indices,
"train_avail": train_avail,
}
class_capacities[cls] = capacity
if not class_capacities:
raise RuntimeError("No modulation classes available for training")
min_capacity = min(class_capacities.values())
limiting_classes = sorted(cls for cls, cap in class_capacities.items() if cap == min_capacity)
effective_train_size = min(train_size, min_capacity)
if effective_train_size < train_size:
limiting_labels = ", ".join(LABEL_NAMES.get(cls, str(cls)) for cls in limiting_classes)
if not limiting_labels:
limiting_labels = "unknown"
print(
f"[WARN] Requested train size {train_size} exceeds available "
f"{min_capacity} after reserving val/test samples; using {effective_train_size} "
f"(limited by {limiting_labels})"
)
if effective_train_size <= 0:
raise RuntimeError("No training samples available after reserving val/test splits")
for cls in sorted(config_arrays.keys()):
ctx = class_contexts[cls]
arrays_map = ctx["arrays_map"]
prev_config_indices = ctx["prev_indices"]
train_avail = ctx["train_avail"]
selected_arrays: List[np.ndarray] = []
prev_total = 0
for config, sel_indices in prev_config_indices.items():
if sel_indices:
idx_sorted = sorted(sel_indices)
selected_arrays.append(arrays_map[config][idx_sorted])
prev_total += len(sel_indices)
needed = max(effective_train_size - prev_total, 0)
if needed > 0:
additional_samples, train_used = sample_train_arrays(arrays_map, train_avail, needed, rng)
if additional_samples.size == 0:
raise RuntimeError("Failed to collect additional training samples")
selected_arrays.append(additional_samples)
for config, indices in train_used.items():
prev_config_indices[config].update(int(idx) for idx in indices)
if not selected_arrays:
raise RuntimeError("No training samples collected")
train_samples = np.concatenate(selected_arrays, axis=0)
if train_samples.shape[0] != effective_train_size:
print(
f"[WARN] Collected {train_samples.shape[0]} training samples for "
f"{LABEL_NAMES.get(cls, str(cls))}, expected {effective_train_size}"
)
val_samples = fixed_val_samples[cls]
test_samples = fixed_test_samples[cls]
train_specs.append(train_samples)
val_specs.append(val_samples)
test_specs.append(test_samples)
local_label = label_to_local[cls]
train_labels.append(np.full(train_samples.shape[0], local_label, dtype=np.int64))
val_labels.append(np.full(val_samples.shape[0], local_label, dtype=np.int64))
test_labels.append(np.full(test_samples.shape[0], local_label, dtype=np.int64))
train_specs_raw = np.concatenate(train_specs)
val_specs_raw = np.concatenate(val_specs)
test_specs_raw = np.concatenate(test_specs)
train_labels = np.concatenate(train_labels)
val_labels = np.concatenate(val_labels)
test_labels = np.concatenate(test_labels)
# Verify no data leakage (all splits are disjoint)
# Note: Since we sample from different configs with availability tracking,
# there should be no overlap, but we verify to be safe
print(
f"[INFO] Verifying data splits for train_size={train_size} "
f"(effective {effective_train_size}), rep={repetition}..."
)
print(
f" Train: {len(train_labels)} samples "
f"(~{effective_train_size} per class expected)"
)
print(f" Val: {len(val_labels)} samples ({args.val_per_class} per class)")
print(f" Test: {len(test_labels)} samples ({args.test_per_class} per class)")
# Check class balance
train_class_counts = Counter(train_labels)
val_class_counts = Counter(val_labels)
test_class_counts = Counter(test_labels)
print(f"[INFO] Train class distribution: {dict(train_class_counts)}")
print(f"[INFO] Val class distribution: {dict(val_class_counts)}")
print(f"[INFO] Test class distribution: {dict(test_class_counts)}")
# Verify all classes have expected counts
expected_train_per_class = effective_train_size
for cls_idx in range(num_classes):
if train_class_counts[cls_idx] != expected_train_per_class:
print(f"[WARN] Class {cls_idx} has {train_class_counts[cls_idx]} train samples, expected {expected_train_per_class}")
if val_class_counts[cls_idx] != args.val_per_class:
print(f"[WARN] Class {cls_idx} has {val_class_counts[cls_idx]} val samples, expected {args.val_per_class}")
if test_class_counts[cls_idx] != args.test_per_class:
print(f"[WARN] Class {cls_idx} has {test_class_counts[cls_idx]} test samples, expected {args.test_per_class}")
print(f"[INFO] ✓ All splits have balanced class distribution")
train_ds_raw = SpectrogramDataset(train_specs_raw, train_labels)
val_ds_raw = SpectrogramDataset(val_specs_raw, val_labels)
test_ds_raw = SpectrogramDataset(test_specs_raw, test_labels)
train_loader_raw = DataLoader(
train_ds_raw,
batch_size=args.batch_size,
shuffle=True,
num_workers=2,
pin_memory=False,
)
val_loader_raw = DataLoader(
val_ds_raw,
batch_size=args.batch_size,
shuffle=False,
num_workers=2,
pin_memory=False,
)
test_loader_raw = DataLoader(
test_ds_raw,
batch_size=args.batch_size,
shuffle=False,
num_workers=2,
pin_memory=False,
)
train_loader_normalized: Optional[DataLoader] = None
val_loader_normalized: Optional[DataLoader] = None
test_loader_normalized: Optional[DataLoader] = None
if requires_normalized_inputs:
train_specs_normalized = apply_normalization(train_specs_raw, stats)
val_specs_normalized = apply_normalization(val_specs_raw, stats)
test_specs_normalized = apply_normalization(test_specs_raw, stats)
train_ds_normalized = SpectrogramDataset(train_specs_normalized, train_labels)
val_ds_normalized = SpectrogramDataset(val_specs_normalized, val_labels)
test_ds_normalized = SpectrogramDataset(test_specs_normalized, test_labels)
train_loader_normalized = DataLoader(
train_ds_normalized,
batch_size=args.batch_size,
shuffle=True,
num_workers=2,
pin_memory=False,
)
val_loader_normalized = DataLoader(
val_ds_normalized,
batch_size=args.batch_size,
shuffle=False,
num_workers=2,
pin_memory=False,
)
test_loader_normalized = DataLoader(
test_ds_normalized,
batch_size=args.batch_size,
shuffle=False,
num_workers=2,
pin_memory=False,
)
rep_root = args.output_dir / f"size_{train_size}" / f"rep_{repetition}"
rep_root.mkdir(parents=True, exist_ok=True)
for model_name in args.models:
model_root = rep_root / model_name
model_root.mkdir(parents=True, exist_ok=True)
epoch_ckpt_dir: Optional[Path] = None
if args.save_epoch_checkpoints:
epoch_ckpt_dir = model_root / "epoch_checkpoints"
epoch_ckpt_dir.mkdir(parents=True, exist_ok=True)
for old_path in epoch_ckpt_dir.glob("epoch_*.pth"):
if old_path.is_file():
old_path.unlink()
print(
f"\n[INFO] Size {train_size} (effective {effective_train_size}), "
f"repetition {repetition}, model {model_name}"
)
set_seed(rep_seed + hash(model_name) % 1000)
lower_name = model_name.lower()
use_raw_input = lower_name in raw_input_models
if use_raw_input:
train_loader = train_loader_raw
val_loader = val_loader_raw
test_loader = test_loader_raw
print(" [INFO] Feeding raw spectrograms with per-batch normalization")
else:
if (
train_loader_normalized is None
or val_loader_normalized is None
or test_loader_normalized is None
):
raise RuntimeError(
"Normalized loaders were requested but could not be constructed."
)
train_loader = train_loader_normalized
val_loader = val_loader_normalized
test_loader = test_loader_normalized
trainable_layers = args.lwm_trainable_layers if lower_name == "lwm" else 0
backbone_lr_factor = args.backbone_lr_factor
if lower_name == "lwm" and args.lwm_backbone_lr_factor is not None:
backbone_lr_factor = args.lwm_backbone_lr_factor
model_checkpoint = checkpoint
model, param_groups = build_model(
model_name,
num_classes,
model_checkpoint,
device,
trainable_layers=trainable_layers,
backbone_lr_factor=backbone_lr_factor,
overrides=model_overrides,
)
if multi_gpu:
model = nn.DataParallel(model, device_ids=active_gpu_ids)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(
f"[INFO] Parameters (total/trainable): {total_params:,} / {trainable_params:,}"
)
def make_optimizer(base_lr: float) -> torch.optim.Optimizer:
optim_groups: List[Dict[str, object]] = []
if param_groups:
for group in param_groups:
scale = float(group.get("scale", 1.0))
params = [p for p in group.get("params", []) if p.requires_grad]
if params:
group_cfg: Dict[str, object] = {
"params": list(params),
"lr": base_lr * scale,
}
if "weight_decay" in group:
group_cfg["weight_decay"] = float(group["weight_decay"])
optim_groups.append(group_cfg)
if not optim_groups:
optim_groups.append({
"params": [p for p in model.parameters() if p.requires_grad],
"lr": base_lr,
})
return torch.optim.AdamW(optim_groups, lr=base_lr, weight_decay=args.weight_decay)
def make_scheduler(optimizer: torch.optim.Optimizer, base_lr: float, patience_limit: int):
plateau_patience = max(2, patience_limit // 2)
return torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode="min",
factor=0.5,
patience=plateau_patience,
min_lr=base_lr * 0.01,
)
# Initialize mixed precision scaler for CUDA
# GradScaler only for CUDA, not for HPU or CPU
scaler = GradScaler('cuda') if device.type == 'cuda' else None
best_val_loss = float("inf")
best_val_acc = 0.0
best_state = None
epoch_history: List[Dict[str, object]] = []
best_val_f1 = 0.0
best_epoch = 0
total_epochs_ran = 0
overall_early_stopped = False
phase_configs = [
{
"name": "main",
"max_epochs": args.epochs,
"base_lr": args.lr,
"patience": max(1, args.early_patience),
"min_epochs": max(0, args.early_min_epochs),
}
]
ft_epochs = max(0, args.finetune_epochs)
ft_lr_factor = args.finetune_lr_factor
ft_patience = max(1, args.finetune_patience)
ft_min_epochs = max(0, args.finetune_min_epochs)
if ft_epochs > 0:
phase_configs.append(
{
"name": "finetune",
"max_epochs": ft_epochs,
"base_lr": args.lr * ft_lr_factor,
"patience": ft_patience,
"min_epochs": ft_min_epochs,
}
)
for phase_idx, phase in enumerate(phase_configs):
if phase["max_epochs"] <= 0:
continue
if phase_idx > 0:
print(
f"\n [INFO] Starting {phase['name']} phase: lr={phase['base_lr']:.2e}, "
f"max_epochs={phase['max_epochs']}"
)
if best_state is not None:
model.load_state_dict(best_state["model"])
optimizer = make_optimizer(phase["base_lr"])
scheduler = make_scheduler(optimizer, phase["base_lr"], phase["patience"])
patience_counter = 0
phase_early_stopped = False
phase_epochs_completed = 0
phase_min_epochs = max(0, phase["min_epochs"])
for local_epoch in range(1, phase["max_epochs"] + 1):
overall_epoch = total_epochs_ran + local_epoch
train_loss = train_one_epoch(
model,
train_loader,
optimizer,
device,
scaler,
batch_normalize=use_raw_input,
)
val_loss, val_acc, val_f1 = evaluate(
model,
val_loader,
device,
eval_debug_config,
batch_normalize=use_raw_input,
)
scheduler.step(val_loss)
current_lr = optimizer.param_groups[0]["lr"]
print(
f" [{phase['name']}] Epoch {overall_epoch:02d}: "
f"train_loss={train_loss:.4f} val_loss={val_loss:.4f} "
f"val_acc={val_acc:.4%} val_f1={val_f1:.4f}"
)
epoch_history.append(
{
"epoch": int(overall_epoch),
"train_loss": float(train_loss),
"val_loss": float(val_loss),
"val_acc": float(val_acc),
"val_f1": float(val_f1),
"lr": float(current_lr),
"phase": phase["name"],
}
)
repetition_records.append(
{
"model": model_name,
"epoch": int(overall_epoch),
"phase": phase["name"],
"train_loss": float(train_loss),
"val_loss": float(val_loss),
"val_acc": float(val_acc),
"val_f1": float(val_f1),
"lr": float(current_lr),
"train_size_requested": int(train_size),
"train_size_effective": int(effective_train_size),
}
)
_write_epoch_history(rep_root, repetition_records, args.save_epoch_history, args.plot_epoch_history)
raw_epoch_state = _strip_module_prefix(model.state_dict())
if epoch_ckpt_dir is not None:
epoch_state = raw_epoch_state.__class__()
for key, value in raw_epoch_state.items():
epoch_state[key] = value.detach().cpu()
epoch_ckpt_path = epoch_ckpt_dir / f"epoch_{overall_epoch:03d}.pth"
torch.save(epoch_state, epoch_ckpt_path)
if val_loss < best_val_loss:
best_val_loss = val_loss
best_val_acc = val_acc
best_val_f1 = val_f1
best_model_state = {
key: value.detach().cpu()
for key, value in model.state_dict().items()
}
best_state = {
"model": best_model_state,
"val_loss": val_loss,
"val_acc": val_acc,
"val_f1": val_f1,
"epoch": int(overall_epoch),
"lr": current_lr,
"phase": phase["name"],
}
best_epoch = int(overall_epoch)
patience_counter = 0
else:
if local_epoch >= phase_min_epochs:
patience_counter += 1
if patience_counter >= phase["patience"]:
print(
f" [INFO] Early stopping ({phase['name']}) at epoch {overall_epoch:02d} "
f"after {patience_counter} epochs without val loss improvement"
)
overall_early_stopped = True
phase_early_stopped = True
phase_epochs_completed = local_epoch
break
phase_epochs_completed = local_epoch
total_epochs_ran += phase_epochs_completed
if phase_early_stopped is False and phase_epochs_completed < phase["max_epochs"]:
# Loop exited early via break without setting the flag (should not happen)
phase_early_stopped = True
if best_state is None:
raise RuntimeError("Training finished without recording a validation improvement")
model.load_state_dict(best_state["model"])
test_loss, test_acc, test_f1 = evaluate(
model,
test_loader,
device,
eval_debug_config,
batch_normalize=use_raw_input,
)
print(
f" -> Test loss={test_loss:.4f} Test acc={test_acc:.4%} Test f1={test_f1:.4f}"
)
export_dir = getattr(args, "export_full_model", None)
if export_dir is not None:
export_dir = export_dir.expanduser().resolve()
export_dir.mkdir(parents=True, exist_ok=True)
comm_token = getattr(args, "comm_suffix", "multi")
filename = f"{comm_token}_{model_name}_size{train_size}_rep{repetition}.pth"
export_path = export_dir / filename
full_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
torch.save(full_state, export_path)
print(f" [INFO] Saved full model (backbone + head) to {export_path}")
summary[model_name][train_size]["acc"].append(test_acc)
summary[model_name][train_size]["f1"].append(test_f1)
summary[model_name][train_size]["val_f1"].append(best_state["val_f1"])
summary[model_name][train_size]["val_loss"].append(best_state["val_loss"])
per_size_val_metrics.append(
(model_name, best_state["val_f1"], best_state["val_loss"], test_f1)
)
result_dir = args.output_dir / f"size_{train_size}" / f"rep_{repetition}" / model_name
result_dir.mkdir(parents=True, exist_ok=True)
state_to_save = copy.deepcopy(best_state)
state_to_save["model"] = _strip_module_prefix(state_to_save["model"])
torch.save(state_to_save, result_dir / "checkpoint.pt")
with open(result_dir / "metrics.json", "w", encoding="utf-8") as f:
json.dump(
{
"train_size_per_class": effective_train_size,
"train_size_per_class_requested": train_size,
"repetition": repetition,
"model": model_name,
"best_val_loss": best_state.get("val_loss", None),
"best_val_acc": best_val_acc,
"best_val_f1": best_state["val_f1"],
"test_loss": test_loss,
"test_acc": test_acc,
"test_f1": test_f1,
"best_epoch": best_epoch,
"epochs_ran": total_epochs_ran,
"early_stopped": overall_early_stopped,
"history": epoch_history,
},
f,
indent=2,
)
rep_root = args.output_dir / f"size_{train_size}" / f"rep_{repetition}"
rep_root.mkdir(parents=True, exist_ok=True)
if args.plot_epoch_history and HAVE_MPL and repetition_records:
models_in_run = sorted({rec["model"] for rec in repetition_records})
fig, axes = plt.subplots(2, 1, figsize=(8, 6), sharex=True)
for ax in axes:
ax.grid(True, linestyle='--', alpha=0.3)
for model_name_plot in models_in_run:
model_records = [rec for rec in repetition_records if rec["model"] == model_name_plot]
if not model_records:
continue
epochs = [rec["epoch"] for rec in model_records]
val_loss_values = [rec["val_loss"] for rec in model_records]
val_f1_values = [rec["val_f1"] for rec in model_records]
axes[0].plot(epochs, val_loss_values, marker='o', label=model_name_plot)
axes[1].plot(epochs, val_f1_values, marker='o', label=model_name_plot)
axes[0].set_ylabel('Val Loss')
axes[1].set_ylabel('Val F1')
axes[1].set_xlabel('Epoch')
axes[0].legend(loc='best')
axes[0].set_title(f'Size {train_size} / Rep {repetition} per-epoch metrics')
fig.tight_layout()
fig.savefig(rep_root / 'epoch_history.png', dpi=150)
plt.close(fig)
# Clean up memory after each model
del model, optimizer, scheduler, best_state
if scaler is not None:
del scaler
if device.type == 'cuda':
torch.cuda.empty_cache()
torch.cuda.synchronize()
if per_size_val_metrics:
print(f"\n[INFO] Validation summary for train_size={train_size}, rep={repetition}:")
for model_name, val_f1, val_loss, test_f1 in sorted(
per_size_val_metrics, key=lambda item: item[1], reverse=True
):
print(
f" {model_name:<16} val_f1={val_f1:.4f} "
f"val_loss={val_loss:.4f} test_f1={test_f1:.4f}"
)
summary_path = args.output_dir / "summary.json"
summary_path.parent.mkdir(parents=True, exist_ok=True)
serializable_summary = {
model_name: {
size: {
"acc": metrics["acc"],
"f1": metrics["f1"],
"val_f1": metrics["val_f1"],
"val_loss": metrics["val_loss"],
}
for size, metrics in size_dict.items()
}
for model_name, size_dict in summary.items()
}
with open(summary_path, "w", encoding="utf-8") as f:
json.dump(serializable_summary, f, indent=2)
print("\n[INFO] Final accuracy summary:")
for model_name, results in summary.items():
for size, metrics in results.items():
if metrics["acc"]:
acc_mean = float(np.mean(metrics["acc"]))
acc_std = float(np.std(metrics["acc"]))
f1_mean = float(np.mean(metrics["f1"]))
f1_std = float(np.std(metrics["f1"]))
n = len(metrics["acc"])
print(
f" {model_name} @ {size:4d}/class -> "
f"acc={acc_mean:.4%} ± {acc_std:.4%}, f1={f1_mean:.4f} ± {f1_std:.4f} (n={n})"
)
print("\n[INFO] Final validation F1 summary:")
for model_name, results in summary.items():
for size, metrics in results.items():
if metrics["val_f1"]:
val_mean = float(np.mean(metrics["val_f1"]))
val_std = float(np.std(metrics["val_f1"]))
n = len(metrics["val_f1"])
print(
f" {model_name} @ {size:4d}/class -> "
f"val_f1={val_mean:.4f} ± {val_std:.4f} (n={n})"
)
print("\n[INFO] Final validation loss summary:")
for model_name, results in summary.items():
for size, metrics in results.items():
if metrics["val_loss"]:
loss_mean = float(np.mean(metrics["val_loss"]))
loss_std = float(np.std(metrics["val_loss"]))
n = len(metrics["val_loss"])
print(
f" {model_name} @ {size:4d}/class -> "
f"val_loss={loss_mean:.4f} ± {loss_std:.4f} (n={n})"
)
if HAVE_MPL:
train_sizes_sorted = sorted(args.train_sizes)
plt.figure(figsize=(8, 5))
plotted = False
for model_name in args.models:
model_results = summary.get(model_name, {})
means: List[float] = []
for size in train_sizes_sorted:
val_list = model_results.get(size, {}).get("val_f1", [])
means.append(float(np.mean(val_list)) if val_list else float("nan"))
if not any(np.isfinite(means)):
continue
plt.plot(train_sizes_sorted, means, marker="o", linewidth=2, label=model_name)
plotted = True
if plotted:
plt.title("Validation F1 vs. Training Size")
plt.xlabel("Training samples per class")
plt.ylabel("Validation F1 (macro)")
plt.xticks(train_sizes_sorted)
plt.ylim(0.0, 1.0)
plt.grid(True, which="both", linestyle="--", alpha=0.4)
plt.legend(title="Model", frameon=False)
plt.tight_layout()
plot_path = args.output_dir / "val_f1_summary.png"
plt.savefig(plot_path, dpi=200)
plt.close()
print(f"[INFO] Saved validation F1 plot to {plot_path}")
else:
plt.close()
print("[WARN] No validation F1 data available to plot.")
else:
print("[WARN] Matplotlib not available; skipping validation F1 plot.")
if __name__ == "__main__":
main()