|
|
|
|
|
"""Benchmark multiple backbones on modulation classification across dataset sizes. |
|
|
|
|
|
For each desired training size (samples per MCS class) and repetition, the script: |
|
|
1. Randomly samples spectrograms from distinct (modulation, rate, SNR, doppler) configs. |
|
|
2. Builds train/val/test splits (val/test sizes are configurable). |
|
|
3. Fine-tunes several backbones (LWM, ResNet18, EfficientNet-B0, MobileNet-V3, |
|
|
and a small CNN) using the same splits. |
|
|
4. Reports accuracy statistics and stores checkpoints/metrics per experiment. |
|
|
|
|
|
Input spectrograms are globally normalized using the dataset mean/std stored with |
|
|
the specified pretrained checkpoint (defaults to the latest run under `models/`). |
|
|
|
|
|
Usage example (defaults cover city_1_losangeles/LTE with all available SNR·mobility combos): |
|
|
|
|
|
python task1/train_mcs_models.py --train-sizes 128 --models resnet18 mobilenet_v3_small |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import argparse |
|
|
import copy |
|
|
import csv |
|
|
import glob |
|
|
import json |
|
|
import os |
|
|
import pickle |
|
|
import random |
|
|
import re |
|
|
import sys |
|
|
from collections import Counter, defaultdict |
|
|
from pathlib import Path |
|
|
|
|
|
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple |
|
|
|
|
|
from contextlib import nullcontext |
|
|
from datetime import datetime |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import DataLoader, Dataset |
|
|
from torch.amp import autocast, GradScaler |
|
|
|
|
|
try: |
|
|
from tqdm import tqdm |
|
|
except ImportError: |
|
|
def tqdm(iterable, *args, **kwargs): |
|
|
return iterable |
|
|
|
|
|
PROJECT_ROOT = Path(__file__).resolve().parent.parent |
|
|
if str(PROJECT_ROOT) not in sys.path: |
|
|
sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
|
|
|
from pretraining.pretrained_model import lwm as lwm_model |
|
|
from utils import count_parameters |
|
|
|
|
|
|
|
|
COMM_CANONICAL = { |
|
|
"lte": "LTE", |
|
|
"wifi": "WiFi", |
|
|
"5g": "5G", |
|
|
} |
|
|
COMM_LOWER = {v: k for k, v in COMM_CANONICAL.items()} |
|
|
try: |
|
|
from sklearn.metrics import f1_score as sklearn_f1_score |
|
|
HAVE_SKLEARN = True |
|
|
except ImportError: |
|
|
HAVE_SKLEARN = False |
|
|
|
|
|
try: |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
HAVE_MPL = True |
|
|
except ImportError: |
|
|
HAVE_MPL = False |
|
|
|
|
|
try: |
|
|
from task2.mobility_utils import LWMClassifierMinimal |
|
|
except ImportError: |
|
|
LWMClassifierMinimal = None |
|
|
|
|
|
|
|
|
HPU_AVAILABLE = False |
|
|
try: |
|
|
import habana_frameworks.torch.core as htcore |
|
|
HPU_AVAILABLE = hasattr(torch, "hpu") and torch.hpu.is_available() |
|
|
except (ImportError, AttributeError): |
|
|
pass |
|
|
|
|
|
|
|
|
def compute_f1(y_true: np.ndarray, y_pred: np.ndarray) -> float: |
|
|
if HAVE_SKLEARN: |
|
|
return float(sklearn_f1_score(y_true, y_pred, average="macro")) |
|
|
classes = np.unique(np.concatenate([y_true, y_pred])) |
|
|
scores = [] |
|
|
for cls in classes: |
|
|
tp = np.sum((y_true == cls) & (y_pred == cls)) |
|
|
fp = np.sum((y_true != cls) & (y_pred == cls)) |
|
|
fn = np.sum((y_true == cls) & (y_pred != cls)) |
|
|
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 |
|
|
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 |
|
|
denom = precision + recall |
|
|
f1 = (2 * precision * recall / denom) if denom > 0 else 0.0 |
|
|
scores.append(f1) |
|
|
return float(np.mean(scores)) |
|
|
|
|
|
|
|
|
MODULATION_LABELS = { |
|
|
"BPSK": 0, |
|
|
"QPSK": 1, |
|
|
"QAM16": 2, |
|
|
"QAM64": 3, |
|
|
"QAM256": 4, |
|
|
} |
|
|
|
|
|
LABEL_NAMES = {idx: name for name, idx in MODULATION_LABELS.items()} |
|
|
|
|
|
DEFAULT_LWM_TRAINABLE_LAYERS = 2 |
|
|
|
|
|
_SAMPLE_COUNT_CACHE: Dict[str, int] = {} |
|
|
|
|
|
|
|
|
def normalize_per_sample(specs: np.ndarray, eps: float = 1e-6) -> np.ndarray: |
|
|
if specs.size == 0: |
|
|
return specs.astype(np.float32, copy=False) |
|
|
means = specs.mean(axis=(1, 2), keepdims=True) |
|
|
stds = specs.std(axis=(1, 2), keepdims=True) |
|
|
stds = np.maximum(stds, eps) |
|
|
normalized = (specs - means) / stds |
|
|
return normalized.astype(np.float32, copy=False) |
|
|
|
|
|
|
|
|
def apply_normalization(specs: np.ndarray, stats: Dict[str, object]) -> np.ndarray: |
|
|
mode = str(stats.get("normalization", "dataset")).lower() |
|
|
mean = float(stats.get("mean", 0.0)) |
|
|
std = float(stats.get("std", 1.0)) |
|
|
if abs(std) < 1e-6: |
|
|
std = 1e-6 |
|
|
if mode == "dataset": |
|
|
return ((specs.astype(np.float32, copy=False) - mean) / std).astype(np.float32, copy=False) |
|
|
return normalize_per_sample(specs) |
|
|
|
|
|
|
|
|
def _unique_parameters(params: Iterable[nn.Parameter]) -> List[nn.Parameter]: |
|
|
seen: set[int] = set() |
|
|
unique: List[nn.Parameter] = [] |
|
|
for param in params: |
|
|
pid = id(param) |
|
|
if pid not in seen: |
|
|
unique.append(param) |
|
|
seen.add(pid) |
|
|
return unique |
|
|
|
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
|
parser = argparse.ArgumentParser(description=__doc__) |
|
|
parser.add_argument( |
|
|
"--data-root", |
|
|
default=str(PROJECT_ROOT / "spectrograms"), |
|
|
help="Root directory containing city folders (default: project_root/spectrograms)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--cities", |
|
|
nargs="*", |
|
|
default=["city_1_losangeles"], |
|
|
help="City directories to include (default: %(default)s)", |
|
|
) |
|
|
parser.add_argument("--comm-types", nargs="*", default=["LTE"], help="Communication standards to include (default: %(default)s)") |
|
|
parser.add_argument("--LTE", dest="select_lte", action="store_true", help="Shortcut for --comm-types LTE") |
|
|
parser.add_argument("--WiFi", dest="select_wifi", action="store_true", help="Shortcut for --comm-types WiFi") |
|
|
parser.add_argument("--5G", dest="select_5g", action="store_true", help="Shortcut for --comm-types 5G") |
|
|
parser.add_argument("--snrs", nargs="*", default=None, help="SNR folders to include for training (default: all available)") |
|
|
parser.add_argument("--val-snrs", nargs="*", default=None, help="SNR folders for validation/test (default: all available)") |
|
|
parser.add_argument( |
|
|
"--mobilities", |
|
|
nargs="*", |
|
|
default=None, |
|
|
help="Mobility folders to include for training (default: all available)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--val-mobilities", |
|
|
nargs="*", |
|
|
default=None, |
|
|
help="Mobility folders for validation/test (default: all available)", |
|
|
) |
|
|
parser.add_argument("--fft-folder", default="512FFT", help="FFT folder name (default: %(default)s)") |
|
|
parser.add_argument( |
|
|
"--device", |
|
|
type=str, |
|
|
default="auto", |
|
|
choices=["auto", "cuda", "hpu", "cpu"], |
|
|
help="Device to use for training (default: auto - detects HPU, then CUDA, then CPU)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--gpu-ids", |
|
|
type=int, |
|
|
nargs="*", |
|
|
default=None, |
|
|
help="Specific GPU device IDs to use (only for CUDA, default: all visible GPUs)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--train-sizes", |
|
|
type=int, |
|
|
nargs="*", |
|
|
default=[2, 4, 8, 16, 32, 64, 128, 256], |
|
|
help="Training samples per class to benchmark", |
|
|
) |
|
|
parser.add_argument("--val-per-class", type=int, default=512, help="Validation samples per class") |
|
|
parser.add_argument("--test-per-class", type=int, default=512, help="Test samples per class") |
|
|
parser.add_argument("--repetitions", type=int, default=1, help="Repetitions per train size") |
|
|
parser.add_argument("--epochs", type=int, default=200, help="Epochs per run") |
|
|
parser.add_argument("--batch-size", type=int, default=32, help="Mini-batch size") |
|
|
parser.add_argument("--lr", type=float, default=8e-4, help="Learning rate for fine-tuning") |
|
|
parser.add_argument("--weight-decay", type=float, default=3e-2, help="Weight decay") |
|
|
parser.add_argument( |
|
|
"--no-epoch-history", |
|
|
action="store_true", |
|
|
help="Disable aggregated per-epoch history tracking", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--no-epoch-plot", |
|
|
action="store_true", |
|
|
help="Disable per-repetition metric plots", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--save-epoch-checkpoints", |
|
|
action="store_true", |
|
|
help="Persist per-epoch checkpoints (default: disabled)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--backbone-lr-factor", |
|
|
type=float, |
|
|
default=0.3, |
|
|
help="Relative LR multiplier applied to unfrozen backbone parameters (default: %(default)s)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--early-patience", |
|
|
type=int, |
|
|
default=5, |
|
|
help="Early stopping patience based on validation F1 (default: %(default)s)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--early-min-epochs", |
|
|
type=int, |
|
|
default=10, |
|
|
help="Minimum number of epochs to run before early stopping can trigger (default: %(default)s)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--finetune-epochs", |
|
|
type=int, |
|
|
default=0, |
|
|
help="Additional fine-tuning epochs to run after the main schedule (default: %(default)s)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--finetune-lr-factor", |
|
|
type=float, |
|
|
default=0.1, |
|
|
help="Multiplier applied to the base learning rate during fine-tuning (default: %(default)s)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--finetune-patience", |
|
|
type=int, |
|
|
default=3, |
|
|
help="Early stopping patience for the fine-tuning phase (default: %(default)s)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--finetune-min-epochs", |
|
|
type=int, |
|
|
default=0, |
|
|
help="Minimum epochs to execute in the fine-tuning phase before early stopping is considered (default: %(default)s)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--debug-eval-batches", |
|
|
type=int, |
|
|
default=0, |
|
|
help="Log detailed stats for the first N evaluation batches (0 disables logging)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--debug-eval-interval", |
|
|
type=int, |
|
|
default=1, |
|
|
help="Evaluate logging interval in batches when debug logging is enabled (default: %(default)s)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--debug-eval-softmax", |
|
|
action="store_true", |
|
|
help="When debugging evaluations, also log softmax statistics per batch", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--models", |
|
|
nargs="*", |
|
|
default=["lwm", "resnet18", "efficientnet_b0", "mobilenet_v3_small", "simple_cnn", "ieee_cnn"], |
|
|
help="Models to benchmark", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--raw-input-models", |
|
|
nargs="*", |
|
|
default=None, |
|
|
help=( |
|
|
"Models that should receive raw spectrograms without additional normalization " |
|
|
"(default: all non-LWM models)" |
|
|
), |
|
|
) |
|
|
parser.add_argument( |
|
|
"--lwm-trainable-layers", |
|
|
type=int, |
|
|
default=2, |
|
|
help="Number of transformer layers (from the end) to fine-tune in LWM (default: %(default)s)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--lwm-classifier-dim", |
|
|
type=int, |
|
|
default=64, |
|
|
help="Hidden width for the LWM classifier MLP head (default: %(default)s; ignored for linear head)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--lwm-head-dropout", |
|
|
type=float, |
|
|
default=0.0, |
|
|
help="Dropout applied inside the LWM classifier head (default: %(default)s)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--lwm-head-type", |
|
|
choices=("linear", "mlp", "res1dcnn"), |
|
|
default="res1dcnn", |
|
|
help="Classifier head architecture for LWM (default: %(default)s)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--lwm-backbone-lr-factor", |
|
|
type=float, |
|
|
default=0.2, |
|
|
help="LR multiplier for unfrozen LWM backbone layers (default: %(default)s)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--resnet-head-width", |
|
|
type=int, |
|
|
default=512, |
|
|
help="Hidden width for the ResNet18 classifier head (default: %(default)s)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--efficientnet-head-width", |
|
|
type=int, |
|
|
default=296, |
|
|
help="Hidden width for the EfficientNet-B0 classifier head (default: %(default)s)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--mobilenet-head-width", |
|
|
type=int, |
|
|
default=576, |
|
|
help="Hidden width for the MobileNetV3-Small classifier head (default: %(default)s)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--imagenet-head-dropout", |
|
|
type=float, |
|
|
default=0.6, |
|
|
help="Dropout probability used inside ImageNet backbone classifier heads (default: %(default)s)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--imagenet-weight-decay-scale", |
|
|
type=float, |
|
|
default=2.0, |
|
|
help="Multiplier applied to weight decay for ImageNet backbone trainable parameters (default: %(default)s)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--simple-cnn-hidden-dims", |
|
|
type=int, |
|
|
nargs="*", |
|
|
default=[272, 128], |
|
|
help="Hidden layer widths for the Simple CNN classifier (default: %(default)s)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--ieee-cnn-hidden-dims", |
|
|
type=int, |
|
|
nargs="*", |
|
|
default=[512, 256], |
|
|
help="Hidden layer widths for the IEEE CNN classifier (default: %(default)s)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--ieee-cnn-dropout", |
|
|
type=float, |
|
|
default=0.3, |
|
|
help="Dropout rate for the IEEE CNN model (default: %(default)s)", |
|
|
) |
|
|
parser.add_argument("--checkpoint", type=Path, default=None, help="Path to pretrained LWM checkpoint (.pth)") |
|
|
parser.add_argument("--stats", type=Path, default=None, help="dataset_stats.json path") |
|
|
parser.add_argument( |
|
|
"--models-root", |
|
|
type=Path, |
|
|
default=PROJECT_ROOT / "models", |
|
|
help="Root with pretrained runs (default: project_root/models)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output-dir", |
|
|
type=Path, |
|
|
default=PROJECT_ROOT / "task1" / "mcs_benchmarks", |
|
|
help="Results root directory (per-run subfolder created automatically)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--export-full-model", |
|
|
type=Path, |
|
|
default=None, |
|
|
help="Directory where best full-model checkpoints (backbone + head) will be exported per run", |
|
|
) |
|
|
parser.add_argument("--seed", type=int, default=42, help="Base random seed") |
|
|
args = parser.parse_args() |
|
|
args.output_root = args.output_dir |
|
|
|
|
|
quick_comm: List[str] = [] |
|
|
if getattr(args, "select_lte", False): |
|
|
quick_comm.append("LTE") |
|
|
if getattr(args, "select_wifi", False): |
|
|
quick_comm.append("WiFi") |
|
|
if getattr(args, "select_5g", False): |
|
|
quick_comm.append("5G") |
|
|
if quick_comm: |
|
|
args.comm_types = quick_comm |
|
|
|
|
|
normalized: List[str] = [] |
|
|
for comm in args.comm_types: |
|
|
upper = comm.upper() |
|
|
if upper == "WIFI": |
|
|
normalized.append("WiFi") |
|
|
elif upper == "LTE": |
|
|
normalized.append("LTE") |
|
|
elif upper == "5G": |
|
|
normalized.append("5G") |
|
|
else: |
|
|
normalized.append(comm) |
|
|
args.comm_types = normalized |
|
|
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") |
|
|
comm_tokens: List[str] = [] |
|
|
for comm in args.comm_types: |
|
|
canonical = COMM_LOWER.get(comm, comm.lower()) |
|
|
token = re.sub(r"[^a-z0-9]+", "-", canonical.lower()).strip("-") |
|
|
comm_tokens.append(token or "unknown") |
|
|
comm_suffix = "-".join(comm_tokens) if comm_tokens else "unknown" |
|
|
args.run_timestamp = timestamp |
|
|
args.output_dir = args.output_root / comm_suffix / timestamp |
|
|
args.comm_suffix = comm_suffix |
|
|
|
|
|
if args.gpu_ids is not None and len(args.gpu_ids) == 0: |
|
|
args.gpu_ids = None |
|
|
|
|
|
if not args.simple_cnn_hidden_dims: |
|
|
args.simple_cnn_hidden_dims = [512, 256] |
|
|
if not args.ieee_cnn_hidden_dims: |
|
|
args.ieee_cnn_hidden_dims = [512, 256] |
|
|
if args.raw_input_models is None: |
|
|
args.raw_input_models = [ |
|
|
model.lower() |
|
|
for model in args.models |
|
|
if model.lower() not in {"lwm"} |
|
|
] |
|
|
else: |
|
|
args.raw_input_models = [model.lower() for model in args.raw_input_models] |
|
|
|
|
|
args.models = [model for model in args.models] |
|
|
args.save_epoch_history = not args.no_epoch_history |
|
|
args.plot_epoch_history = not args.no_epoch_plot |
|
|
|
|
|
args.imagenet_head_dropout = float(max(0.0, min(args.imagenet_head_dropout, 0.95))) |
|
|
args.imagenet_weight_decay_scale = float(max(0.0, args.imagenet_weight_decay_scale)) |
|
|
|
|
|
return args |
|
|
|
|
|
|
|
|
def find_latest_run(models_root: Path) -> Path: |
|
|
run_dirs = [p for p in models_root.iterdir() if p.is_dir()] |
|
|
run_dirs = [p for p in run_dirs if not p.name.lower().endswith("_models")] |
|
|
valid_runs = [p for p in run_dirs if any(p.glob("*.pth"))] |
|
|
if valid_runs: |
|
|
return max(valid_runs, key=lambda p: p.stat().st_mtime) |
|
|
|
|
|
checkpoints = list(models_root.glob("*.pth")) |
|
|
if checkpoints: |
|
|
print(f"[INFO] No checkpoint-bearing subdirectories under {models_root}; using root as run directory.") |
|
|
return models_root |
|
|
|
|
|
raise FileNotFoundError(f"No checkpoints found under {models_root}") |
|
|
|
|
|
|
|
|
def find_best_checkpoint(run_dir: Path) -> Path: |
|
|
candidates = list(run_dir.glob("*.pth")) |
|
|
if not candidates: |
|
|
raise FileNotFoundError(f"No checkpoints in {run_dir}") |
|
|
|
|
|
def metric(path: Path) -> float: |
|
|
match = re.search(r"_val([0-9]+(?:\.[0-9]+)?)", path.name) |
|
|
if match: |
|
|
try: |
|
|
return float(match.group(1)) |
|
|
except ValueError: |
|
|
pass |
|
|
return float("inf") |
|
|
|
|
|
best = min(candidates, key=metric) |
|
|
return best |
|
|
|
|
|
|
|
|
def resolve_models_directory(args: argparse.Namespace) -> Path: |
|
|
base = args.models_root.expanduser().resolve() |
|
|
if not base.exists(): |
|
|
raise FileNotFoundError(f"Models root not found: {base}") |
|
|
matches: List[Path] = [] |
|
|
for comm in args.comm_types: |
|
|
subdir = base / f"{comm}_models" |
|
|
if subdir.exists(): |
|
|
matches.append(subdir) |
|
|
else: |
|
|
print(f"[WARN] Models directory for {comm} not found at {subdir}") |
|
|
if len(matches) == 1: |
|
|
print(f"[INFO] Using models directory for {args.comm_types[0]}: {matches[0]}") |
|
|
return matches[0] |
|
|
if len(matches) > 1: |
|
|
raise ValueError( |
|
|
"Multiple communication-specific model directories detected; please provide --checkpoint explicitly." |
|
|
) |
|
|
print(f"[INFO] Using shared models directory: {base}") |
|
|
return base |
|
|
|
|
|
|
|
|
def resolve_checkpoint_and_stats(args: argparse.Namespace, require_checkpoint: bool) -> Tuple[Path | None, Dict[str, object]]: |
|
|
checkpoint: Path | None = None |
|
|
models_dir = resolve_models_directory(args) |
|
|
user_provided_stats = args.stats is not None |
|
|
|
|
|
if args.checkpoint is not None: |
|
|
checkpoint = args.checkpoint.expanduser().resolve() |
|
|
if not checkpoint.exists(): |
|
|
raise FileNotFoundError(f"Checkpoint not found: {checkpoint}") |
|
|
stats_path = args.stats.expanduser().resolve() if user_provided_stats else checkpoint.parent / "dataset_stats.json" |
|
|
else: |
|
|
run_dir = find_latest_run(models_dir) |
|
|
stats_path = run_dir / "dataset_stats.json" |
|
|
if require_checkpoint: |
|
|
checkpoint = find_best_checkpoint(run_dir) |
|
|
else: |
|
|
checkpoint = None |
|
|
|
|
|
if stats_path.exists(): |
|
|
try: |
|
|
with open(stats_path, "r", encoding="utf-8") as f: |
|
|
stats = json.load(f) |
|
|
except json.JSONDecodeError as exc: |
|
|
if user_provided_stats: |
|
|
raise ValueError( |
|
|
f"Failed to parse dataset_stats.json at {stats_path}: {exc}" |
|
|
) from exc |
|
|
print( |
|
|
f"[WARN] Corrupt dataset_stats.json at {stats_path}; " |
|
|
"falling back to mean=0/std=1 per-sample normalization." |
|
|
) |
|
|
stats = {"mean": 0.0, "std": 1.0, "normalization": "per_sample"} |
|
|
else: |
|
|
if "mean" not in stats or "std" not in stats: |
|
|
raise ValueError("dataset_stats.json must contain 'mean' and 'std'") |
|
|
stats.setdefault("normalization", stats.get("mode", "dataset")) |
|
|
else: |
|
|
if user_provided_stats: |
|
|
raise FileNotFoundError(f"dataset_stats.json not found: {stats_path}") |
|
|
stats = {"mean": 0.0, "std": 1.0, "normalization": "per_sample"} |
|
|
print(f"[WARN] dataset_stats.json not found at {stats_path}. Falling back to per-sample normalization.") |
|
|
|
|
|
if checkpoint is not None: |
|
|
print(f"[INFO] Using checkpoint: {checkpoint}") |
|
|
elif require_checkpoint: |
|
|
raise FileNotFoundError("LWM requested but no checkpoint available") |
|
|
else: |
|
|
print("[INFO] No LWM checkpoint required for selected models") |
|
|
|
|
|
norm_mode = str(stats.get("normalization", "dataset")) |
|
|
if norm_mode.lower() == "dataset": |
|
|
print(f"[INFO] Dataset stats -> mean={stats['mean']:.4f}, std={stats['std']:.4f}") |
|
|
else: |
|
|
print("[INFO] Normalization mode: per_sample") |
|
|
|
|
|
return checkpoint, { |
|
|
"mean": float(stats.get("mean", 0.0)), |
|
|
"std": float(stats.get("std", 1.0)), |
|
|
"normalization": norm_mode, |
|
|
} |
|
|
|
|
|
|
|
|
def identify_modulation(path: str) -> tuple[int | None, str | None]: |
|
|
for mod_name, label in MODULATION_LABELS.items(): |
|
|
if mod_name in path: |
|
|
return label, mod_name |
|
|
return None, None |
|
|
|
|
|
|
|
|
def _extract_metadata(parts: Sequence[str]) -> Tuple[str, str, str]: |
|
|
rate = next((part for part in parts if part.startswith("rate")), "rate_unknown") |
|
|
snr = next((part for part in parts if part.startswith("SNR")), "SNR_unknown") |
|
|
mobility = next((part for part in parts if part in {"static", "pedestrian", "vehicular"}), "mobility_unknown") |
|
|
return rate, snr, mobility |
|
|
|
|
|
|
|
|
def discover_snr_mobility( |
|
|
data_root: Path, |
|
|
cities: Sequence[str], |
|
|
comm_types: Sequence[str], |
|
|
fft_folder: str, |
|
|
) -> Tuple[List[str], List[str]]: |
|
|
snrs: set[str] = set() |
|
|
mobilities: set[str] = set() |
|
|
for city in cities: |
|
|
for comm in comm_types: |
|
|
base = data_root / city / comm |
|
|
if not base.exists(): |
|
|
continue |
|
|
for root, dirs, _ in os.walk(base): |
|
|
parts = Path(root).parts |
|
|
for part in parts: |
|
|
if part.startswith("SNR") and part.endswith("dB"): |
|
|
snrs.add(part) |
|
|
elif part in {"static", "pedestrian", "vehicular"}: |
|
|
mobilities.add(part) |
|
|
if not snrs: |
|
|
snrs.add("SNR20dB") |
|
|
if not mobilities: |
|
|
mobilities.add("static") |
|
|
return sorted(snrs), sorted(mobilities) |
|
|
|
|
|
|
|
|
def build_config_map( |
|
|
data_root: Path, |
|
|
cities: Sequence[str], |
|
|
comm_types: Sequence[str], |
|
|
snrs: Sequence[str], |
|
|
mobilities: Sequence[str], |
|
|
fft_folder: str, |
|
|
) -> Dict[int, Dict[str, List[str]]]: |
|
|
class_configs: Dict[int, Dict[str, List[str]]] = defaultdict(lambda: defaultdict(list)) |
|
|
for city in cities: |
|
|
for comm in comm_types: |
|
|
base = data_root / city / comm |
|
|
for snr in snrs: |
|
|
for mobility in mobilities: |
|
|
pattern = str(base / "**" / snr / mobility / "**" / fft_folder / "**" / "spectrograms" / "*.pkl") |
|
|
for path_str in glob.glob(pattern, recursive=True): |
|
|
cls, modulation_name = identify_modulation(path_str) |
|
|
if cls is None: |
|
|
continue |
|
|
rate, _, _ = _extract_metadata(Path(path_str).parts) |
|
|
config_name = f"{modulation_name}_{rate}_{snr}_{mobility}" |
|
|
class_configs[cls][config_name].append(path_str) |
|
|
return class_configs |
|
|
|
|
|
|
|
|
def build_global_config_map( |
|
|
data_root: Path, |
|
|
cities: Sequence[str], |
|
|
comm_types: Sequence[str], |
|
|
fft_folder: str, |
|
|
) -> Dict[int, Dict[str, List[str]]]: |
|
|
class_configs: Dict[int, Dict[str, List[str]]] = defaultdict(lambda: defaultdict(list)) |
|
|
for city in cities: |
|
|
for comm in comm_types: |
|
|
base = data_root / city / comm |
|
|
pattern = str(base / "**" / fft_folder / "**" / "spectrograms" / "*.pkl") |
|
|
for path_str in glob.glob(pattern, recursive=True): |
|
|
cls, modulation_name = identify_modulation(path_str) |
|
|
if cls is None: |
|
|
continue |
|
|
rate, snr_part, mobility_part = _extract_metadata(Path(path_str).parts) |
|
|
config_name = f"{modulation_name}_{rate}_{snr_part}_{mobility_part}" |
|
|
class_configs[cls][config_name].append(path_str) |
|
|
return class_configs |
|
|
|
|
|
|
|
|
def _count_samples_in_path(path: str) -> int: |
|
|
cached = _SAMPLE_COUNT_CACHE.get(path) |
|
|
if cached is not None: |
|
|
return cached |
|
|
arr = load_all_samples(path) |
|
|
count = int(arr.shape[0]) |
|
|
_SAMPLE_COUNT_CACHE[path] = count |
|
|
return count |
|
|
|
|
|
|
|
|
class LazyConfigArray: |
|
|
"""Lazily views spectrograms spread across multiple pickled files.""" |
|
|
|
|
|
__slots__ = ("paths", "_counts", "_offsets", "_total", "shape", "dtype", "ndim") |
|
|
|
|
|
def __init__(self, paths: Sequence[str]) -> None: |
|
|
filtered_paths: List[str] = [] |
|
|
counts: List[int] = [] |
|
|
for path in sorted(paths): |
|
|
count = _count_samples_in_path(path) |
|
|
if count <= 0: |
|
|
continue |
|
|
filtered_paths.append(path) |
|
|
counts.append(count) |
|
|
|
|
|
self.paths: Tuple[str, ...] = tuple(filtered_paths) |
|
|
if counts: |
|
|
self._counts = np.array(counts, dtype=np.int64) |
|
|
self._offsets = np.concatenate(([0], np.cumsum(self._counts))) |
|
|
self._total = int(self._offsets[-1]) |
|
|
else: |
|
|
self._counts = np.empty(0, dtype=np.int64) |
|
|
self._offsets = np.array([0], dtype=np.int64) |
|
|
self._total = 0 |
|
|
|
|
|
self.shape = (self._total, 128, 128) |
|
|
self.dtype = np.float32 |
|
|
self.ndim = 3 |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return self._total |
|
|
|
|
|
def _resolve_index(self, index: int) -> Tuple[int, int]: |
|
|
if self._total == 0: |
|
|
raise IndexError("attempting to index empty LazyConfigArray") |
|
|
if index < 0: |
|
|
index += self._total |
|
|
if index < 0 or index >= self._total: |
|
|
raise IndexError("index out of range for LazyConfigArray") |
|
|
path_idx = int(np.searchsorted(self._offsets[1:], index, side="right")) |
|
|
start = int(self._offsets[path_idx]) |
|
|
return path_idx, int(index - start) |
|
|
|
|
|
def _load_path(self, path_idx: int) -> np.ndarray: |
|
|
path = self.paths[path_idx] |
|
|
return load_all_samples(path) |
|
|
|
|
|
def __getitem__(self, item: Any) -> np.ndarray: |
|
|
if isinstance(item, (int, np.integer)): |
|
|
path_idx, local_idx = self._resolve_index(int(item)) |
|
|
data = self._load_path(path_idx) |
|
|
sample = data[local_idx].copy() |
|
|
return sample |
|
|
|
|
|
indices = np.asarray(item, dtype=np.int64) |
|
|
if indices.ndim == 0: |
|
|
indices = indices.reshape(1) |
|
|
else: |
|
|
indices = indices.reshape(-1) |
|
|
if indices.size == 0: |
|
|
return np.empty((0, 128, 128), dtype=np.float32) |
|
|
|
|
|
resolved: Dict[int, List[Tuple[int, int]]] = {} |
|
|
for pos, raw_idx in enumerate(indices): |
|
|
path_idx, local_idx = self._resolve_index(int(raw_idx)) |
|
|
resolved.setdefault(path_idx, []).append((pos, local_idx)) |
|
|
|
|
|
result = np.empty((indices.size, 128, 128), dtype=np.float32) |
|
|
for path_idx, items in resolved.items(): |
|
|
data = self._load_path(path_idx) |
|
|
local_positions = [loc for _, loc in items] |
|
|
chunk = data[local_positions] |
|
|
for offset, (pos, _) in enumerate(items): |
|
|
result[pos] = chunk[offset] |
|
|
return result |
|
|
|
|
|
|
|
|
def load_config_arrays(class_configs: Dict[int, Dict[str, List[str]]]) -> Dict[int, Dict[str, LazyConfigArray]]: |
|
|
loaded: Dict[int, Dict[str, LazyConfigArray]] = {} |
|
|
for cls, configs in class_configs.items(): |
|
|
arrays_for_cls: Dict[str, LazyConfigArray] = {} |
|
|
for config_name, paths in configs.items(): |
|
|
lazy_array = LazyConfigArray(paths) |
|
|
if len(lazy_array) == 0: |
|
|
continue |
|
|
arrays_for_cls[config_name] = lazy_array |
|
|
if arrays_for_cls: |
|
|
loaded[cls] = arrays_for_cls |
|
|
return loaded |
|
|
|
|
|
|
|
|
def load_all_samples(path: str) -> np.ndarray: |
|
|
with open(path, "rb") as f: |
|
|
data = pickle.load(f) |
|
|
if isinstance(data, dict) and "spectrograms" in data: |
|
|
arr = data["spectrograms"] |
|
|
elif isinstance(data, np.ndarray): |
|
|
arr = data |
|
|
else: |
|
|
return np.empty((0, 128, 128), dtype=np.float32) |
|
|
|
|
|
arr = np.asarray(arr, dtype=np.float32) |
|
|
if arr.ndim == 2: |
|
|
arr = arr[None, ...] |
|
|
if arr.shape[1:] != (128, 128): |
|
|
return np.empty((0, 128, 128), dtype=np.float32) |
|
|
return arr |
|
|
|
|
|
|
|
|
def sample_from_paths( |
|
|
paths: Sequence[str], |
|
|
n_samples: int, |
|
|
rng: np.random.Generator, |
|
|
used_map: Dict[str, set[int]], |
|
|
) -> Tuple[np.ndarray, List[Tuple[str, np.ndarray]]]: |
|
|
if not paths: |
|
|
raise RuntimeError("No files available for sampling") |
|
|
|
|
|
paths_array = np.array(paths, dtype=object) |
|
|
order = rng.permutation(len(paths_array)) |
|
|
remaining = n_samples |
|
|
collected: List[np.ndarray] = [] |
|
|
info: List[Tuple[str, np.ndarray]] = [] |
|
|
|
|
|
for idx in order: |
|
|
if remaining <= 0: |
|
|
break |
|
|
path = str(paths_array[idx]) |
|
|
samples = load_all_samples(path) |
|
|
total = samples.shape[0] |
|
|
used = used_map[path] |
|
|
if used: |
|
|
used_idx = np.fromiter(used, dtype=np.int64, count=len(used)) |
|
|
available = np.setdiff1d(np.arange(total), used_idx, assume_unique=True) |
|
|
else: |
|
|
available = np.arange(total) |
|
|
if available.size == 0: |
|
|
continue |
|
|
take = min(remaining, available.size) |
|
|
chosen = rng.choice(available, size=take, replace=False) |
|
|
collected.append(samples[chosen]) |
|
|
used_map[path].update(int(i) for i in chosen) |
|
|
info.append((path, chosen)) |
|
|
remaining -= take |
|
|
|
|
|
if remaining > 0: |
|
|
raise RuntimeError("Insufficient samples remaining to satisfy request") |
|
|
|
|
|
result = np.concatenate(collected, axis=0) if len(collected) > 1 else collected[0] |
|
|
return result, info |
|
|
|
|
|
|
|
|
def _ensure_available(total_needed: int, availability: Dict[str, set]) -> None: |
|
|
remaining = sum(len(indices) for indices in availability.values()) |
|
|
if remaining < total_needed: |
|
|
raise RuntimeError( |
|
|
f"Insufficient samples: need {total_needed}, only {remaining} available across configs" |
|
|
) |
|
|
|
|
|
|
|
|
def _sample_from_availability( |
|
|
arrays_map: Dict[str, LazyConfigArray], |
|
|
availability: Dict[str, set[int]], |
|
|
total_needed: int, |
|
|
rng: np.random.Generator, |
|
|
) -> Tuple[np.ndarray, Dict[str, set[int]]]: |
|
|
if total_needed <= 0: |
|
|
return np.empty((0, 128, 128), dtype=np.float32), {cfg: set() for cfg in arrays_map} |
|
|
|
|
|
_ensure_available(total_needed, availability) |
|
|
remaining = total_needed |
|
|
configs = [cfg for cfg, indices in availability.items() if indices] |
|
|
used: Dict[str, set[int]] = {cfg: set() for cfg in arrays_map} |
|
|
collected: List[np.ndarray] = [] |
|
|
|
|
|
while remaining > 0 and configs: |
|
|
cfg = rng.choice(configs) |
|
|
available_indices = np.array(list(availability[cfg]), dtype=np.int64) |
|
|
if available_indices.size == 0: |
|
|
configs = [c for c in configs if c != cfg] |
|
|
continue |
|
|
take = min(max(1, remaining // max(len(configs), 1)), remaining, available_indices.size) |
|
|
chosen = rng.choice(available_indices, size=take, replace=False) |
|
|
collected.append(arrays_map[cfg][chosen]) |
|
|
chosen_set = {int(idx) for idx in chosen} |
|
|
used[cfg].update(chosen_set) |
|
|
availability[cfg].difference_update(chosen_set) |
|
|
remaining -= take |
|
|
configs = [c for c in configs if availability[c]] |
|
|
|
|
|
if remaining > 0: |
|
|
raise RuntimeError("Sampling failed to collect the requested number of samples") |
|
|
|
|
|
stacked = np.concatenate(collected, axis=0) if collected else np.empty((0, 128, 128), dtype=np.float32) |
|
|
return stacked.astype(np.float32, copy=False), used |
|
|
|
|
|
|
|
|
def sample_train_arrays( |
|
|
arrays_map: Dict[str, LazyConfigArray], |
|
|
availability: Dict[str, set[int]], |
|
|
train_size: int, |
|
|
rng: np.random.Generator, |
|
|
) -> Tuple[np.ndarray, Dict[str, set[int]]]: |
|
|
return _sample_from_availability(arrays_map, availability, train_size, rng) |
|
|
|
|
|
|
|
|
def sample_global_arrays( |
|
|
arrays_map: Dict[str, LazyConfigArray], |
|
|
availability: Dict[str, set[int]], |
|
|
per_class: int, |
|
|
rng: np.random.Generator, |
|
|
) -> Tuple[np.ndarray, Dict[str, set[int]]]: |
|
|
return _sample_from_availability(arrays_map, availability, per_class, rng) |
|
|
|
|
|
|
|
|
class SpectrogramDataset(Dataset): |
|
|
def __init__(self, specs: np.ndarray, labels: np.ndarray): |
|
|
self.specs = specs.astype(np.float32, copy=False) |
|
|
self.labels = labels.astype(np.int64, copy=False) |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.labels) |
|
|
|
|
|
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]: |
|
|
return torch.from_numpy(self.specs[idx]), int(self.labels[idx]) |
|
|
|
|
|
|
|
|
def normalize_batch(specs: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: |
|
|
mean = specs.mean() |
|
|
std = specs.std(unbiased=False) |
|
|
std = torch.clamp(std, min=eps) |
|
|
return (specs - mean) / std |
|
|
|
|
|
|
|
|
def apply_spec_augment( |
|
|
specs: torch.Tensor, |
|
|
*, |
|
|
freq_mask_width: int = 12, |
|
|
time_mask_width: int = 16, |
|
|
freq_masks: int = 2, |
|
|
time_masks: int = 2, |
|
|
mask_prob: float = 0.5, |
|
|
noise_std: float = 0.0, |
|
|
) -> torch.Tensor: |
|
|
"""Apply light-weight SpecAugment-style masking to a batch of spectrograms. |
|
|
|
|
|
The function accepts tensors shaped ``[B, H, W]`` or ``[B, 1, H, W]`` and |
|
|
returns an augmented tensor with the same shape. Masks use the sample mean |
|
|
to avoid introducing large bias and are applied per-sample with the given |
|
|
probability. Gaussian noise (if requested) is injected before masking. |
|
|
""" |
|
|
|
|
|
if mask_prob <= 0.0 and noise_std <= 0.0: |
|
|
return specs |
|
|
|
|
|
if specs.dim() not in (3, 4): |
|
|
raise ValueError(f"Spectrograms must be rank-3 or rank-4, got shape {tuple(specs.shape)}") |
|
|
|
|
|
needs_squeeze = specs.dim() == 3 |
|
|
augmented = specs.unsqueeze(1) if needs_squeeze else specs |
|
|
batch_size, _, freq_dim, time_dim = augmented.shape |
|
|
|
|
|
if mask_prob < 1.0: |
|
|
apply_mask = torch.rand(batch_size, device=augmented.device) < mask_prob |
|
|
else: |
|
|
apply_mask = torch.ones(batch_size, dtype=torch.bool, device=augmented.device) |
|
|
|
|
|
freq_mask_width = max(0, int(freq_mask_width)) |
|
|
time_mask_width = max(0, int(time_mask_width)) |
|
|
freq_masks = max(0, int(freq_masks)) |
|
|
time_masks = max(0, int(time_masks)) |
|
|
|
|
|
for idx in range(batch_size): |
|
|
if not apply_mask[idx]: |
|
|
continue |
|
|
|
|
|
sample = augmented[idx] |
|
|
if noise_std > 0.0: |
|
|
sample = sample + noise_std * torch.randn_like(sample) |
|
|
|
|
|
fill_value = sample.mean() |
|
|
|
|
|
if freq_mask_width > 0 and freq_masks > 0: |
|
|
max_width = min(freq_mask_width, freq_dim) |
|
|
for _ in range(freq_masks): |
|
|
width = int(torch.randint(0, max_width + 1, (1,), device=augmented.device).item()) |
|
|
if width == 0 or width > freq_dim: |
|
|
continue |
|
|
start = 0 if freq_dim == width else int(torch.randint(0, freq_dim - width + 1, (1,), device=augmented.device).item()) |
|
|
sample[:, start:start + width, :] = fill_value |
|
|
|
|
|
if time_mask_width > 0 and time_masks > 0: |
|
|
max_width = min(time_mask_width, time_dim) |
|
|
for _ in range(time_masks): |
|
|
width = int(torch.randint(0, max_width + 1, (1,), device=augmented.device).item()) |
|
|
if width == 0 or width > time_dim: |
|
|
continue |
|
|
start = 0 if time_dim == width else int(torch.randint(0, time_dim - width + 1, (1,), device=augmented.device).item()) |
|
|
sample[:, :, start:start + width] = fill_value |
|
|
|
|
|
augmented[idx] = sample |
|
|
|
|
|
return augmented.squeeze(1) if needs_squeeze else augmented |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _write_epoch_history( |
|
|
rep_root: Path, |
|
|
records: Sequence[Dict[str, object]], |
|
|
enable_csv: bool, |
|
|
enable_plot: bool, |
|
|
) -> None: |
|
|
if not records: |
|
|
return |
|
|
|
|
|
rep_root.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
if enable_csv: |
|
|
base_fields = [ |
|
|
"model", |
|
|
"epoch", |
|
|
"phase", |
|
|
"train_loss", |
|
|
"val_loss", |
|
|
"val_acc", |
|
|
"val_f1", |
|
|
"lr", |
|
|
"train_size_requested", |
|
|
"train_size_effective", |
|
|
] |
|
|
extra_fields = sorted( |
|
|
{key for rec in records for key in rec.keys() if key not in base_fields} |
|
|
) |
|
|
fieldnames = base_fields + extra_fields |
|
|
sorted_records = sorted(records, key=lambda r: (r["epoch"], r["model"], r.get("phase", ""))) |
|
|
history_path = rep_root / "epoch_history.csv" |
|
|
with open(history_path, "w", newline="", encoding="utf-8") as csvfile: |
|
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames, extrasaction="ignore") |
|
|
writer.writeheader() |
|
|
writer.writerows(sorted_records) |
|
|
|
|
|
if enable_plot and HAVE_MPL: |
|
|
models_in_run = sorted({rec["model"] for rec in records}) |
|
|
fig, axes = plt.subplots(2, 1, figsize=(8, 6), sharex=True) |
|
|
for ax in axes: |
|
|
ax.grid(True, linestyle='--', alpha=0.3) |
|
|
for model_name_plot in models_in_run: |
|
|
model_records = [rec for rec in records if rec["model"] == model_name_plot] |
|
|
if not model_records: |
|
|
continue |
|
|
epochs = [rec["epoch"] for rec in model_records] |
|
|
val_loss_values = [rec["val_loss"] for rec in model_records] |
|
|
val_f1_values = [rec["val_f1"] for rec in model_records] |
|
|
axes[0].plot(epochs, val_loss_values, marker='o', label=model_name_plot) |
|
|
axes[1].plot(epochs, val_f1_values, marker='o', label=model_name_plot) |
|
|
axes[0].set_ylabel('Val Loss') |
|
|
axes[1].set_ylabel('Val F1') |
|
|
axes[1].set_xlabel('Epoch') |
|
|
axes[0].legend(loc='best') |
|
|
axes[0].set_title('Per-epoch validation metrics') |
|
|
fig.tight_layout() |
|
|
fig.savefig(rep_root / 'epoch_history.png', dpi=150) |
|
|
plt.close(fig) |
|
|
|
|
|
|
|
|
class ResidualBlock1D(nn.Module): |
|
|
"""Lightweight residual block used by the res1dcnn head.""" |
|
|
|
|
|
def __init__(self, in_channels: int, out_channels: int) -> None: |
|
|
super().__init__() |
|
|
self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1) |
|
|
self.bn1 = nn.BatchNorm1d(out_channels) |
|
|
self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1) |
|
|
self.bn2 = nn.BatchNorm1d(out_channels) |
|
|
self.shortcut = nn.Identity() |
|
|
if in_channels != out_channels: |
|
|
self.shortcut = nn.Sequential( |
|
|
nn.Conv1d(in_channels, out_channels, kernel_size=1), |
|
|
nn.BatchNorm1d(out_channels), |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
residual = x |
|
|
x = F.relu(self.bn1(self.conv1(x))) |
|
|
x = self.bn2(self.conv2(x)) |
|
|
x = x + self.shortcut(residual) |
|
|
x = F.relu(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class Res1DCNNHead(nn.Module): |
|
|
"""Residual 1D CNN classifier head that operates on 128-d LWM features.""" |
|
|
|
|
|
def __init__(self, input_dim: int, num_classes: int, dropout: float = 0.1) -> None: |
|
|
super().__init__() |
|
|
self.input_dim = int(input_dim) |
|
|
hidden_dim = 64 |
|
|
self.conv1 = nn.Conv1d(1, hidden_dim, kernel_size=3, padding=1) |
|
|
self.bn1 = nn.BatchNorm1d(hidden_dim) |
|
|
self.res_block = ResidualBlock1D(hidden_dim, hidden_dim) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.fc = nn.Linear(hidden_dim, num_classes) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = x.unsqueeze(1) |
|
|
x = F.relu(self.bn1(self.conv1(x))) |
|
|
x = self.res_block(x) |
|
|
x = F.adaptive_avg_pool1d(x, 1).squeeze(-1) |
|
|
x = self.dropout(x) |
|
|
return self.fc(x) |
|
|
|
|
|
|
|
|
class LWMClassifier(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
backbone: nn.Module, |
|
|
trainable_layers: int, |
|
|
num_classes: int, |
|
|
classifier_dim: int = 128, |
|
|
head_dropout: float = 0.1, |
|
|
head_type: str = "mlp", |
|
|
): |
|
|
super().__init__() |
|
|
self.backbone = backbone |
|
|
self.patch_size = 4 |
|
|
self.unfold = nn.Unfold(kernel_size=self.patch_size, stride=self.patch_size) |
|
|
head_dropout = max(0.0, float(head_dropout)) |
|
|
head_type = head_type.lower().strip() |
|
|
|
|
|
if head_type == "linear": |
|
|
head_layers: List[nn.Module] = [nn.LayerNorm(128)] |
|
|
if head_dropout > 0: |
|
|
head_layers.append(nn.Dropout(head_dropout)) |
|
|
head_layers.append(nn.Linear(128, num_classes)) |
|
|
self.classifier = nn.Sequential(*head_layers) |
|
|
elif head_type == "res1dcnn": |
|
|
self.classifier = nn.Sequential( |
|
|
nn.LayerNorm(128), |
|
|
Res1DCNNHead(128, num_classes, dropout=head_dropout), |
|
|
) |
|
|
else: |
|
|
head_layers = [ |
|
|
nn.LayerNorm(128), |
|
|
nn.Linear(128, classifier_dim), |
|
|
nn.GELU(), |
|
|
] |
|
|
if head_dropout > 0: |
|
|
head_layers.append(nn.Dropout(head_dropout)) |
|
|
head_layers.append(nn.Linear(classifier_dim, num_classes)) |
|
|
self.classifier = nn.Sequential(*head_layers) |
|
|
|
|
|
for param in self.backbone.parameters(): |
|
|
param.requires_grad = False |
|
|
if trainable_layers > 0: |
|
|
for layer in self.backbone.layers[-trainable_layers:]: |
|
|
for param in layer.parameters(): |
|
|
param.requires_grad = True |
|
|
|
|
|
if hasattr(layer, 'gradient_checkpointing'): |
|
|
layer.gradient_checkpointing = True |
|
|
|
|
|
def spectrogram_to_tokens(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = x.unsqueeze(1) |
|
|
patches = self.unfold(x).transpose(1, 2) |
|
|
cls = torch.full( |
|
|
(patches.size(0), 1, patches.size(-1)), 0.2, dtype=patches.dtype, device=patches.device |
|
|
) |
|
|
return torch.cat([cls, patches], dim=1) |
|
|
|
|
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor: |
|
|
tokens = self.spectrogram_to_tokens(x) |
|
|
outputs = self.backbone(tokens) |
|
|
if outputs.size(1) <= 1: |
|
|
return outputs[:, 0, :] |
|
|
return outputs[:, 1:, :].mean(dim=1) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
cls = self.forward_features(x) |
|
|
return self.classifier(cls) |
|
|
|
|
|
|
|
|
def create_simple_cnn( |
|
|
num_classes: int, |
|
|
hidden_dims: Tuple[int, ...] = (192,), |
|
|
dropout: float = 0.3, |
|
|
) -> nn.Module: |
|
|
"""Create baseline CNN with configurable classifier width.""" |
|
|
|
|
|
if not hidden_dims: |
|
|
raise ValueError("hidden_dims must contain at least one value") |
|
|
|
|
|
layers: List[nn.Module] = [ |
|
|
nn.Conv2d(1, 16, 5, padding=2), nn.ReLU(), nn.MaxPool2d(2), |
|
|
nn.Conv2d(16, 32, 5, padding=2), nn.ReLU(), nn.MaxPool2d(2), |
|
|
nn.Conv2d(32, 64, 5, padding=2), nn.ReLU(), nn.AdaptiveAvgPool2d((4, 4)), |
|
|
nn.Flatten(), |
|
|
nn.Dropout(dropout), |
|
|
] |
|
|
|
|
|
in_dim = 4 * 4 * 64 |
|
|
fc_layers: List[nn.Module] = [] |
|
|
for idx, hidden_dim in enumerate(hidden_dims): |
|
|
fc_layers.append(nn.Linear(in_dim, hidden_dim)) |
|
|
fc_layers.append(nn.ReLU()) |
|
|
fc_layers.append(nn.Dropout(dropout)) |
|
|
in_dim = hidden_dim |
|
|
|
|
|
fc_layers.append(nn.Linear(in_dim, num_classes)) |
|
|
|
|
|
return nn.Sequential(*layers, *fc_layers) |
|
|
|
|
|
|
|
|
def create_ieee_cnn( |
|
|
num_classes: int, |
|
|
hidden_dims: Tuple[int, ...] = (512, 256), |
|
|
dropout: float = 0.3, |
|
|
) -> nn.Module: |
|
|
"""CNN inspired by IEEE 2021 joint SNR/mobility classifier.""" |
|
|
|
|
|
if not hidden_dims: |
|
|
raise ValueError("hidden_dims must contain at least one value") |
|
|
|
|
|
layers: List[nn.Module] = [ |
|
|
nn.Conv2d(1, 32, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(32), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.MaxPool2d(2, 2), |
|
|
nn.Dropout2d(p=dropout), |
|
|
nn.Conv2d(32, 64, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(64), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.MaxPool2d(2, 2), |
|
|
nn.Dropout2d(p=dropout), |
|
|
nn.Conv2d(64, 128, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(128), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.MaxPool2d(2, 2), |
|
|
nn.Dropout2d(p=dropout), |
|
|
nn.Conv2d(128, 256, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(256), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.MaxPool2d(2, 2), |
|
|
nn.Dropout2d(p=dropout), |
|
|
nn.Conv2d(256, 256, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(256), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.AdaptiveAvgPool2d((4, 4)), |
|
|
nn.Flatten(), |
|
|
nn.Dropout(dropout), |
|
|
] |
|
|
|
|
|
in_dim = 4 * 4 * 256 |
|
|
fc_layers: List[nn.Module] = [] |
|
|
for hidden_dim in hidden_dims: |
|
|
fc_layers.append(nn.Linear(in_dim, hidden_dim)) |
|
|
fc_layers.append(nn.BatchNorm1d(hidden_dim)) |
|
|
fc_layers.append(nn.ReLU(inplace=True)) |
|
|
fc_layers.append(nn.Dropout(dropout)) |
|
|
in_dim = hidden_dim |
|
|
|
|
|
fc_layers.append(nn.Linear(in_dim, num_classes)) |
|
|
return nn.Sequential(*layers, *fc_layers) |
|
|
|
|
|
|
|
|
def build_model( |
|
|
name: str, |
|
|
num_classes: int, |
|
|
checkpoint: Path, |
|
|
device: torch.device, |
|
|
trainable_layers: int, |
|
|
backbone_lr_factor: float, |
|
|
overrides: Dict[str, object] | None = None, |
|
|
) -> Tuple[nn.Module, List[Dict[str, object]]]: |
|
|
name = name.lower() |
|
|
param_groups: List[Dict[str, object]] = [] |
|
|
overrides = overrides or {} |
|
|
|
|
|
if name == "lwm": |
|
|
backbone = lwm_model(element_length=16, d_model=128, n_layers=12, max_len=1025, n_heads=8, dropout=0.1) |
|
|
if checkpoint is None: |
|
|
raise FileNotFoundError("Checkpoint is required for LWM-based models") |
|
|
try: |
|
|
state = torch.load(checkpoint, map_location="cpu", weights_only=True) |
|
|
except TypeError: |
|
|
|
|
|
state = torch.load(checkpoint, map_location="cpu") |
|
|
if any(k.startswith("module.") for k in state): |
|
|
state = {k.replace("module.", ""): v for k, v in state.items()} |
|
|
if any(k.startswith("backbone.") for k in state): |
|
|
backbone_state = { |
|
|
k.split("backbone.", 1)[1]: v |
|
|
for k, v in state.items() |
|
|
if k.startswith("backbone.") |
|
|
} |
|
|
else: |
|
|
backbone_state = { |
|
|
k: v |
|
|
for k, v in state.items() |
|
|
if not k.startswith("classifier.") and not k.startswith("projection_head.") |
|
|
} |
|
|
backbone.load_state_dict(backbone_state, strict=False) |
|
|
|
|
|
classifier_dim = int(overrides.get("lwm_classifier_dim", 96)) |
|
|
head_dropout = float(overrides.get("lwm_head_dropout", 0.1)) |
|
|
head_type = str(overrides.get("lwm_head_type", "mlp")).lower() |
|
|
model = LWMClassifier( |
|
|
backbone, |
|
|
trainable_layers=trainable_layers, |
|
|
num_classes=num_classes, |
|
|
classifier_dim=classifier_dim, |
|
|
head_dropout=head_dropout, |
|
|
head_type=head_type, |
|
|
) |
|
|
head_params = list(model.classifier.parameters()) |
|
|
param_groups.append({"params": head_params, "scale": 1.0}) |
|
|
|
|
|
if trainable_layers > 0: |
|
|
backbone_params: List[nn.Parameter] = [] |
|
|
for layer in model.backbone.layers[-trainable_layers:]: |
|
|
backbone_params.extend(layer.parameters()) |
|
|
backbone_params = _unique_parameters(backbone_params) |
|
|
if backbone_params: |
|
|
param_groups.append({"params": backbone_params, "scale": backbone_lr_factor}) |
|
|
|
|
|
elif name == "resnet18": |
|
|
from torchvision import models |
|
|
|
|
|
backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) |
|
|
backbone.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) |
|
|
nn.init.kaiming_normal_(backbone.conv1.weight, mode='fan_out', nonlinearity='relu') |
|
|
in_features = backbone.fc.in_features |
|
|
head_width = int(overrides.get("resnet_head_width", 384)) |
|
|
imagenet_head_dropout = float(overrides.get("imagenet_head_dropout", 0.45)) |
|
|
imagenet_head_dropout = max(0.0, min(imagenet_head_dropout, 0.9)) |
|
|
pre_fc_dropout = max(0.0, min(imagenet_head_dropout * 0.5, 0.9)) |
|
|
backbone.fc = nn.Sequential( |
|
|
nn.Dropout(p=pre_fc_dropout), |
|
|
nn.Linear(in_features, head_width), |
|
|
nn.LayerNorm(head_width), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Dropout(p=imagenet_head_dropout), |
|
|
nn.Linear(head_width, num_classes), |
|
|
) |
|
|
|
|
|
for param in backbone.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
head_params = list(backbone.fc.parameters()) |
|
|
for param in head_params: |
|
|
param.requires_grad = True |
|
|
imagenet_weight_decay = overrides.get("imagenet_weight_decay", None) |
|
|
head_group: Dict[str, object] = {"params": head_params, "scale": 1.0} |
|
|
if imagenet_weight_decay is not None: |
|
|
head_group["weight_decay"] = float(imagenet_weight_decay) |
|
|
param_groups.append(head_group) |
|
|
|
|
|
adapt_params: List[nn.Parameter] = [] |
|
|
if hasattr(backbone.layer4[0], "downsample") and backbone.layer4[0].downsample is not None: |
|
|
adapt_params.extend(backbone.layer4[0].downsample[0].parameters()) |
|
|
if len(backbone.layer4[0].downsample) > 1: |
|
|
adapt_params.extend(backbone.layer4[0].downsample[1].parameters()) |
|
|
for module in backbone.layer4[-1].modules(): |
|
|
if isinstance(module, nn.BatchNorm2d): |
|
|
adapt_params.extend(module.parameters()) |
|
|
adapt_params = _unique_parameters(adapt_params) |
|
|
for param in adapt_params: |
|
|
param.requires_grad = True |
|
|
if adapt_params: |
|
|
adapt_group: Dict[str, object] = {"params": adapt_params, "scale": backbone_lr_factor} |
|
|
if imagenet_weight_decay is not None: |
|
|
adapt_group["weight_decay"] = float(imagenet_weight_decay) |
|
|
param_groups.append(adapt_group) |
|
|
|
|
|
model = backbone |
|
|
|
|
|
elif name == "efficientnet_b0": |
|
|
from torchvision import models |
|
|
|
|
|
backbone = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT) |
|
|
first_conv = backbone.features[0][0] |
|
|
backbone.features[0][0] = nn.Conv2d(1, first_conv.out_channels, kernel_size=3, stride=2, padding=1, bias=False) |
|
|
nn.init.kaiming_normal_(backbone.features[0][0].weight, mode='fan_out', nonlinearity='relu') |
|
|
in_features = backbone.classifier[-1].in_features |
|
|
head_width = int(overrides.get("efficientnet_head_width", 192)) |
|
|
imagenet_head_dropout = float(overrides.get("imagenet_head_dropout", 0.45)) |
|
|
imagenet_head_dropout = max(0.0, min(imagenet_head_dropout, 0.9)) |
|
|
pre_fc_dropout = max(0.0, min(imagenet_head_dropout * 0.5, 0.9)) |
|
|
backbone.classifier = nn.Sequential( |
|
|
nn.Dropout(p=pre_fc_dropout), |
|
|
nn.Linear(in_features, head_width), |
|
|
nn.LayerNorm(head_width), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Dropout(p=imagenet_head_dropout), |
|
|
nn.Linear(head_width, num_classes), |
|
|
) |
|
|
|
|
|
for param in backbone.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
head_params = list(backbone.classifier.parameters()) |
|
|
for param in head_params: |
|
|
param.requires_grad = True |
|
|
imagenet_weight_decay = overrides.get("imagenet_weight_decay", None) |
|
|
head_group = {"params": head_params, "scale": 1.0} |
|
|
if imagenet_weight_decay is not None: |
|
|
head_group["weight_decay"] = float(imagenet_weight_decay) |
|
|
param_groups.append(head_group) |
|
|
|
|
|
adapt_params: List[nn.Parameter] = [] |
|
|
final_block = backbone.features[7][0] |
|
|
|
|
|
depthwise = final_block.block[1][0] |
|
|
adapt_params.extend(depthwise.parameters()) |
|
|
for idx in (0, 1, 3): |
|
|
for module in final_block.block[idx].modules(): |
|
|
if isinstance(module, nn.BatchNorm2d): |
|
|
adapt_params.extend(module.parameters()) |
|
|
adapt_params = _unique_parameters(adapt_params) |
|
|
for param in adapt_params: |
|
|
param.requires_grad = True |
|
|
if adapt_params: |
|
|
adapt_group = {"params": adapt_params, "scale": backbone_lr_factor} |
|
|
if imagenet_weight_decay is not None: |
|
|
adapt_group["weight_decay"] = float(imagenet_weight_decay) |
|
|
param_groups.append(adapt_group) |
|
|
|
|
|
model = backbone |
|
|
|
|
|
elif name == "mobilenet_v3_small": |
|
|
from torchvision import models |
|
|
|
|
|
backbone = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT) |
|
|
backbone.features[0][0] = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1, bias=False) |
|
|
nn.init.kaiming_normal_(backbone.features[0][0].weight, mode='fan_out', nonlinearity='relu') |
|
|
with torch.no_grad(): |
|
|
dummy = torch.zeros(1, 1, 128, 128) |
|
|
features = backbone.features(dummy) |
|
|
pooled = backbone.avgpool(features) |
|
|
flattened = torch.flatten(pooled, 1) |
|
|
in_features = flattened.shape[1] |
|
|
head_width = int(overrides.get("mobilenet_head_width", 320)) |
|
|
imagenet_head_dropout = float(overrides.get("imagenet_head_dropout", 0.45)) |
|
|
imagenet_head_dropout = max(0.0, min(imagenet_head_dropout, 0.9)) |
|
|
pre_fc_dropout = max(0.0, min(imagenet_head_dropout * 0.5, 0.9)) |
|
|
backbone.classifier = nn.Sequential( |
|
|
nn.Dropout(p=pre_fc_dropout), |
|
|
nn.Linear(in_features, head_width), |
|
|
nn.LayerNorm(head_width), |
|
|
nn.Hardswish(), |
|
|
nn.Dropout(p=imagenet_head_dropout), |
|
|
nn.Linear(head_width, num_classes), |
|
|
) |
|
|
|
|
|
for param in backbone.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
head_params = list(backbone.classifier.parameters()) |
|
|
for param in head_params: |
|
|
param.requires_grad = True |
|
|
imagenet_weight_decay = overrides.get("imagenet_weight_decay", None) |
|
|
head_group = {"params": head_params, "scale": 1.0} |
|
|
if imagenet_weight_decay is not None: |
|
|
head_group["weight_decay"] = float(imagenet_weight_decay) |
|
|
param_groups.append(head_group) |
|
|
|
|
|
adapt_params: List[nn.Parameter] = [] |
|
|
adapt_params.extend(backbone.features[-1][0].parameters()) |
|
|
for module in backbone.features[-2].modules(): |
|
|
if isinstance(module, nn.BatchNorm2d): |
|
|
adapt_params.extend(module.parameters()) |
|
|
adapt_params = _unique_parameters(adapt_params) |
|
|
for param in adapt_params: |
|
|
param.requires_grad = True |
|
|
if adapt_params: |
|
|
adapt_group = {"params": adapt_params, "scale": backbone_lr_factor} |
|
|
if imagenet_weight_decay is not None: |
|
|
adapt_group["weight_decay"] = float(imagenet_weight_decay) |
|
|
param_groups.append(adapt_group) |
|
|
|
|
|
model = backbone |
|
|
|
|
|
elif name in {"simple_cnn", "simplecnn"}: |
|
|
hidden_dims = overrides.get("simple_cnn_hidden_dims", (192,)) |
|
|
if isinstance(hidden_dims, Sequence) and not isinstance(hidden_dims, str): |
|
|
simple_dims = tuple(int(dim) for dim in hidden_dims) |
|
|
else: |
|
|
simple_dims = (int(hidden_dims),) |
|
|
model = create_simple_cnn(num_classes, hidden_dims=simple_dims) |
|
|
head_params = list(model.parameters()) |
|
|
param_groups.append({"params": head_params, "scale": 1.0}) |
|
|
|
|
|
elif name in {"ieee_cnn", "ieeecnn"}: |
|
|
hidden_dims = overrides.get("ieee_cnn_hidden_dims", (512, 256)) |
|
|
if isinstance(hidden_dims, Sequence) and not isinstance(hidden_dims, str): |
|
|
ieee_dims = tuple(int(dim) for dim in hidden_dims) |
|
|
else: |
|
|
ieee_dims = (int(hidden_dims),) |
|
|
dropout = float(overrides.get("ieee_cnn_dropout", 0.3)) |
|
|
model = create_ieee_cnn(num_classes, hidden_dims=ieee_dims, dropout=dropout) |
|
|
head_params = list(model.parameters()) |
|
|
param_groups.append({"params": head_params, "scale": 1.0}) |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unknown model: {name}") |
|
|
|
|
|
return model.to(device), param_groups |
|
|
|
|
|
|
|
|
def _unwrap_module(model: nn.Module) -> nn.Module: |
|
|
return model.module if isinstance(model, nn.DataParallel) else model |
|
|
|
|
|
|
|
|
def _strip_module_prefix(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
|
if not state_dict: |
|
|
return state_dict |
|
|
needs_strip = any(key.startswith("module.") for key in state_dict) |
|
|
if not needs_strip: |
|
|
return state_dict |
|
|
stripped = state_dict.__class__() if hasattr(state_dict, "__class__") else {} |
|
|
for key, value in state_dict.items(): |
|
|
new_key = key.split("module.", 1)[1] if key.startswith("module.") else key |
|
|
stripped[new_key] = value |
|
|
return stripped |
|
|
|
|
|
|
|
|
def _model_forward( |
|
|
model: nn.Module, |
|
|
specs: torch.Tensor, |
|
|
input_stats: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
base_model = _unwrap_module(model) |
|
|
is_lwm_like = isinstance(base_model, LWMClassifier) |
|
|
if not is_lwm_like and LWMClassifierMinimal is not None: |
|
|
is_lwm_like = isinstance(base_model, LWMClassifierMinimal) |
|
|
if not is_lwm_like and hasattr(base_model, "spectrogram_to_tokens"): |
|
|
is_lwm_like = True |
|
|
if is_lwm_like: |
|
|
while specs.dim() > 3 and specs.size(1) == 1: |
|
|
specs = specs.squeeze(1) |
|
|
if specs.dim() != 3: |
|
|
specs = specs.view(specs.size(0), specs.size(-2), specs.size(-1)) |
|
|
if not is_lwm_like and specs.dim() == 3: |
|
|
specs = specs.unsqueeze(1) |
|
|
if input_stats is not None: |
|
|
input_stats = input_stats.to(specs.device, non_blocking=True) |
|
|
supports_stats = bool( |
|
|
is_lwm_like and hasattr(base_model, "append_input_stats") and getattr(base_model, "append_input_stats") |
|
|
) |
|
|
if supports_stats and input_stats is not None: |
|
|
return model(specs, input_stats=input_stats) |
|
|
return model(specs) |
|
|
|
|
|
|
|
|
def train_one_epoch(model, loader, optimizer, device, scaler=None, batch_normalize: bool = False): |
|
|
criterion = nn.CrossEntropyLoss(reduction='mean') |
|
|
model.train() |
|
|
total_loss = 0.0 |
|
|
total = 0 |
|
|
for specs, labels in loader: |
|
|
specs = specs.to(device, non_blocking=True) |
|
|
if batch_normalize: |
|
|
specs = normalize_batch(specs) |
|
|
labels = labels.to(device, non_blocking=True) |
|
|
optimizer.zero_grad(set_to_none=True) |
|
|
|
|
|
|
|
|
if scaler is not None and device.type == 'cuda': |
|
|
with autocast(device_type='cuda'): |
|
|
logits = _model_forward(model, specs) |
|
|
loss = criterion(logits, labels) |
|
|
scaler.scale(loss).backward() |
|
|
scaler.step(optimizer) |
|
|
scaler.update() |
|
|
else: |
|
|
|
|
|
logits = _model_forward(model, specs) |
|
|
loss = criterion(logits, labels) |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
total_loss += loss.item() * labels.size(0) |
|
|
total += labels.size(0) |
|
|
|
|
|
|
|
|
if device.type == 'cuda': |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
return total_loss / max(total, 1) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def evaluate( |
|
|
model, |
|
|
loader, |
|
|
device, |
|
|
debug: Optional[Dict[str, object]] = None, |
|
|
batch_normalize: bool = False, |
|
|
) -> Tuple[float, float, float]: |
|
|
criterion = nn.CrossEntropyLoss(reduction='mean') |
|
|
model.eval() |
|
|
total_loss = 0.0 |
|
|
correct = 0 |
|
|
total = 0 |
|
|
all_preds: List[np.ndarray] = [] |
|
|
all_labels: List[np.ndarray] = [] |
|
|
|
|
|
debug_batches = int(debug.get("log_batches", 0)) if debug else 0 |
|
|
debug_every = max(1, int(debug.get("log_every", 1))) if debug else 1 |
|
|
log_softmax = bool(debug.get("log_softmax", False)) if debug else False |
|
|
debug_logged = 0 |
|
|
|
|
|
for batch_idx, batch in enumerate(loader, start=1): |
|
|
stats_batch: Optional[torch.Tensor] |
|
|
if isinstance(batch, (list, tuple)) and len(batch) == 3: |
|
|
specs, stats_batch, labels = batch |
|
|
stats_batch = stats_batch.to(device, non_blocking=True) |
|
|
else: |
|
|
specs, labels = batch |
|
|
stats_batch = None |
|
|
specs = specs.to(device, non_blocking=True) |
|
|
if batch_normalize: |
|
|
specs = normalize_batch(specs) |
|
|
labels = labels.to(device, non_blocking=True) |
|
|
|
|
|
|
|
|
if device.type == 'cuda': |
|
|
context = autocast(device_type='cuda') |
|
|
else: |
|
|
context = nullcontext() |
|
|
|
|
|
with context: |
|
|
logits = _model_forward(model, specs, stats_batch) |
|
|
loss = criterion(logits, labels) |
|
|
|
|
|
preds = logits.argmax(dim=1) |
|
|
total_loss += loss.item() * labels.size(0) |
|
|
correct += (preds == labels).sum().item() |
|
|
total += labels.size(0) |
|
|
all_preds.append(preds.detach().cpu().numpy()) |
|
|
all_labels.append(labels.detach().cpu().numpy()) |
|
|
|
|
|
should_log = ( |
|
|
debug_batches > 0 |
|
|
and debug_logged < debug_batches |
|
|
and ((batch_idx - 1) % debug_every == 0) |
|
|
) |
|
|
if should_log: |
|
|
specs_cpu = specs.detach().cpu() |
|
|
logits_cpu = logits.detach().cpu() |
|
|
loss_scalar = float(loss.detach().cpu().item()) |
|
|
finite_specs = torch.isfinite(specs).all().item() |
|
|
finite_logits = torch.isfinite(logits).all().item() |
|
|
print( |
|
|
f" [DEBUG][eval][batch {batch_idx}] loss={loss_scalar:.6f} " |
|
|
f"reduction={criterion.reduction} labels_shape={tuple(labels.shape)}" |
|
|
) |
|
|
print( |
|
|
f" specs dtype={specs.dtype} mean={specs_cpu.mean():.4f} std={specs_cpu.std():.4f} " |
|
|
f"min={specs_cpu.min():.4f} max={specs_cpu.max():.4f} finite={bool(finite_specs)}" |
|
|
) |
|
|
print( |
|
|
f" logits dtype={logits.dtype} mean={logits_cpu.mean():.4f} std={logits_cpu.std():.4f} " |
|
|
f"min={logits_cpu.min():.4f} max={logits_cpu.max():.4f} finite={bool(finite_logits)}" |
|
|
) |
|
|
unique_labels, counts = torch.unique(labels.detach().cpu(), return_counts=True) |
|
|
label_info = ", ".join( |
|
|
f"{int(lbl)}:{int(cnt)}" for lbl, cnt in zip(unique_labels, counts) |
|
|
) |
|
|
print(f" label distribution -> {label_info}") |
|
|
if log_softmax: |
|
|
probs = torch.softmax(logits_cpu, dim=1) |
|
|
print( |
|
|
f" softmax mean={probs.mean():.4f} std={probs.std():.4f} " |
|
|
f"min={probs.min():.4f} max={probs.max():.4f}" |
|
|
) |
|
|
debug_logged += 1 |
|
|
|
|
|
|
|
|
if device.type == 'cuda': |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
y_true = np.concatenate(all_labels) if all_labels else np.empty(0) |
|
|
y_pred = np.concatenate(all_preds) if all_preds else np.empty(0) |
|
|
f1 = compute_f1(y_true, y_pred) if y_true.size > 0 else 0.0 |
|
|
return total_loss / max(total, 1), correct / max(total, 1), f1 |
|
|
|
|
|
|
|
|
def set_seed(seed: int) -> None: |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
if HPU_AVAILABLE and hasattr(torch.hpu, "manual_seed"): |
|
|
torch.hpu.manual_seed(seed) |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
import os |
|
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' |
|
|
|
|
|
args = parse_args() |
|
|
|
|
|
if args.early_min_epochs < 10: |
|
|
print( |
|
|
f"[INFO] Requested early_min_epochs={args.early_min_epochs} < 10; enforcing minimum of 10" |
|
|
) |
|
|
args.early_min_epochs = 10 |
|
|
set_seed(args.seed) |
|
|
|
|
|
require_checkpoint = any(model.lower() == "lwm" for model in args.models) |
|
|
checkpoint, stats = resolve_checkpoint_and_stats(args, require_checkpoint=require_checkpoint) |
|
|
|
|
|
normalization_mode = str(stats.get("normalization", "dataset")).lower() |
|
|
print(f"[INFO] Normalization mode from stats: {normalization_mode}") |
|
|
|
|
|
data_root = Path(args.data_root) |
|
|
available_snrs, available_mobilities = discover_snr_mobility( |
|
|
data_root, args.cities, args.comm_types, args.fft_folder |
|
|
) |
|
|
train_snrs = args.snrs if args.snrs else available_snrs |
|
|
train_mobilities = args.mobilities if args.mobilities else available_mobilities |
|
|
val_snrs = args.val_snrs if args.val_snrs else available_snrs |
|
|
val_mobilities = args.val_mobilities if args.val_mobilities else available_mobilities |
|
|
|
|
|
class_configs = build_config_map( |
|
|
data_root, args.cities, args.comm_types, train_snrs, train_mobilities, args.fft_folder |
|
|
) |
|
|
active_labels = [cls for cls, configs in class_configs.items() if any(configs.values())] |
|
|
if not active_labels: |
|
|
raise RuntimeError("No modulation classes found with the provided filters.") |
|
|
class_configs = {cls: class_configs[cls] for cls in active_labels} |
|
|
label_to_local = {cls: idx for idx, cls in enumerate(sorted(active_labels))} |
|
|
num_classes = len(active_labels) |
|
|
print("[INFO] Active modulation classes:", ", ".join(LABEL_NAMES.get(cls, str(cls)) for cls in sorted(active_labels))) |
|
|
config_arrays = load_config_arrays(class_configs) |
|
|
global_config_map = build_config_map( |
|
|
Path(args.data_root), args.cities, args.comm_types, val_snrs, val_mobilities, args.fft_folder |
|
|
) |
|
|
global_config_arrays = load_config_arrays(global_config_map) |
|
|
print("[INFO] Training SNRs:", ", ".join(train_snrs)) |
|
|
print("[INFO] Training mobilities:", ", ".join(train_mobilities)) |
|
|
print("[INFO] Validation/Test SNRs:", ", ".join(val_snrs)) |
|
|
print("[INFO] Validation/Test mobilities:", ", ".join(val_mobilities)) |
|
|
|
|
|
per_class_totals: Dict[int, int] = {} |
|
|
for cls in sorted(active_labels): |
|
|
configs = config_arrays[cls] |
|
|
total_samples = sum(arr.shape[0] for arr in configs.values()) |
|
|
per_class_totals[cls] = total_samples |
|
|
print(f"[INFO] Class {LABEL_NAMES.get(cls, str(cls))}: {len(configs)} configs, {total_samples} samples") |
|
|
if cls not in global_config_arrays or not global_config_arrays[cls]: |
|
|
raise RuntimeError(f"No global data found for modulation {LABEL_NAMES.get(cls, str(cls))}") |
|
|
|
|
|
min_class_total = min(per_class_totals.values()) |
|
|
max_train_per_class = min_class_total - args.val_per_class - args.test_per_class |
|
|
if max_train_per_class <= 0: |
|
|
raise RuntimeError( |
|
|
"Requested val/test splits leave no data for training. " |
|
|
f"Minimum class has {min_class_total} samples; " |
|
|
f"val={args.val_per_class}, test={args.test_per_class}." |
|
|
) |
|
|
if any(size > max_train_per_class for size in args.train_sizes): |
|
|
adjusted: List[int] = [] |
|
|
for size in args.train_sizes: |
|
|
if size > max_train_per_class: |
|
|
print( |
|
|
f"[WARN] Requested train size {size} exceeds available " |
|
|
f"{max_train_per_class} per class after val/test splits; capping." |
|
|
) |
|
|
capped = min(size, max_train_per_class) |
|
|
if capped not in adjusted: |
|
|
adjusted.append(capped) |
|
|
args.train_sizes = adjusted if adjusted else [max_train_per_class] |
|
|
print(f"[INFO] Effective train sizes per class: {args.train_sizes}") |
|
|
|
|
|
|
|
|
requested_device = args.device.lower() |
|
|
|
|
|
if requested_device == "auto": |
|
|
if HPU_AVAILABLE: |
|
|
requested_device = "hpu" |
|
|
elif torch.cuda.is_available(): |
|
|
requested_device = "cuda" |
|
|
else: |
|
|
requested_device = "cpu" |
|
|
|
|
|
|
|
|
if requested_device == "hpu": |
|
|
if not HPU_AVAILABLE: |
|
|
raise RuntimeError( |
|
|
"HPU device requested but not available. " |
|
|
"Install Habana PyTorch or select --device cuda/cpu." |
|
|
) |
|
|
device = torch.device("hpu") |
|
|
|
|
|
if hasattr(torch.hpu, "set_device"): |
|
|
torch.hpu.set_device(0) |
|
|
print(f"[INFO] Using HPU device") |
|
|
active_gpu_ids = [] |
|
|
multi_gpu = False |
|
|
|
|
|
elif requested_device == "cuda": |
|
|
cuda_available = torch.cuda.is_available() |
|
|
if not cuda_available: |
|
|
raise RuntimeError("CUDA device requested but not available.") |
|
|
|
|
|
available_gpu_ids = list(range(torch.cuda.device_count())) |
|
|
if args.gpu_ids is not None: |
|
|
invalid_ids = [gpu_id for gpu_id in args.gpu_ids if gpu_id not in available_gpu_ids] |
|
|
if invalid_ids: |
|
|
raise ValueError( |
|
|
f"Requested GPU IDs not available: {invalid_ids}; available: {available_gpu_ids}" |
|
|
) |
|
|
active_gpu_ids = list(dict.fromkeys(args.gpu_ids)) |
|
|
else: |
|
|
active_gpu_ids = available_gpu_ids |
|
|
|
|
|
if active_gpu_ids: |
|
|
primary_gpu = active_gpu_ids[0] |
|
|
torch.cuda.set_device(primary_gpu) |
|
|
device = torch.device(f"cuda:{primary_gpu}") |
|
|
print(f"[INFO] Using CUDA device(s): {', '.join(str(i) for i in active_gpu_ids)}") |
|
|
else: |
|
|
device = torch.device("cpu") |
|
|
print("[INFO] CUDA requested but no GPUs available, using CPU") |
|
|
|
|
|
multi_gpu = len(active_gpu_ids) > 1 |
|
|
if multi_gpu: |
|
|
print(f"[INFO] Enabling DataParallel across GPUs: {', '.join(str(i) for i in active_gpu_ids)}") |
|
|
|
|
|
else: |
|
|
device = torch.device("cpu") |
|
|
if args.gpu_ids is not None: |
|
|
print("[WARN] GPU IDs specified but using CPU") |
|
|
print("[INFO] Using CPU") |
|
|
active_gpu_ids = [] |
|
|
multi_gpu = False |
|
|
|
|
|
print(f"[INFO] Using device: {device}") |
|
|
args.output_dir.mkdir(parents=True, exist_ok=True) |
|
|
print(f"[INFO] Saving outputs under: {args.output_dir}") |
|
|
|
|
|
eval_debug_config: Optional[Dict[str, object]] = None |
|
|
if args.debug_eval_batches > 0: |
|
|
eval_debug_config = { |
|
|
"log_batches": int(args.debug_eval_batches), |
|
|
"log_every": max(1, int(args.debug_eval_interval)), |
|
|
"log_softmax": bool(args.debug_eval_softmax), |
|
|
} |
|
|
print( |
|
|
"[INFO] Evaluation debug logging enabled -> batches:" |
|
|
f" {eval_debug_config['log_batches']}, interval: {eval_debug_config['log_every']}" |
|
|
) |
|
|
|
|
|
summary_device = torch.device("cpu") if device.type == "cuda" else device |
|
|
model_overrides: Dict[str, object] = { |
|
|
"resnet_head_width": args.resnet_head_width, |
|
|
"efficientnet_head_width": args.efficientnet_head_width, |
|
|
"mobilenet_head_width": args.mobilenet_head_width, |
|
|
"simple_cnn_hidden_dims": tuple(args.simple_cnn_hidden_dims), |
|
|
"ieee_cnn_hidden_dims": tuple(args.ieee_cnn_hidden_dims), |
|
|
"ieee_cnn_dropout": args.ieee_cnn_dropout, |
|
|
"lwm_classifier_dim": args.lwm_classifier_dim, |
|
|
"lwm_head_dropout": args.lwm_head_dropout, |
|
|
"lwm_head_type": args.lwm_head_type, |
|
|
"imagenet_head_dropout": args.imagenet_head_dropout, |
|
|
"imagenet_weight_decay": args.weight_decay * args.imagenet_weight_decay_scale, |
|
|
} |
|
|
print("\n[INFO] Parameter counts per model (total/trainable):") |
|
|
for model_name in args.models: |
|
|
lower_name = model_name.lower() |
|
|
trainable_layers = args.lwm_trainable_layers if lower_name == "lwm" else 0 |
|
|
model_checkpoint = checkpoint |
|
|
backbone_lr_factor = args.backbone_lr_factor |
|
|
if lower_name == "lwm" and args.lwm_backbone_lr_factor is not None: |
|
|
backbone_lr_factor = args.lwm_backbone_lr_factor |
|
|
model, _ = build_model( |
|
|
model_name, |
|
|
num_classes, |
|
|
model_checkpoint, |
|
|
summary_device, |
|
|
trainable_layers=trainable_layers, |
|
|
backbone_lr_factor=backbone_lr_factor, |
|
|
overrides=model_overrides, |
|
|
) |
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
print(f" {model_name}: {total_params:,} / {trainable_params:,}") |
|
|
del model |
|
|
if device.type == 'cuda': |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
raw_input_models = set(args.raw_input_models) |
|
|
active_raw_models = [model for model in args.models if model.lower() in raw_input_models] |
|
|
if active_raw_models: |
|
|
print( |
|
|
"[INFO] Raw spectrogram input (per-batch normalization) for models: " |
|
|
+ ", ".join(active_raw_models) |
|
|
) |
|
|
|
|
|
normalized_models = [model for model in args.models if model.lower() not in raw_input_models] |
|
|
requires_normalized_inputs = len(normalized_models) > 0 |
|
|
if requires_normalized_inputs: |
|
|
print( |
|
|
"[INFO] Applying normalization for models: " |
|
|
+ ", ".join(normalized_models) |
|
|
) |
|
|
else: |
|
|
print("[INFO] All selected models consume raw spectrograms; normalization skipped") |
|
|
|
|
|
summary: Dict[str, Dict[int, Dict[str, List[float]]]] = { |
|
|
model: {size: {"acc": [], "f1": [], "val_f1": [], "val_loss": []} for size in args.train_sizes} |
|
|
for model in args.models |
|
|
} |
|
|
|
|
|
train_sizes_sorted = sorted(args.train_sizes) |
|
|
|
|
|
for repetition in range(1, args.repetitions + 1): |
|
|
selection_for_repetition: Dict[int, Dict[str, set[int]]] = {} |
|
|
val_rng_seed = args.seed + repetition * 100000 |
|
|
val_rng = np.random.default_rng(val_rng_seed) |
|
|
fixed_val_samples: Dict[int, np.ndarray] = {} |
|
|
fixed_test_samples: Dict[int, np.ndarray] = {} |
|
|
val_reserved_indices: Dict[int, Dict[str, set[int]]] = {} |
|
|
test_reserved_indices: Dict[int, Dict[str, set[int]]] = {} |
|
|
|
|
|
for cls in sorted(config_arrays.keys()): |
|
|
global_arrays = global_config_arrays[cls] |
|
|
global_avail = {cfg: set(range(arr.shape[0])) for cfg, arr in global_arrays.items()} |
|
|
val_samples, val_used = sample_global_arrays(global_arrays, global_avail, args.val_per_class, val_rng) |
|
|
test_samples, test_used = sample_global_arrays(global_arrays, global_avail, args.test_per_class, val_rng) |
|
|
fixed_val_samples[cls] = val_samples |
|
|
fixed_test_samples[cls] = test_samples |
|
|
val_reserved_indices[cls] = {cfg: set(indices) for cfg, indices in val_used.items()} |
|
|
test_reserved_indices[cls] = {cfg: set(indices) for cfg, indices in test_used.items()} |
|
|
|
|
|
for train_size in train_sizes_sorted: |
|
|
rep_seed = args.seed + train_size * 1000 + repetition |
|
|
rng = np.random.default_rng(rep_seed) |
|
|
|
|
|
repetition_records: List[Dict[str, object]] = [] |
|
|
per_size_val_metrics: List[Tuple[str, float, float, float]] = [] |
|
|
|
|
|
train_specs, train_labels = [], [] |
|
|
val_specs, val_labels = [], [] |
|
|
test_specs, test_labels = [], [] |
|
|
|
|
|
class_contexts: Dict[int, Dict[str, Any]] = {} |
|
|
class_capacities: Dict[int, int] = {} |
|
|
for cls in sorted(config_arrays.keys()): |
|
|
arrays_map = config_arrays[cls] |
|
|
if not arrays_map: |
|
|
raise RuntimeError(f"No data for class {LABEL_NAMES[cls]}") |
|
|
|
|
|
if cls not in selection_for_repetition: |
|
|
selection_for_repetition[cls] = defaultdict(set) |
|
|
prev_config_indices = selection_for_repetition[cls] |
|
|
|
|
|
prev_total = sum(len(sel_indices) for sel_indices in prev_config_indices.values()) |
|
|
if prev_total > train_size: |
|
|
raise ValueError( |
|
|
f"Requested train size {train_size} is smaller than previously selected {prev_total} " |
|
|
f"for class {LABEL_NAMES[cls]}" |
|
|
) |
|
|
|
|
|
train_avail = {config: set(range(arr.shape[0])) for config, arr in arrays_map.items()} |
|
|
for config, sel_indices in prev_config_indices.items(): |
|
|
if sel_indices and config in train_avail: |
|
|
train_avail[config].difference_update(sel_indices) |
|
|
val_reserved = val_reserved_indices.get(cls, {}) |
|
|
test_reserved = test_reserved_indices.get(cls, {}) |
|
|
for config, reserved in val_reserved.items(): |
|
|
if reserved and config in train_avail: |
|
|
train_avail[config].difference_update(reserved) |
|
|
for config, reserved in test_reserved.items(): |
|
|
if reserved and config in train_avail: |
|
|
train_avail[config].difference_update(reserved) |
|
|
|
|
|
available_now = sum(len(indices) for indices in train_avail.values()) |
|
|
capacity = prev_total + available_now |
|
|
class_contexts[cls] = { |
|
|
"arrays_map": arrays_map, |
|
|
"prev_indices": prev_config_indices, |
|
|
"train_avail": train_avail, |
|
|
} |
|
|
class_capacities[cls] = capacity |
|
|
|
|
|
if not class_capacities: |
|
|
raise RuntimeError("No modulation classes available for training") |
|
|
|
|
|
min_capacity = min(class_capacities.values()) |
|
|
limiting_classes = sorted(cls for cls, cap in class_capacities.items() if cap == min_capacity) |
|
|
effective_train_size = min(train_size, min_capacity) |
|
|
if effective_train_size < train_size: |
|
|
limiting_labels = ", ".join(LABEL_NAMES.get(cls, str(cls)) for cls in limiting_classes) |
|
|
if not limiting_labels: |
|
|
limiting_labels = "unknown" |
|
|
print( |
|
|
f"[WARN] Requested train size {train_size} exceeds available " |
|
|
f"{min_capacity} after reserving val/test samples; using {effective_train_size} " |
|
|
f"(limited by {limiting_labels})" |
|
|
) |
|
|
if effective_train_size <= 0: |
|
|
raise RuntimeError("No training samples available after reserving val/test splits") |
|
|
|
|
|
for cls in sorted(config_arrays.keys()): |
|
|
ctx = class_contexts[cls] |
|
|
arrays_map = ctx["arrays_map"] |
|
|
prev_config_indices = ctx["prev_indices"] |
|
|
train_avail = ctx["train_avail"] |
|
|
|
|
|
selected_arrays: List[np.ndarray] = [] |
|
|
prev_total = 0 |
|
|
for config, sel_indices in prev_config_indices.items(): |
|
|
if sel_indices: |
|
|
idx_sorted = sorted(sel_indices) |
|
|
selected_arrays.append(arrays_map[config][idx_sorted]) |
|
|
prev_total += len(sel_indices) |
|
|
|
|
|
needed = max(effective_train_size - prev_total, 0) |
|
|
if needed > 0: |
|
|
additional_samples, train_used = sample_train_arrays(arrays_map, train_avail, needed, rng) |
|
|
if additional_samples.size == 0: |
|
|
raise RuntimeError("Failed to collect additional training samples") |
|
|
selected_arrays.append(additional_samples) |
|
|
for config, indices in train_used.items(): |
|
|
prev_config_indices[config].update(int(idx) for idx in indices) |
|
|
|
|
|
if not selected_arrays: |
|
|
raise RuntimeError("No training samples collected") |
|
|
|
|
|
train_samples = np.concatenate(selected_arrays, axis=0) |
|
|
if train_samples.shape[0] != effective_train_size: |
|
|
print( |
|
|
f"[WARN] Collected {train_samples.shape[0]} training samples for " |
|
|
f"{LABEL_NAMES.get(cls, str(cls))}, expected {effective_train_size}" |
|
|
) |
|
|
|
|
|
val_samples = fixed_val_samples[cls] |
|
|
test_samples = fixed_test_samples[cls] |
|
|
train_specs.append(train_samples) |
|
|
val_specs.append(val_samples) |
|
|
test_specs.append(test_samples) |
|
|
local_label = label_to_local[cls] |
|
|
train_labels.append(np.full(train_samples.shape[0], local_label, dtype=np.int64)) |
|
|
val_labels.append(np.full(val_samples.shape[0], local_label, dtype=np.int64)) |
|
|
test_labels.append(np.full(test_samples.shape[0], local_label, dtype=np.int64)) |
|
|
|
|
|
train_specs_raw = np.concatenate(train_specs) |
|
|
val_specs_raw = np.concatenate(val_specs) |
|
|
test_specs_raw = np.concatenate(test_specs) |
|
|
train_labels = np.concatenate(train_labels) |
|
|
val_labels = np.concatenate(val_labels) |
|
|
test_labels = np.concatenate(test_labels) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print( |
|
|
f"[INFO] Verifying data splits for train_size={train_size} " |
|
|
f"(effective {effective_train_size}), rep={repetition}..." |
|
|
) |
|
|
print( |
|
|
f" Train: {len(train_labels)} samples " |
|
|
f"(~{effective_train_size} per class expected)" |
|
|
) |
|
|
print(f" Val: {len(val_labels)} samples ({args.val_per_class} per class)") |
|
|
print(f" Test: {len(test_labels)} samples ({args.test_per_class} per class)") |
|
|
|
|
|
|
|
|
train_class_counts = Counter(train_labels) |
|
|
val_class_counts = Counter(val_labels) |
|
|
test_class_counts = Counter(test_labels) |
|
|
|
|
|
print(f"[INFO] Train class distribution: {dict(train_class_counts)}") |
|
|
print(f"[INFO] Val class distribution: {dict(val_class_counts)}") |
|
|
print(f"[INFO] Test class distribution: {dict(test_class_counts)}") |
|
|
|
|
|
|
|
|
expected_train_per_class = effective_train_size |
|
|
for cls_idx in range(num_classes): |
|
|
if train_class_counts[cls_idx] != expected_train_per_class: |
|
|
print(f"[WARN] Class {cls_idx} has {train_class_counts[cls_idx]} train samples, expected {expected_train_per_class}") |
|
|
if val_class_counts[cls_idx] != args.val_per_class: |
|
|
print(f"[WARN] Class {cls_idx} has {val_class_counts[cls_idx]} val samples, expected {args.val_per_class}") |
|
|
if test_class_counts[cls_idx] != args.test_per_class: |
|
|
print(f"[WARN] Class {cls_idx} has {test_class_counts[cls_idx]} test samples, expected {args.test_per_class}") |
|
|
|
|
|
print(f"[INFO] ✓ All splits have balanced class distribution") |
|
|
|
|
|
train_ds_raw = SpectrogramDataset(train_specs_raw, train_labels) |
|
|
val_ds_raw = SpectrogramDataset(val_specs_raw, val_labels) |
|
|
test_ds_raw = SpectrogramDataset(test_specs_raw, test_labels) |
|
|
|
|
|
train_loader_raw = DataLoader( |
|
|
train_ds_raw, |
|
|
batch_size=args.batch_size, |
|
|
shuffle=True, |
|
|
num_workers=2, |
|
|
pin_memory=False, |
|
|
) |
|
|
val_loader_raw = DataLoader( |
|
|
val_ds_raw, |
|
|
batch_size=args.batch_size, |
|
|
shuffle=False, |
|
|
num_workers=2, |
|
|
pin_memory=False, |
|
|
) |
|
|
test_loader_raw = DataLoader( |
|
|
test_ds_raw, |
|
|
batch_size=args.batch_size, |
|
|
shuffle=False, |
|
|
num_workers=2, |
|
|
pin_memory=False, |
|
|
) |
|
|
|
|
|
train_loader_normalized: Optional[DataLoader] = None |
|
|
val_loader_normalized: Optional[DataLoader] = None |
|
|
test_loader_normalized: Optional[DataLoader] = None |
|
|
|
|
|
if requires_normalized_inputs: |
|
|
train_specs_normalized = apply_normalization(train_specs_raw, stats) |
|
|
val_specs_normalized = apply_normalization(val_specs_raw, stats) |
|
|
test_specs_normalized = apply_normalization(test_specs_raw, stats) |
|
|
|
|
|
train_ds_normalized = SpectrogramDataset(train_specs_normalized, train_labels) |
|
|
val_ds_normalized = SpectrogramDataset(val_specs_normalized, val_labels) |
|
|
test_ds_normalized = SpectrogramDataset(test_specs_normalized, test_labels) |
|
|
|
|
|
train_loader_normalized = DataLoader( |
|
|
train_ds_normalized, |
|
|
batch_size=args.batch_size, |
|
|
shuffle=True, |
|
|
num_workers=2, |
|
|
pin_memory=False, |
|
|
) |
|
|
val_loader_normalized = DataLoader( |
|
|
val_ds_normalized, |
|
|
batch_size=args.batch_size, |
|
|
shuffle=False, |
|
|
num_workers=2, |
|
|
pin_memory=False, |
|
|
) |
|
|
test_loader_normalized = DataLoader( |
|
|
test_ds_normalized, |
|
|
batch_size=args.batch_size, |
|
|
shuffle=False, |
|
|
num_workers=2, |
|
|
pin_memory=False, |
|
|
) |
|
|
|
|
|
rep_root = args.output_dir / f"size_{train_size}" / f"rep_{repetition}" |
|
|
rep_root.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
for model_name in args.models: |
|
|
model_root = rep_root / model_name |
|
|
model_root.mkdir(parents=True, exist_ok=True) |
|
|
epoch_ckpt_dir: Optional[Path] = None |
|
|
if args.save_epoch_checkpoints: |
|
|
epoch_ckpt_dir = model_root / "epoch_checkpoints" |
|
|
epoch_ckpt_dir.mkdir(parents=True, exist_ok=True) |
|
|
for old_path in epoch_ckpt_dir.glob("epoch_*.pth"): |
|
|
if old_path.is_file(): |
|
|
old_path.unlink() |
|
|
print( |
|
|
f"\n[INFO] Size {train_size} (effective {effective_train_size}), " |
|
|
f"repetition {repetition}, model {model_name}" |
|
|
) |
|
|
set_seed(rep_seed + hash(model_name) % 1000) |
|
|
lower_name = model_name.lower() |
|
|
use_raw_input = lower_name in raw_input_models |
|
|
if use_raw_input: |
|
|
train_loader = train_loader_raw |
|
|
val_loader = val_loader_raw |
|
|
test_loader = test_loader_raw |
|
|
print(" [INFO] Feeding raw spectrograms with per-batch normalization") |
|
|
else: |
|
|
if ( |
|
|
train_loader_normalized is None |
|
|
or val_loader_normalized is None |
|
|
or test_loader_normalized is None |
|
|
): |
|
|
raise RuntimeError( |
|
|
"Normalized loaders were requested but could not be constructed." |
|
|
) |
|
|
train_loader = train_loader_normalized |
|
|
val_loader = val_loader_normalized |
|
|
test_loader = test_loader_normalized |
|
|
|
|
|
trainable_layers = args.lwm_trainable_layers if lower_name == "lwm" else 0 |
|
|
backbone_lr_factor = args.backbone_lr_factor |
|
|
if lower_name == "lwm" and args.lwm_backbone_lr_factor is not None: |
|
|
backbone_lr_factor = args.lwm_backbone_lr_factor |
|
|
model_checkpoint = checkpoint |
|
|
model, param_groups = build_model( |
|
|
model_name, |
|
|
num_classes, |
|
|
model_checkpoint, |
|
|
device, |
|
|
trainable_layers=trainable_layers, |
|
|
backbone_lr_factor=backbone_lr_factor, |
|
|
overrides=model_overrides, |
|
|
) |
|
|
if multi_gpu: |
|
|
model = nn.DataParallel(model, device_ids=active_gpu_ids) |
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
print( |
|
|
f"[INFO] Parameters (total/trainable): {total_params:,} / {trainable_params:,}" |
|
|
) |
|
|
def make_optimizer(base_lr: float) -> torch.optim.Optimizer: |
|
|
optim_groups: List[Dict[str, object]] = [] |
|
|
if param_groups: |
|
|
for group in param_groups: |
|
|
scale = float(group.get("scale", 1.0)) |
|
|
params = [p for p in group.get("params", []) if p.requires_grad] |
|
|
if params: |
|
|
group_cfg: Dict[str, object] = { |
|
|
"params": list(params), |
|
|
"lr": base_lr * scale, |
|
|
} |
|
|
if "weight_decay" in group: |
|
|
group_cfg["weight_decay"] = float(group["weight_decay"]) |
|
|
optim_groups.append(group_cfg) |
|
|
if not optim_groups: |
|
|
optim_groups.append({ |
|
|
"params": [p for p in model.parameters() if p.requires_grad], |
|
|
"lr": base_lr, |
|
|
}) |
|
|
return torch.optim.AdamW(optim_groups, lr=base_lr, weight_decay=args.weight_decay) |
|
|
|
|
|
def make_scheduler(optimizer: torch.optim.Optimizer, base_lr: float, patience_limit: int): |
|
|
plateau_patience = max(2, patience_limit // 2) |
|
|
return torch.optim.lr_scheduler.ReduceLROnPlateau( |
|
|
optimizer, |
|
|
mode="min", |
|
|
factor=0.5, |
|
|
patience=plateau_patience, |
|
|
min_lr=base_lr * 0.01, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
scaler = GradScaler('cuda') if device.type == 'cuda' else None |
|
|
best_val_loss = float("inf") |
|
|
best_val_acc = 0.0 |
|
|
best_state = None |
|
|
epoch_history: List[Dict[str, object]] = [] |
|
|
best_val_f1 = 0.0 |
|
|
best_epoch = 0 |
|
|
total_epochs_ran = 0 |
|
|
overall_early_stopped = False |
|
|
|
|
|
phase_configs = [ |
|
|
{ |
|
|
"name": "main", |
|
|
"max_epochs": args.epochs, |
|
|
"base_lr": args.lr, |
|
|
"patience": max(1, args.early_patience), |
|
|
"min_epochs": max(0, args.early_min_epochs), |
|
|
} |
|
|
] |
|
|
|
|
|
ft_epochs = max(0, args.finetune_epochs) |
|
|
ft_lr_factor = args.finetune_lr_factor |
|
|
ft_patience = max(1, args.finetune_patience) |
|
|
ft_min_epochs = max(0, args.finetune_min_epochs) |
|
|
|
|
|
if ft_epochs > 0: |
|
|
phase_configs.append( |
|
|
{ |
|
|
"name": "finetune", |
|
|
"max_epochs": ft_epochs, |
|
|
"base_lr": args.lr * ft_lr_factor, |
|
|
"patience": ft_patience, |
|
|
"min_epochs": ft_min_epochs, |
|
|
} |
|
|
) |
|
|
|
|
|
for phase_idx, phase in enumerate(phase_configs): |
|
|
if phase["max_epochs"] <= 0: |
|
|
continue |
|
|
|
|
|
if phase_idx > 0: |
|
|
print( |
|
|
f"\n [INFO] Starting {phase['name']} phase: lr={phase['base_lr']:.2e}, " |
|
|
f"max_epochs={phase['max_epochs']}" |
|
|
) |
|
|
if best_state is not None: |
|
|
model.load_state_dict(best_state["model"]) |
|
|
|
|
|
optimizer = make_optimizer(phase["base_lr"]) |
|
|
scheduler = make_scheduler(optimizer, phase["base_lr"], phase["patience"]) |
|
|
patience_counter = 0 |
|
|
phase_early_stopped = False |
|
|
phase_epochs_completed = 0 |
|
|
phase_min_epochs = max(0, phase["min_epochs"]) |
|
|
|
|
|
for local_epoch in range(1, phase["max_epochs"] + 1): |
|
|
overall_epoch = total_epochs_ran + local_epoch |
|
|
train_loss = train_one_epoch( |
|
|
model, |
|
|
train_loader, |
|
|
optimizer, |
|
|
device, |
|
|
scaler, |
|
|
batch_normalize=use_raw_input, |
|
|
) |
|
|
val_loss, val_acc, val_f1 = evaluate( |
|
|
model, |
|
|
val_loader, |
|
|
device, |
|
|
eval_debug_config, |
|
|
batch_normalize=use_raw_input, |
|
|
) |
|
|
scheduler.step(val_loss) |
|
|
current_lr = optimizer.param_groups[0]["lr"] |
|
|
print( |
|
|
f" [{phase['name']}] Epoch {overall_epoch:02d}: " |
|
|
f"train_loss={train_loss:.4f} val_loss={val_loss:.4f} " |
|
|
f"val_acc={val_acc:.4%} val_f1={val_f1:.4f}" |
|
|
) |
|
|
epoch_history.append( |
|
|
{ |
|
|
"epoch": int(overall_epoch), |
|
|
"train_loss": float(train_loss), |
|
|
"val_loss": float(val_loss), |
|
|
"val_acc": float(val_acc), |
|
|
"val_f1": float(val_f1), |
|
|
"lr": float(current_lr), |
|
|
"phase": phase["name"], |
|
|
} |
|
|
) |
|
|
repetition_records.append( |
|
|
{ |
|
|
"model": model_name, |
|
|
"epoch": int(overall_epoch), |
|
|
"phase": phase["name"], |
|
|
"train_loss": float(train_loss), |
|
|
"val_loss": float(val_loss), |
|
|
"val_acc": float(val_acc), |
|
|
"val_f1": float(val_f1), |
|
|
"lr": float(current_lr), |
|
|
"train_size_requested": int(train_size), |
|
|
"train_size_effective": int(effective_train_size), |
|
|
} |
|
|
) |
|
|
_write_epoch_history(rep_root, repetition_records, args.save_epoch_history, args.plot_epoch_history) |
|
|
raw_epoch_state = _strip_module_prefix(model.state_dict()) |
|
|
if epoch_ckpt_dir is not None: |
|
|
epoch_state = raw_epoch_state.__class__() |
|
|
for key, value in raw_epoch_state.items(): |
|
|
epoch_state[key] = value.detach().cpu() |
|
|
epoch_ckpt_path = epoch_ckpt_dir / f"epoch_{overall_epoch:03d}.pth" |
|
|
torch.save(epoch_state, epoch_ckpt_path) |
|
|
if val_loss < best_val_loss: |
|
|
best_val_loss = val_loss |
|
|
best_val_acc = val_acc |
|
|
best_val_f1 = val_f1 |
|
|
best_model_state = { |
|
|
key: value.detach().cpu() |
|
|
for key, value in model.state_dict().items() |
|
|
} |
|
|
best_state = { |
|
|
"model": best_model_state, |
|
|
"val_loss": val_loss, |
|
|
"val_acc": val_acc, |
|
|
"val_f1": val_f1, |
|
|
"epoch": int(overall_epoch), |
|
|
"lr": current_lr, |
|
|
"phase": phase["name"], |
|
|
} |
|
|
best_epoch = int(overall_epoch) |
|
|
patience_counter = 0 |
|
|
else: |
|
|
if local_epoch >= phase_min_epochs: |
|
|
patience_counter += 1 |
|
|
if patience_counter >= phase["patience"]: |
|
|
print( |
|
|
f" [INFO] Early stopping ({phase['name']}) at epoch {overall_epoch:02d} " |
|
|
f"after {patience_counter} epochs without val loss improvement" |
|
|
) |
|
|
overall_early_stopped = True |
|
|
phase_early_stopped = True |
|
|
phase_epochs_completed = local_epoch |
|
|
break |
|
|
phase_epochs_completed = local_epoch |
|
|
|
|
|
total_epochs_ran += phase_epochs_completed |
|
|
|
|
|
if phase_early_stopped is False and phase_epochs_completed < phase["max_epochs"]: |
|
|
|
|
|
phase_early_stopped = True |
|
|
|
|
|
if best_state is None: |
|
|
raise RuntimeError("Training finished without recording a validation improvement") |
|
|
|
|
|
model.load_state_dict(best_state["model"]) |
|
|
test_loss, test_acc, test_f1 = evaluate( |
|
|
model, |
|
|
test_loader, |
|
|
device, |
|
|
eval_debug_config, |
|
|
batch_normalize=use_raw_input, |
|
|
) |
|
|
print( |
|
|
f" -> Test loss={test_loss:.4f} Test acc={test_acc:.4%} Test f1={test_f1:.4f}" |
|
|
) |
|
|
|
|
|
export_dir = getattr(args, "export_full_model", None) |
|
|
if export_dir is not None: |
|
|
export_dir = export_dir.expanduser().resolve() |
|
|
export_dir.mkdir(parents=True, exist_ok=True) |
|
|
comm_token = getattr(args, "comm_suffix", "multi") |
|
|
filename = f"{comm_token}_{model_name}_size{train_size}_rep{repetition}.pth" |
|
|
export_path = export_dir / filename |
|
|
full_state = {k: v.detach().cpu() for k, v in model.state_dict().items()} |
|
|
torch.save(full_state, export_path) |
|
|
print(f" [INFO] Saved full model (backbone + head) to {export_path}") |
|
|
|
|
|
summary[model_name][train_size]["acc"].append(test_acc) |
|
|
summary[model_name][train_size]["f1"].append(test_f1) |
|
|
summary[model_name][train_size]["val_f1"].append(best_state["val_f1"]) |
|
|
summary[model_name][train_size]["val_loss"].append(best_state["val_loss"]) |
|
|
per_size_val_metrics.append( |
|
|
(model_name, best_state["val_f1"], best_state["val_loss"], test_f1) |
|
|
) |
|
|
|
|
|
result_dir = args.output_dir / f"size_{train_size}" / f"rep_{repetition}" / model_name |
|
|
result_dir.mkdir(parents=True, exist_ok=True) |
|
|
state_to_save = copy.deepcopy(best_state) |
|
|
state_to_save["model"] = _strip_module_prefix(state_to_save["model"]) |
|
|
torch.save(state_to_save, result_dir / "checkpoint.pt") |
|
|
with open(result_dir / "metrics.json", "w", encoding="utf-8") as f: |
|
|
json.dump( |
|
|
{ |
|
|
"train_size_per_class": effective_train_size, |
|
|
"train_size_per_class_requested": train_size, |
|
|
"repetition": repetition, |
|
|
"model": model_name, |
|
|
"best_val_loss": best_state.get("val_loss", None), |
|
|
"best_val_acc": best_val_acc, |
|
|
"best_val_f1": best_state["val_f1"], |
|
|
"test_loss": test_loss, |
|
|
"test_acc": test_acc, |
|
|
"test_f1": test_f1, |
|
|
"best_epoch": best_epoch, |
|
|
"epochs_ran": total_epochs_ran, |
|
|
"early_stopped": overall_early_stopped, |
|
|
"history": epoch_history, |
|
|
}, |
|
|
f, |
|
|
indent=2, |
|
|
) |
|
|
|
|
|
rep_root = args.output_dir / f"size_{train_size}" / f"rep_{repetition}" |
|
|
rep_root.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
if args.plot_epoch_history and HAVE_MPL and repetition_records: |
|
|
models_in_run = sorted({rec["model"] for rec in repetition_records}) |
|
|
fig, axes = plt.subplots(2, 1, figsize=(8, 6), sharex=True) |
|
|
for ax in axes: |
|
|
ax.grid(True, linestyle='--', alpha=0.3) |
|
|
for model_name_plot in models_in_run: |
|
|
model_records = [rec for rec in repetition_records if rec["model"] == model_name_plot] |
|
|
if not model_records: |
|
|
continue |
|
|
epochs = [rec["epoch"] for rec in model_records] |
|
|
val_loss_values = [rec["val_loss"] for rec in model_records] |
|
|
val_f1_values = [rec["val_f1"] for rec in model_records] |
|
|
axes[0].plot(epochs, val_loss_values, marker='o', label=model_name_plot) |
|
|
axes[1].plot(epochs, val_f1_values, marker='o', label=model_name_plot) |
|
|
axes[0].set_ylabel('Val Loss') |
|
|
axes[1].set_ylabel('Val F1') |
|
|
axes[1].set_xlabel('Epoch') |
|
|
axes[0].legend(loc='best') |
|
|
axes[0].set_title(f'Size {train_size} / Rep {repetition} per-epoch metrics') |
|
|
fig.tight_layout() |
|
|
fig.savefig(rep_root / 'epoch_history.png', dpi=150) |
|
|
plt.close(fig) |
|
|
|
|
|
|
|
|
del model, optimizer, scheduler, best_state |
|
|
if scaler is not None: |
|
|
del scaler |
|
|
if device.type == 'cuda': |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
if per_size_val_metrics: |
|
|
print(f"\n[INFO] Validation summary for train_size={train_size}, rep={repetition}:") |
|
|
for model_name, val_f1, val_loss, test_f1 in sorted( |
|
|
per_size_val_metrics, key=lambda item: item[1], reverse=True |
|
|
): |
|
|
print( |
|
|
f" {model_name:<16} val_f1={val_f1:.4f} " |
|
|
f"val_loss={val_loss:.4f} test_f1={test_f1:.4f}" |
|
|
) |
|
|
|
|
|
summary_path = args.output_dir / "summary.json" |
|
|
summary_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
serializable_summary = { |
|
|
model_name: { |
|
|
size: { |
|
|
"acc": metrics["acc"], |
|
|
"f1": metrics["f1"], |
|
|
"val_f1": metrics["val_f1"], |
|
|
"val_loss": metrics["val_loss"], |
|
|
} |
|
|
for size, metrics in size_dict.items() |
|
|
} |
|
|
for model_name, size_dict in summary.items() |
|
|
} |
|
|
with open(summary_path, "w", encoding="utf-8") as f: |
|
|
json.dump(serializable_summary, f, indent=2) |
|
|
|
|
|
print("\n[INFO] Final accuracy summary:") |
|
|
for model_name, results in summary.items(): |
|
|
for size, metrics in results.items(): |
|
|
if metrics["acc"]: |
|
|
acc_mean = float(np.mean(metrics["acc"])) |
|
|
acc_std = float(np.std(metrics["acc"])) |
|
|
f1_mean = float(np.mean(metrics["f1"])) |
|
|
f1_std = float(np.std(metrics["f1"])) |
|
|
n = len(metrics["acc"]) |
|
|
print( |
|
|
f" {model_name} @ {size:4d}/class -> " |
|
|
f"acc={acc_mean:.4%} ± {acc_std:.4%}, f1={f1_mean:.4f} ± {f1_std:.4f} (n={n})" |
|
|
) |
|
|
|
|
|
print("\n[INFO] Final validation F1 summary:") |
|
|
for model_name, results in summary.items(): |
|
|
for size, metrics in results.items(): |
|
|
if metrics["val_f1"]: |
|
|
val_mean = float(np.mean(metrics["val_f1"])) |
|
|
val_std = float(np.std(metrics["val_f1"])) |
|
|
n = len(metrics["val_f1"]) |
|
|
print( |
|
|
f" {model_name} @ {size:4d}/class -> " |
|
|
f"val_f1={val_mean:.4f} ± {val_std:.4f} (n={n})" |
|
|
) |
|
|
|
|
|
print("\n[INFO] Final validation loss summary:") |
|
|
for model_name, results in summary.items(): |
|
|
for size, metrics in results.items(): |
|
|
if metrics["val_loss"]: |
|
|
loss_mean = float(np.mean(metrics["val_loss"])) |
|
|
loss_std = float(np.std(metrics["val_loss"])) |
|
|
n = len(metrics["val_loss"]) |
|
|
print( |
|
|
f" {model_name} @ {size:4d}/class -> " |
|
|
f"val_loss={loss_mean:.4f} ± {loss_std:.4f} (n={n})" |
|
|
) |
|
|
|
|
|
if HAVE_MPL: |
|
|
train_sizes_sorted = sorted(args.train_sizes) |
|
|
plt.figure(figsize=(8, 5)) |
|
|
plotted = False |
|
|
for model_name in args.models: |
|
|
model_results = summary.get(model_name, {}) |
|
|
means: List[float] = [] |
|
|
for size in train_sizes_sorted: |
|
|
val_list = model_results.get(size, {}).get("val_f1", []) |
|
|
means.append(float(np.mean(val_list)) if val_list else float("nan")) |
|
|
if not any(np.isfinite(means)): |
|
|
continue |
|
|
plt.plot(train_sizes_sorted, means, marker="o", linewidth=2, label=model_name) |
|
|
plotted = True |
|
|
if plotted: |
|
|
plt.title("Validation F1 vs. Training Size") |
|
|
plt.xlabel("Training samples per class") |
|
|
plt.ylabel("Validation F1 (macro)") |
|
|
plt.xticks(train_sizes_sorted) |
|
|
plt.ylim(0.0, 1.0) |
|
|
plt.grid(True, which="both", linestyle="--", alpha=0.4) |
|
|
plt.legend(title="Model", frameon=False) |
|
|
plt.tight_layout() |
|
|
plot_path = args.output_dir / "val_f1_summary.png" |
|
|
plt.savefig(plot_path, dpi=200) |
|
|
plt.close() |
|
|
print(f"[INFO] Saved validation F1 plot to {plot_path}") |
|
|
else: |
|
|
plt.close() |
|
|
print("[WARN] No validation F1 data available to plot.") |
|
|
else: |
|
|
print("[WARN] Matplotlib not available; skipping validation F1 plot.") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|