#!/usr/bin/env python3 """Benchmark multiple backbones on modulation classification across dataset sizes. For each desired training size (samples per MCS class) and repetition, the script: 1. Randomly samples spectrograms from distinct (modulation, rate, SNR, doppler) configs. 2. Builds train/val/test splits (val/test sizes are configurable). 3. Fine-tunes several backbones (LWM, ResNet18, EfficientNet-B0, MobileNet-V3, and a small CNN) using the same splits. 4. Reports accuracy statistics and stores checkpoints/metrics per experiment. Input spectrograms are globally normalized using the dataset mean/std stored with the specified pretrained checkpoint (defaults to the latest run under `models/`). Usage example (defaults cover city_1_losangeles/LTE with all available SNR·mobility combos): python task1/train_mcs_models.py --train-sizes 128 --models resnet18 mobilenet_v3_small """ from __future__ import annotations import argparse import copy import csv import glob import json import os import pickle import random import re import sys from collections import Counter, defaultdict from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple from contextlib import nullcontext from datetime import datetime import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset from torch.amp import autocast, GradScaler try: from tqdm import tqdm except ImportError: # pragma: no cover - optional dependency def tqdm(iterable, *args, **kwargs): return iterable PROJECT_ROOT = Path(__file__).resolve().parent.parent if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) from pretraining.pretrained_model import lwm as lwm_model from utils import count_parameters COMM_CANONICAL = { "lte": "LTE", "wifi": "WiFi", "5g": "5G", } COMM_LOWER = {v: k for k, v in COMM_CANONICAL.items()} try: from sklearn.metrics import f1_score as sklearn_f1_score HAVE_SKLEARN = True except ImportError: HAVE_SKLEARN = False try: import matplotlib.pyplot as plt HAVE_MPL = True except ImportError: HAVE_MPL = False try: from task2.mobility_utils import LWMClassifierMinimal # type: ignore except ImportError: # pragma: no cover - optional dependency LWMClassifierMinimal = None # type: ignore[misc] # HPU support detection HPU_AVAILABLE = False try: import habana_frameworks.torch.core as htcore # type: ignore[import-not-found] HPU_AVAILABLE = hasattr(torch, "hpu") and torch.hpu.is_available() except (ImportError, AttributeError): pass def compute_f1(y_true: np.ndarray, y_pred: np.ndarray) -> float: if HAVE_SKLEARN: return float(sklearn_f1_score(y_true, y_pred, average="macro")) classes = np.unique(np.concatenate([y_true, y_pred])) scores = [] for cls in classes: tp = np.sum((y_true == cls) & (y_pred == cls)) fp = np.sum((y_true != cls) & (y_pred == cls)) fn = np.sum((y_true == cls) & (y_pred != cls)) precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 denom = precision + recall f1 = (2 * precision * recall / denom) if denom > 0 else 0.0 scores.append(f1) return float(np.mean(scores)) MODULATION_LABELS = { "BPSK": 0, "QPSK": 1, "QAM16": 2, "QAM64": 3, "QAM256": 4, } LABEL_NAMES = {idx: name for name, idx in MODULATION_LABELS.items()} DEFAULT_LWM_TRAINABLE_LAYERS = 2 # fine-tune the last two transformer blocks _SAMPLE_COUNT_CACHE: Dict[str, int] = {} def normalize_per_sample(specs: np.ndarray, eps: float = 1e-6) -> np.ndarray: if specs.size == 0: return specs.astype(np.float32, copy=False) means = specs.mean(axis=(1, 2), keepdims=True) stds = specs.std(axis=(1, 2), keepdims=True) stds = np.maximum(stds, eps) normalized = (specs - means) / stds return normalized.astype(np.float32, copy=False) def apply_normalization(specs: np.ndarray, stats: Dict[str, object]) -> np.ndarray: mode = str(stats.get("normalization", "dataset")).lower() mean = float(stats.get("mean", 0.0)) std = float(stats.get("std", 1.0)) if abs(std) < 1e-6: std = 1e-6 if mode == "dataset": return ((specs.astype(np.float32, copy=False) - mean) / std).astype(np.float32, copy=False) return normalize_per_sample(specs) def _unique_parameters(params: Iterable[nn.Parameter]) -> List[nn.Parameter]: seen: set[int] = set() unique: List[nn.Parameter] = [] for param in params: pid = id(param) if pid not in seen: unique.append(param) seen.add(pid) return unique def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--data-root", default=str(PROJECT_ROOT / "spectrograms"), help="Root directory containing city folders (default: project_root/spectrograms)", ) parser.add_argument( "--cities", nargs="*", default=["city_1_losangeles"], help="City directories to include (default: %(default)s)", ) parser.add_argument("--comm-types", nargs="*", default=["LTE"], help="Communication standards to include (default: %(default)s)") parser.add_argument("--LTE", dest="select_lte", action="store_true", help="Shortcut for --comm-types LTE") parser.add_argument("--WiFi", dest="select_wifi", action="store_true", help="Shortcut for --comm-types WiFi") parser.add_argument("--5G", dest="select_5g", action="store_true", help="Shortcut for --comm-types 5G") parser.add_argument("--snrs", nargs="*", default=None, help="SNR folders to include for training (default: all available)") parser.add_argument("--val-snrs", nargs="*", default=None, help="SNR folders for validation/test (default: all available)") parser.add_argument( "--mobilities", nargs="*", default=None, help="Mobility folders to include for training (default: all available)", ) parser.add_argument( "--val-mobilities", nargs="*", default=None, help="Mobility folders for validation/test (default: all available)", ) parser.add_argument("--fft-folder", default="512FFT", help="FFT folder name (default: %(default)s)") parser.add_argument( "--device", type=str, default="auto", choices=["auto", "cuda", "hpu", "cpu"], help="Device to use for training (default: auto - detects HPU, then CUDA, then CPU)", ) parser.add_argument( "--gpu-ids", type=int, nargs="*", default=None, help="Specific GPU device IDs to use (only for CUDA, default: all visible GPUs)", ) parser.add_argument( "--train-sizes", type=int, nargs="*", default=[2, 4, 8, 16, 32, 64, 128, 256], help="Training samples per class to benchmark", ) parser.add_argument("--val-per-class", type=int, default=512, help="Validation samples per class") parser.add_argument("--test-per-class", type=int, default=512, help="Test samples per class") parser.add_argument("--repetitions", type=int, default=1, help="Repetitions per train size") parser.add_argument("--epochs", type=int, default=200, help="Epochs per run") parser.add_argument("--batch-size", type=int, default=32, help="Mini-batch size") parser.add_argument("--lr", type=float, default=8e-4, help="Learning rate for fine-tuning") parser.add_argument("--weight-decay", type=float, default=3e-2, help="Weight decay") parser.add_argument( "--no-epoch-history", action="store_true", help="Disable aggregated per-epoch history tracking", ) parser.add_argument( "--no-epoch-plot", action="store_true", help="Disable per-repetition metric plots", ) parser.add_argument( "--save-epoch-checkpoints", action="store_true", help="Persist per-epoch checkpoints (default: disabled)", ) parser.add_argument( "--backbone-lr-factor", type=float, default=0.3, help="Relative LR multiplier applied to unfrozen backbone parameters (default: %(default)s)", ) parser.add_argument( "--early-patience", type=int, default=5, help="Early stopping patience based on validation F1 (default: %(default)s)", ) parser.add_argument( "--early-min-epochs", type=int, default=10, help="Minimum number of epochs to run before early stopping can trigger (default: %(default)s)", ) parser.add_argument( "--finetune-epochs", type=int, default=0, help="Additional fine-tuning epochs to run after the main schedule (default: %(default)s)", ) parser.add_argument( "--finetune-lr-factor", type=float, default=0.1, help="Multiplier applied to the base learning rate during fine-tuning (default: %(default)s)", ) parser.add_argument( "--finetune-patience", type=int, default=3, help="Early stopping patience for the fine-tuning phase (default: %(default)s)", ) parser.add_argument( "--finetune-min-epochs", type=int, default=0, help="Minimum epochs to execute in the fine-tuning phase before early stopping is considered (default: %(default)s)", ) parser.add_argument( "--debug-eval-batches", type=int, default=0, help="Log detailed stats for the first N evaluation batches (0 disables logging)", ) parser.add_argument( "--debug-eval-interval", type=int, default=1, help="Evaluate logging interval in batches when debug logging is enabled (default: %(default)s)", ) parser.add_argument( "--debug-eval-softmax", action="store_true", help="When debugging evaluations, also log softmax statistics per batch", ) parser.add_argument( "--models", nargs="*", default=["lwm", "resnet18", "efficientnet_b0", "mobilenet_v3_small", "simple_cnn", "ieee_cnn"], help="Models to benchmark", ) parser.add_argument( "--raw-input-models", nargs="*", default=None, help=( "Models that should receive raw spectrograms without additional normalization " "(default: all non-LWM models)" ), ) parser.add_argument( "--lwm-trainable-layers", type=int, default=2, help="Number of transformer layers (from the end) to fine-tune in LWM (default: %(default)s)", ) parser.add_argument( "--lwm-classifier-dim", type=int, default=64, help="Hidden width for the LWM classifier MLP head (default: %(default)s; ignored for linear head)", ) parser.add_argument( "--lwm-head-dropout", type=float, default=0.0, help="Dropout applied inside the LWM classifier head (default: %(default)s)", ) parser.add_argument( "--lwm-head-type", choices=("linear", "mlp", "res1dcnn"), default="res1dcnn", help="Classifier head architecture for LWM (default: %(default)s)", ) parser.add_argument( "--lwm-backbone-lr-factor", type=float, default=0.2, help="LR multiplier for unfrozen LWM backbone layers (default: %(default)s)", ) parser.add_argument( "--resnet-head-width", type=int, default=512, help="Hidden width for the ResNet18 classifier head (default: %(default)s)", ) parser.add_argument( "--efficientnet-head-width", type=int, default=296, help="Hidden width for the EfficientNet-B0 classifier head (default: %(default)s)", ) parser.add_argument( "--mobilenet-head-width", type=int, default=576, help="Hidden width for the MobileNetV3-Small classifier head (default: %(default)s)", ) parser.add_argument( "--imagenet-head-dropout", type=float, default=0.6, help="Dropout probability used inside ImageNet backbone classifier heads (default: %(default)s)", ) parser.add_argument( "--imagenet-weight-decay-scale", type=float, default=2.0, help="Multiplier applied to weight decay for ImageNet backbone trainable parameters (default: %(default)s)", ) parser.add_argument( "--simple-cnn-hidden-dims", type=int, nargs="*", default=[272, 128], help="Hidden layer widths for the Simple CNN classifier (default: %(default)s)", ) parser.add_argument( "--ieee-cnn-hidden-dims", type=int, nargs="*", default=[512, 256], help="Hidden layer widths for the IEEE CNN classifier (default: %(default)s)", ) parser.add_argument( "--ieee-cnn-dropout", type=float, default=0.3, help="Dropout rate for the IEEE CNN model (default: %(default)s)", ) parser.add_argument("--checkpoint", type=Path, default=None, help="Path to pretrained LWM checkpoint (.pth)") parser.add_argument("--stats", type=Path, default=None, help="dataset_stats.json path") parser.add_argument( "--models-root", type=Path, default=PROJECT_ROOT / "models", help="Root with pretrained runs (default: project_root/models)", ) parser.add_argument( "--output-dir", type=Path, default=PROJECT_ROOT / "task1" / "mcs_benchmarks", help="Results root directory (per-run subfolder created automatically)", ) parser.add_argument( "--export-full-model", type=Path, default=None, help="Directory where best full-model checkpoints (backbone + head) will be exported per run", ) parser.add_argument("--seed", type=int, default=42, help="Base random seed") args = parser.parse_args() args.output_root = args.output_dir quick_comm: List[str] = [] if getattr(args, "select_lte", False): quick_comm.append("LTE") if getattr(args, "select_wifi", False): quick_comm.append("WiFi") if getattr(args, "select_5g", False): quick_comm.append("5G") if quick_comm: args.comm_types = quick_comm normalized: List[str] = [] for comm in args.comm_types: upper = comm.upper() if upper == "WIFI": normalized.append("WiFi") elif upper == "LTE": normalized.append("LTE") elif upper == "5G": normalized.append("5G") else: normalized.append(comm) args.comm_types = normalized timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") comm_tokens: List[str] = [] for comm in args.comm_types: canonical = COMM_LOWER.get(comm, comm.lower()) token = re.sub(r"[^a-z0-9]+", "-", canonical.lower()).strip("-") comm_tokens.append(token or "unknown") comm_suffix = "-".join(comm_tokens) if comm_tokens else "unknown" args.run_timestamp = timestamp args.output_dir = args.output_root / comm_suffix / timestamp args.comm_suffix = comm_suffix if args.gpu_ids is not None and len(args.gpu_ids) == 0: args.gpu_ids = None if not args.simple_cnn_hidden_dims: args.simple_cnn_hidden_dims = [512, 256] if not args.ieee_cnn_hidden_dims: args.ieee_cnn_hidden_dims = [512, 256] if args.raw_input_models is None: args.raw_input_models = [ model.lower() for model in args.models if model.lower() not in {"lwm"} ] else: args.raw_input_models = [model.lower() for model in args.raw_input_models] args.models = [model for model in args.models] args.save_epoch_history = not args.no_epoch_history args.plot_epoch_history = not args.no_epoch_plot args.imagenet_head_dropout = float(max(0.0, min(args.imagenet_head_dropout, 0.95))) args.imagenet_weight_decay_scale = float(max(0.0, args.imagenet_weight_decay_scale)) return args def find_latest_run(models_root: Path) -> Path: run_dirs = [p for p in models_root.iterdir() if p.is_dir()] run_dirs = [p for p in run_dirs if not p.name.lower().endswith("_models")] valid_runs = [p for p in run_dirs if any(p.glob("*.pth"))] if valid_runs: return max(valid_runs, key=lambda p: p.stat().st_mtime) checkpoints = list(models_root.glob("*.pth")) if checkpoints: print(f"[INFO] No checkpoint-bearing subdirectories under {models_root}; using root as run directory.") return models_root raise FileNotFoundError(f"No checkpoints found under {models_root}") def find_best_checkpoint(run_dir: Path) -> Path: candidates = list(run_dir.glob("*.pth")) if not candidates: raise FileNotFoundError(f"No checkpoints in {run_dir}") def metric(path: Path) -> float: match = re.search(r"_val([0-9]+(?:\.[0-9]+)?)", path.name) if match: try: return float(match.group(1)) except ValueError: pass return float("inf") best = min(candidates, key=metric) return best def resolve_models_directory(args: argparse.Namespace) -> Path: base = args.models_root.expanduser().resolve() if not base.exists(): raise FileNotFoundError(f"Models root not found: {base}") matches: List[Path] = [] for comm in args.comm_types: subdir = base / f"{comm}_models" if subdir.exists(): matches.append(subdir) else: print(f"[WARN] Models directory for {comm} not found at {subdir}") if len(matches) == 1: print(f"[INFO] Using models directory for {args.comm_types[0]}: {matches[0]}") return matches[0] if len(matches) > 1: raise ValueError( "Multiple communication-specific model directories detected; please provide --checkpoint explicitly." ) print(f"[INFO] Using shared models directory: {base}") return base def resolve_checkpoint_and_stats(args: argparse.Namespace, require_checkpoint: bool) -> Tuple[Path | None, Dict[str, object]]: checkpoint: Path | None = None models_dir = resolve_models_directory(args) user_provided_stats = args.stats is not None if args.checkpoint is not None: checkpoint = args.checkpoint.expanduser().resolve() if not checkpoint.exists(): raise FileNotFoundError(f"Checkpoint not found: {checkpoint}") stats_path = args.stats.expanduser().resolve() if user_provided_stats else checkpoint.parent / "dataset_stats.json" else: run_dir = find_latest_run(models_dir) stats_path = run_dir / "dataset_stats.json" if require_checkpoint: checkpoint = find_best_checkpoint(run_dir) else: checkpoint = None if stats_path.exists(): try: with open(stats_path, "r", encoding="utf-8") as f: stats = json.load(f) except json.JSONDecodeError as exc: if user_provided_stats: raise ValueError( f"Failed to parse dataset_stats.json at {stats_path}: {exc}" ) from exc print( f"[WARN] Corrupt dataset_stats.json at {stats_path}; " "falling back to mean=0/std=1 per-sample normalization." ) stats = {"mean": 0.0, "std": 1.0, "normalization": "per_sample"} else: if "mean" not in stats or "std" not in stats: raise ValueError("dataset_stats.json must contain 'mean' and 'std'") stats.setdefault("normalization", stats.get("mode", "dataset")) else: if user_provided_stats: raise FileNotFoundError(f"dataset_stats.json not found: {stats_path}") stats = {"mean": 0.0, "std": 1.0, "normalization": "per_sample"} print(f"[WARN] dataset_stats.json not found at {stats_path}. Falling back to per-sample normalization.") if checkpoint is not None: print(f"[INFO] Using checkpoint: {checkpoint}") elif require_checkpoint: raise FileNotFoundError("LWM requested but no checkpoint available") else: print("[INFO] No LWM checkpoint required for selected models") norm_mode = str(stats.get("normalization", "dataset")) if norm_mode.lower() == "dataset": print(f"[INFO] Dataset stats -> mean={stats['mean']:.4f}, std={stats['std']:.4f}") else: print("[INFO] Normalization mode: per_sample") return checkpoint, { "mean": float(stats.get("mean", 0.0)), "std": float(stats.get("std", 1.0)), "normalization": norm_mode, } def identify_modulation(path: str) -> tuple[int | None, str | None]: for mod_name, label in MODULATION_LABELS.items(): if mod_name in path: return label, mod_name return None, None def _extract_metadata(parts: Sequence[str]) -> Tuple[str, str, str]: rate = next((part for part in parts if part.startswith("rate")), "rate_unknown") snr = next((part for part in parts if part.startswith("SNR")), "SNR_unknown") mobility = next((part for part in parts if part in {"static", "pedestrian", "vehicular"}), "mobility_unknown") return rate, snr, mobility def discover_snr_mobility( data_root: Path, cities: Sequence[str], comm_types: Sequence[str], fft_folder: str, ) -> Tuple[List[str], List[str]]: snrs: set[str] = set() mobilities: set[str] = set() for city in cities: for comm in comm_types: base = data_root / city / comm if not base.exists(): continue for root, dirs, _ in os.walk(base): parts = Path(root).parts for part in parts: if part.startswith("SNR") and part.endswith("dB"): snrs.add(part) elif part in {"static", "pedestrian", "vehicular"}: mobilities.add(part) if not snrs: snrs.add("SNR20dB") if not mobilities: mobilities.add("static") return sorted(snrs), sorted(mobilities) def build_config_map( data_root: Path, cities: Sequence[str], comm_types: Sequence[str], snrs: Sequence[str], mobilities: Sequence[str], fft_folder: str, ) -> Dict[int, Dict[str, List[str]]]: class_configs: Dict[int, Dict[str, List[str]]] = defaultdict(lambda: defaultdict(list)) for city in cities: for comm in comm_types: base = data_root / city / comm for snr in snrs: for mobility in mobilities: pattern = str(base / "**" / snr / mobility / "**" / fft_folder / "**" / "spectrograms" / "*.pkl") for path_str in glob.glob(pattern, recursive=True): cls, modulation_name = identify_modulation(path_str) if cls is None: continue rate, _, _ = _extract_metadata(Path(path_str).parts) config_name = f"{modulation_name}_{rate}_{snr}_{mobility}" class_configs[cls][config_name].append(path_str) return class_configs def build_global_config_map( data_root: Path, cities: Sequence[str], comm_types: Sequence[str], fft_folder: str, ) -> Dict[int, Dict[str, List[str]]]: class_configs: Dict[int, Dict[str, List[str]]] = defaultdict(lambda: defaultdict(list)) for city in cities: for comm in comm_types: base = data_root / city / comm pattern = str(base / "**" / fft_folder / "**" / "spectrograms" / "*.pkl") for path_str in glob.glob(pattern, recursive=True): cls, modulation_name = identify_modulation(path_str) if cls is None: continue rate, snr_part, mobility_part = _extract_metadata(Path(path_str).parts) config_name = f"{modulation_name}_{rate}_{snr_part}_{mobility_part}" class_configs[cls][config_name].append(path_str) return class_configs def _count_samples_in_path(path: str) -> int: cached = _SAMPLE_COUNT_CACHE.get(path) if cached is not None: return cached arr = load_all_samples(path) count = int(arr.shape[0]) _SAMPLE_COUNT_CACHE[path] = count return count class LazyConfigArray: """Lazily views spectrograms spread across multiple pickled files.""" __slots__ = ("paths", "_counts", "_offsets", "_total", "shape", "dtype", "ndim") def __init__(self, paths: Sequence[str]) -> None: filtered_paths: List[str] = [] counts: List[int] = [] for path in sorted(paths): count = _count_samples_in_path(path) if count <= 0: continue filtered_paths.append(path) counts.append(count) self.paths: Tuple[str, ...] = tuple(filtered_paths) if counts: self._counts = np.array(counts, dtype=np.int64) self._offsets = np.concatenate(([0], np.cumsum(self._counts))) self._total = int(self._offsets[-1]) else: self._counts = np.empty(0, dtype=np.int64) self._offsets = np.array([0], dtype=np.int64) self._total = 0 self.shape = (self._total, 128, 128) self.dtype = np.float32 self.ndim = 3 def __len__(self) -> int: return self._total def _resolve_index(self, index: int) -> Tuple[int, int]: if self._total == 0: raise IndexError("attempting to index empty LazyConfigArray") if index < 0: index += self._total if index < 0 or index >= self._total: raise IndexError("index out of range for LazyConfigArray") path_idx = int(np.searchsorted(self._offsets[1:], index, side="right")) start = int(self._offsets[path_idx]) return path_idx, int(index - start) def _load_path(self, path_idx: int) -> np.ndarray: path = self.paths[path_idx] return load_all_samples(path) def __getitem__(self, item: Any) -> np.ndarray: if isinstance(item, (int, np.integer)): path_idx, local_idx = self._resolve_index(int(item)) data = self._load_path(path_idx) sample = data[local_idx].copy() return sample indices = np.asarray(item, dtype=np.int64) if indices.ndim == 0: indices = indices.reshape(1) else: indices = indices.reshape(-1) if indices.size == 0: return np.empty((0, 128, 128), dtype=np.float32) resolved: Dict[int, List[Tuple[int, int]]] = {} for pos, raw_idx in enumerate(indices): path_idx, local_idx = self._resolve_index(int(raw_idx)) resolved.setdefault(path_idx, []).append((pos, local_idx)) result = np.empty((indices.size, 128, 128), dtype=np.float32) for path_idx, items in resolved.items(): data = self._load_path(path_idx) local_positions = [loc for _, loc in items] chunk = data[local_positions] for offset, (pos, _) in enumerate(items): result[pos] = chunk[offset] return result def load_config_arrays(class_configs: Dict[int, Dict[str, List[str]]]) -> Dict[int, Dict[str, LazyConfigArray]]: loaded: Dict[int, Dict[str, LazyConfigArray]] = {} for cls, configs in class_configs.items(): arrays_for_cls: Dict[str, LazyConfigArray] = {} for config_name, paths in configs.items(): lazy_array = LazyConfigArray(paths) if len(lazy_array) == 0: continue arrays_for_cls[config_name] = lazy_array if arrays_for_cls: loaded[cls] = arrays_for_cls return loaded def load_all_samples(path: str) -> np.ndarray: with open(path, "rb") as f: data = pickle.load(f) if isinstance(data, dict) and "spectrograms" in data: arr = data["spectrograms"] elif isinstance(data, np.ndarray): arr = data else: return np.empty((0, 128, 128), dtype=np.float32) arr = np.asarray(arr, dtype=np.float32) if arr.ndim == 2: arr = arr[None, ...] if arr.shape[1:] != (128, 128): return np.empty((0, 128, 128), dtype=np.float32) return arr def sample_from_paths( paths: Sequence[str], n_samples: int, rng: np.random.Generator, used_map: Dict[str, set[int]], ) -> Tuple[np.ndarray, List[Tuple[str, np.ndarray]]]: if not paths: raise RuntimeError("No files available for sampling") paths_array = np.array(paths, dtype=object) order = rng.permutation(len(paths_array)) remaining = n_samples collected: List[np.ndarray] = [] info: List[Tuple[str, np.ndarray]] = [] for idx in order: if remaining <= 0: break path = str(paths_array[idx]) samples = load_all_samples(path) total = samples.shape[0] used = used_map[path] if used: used_idx = np.fromiter(used, dtype=np.int64, count=len(used)) available = np.setdiff1d(np.arange(total), used_idx, assume_unique=True) else: available = np.arange(total) if available.size == 0: continue take = min(remaining, available.size) chosen = rng.choice(available, size=take, replace=False) collected.append(samples[chosen]) used_map[path].update(int(i) for i in chosen) info.append((path, chosen)) remaining -= take if remaining > 0: raise RuntimeError("Insufficient samples remaining to satisfy request") result = np.concatenate(collected, axis=0) if len(collected) > 1 else collected[0] return result, info def _ensure_available(total_needed: int, availability: Dict[str, set]) -> None: remaining = sum(len(indices) for indices in availability.values()) if remaining < total_needed: raise RuntimeError( f"Insufficient samples: need {total_needed}, only {remaining} available across configs" ) def _sample_from_availability( arrays_map: Dict[str, LazyConfigArray], availability: Dict[str, set[int]], total_needed: int, rng: np.random.Generator, ) -> Tuple[np.ndarray, Dict[str, set[int]]]: if total_needed <= 0: return np.empty((0, 128, 128), dtype=np.float32), {cfg: set() for cfg in arrays_map} _ensure_available(total_needed, availability) remaining = total_needed configs = [cfg for cfg, indices in availability.items() if indices] used: Dict[str, set[int]] = {cfg: set() for cfg in arrays_map} collected: List[np.ndarray] = [] while remaining > 0 and configs: cfg = rng.choice(configs) available_indices = np.array(list(availability[cfg]), dtype=np.int64) if available_indices.size == 0: configs = [c for c in configs if c != cfg] continue take = min(max(1, remaining // max(len(configs), 1)), remaining, available_indices.size) chosen = rng.choice(available_indices, size=take, replace=False) collected.append(arrays_map[cfg][chosen]) chosen_set = {int(idx) for idx in chosen} used[cfg].update(chosen_set) availability[cfg].difference_update(chosen_set) remaining -= take configs = [c for c in configs if availability[c]] if remaining > 0: raise RuntimeError("Sampling failed to collect the requested number of samples") stacked = np.concatenate(collected, axis=0) if collected else np.empty((0, 128, 128), dtype=np.float32) return stacked.astype(np.float32, copy=False), used def sample_train_arrays( arrays_map: Dict[str, LazyConfigArray], availability: Dict[str, set[int]], train_size: int, rng: np.random.Generator, ) -> Tuple[np.ndarray, Dict[str, set[int]]]: return _sample_from_availability(arrays_map, availability, train_size, rng) def sample_global_arrays( arrays_map: Dict[str, LazyConfigArray], availability: Dict[str, set[int]], per_class: int, rng: np.random.Generator, ) -> Tuple[np.ndarray, Dict[str, set[int]]]: return _sample_from_availability(arrays_map, availability, per_class, rng) class SpectrogramDataset(Dataset): def __init__(self, specs: np.ndarray, labels: np.ndarray): self.specs = specs.astype(np.float32, copy=False) self.labels = labels.astype(np.int64, copy=False) def __len__(self) -> int: return len(self.labels) def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]: return torch.from_numpy(self.specs[idx]), int(self.labels[idx]) def normalize_batch(specs: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: mean = specs.mean() std = specs.std(unbiased=False) std = torch.clamp(std, min=eps) return (specs - mean) / std def apply_spec_augment( specs: torch.Tensor, *, freq_mask_width: int = 12, time_mask_width: int = 16, freq_masks: int = 2, time_masks: int = 2, mask_prob: float = 0.5, noise_std: float = 0.0, ) -> torch.Tensor: """Apply light-weight SpecAugment-style masking to a batch of spectrograms. The function accepts tensors shaped ``[B, H, W]`` or ``[B, 1, H, W]`` and returns an augmented tensor with the same shape. Masks use the sample mean to avoid introducing large bias and are applied per-sample with the given probability. Gaussian noise (if requested) is injected before masking. """ if mask_prob <= 0.0 and noise_std <= 0.0: return specs if specs.dim() not in (3, 4): raise ValueError(f"Spectrograms must be rank-3 or rank-4, got shape {tuple(specs.shape)}") needs_squeeze = specs.dim() == 3 augmented = specs.unsqueeze(1) if needs_squeeze else specs batch_size, _, freq_dim, time_dim = augmented.shape if mask_prob < 1.0: apply_mask = torch.rand(batch_size, device=augmented.device) < mask_prob else: apply_mask = torch.ones(batch_size, dtype=torch.bool, device=augmented.device) freq_mask_width = max(0, int(freq_mask_width)) time_mask_width = max(0, int(time_mask_width)) freq_masks = max(0, int(freq_masks)) time_masks = max(0, int(time_masks)) for idx in range(batch_size): if not apply_mask[idx]: continue sample = augmented[idx] if noise_std > 0.0: sample = sample + noise_std * torch.randn_like(sample) fill_value = sample.mean() if freq_mask_width > 0 and freq_masks > 0: max_width = min(freq_mask_width, freq_dim) for _ in range(freq_masks): width = int(torch.randint(0, max_width + 1, (1,), device=augmented.device).item()) if width == 0 or width > freq_dim: continue start = 0 if freq_dim == width else int(torch.randint(0, freq_dim - width + 1, (1,), device=augmented.device).item()) sample[:, start:start + width, :] = fill_value if time_mask_width > 0 and time_masks > 0: max_width = min(time_mask_width, time_dim) for _ in range(time_masks): width = int(torch.randint(0, max_width + 1, (1,), device=augmented.device).item()) if width == 0 or width > time_dim: continue start = 0 if time_dim == width else int(torch.randint(0, time_dim - width + 1, (1,), device=augmented.device).item()) sample[:, :, start:start + width] = fill_value augmented[idx] = sample return augmented.squeeze(1) if needs_squeeze else augmented def _write_epoch_history( rep_root: Path, records: Sequence[Dict[str, object]], enable_csv: bool, enable_plot: bool, ) -> None: if not records: return rep_root.mkdir(parents=True, exist_ok=True) if enable_csv: base_fields = [ "model", "epoch", "phase", "train_loss", "val_loss", "val_acc", "val_f1", "lr", "train_size_requested", "train_size_effective", ] extra_fields = sorted( {key for rec in records for key in rec.keys() if key not in base_fields} ) fieldnames = base_fields + extra_fields sorted_records = sorted(records, key=lambda r: (r["epoch"], r["model"], r.get("phase", ""))) history_path = rep_root / "epoch_history.csv" with open(history_path, "w", newline="", encoding="utf-8") as csvfile: writer = csv.DictWriter(csvfile, fieldnames=fieldnames, extrasaction="ignore") writer.writeheader() writer.writerows(sorted_records) if enable_plot and HAVE_MPL: models_in_run = sorted({rec["model"] for rec in records}) fig, axes = plt.subplots(2, 1, figsize=(8, 6), sharex=True) for ax in axes: ax.grid(True, linestyle='--', alpha=0.3) for model_name_plot in models_in_run: model_records = [rec for rec in records if rec["model"] == model_name_plot] if not model_records: continue epochs = [rec["epoch"] for rec in model_records] val_loss_values = [rec["val_loss"] for rec in model_records] val_f1_values = [rec["val_f1"] for rec in model_records] axes[0].plot(epochs, val_loss_values, marker='o', label=model_name_plot) axes[1].plot(epochs, val_f1_values, marker='o', label=model_name_plot) axes[0].set_ylabel('Val Loss') axes[1].set_ylabel('Val F1') axes[1].set_xlabel('Epoch') axes[0].legend(loc='best') axes[0].set_title('Per-epoch validation metrics') fig.tight_layout() fig.savefig(rep_root / 'epoch_history.png', dpi=150) plt.close(fig) class ResidualBlock1D(nn.Module): """Lightweight residual block used by the res1dcnn head.""" def __init__(self, in_channels: int, out_channels: int) -> None: super().__init__() self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm1d(out_channels) self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm1d(out_channels) self.shortcut = nn.Identity() if in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv1d(in_channels, out_channels, kernel_size=1), nn.BatchNorm1d(out_channels), ) def forward(self, x: torch.Tensor) -> torch.Tensor: residual = x x = F.relu(self.bn1(self.conv1(x))) x = self.bn2(self.conv2(x)) x = x + self.shortcut(residual) x = F.relu(x) return x class Res1DCNNHead(nn.Module): """Residual 1D CNN classifier head that operates on 128-d LWM features.""" def __init__(self, input_dim: int, num_classes: int, dropout: float = 0.1) -> None: super().__init__() self.input_dim = int(input_dim) hidden_dim = 64 self.conv1 = nn.Conv1d(1, hidden_dim, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm1d(hidden_dim) self.res_block = ResidualBlock1D(hidden_dim, hidden_dim) self.dropout = nn.Dropout(dropout) self.fc = nn.Linear(hidden_dim, num_classes) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.unsqueeze(1) x = F.relu(self.bn1(self.conv1(x))) x = self.res_block(x) x = F.adaptive_avg_pool1d(x, 1).squeeze(-1) x = self.dropout(x) return self.fc(x) class LWMClassifier(nn.Module): def __init__( self, backbone: nn.Module, trainable_layers: int, num_classes: int, classifier_dim: int = 128, head_dropout: float = 0.1, head_type: str = "mlp", ): super().__init__() self.backbone = backbone self.patch_size = 4 self.unfold = nn.Unfold(kernel_size=self.patch_size, stride=self.patch_size) head_dropout = max(0.0, float(head_dropout)) head_type = head_type.lower().strip() if head_type == "linear": head_layers: List[nn.Module] = [nn.LayerNorm(128)] if head_dropout > 0: head_layers.append(nn.Dropout(head_dropout)) head_layers.append(nn.Linear(128, num_classes)) self.classifier = nn.Sequential(*head_layers) elif head_type == "res1dcnn": self.classifier = nn.Sequential( nn.LayerNorm(128), Res1DCNNHead(128, num_classes, dropout=head_dropout), ) else: head_layers = [ nn.LayerNorm(128), nn.Linear(128, classifier_dim), nn.GELU(), ] if head_dropout > 0: head_layers.append(nn.Dropout(head_dropout)) head_layers.append(nn.Linear(classifier_dim, num_classes)) self.classifier = nn.Sequential(*head_layers) for param in self.backbone.parameters(): param.requires_grad = False if trainable_layers > 0: for layer in self.backbone.layers[-trainable_layers:]: for param in layer.parameters(): param.requires_grad = True # Enable gradient checkpointing for memory efficiency if hasattr(layer, 'gradient_checkpointing'): layer.gradient_checkpointing = True def spectrogram_to_tokens(self, x: torch.Tensor) -> torch.Tensor: x = x.unsqueeze(1) patches = self.unfold(x).transpose(1, 2) cls = torch.full( (patches.size(0), 1, patches.size(-1)), 0.2, dtype=patches.dtype, device=patches.device ) return torch.cat([cls, patches], dim=1) def forward_features(self, x: torch.Tensor) -> torch.Tensor: tokens = self.spectrogram_to_tokens(x) outputs = self.backbone(tokens) if outputs.size(1) <= 1: return outputs[:, 0, :] return outputs[:, 1:, :].mean(dim=1) def forward(self, x: torch.Tensor) -> torch.Tensor: cls = self.forward_features(x) return self.classifier(cls) def create_simple_cnn( num_classes: int, hidden_dims: Tuple[int, ...] = (192,), dropout: float = 0.3, ) -> nn.Module: """Create baseline CNN with configurable classifier width.""" if not hidden_dims: raise ValueError("hidden_dims must contain at least one value") layers: List[nn.Module] = [ nn.Conv2d(1, 16, 5, padding=2), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(16, 32, 5, padding=2), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32, 64, 5, padding=2), nn.ReLU(), nn.AdaptiveAvgPool2d((4, 4)), nn.Flatten(), nn.Dropout(dropout), ] in_dim = 4 * 4 * 64 fc_layers: List[nn.Module] = [] for idx, hidden_dim in enumerate(hidden_dims): fc_layers.append(nn.Linear(in_dim, hidden_dim)) fc_layers.append(nn.ReLU()) fc_layers.append(nn.Dropout(dropout)) in_dim = hidden_dim fc_layers.append(nn.Linear(in_dim, num_classes)) return nn.Sequential(*layers, *fc_layers) def create_ieee_cnn( num_classes: int, hidden_dims: Tuple[int, ...] = (512, 256), dropout: float = 0.3, ) -> nn.Module: """CNN inspired by IEEE 2021 joint SNR/mobility classifier.""" if not hidden_dims: raise ValueError("hidden_dims must contain at least one value") layers: List[nn.Module] = [ nn.Conv2d(1, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), nn.Dropout2d(p=dropout), nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), nn.Dropout2d(p=dropout), nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), nn.Dropout2d(p=dropout), nn.Conv2d(128, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), nn.Dropout2d(p=dropout), nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((4, 4)), nn.Flatten(), nn.Dropout(dropout), ] in_dim = 4 * 4 * 256 fc_layers: List[nn.Module] = [] for hidden_dim in hidden_dims: fc_layers.append(nn.Linear(in_dim, hidden_dim)) fc_layers.append(nn.BatchNorm1d(hidden_dim)) fc_layers.append(nn.ReLU(inplace=True)) fc_layers.append(nn.Dropout(dropout)) in_dim = hidden_dim fc_layers.append(nn.Linear(in_dim, num_classes)) return nn.Sequential(*layers, *fc_layers) def build_model( name: str, num_classes: int, checkpoint: Path, device: torch.device, trainable_layers: int, backbone_lr_factor: float, overrides: Dict[str, object] | None = None, ) -> Tuple[nn.Module, List[Dict[str, object]]]: name = name.lower() param_groups: List[Dict[str, object]] = [] overrides = overrides or {} if name == "lwm": backbone = lwm_model(element_length=16, d_model=128, n_layers=12, max_len=1025, n_heads=8, dropout=0.1) if checkpoint is None: raise FileNotFoundError("Checkpoint is required for LWM-based models") try: state = torch.load(checkpoint, map_location="cpu", weights_only=True) except TypeError: # Older torch versions do not support weights_only state = torch.load(checkpoint, map_location="cpu") if any(k.startswith("module.") for k in state): state = {k.replace("module.", ""): v for k, v in state.items()} if any(k.startswith("backbone.") for k in state): backbone_state = { k.split("backbone.", 1)[1]: v for k, v in state.items() if k.startswith("backbone.") } else: backbone_state = { k: v for k, v in state.items() if not k.startswith("classifier.") and not k.startswith("projection_head.") } backbone.load_state_dict(backbone_state, strict=False) classifier_dim = int(overrides.get("lwm_classifier_dim", 96)) head_dropout = float(overrides.get("lwm_head_dropout", 0.1)) head_type = str(overrides.get("lwm_head_type", "mlp")).lower() model = LWMClassifier( backbone, trainable_layers=trainable_layers, num_classes=num_classes, classifier_dim=classifier_dim, head_dropout=head_dropout, head_type=head_type, ) head_params = list(model.classifier.parameters()) param_groups.append({"params": head_params, "scale": 1.0}) if trainable_layers > 0: backbone_params: List[nn.Parameter] = [] for layer in model.backbone.layers[-trainable_layers:]: backbone_params.extend(layer.parameters()) backbone_params = _unique_parameters(backbone_params) if backbone_params: param_groups.append({"params": backbone_params, "scale": backbone_lr_factor}) elif name == "resnet18": from torchvision import models backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) backbone.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) nn.init.kaiming_normal_(backbone.conv1.weight, mode='fan_out', nonlinearity='relu') in_features = backbone.fc.in_features head_width = int(overrides.get("resnet_head_width", 384)) imagenet_head_dropout = float(overrides.get("imagenet_head_dropout", 0.45)) imagenet_head_dropout = max(0.0, min(imagenet_head_dropout, 0.9)) pre_fc_dropout = max(0.0, min(imagenet_head_dropout * 0.5, 0.9)) backbone.fc = nn.Sequential( nn.Dropout(p=pre_fc_dropout), nn.Linear(in_features, head_width), nn.LayerNorm(head_width), nn.ReLU(inplace=True), nn.Dropout(p=imagenet_head_dropout), nn.Linear(head_width, num_classes), ) for param in backbone.parameters(): param.requires_grad = False head_params = list(backbone.fc.parameters()) for param in head_params: param.requires_grad = True imagenet_weight_decay = overrides.get("imagenet_weight_decay", None) head_group: Dict[str, object] = {"params": head_params, "scale": 1.0} if imagenet_weight_decay is not None: head_group["weight_decay"] = float(imagenet_weight_decay) param_groups.append(head_group) adapt_params: List[nn.Parameter] = [] if hasattr(backbone.layer4[0], "downsample") and backbone.layer4[0].downsample is not None: adapt_params.extend(backbone.layer4[0].downsample[0].parameters()) if len(backbone.layer4[0].downsample) > 1: adapt_params.extend(backbone.layer4[0].downsample[1].parameters()) for module in backbone.layer4[-1].modules(): if isinstance(module, nn.BatchNorm2d): adapt_params.extend(module.parameters()) adapt_params = _unique_parameters(adapt_params) for param in adapt_params: param.requires_grad = True if adapt_params: adapt_group: Dict[str, object] = {"params": adapt_params, "scale": backbone_lr_factor} if imagenet_weight_decay is not None: adapt_group["weight_decay"] = float(imagenet_weight_decay) param_groups.append(adapt_group) model = backbone elif name == "efficientnet_b0": from torchvision import models backbone = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT) first_conv = backbone.features[0][0] backbone.features[0][0] = nn.Conv2d(1, first_conv.out_channels, kernel_size=3, stride=2, padding=1, bias=False) nn.init.kaiming_normal_(backbone.features[0][0].weight, mode='fan_out', nonlinearity='relu') in_features = backbone.classifier[-1].in_features head_width = int(overrides.get("efficientnet_head_width", 192)) imagenet_head_dropout = float(overrides.get("imagenet_head_dropout", 0.45)) imagenet_head_dropout = max(0.0, min(imagenet_head_dropout, 0.9)) pre_fc_dropout = max(0.0, min(imagenet_head_dropout * 0.5, 0.9)) backbone.classifier = nn.Sequential( nn.Dropout(p=pre_fc_dropout), nn.Linear(in_features, head_width), nn.LayerNorm(head_width), nn.ReLU(inplace=True), nn.Dropout(p=imagenet_head_dropout), nn.Linear(head_width, num_classes), ) for param in backbone.parameters(): param.requires_grad = False head_params = list(backbone.classifier.parameters()) for param in head_params: param.requires_grad = True imagenet_weight_decay = overrides.get("imagenet_weight_decay", None) head_group = {"params": head_params, "scale": 1.0} if imagenet_weight_decay is not None: head_group["weight_decay"] = float(imagenet_weight_decay) param_groups.append(head_group) adapt_params: List[nn.Parameter] = [] final_block = backbone.features[7][0] # Depthwise conv + associated norms for the last MBConv block depthwise = final_block.block[1][0] adapt_params.extend(depthwise.parameters()) for idx in (0, 1, 3): for module in final_block.block[idx].modules(): if isinstance(module, nn.BatchNorm2d): adapt_params.extend(module.parameters()) adapt_params = _unique_parameters(adapt_params) for param in adapt_params: param.requires_grad = True if adapt_params: adapt_group = {"params": adapt_params, "scale": backbone_lr_factor} if imagenet_weight_decay is not None: adapt_group["weight_decay"] = float(imagenet_weight_decay) param_groups.append(adapt_group) model = backbone elif name == "mobilenet_v3_small": from torchvision import models backbone = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT) backbone.features[0][0] = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1, bias=False) nn.init.kaiming_normal_(backbone.features[0][0].weight, mode='fan_out', nonlinearity='relu') with torch.no_grad(): dummy = torch.zeros(1, 1, 128, 128) features = backbone.features(dummy) pooled = backbone.avgpool(features) flattened = torch.flatten(pooled, 1) in_features = flattened.shape[1] head_width = int(overrides.get("mobilenet_head_width", 320)) imagenet_head_dropout = float(overrides.get("imagenet_head_dropout", 0.45)) imagenet_head_dropout = max(0.0, min(imagenet_head_dropout, 0.9)) pre_fc_dropout = max(0.0, min(imagenet_head_dropout * 0.5, 0.9)) backbone.classifier = nn.Sequential( nn.Dropout(p=pre_fc_dropout), nn.Linear(in_features, head_width), nn.LayerNorm(head_width), nn.Hardswish(), nn.Dropout(p=imagenet_head_dropout), nn.Linear(head_width, num_classes), ) for param in backbone.parameters(): param.requires_grad = False head_params = list(backbone.classifier.parameters()) for param in head_params: param.requires_grad = True imagenet_weight_decay = overrides.get("imagenet_weight_decay", None) head_group = {"params": head_params, "scale": 1.0} if imagenet_weight_decay is not None: head_group["weight_decay"] = float(imagenet_weight_decay) param_groups.append(head_group) adapt_params: List[nn.Parameter] = [] adapt_params.extend(backbone.features[-1][0].parameters()) for module in backbone.features[-2].modules(): if isinstance(module, nn.BatchNorm2d): adapt_params.extend(module.parameters()) adapt_params = _unique_parameters(adapt_params) for param in adapt_params: param.requires_grad = True if adapt_params: adapt_group = {"params": adapt_params, "scale": backbone_lr_factor} if imagenet_weight_decay is not None: adapt_group["weight_decay"] = float(imagenet_weight_decay) param_groups.append(adapt_group) model = backbone elif name in {"simple_cnn", "simplecnn"}: hidden_dims = overrides.get("simple_cnn_hidden_dims", (192,)) if isinstance(hidden_dims, Sequence) and not isinstance(hidden_dims, str): simple_dims = tuple(int(dim) for dim in hidden_dims) else: simple_dims = (int(hidden_dims),) model = create_simple_cnn(num_classes, hidden_dims=simple_dims) head_params = list(model.parameters()) param_groups.append({"params": head_params, "scale": 1.0}) elif name in {"ieee_cnn", "ieeecnn"}: hidden_dims = overrides.get("ieee_cnn_hidden_dims", (512, 256)) if isinstance(hidden_dims, Sequence) and not isinstance(hidden_dims, str): ieee_dims = tuple(int(dim) for dim in hidden_dims) else: ieee_dims = (int(hidden_dims),) dropout = float(overrides.get("ieee_cnn_dropout", 0.3)) model = create_ieee_cnn(num_classes, hidden_dims=ieee_dims, dropout=dropout) head_params = list(model.parameters()) param_groups.append({"params": head_params, "scale": 1.0}) else: raise ValueError(f"Unknown model: {name}") return model.to(device), param_groups def _unwrap_module(model: nn.Module) -> nn.Module: return model.module if isinstance(model, nn.DataParallel) else model def _strip_module_prefix(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: if not state_dict: return state_dict needs_strip = any(key.startswith("module.") for key in state_dict) if not needs_strip: return state_dict stripped = state_dict.__class__() if hasattr(state_dict, "__class__") else {} for key, value in state_dict.items(): new_key = key.split("module.", 1)[1] if key.startswith("module.") else key stripped[new_key] = value return stripped def _model_forward( model: nn.Module, specs: torch.Tensor, input_stats: Optional[torch.Tensor] = None, ) -> torch.Tensor: base_model = _unwrap_module(model) is_lwm_like = isinstance(base_model, LWMClassifier) if not is_lwm_like and LWMClassifierMinimal is not None: is_lwm_like = isinstance(base_model, LWMClassifierMinimal) if not is_lwm_like and hasattr(base_model, "spectrogram_to_tokens"): is_lwm_like = True if is_lwm_like: while specs.dim() > 3 and specs.size(1) == 1: specs = specs.squeeze(1) if specs.dim() != 3: specs = specs.view(specs.size(0), specs.size(-2), specs.size(-1)) if not is_lwm_like and specs.dim() == 3: specs = specs.unsqueeze(1) if input_stats is not None: input_stats = input_stats.to(specs.device, non_blocking=True) supports_stats = bool( is_lwm_like and hasattr(base_model, "append_input_stats") and getattr(base_model, "append_input_stats") ) if supports_stats and input_stats is not None: return model(specs, input_stats=input_stats) return model(specs) def train_one_epoch(model, loader, optimizer, device, scaler=None, batch_normalize: bool = False): criterion = nn.CrossEntropyLoss(reduction='mean') model.train() total_loss = 0.0 total = 0 for specs, labels in loader: specs = specs.to(device, non_blocking=True) if batch_normalize: specs = normalize_batch(specs) labels = labels.to(device, non_blocking=True) optimizer.zero_grad(set_to_none=True) # Use autocast only for CUDA if scaler is not None and device.type == 'cuda': with autocast(device_type='cuda'): logits = _model_forward(model, specs) loss = criterion(logits, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() else: # HPU and CPU use standard forward/backward logits = _model_forward(model, specs) loss = criterion(logits, labels) loss.backward() optimizer.step() total_loss += loss.item() * labels.size(0) total += labels.size(0) # Clear cache periodically to prevent memory fragmentation if device.type == 'cuda': torch.cuda.empty_cache() return total_loss / max(total, 1) @torch.no_grad() def evaluate( model, loader, device, debug: Optional[Dict[str, object]] = None, batch_normalize: bool = False, ) -> Tuple[float, float, float]: criterion = nn.CrossEntropyLoss(reduction='mean') model.eval() total_loss = 0.0 correct = 0 total = 0 all_preds: List[np.ndarray] = [] all_labels: List[np.ndarray] = [] debug_batches = int(debug.get("log_batches", 0)) if debug else 0 debug_every = max(1, int(debug.get("log_every", 1))) if debug else 1 log_softmax = bool(debug.get("log_softmax", False)) if debug else False debug_logged = 0 for batch_idx, batch in enumerate(loader, start=1): stats_batch: Optional[torch.Tensor] if isinstance(batch, (list, tuple)) and len(batch) == 3: specs, stats_batch, labels = batch stats_batch = stats_batch.to(device, non_blocking=True) else: specs, labels = batch # type: ignore[misc] stats_batch = None specs = specs.to(device, non_blocking=True) if batch_normalize: specs = normalize_batch(specs) labels = labels.to(device, non_blocking=True) # Use autocast only for CUDA, not for HPU or CPU if device.type == 'cuda': context = autocast(device_type='cuda') else: context = nullcontext() with context: logits = _model_forward(model, specs, stats_batch) loss = criterion(logits, labels) preds = logits.argmax(dim=1) total_loss += loss.item() * labels.size(0) correct += (preds == labels).sum().item() total += labels.size(0) all_preds.append(preds.detach().cpu().numpy()) all_labels.append(labels.detach().cpu().numpy()) should_log = ( debug_batches > 0 and debug_logged < debug_batches and ((batch_idx - 1) % debug_every == 0) ) if should_log: specs_cpu = specs.detach().cpu() logits_cpu = logits.detach().cpu() loss_scalar = float(loss.detach().cpu().item()) finite_specs = torch.isfinite(specs).all().item() finite_logits = torch.isfinite(logits).all().item() print( f" [DEBUG][eval][batch {batch_idx}] loss={loss_scalar:.6f} " f"reduction={criterion.reduction} labels_shape={tuple(labels.shape)}" ) print( f" specs dtype={specs.dtype} mean={specs_cpu.mean():.4f} std={specs_cpu.std():.4f} " f"min={specs_cpu.min():.4f} max={specs_cpu.max():.4f} finite={bool(finite_specs)}" ) print( f" logits dtype={logits.dtype} mean={logits_cpu.mean():.4f} std={logits_cpu.std():.4f} " f"min={logits_cpu.min():.4f} max={logits_cpu.max():.4f} finite={bool(finite_logits)}" ) unique_labels, counts = torch.unique(labels.detach().cpu(), return_counts=True) label_info = ", ".join( f"{int(lbl)}:{int(cnt)}" for lbl, cnt in zip(unique_labels, counts) ) print(f" label distribution -> {label_info}") if log_softmax: probs = torch.softmax(logits_cpu, dim=1) print( f" softmax mean={probs.mean():.4f} std={probs.std():.4f} " f"min={probs.min():.4f} max={probs.max():.4f}" ) debug_logged += 1 # Clear cache periodically if device.type == 'cuda': torch.cuda.empty_cache() y_true = np.concatenate(all_labels) if all_labels else np.empty(0) y_pred = np.concatenate(all_preds) if all_preds else np.empty(0) f1 = compute_f1(y_true, y_pred) if y_true.size > 0 else 0.0 return total_loss / max(total, 1), correct / max(total, 1), f1 def set_seed(seed: int) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # HPU seed setting (if available) if HPU_AVAILABLE and hasattr(torch.hpu, "manual_seed"): torch.hpu.manual_seed(seed) def main() -> None: # Set CUDA memory allocation configuration to reduce fragmentation if torch.cuda.is_available(): import os os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' args = parse_args() if args.early_min_epochs < 10: print( f"[INFO] Requested early_min_epochs={args.early_min_epochs} < 10; enforcing minimum of 10" ) args.early_min_epochs = 10 set_seed(args.seed) require_checkpoint = any(model.lower() == "lwm" for model in args.models) checkpoint, stats = resolve_checkpoint_and_stats(args, require_checkpoint=require_checkpoint) normalization_mode = str(stats.get("normalization", "dataset")).lower() print(f"[INFO] Normalization mode from stats: {normalization_mode}") data_root = Path(args.data_root) available_snrs, available_mobilities = discover_snr_mobility( data_root, args.cities, args.comm_types, args.fft_folder ) train_snrs = args.snrs if args.snrs else available_snrs train_mobilities = args.mobilities if args.mobilities else available_mobilities val_snrs = args.val_snrs if args.val_snrs else available_snrs val_mobilities = args.val_mobilities if args.val_mobilities else available_mobilities class_configs = build_config_map( data_root, args.cities, args.comm_types, train_snrs, train_mobilities, args.fft_folder ) active_labels = [cls for cls, configs in class_configs.items() if any(configs.values())] if not active_labels: raise RuntimeError("No modulation classes found with the provided filters.") class_configs = {cls: class_configs[cls] for cls in active_labels} label_to_local = {cls: idx for idx, cls in enumerate(sorted(active_labels))} num_classes = len(active_labels) print("[INFO] Active modulation classes:", ", ".join(LABEL_NAMES.get(cls, str(cls)) for cls in sorted(active_labels))) config_arrays = load_config_arrays(class_configs) global_config_map = build_config_map( Path(args.data_root), args.cities, args.comm_types, val_snrs, val_mobilities, args.fft_folder ) global_config_arrays = load_config_arrays(global_config_map) print("[INFO] Training SNRs:", ", ".join(train_snrs)) print("[INFO] Training mobilities:", ", ".join(train_mobilities)) print("[INFO] Validation/Test SNRs:", ", ".join(val_snrs)) print("[INFO] Validation/Test mobilities:", ", ".join(val_mobilities)) per_class_totals: Dict[int, int] = {} for cls in sorted(active_labels): configs = config_arrays[cls] total_samples = sum(arr.shape[0] for arr in configs.values()) per_class_totals[cls] = total_samples print(f"[INFO] Class {LABEL_NAMES.get(cls, str(cls))}: {len(configs)} configs, {total_samples} samples") if cls not in global_config_arrays or not global_config_arrays[cls]: raise RuntimeError(f"No global data found for modulation {LABEL_NAMES.get(cls, str(cls))}") min_class_total = min(per_class_totals.values()) max_train_per_class = min_class_total - args.val_per_class - args.test_per_class if max_train_per_class <= 0: raise RuntimeError( "Requested val/test splits leave no data for training. " f"Minimum class has {min_class_total} samples; " f"val={args.val_per_class}, test={args.test_per_class}." ) if any(size > max_train_per_class for size in args.train_sizes): adjusted: List[int] = [] for size in args.train_sizes: if size > max_train_per_class: print( f"[WARN] Requested train size {size} exceeds available " f"{max_train_per_class} per class after val/test splits; capping." ) capped = min(size, max_train_per_class) if capped not in adjusted: adjusted.append(capped) args.train_sizes = adjusted if adjusted else [max_train_per_class] print(f"[INFO] Effective train sizes per class: {args.train_sizes}") # Device selection: auto, cuda, hpu, or cpu requested_device = args.device.lower() if requested_device == "auto": if HPU_AVAILABLE: requested_device = "hpu" elif torch.cuda.is_available(): requested_device = "cuda" else: requested_device = "cpu" # Setup device based on selection if requested_device == "hpu": if not HPU_AVAILABLE: raise RuntimeError( "HPU device requested but not available. " "Install Habana PyTorch or select --device cuda/cpu." ) device = torch.device("hpu") # Set HPU device (typically device 0 for single-process) if hasattr(torch.hpu, "set_device"): torch.hpu.set_device(0) print(f"[INFO] Using HPU device") active_gpu_ids = [] # Not applicable for HPU multi_gpu = False elif requested_device == "cuda": cuda_available = torch.cuda.is_available() if not cuda_available: raise RuntimeError("CUDA device requested but not available.") available_gpu_ids = list(range(torch.cuda.device_count())) if args.gpu_ids is not None: invalid_ids = [gpu_id for gpu_id in args.gpu_ids if gpu_id not in available_gpu_ids] if invalid_ids: raise ValueError( f"Requested GPU IDs not available: {invalid_ids}; available: {available_gpu_ids}" ) active_gpu_ids = list(dict.fromkeys(args.gpu_ids)) else: active_gpu_ids = available_gpu_ids if active_gpu_ids: primary_gpu = active_gpu_ids[0] torch.cuda.set_device(primary_gpu) device = torch.device(f"cuda:{primary_gpu}") print(f"[INFO] Using CUDA device(s): {', '.join(str(i) for i in active_gpu_ids)}") else: device = torch.device("cpu") print("[INFO] CUDA requested but no GPUs available, using CPU") multi_gpu = len(active_gpu_ids) > 1 if multi_gpu: print(f"[INFO] Enabling DataParallel across GPUs: {', '.join(str(i) for i in active_gpu_ids)}") else: # cpu device = torch.device("cpu") if args.gpu_ids is not None: print("[WARN] GPU IDs specified but using CPU") print("[INFO] Using CPU") active_gpu_ids = [] multi_gpu = False print(f"[INFO] Using device: {device}") args.output_dir.mkdir(parents=True, exist_ok=True) print(f"[INFO] Saving outputs under: {args.output_dir}") eval_debug_config: Optional[Dict[str, object]] = None if args.debug_eval_batches > 0: eval_debug_config = { "log_batches": int(args.debug_eval_batches), "log_every": max(1, int(args.debug_eval_interval)), "log_softmax": bool(args.debug_eval_softmax), } print( "[INFO] Evaluation debug logging enabled -> batches:" f" {eval_debug_config['log_batches']}, interval: {eval_debug_config['log_every']}" ) summary_device = torch.device("cpu") if device.type == "cuda" else device model_overrides: Dict[str, object] = { "resnet_head_width": args.resnet_head_width, "efficientnet_head_width": args.efficientnet_head_width, "mobilenet_head_width": args.mobilenet_head_width, "simple_cnn_hidden_dims": tuple(args.simple_cnn_hidden_dims), "ieee_cnn_hidden_dims": tuple(args.ieee_cnn_hidden_dims), "ieee_cnn_dropout": args.ieee_cnn_dropout, "lwm_classifier_dim": args.lwm_classifier_dim, "lwm_head_dropout": args.lwm_head_dropout, "lwm_head_type": args.lwm_head_type, "imagenet_head_dropout": args.imagenet_head_dropout, "imagenet_weight_decay": args.weight_decay * args.imagenet_weight_decay_scale, } print("\n[INFO] Parameter counts per model (total/trainable):") for model_name in args.models: lower_name = model_name.lower() trainable_layers = args.lwm_trainable_layers if lower_name == "lwm" else 0 model_checkpoint = checkpoint backbone_lr_factor = args.backbone_lr_factor if lower_name == "lwm" and args.lwm_backbone_lr_factor is not None: backbone_lr_factor = args.lwm_backbone_lr_factor model, _ = build_model( model_name, num_classes, model_checkpoint, summary_device, trainable_layers=trainable_layers, backbone_lr_factor=backbone_lr_factor, overrides=model_overrides, ) total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f" {model_name}: {total_params:,} / {trainable_params:,}") del model if device.type == 'cuda': torch.cuda.empty_cache() raw_input_models = set(args.raw_input_models) active_raw_models = [model for model in args.models if model.lower() in raw_input_models] if active_raw_models: print( "[INFO] Raw spectrogram input (per-batch normalization) for models: " + ", ".join(active_raw_models) ) normalized_models = [model for model in args.models if model.lower() not in raw_input_models] requires_normalized_inputs = len(normalized_models) > 0 if requires_normalized_inputs: print( "[INFO] Applying normalization for models: " + ", ".join(normalized_models) ) else: print("[INFO] All selected models consume raw spectrograms; normalization skipped") summary: Dict[str, Dict[int, Dict[str, List[float]]]] = { model: {size: {"acc": [], "f1": [], "val_f1": [], "val_loss": []} for size in args.train_sizes} for model in args.models } train_sizes_sorted = sorted(args.train_sizes) for repetition in range(1, args.repetitions + 1): selection_for_repetition: Dict[int, Dict[str, set[int]]] = {} val_rng_seed = args.seed + repetition * 100000 val_rng = np.random.default_rng(val_rng_seed) fixed_val_samples: Dict[int, np.ndarray] = {} fixed_test_samples: Dict[int, np.ndarray] = {} val_reserved_indices: Dict[int, Dict[str, set[int]]] = {} test_reserved_indices: Dict[int, Dict[str, set[int]]] = {} for cls in sorted(config_arrays.keys()): global_arrays = global_config_arrays[cls] global_avail = {cfg: set(range(arr.shape[0])) for cfg, arr in global_arrays.items()} val_samples, val_used = sample_global_arrays(global_arrays, global_avail, args.val_per_class, val_rng) test_samples, test_used = sample_global_arrays(global_arrays, global_avail, args.test_per_class, val_rng) fixed_val_samples[cls] = val_samples fixed_test_samples[cls] = test_samples val_reserved_indices[cls] = {cfg: set(indices) for cfg, indices in val_used.items()} test_reserved_indices[cls] = {cfg: set(indices) for cfg, indices in test_used.items()} for train_size in train_sizes_sorted: rep_seed = args.seed + train_size * 1000 + repetition rng = np.random.default_rng(rep_seed) repetition_records: List[Dict[str, object]] = [] per_size_val_metrics: List[Tuple[str, float, float, float]] = [] train_specs, train_labels = [], [] val_specs, val_labels = [], [] test_specs, test_labels = [], [] class_contexts: Dict[int, Dict[str, Any]] = {} class_capacities: Dict[int, int] = {} for cls in sorted(config_arrays.keys()): arrays_map = config_arrays[cls] if not arrays_map: raise RuntimeError(f"No data for class {LABEL_NAMES[cls]}") if cls not in selection_for_repetition: selection_for_repetition[cls] = defaultdict(set) prev_config_indices = selection_for_repetition[cls] prev_total = sum(len(sel_indices) for sel_indices in prev_config_indices.values()) if prev_total > train_size: raise ValueError( f"Requested train size {train_size} is smaller than previously selected {prev_total} " f"for class {LABEL_NAMES[cls]}" ) train_avail = {config: set(range(arr.shape[0])) for config, arr in arrays_map.items()} for config, sel_indices in prev_config_indices.items(): if sel_indices and config in train_avail: train_avail[config].difference_update(sel_indices) val_reserved = val_reserved_indices.get(cls, {}) test_reserved = test_reserved_indices.get(cls, {}) for config, reserved in val_reserved.items(): if reserved and config in train_avail: train_avail[config].difference_update(reserved) for config, reserved in test_reserved.items(): if reserved and config in train_avail: train_avail[config].difference_update(reserved) available_now = sum(len(indices) for indices in train_avail.values()) capacity = prev_total + available_now class_contexts[cls] = { "arrays_map": arrays_map, "prev_indices": prev_config_indices, "train_avail": train_avail, } class_capacities[cls] = capacity if not class_capacities: raise RuntimeError("No modulation classes available for training") min_capacity = min(class_capacities.values()) limiting_classes = sorted(cls for cls, cap in class_capacities.items() if cap == min_capacity) effective_train_size = min(train_size, min_capacity) if effective_train_size < train_size: limiting_labels = ", ".join(LABEL_NAMES.get(cls, str(cls)) for cls in limiting_classes) if not limiting_labels: limiting_labels = "unknown" print( f"[WARN] Requested train size {train_size} exceeds available " f"{min_capacity} after reserving val/test samples; using {effective_train_size} " f"(limited by {limiting_labels})" ) if effective_train_size <= 0: raise RuntimeError("No training samples available after reserving val/test splits") for cls in sorted(config_arrays.keys()): ctx = class_contexts[cls] arrays_map = ctx["arrays_map"] prev_config_indices = ctx["prev_indices"] train_avail = ctx["train_avail"] selected_arrays: List[np.ndarray] = [] prev_total = 0 for config, sel_indices in prev_config_indices.items(): if sel_indices: idx_sorted = sorted(sel_indices) selected_arrays.append(arrays_map[config][idx_sorted]) prev_total += len(sel_indices) needed = max(effective_train_size - prev_total, 0) if needed > 0: additional_samples, train_used = sample_train_arrays(arrays_map, train_avail, needed, rng) if additional_samples.size == 0: raise RuntimeError("Failed to collect additional training samples") selected_arrays.append(additional_samples) for config, indices in train_used.items(): prev_config_indices[config].update(int(idx) for idx in indices) if not selected_arrays: raise RuntimeError("No training samples collected") train_samples = np.concatenate(selected_arrays, axis=0) if train_samples.shape[0] != effective_train_size: print( f"[WARN] Collected {train_samples.shape[0]} training samples for " f"{LABEL_NAMES.get(cls, str(cls))}, expected {effective_train_size}" ) val_samples = fixed_val_samples[cls] test_samples = fixed_test_samples[cls] train_specs.append(train_samples) val_specs.append(val_samples) test_specs.append(test_samples) local_label = label_to_local[cls] train_labels.append(np.full(train_samples.shape[0], local_label, dtype=np.int64)) val_labels.append(np.full(val_samples.shape[0], local_label, dtype=np.int64)) test_labels.append(np.full(test_samples.shape[0], local_label, dtype=np.int64)) train_specs_raw = np.concatenate(train_specs) val_specs_raw = np.concatenate(val_specs) test_specs_raw = np.concatenate(test_specs) train_labels = np.concatenate(train_labels) val_labels = np.concatenate(val_labels) test_labels = np.concatenate(test_labels) # Verify no data leakage (all splits are disjoint) # Note: Since we sample from different configs with availability tracking, # there should be no overlap, but we verify to be safe print( f"[INFO] Verifying data splits for train_size={train_size} " f"(effective {effective_train_size}), rep={repetition}..." ) print( f" Train: {len(train_labels)} samples " f"(~{effective_train_size} per class expected)" ) print(f" Val: {len(val_labels)} samples ({args.val_per_class} per class)") print(f" Test: {len(test_labels)} samples ({args.test_per_class} per class)") # Check class balance train_class_counts = Counter(train_labels) val_class_counts = Counter(val_labels) test_class_counts = Counter(test_labels) print(f"[INFO] Train class distribution: {dict(train_class_counts)}") print(f"[INFO] Val class distribution: {dict(val_class_counts)}") print(f"[INFO] Test class distribution: {dict(test_class_counts)}") # Verify all classes have expected counts expected_train_per_class = effective_train_size for cls_idx in range(num_classes): if train_class_counts[cls_idx] != expected_train_per_class: print(f"[WARN] Class {cls_idx} has {train_class_counts[cls_idx]} train samples, expected {expected_train_per_class}") if val_class_counts[cls_idx] != args.val_per_class: print(f"[WARN] Class {cls_idx} has {val_class_counts[cls_idx]} val samples, expected {args.val_per_class}") if test_class_counts[cls_idx] != args.test_per_class: print(f"[WARN] Class {cls_idx} has {test_class_counts[cls_idx]} test samples, expected {args.test_per_class}") print(f"[INFO] ✓ All splits have balanced class distribution") train_ds_raw = SpectrogramDataset(train_specs_raw, train_labels) val_ds_raw = SpectrogramDataset(val_specs_raw, val_labels) test_ds_raw = SpectrogramDataset(test_specs_raw, test_labels) train_loader_raw = DataLoader( train_ds_raw, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=False, ) val_loader_raw = DataLoader( val_ds_raw, batch_size=args.batch_size, shuffle=False, num_workers=2, pin_memory=False, ) test_loader_raw = DataLoader( test_ds_raw, batch_size=args.batch_size, shuffle=False, num_workers=2, pin_memory=False, ) train_loader_normalized: Optional[DataLoader] = None val_loader_normalized: Optional[DataLoader] = None test_loader_normalized: Optional[DataLoader] = None if requires_normalized_inputs: train_specs_normalized = apply_normalization(train_specs_raw, stats) val_specs_normalized = apply_normalization(val_specs_raw, stats) test_specs_normalized = apply_normalization(test_specs_raw, stats) train_ds_normalized = SpectrogramDataset(train_specs_normalized, train_labels) val_ds_normalized = SpectrogramDataset(val_specs_normalized, val_labels) test_ds_normalized = SpectrogramDataset(test_specs_normalized, test_labels) train_loader_normalized = DataLoader( train_ds_normalized, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=False, ) val_loader_normalized = DataLoader( val_ds_normalized, batch_size=args.batch_size, shuffle=False, num_workers=2, pin_memory=False, ) test_loader_normalized = DataLoader( test_ds_normalized, batch_size=args.batch_size, shuffle=False, num_workers=2, pin_memory=False, ) rep_root = args.output_dir / f"size_{train_size}" / f"rep_{repetition}" rep_root.mkdir(parents=True, exist_ok=True) for model_name in args.models: model_root = rep_root / model_name model_root.mkdir(parents=True, exist_ok=True) epoch_ckpt_dir: Optional[Path] = None if args.save_epoch_checkpoints: epoch_ckpt_dir = model_root / "epoch_checkpoints" epoch_ckpt_dir.mkdir(parents=True, exist_ok=True) for old_path in epoch_ckpt_dir.glob("epoch_*.pth"): if old_path.is_file(): old_path.unlink() print( f"\n[INFO] Size {train_size} (effective {effective_train_size}), " f"repetition {repetition}, model {model_name}" ) set_seed(rep_seed + hash(model_name) % 1000) lower_name = model_name.lower() use_raw_input = lower_name in raw_input_models if use_raw_input: train_loader = train_loader_raw val_loader = val_loader_raw test_loader = test_loader_raw print(" [INFO] Feeding raw spectrograms with per-batch normalization") else: if ( train_loader_normalized is None or val_loader_normalized is None or test_loader_normalized is None ): raise RuntimeError( "Normalized loaders were requested but could not be constructed." ) train_loader = train_loader_normalized val_loader = val_loader_normalized test_loader = test_loader_normalized trainable_layers = args.lwm_trainable_layers if lower_name == "lwm" else 0 backbone_lr_factor = args.backbone_lr_factor if lower_name == "lwm" and args.lwm_backbone_lr_factor is not None: backbone_lr_factor = args.lwm_backbone_lr_factor model_checkpoint = checkpoint model, param_groups = build_model( model_name, num_classes, model_checkpoint, device, trainable_layers=trainable_layers, backbone_lr_factor=backbone_lr_factor, overrides=model_overrides, ) if multi_gpu: model = nn.DataParallel(model, device_ids=active_gpu_ids) total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print( f"[INFO] Parameters (total/trainable): {total_params:,} / {trainable_params:,}" ) def make_optimizer(base_lr: float) -> torch.optim.Optimizer: optim_groups: List[Dict[str, object]] = [] if param_groups: for group in param_groups: scale = float(group.get("scale", 1.0)) params = [p for p in group.get("params", []) if p.requires_grad] if params: group_cfg: Dict[str, object] = { "params": list(params), "lr": base_lr * scale, } if "weight_decay" in group: group_cfg["weight_decay"] = float(group["weight_decay"]) optim_groups.append(group_cfg) if not optim_groups: optim_groups.append({ "params": [p for p in model.parameters() if p.requires_grad], "lr": base_lr, }) return torch.optim.AdamW(optim_groups, lr=base_lr, weight_decay=args.weight_decay) def make_scheduler(optimizer: torch.optim.Optimizer, base_lr: float, patience_limit: int): plateau_patience = max(2, patience_limit // 2) return torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", factor=0.5, patience=plateau_patience, min_lr=base_lr * 0.01, ) # Initialize mixed precision scaler for CUDA # GradScaler only for CUDA, not for HPU or CPU scaler = GradScaler('cuda') if device.type == 'cuda' else None best_val_loss = float("inf") best_val_acc = 0.0 best_state = None epoch_history: List[Dict[str, object]] = [] best_val_f1 = 0.0 best_epoch = 0 total_epochs_ran = 0 overall_early_stopped = False phase_configs = [ { "name": "main", "max_epochs": args.epochs, "base_lr": args.lr, "patience": max(1, args.early_patience), "min_epochs": max(0, args.early_min_epochs), } ] ft_epochs = max(0, args.finetune_epochs) ft_lr_factor = args.finetune_lr_factor ft_patience = max(1, args.finetune_patience) ft_min_epochs = max(0, args.finetune_min_epochs) if ft_epochs > 0: phase_configs.append( { "name": "finetune", "max_epochs": ft_epochs, "base_lr": args.lr * ft_lr_factor, "patience": ft_patience, "min_epochs": ft_min_epochs, } ) for phase_idx, phase in enumerate(phase_configs): if phase["max_epochs"] <= 0: continue if phase_idx > 0: print( f"\n [INFO] Starting {phase['name']} phase: lr={phase['base_lr']:.2e}, " f"max_epochs={phase['max_epochs']}" ) if best_state is not None: model.load_state_dict(best_state["model"]) optimizer = make_optimizer(phase["base_lr"]) scheduler = make_scheduler(optimizer, phase["base_lr"], phase["patience"]) patience_counter = 0 phase_early_stopped = False phase_epochs_completed = 0 phase_min_epochs = max(0, phase["min_epochs"]) for local_epoch in range(1, phase["max_epochs"] + 1): overall_epoch = total_epochs_ran + local_epoch train_loss = train_one_epoch( model, train_loader, optimizer, device, scaler, batch_normalize=use_raw_input, ) val_loss, val_acc, val_f1 = evaluate( model, val_loader, device, eval_debug_config, batch_normalize=use_raw_input, ) scheduler.step(val_loss) current_lr = optimizer.param_groups[0]["lr"] print( f" [{phase['name']}] Epoch {overall_epoch:02d}: " f"train_loss={train_loss:.4f} val_loss={val_loss:.4f} " f"val_acc={val_acc:.4%} val_f1={val_f1:.4f}" ) epoch_history.append( { "epoch": int(overall_epoch), "train_loss": float(train_loss), "val_loss": float(val_loss), "val_acc": float(val_acc), "val_f1": float(val_f1), "lr": float(current_lr), "phase": phase["name"], } ) repetition_records.append( { "model": model_name, "epoch": int(overall_epoch), "phase": phase["name"], "train_loss": float(train_loss), "val_loss": float(val_loss), "val_acc": float(val_acc), "val_f1": float(val_f1), "lr": float(current_lr), "train_size_requested": int(train_size), "train_size_effective": int(effective_train_size), } ) _write_epoch_history(rep_root, repetition_records, args.save_epoch_history, args.plot_epoch_history) raw_epoch_state = _strip_module_prefix(model.state_dict()) if epoch_ckpt_dir is not None: epoch_state = raw_epoch_state.__class__() for key, value in raw_epoch_state.items(): epoch_state[key] = value.detach().cpu() epoch_ckpt_path = epoch_ckpt_dir / f"epoch_{overall_epoch:03d}.pth" torch.save(epoch_state, epoch_ckpt_path) if val_loss < best_val_loss: best_val_loss = val_loss best_val_acc = val_acc best_val_f1 = val_f1 best_model_state = { key: value.detach().cpu() for key, value in model.state_dict().items() } best_state = { "model": best_model_state, "val_loss": val_loss, "val_acc": val_acc, "val_f1": val_f1, "epoch": int(overall_epoch), "lr": current_lr, "phase": phase["name"], } best_epoch = int(overall_epoch) patience_counter = 0 else: if local_epoch >= phase_min_epochs: patience_counter += 1 if patience_counter >= phase["patience"]: print( f" [INFO] Early stopping ({phase['name']}) at epoch {overall_epoch:02d} " f"after {patience_counter} epochs without val loss improvement" ) overall_early_stopped = True phase_early_stopped = True phase_epochs_completed = local_epoch break phase_epochs_completed = local_epoch total_epochs_ran += phase_epochs_completed if phase_early_stopped is False and phase_epochs_completed < phase["max_epochs"]: # Loop exited early via break without setting the flag (should not happen) phase_early_stopped = True if best_state is None: raise RuntimeError("Training finished without recording a validation improvement") model.load_state_dict(best_state["model"]) test_loss, test_acc, test_f1 = evaluate( model, test_loader, device, eval_debug_config, batch_normalize=use_raw_input, ) print( f" -> Test loss={test_loss:.4f} Test acc={test_acc:.4%} Test f1={test_f1:.4f}" ) export_dir = getattr(args, "export_full_model", None) if export_dir is not None: export_dir = export_dir.expanduser().resolve() export_dir.mkdir(parents=True, exist_ok=True) comm_token = getattr(args, "comm_suffix", "multi") filename = f"{comm_token}_{model_name}_size{train_size}_rep{repetition}.pth" export_path = export_dir / filename full_state = {k: v.detach().cpu() for k, v in model.state_dict().items()} torch.save(full_state, export_path) print(f" [INFO] Saved full model (backbone + head) to {export_path}") summary[model_name][train_size]["acc"].append(test_acc) summary[model_name][train_size]["f1"].append(test_f1) summary[model_name][train_size]["val_f1"].append(best_state["val_f1"]) summary[model_name][train_size]["val_loss"].append(best_state["val_loss"]) per_size_val_metrics.append( (model_name, best_state["val_f1"], best_state["val_loss"], test_f1) ) result_dir = args.output_dir / f"size_{train_size}" / f"rep_{repetition}" / model_name result_dir.mkdir(parents=True, exist_ok=True) state_to_save = copy.deepcopy(best_state) state_to_save["model"] = _strip_module_prefix(state_to_save["model"]) torch.save(state_to_save, result_dir / "checkpoint.pt") with open(result_dir / "metrics.json", "w", encoding="utf-8") as f: json.dump( { "train_size_per_class": effective_train_size, "train_size_per_class_requested": train_size, "repetition": repetition, "model": model_name, "best_val_loss": best_state.get("val_loss", None), "best_val_acc": best_val_acc, "best_val_f1": best_state["val_f1"], "test_loss": test_loss, "test_acc": test_acc, "test_f1": test_f1, "best_epoch": best_epoch, "epochs_ran": total_epochs_ran, "early_stopped": overall_early_stopped, "history": epoch_history, }, f, indent=2, ) rep_root = args.output_dir / f"size_{train_size}" / f"rep_{repetition}" rep_root.mkdir(parents=True, exist_ok=True) if args.plot_epoch_history and HAVE_MPL and repetition_records: models_in_run = sorted({rec["model"] for rec in repetition_records}) fig, axes = plt.subplots(2, 1, figsize=(8, 6), sharex=True) for ax in axes: ax.grid(True, linestyle='--', alpha=0.3) for model_name_plot in models_in_run: model_records = [rec for rec in repetition_records if rec["model"] == model_name_plot] if not model_records: continue epochs = [rec["epoch"] for rec in model_records] val_loss_values = [rec["val_loss"] for rec in model_records] val_f1_values = [rec["val_f1"] for rec in model_records] axes[0].plot(epochs, val_loss_values, marker='o', label=model_name_plot) axes[1].plot(epochs, val_f1_values, marker='o', label=model_name_plot) axes[0].set_ylabel('Val Loss') axes[1].set_ylabel('Val F1') axes[1].set_xlabel('Epoch') axes[0].legend(loc='best') axes[0].set_title(f'Size {train_size} / Rep {repetition} per-epoch metrics') fig.tight_layout() fig.savefig(rep_root / 'epoch_history.png', dpi=150) plt.close(fig) # Clean up memory after each model del model, optimizer, scheduler, best_state if scaler is not None: del scaler if device.type == 'cuda': torch.cuda.empty_cache() torch.cuda.synchronize() if per_size_val_metrics: print(f"\n[INFO] Validation summary for train_size={train_size}, rep={repetition}:") for model_name, val_f1, val_loss, test_f1 in sorted( per_size_val_metrics, key=lambda item: item[1], reverse=True ): print( f" {model_name:<16} val_f1={val_f1:.4f} " f"val_loss={val_loss:.4f} test_f1={test_f1:.4f}" ) summary_path = args.output_dir / "summary.json" summary_path.parent.mkdir(parents=True, exist_ok=True) serializable_summary = { model_name: { size: { "acc": metrics["acc"], "f1": metrics["f1"], "val_f1": metrics["val_f1"], "val_loss": metrics["val_loss"], } for size, metrics in size_dict.items() } for model_name, size_dict in summary.items() } with open(summary_path, "w", encoding="utf-8") as f: json.dump(serializable_summary, f, indent=2) print("\n[INFO] Final accuracy summary:") for model_name, results in summary.items(): for size, metrics in results.items(): if metrics["acc"]: acc_mean = float(np.mean(metrics["acc"])) acc_std = float(np.std(metrics["acc"])) f1_mean = float(np.mean(metrics["f1"])) f1_std = float(np.std(metrics["f1"])) n = len(metrics["acc"]) print( f" {model_name} @ {size:4d}/class -> " f"acc={acc_mean:.4%} ± {acc_std:.4%}, f1={f1_mean:.4f} ± {f1_std:.4f} (n={n})" ) print("\n[INFO] Final validation F1 summary:") for model_name, results in summary.items(): for size, metrics in results.items(): if metrics["val_f1"]: val_mean = float(np.mean(metrics["val_f1"])) val_std = float(np.std(metrics["val_f1"])) n = len(metrics["val_f1"]) print( f" {model_name} @ {size:4d}/class -> " f"val_f1={val_mean:.4f} ± {val_std:.4f} (n={n})" ) print("\n[INFO] Final validation loss summary:") for model_name, results in summary.items(): for size, metrics in results.items(): if metrics["val_loss"]: loss_mean = float(np.mean(metrics["val_loss"])) loss_std = float(np.std(metrics["val_loss"])) n = len(metrics["val_loss"]) print( f" {model_name} @ {size:4d}/class -> " f"val_loss={loss_mean:.4f} ± {loss_std:.4f} (n={n})" ) if HAVE_MPL: train_sizes_sorted = sorted(args.train_sizes) plt.figure(figsize=(8, 5)) plotted = False for model_name in args.models: model_results = summary.get(model_name, {}) means: List[float] = [] for size in train_sizes_sorted: val_list = model_results.get(size, {}).get("val_f1", []) means.append(float(np.mean(val_list)) if val_list else float("nan")) if not any(np.isfinite(means)): continue plt.plot(train_sizes_sorted, means, marker="o", linewidth=2, label=model_name) plotted = True if plotted: plt.title("Validation F1 vs. Training Size") plt.xlabel("Training samples per class") plt.ylabel("Validation F1 (macro)") plt.xticks(train_sizes_sorted) plt.ylim(0.0, 1.0) plt.grid(True, which="both", linestyle="--", alpha=0.4) plt.legend(title="Model", frameon=False) plt.tight_layout() plot_path = args.output_dir / "val_f1_summary.png" plt.savefig(plot_path, dpi=200) plt.close() print(f"[INFO] Saved validation F1 plot to {plot_path}") else: plt.close() print("[WARN] No validation F1 data available to plot.") else: print("[WARN] Matplotlib not available; skipping validation F1 plot.") if __name__ == "__main__": main()