diff --git "a/mixture/train_embedding_router.py" "b/mixture/train_embedding_router.py" deleted file mode 100644--- "a/mixture/train_embedding_router.py" +++ /dev/null @@ -1,3381 +0,0 @@ -#!/usr/bin/env python3 -"""Train a router that selects top-k LWM backbones and feeds a shared classifier. - -This variant differs from ``train_top1_router.py`` by: - * loading each expert checkpoint only for its backbone (classifier discarded) - * extracting 128-d embeddings with the standard mean-pooled LWM features - * selecting the top-k experts per sample (default k=2) via a router network - * feeding each selected embedding through a shared Res1DCNN head - * weighting the per-embedding logits using the router probabilities - -The script auto-discovers six experts by default: - 1. Latest ``lwm_epoch*.pth`` from ``models/LTE_models`` (base LTE backbone) - 2. Latest ``lwm_epoch*.pth`` from ``models/WiFi_models`` (base WiFi backbone) - 3. Latest ``lwm_epoch*.pth`` from ``models/5G_models`` (base 5G backbone) - 4. Most recent epoch checkpoint from ``task2/mobility_benchmark/lte`` (LTE mobility expert) - 5. Most recent epoch checkpoint from ``task2/mobility_benchmark/wifi`` (WiFi mobility expert) - 6. Most recent epoch checkpoint from ``task2/mobility_benchmark/5g`` (5G mobility expert) - -Each expert must expose dataset statistics (mean/std or per-sample flag). The router -is first warmed up on communication labels, then fine-tuned jointly with the shared -classifier using both task loss and an auxiliary communication loss. -""" - -from __future__ import annotations - -import argparse -import csv -import json -import os -import math -import re -import sys -from collections import defaultdict -from dataclasses import dataclass -from functools import lru_cache -from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Set, Tuple - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.amp import GradScaler, autocast -from torch.utils.data import DataLoader, Dataset - -try: - from sklearn.metrics import f1_score - SKLEARN_AVAILABLE = True -except ImportError: - SKLEARN_AVAILABLE = False - -REPO_ROOT = Path(__file__).resolve().parent.parent -DEFAULT_EXPERT_ROOT = REPO_ROOT / "mixture" / "experts" - -sys.path.append(str(REPO_ROOT)) - -from task1.train_mcs_models import Res1DCNNHead, load_all_samples # type: ignore -from task2.mobility_utils import prepare_model # type: ignore -from mixture.train_top1_router import ( - SampleMetadata, - _collect_candidate_files, - load_dataset_stats, - snr_sort_key, -) # type: ignore - - -try: - from tqdm.auto import tqdm -except ImportError: # pragma: no cover - optional dependency - tqdm = None - - -COMM_CANONICAL = {"lte": "LTE", "wifi": "WiFi", "5g": "5G"} -ExpertIndex = int - - -@dataclass(slots=True) -class SampleEntry: - """Describe a single spectrogram sample stored inside a pickled tensor file.""" - - path: Path - index: int - metadata: SampleMetadata - - -def canonical_comm_name(name: str) -> str: - lower = name.strip().lower() - if lower in COMM_CANONICAL: - return COMM_CANONICAL[lower] - for canonical in COMM_CANONICAL.values(): - if canonical.lower() == lower: - return canonical - raise ValueError(f"Unknown communication type: {name}") - - -def discover_latest_base_checkpoint(comm: str) -> Path: - # Normalize communication type name (lte -> LTE) - comm_upper = canonical_comm_name(comm) - - # Try multiple possible locations - possible_folders = [ - REPO_ROOT / "models" / "experts" / "baseline" / f"{comm_upper}_models", - REPO_ROOT / "models" / f"{comm_upper}_models", - REPO_ROOT / "models" / f"{comm.capitalize()}_models", - ] - - folder = None - for candidate in possible_folders: - if candidate.exists(): - folder = candidate - break - - if folder is None: - raise FileNotFoundError( - f"Base model directory not found for {comm}. Tried: {[str(p) for p in possible_folders]}" - ) - - candidates = sorted(folder.glob("lwm_epoch*_val*.pth")) - if not candidates: - raise FileNotFoundError(f"No base checkpoints found under {folder}") - - def key(path: Path) -> Tuple[float, float]: - match = re.search(r"epoch(\d+)_val([\d.]+)\.pth$", path.name) - if match: - epoch = float(match.group(1)) - val = float(match.group(2)) - return val, -epoch - return float("inf"), 0.0 - - return min(candidates, key=key) - - -def discover_latest_mobility_checkpoint(comm: str) -> Path: - # Normalize communication type name (lte -> LTE) - comm_upper = canonical_comm_name(comm) - - # Try multiple possible locations - possible_locations = [ - # First try the experts/task2 folder - REPO_ROOT / "models" / "experts" / "task2" / f"{comm_upper}_models", - # Then try task2/mobility_benchmark - REPO_ROOT / "task2" / "mobility_benchmark" / comm.lower(), - ] - - for base_dir in possible_locations: - if not base_dir.exists(): - continue - - # If it's a models directory, look for checkpoint files directly - if "_models" in base_dir.name: - candidates = sorted(base_dir.glob("*.pth")) - if candidates: - return candidates[-1] - - # Otherwise, it's a run directory - search for checkpoints - run_dirs = sorted([p for p in base_dir.iterdir() if p.is_dir()]) - if run_dirs: - for run_dir in reversed(run_dirs): - epoch_dir = run_dir / "epoch_checkpoints" - if epoch_dir.exists(): - epochs = sorted(epoch_dir.glob("epoch_*.pth")) - if epochs: - return epochs[-1] - # Fallback: search recursively for *.pth within run - candidates = sorted(run_dir.rglob("*.pth")) - if candidates: - return candidates[-1] - - raise FileNotFoundError( - f"No mobility checkpoints found for {comm}. Tried: {[str(p) for p in possible_locations]}" - ) - - -@dataclass(slots=True) -class ExpertSpec: - name: str - comm: str - checkpoint: Path - stats_path: Optional[Path] = None - - -def infer_comm_from_path(path: Path) -> Optional[str]: - parts = [p.lower() for p in path.parts] - stem = path.stem.lower() - filename = path.name.lower() - for part in parts: - for key, canonical in COMM_CANONICAL.items(): - canonical_lower = canonical.lower() - if key in part or canonical_lower in part: - return canonical - for key, canonical in COMM_CANONICAL.items(): - canonical_lower = canonical.lower() - if ( - stem.startswith(f"{key}_") - or stem.startswith(f"{canonical_lower}_") - or key in filename - or canonical_lower in filename - ): - return canonical - return None - - -def discover_experts_from_directory(base_dir: Path) -> List[ExpertSpec]: - specs: List[ExpertSpec] = [] - if not base_dir.exists(): - return specs - for ckpt in sorted(base_dir.rglob("*.pth")): - comm = infer_comm_from_path(ckpt) - if comm is None: - print(f"[WARN] Unable to infer communication type for expert at {ckpt}; skipping") - continue - name = ckpt.stem - # Stats path is optional - we use per-sample normalization - stats_candidates = [ - ckpt.with_suffix(".json"), - ckpt.parent / "dataset_stats.json", - REPO_ROOT / "models" / f"{comm}_models" / "dataset_stats.json", - ] - stats_path: Optional[Path] = None - for candidate in stats_candidates: - if candidate.exists(): - stats_path = candidate - break - - specs.append( - ExpertSpec( - name=name, - comm=comm, - checkpoint=ckpt.resolve(), - stats_path=stats_path.resolve() if stats_path else None, - ) - ) - return specs - - -def discover_default_experts() -> List[ExpertSpec]: - directory_specs = discover_experts_from_directory(DEFAULT_EXPERT_ROOT) - if directory_specs: - print(f"[INFO] Discovered {len(directory_specs)} expert(s) under {DEFAULT_EXPERT_ROOT}") - return directory_specs - - specs: List[ExpertSpec] = [] - for comm in ("lte", "wifi", "5g"): - pretty = canonical_comm_name(comm) - base_ckpt = discover_latest_base_checkpoint(comm) - # Stats path is optional - we use per-sample normalization - stats_path = base_ckpt.parent / "dataset_stats.json" - specs.append( - ExpertSpec( - name=f"{pretty}_base", - comm=pretty, - checkpoint=base_ckpt, - stats_path=stats_path if stats_path.exists() else None, - ) - ) - for comm in ("lte", "wifi", "5g"): - pretty = canonical_comm_name(comm) - ckpt = discover_latest_mobility_checkpoint(comm) - # Stats path is optional - we use per-sample normalization - stats_candidates = [ - ckpt.parent / "dataset_stats.json", - (REPO_ROOT / "models" / f"{pretty}_models" / "dataset_stats.json"), - ] - stats_path = next((p for p in stats_candidates if p.exists()), None) - specs.append( - ExpertSpec( - name=f"{pretty}_mobility", - comm=pretty, - checkpoint=ckpt, - stats_path=stats_path, - ) - ) - return specs - - -def collect_sample_entries_for_comm( - *, - data_root: Path, - cities: Sequence[str], - comm: str, - snrs: Optional[Sequence[str]], - mobilities: Optional[Sequence[str]], - modulations: Optional[Sequence[str]], - fft_folders: Optional[Sequence[str]], - max_samples: int, - max_per_combo: Optional[int], - target_per_combo: Optional[int], - rng: np.random.Generator, -) -> List[SampleEntry]: - """Gather per-sample references without materialising full tensors.""" - candidates = _collect_candidate_files( - data_root=data_root, - cities=cities, - comm=comm, - snr_filters=snrs, - mobility_filters=mobilities, - modulation_filters=modulations, - fft_filters=fft_folders, - ) - if not candidates: - raise RuntimeError(f"No spectrogram files matched filters for {comm}") - - # Resolve all paths upfront to avoid repeated resolve() calls - candidates = [(path.resolve(), meta) for path, meta in candidates] - rng.shuffle(candidates) - combo_counts = defaultdict(int) - entries: List[SampleEntry] = [] - remaining = max_samples if max_samples > 0 else None - per_combo_limit = max_per_combo if (max_per_combo is not None and max_per_combo > 0) else None - combos_available: Set[Tuple[str, str, str]] = { - (meta.modulation, meta.snr, meta.mobility) for _, meta in candidates - } - combo_targets: Optional[Dict[Tuple[str, str, str], int]] = None - satisfied_combos: Set[Tuple[str, str, str]] = set() - warned_combo_limit = False - if target_per_combo is not None: - combo_targets = {} - for combo in combos_available: - target = target_per_combo - if per_combo_limit is not None: - effective = min(target, per_combo_limit) - if effective < target and not warned_combo_limit: - print( - f"[WARN] {comm}: per-combo limit ({per_combo_limit}) is below requested total ({target}); " - "consider relaxing --max-per-combo or per-class caps." - ) - warned_combo_limit = True - target = effective - combo_targets[combo] = target - if target <= 0: - satisfied_combos.add(combo) - - files_processed = 0 - for file_idx, (path, meta) in enumerate(candidates, start=1): - files_processed = file_idx - if remaining is not None and remaining <= 0: - break - combo_key = (meta.modulation, meta.snr, meta.mobility) - already = combo_counts[combo_key] - if per_combo_limit is not None and already >= per_combo_limit: - continue - - path_str = str(path) - try: - # Fast metadata-only read to get sample count - num_samples = get_sample_count_fast(path_str) - except Exception as exc: # pragma: no cover - guard against corrupted files - print(f"[WARN] Failed to load {path_str}: {exc}") - continue - - if num_samples == 0: - continue - - remaining_for_combo = ( - per_combo_limit - already if per_combo_limit is not None else num_samples - ) - allowed = min(num_samples, remaining_for_combo) - if remaining is not None: - allowed = min(allowed, remaining) - if allowed <= 0: - continue - - if allowed == num_samples: - chosen_indices = np.arange(num_samples) - else: - chosen_indices = rng.choice(num_samples, size=allowed, replace=False) - - # Reuse metadata for all samples from same file - entry_meta = SampleMetadata( - comm=meta.comm, - modulation=meta.modulation, - snr=meta.snr, - mobility=meta.mobility, - rate=meta.rate, - source=path_str, - ) - - # Batch create and extend entries (faster than repeated append) - batch_entries = [ - SampleEntry(path=path, index=int(idx), metadata=entry_meta) - for idx in chosen_indices.tolist() - ] - entries.extend(batch_entries) - - combo_counts[combo_key] += int(len(chosen_indices)) - if remaining is not None: - remaining -= int(len(chosen_indices)) - # Only log at major milestones to reduce overhead - if file_idx == 1 or file_idx % 50 == 0: - print(f"[DATA] {comm}: gathered {len(entries):,} samples after {file_idx} files", flush=True) - if combo_targets is not None: - target = combo_targets.get(combo_key) - if target is not None and combo_counts[combo_key] >= target: - satisfied_combos.add(combo_key) - if len(satisfied_combos) == len(combo_targets): - break - - if not entries: - raise RuntimeError(f"Unable to collect samples for {comm} after applying limits") - if combo_targets is not None: - unmet = [ - combo for combo, target in combo_targets.items() if combo_counts[combo] < target - ] - if unmet: - print( - f"[WARN] {comm}: target not met for {len(unmet)} combo(s); " - "consider lowering per-class requirements.", - flush=True, - ) - print( - f"[DATA] {comm}: gathered {len(entries)} samples after scanning {files_processed} files", - flush=True, - ) - return entries - - -def iterate_batches(loader: Iterable, desc: str, *, log_every: Optional[int] = None) -> Iterable: - """Yield from a loader while emitting progress information.""" - if tqdm is not None: - for item in tqdm(loader, desc=desc, leave=False, dynamic_ncols=True): - yield item - return - - try: - total = len(loader) # type: ignore[arg-type] - except Exception: - total = None - if log_every is None: - if total: - log_every = max(1, total // 10) - else: - log_every = 50 - for idx, batch in enumerate(loader, start=1): - if idx == 1 or idx % log_every == 0 or (total is not None and idx == total): - if total is not None: - print(f"[Progress] {desc}: {idx}/{total}", flush=True) - else: - print(f"[Progress] {desc}: batch {idx}", flush=True) - yield batch - - -def _get_cache_capacity(default: int = 32) -> int: - """Derive cache size from env while guarding against invalid values.""" - override = os.environ.get("LWM_FILE_CACHE_SIZE") - if not override: - return default - try: - value = int(override) - return max(1, value) - except ValueError: - print( - f"[WARN] Ignoring invalid LWM_FILE_CACHE_SIZE={override!r}; using default {default}", - flush=True, - ) - return default - - -_FILE_CACHE_SIZE = _get_cache_capacity() - - -def _resolve_preload_dtype(default: str = "float16") -> torch.dtype: - override = os.environ.get("LWM_PRELOAD_DTYPE", default) - alias = override.strip().lower() - mapping = { - "float16": torch.float16, - "fp16": torch.float16, - "half": torch.float16, - "float32": torch.float32, - "fp32": torch.float32, - "single": torch.float32, - } - dtype = mapping.get(alias) - if dtype is None: - print( - f"[WARN] Unknown LWM_PRELOAD_DTYPE={override!r}; defaulting to {default}", - flush=True, - ) - dtype = mapping[default] - return dtype - - -def _parse_float_env(name: str) -> Optional[float]: - value = os.environ.get(name) - if not value: - return None - try: - return float(value) - except ValueError: - print(f"[WARN] Ignoring invalid {name}={value!r}", flush=True) - return None - - -def _available_ram_bytes() -> Optional[int]: - try: - import psutil # type: ignore - - return int(psutil.virtual_memory().available) - except Exception: - pass - try: - with open("/proc/meminfo", "r", encoding="utf-8") as fh: - for line in fh: - if line.startswith("MemAvailable:"): - parts = line.split() - if len(parts) >= 2: - # Value reported in kB - return int(float(parts[1]) * 1024) - except Exception: - pass - return None - - -@lru_cache(maxsize=_FILE_CACHE_SIZE) -def _load_file_spectrograms(path_str: str) -> np.ndarray: - """Load all spectrograms for a path with a bounded LRU cache.""" - arr = load_all_samples(path_str) - if arr.dtype != np.float16: - arr = arr.astype(np.float16, copy=False) - return arr - - -def get_sample_count_fast(path: str) -> int: - """Get number of samples in a file without loading the full array. - - For this dataset, each pickle file contains exactly 1000 samples, - so we can skip the expensive file I/O and return the constant. - """ - # Dataset convention: each spectrogram pickle file contains 1000 samples - return 1000 - - -def load_spec_tensor(entry: SampleEntry) -> torch.Tensor: - """Materialise a single spectrogram as a float32 tensor.""" - path_str = str(entry.path) - specs = _load_file_spectrograms(path_str) - if entry.index < 0 or entry.index >= specs.shape[0]: - raise IndexError(f"Sample index {entry.index} out of range for {path_str}") - sample = specs[entry.index] - if sample.ndim != 2: - raise ValueError(f"Expected 2-D spectrogram, got shape {sample.shape}") - # Clone to decouple the tensor lifetime from the cached numpy array. - # Explicitly ensure float32 dtype for GPU efficiency - return torch.from_numpy(sample).clone().float() - - -class EmbeddingRouterDataset(Dataset): - def __init__( - self, - entries: Sequence[SampleEntry], - comm_labels: np.ndarray, - task_labels: np.ndarray, - preload: bool = True, - ) -> None: - if not (len(entries) == len(comm_labels) == len(task_labels)): - raise ValueError("Dataset inputs must share the same length") - self.entries = list(entries) - self.comm_labels = torch.from_numpy(comm_labels.astype(np.int64, copy=False)) - self.task_labels = torch.from_numpy(task_labels.astype(np.int64, copy=False)) - self.metadata = [entry.metadata for entry in self.entries] - - # Preload all spectrograms into memory for faster training - self.preload = preload - self.spectrograms = None - self.preload_dtype = _resolve_preload_dtype() - if self.preload: - element_size = torch.tensor([], dtype=self.preload_dtype).element_size() - required_bytes = len(entries) * 128 * 128 * element_size - required_gb = required_bytes / 1e9 - max_gb = _parse_float_env("LWM_PRELOAD_MAX_GB") - available_bytes = _available_ram_bytes() - allow_preload = True - if max_gb is not None and required_gb > max_gb: - print( - f"[WARN] Requested preload requires {required_gb:.2f} GB, " - f"exceeding LWM_PRELOAD_MAX_GB={max_gb:.2f}; falling back to streaming.", - flush=True, - ) - allow_preload = False - elif available_bytes is not None and required_bytes > available_bytes * 0.8: - available_gb = available_bytes / 1e9 - print( - f"[WARN] Requested preload requires {required_gb:.2f} GB but only " - f"{available_gb:.2f} GB appears available; falling back to streaming.", - flush=True, - ) - allow_preload = False - if not allow_preload: - self.preload = False - - file_groups: Optional[Dict[str, List[Tuple[int, int]]]] = None - if self.preload: - print(f"[INFO] Preloading {len(entries):,} spectrograms into RAM...") - # Group entries by file path for efficient loading - from collections import defaultdict - groups: Dict[str, List[Tuple[int, int]]] = defaultdict(list) - for idx, entry in enumerate(entries): - groups[str(entry.path)].append((idx, entry.index)) - file_groups = groups - - # Preallocate tensor using the configured dtype to control memory footprint - try: - self.spectrograms = torch.empty( - (len(entries), 128, 128), dtype=self.preload_dtype - ) - except RuntimeError as exc: - print( - f"[WARN] Failed to allocate preload buffer ({exc}); falling back to streaming.", - flush=True, - ) - self.preload = False - self.spectrograms = None - - if self.preload and self.spectrograms is not None: - if file_groups is None: - from collections import defaultdict - - file_groups = defaultdict(list) - for idx, entry in enumerate(entries): - file_groups[str(entry.path)].append((idx, entry.index)) - - # Load files in batch - if tqdm is not None: - iter_files = tqdm( - file_groups.items(), desc="Loading files", leave=False, total=len(file_groups) - ) - else: - iter_files = file_groups.items() - - for path_str, indices_list in iter_files: - # Load file once - file_data = load_all_samples(path_str) - # Extract all needed samples from this file - for sample_idx, file_offset in indices_list: - tensor = torch.from_numpy(file_data[file_offset]).to(dtype=self.preload_dtype) - self.spectrograms[sample_idx] = tensor - - print( - f"[INFO] Preloaded {self.spectrograms.shape[0]:,} spectrograms " - f"({self.spectrograms.element_size() * self.spectrograms.nelement() / 1e9:.2f} GB)" - ) - - def __len__(self) -> int: - return len(self.entries) - - def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int, int]: - if self.preload and self.spectrograms is not None: - spec = self.spectrograms[idx].to(dtype=torch.float32) - else: - entry = self.entries[idx] - spec = load_spec_tensor(entry) - return spec, int(self.comm_labels[idx]), int(self.task_labels[idx]) - - -def modulation_labels_from_metadata(metadata: Sequence[SampleMetadata]) -> np.ndarray: - from task1.train_mcs_models import MODULATION_LABELS # type: ignore - - labels: List[int] = [] - for meta in metadata: - label = MODULATION_LABELS.get(meta.modulation.upper()) - if label is None: - raise ValueError(f"Unknown modulation label in metadata: {meta.modulation}") - labels.append(label) - return np.array(labels, dtype=np.int64) - - -def snr_mobility_labels_from_metadata( - metadata: Sequence[SampleMetadata], - *, - snr_order: Sequence[str], - mobility_order: Sequence[str], -) -> Tuple[np.ndarray, Dict[int, Tuple[str, str]]]: - combos: List[Tuple[str, str]] = [] - for snr in snr_order: - for mobility in mobility_order: - combos.append((snr, mobility)) - combo_to_idx = {combo: idx for idx, combo in enumerate(combos)} - - labels: List[int] = [] - for meta in metadata: - combo = (meta.snr, meta.mobility) - if combo not in combo_to_idx: - raise ValueError(f"Sample combo {combo} not present in configured (snr, mobility) grid") - labels.append(combo_to_idx[combo]) - mapping = {idx: combo for combo, idx in combo_to_idx.items()} - return np.array(labels, dtype=np.int64), mapping - - -def prepare_dataset( - *, - data_root: Path, - cities: Sequence[str], - comm_types: Sequence[str], - snrs: Optional[Sequence[str]], - mobilities: Optional[Sequence[str]], - modulations: Optional[Sequence[str]], - fft_folders: Optional[Sequence[str]], - max_samples_per_comm: int, - max_per_combo: Optional[int], - max_samples_per_class: int, - val_samples_per_class: int, - test_samples_per_class: int, - task: str, - seed: int, - preload: bool = True, -) -> Tuple[EmbeddingRouterDataset, Dict[str, int], Optional[Dict[int, Tuple[str, str]]]]: - rng = np.random.default_rng(seed) - entries: List[SampleEntry] = [] - comm_labels_list: List[int] = [] - comm_to_idx: Dict[str, int] = {} - total_required = 0 - have_requirements = False - if max_samples_per_class > 0: - total_required += max_samples_per_class - have_requirements = True - if val_samples_per_class > 0: - total_required += val_samples_per_class - have_requirements = True - if test_samples_per_class > 0: - total_required += test_samples_per_class - have_requirements = True - target_per_combo = total_required if have_requirements else None - - for comm in comm_types: - try: - comm_entries = collect_sample_entries_for_comm( - data_root=data_root, - cities=cities, - comm=comm, - snrs=snrs, - mobilities=mobilities, - modulations=modulations, - fft_folders=fft_folders, - max_samples=max_samples_per_comm, - max_per_combo=max_per_combo, - target_per_combo=target_per_combo, - rng=rng, - ) - except RuntimeError as exc: - print(f"[WARN] {exc}; skipping {comm}") - continue - if comm not in comm_to_idx: - comm_to_idx[comm] = len(comm_to_idx) - comm_idx = comm_to_idx[comm] - entries.extend(comm_entries) - comm_labels_list.extend([comm_idx] * len(comm_entries)) - - if not entries: - raise RuntimeError("No spectrogram data collected for any communication type") - - comm_labels = np.array(comm_labels_list, dtype=np.int64) - order = rng.permutation(len(entries)) - entries = [entries[idx] for idx in order] - comm_labels = comm_labels[order] - metadata = [entry.metadata for entry in entries] - - if task == "modulation": - task_labels = modulation_labels_from_metadata(metadata) - mapping = None - else: - if snrs is None: - snr_order = sorted({meta.snr for meta in metadata}, key=snr_sort_key) - else: - snr_order = [snr for snr in snrs if any(meta.snr == snr for meta in metadata)] - if mobilities is None: - mobility_order = sorted({meta.mobility for meta in metadata}) - else: - mobility_order = [mob for mob in mobilities if any(meta.mobility == mob for meta in metadata)] - task_labels, mapping = snr_mobility_labels_from_metadata( - metadata, - snr_order=snr_order, - mobility_order=mobility_order, - ) - - dataset = EmbeddingRouterDataset(entries, comm_labels, task_labels, preload=preload) - - # Print data statistics - print(f"\n[DATA] Collected {len(dataset)} total samples:") - for comm_name, comm_idx in sorted(comm_to_idx.items(), key=lambda x: x[1]): - count = int((comm_labels == comm_idx).sum()) - print(f" {comm_name}: {count:,} samples") - print() - - return dataset, comm_to_idx, mapping - - -def stratified_split( - labels: np.ndarray, - *, - train_ratio: float, - val_ratio: float, - max_train_per_class: int = 0, - val_samples_per_class: int = 0, - test_samples_per_class: int = 0, - seed: int, -) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - if not (0 < train_ratio < 1): - raise ValueError("train_ratio must be in (0, 1)") - if not (0 < val_ratio < 1): - raise ValueError("val_ratio must be in (0, 1)") - if ( - val_samples_per_class <= 0 - and test_samples_per_class <= 0 - and train_ratio + val_ratio >= 1.0 - ): - raise ValueError("train_ratio + val_ratio must be < 1.0 when using ratios for all splits") - - rng = np.random.default_rng(seed) - train_indices: List[int] = [] - val_indices: List[int] = [] - test_indices: List[int] = [] - base_test_ratio = max(0.0, 1.0 - train_ratio - val_ratio) - - for label in np.unique(labels): - idx = np.where(labels == label)[0] - rng.shuffle(idx) - n_total = idx.size - n_val = ( - min(val_samples_per_class, n_total) - if val_samples_per_class > 0 - else int(math.floor(val_ratio * n_total)) - ) - if val_samples_per_class > 0 and n_val < val_samples_per_class: - print( - f"[WARN] Class {label}: requested {val_samples_per_class} validation samples but only {n_total} available" - ) - remaining_after_val = max(n_total - n_val, 0) - n_test = ( - min(test_samples_per_class, remaining_after_val) - if test_samples_per_class > 0 - else int(math.floor(base_test_ratio * n_total)) - ) - if test_samples_per_class > 0 and n_test < test_samples_per_class: - print( - f"[WARN] Class {label}: requested {test_samples_per_class} test samples but only {remaining_after_val} available after validation" - ) - n_test = min(n_test, remaining_after_val) - base_train = remaining_after_val - n_test - if base_train < 0: - n_test = max(0, remaining_after_val) - base_train = remaining_after_val - n_test - if max_train_per_class > 0 and base_train > max_train_per_class: - overflow = base_train - max_train_per_class - n_train = max_train_per_class - n_test = min(n_test + overflow, remaining_after_val) - else: - n_train = base_train - - used = n_val + n_test + n_train - if used < n_total: - # Prefer allocating leftovers to the test split for more evaluation coverage. - extra = min(n_total - used, remaining_after_val - n_test) - n_test += extra - used = n_val + n_test + n_train - if used > n_total: - overflow = used - n_total - reduction = min(overflow, n_test) - n_test -= reduction - overflow -= reduction - if overflow > 0: - n_val = max(0, n_val - overflow) - overflow = 0 - used = n_val + n_test + n_train - - if n_val < 0 or n_test < 0 or n_train < 0: - raise RuntimeError( - f"Negative split size encountered for class {label}: " - f"train={n_train}, val={n_val}, test={n_test}" - ) - - start_train = n_val + n_test - train_indices.extend(idx[start_train:start_train + n_train]) - val_indices.extend(idx[:n_val]) - test_indices.extend(idx[n_val:n_val + n_test]) - - return ( - np.sort(np.array(train_indices, dtype=np.int64)), - np.sort(np.array(val_indices, dtype=np.int64)), - np.sort(np.array(test_indices, dtype=np.int64)), - ) - - -class RouterNet(nn.Module): - """Lightweight CNN router.""" - - def __init__(self, num_experts: int, dropout: float = 0.1) -> None: - super().__init__() - self.features = nn.Sequential( - nn.Conv2d(1, 32, kernel_size=5, stride=2, padding=2), - nn.BatchNorm2d(32), - nn.SiLU(inplace=True), - nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), - nn.BatchNorm2d(64), - nn.SiLU(inplace=True), - nn.Conv2d(64, 96, kernel_size=3, stride=2, padding=1), - nn.BatchNorm2d(96), - nn.SiLU(inplace=True), - nn.Conv2d(96, 128, kernel_size=3, stride=2, padding=1), - nn.BatchNorm2d(128), - nn.SiLU(inplace=True), - nn.AdaptiveAvgPool2d((1, 1)), - ) - head_layers: List[nn.Module] = [nn.Flatten()] - if dropout > 0: - head_layers.append(nn.Dropout(dropout)) - head_layers.append(nn.Linear(128, num_experts)) - self.classifier = nn.Sequential(*head_layers) - - def forward(self, specs: torch.Tensor) -> torch.Tensor: - x = specs - if x.dim() == 3: - x = x.unsqueeze(1) - elif x.dim() != 4: - raise ValueError(f"Expected specs rank 3 or 4, got shape {tuple(specs.shape)}") - features = self.features(x) - logits = self.classifier(features) - return logits - - -class TaskClassifier(nn.Module): - """Shared Res1DCNN head operating on 128-d embeddings.""" - - def __init__(self, num_classes: int, dropout: float = 0.1) -> None: - super().__init__() - self.head = nn.Sequential( - nn.LayerNorm(128), - Res1DCNNHead(128, num_classes, dropout=dropout), - ) - - def forward(self, embeddings: torch.Tensor) -> torch.Tensor: - return self.head(embeddings) - - -class EmbeddingExpert(nn.Module): - def __init__(self, spec: ExpertSpec, device: torch.device, *, trainable: bool = False) -> None: - super().__init__() - self._trainable = bool(trainable) - # Use per-sample normalization by default if no stats are provided - if spec.stats_path is not None and spec.stats_path.exists(): - stats = load_dataset_stats(spec.stats_path) - self.stats = { - "normalization": str(stats.get("normalization", "per_sample")).lower(), - "mean": float(stats.get("mean", 0.0)), - "std": float(stats.get("std", 1.0)), - } - else: - # Default to per-sample normalization - self.stats = { - "normalization": "per_sample", - "mean": 0.0, - "std": 1.0, - } - # Prepare normalization stats for prepare_model - # If using per-sample normalization, we don't need to pass dataset stats - normalization_stats = None - if self.stats["normalization"] != "per_sample": - normalization_stats = { - "normalization": self.stats["normalization"], - "mean": self.stats["mean"], - "std": self.stats["std"], - } - - model = prepare_model( - checkpoint=spec.checkpoint, - num_classes=2, - classifier_dim=128, - dropout=0.0, - trainable_layers=2 if trainable else 0, - projection_dim=0, - append_input_stats=False, - normalization_stats=normalization_stats, - head_type="mlp", - ) - # prepare_model already sets requires_grad correctly based on trainable_layers - model.train(self._trainable) - self.model = model.to(device) - - @property - def trainable(self) -> bool: - return self._trainable - - def set_trainable(self, trainable: bool) -> None: - self._trainable = bool(trainable) - # Freeze all backbone parameters first - for param in self.model.backbone.parameters(): - param.requires_grad = False - # If trainable, enable last 2 layers of backbone - if self._trainable: - layers = getattr(self.model.backbone, "layers", None) - if layers is not None and len(layers) >= 2: - for layer in layers[-2:]: - for param in layer.parameters(): - param.requires_grad = True - self.model.train(self._trainable) - super().train(self._trainable) - - def train(self, mode: bool = True) -> "EmbeddingExpert": - effective_mode = bool(mode and self._trainable) - self.model.train(effective_mode) - return super().train(mode) - - def eval(self) -> "EmbeddingExpert": - self.model.eval() - return super().eval() - - def forward(self, specs: torch.Tensor, *, allow_grad: Optional[bool] = None) -> torch.Tensor: - use_grad = self._determine_grad(allow_grad) - x = self._normalize(specs) - if use_grad: - return self.model.forward_features(x) - with torch.no_grad(): - return self.model.forward_features(x) - - def forward_prenormalized( - self, - specs: torch.Tensor, - *, - allow_grad: Optional[bool] = None, - ) -> torch.Tensor: - """Forward pass with pre-normalized spectrograms (skip normalization).""" - use_grad = self._determine_grad(allow_grad) - if use_grad: - return self.model.forward_features(specs) - with torch.no_grad(): - return self.model.forward_features(specs) - - def _normalize(self, specs: torch.Tensor) -> torch.Tensor: - mode = self.stats["normalization"] - mean = self.stats["mean"] - std = max(abs(self.stats["std"]), 1e-6) - if mode == "dataset": - return (specs - mean) / std - mean_tensor = specs.mean(dim=(1, 2), keepdim=True) - std_tensor = specs.std(dim=(1, 2), keepdim=True, unbiased=False) - std_tensor = torch.clamp(std_tensor, min=1e-6) - return (specs - mean_tensor) / std_tensor - - def _determine_grad(self, allow_grad: Optional[bool]) -> bool: - if allow_grad is None: - return self._trainable - return bool(allow_grad) - - -def normalize_per_sample_tensor(specs: torch.Tensor) -> torch.Tensor: - mean = specs.mean(dim=(1, 2), keepdim=True) - std = specs.std(dim=(1, 2), keepdim=True, unbiased=False) - std = torch.clamp(std, min=1e-6) - return (specs - mean) / std - - -def build_dataloaders( - dataset: EmbeddingRouterDataset, - *, - train_idx: np.ndarray, - val_idx: np.ndarray, - test_idx: np.ndarray, - batch_size: int, - num_workers: int, -) -> Tuple[DataLoader, DataLoader, DataLoader]: - def _subset(indices: np.ndarray) -> torch.utils.data.Subset: - return torch.utils.data.Subset(dataset, indices.tolist()) - - # Optimize DataLoader configuration based on whether we're using workers - use_cuda = torch.cuda.is_available() - persistent_workers = num_workers > 0 - prefetch_factor = 4 if num_workers > 0 else None # Increased prefetch for better pipelining - - # Use larger batch size effectively with proper pin_memory and non_blocking transfers - train_loader = DataLoader( - _subset(train_idx), - batch_size=batch_size, - shuffle=True, - num_workers=num_workers, - pin_memory=use_cuda, - persistent_workers=persistent_workers, - prefetch_factor=prefetch_factor, - drop_last=True, # Drop incomplete batches for more consistent training - ) - val_loader = DataLoader( - _subset(val_idx), - batch_size=batch_size, - shuffle=False, - num_workers=num_workers, - pin_memory=use_cuda, - persistent_workers=persistent_workers, - prefetch_factor=prefetch_factor, - ) - test_loader = DataLoader( - _subset(test_idx), - batch_size=batch_size, - shuffle=False, - num_workers=num_workers, - pin_memory=use_cuda, - persistent_workers=persistent_workers, - prefetch_factor=prefetch_factor, - ) - return train_loader, val_loader, test_loader - - -def aggregate_comm_probs( - probs: torch.Tensor, - group_map: Mapping[int, List[int]], -) -> torch.Tensor: - num_comm = len(group_map) - agg = torch.zeros(probs.size(0), num_comm, device=probs.device, dtype=probs.dtype) - for comm_idx, expert_indices in group_map.items(): - if not expert_indices: - continue - agg[:, comm_idx] = probs[:, expert_indices].sum(dim=1) - return agg - - -def router_cross_entropy( - logits: torch.Tensor, - targets: torch.Tensor, - group_map: Mapping[int, List[int]], -) -> torch.Tensor: - probs = torch.softmax(logits, dim=1) - agg = aggregate_comm_probs(probs, group_map) - agg = torch.clamp(agg, min=1e-12) - return F.nll_loss(agg.log(), targets) - - -def build_group_map(experts: Sequence[ExpertSpec], comm_to_idx: Mapping[str, int]) -> Dict[int, List[int]]: - grouping: Dict[int, List[int]] = {idx: [] for idx in comm_to_idx.values()} - for expert_idx, spec in enumerate(experts): - comm_idx = comm_to_idx[spec.comm] - grouping[comm_idx].append(expert_idx) - return grouping - - -def train_router( - router: RouterNet, - *, - experts: Sequence[ExpertSpec], - comm_to_idx: Mapping[str, int], - train_loader: DataLoader, - val_loader: DataLoader, - device: torch.device, - epochs: int, - lr: float, - weight_decay: float, -) -> Dict[str, List[float]]: - group_map = build_group_map(experts, comm_to_idx) - optimizer = torch.optim.AdamW(router.parameters(), lr=lr, weight_decay=weight_decay) - scaler = GradScaler(enabled=torch.cuda.is_available()) - - # HPU support - use_hpu = device.type == "hpu" - if use_hpu: - try: - import habana_frameworks.torch.core as htcore - except ImportError: - use_hpu = False - - history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []} - for epoch in range(1, epochs + 1): - router.train() - running_loss = 0.0 - correct = 0 - total = 0 - desc = f"Router train {epoch:02d}" - for specs, comm_labels, _ in iterate_batches(train_loader, desc): - specs = specs.to(device, non_blocking=True) - comm_labels = comm_labels.to(device, non_blocking=True) - optimizer.zero_grad(set_to_none=True) - norm_specs = normalize_per_sample_tensor(specs) - context = autocast(device_type=device.type, enabled=scaler.is_enabled()) - with context: - logits = router(norm_specs) - loss = router_cross_entropy(logits, comm_labels, group_map) - if scaler.is_enabled(): - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() - else: - loss.backward() - optimizer.step() - - # HPU memory management - if use_hpu: - # In eager mode, mark_step doesn't work but we can try synchronization - try: - htcore.mark_step() - except: - pass - # Force synchronization and memory cleanup - torch.hpu.synchronize() - import gc - gc.collect() - - running_loss += loss.item() * specs.size(0) - probs = torch.softmax(logits.detach(), dim=1) - agg = aggregate_comm_probs(probs, group_map) - preds = agg.argmax(dim=1) - correct += (preds == comm_labels).sum().item() - total += specs.size(0) - - # Clear unused tensors to free memory - del specs, comm_labels, logits, loss, norm_specs, probs, agg, preds - - # Additional memory cleanup for HPU - if use_hpu and total % 100 == 0: # Every 100 samples - torch.hpu.synchronize() - import gc - gc.collect() - train_loss = running_loss / max(total, 1) - train_acc = correct / max(total, 1) - val_loss, val_acc = evaluate_router(router, val_loader, group_map, device) - history["train_loss"].append(train_loss) - history["train_acc"].append(train_acc) - history["val_loss"].append(val_loss) - history["val_acc"].append(val_acc) - print( - f"[Router] Epoch {epoch:02d}: train_loss={train_loss:.4f} train_acc={train_acc:.3f} " - f"val_loss={val_loss:.4f} val_acc={val_acc:.3f}" - ) - return history - - -@torch.no_grad() -def evaluate_router( - router: RouterNet, - loader: DataLoader, - group_map: Mapping[int, List[int]], - device: torch.device, -) -> Tuple[float, float]: - router.eval() - - # HPU support - use_hpu = device.type == "hpu" - if use_hpu: - try: - import habana_frameworks.torch.core as htcore - except ImportError: - use_hpu = False - - total_loss = 0.0 - correct = 0 - total = 0 - for specs, comm_labels, _ in iterate_batches(loader, "Router eval"): - specs = specs.to(device, non_blocking=True) - comm_labels = comm_labels.to(device, non_blocking=True) - logits = router(normalize_per_sample_tensor(specs)) - loss = router_cross_entropy(logits, comm_labels, group_map) - total_loss += loss.item() * specs.size(0) - probs = torch.softmax(logits, dim=1) - agg = aggregate_comm_probs(probs, group_map) - preds = agg.argmax(dim=1) - correct += (preds == comm_labels).sum().item() - total += specs.size(0) - - # HPU memory management - if use_hpu: - htcore.mark_step() - - # Clear unused tensors to free memory - del specs, comm_labels, logits, loss, probs, agg, preds - - return total_loss / max(total, 1), correct / max(total, 1) - - -def stack_expert_embeddings( - experts: Sequence[EmbeddingExpert], - specs: torch.Tensor, -) -> torch.Tensor: - embeddings: List[torch.Tensor] = [] - for expert in experts: - emb = expert(specs) - embeddings.append(emb.unsqueeze(1)) - return torch.cat(embeddings, dim=1) - - -def compute_selected_expert_embeddings( - experts: Sequence[EmbeddingExpert], - specs_normalized: torch.Tensor, - topk_indices: torch.Tensor, - *, - allow_grad: bool, -) -> torch.Tensor: - """Compute embeddings only for selected experts (optimized GPU version). - - Args: - experts: List of expert models - specs_normalized: Pre-normalized input spectrograms [batch_size, H, W] - topk_indices: Selected expert indices [batch_size, k] - - Returns: - Selected embeddings [batch_size, k, embed_dim] - """ - batch_size, k = topk_indices.shape - device = specs_normalized.device - unique_experts = torch.unique(topk_indices) - - output: Optional[torch.Tensor] = None - for expert_idx in unique_experts.tolist(): - expert_idx_int = int(expert_idx) - # Identify samples that actually route to this expert - sample_mask = (topk_indices == expert_idx_int).any(dim=1) - if not torch.any(sample_mask): - continue - sample_indices = sample_mask.nonzero(as_tuple=False).squeeze(1) - specs_subset = specs_normalized.index_select(0, sample_indices) - - # Handle DataParallel wrapper - expert_model = experts[expert_idx_int] - if isinstance(expert_model, nn.DataParallel): - expert_model = expert_model.module - - embeddings_subset = expert_model.forward_prenormalized( - specs_subset, - allow_grad=allow_grad and expert_model.trainable, - ) - - if output is None: - embed_dim = embeddings_subset.shape[-1] - output = torch.empty( - batch_size, - k, - embed_dim, - device=device, - dtype=embeddings_subset.dtype, - ) - - for pos in range(k): - pos_mask = topk_indices[sample_indices, pos] == expert_idx_int - if pos_mask.any(): - output[sample_indices[pos_mask], pos] = embeddings_subset[pos_mask] - - if output is None: - raise RuntimeError("No experts selected for current batch.") - - return output - - -def gather_topk_embeddings( - embeddings: torch.Tensor, - topk_indices: torch.Tensor, -) -> torch.Tensor: - batch, k = topk_indices.shape - feature_dim = embeddings.size(-1) - expanded_indices = topk_indices.unsqueeze(-1).expand(batch, k, feature_dim) - return embeddings.gather(dim=1, index=expanded_indices) - - -def train_task_model( - *, - router: RouterNet, - experts: Sequence[EmbeddingExpert], - expert_specs: Sequence[ExpertSpec], - comm_to_idx: Mapping[str, int], - classifier: TaskClassifier, - train_loader: DataLoader, - val_loader: DataLoader, - device: torch.device, - epochs: int, - topk: int, - router_lr: float, - classifier_lr: float, - expert_lr: float, - weight_decay: float, - router_loss_weight: float, - load_balance_weight: float, - gating_noise_std: float, - gating_noise_epochs: int, - patience: int = 10, - eval_interval: int = 1, - early_delta: float = 0.0, - checkpoint_callback: Optional[Callable[[int], None]] = None, -) -> Dict[str, List[Any]]: - expert_requires_grad = expert_lr > 0 - param_groups: List[Dict[str, object]] = [] - - classifier_params = [p for p in classifier.parameters() if p.requires_grad] - if classifier_lr > 0 and classifier_params: - param_groups.append({"params": classifier_params, "lr": classifier_lr}) - - router_params = [p for p in router.parameters() if p.requires_grad] - if router_lr > 0 and router_params: - param_groups.append({"params": router_params, "lr": router_lr}) - - expert_params: List[torch.Tensor] = [] - if expert_requires_grad: - for expert in experts: - expert_params.extend([p for p in expert.parameters() if p.requires_grad]) - if expert_params: - param_groups.append({"params": expert_params, "lr": expert_lr}) - - if not param_groups: - raise ValueError( - "No parameters selected for optimisation. Ensure at least one learning rate is > 0." - ) - - optimizer = torch.optim.AdamW(param_groups, weight_decay=weight_decay) - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer, mode="min", factor=0.5, patience=3, min_lr=1e-6 - ) - - scaler = GradScaler(enabled=torch.cuda.is_available()) - group_map = build_group_map(expert_specs, comm_to_idx) if router_loss_weight > 0 else None - - history: Dict[str, List[Any]] = { - "train_loss": [], - "train_acc": [], - "val_loss": [], - "val_acc": [], - "val_f1": [], - "train_balance": [], - "val_balance": [], - "train_router_aux": [], - "val_router_aux": [], - "train_entropy": [], - "val_entropy": [], - "train_usage": [], - "val_usage": [], - "gating_noise": [], - } - - best_val_loss = float("inf") - best_val_f1 = 0.0 - patience_counter = 0 - best_router_state: Optional[Dict[str, torch.Tensor]] = None - best_classifier_state: Optional[Dict[str, torch.Tensor]] = None - best_expert_states: Optional[List[Dict[str, torch.Tensor]]] = None - eval_interval = max(1, int(eval_interval)) - early_delta = float(max(0.0, early_delta)) - - for epoch in range(1, epochs + 1): - router.train() - classifier.train() - for expert in experts: - expert.train(expert_requires_grad) - - running_loss = 0.0 - running_balance = 0.0 - running_router_aux = 0.0 - correct = 0 - total = 0 - # Keep usage_sum on CPU to prevent memory accumulation on GPU/HPU - usage_sum = torch.zeros(len(expert_specs), device='cpu') - - desc = f"Task train {epoch:02d}" - if tqdm is not None: - pbar = tqdm(train_loader, desc=desc, leave=False, dynamic_ncols=True) - else: - pbar = train_loader - - if gating_noise_std > 0 and gating_noise_epochs > 0: - decay = max(0.0, 1.0 - (epoch - 1) / float(max(gating_noise_epochs, 1))) - epoch_noise_std = gating_noise_std * decay - else: - epoch_noise_std = gating_noise_std - - for specs, comm_labels, task_labels in pbar: - specs = specs.to(device, non_blocking=True) - comm_labels = comm_labels.to(device, non_blocking=True) - task_labels = task_labels.to(device, non_blocking=True) - optimizer.zero_grad(set_to_none=True) - - specs_norm = normalize_per_sample_tensor(specs) - balance_penalty: Optional[torch.Tensor] = None - router_aux_penalty: Optional[torch.Tensor] = None - - context = autocast(device_type=device.type, enabled=scaler.is_enabled()) - with context: - router_logits = router(specs_norm) - if epoch_noise_std > 0: - router_logits = router_logits + torch.randn_like(router_logits) * epoch_noise_std - router_probs = torch.softmax(router_logits, dim=1) - topk_vals, topk_idx = router_probs.topk(k=topk, dim=1) - weights = topk_vals / torch.clamp(topk_vals.sum(dim=1, keepdim=True), min=1e-6) - - selected_embeddings = compute_selected_expert_embeddings( - experts, - specs_norm, - topk_idx, - allow_grad=expert_requires_grad, - ) - logits_each = classifier(selected_embeddings.view(-1, selected_embeddings.size(-1))) - logits_each = logits_each.view(specs.size(0), topk, -1) - weighted_logits = (weights.unsqueeze(-1) * logits_each).sum(dim=1) - task_loss = F.cross_entropy(weighted_logits, task_labels) - - loss = task_loss - if router_loss_weight > 0 and group_map is not None: - router_aux_penalty = router_cross_entropy(router_logits, comm_labels, group_map) - loss = loss + router_loss_weight * router_aux_penalty - if load_balance_weight > 0: - avg_probs = router_probs.mean(dim=0) - uniform = torch.full_like(avg_probs, 1.0 / max(avg_probs.numel(), 1)) - balance_penalty = F.mse_loss(avg_probs, uniform) - loss = loss + load_balance_weight * balance_penalty - - if scaler.is_enabled(): - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() - else: - loss.backward() - optimizer.step() - - batch_size = specs.size(0) - running_loss += loss.detach().item() * batch_size - - if balance_penalty is not None: - running_balance += balance_penalty.detach().item() * batch_size - if router_aux_penalty is not None: - running_router_aux += router_aux_penalty.detach().item() * batch_size - - preds = weighted_logits.argmax(dim=1) - correct += (preds == task_labels).sum().item() - total += batch_size - # Move to CPU to prevent GPU/HPU memory accumulation - usage_sum = usage_sum + router_probs.detach().sum(dim=0).cpu() - - if tqdm is not None: - current_loss = running_loss / max(total, 1) - current_acc = correct / max(total, 1) - postfix: Dict[str, str] = { - "loss": f"{current_loss:.4f}", - "acc": f"{current_acc:.3f}", - } - if load_balance_weight > 0 and total > 0: - postfix["lb"] = f"{(running_balance / total):.4f}" - pbar.set_postfix(postfix) - - denom = max(total, 1) - train_loss = running_loss / denom - train_acc = correct / denom - train_balance = running_balance / denom - train_router_aux = running_router_aux / denom - train_usage_tensor = usage_sum / float(max(total, 1)) - train_usage_tensor = train_usage_tensor.clamp(min=0.0) - train_entropy = float( - -(train_usage_tensor * train_usage_tensor.clamp_min(1e-8).log()).sum().item() - ) - train_usage_list = train_usage_tensor.detach().cpu().tolist() - should_eval = (epoch % eval_interval == 0) or (epoch == epochs) - - if should_eval: - val_metrics = evaluate_task_model( - router=router, - experts=experts, - classifier=classifier, - loader=val_loader, - topk=topk, - device=device, - comm_to_idx=comm_to_idx, - expert_specs=expert_specs, - router_loss_weight=router_loss_weight, - load_balance_weight=load_balance_weight, - ) - val_loss = val_metrics["loss"] - val_acc = val_metrics["acc"] - val_f1 = val_metrics["f1"] - val_balance = val_metrics["balance"] - val_router_aux = val_metrics["router_aux"] - val_usage = val_metrics.get("usage") - val_entropy = val_metrics.get("entropy", float("nan")) - - old_lr = optimizer.param_groups[0]["lr"] - scheduler.step(val_loss) - new_lr = optimizer.param_groups[0]["lr"] - if new_lr != old_lr: - print(f"[INFO] Learning rate reduced: {old_lr:.2e} -> {new_lr:.2e}") - else: - val_loss = float("nan") - val_acc = float("nan") - val_f1 = float("nan") - val_balance = float("nan") - val_router_aux = float("nan") - val_usage = None - val_entropy = float("nan") - - history["train_loss"].append(train_loss) - history["train_acc"].append(train_acc) - history["val_loss"].append(val_loss) - history["val_acc"].append(val_acc) - history["val_f1"].append(val_f1) - history["train_balance"].append(train_balance) - history["val_balance"].append(val_balance) - history["train_router_aux"].append(train_router_aux) - history["val_router_aux"].append(val_router_aux) - history["train_entropy"].append(train_entropy) - history["val_entropy"].append(val_entropy if not math.isnan(val_entropy) else None) - history["train_usage"].append(train_usage_list) - history["val_usage"].append(val_usage if val_usage is not None else None) - history["gating_noise"].append(float(epoch_noise_std)) - - msg = [ - f"[Task] Epoch {epoch:02d}:", - f"train_loss={train_loss:.4f}", - f"train_acc={train_acc:.3f}", - ] - if load_balance_weight > 0: - msg.append(f"train_lb={train_balance:.4f}") - if router_loss_weight > 0: - msg.append(f"train_aux={train_router_aux:.4f}") - if train_usage_list: - train_usage_min = min(train_usage_list) - train_usage_max = max(train_usage_list) - msg.append(f"train_usage=({train_usage_min:.2f},{train_usage_max:.2f})") - usage_list_str = ", ".join(f"{usage:.2f}" for usage in train_usage_list) - msg.append(f"train_usage_all=[{usage_list_str}]") - msg.append(f"train_H={train_entropy:.3f}") - if epoch_noise_std > 0: - msg.append(f"noise={epoch_noise_std:.3f}") - - if should_eval: - msg.extend( - [ - f"val_loss={val_loss:.4f}", - f"val_acc={val_acc:.3f}", - f"val_f1={val_f1:.3f}", - ] - ) - if load_balance_weight > 0: - msg.append(f"val_lb={val_balance:.4f}") - if router_loss_weight > 0: - msg.append(f"val_aux={val_router_aux:.4f}") - if val_usage: - val_usage_min = min(val_usage) - val_usage_max = max(val_usage) - msg.append(f"val_usage=({val_usage_min:.2f},{val_usage_max:.2f})") - usage_list_str = ", ".join(f"{usage:.2f}" for usage in val_usage) - msg.append(f"val_usage_all=[{usage_list_str}]") - if not math.isnan(val_entropy): - msg.append(f"val_H={val_entropy:.3f}") - else: - msg.append("validation skipped") - print(" ".join(msg)) - - if should_eval: - improved = (val_loss + early_delta) < best_val_loss - if improved: - best_val_loss = val_loss - best_val_f1 = val_f1 - patience_counter = 0 - # Free old state before saving new one to prevent memory accumulation - if best_router_state is not None: - del best_router_state - if best_classifier_state is not None: - del best_classifier_state - if best_expert_states is not None: - del best_expert_states - best_router_state = {k: v.cpu().clone() for k, v in router.state_dict().items()} - best_classifier_state = {k: v.cpu().clone() for k, v in classifier.state_dict().items()} - if expert_requires_grad: - best_expert_states = [ - {k: v.cpu().clone() for k, v in expert.state_dict().items()} - for expert in experts - ] - else: - patience_counter += 1 - if patience_counter >= patience: - print(f"[INFO] Early stopping triggered after {epoch} epochs without val_loss improvement") - break - - if checkpoint_callback is not None: - try: - checkpoint_callback(epoch) - except Exception as e: - print(f"[WARN] Epoch {epoch:02d} checkpoint save failed: {e}") - - # Explicit memory cleanup at end of epoch - if device.type == 'hpu': - try: - import habana_frameworks.torch.core as htcore - htcore.mark_step() - except (ImportError, AttributeError): - pass - elif torch.cuda.is_available(): - torch.cuda.empty_cache() - - if best_router_state is not None: - router.load_state_dict({k: v.to(device) for k, v in best_router_state.items()}) - if best_classifier_state is not None: - classifier.load_state_dict({k: v.to(device) for k, v in best_classifier_state.items()}) - if best_expert_states is not None: - for expert, state in zip(experts, best_expert_states): - expert.load_state_dict({k: v.to(device) for k, v in state.items()}) - print(f"[INFO] Restored best expert checkpoints (val_f1={best_val_f1:.3f})") - elif best_router_state is not None and best_classifier_state is not None: - print(f"[INFO] Restored best model with val_f1={best_val_f1:.3f}") - - return history - - -@torch.no_grad() -def evaluate_task_model( - *, - router: RouterNet, - experts: Sequence[EmbeddingExpert], - classifier: TaskClassifier, - loader: DataLoader, - topk: int, - device: torch.device, - comm_to_idx: Mapping[str, int], - expert_specs: Sequence[ExpertSpec], - router_loss_weight: float, - load_balance_weight: float, -) -> Dict[str, Any]: - """Evaluate task model and return aggregate metrics.""" - router.eval() - classifier.eval() - for expert in experts: - expert.eval() - - group_map = build_group_map(expert_specs, comm_to_idx) if router_loss_weight > 0 else None - - total_loss = 0.0 - total_balance = 0.0 - total_router_aux = 0.0 - total_samples = 0 - usage_sum: Optional[torch.Tensor] = None - all_preds: List[int] = [] - all_targets: List[int] = [] - - for specs, comm_labels, task_labels in iterate_batches(loader, "Task eval"): - specs = specs.to(device, non_blocking=True) - comm_labels = comm_labels.to(device, non_blocking=True) - task_labels = task_labels.to(device, non_blocking=True) - batch_size = specs.size(0) - - specs_norm = normalize_per_sample_tensor(specs) - router_logits = router(specs_norm) - router_probs = torch.softmax(router_logits, dim=1) - topk_vals, topk_idx = router_probs.topk(k=topk, dim=1) - weights = topk_vals / torch.clamp(topk_vals.sum(dim=1, keepdim=True), min=1e-6) - if usage_sum is None: - usage_sum = torch.zeros(router_probs.size(1), device=device) - usage_sum = usage_sum + router_probs.sum(dim=0) - - selected_embeddings = compute_selected_expert_embeddings( - experts, - specs_norm, - topk_idx, - allow_grad=False, - ) - logits_each = classifier(selected_embeddings.view(-1, selected_embeddings.size(-1))) - logits_each = logits_each.view(batch_size, topk, -1) - weighted_logits = (weights.unsqueeze(-1) * logits_each).sum(dim=1) - - task_loss = F.cross_entropy(weighted_logits, task_labels) - loss = task_loss - - if router_loss_weight > 0 and group_map is not None: - router_aux = router_cross_entropy(router_logits, comm_labels, group_map) - total_router_aux += router_aux.item() * batch_size - loss = loss + router_loss_weight * router_aux - - if load_balance_weight > 0: - avg_probs = router_probs.mean(dim=0) - uniform = torch.full_like(avg_probs, 1.0 / max(avg_probs.numel(), 1)) - balance_penalty = F.mse_loss(avg_probs, uniform) - total_balance += balance_penalty.item() * batch_size - loss = loss + load_balance_weight * balance_penalty - - total_loss += loss.item() * batch_size - total_samples += batch_size - - preds = weighted_logits.argmax(dim=1) - all_preds.extend(preds.cpu().tolist()) - all_targets.extend(task_labels.cpu().tolist()) - - all_preds_arr = np.array(all_preds) - all_targets_arr = np.array(all_targets) - acc = float((all_preds_arr == all_targets_arr).mean()) if len(all_preds_arr) > 0 else 0.0 - - f1 = 0.0 - if SKLEARN_AVAILABLE and len(all_preds_arr) > 0: - try: - f1 = float(f1_score(all_targets_arr, all_preds_arr, average="weighted", zero_division=0)) - except Exception: - pass - - denom = max(total_samples, 1) - if usage_sum is not None: - avg_usage = usage_sum / float(denom) - avg_usage = avg_usage.clamp(min=0.0) - entropy = float(-(avg_usage * avg_usage.clamp_min(1e-8).log()).sum().item()) - usage_list: Optional[List[float]] = avg_usage.detach().cpu().tolist() - else: - usage_list = None - entropy = float("nan") - return { - "loss": total_loss / denom, - "acc": acc, - "f1": f1, - "balance": (total_balance / denom) if load_balance_weight > 0 else 0.0, - "router_aux": (total_router_aux / denom) if router_loss_weight > 0 else 0.0, - "usage": usage_list, - "entropy": entropy, - } - - -def train_oracle_baseline( - *, - experts: Sequence[EmbeddingExpert], - expert_specs: Sequence[ExpertSpec], - comm_to_idx: Mapping[str, int], - classifier: TaskClassifier, - train_loader: DataLoader, - val_loader: DataLoader, - device: torch.device, - epochs: int, - lr: float, - weight_decay: float, - patience: int = 10, -) -> Dict[str, List[float]]: - """Train oracle baseline: use ground-truth communication labels to select experts.""" - # Build comm -> expert index mapping (only baseline experts) - comm_to_expert_idx = build_baseline_expert_map(expert_specs, comm_to_idx) - - optimizer = torch.optim.AdamW(classifier.parameters(), lr=lr, weight_decay=weight_decay) - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer, mode='min', factor=0.5, patience=3, min_lr=1e-6 - ) - scaler = GradScaler(enabled=torch.cuda.is_available()) - history = {"train_loss": [], "val_loss": [], "val_acc": [], "val_f1": []} - - best_val_f1 = 0.0 - patience_counter = 0 - best_classifier_state = None - - for epoch in range(1, epochs + 1): - classifier.train() - running_loss = 0.0 - total = 0 - desc = f"Oracle train {epoch:02d}" - - for specs, comm_labels, task_labels in iterate_batches(train_loader, desc): - specs = specs.to(device, non_blocking=True) - comm_labels = comm_labels.to(device, non_blocking=True) - task_labels = task_labels.to(device, non_blocking=True) - optimizer.zero_grad(set_to_none=True) - - context = autocast(device_type=device.type, enabled=scaler.is_enabled()) - with context: - # Use ground-truth comm labels to select experts - with torch.no_grad(): - embeddings = stack_expert_embeddings(experts, specs) - - # Select expert based on ground-truth comm label - batch_size = specs.size(0) - selected_embeddings = [] - for i in range(batch_size): - comm_idx = int(comm_labels[i].item()) - expert_idx = comm_to_expert_idx[comm_idx] # Will raise KeyError if missing - selected_embeddings.append(embeddings[i, expert_idx]) - selected_embeddings = torch.stack(selected_embeddings, dim=0) - - logits = classifier(selected_embeddings) - loss = F.cross_entropy(logits, task_labels) - - if scaler.is_enabled(): - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() - else: - loss.backward() - optimizer.step() - - running_loss += loss.item() * specs.size(0) - total += specs.size(0) - - train_loss = running_loss / max(total, 1) - val_loss, val_acc, val_f1 = evaluate_oracle_baseline( - experts=experts, - expert_specs=expert_specs, - comm_to_idx=comm_to_idx, - comm_to_expert_idx=comm_to_expert_idx, - classifier=classifier, - loader=val_loader, - device=device, - ) - - old_lr = optimizer.param_groups[0]['lr'] - scheduler.step(val_loss) - new_lr = optimizer.param_groups[0]['lr'] - if new_lr != old_lr: - print(f"[INFO] Learning rate reduced: {old_lr:.2e} -> {new_lr:.2e}") - - history["train_loss"].append(train_loss) - history["val_loss"].append(val_loss) - history["val_acc"].append(val_acc) - history["val_f1"].append(val_f1) - - print( - f"[Oracle] Epoch {epoch:02d}: train_loss={train_loss:.4f} " - f"val_loss={val_loss:.4f} val_acc={val_acc:.3f} val_f1={val_f1:.3f}" - ) - - if val_f1 > best_val_f1: - best_val_f1 = val_f1 - patience_counter = 0 - best_classifier_state = {k: v.cpu().clone() for k, v in classifier.state_dict().items()} - else: - patience_counter += 1 - if patience_counter >= patience: - print(f"[INFO] Early stopping triggered after {epoch} epochs") - break - - if best_classifier_state is not None: - classifier.load_state_dict({k: v.to(device) for k, v in best_classifier_state.items()}) - print(f"[INFO] Restored best model with val_f1={best_val_f1:.3f}") - - return history - - -@torch.no_grad() -def evaluate_oracle_baseline( - *, - experts: Sequence[EmbeddingExpert], - expert_specs: Sequence[ExpertSpec], - comm_to_idx: Mapping[str, int], - comm_to_expert_idx: Mapping[int, int], - classifier: TaskClassifier, - loader: DataLoader, - device: torch.device, -) -> Tuple[float, float, float]: - """Evaluate oracle baseline.""" - classifier.eval() - total_loss = 0.0 - all_preds: List[int] = [] - all_targets: List[int] = [] - - for specs, comm_labels, task_labels in iterate_batches(loader, "Oracle eval"): - specs = specs.to(device, non_blocking=True) - comm_labels = comm_labels.to(device, non_blocking=True) - task_labels = task_labels.to(device, non_blocking=True) - - embeddings = stack_expert_embeddings(experts, specs) - - batch_size = specs.size(0) - selected_embeddings = [] - for i in range(batch_size): - comm_idx = int(comm_labels[i].item()) - expert_idx = comm_to_expert_idx[comm_idx] # Will raise KeyError if missing - selected_embeddings.append(embeddings[i, expert_idx]) - selected_embeddings = torch.stack(selected_embeddings, dim=0) - - logits = classifier(selected_embeddings) - loss = F.cross_entropy(logits, task_labels) - total_loss += loss.item() * specs.size(0) - - preds = logits.argmax(dim=1) - all_preds.extend(preds.cpu().tolist()) - all_targets.extend(task_labels.cpu().tolist()) - - all_preds_arr = np.array(all_preds) - all_targets_arr = np.array(all_targets) - acc = float((all_preds_arr == all_targets_arr).mean()) - - f1 = 0.0 - if SKLEARN_AVAILABLE and len(all_preds_arr) > 0: - try: - f1 = float(f1_score(all_targets_arr, all_preds_arr, average='weighted', zero_division=0)) - except Exception: - pass - - total = len(all_preds_arr) - return total_loss / max(total, 1), acc, f1 - - -class SingleModelBackbone(nn.Module): - """Simple CNN backbone for single model baseline.""" - - def __init__(self, dropout: float = 0.1) -> None: - super().__init__() - self.features = nn.Sequential( - nn.Conv2d(1, 32, kernel_size=5, stride=2, padding=2), - nn.BatchNorm2d(32), - nn.SiLU(inplace=True), - nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), - nn.BatchNorm2d(64), - nn.SiLU(inplace=True), - nn.Conv2d(64, 96, kernel_size=3, stride=2, padding=1), - nn.BatchNorm2d(96), - nn.SiLU(inplace=True), - nn.Conv2d(96, 128, kernel_size=3, stride=2, padding=1), - nn.BatchNorm2d(128), - nn.SiLU(inplace=True), - nn.AdaptiveAvgPool2d((1, 1)), - nn.Flatten(), - ) - if dropout > 0: - self.features.add_module('dropout', nn.Dropout(dropout)) - - def forward(self, specs: torch.Tensor) -> torch.Tensor: - x = specs - if x.dim() == 3: - x = x.unsqueeze(1) - elif x.dim() != 4: - raise ValueError(f"Expected specs rank 3 or 4, got shape {tuple(specs.shape)}") - return self.features(x) - - -class ImageNetBackbone(nn.Module): - """ImageNet pretrained backbone (ResNet18) for baseline.""" - - def __init__(self, dropout: float = 0.1, freeze_backbone: bool = False) -> None: - super().__init__() - try: - import torchvision.models as models - except ImportError: - raise ImportError("torchvision is required for ImageNet backbone. Install with: pip install torchvision") - - # Load pretrained ResNet18 - resnet = models.resnet18(pretrained=True) - - # Convert grayscale to RGB by replicating first conv layer - # Original: Conv2d(3, 64, kernel_size=7, stride=2, padding=3) - # New: Conv2d(1, 64, kernel_size=7, stride=2, padding=3) - original_conv1 = resnet.conv1 - self.conv1 = nn.Conv2d( - 1, 64, kernel_size=7, stride=2, padding=3, bias=False - ) - # Initialize by averaging RGB weights - with torch.no_grad(): - self.conv1.weight.data = original_conv1.weight.data.mean(dim=1, keepdim=True) - - self.bn1 = resnet.bn1 - self.relu = resnet.relu - self.maxpool = resnet.maxpool - self.layer1 = resnet.layer1 - self.layer2 = resnet.layer2 - self.layer3 = resnet.layer3 - self.layer4 = resnet.layer4 - self.avgpool = resnet.avgpool - - # Freeze backbone if requested - if freeze_backbone: - for param in [self.conv1, self.bn1, self.layer1, self.layer2, self.layer3, self.layer4]: - for p in param.parameters(): - p.requires_grad = False - - # Project from 512-d (ResNet18 output) to 128-d - self.projection = nn.Sequential( - nn.Flatten(), - nn.Linear(512, 128), - nn.ReLU(inplace=True), - ) - if dropout > 0: - self.projection.add_module('dropout', nn.Dropout(dropout)) - - def forward(self, specs: torch.Tensor) -> torch.Tensor: - x = specs - if x.dim() == 3: - x = x.unsqueeze(1) - elif x.dim() != 4: - raise ValueError(f"Expected specs rank 3 or 4, got shape {tuple(specs.shape)}") - - # ResNet forward pass - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - x = self.maxpool(x) - - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - - x = self.avgpool(x) - x = self.projection(x) - - return x - - -def train_single_model( - *, - backbone: SingleModelBackbone, - classifier: TaskClassifier, - train_loader: DataLoader, - val_loader: DataLoader, - device: torch.device, - epochs: int, - lr: float, - weight_decay: float, - patience: int = 10, - eval_interval: int = 1, -) -> Dict[str, List[float]]: - """Train single model baseline: one model for all communication types.""" - optimizer = torch.optim.AdamW( - list(backbone.parameters()) + list(classifier.parameters()), - lr=lr, - weight_decay=weight_decay, - ) - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer, mode='min', factor=0.5, patience=3, min_lr=1e-6 - ) - scaler = GradScaler(enabled=torch.cuda.is_available()) - history = {"train_loss": [], "val_loss": [], "val_acc": [], "val_f1": []} - - best_val_f1 = 0.0 - patience_counter = 0 - best_backbone_state = None - best_classifier_state = None - eval_interval = max(1, int(eval_interval)) - - for epoch in range(1, epochs + 1): - backbone.train() - classifier.train() - running_loss = 0.0 - total = 0 - desc = f"Single train {epoch:02d}" - - for specs, _, task_labels in iterate_batches(train_loader, desc): - specs = specs.to(device, non_blocking=True) - task_labels = task_labels.to(device, non_blocking=True) - optimizer.zero_grad(set_to_none=True) - - # Normalize per-sample - specs_norm = normalize_per_sample_tensor(specs) - - context = autocast(device_type=device.type, enabled=scaler.is_enabled()) - with context: - embeddings = backbone(specs_norm) - logits = classifier(embeddings) - loss = F.cross_entropy(logits, task_labels) - - if scaler.is_enabled(): - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() - else: - loss.backward() - optimizer.step() - - running_loss += loss.item() * specs.size(0) - total += specs.size(0) - - train_loss = running_loss / max(total, 1) - should_eval = (epoch % eval_interval == 0) or (epoch == epochs) - if should_eval: - val_loss, val_acc, val_f1 = evaluate_single_model( - backbone=backbone, - classifier=classifier, - loader=val_loader, - device=device, - ) - old_lr = optimizer.param_groups[0]['lr'] - scheduler.step(val_loss) - new_lr = optimizer.param_groups[0]['lr'] - if new_lr != old_lr: - print(f"[INFO] Learning rate reduced: {old_lr:.2e} -> {new_lr:.2e}") - else: - val_loss = float("nan") - val_acc = float("nan") - val_f1 = float("nan") - - history["train_loss"].append(train_loss) - history["val_loss"].append(val_loss) - history["val_acc"].append(val_acc) - history["val_f1"].append(val_f1) - - if should_eval: - print( - f"[Single] Epoch {epoch:02d}: train_loss={train_loss:.4f} " - f"val_loss={val_loss:.4f} val_acc={val_acc:.3f} val_f1={val_f1:.3f}" - ) - else: - print(f"[Single] Epoch {epoch:02d}: train_loss={train_loss:.4f} (validation skipped)") - - if should_eval: - if val_f1 > best_val_f1: - best_val_f1 = val_f1 - patience_counter = 0 - best_backbone_state = {k: v.cpu().clone() for k, v in backbone.state_dict().items()} - best_classifier_state = {k: v.cpu().clone() for k, v in classifier.state_dict().items()} - else: - patience_counter += 1 - if patience_counter >= patience: - print(f"[INFO] Early stopping triggered after {epoch} epochs") - break - - if best_backbone_state is not None and best_classifier_state is not None: - backbone.load_state_dict({k: v.to(device) for k, v in best_backbone_state.items()}) - classifier.load_state_dict({k: v.to(device) for k, v in best_classifier_state.items()}) - print(f"[INFO] Restored best model with val_f1={best_val_f1:.3f}") - - return history - - -@torch.no_grad() -def evaluate_single_model( - *, - backbone: SingleModelBackbone, - classifier: TaskClassifier, - loader: DataLoader, - device: torch.device, -) -> Tuple[float, float, float]: - """Evaluate single model baseline.""" - backbone.eval() - classifier.eval() - total_loss = 0.0 - all_preds: List[int] = [] - all_targets: List[int] = [] - - for specs, _, task_labels in iterate_batches(loader, "Single eval"): - specs = specs.to(device, non_blocking=True) - task_labels = task_labels.to(device, non_blocking=True) - - specs_norm = normalize_per_sample_tensor(specs) - embeddings = backbone(specs_norm) - logits = classifier(embeddings) - loss = F.cross_entropy(logits, task_labels) - total_loss += loss.item() * specs.size(0) - - preds = logits.argmax(dim=1) - all_preds.extend(preds.cpu().tolist()) - all_targets.extend(task_labels.cpu().tolist()) - - all_preds_arr = np.array(all_preds) - all_targets_arr = np.array(all_targets) - acc = float((all_preds_arr == all_targets_arr).mean()) - - f1 = 0.0 - if SKLEARN_AVAILABLE and len(all_preds_arr) > 0: - try: - f1 = float(f1_score(all_targets_arr, all_preds_arr, average='weighted', zero_division=0)) - except Exception: - pass - - total = len(all_preds_arr) - return total_loss / max(total, 1), acc, f1 - - -@torch.no_grad() -def evaluate_test_metrics( - *, - router: RouterNet, - experts: Sequence[EmbeddingExpert], - classifier: TaskClassifier, - loader: DataLoader, - topk: int, - device: torch.device, -) -> Dict[str, object]: - router.eval() - classifier.eval() - for expert in experts: - expert.eval() - all_preds: List[int] = [] - all_targets: List[int] = [] - coverage = torch.zeros(len(experts), dtype=torch.float64) - for specs, _, task_labels in iterate_batches(loader, "Task test"): - specs = specs.to(device, non_blocking=True) - task_labels = task_labels.to(device, non_blocking=True) - # Normalize once and reuse - specs_norm = normalize_per_sample_tensor(specs) - router_probs = torch.softmax(router(specs_norm), dim=1) - topk_vals, topk_idx = router_probs.topk(k=topk, dim=1) - weights = topk_vals / torch.clamp(topk_vals.sum(dim=1, keepdim=True), min=1e-6) - # Run only selected experts with pre-normalized specs - selected_embeddings = compute_selected_expert_embeddings( - experts, - specs_norm, - topk_idx, - allow_grad=False, - ) - logits_each = classifier(selected_embeddings.view(-1, selected_embeddings.size(-1))) - logits_each = logits_each.view(specs.size(0), topk, -1) - weighted_logits = (weights.unsqueeze(-1) * logits_each).sum(dim=1) - preds = weighted_logits.argmax(dim=1) - all_preds.extend(preds.detach().cpu().tolist()) - all_targets.extend(task_labels.detach().cpu().tolist()) - for b in range(specs.size(0)): - for rank in range(topk): - coverage[topk_idx[b, rank].item()] += float(weights[b, rank].item()) - all_preds_arr = np.array(all_preds, dtype=np.int64) - all_targets_arr = np.array(all_targets, dtype=np.int64) - acc = float((all_preds_arr == all_targets_arr).mean()) - - # Compute F1 score (weighted) - f1 = 0.0 - conf = None - - if SKLEARN_AVAILABLE and len(all_preds_arr) > 0: - try: - from sklearn.metrics import confusion_matrix - conf = confusion_matrix(all_targets_arr, all_preds_arr).tolist() - f1 = float(f1_score(all_targets_arr, all_preds_arr, average='weighted', zero_division=0)) - except Exception as e: - print(f"[WARN] Could not compute sklearn metrics: {e}") - - coverage_dict = { - idx: float(coverage[idx].item()) for idx in range(len(experts)) - } - return { - "test_accuracy": acc, - "test_f1": f1, - "confusion_matrix": conf, - "coverage": coverage_dict, - } - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument("--data-root", type=Path, default=Path("spectrograms"), help="Root directory with spectrogram data") - parser.add_argument("--cities", nargs="*", default=["city_1_losangeles"], help="City folders to include") - parser.add_argument("--comm-types", nargs="*", default=["LTE", "WiFi", "5G"], help="Communication standards to model") - parser.add_argument("--snrs", nargs="*", default=None, help="SNR folders to include") - parser.add_argument( - "--mobilities", - nargs="*", - default=["pedestrian", "vehicular"], - help="Mobility folders to include (default: pedestrian vehicular)", - ) - parser.add_argument("--modulations", nargs="*", default=None, help="Modulation classes to include (default: all)") - parser.add_argument("--fft-folders", nargs="*", default=None, help="Specific FFT/window folders to include") - parser.add_argument("--task", choices=("modulation", "snr_mobility"), default="snr_mobility", help="Downstream task label") - parser.add_argument("--max-samples-per-comm", type=int, default=0, help="Maximum samples per communication profile (0=use all available data)") - parser.add_argument("--max-per-combo", type=int, default=0, help="Cap per (modulation,SNR,mobility) combo (0=unbounded, use all available)") - parser.add_argument("--seed", type=int, default=42, help="Random seed") - parser.add_argument("--train-ratio", type=float, default=0.7, help="Fraction of data for training") - parser.add_argument("--val-ratio", type=float, default=0.15, help="Fraction of data for validation") - parser.add_argument("--max-samples-per-class", type=int, default=0, help="Maximum training samples per task class (0=no cap)") - parser.add_argument("--val-samples-per-class", type=int, default=0, help="Validation samples per task class (0=use fraction)") - parser.add_argument("--test-samples-per-class", type=int, default=0, help="Test samples per task class (0=use remaining)") - parser.add_argument("--batch-size", type=int, default=64, help="Mini-batch size (optimized for speed and memory)") - parser.add_argument("--gradient-accumulation-steps", type=int, default=1, help="Accumulate gradients over N steps (effective batch = batch_size * N)") - parser.add_argument( - "--router-epochs", - type=int, - default=2, - help="Warm-up epochs for router pre-training (default: 2; set to 0 to skip)", - ) - parser.add_argument("--task-epochs", type=int, default=25, help="Joint training epochs for classifier and router") - parser.add_argument("--router-lr", type=float, default=5e-4, help="Learning rate for router during joint training (increased for faster convergence)") - parser.add_argument("--router-warmup-lr", type=float, default=3e-4, help="Learning rate during router warm-up") - parser.add_argument("--classifier-lr", type=float, default=2e-3, help="Learning rate for task classifier (increased for faster convergence)") - parser.add_argument("--expert-lr", type=float, default=5e-5, help="Learning rate for expert fine-tuning (0 keeps experts frozen)") - parser.add_argument("--weight-decay", type=float, default=1e-4, help="Weight decay for optimizers") - parser.add_argument("--router-loss-weight", type=float, default=0.05, help="Weight for communication auxiliary loss") - parser.add_argument("--load-balance-weight", type=float, default=0.05, help="Weight for expert load-balancing regulariser (0 disables)") - parser.add_argument("--dropout", type=float, default=0.1, help="Dropout probability for router and classifier heads") - parser.add_argument("--routing-topk", type=int, default=2, help="Number of experts to keep per sample") - parser.add_argument("--num-workers", type=int, default=4, help="DataLoader workers (4 recommended for optimal performance)") - parser.add_argument("--preload-data", action="store_true", default=True, help="Preload all data into RAM for faster training") - parser.add_argument("--no-preload-data", dest="preload_data", action="store_false", help="Disable data preloading") - parser.add_argument("--patience", type=int, default=10, help="Early stopping patience (epochs)") - parser.add_argument( - "--expert", - action="append", - default=None, - help="Optional manual expert definition NAME=COMM:checkpoint[:stats_path]", - ) - parser.add_argument("--output-dir", type=Path, default=Path("mixture/runs/embedding_router"), help="Directory for outputs") - parser.add_argument("--save-router", action="store_true", help="Save trained router state_dict") - parser.add_argument("--save-classifier", action="store_true", help="Save trained classifier state_dict") - parser.add_argument( - "--resume-checkpoint", - type=Path, - default=None, - help="Resume MoE training from an existing checkpoint (router/classifier fine-tuning)", - ) - parser.add_argument( - "--resume-router-warmup", - action="store_true", - help="When resuming, run the router warm-up stage before joint training", - ) - parser.add_argument( - "--baseline", - choices=["oracle", "single", "imagenet"], - default=None, - help="Baseline mode: 'oracle' uses ground-truth comm labels with baseline experts, 'single' trains a single CNN model, 'imagenet' uses pretrained ResNet18", - ) - parser.add_argument( - "--gating-noise-std", - type=float, - default=0.1, - help="Stddev of Gaussian noise added to router logits during early training (0 disables)", - ) - parser.add_argument( - "--gating-noise-epochs", - type=int, - default=5, - help="Number of epochs over which gating noise decays to zero (0 keeps constant std while enabled)", - ) - parser.add_argument( - "--freeze-backbone", - action="store_true", - help="Freeze ImageNet backbone weights (only train projection and classifier)", - ) - return parser.parse_args() - - -def parse_manual_expert(entry: str) -> ExpertSpec: - if "=" not in entry: - raise ValueError(f"Expert definition must use NAME=COMM:path syntax (got: {entry})") - name_part, _, payload = entry.partition("=") - if ":" not in payload: - raise ValueError(f"Expert definition missing COMM:path separator (got: {entry})") - comm_part, _, remainder = payload.partition(":") - comm = canonical_comm_name(comm_part) - if ":" in remainder: - checkpoint_str, stats_str = remainder.split(":", 1) - stats_path = Path(stats_str).expanduser().resolve() - if not stats_path.exists(): - print(f"[WARN] Stats file not found: {stats_path}; will use per-sample normalization") - stats_path = None - else: - checkpoint_str = remainder - # Stats path is optional - we use per-sample normalization - stats_path_candidate = Path(REPO_ROOT / "models" / f"{comm}_models" / "dataset_stats.json").resolve() - stats_path = stats_path_candidate if stats_path_candidate.exists() else None - checkpoint = Path(checkpoint_str).expanduser().resolve() - if not checkpoint.exists(): - raise FileNotFoundError(f"Expert checkpoint does not exist: {checkpoint}") - return ExpertSpec(name=name_part.strip(), comm=comm, checkpoint=checkpoint, stats_path=stats_path) - - -def load_experts( - specs: Sequence[ExpertSpec], - device: torch.device, - *, - trainable: bool = False, -) -> List[EmbeddingExpert]: - embeddings: List[EmbeddingExpert] = [] - for spec in specs: - print( - f"[INFO] Loading expert '{spec.name}' ({spec.comm}) from {spec.checkpoint}" - + (" [trainable]" if trainable else "") - ) - embeddings.append(EmbeddingExpert(spec, device, trainable=trainable)) - return embeddings - - -def build_baseline_expert_map( - expert_specs: Sequence[ExpertSpec], - comm_to_idx: Mapping[str, int], -) -> Dict[int, int]: - """Build mapping from communication type to baseline expert index.""" - comm_to_expert_idx: Dict[int, int] = {} - for expert_idx, spec in enumerate(expert_specs): - # Select baseline experts: either has "baseline" in path or not in "task2" folder - is_baseline = "baseline" in str(spec.checkpoint).lower() or "task2" not in str(spec.checkpoint) - if is_baseline: - comm_idx = comm_to_idx[spec.comm] - if comm_idx not in comm_to_expert_idx: - comm_to_expert_idx[comm_idx] = expert_idx - - # Validate that we have experts for all communication types - if len(comm_to_expert_idx) != len(comm_to_idx): - missing_indices = set(comm_to_idx.keys()) - set(comm_to_expert_idx.keys()) - missing_names = [name for name, idx in comm_to_idx.items() if idx in missing_indices] - raise RuntimeError(f"Missing baseline experts for communication types: {missing_names}") - - return comm_to_expert_idx - - -def sanitize_history_for_serialization(history: Dict[str, List[Any]]) -> Dict[str, List[Any]]: - """Convert NaN/Inf values in history to None for safe serialization.""" - - def _sanitize(value: Any) -> Any: - if isinstance(value, float): - if math.isnan(value) or math.isinf(value): - return None - return value - if isinstance(value, list): - return [_sanitize(item) for item in value] - if isinstance(value, tuple): - return [_sanitize(item) for item in value] - return value - - return {key: [_sanitize(entry) for entry in values] for key, values in history.items()} - - -def write_training_metrics_csv( - history: Dict[str, List[Any]], - expert_specs: Sequence[ExpertSpec], - csv_path: Path, -) -> None: - """Write per-epoch training metrics to CSV with expert usage columns.""" - epochs = len(history.get("train_loss", [])) - if epochs == 0: - return - - csv_path.parent.mkdir(parents=True, exist_ok=True) - - fieldnames = [ - "epoch", - "train_loss", - "train_acc", - "val_loss", - "val_acc", - "val_f1", - "train_balance", - "val_balance", - "train_router_aux", - "val_router_aux", - "train_entropy", - "val_entropy", - "gating_noise", - ] - usage_headers: List[str] = [] - for spec in expert_specs: - usage_headers.append(f"train_usage[{spec.name}]") - usage_headers.append(f"val_usage[{spec.name}]") - fieldnames.extend(usage_headers) - - with csv_path.open("w", newline="", encoding="utf-8") as fh: - writer = csv.DictWriter(fh, fieldnames=fieldnames) - writer.writeheader() - for epoch_idx in range(epochs): - row: Dict[str, Any] = { - "epoch": epoch_idx + 1, - "train_loss": history["train_loss"][epoch_idx], - "train_acc": history["train_acc"][epoch_idx], - "val_loss": history["val_loss"][epoch_idx], - "val_acc": history["val_acc"][epoch_idx], - "val_f1": history["val_f1"][epoch_idx], - "train_balance": history["train_balance"][epoch_idx], - "val_balance": history["val_balance"][epoch_idx], - "train_router_aux": history["train_router_aux"][epoch_idx], - "val_router_aux": history["val_router_aux"][epoch_idx], - "train_entropy": history["train_entropy"][epoch_idx], - "val_entropy": history["val_entropy"][epoch_idx], - "gating_noise": history["gating_noise"][epoch_idx], - } - - train_usage = history["train_usage"][epoch_idx] or [None] * len(expert_specs) - if len(train_usage) < len(expert_specs): - train_usage = list(train_usage) + [None] * (len(expert_specs) - len(train_usage)) - val_usage = history["val_usage"][epoch_idx] - if val_usage is None: - val_usage = [None] * len(expert_specs) - elif len(val_usage) < len(expert_specs): - val_usage = list(val_usage) + [None] * (len(expert_specs) - len(val_usage)) - - for spec, usage_value in zip(expert_specs, train_usage): - row[f"train_usage[{spec.name}]"] = usage_value - for spec, usage_value in zip(expert_specs, val_usage): - row[f"val_usage[{spec.name}]"] = usage_value - - writer.writerow(row) - - -def save_complete_checkpoint( - *, - router: Optional[RouterNet], - classifier: TaskClassifier, - expert_models: Optional[Sequence[EmbeddingExpert]], - expert_specs: Sequence[ExpertSpec], - comm_to_idx: Mapping[str, int], - task_type: str, - num_classes: int, - topk: int, - dropout: float, - mapping: Optional[Dict[int, Tuple[str, str]]], - output_path: Path, - model_type: str = "embedding_router_moe", - backbone_state_dict: Optional[Dict[str, torch.Tensor]] = None, - backbone_meta: Optional[Dict[str, Any]] = None, - expert_trainable: bool = False, -) -> None: - """Save complete MoE checkpoint for inference.""" - checkpoint = { - "model_type": model_type, - "task": task_type, - "num_classes": num_classes, - "topk": topk, - "dropout": dropout, - "comm_to_idx": dict(comm_to_idx), - "experts": [ - { - "name": spec.name, - "comm": spec.comm, - "checkpoint": str(spec.checkpoint), - "stats_path": str(spec.stats_path) if spec.stats_path else None, - } - for spec in expert_specs - ], - "classifier_state_dict": classifier.state_dict(), - "mapping": {int(k): v for k, v in mapping.items()} if mapping else None, - "expert_trainable": bool(expert_trainable), - } - if router is not None: - checkpoint["router_state_dict"] = router.state_dict() - if backbone_state_dict is not None: - checkpoint["backbone_state_dict"] = backbone_state_dict - if backbone_meta is not None: - checkpoint["backbone_meta"] = backbone_meta - if expert_models is not None: - def _is_trainable(expert: nn.Module) -> bool: - # Handle plain EmbeddingExpert or DataParallel-wrapped - if hasattr(expert, "trainable"): - return bool(getattr(expert, "trainable")) - if hasattr(expert, "module") and hasattr(expert.module, "trainable"): # type: ignore[attr-defined] - return bool(expert.module.trainable) # type: ignore[attr-defined] - return False - - trainable_flags = [_is_trainable(expert) for expert in expert_models] - if any(trainable_flags): - checkpoint["expert_state_dicts"] = [ - { - "name": spec.name, - "state_dict": {k: v.cpu() for k, v in expert.state_dict().items()}, - } - for spec, expert in zip(expert_specs, expert_models) - ] - - torch.save(checkpoint, output_path) - print(f"[INFO] Complete checkpoint saved to {output_path}") - - -def _resolve_repo_path(path_str: str) -> Path: - """Resolve paths saved inside checkpoints relative to the repository root.""" - path = Path(path_str).expanduser() - if path.is_absolute(): - if path.exists(): - return path - # Fallback for checkpoints saved with absolute training paths - repo_name = REPO_ROOT.name - if repo_name in path.parts: - try: - repo_idx = path.parts.index(repo_name) - candidate = REPO_ROOT.joinpath(*path.parts[repo_idx + 1 :]) - if candidate.exists(): - return candidate - except ValueError: - pass - return path - return (REPO_ROOT / path).resolve() - - -def _checkpoint_to_expert_specs(checkpoint: Mapping[str, Any]) -> List[ExpertSpec]: - specs: List[ExpertSpec] = [] - for expert in checkpoint.get("experts", []): - checkpoint_path = _resolve_repo_path(expert["checkpoint"]) - if not checkpoint_path.exists(): - raise FileNotFoundError( - f"Expert checkpoint referenced in resume file missing: {checkpoint_path}" - ) - stats_path = None - stats_path_str = expert.get("stats_path") - if stats_path_str: - stats_candidate = _resolve_repo_path(stats_path_str) - if stats_candidate.exists(): - stats_path = stats_candidate - else: - print( - f"[WARN] Stats file referenced in checkpoint missing: {stats_candidate}; " - "defaulting to per-sample normalization" - ) - specs.append( - ExpertSpec( - name=expert["name"], - comm=canonical_comm_name(expert["comm"]), - checkpoint=checkpoint_path, - stats_path=stats_path, - ) - ) - return specs - - -def _normalize_comm_mapping(comm_to_idx_raw: Mapping[str, Any]) -> Dict[str, int]: - normalized: Dict[str, int] = {} - for key, idx in comm_to_idx_raw.items(): - normalized[canonical_comm_name(str(key))] = int(idx) - return normalized - - -def _normalize_label_mapping(mapping_raw: Optional[Mapping[Any, Any]]) -> Optional[Dict[int, Tuple[str, str]]]: - if mapping_raw is None: - return None - mapping: Dict[int, Tuple[str, str]] = {} - for key, value in mapping_raw.items(): - idx = int(key) - if isinstance(value, (list, tuple)) and len(value) == 2: - mapping[idx] = (str(value[0]), str(value[1])) - else: - raise ValueError(f"Unexpected mapping entry for class {idx}: {value!r}") - return mapping - - -def _build_checkpoint_components( - checkpoint: Mapping[str, Any], - device: torch.device, - *, - train_mode: bool, -) -> Dict[str, Any]: - dropout = float(checkpoint.get("dropout", 0.1)) - num_classes = int(checkpoint["num_classes"]) - expert_specs = _checkpoint_to_expert_specs(checkpoint) - expert_trainable_flag = bool(checkpoint.get("expert_trainable", False)) - experts = load_experts( - expert_specs, - device, - trainable=train_mode and expert_trainable_flag, - ) - - expert_state_dicts = checkpoint.get("expert_state_dicts") - if expert_state_dicts: - name_to_state = { - entry.get("name"): entry.get("state_dict") - for entry in expert_state_dicts - if isinstance(entry, Mapping) - } - for spec, expert in zip(expert_specs, experts): - state_dict = name_to_state.get(spec.name) - if state_dict: - expert.load_state_dict({k: v for k, v in state_dict.items()}) - if not train_mode: - for expert_model in experts: - expert_model.eval() - - router: Optional[RouterNet] = None - router_state = checkpoint.get("router_state_dict") - if router_state is not None: - router = RouterNet(num_experts=len(expert_specs), dropout=dropout).to(device) - router.load_state_dict(router_state) - if train_mode: - router.train() - else: - router.eval() - - classifier = TaskClassifier(num_classes=num_classes, dropout=dropout).to(device) - classifier.load_state_dict(checkpoint["classifier_state_dict"]) - if train_mode: - classifier.train() - else: - classifier.eval() - - return { - "router": router, - "classifier": classifier, - "experts": experts, - "expert_specs": expert_specs, - "comm_to_idx": _normalize_comm_mapping(checkpoint["comm_to_idx"]), - "task": checkpoint["task"], - "topk": int(checkpoint.get("topk", 1)), - "num_classes": num_classes, - "dropout": dropout, - "mapping": _normalize_label_mapping(checkpoint.get("mapping")), - "expert_trainable": expert_trainable_flag, - } - - -def load_checkpoint_for_training( - checkpoint_path: Path, - device: torch.device, - checkpoint_data: Optional[Mapping[str, Any]] = None, -) -> Dict[str, Any]: - """Load checkpoint components for continued training.""" - if checkpoint_data is None: - checkpoint = torch.load(checkpoint_path, map_location="cpu") - else: - checkpoint = dict(checkpoint_data) - model_type = str(checkpoint.get("model_type", "")) - if model_type != "embedding_router_moe": - raise ValueError( - f"Resume checkpoint expected 'embedding_router_moe' but found '{model_type}'" - ) - return _build_checkpoint_components(checkpoint, device, train_mode=True) - - -def load_checkpoint_for_inference(checkpoint_path: Path, device: torch.device): - """Load complete checkpoint and return all components.""" - checkpoint = torch.load(checkpoint_path, map_location=device) - return _build_checkpoint_components(checkpoint, device, train_mode=False) - - -class MoEPredictor: - """Inference wrapper for trained MoE model.""" - - def __init__( - self, - *, - router: Optional[RouterNet], - classifier: TaskClassifier, - experts: Sequence[EmbeddingExpert], - expert_specs: Sequence[ExpertSpec], - comm_to_idx: Mapping[str, int], - task_type: str, - topk: int, - mapping: Optional[Dict[int, Tuple[str, str]]], - device: torch.device, - ) -> None: - self.router = router - self.classifier = classifier - self.experts = experts - self.expert_specs = expert_specs - self.comm_to_idx = comm_to_idx - self.task_type = task_type - self.topk = topk - self.mapping = mapping - self.device = device - - # Build reverse mapping for results - if task_type == "modulation": - from task1.train_mcs_models import MODULATION_LABELS - self.idx_to_label = {v: k for k, v in MODULATION_LABELS.items()} - elif mapping: - self.idx_to_label = mapping - else: - self.idx_to_label = None - - @classmethod - def from_checkpoint(cls, checkpoint_path: Path, device: Optional[torch.device] = None): - """Load predictor from checkpoint file.""" - if device is None: - # Try HPU first, then CUDA, then CPU - try: - import habana_frameworks.torch.core as htcore - device = torch.device("hpu") - except (ImportError, RuntimeError): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - components = load_checkpoint_for_inference(checkpoint_path, device) - - return cls( - router=components["router"], - classifier=components["classifier"], - experts=components["experts"], - expert_specs=components["expert_specs"], - comm_to_idx=components["comm_to_idx"], - task_type=components["task"], - topk=components["topk"], - mapping=components["mapping"], - device=device, - ) - - @torch.no_grad() - def predict( - self, - spectrogram: torch.Tensor, - return_probabilities: bool = False, - return_routing: bool = False, - ) -> Dict[str, object]: - """Predict task label for a single spectrogram or batch. - - Args: - spectrogram: Tensor of shape [H, W] or [B, H, W] - return_probabilities: If True, return class probabilities - return_routing: If True, return routing weights - - Returns: - Dictionary with prediction results - """ - # Handle single sample - single_sample = spectrogram.dim() == 2 - if single_sample: - spectrogram = spectrogram.unsqueeze(0) - - # Move to device and normalize - specs = spectrogram.to(self.device) - specs_norm = normalize_per_sample_tensor(specs) - - if self.router is not None: - # Router-based prediction - router_logits = self.router(specs_norm) - router_probs = torch.softmax(router_logits, dim=1) - topk_vals, topk_idx = router_probs.topk(k=self.topk, dim=1) - weights = topk_vals / torch.clamp(topk_vals.sum(dim=1, keepdim=True), min=1e-6) - - # Get embeddings from selected experts - selected_embeddings = compute_selected_expert_embeddings( - self.experts, - specs_norm, - topk_idx, - allow_grad=False, - ) - logits_each = self.classifier(selected_embeddings.view(-1, selected_embeddings.size(-1))) - logits_each = logits_each.view(specs.size(0), self.topk, -1) - weighted_logits = (weights.unsqueeze(-1) * logits_each).sum(dim=1) - else: - # Oracle baseline: use all experts (fallback) - embeddings = stack_expert_embeddings(self.experts, specs) - # Average all expert embeddings - avg_embedding = embeddings.mean(dim=1) - weighted_logits = self.classifier(avg_embedding) - topk_idx = None - weights = None - - probs = torch.softmax(weighted_logits, dim=1) - predicted_classes = weighted_logits.argmax(dim=1) - - # Build results - results = { - "predicted_class": int(predicted_classes[0].item()) if single_sample else predicted_classes.cpu().tolist(), - "confidence": float(probs[0, predicted_classes[0]].item()) if single_sample else [float(probs[i, predicted_classes[i]].item()) for i in range(len(predicted_classes))], - } - - # Add human-readable labels - if self.idx_to_label: - if single_sample: - results["label"] = self.idx_to_label.get(int(predicted_classes[0].item()), "Unknown") - else: - results["labels"] = [self.idx_to_label.get(int(c), "Unknown") for c in predicted_classes.cpu().tolist()] - - if return_probabilities: - results["probabilities"] = probs[0].cpu().tolist() if single_sample else probs.cpu().tolist() - - if return_routing and topk_idx is not None: - routing_info = [] - for b in range(specs.size(0)): - sample_routing = [] - for k in range(self.topk): - expert_idx = int(topk_idx[b, k].item()) - sample_routing.append({ - "expert": self.expert_specs[expert_idx].name, - "comm": self.expert_specs[expert_idx].comm, - "weight": float(weights[b, k].item()), - }) - routing_info.append(sample_routing) - results["routing"] = routing_info[0] if single_sample else routing_info - - return results - - -def main() -> None: - args = parse_args() - torch.manual_seed(args.seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(args.seed) - - resume_checkpoint_path: Optional[Path] = None - resume_checkpoint_data: Optional[Dict[str, Any]] = None - resume_comm_to_idx: Optional[Dict[str, int]] = None - resume_topk: Optional[int] = None - resume_dropout: Optional[float] = None - resume_task: Optional[str] = None - resume_num_classes: Optional[int] = None - resume_mapping: Optional[Dict[int, Tuple[str, str]]] = None - - if args.resume_checkpoint is not None: - if args.baseline is not None: - raise ValueError("--resume-checkpoint is only supported for MoE training (omit --baseline).") - if args.expert: - raise ValueError("--resume-checkpoint cannot be combined with manual --expert definitions.") - resume_checkpoint_path = args.resume_checkpoint.expanduser().resolve() - if not resume_checkpoint_path.exists(): - raise FileNotFoundError(f"Resume checkpoint does not exist: {resume_checkpoint_path}") - resume_checkpoint_data = torch.load(resume_checkpoint_path, map_location="cpu") - model_type = str(resume_checkpoint_data.get("model_type", "")) - if model_type != "embedding_router_moe": - raise ValueError( - f"Cannot resume from checkpoint with model_type={model_type!r}; expected 'embedding_router_moe'" - ) - resume_comm_to_idx = _normalize_comm_mapping(resume_checkpoint_data["comm_to_idx"]) - resume_topk = int(resume_checkpoint_data.get("topk", 1)) - resume_dropout = float(resume_checkpoint_data.get("dropout", args.dropout)) - resume_task = str(resume_checkpoint_data.get("task", args.task)) - resume_num_classes = int(resume_checkpoint_data["num_classes"]) - resume_mapping = _normalize_label_mapping(resume_checkpoint_data.get("mapping")) - expert_specs = _checkpoint_to_expert_specs(resume_checkpoint_data) - comm_types = [comm for comm, _ in sorted(resume_comm_to_idx.items(), key=lambda kv: kv[1])] - requested_comm_types = [canonical_comm_name(comm) for comm in args.comm_types] - if set(requested_comm_types) != set(comm_types): - raise ValueError( - "--comm-types must match the communications present in the resume checkpoint " - f"{comm_types}; received {requested_comm_types}" - ) - if args.task != resume_task: - print(f"[WARN] Overriding task '{args.task}' -> '{resume_task}' to match resume checkpoint") - args.task = resume_task - if abs(args.dropout - resume_dropout) > 1e-6: - print(f"[WARN] Overriding dropout {args.dropout} -> {resume_dropout} to match resume checkpoint") - args.dropout = resume_dropout - args.routing_topk = resume_topk - else: - if args.expert: - expert_specs = [parse_manual_expert(entry) for entry in args.expert] - else: - expert_specs = discover_default_experts() - comm_types = [canonical_comm_name(comm) for comm in args.comm_types] - for spec in expert_specs: - if spec.comm not in comm_types: - comm_types.append(spec.comm) - comm_types = list(dict.fromkeys(comm_types)) - - dataset, comm_to_idx, mapping = prepare_dataset( - data_root=args.data_root, - cities=args.cities, - comm_types=comm_types, - snrs=args.snrs, - mobilities=args.mobilities, - modulations=args.modulations, - fft_folders=args.fft_folders, - max_samples_per_comm=args.max_samples_per_comm, - max_per_combo=args.max_per_combo, - max_samples_per_class=args.max_samples_per_class, - val_samples_per_class=args.val_samples_per_class, - test_samples_per_class=args.test_samples_per_class, - task=args.task, - seed=args.seed, - preload=args.preload_data, - ) - - available_comms = set(comm_to_idx.keys()) - filtered_specs: List[ExpertSpec] = [] - missing_specs: List[ExpertSpec] = [] - for spec in expert_specs: - if spec.comm in available_comms: - filtered_specs.append(spec) - else: - missing_specs.append(spec) - if missing_specs: - missing_comm = ", ".join(sorted({spec.comm for spec in missing_specs})) - if resume_checkpoint_path is not None: - raise RuntimeError( - "Resume dataset is missing communication types required by the checkpoint: " - f"{missing_comm}" - ) - missing_names = ", ".join(sorted({f"{spec.name} ({spec.comm})" for spec in missing_specs})) - print(f"[WARN] Skipping experts with no matching data: {missing_names}") - expert_specs = filtered_specs - if not expert_specs: - raise RuntimeError("No experts remain after filtering by available communication types") - - if resume_comm_to_idx is not None and comm_to_idx != resume_comm_to_idx: - raise RuntimeError( - "Communication mapping inferred from data does not match resume checkpoint. " - "Ensure the same communication types are present." - ) - - dataset_num_classes = int(dataset.task_labels.max()) + 1 - if resume_num_classes is not None and dataset_num_classes != resume_num_classes: - raise RuntimeError( - f"Dataset provides {dataset_num_classes} task classes but resume checkpoint expects {resume_num_classes}. " - "Ensure the limited dataset still covers every class." - ) - if resume_mapping is not None: - if mapping is None: - raise RuntimeError("Resume checkpoint includes task mapping but dataset could not infer one.") - if mapping != resume_mapping: - raise RuntimeError( - "Task label mapping from limited dataset does not match resume checkpoint. " - "Ensure all (SNR, mobility) combinations are present." - ) - - train_idx, val_idx, test_idx = stratified_split( - dataset.task_labels.numpy(), - train_ratio=args.train_ratio, - val_ratio=args.val_ratio, - max_train_per_class=args.max_samples_per_class, - val_samples_per_class=args.val_samples_per_class, - test_samples_per_class=args.test_samples_per_class, - seed=args.seed, - ) - - print(f"[SPLIT] Train: {len(train_idx):,} samples ({len(train_idx)/len(dataset)*100:.1f}%)") - print(f"[SPLIT] Val: {len(val_idx):,} samples ({len(val_idx)/len(dataset)*100:.1f}%)") - print(f"[SPLIT] Test: {len(test_idx):,} samples ({len(test_idx)/len(dataset)*100:.1f}%)\n") - - train_loader, val_loader, test_loader = build_dataloaders( - dataset, - train_idx=train_idx, - val_idx=val_idx, - test_idx=test_idx, - batch_size=args.batch_size, - num_workers=args.num_workers, - ) - # Device selection with HPU support - # Try HPU first, then CUDA, then CPU - try: - import habana_frameworks.torch.core as htcore - device = torch.device("hpu") - print("[INFO] HPU device detected and selected") - except (ImportError, RuntimeError): - if torch.cuda.is_available(): - device = torch.device("cuda") - num_gpus = torch.cuda.device_count() - print(f"[INFO] CUDA device detected - using {num_gpus} GPU(s)") - else: - device = torch.device("cpu") - num_gpus = 0 - print("[INFO] Using CPU device") - num_classes = dataset_num_classes - - resume_state: Optional[Dict[str, Any]] = None - if resume_checkpoint_path is not None: - resume_state = load_checkpoint_for_training( - resume_checkpoint_path, - device, - checkpoint_data=resume_checkpoint_data, - ) - if resume_state["comm_to_idx"] != comm_to_idx: - raise RuntimeError("Mismatch between resume checkpoint comm_to_idx and dataset mapping") - args.routing_topk = resume_state["topk"] - args.dropout = resume_state["dropout"] - expert_specs = resume_state["expert_specs"] - print(f"[INFO] Resuming MoE training from {resume_checkpoint_path}") - if resume_state.get("expert_trainable") and args.expert_lr <= 0: - print( - "[WARN] Resume checkpoint includes fine-tuned experts but --expert-lr is 0; " - "experts will remain frozen unless you provide a positive --expert-lr." - ) - embedding_models: Optional[List[EmbeddingExpert]] = None - training_history: Optional[Dict[str, Any]] = None - - effective_topk = max(1, min(args.routing_topk, len(expert_specs))) - model_type = "embedding_router_moe" - router: Optional[RouterNet] = None - classifier: TaskClassifier - backbone_state_dict: Optional[Dict[str, torch.Tensor]] = None - backbone_meta: Optional[Dict[str, Any]] = None - use_data_parallel = torch.cuda.is_available() and torch.cuda.device_count() > 1 - - if args.baseline == "single": - print("[INFO] Training single model baseline...") - backbone = SingleModelBackbone(dropout=args.dropout).to(device) - classifier = TaskClassifier(num_classes=num_classes, dropout=args.dropout).to(device) - train_single_model( - backbone=backbone, - classifier=classifier, - train_loader=train_loader, - val_loader=val_loader, - device=device, - epochs=args.task_epochs, - lr=args.classifier_lr, - weight_decay=args.weight_decay, - patience=args.patience, - ) - print("[INFO] Evaluating on test split...") - test_loss, test_acc, test_f1 = evaluate_single_model( - backbone=backbone, - classifier=classifier, - loader=test_loader, - device=device, - ) - test_metrics = { - "test_accuracy": test_acc, - "test_f1": test_f1, - "test_loss": test_loss, - } - model_type = "baseline_single" - backbone_state_dict = {k: v.cpu() for k, v in backbone.state_dict().items()} - backbone_meta = { - "baseline_mode": "single", - "backbone_class": backbone.__class__.__name__, - } - elif args.baseline == "imagenet": - print("[INFO] Training ImageNet pretrained baseline...") - backbone = ImageNetBackbone(dropout=args.dropout, freeze_backbone=args.freeze_backbone).to(device) - classifier = TaskClassifier(num_classes=num_classes, dropout=args.dropout).to(device) - train_single_model( - backbone=backbone, - classifier=classifier, - train_loader=train_loader, - val_loader=val_loader, - device=device, - epochs=args.task_epochs, - lr=args.classifier_lr, - weight_decay=args.weight_decay, - patience=args.patience, - ) - print("[INFO] Evaluating on test split...") - test_loss, test_acc, test_f1 = evaluate_single_model( - backbone=backbone, - classifier=classifier, - loader=test_loader, - device=device, - ) - test_metrics = { - "test_accuracy": test_acc, - "test_f1": test_f1, - "test_loss": test_loss, - } - model_type = "baseline_imagenet" - backbone_state_dict = {k: v.cpu() for k, v in backbone.state_dict().items()} - backbone_meta = { - "baseline_mode": "imagenet", - "backbone_class": backbone.__class__.__name__, - "freeze_backbone": args.freeze_backbone, - } - elif args.baseline == "oracle": - print("[INFO] Training oracle baseline (ground-truth comm labels)...") - embedding_models = load_experts(expert_specs, device) - classifier = TaskClassifier(num_classes=num_classes, dropout=args.dropout).to(device) - train_oracle_baseline( - experts=embedding_models, - expert_specs=expert_specs, - comm_to_idx=comm_to_idx, - classifier=classifier, - train_loader=train_loader, - val_loader=val_loader, - device=device, - epochs=args.task_epochs, - lr=args.classifier_lr, - weight_decay=args.weight_decay, - patience=args.patience, - ) - print("[INFO] Evaluating on test split...") - comm_to_expert_idx = build_baseline_expert_map(expert_specs, comm_to_idx) - test_loss, test_acc, test_f1 = evaluate_oracle_baseline( - experts=embedding_models, - expert_specs=expert_specs, - comm_to_idx=comm_to_idx, - comm_to_expert_idx=comm_to_expert_idx, - classifier=classifier, - loader=test_loader, - device=device, - ) - test_metrics = { - "test_accuracy": test_acc, - "test_f1": test_f1, - "test_loss": test_loss, - } - model_type = "baseline_oracle" - else: - expert_trainable = args.expert_lr > 0 - if resume_state is None: - router = RouterNet(num_experts=len(expert_specs), dropout=args.dropout).to(device) - if use_data_parallel: - print(f"[INFO] Wrapping router with DataParallel") - router = nn.DataParallel(router) - if args.router_epochs > 0: - print("[INFO] Pre-training router...") - train_router( - router, - experts=expert_specs, - comm_to_idx=comm_to_idx, - train_loader=train_loader, - val_loader=val_loader, - device=device, - epochs=args.router_epochs, - lr=args.router_warmup_lr, - weight_decay=args.weight_decay, - ) - else: - print("[INFO] Skipping router warm-up (router_epochs=0).") - embedding_models = load_experts(expert_specs, device, trainable=expert_trainable) - if use_data_parallel: - print(f"[INFO] Wrapping {len(embedding_models)} experts with DataParallel") - embedding_models = [nn.DataParallel(m) if hasattr(m, 'forward') else m for m in embedding_models] - classifier = TaskClassifier(num_classes=num_classes, dropout=args.dropout).to(device) - if use_data_parallel: - print(f"[INFO] Wrapping classifier with DataParallel") - classifier = nn.DataParallel(classifier) - else: - router = resume_state["router"] - classifier = resume_state["classifier"] - embedding_models = resume_state["experts"] - for expert in embedding_models: - expert.set_trainable(expert_trainable) - if args.resume_router_warmup and args.router_epochs > 0: - print("[INFO] Running router warm-up on resume data...") - train_router( - router, - experts=expert_specs, - comm_to_idx=comm_to_idx, - train_loader=train_loader, - val_loader=val_loader, - device=device, - epochs=args.router_epochs, - lr=args.router_warmup_lr, - weight_decay=args.weight_decay, - ) - else: - print("[INFO] Skipping router warm-up (resume).") - # Ensure experts are moved to correct device/state after resume - for expert in embedding_models: - expert.to(device) - output_dir = args.output_dir.expanduser().resolve() - output_dir.mkdir(parents=True, exist_ok=True) - epoch_checkpoint_dir = output_dir / "epoch_checkpoints" - epoch_checkpoint_dir.mkdir(parents=True, exist_ok=True) - - def save_epoch_checkpoint(epoch_idx: int) -> None: - checkpoint_path = epoch_checkpoint_dir / f"epoch_{epoch_idx:02d}.pth" - expert_models_to_save = ( - embedding_models if (embedding_models is not None and args.baseline is None) else None - ) - expert_trainable_flag = ( - bool(expert_models_to_save) - and any(expert.trainable for expert in expert_models_to_save) # type: ignore[arg-type] - ) - save_complete_checkpoint( - router=router if args.baseline is None else None, - classifier=classifier, - expert_models=expert_models_to_save, - expert_specs=expert_specs, - comm_to_idx=comm_to_idx, - task_type=args.task, - num_classes=num_classes, - topk=effective_topk if args.baseline is None else 1, - dropout=args.dropout, - mapping=mapping, - output_path=checkpoint_path, - model_type=model_type, - backbone_state_dict=backbone_state_dict if args.baseline in {"single", "imagenet"} else None, - backbone_meta=backbone_meta if args.baseline in {"single", "imagenet"} else None, - expert_trainable=expert_trainable_flag, - ) - - print("[INFO] Training classifier with router-guided embeddings...") - training_history = train_task_model( - router=router, - experts=embedding_models, - expert_specs=expert_specs, - comm_to_idx=comm_to_idx, - classifier=classifier, - train_loader=train_loader, - val_loader=val_loader, - device=device, - epochs=args.task_epochs, - topk=effective_topk, - router_lr=args.router_lr, - classifier_lr=args.classifier_lr, - expert_lr=max(0.0, args.expert_lr), - weight_decay=args.weight_decay, - router_loss_weight=max(0.0, args.router_loss_weight), - load_balance_weight=max(0.0, args.load_balance_weight), - gating_noise_std=max(0.0, args.gating_noise_std), - gating_noise_epochs=max(0, args.gating_noise_epochs), - patience=args.patience, - checkpoint_callback=save_epoch_checkpoint, - ) - print("[INFO] Evaluating on test split...") - test_metrics = evaluate_test_metrics( - router=router, - experts=embedding_models, - classifier=classifier, - loader=test_loader, - topk=effective_topk, - device=device, - ) - - print("\n" + "=" * 60) - print("TEST RESULTS") - print("=" * 60) - print(f"Accuracy: {test_metrics['test_accuracy']:.4f}") - print(f"F1 Score: {test_metrics['test_f1']:.4f}") - print("=" * 60 + "\n") - - metrics_path = output_dir / "metrics.json" - with metrics_path.open("w", encoding="utf-8") as fh: - json.dump(test_metrics, fh, indent=2) - print("[RESULT] Test metrics saved to", metrics_path) - - if training_history is not None: - sanitized_history = sanitize_history_for_serialization(training_history) - history_path = output_dir / "training_history.json" - with history_path.open("w", encoding="utf-8") as fh: - json.dump(sanitized_history, fh, indent=2) - metrics_csv_path = output_dir / "training_metrics.csv" - write_training_metrics_csv(sanitized_history, expert_specs, metrics_csv_path) - print("[RESULT] Training history saved to", history_path) - print("[RESULT] Training metrics saved to", metrics_csv_path) - - checkpoint_path = output_dir / "moe_checkpoint.pth" - expert_models_to_save = ( - embedding_models if (embedding_models is not None and args.baseline is None) else None - ) - expert_trainable_flag = ( - bool(expert_models_to_save) - and any(expert.trainable for expert in expert_models_to_save) # type: ignore[arg-type] - ) - save_complete_checkpoint( - router=router if args.baseline is None else None, - classifier=classifier, - expert_models=expert_models_to_save, - expert_specs=expert_specs, - comm_to_idx=comm_to_idx, - task_type=args.task, - num_classes=num_classes, - topk=effective_topk if args.baseline is None else 1, - dropout=args.dropout, - mapping=mapping, - output_path=checkpoint_path, - model_type=model_type, - backbone_state_dict=backbone_state_dict if args.baseline in {"single", "imagenet"} else None, - backbone_meta=backbone_meta if args.baseline in {"single", "imagenet"} else None, - expert_trainable=expert_trainable_flag, - ) - - if args.baseline is None and args.save_router: - torch.save(router.state_dict(), output_dir / "router_state_dict.pth") - print("[INFO] Router state_dict saved") - if args.save_classifier: - torch.save(classifier.state_dict(), output_dir / "classifier_state_dict.pth") - print("[INFO] Classifier state_dict saved") - if mapping is not None: - with (output_dir / "snr_mobility_mapping.json").open("w", encoding="utf-8") as fh: - json.dump({int(k): v for k, v in mapping.items()}, fh, indent=2) - - -if __name__ == "__main__": - main()