|
|
|
|
|
"""Visualise how strongly metadata drives the learned embedding space. |
|
|
|
|
|
This script mirrors the functionality of ``task1/plot_mod_tsne.py`` but groups |
|
|
spectrograms by their SNR folder name (e.g. ``SNR0dB``) instead of modulation. |
|
|
It is useful for checking whether the self-supervised LWM backbone mostly |
|
|
captures channel/SNR differences rather than modulation characteristics. |
|
|
|
|
|
Pass ``--label-field modulation`` to reuse the same sampled spectrograms while |
|
|
colouring and scoring them by their modulation folder instead of SNR. Use |
|
|
``--label-field mobility`` to highlight link-level mobility categories when |
|
|
present in the dataset tree. Saved figures automatically include the detected |
|
|
communication profile (e.g. LTE/WiFi/5G) and label mode in the filename when |
|
|
those suffixes are not already present. |
|
|
|
|
|
Usage example: |
|
|
|
|
|
```bash |
|
|
python task1/plot_snr_tsne.py \ |
|
|
--data-root spectrograms/city_1_losangeles/LTE \ |
|
|
--snrs SNR-5dB,SNR0dB,SNR10dB,SNR15dB,SNR20dB,SNR25dB \ |
|
|
--save-path task1/snr_separation_plot_latest.png |
|
|
``` |
|
|
Shortcut presets: |
|
|
|
|
|
```bash |
|
|
python task1/plot_snr_tsne.py --WiFi --report-metrics |
|
|
``` |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import argparse |
|
|
import glob |
|
|
import pickle |
|
|
import random |
|
|
import re |
|
|
from pathlib import Path |
|
|
from collections import Counter, defaultdict |
|
|
from typing import Dict, Iterable, List, Tuple |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
import torch |
|
|
from sklearn.manifold import TSNE |
|
|
from sklearn.metrics import silhouette_score |
|
|
from sklearn.model_selection import StratifiedKFold |
|
|
from sklearn.neighbors import KNeighborsClassifier |
|
|
from sklearn.preprocessing import StandardScaler |
|
|
|
|
|
from pretraining.pretrained_model import lwm as lwm_model |
|
|
from utils import load_spectrogram_data |
|
|
|
|
|
|
|
|
DEFAULT_DATA_ROOT = "spectrograms/city_1_losangeles/LTE" |
|
|
DEFAULT_MODELS_ROOT = "models/LTE_models" |
|
|
|
|
|
PROFILE_PRESETS: Dict[str, Dict[str, str]] = { |
|
|
"LTE": { |
|
|
"data_root": DEFAULT_DATA_ROOT, |
|
|
"models_root": DEFAULT_MODELS_ROOT, |
|
|
}, |
|
|
"WiFi": { |
|
|
"data_root": "spectrograms/city_1_losangeles/WiFi", |
|
|
"models_root": "models/WiFi_models", |
|
|
}, |
|
|
"5G": { |
|
|
"data_root": "spectrograms/city_1_losangeles/5G", |
|
|
"models_root": "models/5G_models", |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
def normalize_per_sample(specs: np.ndarray, eps: float = 1e-6) -> np.ndarray: |
|
|
means = specs.mean(axis=(1, 2), keepdims=True) |
|
|
stds = specs.std(axis=(1, 2), keepdims=True) |
|
|
stds = np.maximum(stds, eps) |
|
|
return ((specs - means) / stds).astype(np.float32, copy=False) |
|
|
|
|
|
|
|
|
def normalize_dataset(specs: np.ndarray, eps: float = 1e-6) -> np.ndarray: |
|
|
mean = float(specs.mean()) |
|
|
std = float(specs.std()) |
|
|
std = max(std, eps) |
|
|
return ((specs - mean) / std).astype(np.float32, copy=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
|
parser = argparse.ArgumentParser(description=__doc__) |
|
|
parser.add_argument( |
|
|
"--data-root", |
|
|
default=DEFAULT_DATA_ROOT, |
|
|
help="Root directory containing modulation folders (default: %(default)s)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--modulation", |
|
|
default="all", |
|
|
help="Modulation folder to load (default: %(default)s)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--snrs", |
|
|
default="SNR-5dB,SNR0dB,SNR5dB,SNR10dB,SNR15dB,SNR20dB,SNR25dB", |
|
|
help=( |
|
|
"Comma-separated list of SNR folder names to include. Pass 'all' " |
|
|
"to include every SNR discovered under the modulation (default: %(default)s)" |
|
|
), |
|
|
) |
|
|
parser.add_argument( |
|
|
"--mobility", |
|
|
nargs="+", |
|
|
default=["all"], |
|
|
help=( |
|
|
"Mobility folder(s) to filter on. Pass 'all' to include every mobility " |
|
|
"(default: %(default)s). Multiple values can be provided either as a " |
|
|
"space-separated list (e.g. '--mobility vehicular pedestrian') or a " |
|
|
"comma-separated string." |
|
|
), |
|
|
) |
|
|
parser.add_argument( |
|
|
"--fft-folder", |
|
|
default="all", |
|
|
help=( |
|
|
"FFT size folder name to use. Pass 'all' to include every FFT variant " |
|
|
"(default: %(default)s)" |
|
|
), |
|
|
) |
|
|
parser.add_argument( |
|
|
"--samples-per-snr", |
|
|
type=int, |
|
|
default=500, |
|
|
help="Maximum number of samples to draw for each SNR label", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--seed", |
|
|
type=int, |
|
|
default=42, |
|
|
help="Random seed for sampling and t-SNE", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--pooling", |
|
|
choices=("mean", "cls"), |
|
|
default="mean", |
|
|
help="How to collapse token embeddings into a single vector", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--save-path", |
|
|
default="task1/snr_separation_plot_latest.png", |
|
|
help="Location to save the generated figure (default: %(default)s)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--checkpoint", |
|
|
default=None, |
|
|
help="Optional explicit checkpoint path; overrides automatic latest selection", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--models-root", |
|
|
default=DEFAULT_MODELS_ROOT, |
|
|
help=( |
|
|
"Directory containing checkpoints. When --checkpoint is not given, " |
|
|
"the latest/best checkpoint inside this directory will be used " |
|
|
"(default: %(default)s)" |
|
|
), |
|
|
) |
|
|
preset_group = parser.add_mutually_exclusive_group() |
|
|
preset_group.add_argument( |
|
|
"--profile", |
|
|
dest="profile", |
|
|
choices=tuple(PROFILE_PRESETS.keys()), |
|
|
help=( |
|
|
"Convenience preset that sets --data-root and --models-root when they " |
|
|
"are left at their defaults" |
|
|
), |
|
|
) |
|
|
preset_group.add_argument( |
|
|
"--LTE", |
|
|
dest="profile", |
|
|
action="store_const", |
|
|
const="LTE", |
|
|
help="Shortcut for --profile LTE", |
|
|
) |
|
|
preset_group.add_argument( |
|
|
"--WiFi", |
|
|
dest="profile", |
|
|
action="store_const", |
|
|
const="WiFi", |
|
|
help="Shortcut for --profile WiFi", |
|
|
) |
|
|
preset_group.add_argument( |
|
|
"--5G", |
|
|
dest="profile", |
|
|
action="store_const", |
|
|
const="5G", |
|
|
help="Shortcut for --profile 5G", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--report-metrics", |
|
|
action="store_true", |
|
|
help="Print clustering metrics (silhouette, 5-fold kNN accuracy)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--metrics-only", |
|
|
action="store_true", |
|
|
help="Exit after reporting metrics without running t-SNE or saving figures", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--sampling-mode", |
|
|
choices=("first", "reservoir"), |
|
|
default="first", |
|
|
help="How to down-sample each class (default: first)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--complex-mode", |
|
|
choices=("auto", "magnitude", "interleaved"), |
|
|
default="auto", |
|
|
help=( |
|
|
"How to handle complex spectrograms: 'magnitude' (abs), 'interleaved' (real/imag interleaved along width), " |
|
|
"or 'auto' (prefer interleaved when complex). Real-valued inputs are unaffected." |
|
|
), |
|
|
) |
|
|
parser.add_argument( |
|
|
"--label-field", |
|
|
choices=("snr", "modulation", "mobility"), |
|
|
default="snr", |
|
|
help="Choose which label to visualise and score (default: %(default)s)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--normalization", |
|
|
choices=("per-sample", "dataset"), |
|
|
default="per-sample", |
|
|
help="Normalisation strategy applied before embedding extraction", |
|
|
) |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def find_latest_checkpoint(models_root: Path) -> Path: |
|
|
"""Return a checkpoint path under ``models_root``. |
|
|
|
|
|
Works with either a parent directory that contains multiple run folders, |
|
|
or directly with a single run directory containing ``*.pth`` files. |
|
|
Chooses the checkpoint with the lowest parsed validation value when |
|
|
available, else falls back to most-recent modification time. |
|
|
""" |
|
|
|
|
|
if not models_root.exists(): |
|
|
raise FileNotFoundError(f"Models root not found: {models_root}") |
|
|
|
|
|
if models_root.is_file(): |
|
|
raise FileNotFoundError(f"Expected a directory, got file: {models_root}") |
|
|
|
|
|
|
|
|
checkpoints = list(models_root.glob("*.pth")) |
|
|
if not checkpoints: |
|
|
|
|
|
run_dirs = [p for p in models_root.iterdir() if p.is_dir()] |
|
|
candidate_runs = [d for d in run_dirs if any(d.glob("*.pth"))] |
|
|
if not candidate_runs: |
|
|
raise FileNotFoundError( |
|
|
f"No checkpoints found under {models_root} (no .pth files in this dir or its run subdirs)" |
|
|
) |
|
|
latest_run = max(candidate_runs, key=lambda p: p.stat().st_mtime) |
|
|
checkpoints = list(latest_run.glob("*.pth")) |
|
|
|
|
|
def parse_val_metric(path: Path) -> float | None: |
|
|
match = re.search(r"_val([0-9]+(?:\.[0-9]+)?)", path.name) |
|
|
if match: |
|
|
try: |
|
|
return float(match.group(1)) |
|
|
except ValueError: |
|
|
return None |
|
|
return None |
|
|
|
|
|
parsed = [(parse_val_metric(p), p) for p in checkpoints] |
|
|
valid = [item for item in parsed if item[0] is not None] |
|
|
if valid: |
|
|
valid.sort(key=lambda item: item[0]) |
|
|
return valid[0][1] |
|
|
|
|
|
|
|
|
return max(checkpoints, key=lambda p: p.stat().st_mtime) |
|
|
|
|
|
|
|
|
def parse_snr_list(snr_argument: str | None) -> set[str] | None: |
|
|
if snr_argument is None or snr_argument.lower() == "all": |
|
|
return None |
|
|
values = [item.strip() for item in snr_argument.split(",") if item.strip()] |
|
|
return set(values) |
|
|
|
|
|
|
|
|
def list_snr_samples( |
|
|
data_root: Path, |
|
|
modulation: str, |
|
|
allowed_snrs: set[str] | None, |
|
|
mobility_filter: set[str] | None, |
|
|
fft_folder: str, |
|
|
max_per_class: int, |
|
|
rng: random.Random, |
|
|
mode: str, |
|
|
complex_mode: str, |
|
|
) -> Dict[str, List[Tuple[np.ndarray, str, str]]]: |
|
|
"""Collect spectrogram samples grouped by SNR label. |
|
|
|
|
|
Supports both legacy PKL layout with a trailing 'spectrograms/' folder and |
|
|
MATLAB .mat bundles saved directly under the mobility folder. |
|
|
|
|
|
Returns: mapping from SNR label to list of tuples: (spec, modulation, mobility) |
|
|
""" |
|
|
|
|
|
class_samples: Dict[str, List[Tuple[np.ndarray, str, str]]] = defaultdict(list) |
|
|
seen_counts: Dict[str, int] = defaultdict(int) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
patterns = [ |
|
|
str(data_root / "**" / "spectrograms" / "*.pkl"), |
|
|
str(data_root / "**" / "spectrogram_*.mat"), |
|
|
] |
|
|
|
|
|
mobility_set = {"static", "pedestrian", "vehicular"} |
|
|
|
|
|
def extract_tokens(rel_parts: Tuple[str, ...]) -> Tuple[str, str, str, str] | None: |
|
|
|
|
|
|
|
|
if not rel_parts: |
|
|
return None |
|
|
modulation_folder = rel_parts[0] |
|
|
|
|
|
|
|
|
snr_folder = next((p for p in rel_parts if re.match(r"^SNR-?\d+dB$", p)), None) |
|
|
if snr_folder is None: |
|
|
return None |
|
|
|
|
|
|
|
|
mobility_folder = next((p for p in rel_parts if p.lower() in mobility_set), None) |
|
|
if mobility_folder is None: |
|
|
return None |
|
|
|
|
|
|
|
|
fft_folder_name = next((p for p in rel_parts if p.startswith("win") or p.startswith("fft")), "fft_unknown") |
|
|
|
|
|
return modulation_folder, snr_folder, mobility_folder, fft_folder_name |
|
|
|
|
|
for pattern in patterns: |
|
|
for path_str in glob.iglob(pattern, recursive=True): |
|
|
path = Path(path_str) |
|
|
try: |
|
|
rel_parts = path.relative_to(data_root).parts |
|
|
except ValueError: |
|
|
continue |
|
|
|
|
|
tokens = extract_tokens(rel_parts) |
|
|
if tokens is None: |
|
|
continue |
|
|
modulation_folder, snr_folder, mobility_folder, fft_folder_name = tokens |
|
|
|
|
|
|
|
|
if modulation.lower() != "all" and modulation_folder != modulation: |
|
|
continue |
|
|
if allowed_snrs is not None and snr_folder not in allowed_snrs: |
|
|
continue |
|
|
if mobility_filter is not None and mobility_folder.lower() not in mobility_filter: |
|
|
continue |
|
|
if fft_folder != "all" and fft_folder_name != fft_folder: |
|
|
continue |
|
|
|
|
|
class_label = snr_folder |
|
|
if mode == "first" and len(class_samples[class_label]) >= max_per_class: |
|
|
continue |
|
|
|
|
|
|
|
|
try: |
|
|
arr = load_spectrogram_data(str(path)) |
|
|
except Exception as exc: |
|
|
print(f"[WARN] Failed to load {path}: {exc}") |
|
|
continue |
|
|
|
|
|
if not isinstance(arr, np.ndarray) or arr.size == 0: |
|
|
continue |
|
|
|
|
|
|
|
|
if np.iscomplexobj(arr): |
|
|
if complex_mode == "magnitude": |
|
|
arr = np.abs(arr) |
|
|
else: |
|
|
|
|
|
if arr.ndim == 4 and arr.shape[1] == 1: |
|
|
arr = arr[:, 0] |
|
|
if arr.ndim == 3: |
|
|
real = arr.real.astype(np.float32, copy=False) |
|
|
imag = arr.imag.astype(np.float32, copy=False) |
|
|
n, h, w = real.shape |
|
|
inter = np.empty((n, h, w * 2), dtype=np.float32) |
|
|
inter[:, :, 0::2] = real |
|
|
inter[:, :, 1::2] = imag |
|
|
arr = inter |
|
|
else: |
|
|
|
|
|
arr = np.abs(arr) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if arr.ndim == 4: |
|
|
|
|
|
if arr.shape[1] > 1: |
|
|
specs = arr.mean(axis=1) |
|
|
else: |
|
|
specs = arr[:, 0] |
|
|
elif arr.ndim == 3: |
|
|
specs = arr |
|
|
elif arr.ndim == 2: |
|
|
specs = arr[None, ...] |
|
|
else: |
|
|
print(f"[WARN] Unexpected spectrogram shape in {path}: {arr.shape}") |
|
|
continue |
|
|
|
|
|
for spec in specs: |
|
|
sample = np.asarray(spec, dtype=np.float32) |
|
|
bucket = class_samples[class_label] |
|
|
|
|
|
if len(bucket) < max_per_class: |
|
|
bucket.append((sample, modulation_folder, mobility_folder)) |
|
|
seen_counts[class_label] += 1 |
|
|
elif mode == "reservoir": |
|
|
seen_counts[class_label] += 1 |
|
|
j = rng.randint(0, seen_counts[class_label] - 1) |
|
|
if j < max_per_class: |
|
|
bucket[j] = (sample, modulation_folder, mobility_folder) |
|
|
else: |
|
|
break |
|
|
|
|
|
return class_samples |
|
|
|
|
|
|
|
|
def sample_balanced_dataset( |
|
|
class_samples: Dict[str, List[Tuple[np.ndarray, str, str]]], |
|
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[str]]: |
|
|
"""Stack the sampled spectrograms alongside SNR, modulation, and mobility labels.""" |
|
|
|
|
|
features: List[np.ndarray] = [] |
|
|
snr_labels: List[str] = [] |
|
|
modulation_labels: List[str] = [] |
|
|
mobility_labels: List[str] = [] |
|
|
class_names = sorted(class_samples.keys()) |
|
|
|
|
|
for class_name in class_names: |
|
|
samples = class_samples[class_name] |
|
|
if not samples: |
|
|
continue |
|
|
for sample, modulation_label, mobility_label in samples: |
|
|
features.append(sample) |
|
|
snr_labels.append(class_name) |
|
|
modulation_labels.append(modulation_label) |
|
|
mobility_labels.append(mobility_label) |
|
|
|
|
|
if not features: |
|
|
raise RuntimeError("No spectrogram samples collected for the specified filters") |
|
|
|
|
|
stacked = np.stack(features) |
|
|
return ( |
|
|
stacked, |
|
|
np.array(snr_labels), |
|
|
np.array(modulation_labels), |
|
|
np.array(mobility_labels), |
|
|
class_names, |
|
|
) |
|
|
|
|
|
|
|
|
def unfold_patches_square(x: torch.Tensor, patch_size: int = 4) -> torch.Tensor: |
|
|
|
|
|
patches_h = x.unfold(1, patch_size, patch_size) |
|
|
patches = patches_h.unfold(2, patch_size, patch_size) |
|
|
return patches.contiguous().view(x.shape[0], -1, patch_size * patch_size) |
|
|
|
|
|
|
|
|
def unfold_patches_rect(x: torch.Tensor, patch_rows: int = 4, patch_cols: int = 8) -> torch.Tensor: |
|
|
|
|
|
patches_h = x.unfold(1, patch_rows, patch_rows) |
|
|
patches = patches_h.unfold(2, patch_cols, patch_cols) |
|
|
return patches.contiguous().view(x.shape[0], -1, patch_rows * patch_cols) |
|
|
|
|
|
|
|
|
def extract_tokens(spec: np.ndarray, device: torch.device, interleaved: bool) -> torch.Tensor: |
|
|
tensor = torch.from_numpy(spec).unsqueeze(0).to(device) |
|
|
if interleaved: |
|
|
|
|
|
return unfold_patches_rect(tensor, 4, 8) |
|
|
else: |
|
|
return unfold_patches_square(tensor, 4) |
|
|
|
|
|
|
|
|
def pool_embeddings( |
|
|
tokens: torch.Tensor, |
|
|
model: torch.nn.Module, |
|
|
pooling: str, |
|
|
) -> np.ndarray: |
|
|
|
|
|
cls_token = torch.full((tokens.size(0), 1, tokens.size(-1)), 0.2, device=tokens.device) |
|
|
inputs = torch.cat([cls_token, tokens], dim=1) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(inputs) |
|
|
|
|
|
if pooling == "cls": |
|
|
pooled = outputs[:, 0] |
|
|
else: |
|
|
pooled = outputs[:, 1:].mean(dim=1) |
|
|
|
|
|
return pooled.detach().cpu().numpy() |
|
|
|
|
|
|
|
|
def sort_snr_labels(labels: List[str]) -> List[str]: |
|
|
"""Sort SNR labels by numeric value instead of lexicographic order.""" |
|
|
def extract_snr_value(label: str) -> float: |
|
|
"""Extract numeric SNR value from label like 'SNR-5dB' -> -5.0""" |
|
|
import re |
|
|
match = re.search(r'SNR(-?\d+)dB', label) |
|
|
if match: |
|
|
return float(match.group(1)) |
|
|
else: |
|
|
return float('inf') |
|
|
|
|
|
return sorted(labels, key=extract_snr_value) |
|
|
|
|
|
|
|
|
def run_tsne(x: np.ndarray, labels: np.ndarray, title: str, ax: plt.Axes) -> None: |
|
|
scaler = StandardScaler() |
|
|
x_scaled = scaler.fit_transform(x) |
|
|
|
|
|
x_scaled = np.nan_to_num(x_scaled, copy=False, nan=0.0, posinf=0.0, neginf=0.0) |
|
|
x_scaled = np.clip(x_scaled, -1e6, 1e6) |
|
|
x_scaled = x_scaled.astype(np.float32, copy=False) |
|
|
|
|
|
max_perplexity = max(5, min(30, len(x_scaled) // 10)) |
|
|
perplexity = min(max_perplexity, len(x_scaled) - 1) |
|
|
perplexity = max(perplexity, 5) |
|
|
|
|
|
tsne = TSNE( |
|
|
n_components=2, |
|
|
perplexity=perplexity, |
|
|
random_state=42, |
|
|
init="random", |
|
|
learning_rate="auto", |
|
|
) |
|
|
try: |
|
|
embedding = tsne.fit_transform(x_scaled) |
|
|
except Exception as e: |
|
|
|
|
|
print(f"[WARN] t-SNE failed ({e}); falling back to PCA.") |
|
|
pca = PCA(n_components=2, svd_solver="full", random_state=42) |
|
|
embedding = pca.fit_transform(x_scaled) |
|
|
|
|
|
class_names = sort_snr_labels(list(np.unique(labels))) |
|
|
colors = plt.cm.Set3(np.linspace(0, 1, len(class_names))) |
|
|
for color, class_name in zip(colors, class_names): |
|
|
mask = labels == class_name |
|
|
ax.scatter(embedding[mask, 0], embedding[mask, 1], c=[color], s=18, alpha=0.7, label=class_name) |
|
|
|
|
|
|
|
|
ax.set_xlabel("t-SNE Component 1", fontsize=16) |
|
|
ax.set_ylabel("t-SNE Component 2", fontsize=16) |
|
|
ax.tick_params(labelsize=14) |
|
|
ax.grid(True, alpha=0.3) |
|
|
ax.legend(bbox_to_anchor=(1.02, 1), loc="upper left", fontsize=12) |
|
|
|
|
|
|
|
|
def compute_metrics(name: str, features: np.ndarray, labels: np.ndarray) -> None: |
|
|
if len(np.unique(labels)) < 2: |
|
|
print(f"[METRIC] {name}: skipped (only one class present)") |
|
|
return |
|
|
|
|
|
scaler = StandardScaler() |
|
|
features_scaled = scaler.fit_transform(features) |
|
|
|
|
|
silhouette = silhouette_score(features_scaled, labels) |
|
|
|
|
|
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42) |
|
|
scores: List[float] = [] |
|
|
for train_idx, test_idx in skf.split(features_scaled, labels): |
|
|
clf = KNeighborsClassifier(n_neighbors=5) |
|
|
clf.fit(features_scaled[train_idx], labels[train_idx]) |
|
|
scores.append(clf.score(features_scaled[test_idx], labels[test_idx])) |
|
|
|
|
|
mean_acc = float(np.mean(scores)) |
|
|
std_acc = float(np.std(scores)) |
|
|
print( |
|
|
f"[METRIC] {name}: silhouette={silhouette:.3f}, " |
|
|
f"5-NN accuracy={mean_acc:.3f} ± {std_acc:.3f}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
args = parse_args() |
|
|
|
|
|
if args.profile: |
|
|
preset = PROFILE_PRESETS.get(args.profile) |
|
|
if not preset: |
|
|
raise ValueError(f"Unknown profile requested: {args.profile}") |
|
|
if args.data_root == DEFAULT_DATA_ROOT: |
|
|
args.data_root = preset["data_root"] |
|
|
if args.models_root == DEFAULT_MODELS_ROOT: |
|
|
args.models_root = preset["models_root"] |
|
|
|
|
|
if args.profile: |
|
|
print(f"[INFO] Profile preset active: {args.profile}") |
|
|
|
|
|
random.seed(args.seed) |
|
|
np.random.seed(args.seed) |
|
|
torch.manual_seed(args.seed) |
|
|
|
|
|
data_root = Path(args.data_root) |
|
|
if not data_root.exists(): |
|
|
raise FileNotFoundError(f"Data root not found: {data_root}") |
|
|
|
|
|
allowed_snrs = parse_snr_list(args.snrs) |
|
|
|
|
|
mobility_filter: set[str] | None = None |
|
|
if args.mobility: |
|
|
mobility_values: List[str] = [] |
|
|
for value in args.mobility: |
|
|
mobility_values.extend([item.strip() for item in value.split(",") if item.strip()]) |
|
|
mobility_values = [value for value in mobility_values if value] |
|
|
if mobility_values and not (len(mobility_values) == 1 and mobility_values[0].lower() == "all"): |
|
|
mobility_filter = {value.lower() for value in mobility_values} |
|
|
print( |
|
|
"[INFO] Mobility filter active: " |
|
|
+ ", ".join(sorted(mobility_filter)) |
|
|
) |
|
|
|
|
|
class_samples = list_snr_samples( |
|
|
data_root, |
|
|
args.modulation, |
|
|
allowed_snrs, |
|
|
mobility_filter, |
|
|
args.fft_folder, |
|
|
args.samples_per_snr, |
|
|
random, |
|
|
args.sampling_mode, |
|
|
args.complex_mode, |
|
|
) |
|
|
samples, snr_labels, modulation_labels, mobility_labels, _ = sample_balanced_dataset(class_samples) |
|
|
|
|
|
if args.label_field == "snr": |
|
|
labels = snr_labels |
|
|
label_name = "SNR" |
|
|
label_display = "SNR" |
|
|
elif args.label_field == "modulation": |
|
|
labels = modulation_labels |
|
|
label_name = "modulation" |
|
|
label_display = "Modulation" |
|
|
else: |
|
|
labels = mobility_labels |
|
|
label_name = "mobility" |
|
|
label_display = "Mobility" |
|
|
|
|
|
unique_labels = np.unique(labels) |
|
|
print( |
|
|
f"[INFO] Loaded {samples.shape[0]} spectrograms across {len(unique_labels)} {label_name} buckets" |
|
|
) |
|
|
class_counts = Counter(labels) |
|
|
print(f"[INFO] Samples per {label_name}:") |
|
|
for name, count in sorted(class_counts.items()): |
|
|
print(f" {name}: {count}") |
|
|
|
|
|
if args.label_field != "snr": |
|
|
snr_counts = Counter(snr_labels) |
|
|
print("[INFO] SNR distribution (sampling classes):") |
|
|
for name, count in sorted(snr_counts.items()): |
|
|
print(f" {name}: {count}") |
|
|
if args.label_field == "mobility": |
|
|
modulation_counts = Counter(modulation_labels) |
|
|
print("[INFO] Modulation distribution:") |
|
|
for name, count in sorted(modulation_counts.items()): |
|
|
print(f" {name}: {count}") |
|
|
|
|
|
normalization_mode = args.normalization |
|
|
if normalization_mode == "per-sample": |
|
|
normalized_samples = normalize_per_sample(samples) |
|
|
else: |
|
|
normalized_samples = normalize_dataset(samples) |
|
|
print(f"[INFO] Normalisation mode: {normalization_mode}") |
|
|
|
|
|
|
|
|
raw_vectors = normalized_samples.reshape(normalized_samples.shape[0], -1) |
|
|
|
|
|
|
|
|
if args.checkpoint: |
|
|
checkpoint_path = Path(args.checkpoint) |
|
|
if not checkpoint_path.exists(): |
|
|
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") |
|
|
else: |
|
|
checkpoint_path = find_latest_checkpoint(Path(args.models_root)) |
|
|
print(f"[INFO] Using checkpoint: {checkpoint_path}") |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"[INFO] Using device: {device}") |
|
|
print(f"[INFO] Pooling strategy: {args.pooling}") |
|
|
|
|
|
use_interleaved = False |
|
|
if args.complex_mode == "interleaved": |
|
|
use_interleaved = True |
|
|
elif args.complex_mode == "auto": |
|
|
|
|
|
sample_shape = tuple(normalized_samples.shape[1:]) |
|
|
if len(sample_shape) == 2 and sample_shape[1] > 128: |
|
|
use_interleaved = True |
|
|
|
|
|
element_length = 32 if use_interleaved else 16 |
|
|
|
|
|
model = lwm_model(element_length=element_length, d_model=128, n_layers=12, max_len=1025, n_heads=8, dropout=0.1) |
|
|
state_dict = torch.load(checkpoint_path, map_location=device) |
|
|
if any(k.startswith("module.") for k in state_dict): |
|
|
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} |
|
|
try: |
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
except RuntimeError as e: |
|
|
msg = str(e) |
|
|
|
|
|
mismatch16 = "[128, 16]" in msg or "[16]" in msg |
|
|
mismatch32 = "[128, 32]" in msg or "[32]" in msg |
|
|
if mismatch16 and not mismatch32: |
|
|
print("[WARN] Checkpoint expects token dimension 16. Falling back to magnitude embedding.") |
|
|
use_interleaved = False |
|
|
element_length = 16 |
|
|
|
|
|
model = lwm_model(element_length=element_length, d_model=128, n_layers=12, max_len=1025, n_heads=8, dropout=0.1) |
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
else: |
|
|
raise |
|
|
model = model.to(device).eval() |
|
|
|
|
|
def collapse_interleaved_to_magnitude(spec: np.ndarray) -> np.ndarray: |
|
|
|
|
|
h, w2 = spec.shape |
|
|
if w2 % 2 != 0: |
|
|
return spec |
|
|
real = spec[:, 0::2] |
|
|
imag = spec[:, 1::2] |
|
|
return np.sqrt(np.maximum(real * real + imag * imag, 0.0, dtype=np.float32)) |
|
|
|
|
|
|
|
|
embed_inputs = normalized_samples |
|
|
if not use_interleaved and normalized_samples.shape[2] > 128: |
|
|
collapsed = [] |
|
|
for spec in normalized_samples: |
|
|
collapsed.append(collapse_interleaved_to_magnitude(spec)) |
|
|
embed_inputs = np.stack(collapsed).astype(np.float32, copy=False) |
|
|
|
|
|
embeddings: List[np.ndarray] = [] |
|
|
for spec in embed_inputs: |
|
|
tokens = extract_tokens(spec, device, interleaved=use_interleaved) |
|
|
embedding = pool_embeddings(tokens, model, args.pooling) |
|
|
embeddings.append(embedding.squeeze(0)) |
|
|
|
|
|
embeddings_np = np.vstack(embeddings) |
|
|
print(f"[INFO] Generated embeddings with shape {embeddings_np.shape}") |
|
|
|
|
|
if args.report_metrics: |
|
|
compute_metrics("Raw spectrogram", raw_vectors, labels) |
|
|
pool_label = "LWM mean" if args.pooling == "mean" else "LWM CLS" |
|
|
compute_metrics(pool_label, embeddings_np, labels) |
|
|
if args.metrics_only: |
|
|
return |
|
|
|
|
|
|
|
|
fig, axes = plt.subplots(1, 2, figsize=(18, 7)) |
|
|
raw_title = f"Raw Spectrogram t-SNE (by {label_display})" |
|
|
pooling_label = "Mean Pool" if args.pooling == "mean" else "CLS Token" |
|
|
embedding_title = f"LWM Embedding t-SNE ({pooling_label}, by {label_display})" |
|
|
run_tsne(raw_vectors, labels, raw_title, axes[0]) |
|
|
run_tsne(embeddings_np, labels, embedding_title, axes[1]) |
|
|
|
|
|
fig.tight_layout() |
|
|
save_path = Path(args.save_path) |
|
|
|
|
|
communication_tag: str | None = None |
|
|
if args.profile: |
|
|
communication_tag = args.profile |
|
|
else: |
|
|
root_name = Path(args.data_root).name |
|
|
if root_name: |
|
|
communication_tag = root_name |
|
|
|
|
|
def ensure_suffix(stem: str, suffix: str) -> str: |
|
|
return stem if stem.endswith(suffix) else f"{stem}_{suffix}" |
|
|
|
|
|
updated_stem = save_path.stem |
|
|
if communication_tag: |
|
|
updated_stem = ensure_suffix(updated_stem, communication_tag) |
|
|
if args.label_field != "snr": |
|
|
label_suffix = f"by_{args.label_field}" |
|
|
updated_stem = ensure_suffix(updated_stem, label_suffix) |
|
|
|
|
|
if updated_stem != save_path.stem: |
|
|
save_path = save_path.with_name(f"{updated_stem}{save_path.suffix}") |
|
|
save_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
plt.savefig(save_path, dpi=600, bbox_inches="tight") |
|
|
print(f"[INFO] Figure saved to {save_path}") |
|
|
|
|
|
|
|
|
pdf_path = save_path.with_suffix('.pdf') |
|
|
plt.savefig(pdf_path, format='pdf', bbox_inches="tight") |
|
|
print(f"[INFO] PDF version saved to {pdf_path}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|