|
|
|
|
|
"""Visualise how strongly Doppler/mobility drives the learned embedding space. |
|
|
|
|
|
This script mirrors ``task1/plot_mod_tsne.py`` but groups spectrograms by their |
|
|
mobility (``static``, ``pedestrian``, ``vehicular``) to inspect whether LWM |
|
|
embeddings primarily encode Doppler rather than modulation differences. |
|
|
|
|
|
Usage example: |
|
|
|
|
|
```bash |
|
|
python task1/plot_doppler_tsne.py \ |
|
|
--data-root spectrograms/city_1_losangeles/LTE \ |
|
|
--modulation QPSK \ |
|
|
--snr SNR10dB \ |
|
|
--dopplers static,pedestrian,vehicular \ |
|
|
--save-path task1/doppler_separation_plot_latest.png |
|
|
``` |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import argparse |
|
|
import glob |
|
|
import pickle |
|
|
import random |
|
|
import re |
|
|
from pathlib import Path |
|
|
from collections import Counter, defaultdict |
|
|
from typing import Dict, Iterable, List, Tuple |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
import torch |
|
|
from sklearn.manifold import TSNE |
|
|
from sklearn.metrics import silhouette_score |
|
|
from sklearn.model_selection import StratifiedKFold |
|
|
from sklearn.neighbors import KNeighborsClassifier |
|
|
from sklearn.preprocessing import StandardScaler |
|
|
|
|
|
from pretraining.pretrained_model import lwm as lwm_model |
|
|
|
|
|
|
|
|
def normalize_per_sample(specs: np.ndarray, eps: float = 1e-6) -> np.ndarray: |
|
|
means = specs.mean(axis=(1, 2), keepdims=True) |
|
|
stds = specs.std(axis=(1, 2), keepdims=True) |
|
|
stds = np.maximum(stds, eps) |
|
|
return ((specs - means) / stds).astype(np.float32, copy=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
|
parser = argparse.ArgumentParser(description=__doc__) |
|
|
parser.add_argument( |
|
|
"--data-root", |
|
|
default="spectrograms/city_0_newyork/WiFi", |
|
|
help="Root directory containing modulation folders (default: %(default)s)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--modulation", |
|
|
default="QPSK", |
|
|
help=( |
|
|
"Modulation folder(s) to load. Pass 'all' or a comma-separated list " |
|
|
"to include multiple values (default: %(default)s)" |
|
|
), |
|
|
) |
|
|
parser.add_argument( |
|
|
"--snr", |
|
|
default="SNR10dB", |
|
|
help=( |
|
|
"SNR folder(s) to analyse. Pass 'all' or a comma-separated list to " |
|
|
"include multiple values (default: %(default)s)" |
|
|
), |
|
|
) |
|
|
parser.add_argument( |
|
|
"--dopplers", |
|
|
default="static,pedestrian,vehicular", |
|
|
help=( |
|
|
"Comma-separated list of mobility folders to include. Pass 'all' " |
|
|
"to include every mobility present (default: %(default)s)" |
|
|
), |
|
|
) |
|
|
parser.add_argument( |
|
|
"--fft-folder", |
|
|
default="all", |
|
|
help=( |
|
|
"FFT size folder name to use. Pass 'all' to include every FFT variant " |
|
|
"(default: %(default)s)" |
|
|
), |
|
|
) |
|
|
parser.add_argument( |
|
|
"--samples-per-doppler", |
|
|
type=int, |
|
|
default=500, |
|
|
help="Maximum number of samples to draw for each mobility label", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--balance-mode", |
|
|
choices=("mobility", "mobility_snr_mod"), |
|
|
default="mobility", |
|
|
help="Sampling strategy: uniform per mobility or per (modulation, SNR, mobility)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--samples-per-combo", |
|
|
type=int, |
|
|
default=150, |
|
|
help="Maximum samples per (modulation, SNR, mobility) combo when balance-mode=mobility_snr_mod", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--seed", |
|
|
type=int, |
|
|
default=42, |
|
|
help="Random seed for sampling and t-SNE", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--pooling", |
|
|
choices=("mean", "cls"), |
|
|
default="mean", |
|
|
help="How to collapse token embeddings into a single vector", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--save-path", |
|
|
default="task1/doppler_separation_plot_latest.png", |
|
|
help="Location to save the generated figure (default: %(default)s)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--checkpoint", |
|
|
default=None, |
|
|
help="Optional explicit checkpoint path; overrides automatic latest selection", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--contrastive-checkpoint", |
|
|
default=None, |
|
|
help="Optional checkpoint path after contrastive fine-tuning for comparison", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--models-root", |
|
|
default="models/20250922_235752", |
|
|
help=( |
|
|
"Directory containing checkpoints. When --checkpoint is not given, " |
|
|
"the latest/best checkpoint inside this directory will be used " |
|
|
"(default: %(default)s)" |
|
|
), |
|
|
) |
|
|
parser.add_argument( |
|
|
"--report-metrics", |
|
|
action="store_true", |
|
|
help="Print clustering metrics (silhouette, 5-fold kNN accuracy)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--metrics-only", |
|
|
action="store_true", |
|
|
help="Exit after reporting metrics without running t-SNE or saving figures", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--sampling-mode", |
|
|
choices=("first", "reservoir"), |
|
|
default="first", |
|
|
help="How to down-sample each class (default: first)", |
|
|
) |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def find_latest_checkpoint(models_root: Path) -> Path: |
|
|
"""Return a checkpoint path under ``models_root``. |
|
|
|
|
|
Works with either a parent directory that contains multiple run folders, |
|
|
or directly with a single run directory containing ``*.pth`` files. |
|
|
Chooses the checkpoint with the lowest parsed validation value when |
|
|
available, else falls back to most-recent modification time. |
|
|
""" |
|
|
|
|
|
if not models_root.exists(): |
|
|
raise FileNotFoundError(f"Models root not found: {models_root}") |
|
|
|
|
|
if models_root.is_file(): |
|
|
raise FileNotFoundError(f"Expected a directory, got file: {models_root}") |
|
|
|
|
|
|
|
|
checkpoints = list(models_root.glob("*.pth")) |
|
|
if not checkpoints: |
|
|
|
|
|
run_dirs = [p for p in models_root.iterdir() if p.is_dir()] |
|
|
candidate_runs = [d for d in run_dirs if any(d.glob("*.pth"))] |
|
|
if not candidate_runs: |
|
|
raise FileNotFoundError( |
|
|
f"No checkpoints found under {models_root} (no .pth files in this dir or its run subdirs)" |
|
|
) |
|
|
latest_run = max(candidate_runs, key=lambda p: p.stat().st_mtime) |
|
|
checkpoints = list(latest_run.glob("*.pth")) |
|
|
|
|
|
def parse_val_metric(path: Path) -> float | None: |
|
|
match = re.search(r"_val([0-9.]+)", path.name) |
|
|
if match: |
|
|
try: |
|
|
return float(match.group(1)) |
|
|
except ValueError: |
|
|
return None |
|
|
return None |
|
|
|
|
|
parsed = [(parse_val_metric(p), p) for p in checkpoints] |
|
|
valid = [item for item in parsed if item[0] is not None] |
|
|
if valid: |
|
|
valid.sort(key=lambda item: item[0]) |
|
|
return valid[0][1] |
|
|
|
|
|
|
|
|
return max(checkpoints, key=lambda p: p.stat().st_mtime) |
|
|
|
|
|
|
|
|
def parse_list_argument(argument: str | None) -> set[str] | None: |
|
|
if argument is None or argument.lower() == "all": |
|
|
return None |
|
|
values = [item.strip() for item in argument.split(",") if item.strip()] |
|
|
return set(values) |
|
|
|
|
|
|
|
|
def list_doppler_samples( |
|
|
data_root: Path, |
|
|
allowed_modulations: set[str] | None, |
|
|
allowed_snrs: set[str] | None, |
|
|
allowed_dopplers: set[str] | None, |
|
|
fft_folder: str, |
|
|
max_per_class: int, |
|
|
rng: random.Random, |
|
|
mode: str, |
|
|
balance_mode: str, |
|
|
samples_per_combo: int, |
|
|
) -> Dict[str, List[np.ndarray]]: |
|
|
"""Collect spectrogram samples grouped by mobility label.""" |
|
|
|
|
|
class_samples: Dict[str, List[np.ndarray]] = defaultdict(list) |
|
|
seen_counts: Dict[str, int] = defaultdict(int) |
|
|
combo_samples: Dict[Tuple[str, str, str], List[np.ndarray]] = defaultdict(list) |
|
|
combo_counts: Dict[Tuple[str, str, str], int] = defaultdict(int) |
|
|
pattern = str(data_root / "**" / "spectrograms" / "*.pkl") |
|
|
for path_str in glob.glob(pattern, recursive=True): |
|
|
path = Path(path_str) |
|
|
|
|
|
try: |
|
|
rel_parts = path.relative_to(data_root).parts |
|
|
except ValueError: |
|
|
continue |
|
|
|
|
|
if len(rel_parts) < 7: |
|
|
continue |
|
|
|
|
|
modulation_folder = rel_parts[0] |
|
|
snr_folder = rel_parts[2] |
|
|
mobility_folder = rel_parts[3] |
|
|
fft_folder_name = rel_parts[6] |
|
|
|
|
|
if allowed_modulations is not None and modulation_folder not in allowed_modulations: |
|
|
continue |
|
|
if allowed_snrs is not None and snr_folder not in allowed_snrs: |
|
|
continue |
|
|
if allowed_dopplers is not None and mobility_folder not in allowed_dopplers: |
|
|
continue |
|
|
if fft_folder != "all" and fft_folder_name != fft_folder: |
|
|
continue |
|
|
|
|
|
class_label = mobility_folder |
|
|
|
|
|
if balance_mode == "mobility" and mode == "first" and len(class_samples[class_label]) >= max_per_class: |
|
|
continue |
|
|
|
|
|
try: |
|
|
with open(path, "rb") as fh: |
|
|
data = pickle.load(fh) |
|
|
except Exception as exc: |
|
|
print(f"[WARN] Failed to load {path}: {exc}") |
|
|
continue |
|
|
|
|
|
if isinstance(data, dict) and "spectrograms" in data: |
|
|
specs = data["spectrograms"] |
|
|
elif isinstance(data, np.ndarray): |
|
|
specs = data |
|
|
else: |
|
|
print(f"[WARN] Unknown format in {path}: {type(data)}") |
|
|
continue |
|
|
|
|
|
specs = np.asarray(specs) |
|
|
if specs.ndim == 3: |
|
|
pass |
|
|
elif specs.ndim == 2: |
|
|
specs = specs[None, ...] |
|
|
else: |
|
|
print(f"[WARN] Unexpected spectrogram shape in {path}: {specs.shape}") |
|
|
continue |
|
|
|
|
|
for spec in specs: |
|
|
sample = spec.astype(np.float32) |
|
|
if balance_mode == "mobility": |
|
|
bucket = class_samples[class_label] |
|
|
|
|
|
if len(bucket) < max_per_class: |
|
|
bucket.append(sample) |
|
|
seen_counts[class_label] += 1 |
|
|
elif mode == "reservoir": |
|
|
seen_counts[class_label] += 1 |
|
|
j = rng.randint(0, seen_counts[class_label] - 1) |
|
|
if j < max_per_class: |
|
|
bucket[j] = sample |
|
|
else: |
|
|
break |
|
|
else: |
|
|
combo_key = (class_label, snr_folder, modulation_folder) |
|
|
bucket = combo_samples[combo_key] |
|
|
limit = samples_per_combo |
|
|
if limit <= 0: |
|
|
bucket.append(sample) |
|
|
combo_counts[combo_key] += 1 |
|
|
elif len(bucket) < limit: |
|
|
bucket.append(sample) |
|
|
combo_counts[combo_key] += 1 |
|
|
elif mode == "reservoir": |
|
|
combo_counts[combo_key] += 1 |
|
|
j = rng.randint(0, combo_counts[combo_key] - 1) |
|
|
if j < limit: |
|
|
bucket[j] = sample |
|
|
else: |
|
|
break |
|
|
|
|
|
if balance_mode == "mobility": |
|
|
return class_samples |
|
|
|
|
|
balanced: Dict[str, List[np.ndarray]] = defaultdict(list) |
|
|
for (mobility, snr_label, mod_label), samples in combo_samples.items(): |
|
|
if not samples: |
|
|
continue |
|
|
balanced[mobility].extend(samples) |
|
|
return balanced |
|
|
|
|
|
|
|
|
def sample_balanced_dataset( |
|
|
class_samples: Dict[str, List[np.ndarray]], |
|
|
) -> Tuple[np.ndarray, np.ndarray, List[str]]: |
|
|
"""Draw up to ``samples_per_doppler`` from each mobility bucket.""" |
|
|
|
|
|
features: List[np.ndarray] = [] |
|
|
labels: List[str] = [] |
|
|
class_names = sorted(class_samples.keys()) |
|
|
|
|
|
for class_name in class_names: |
|
|
samples = class_samples[class_name] |
|
|
if not samples: |
|
|
continue |
|
|
features.extend(samples) |
|
|
labels.extend([class_name] * len(samples)) |
|
|
|
|
|
if not features: |
|
|
raise RuntimeError("No spectrogram samples collected for the specified filters") |
|
|
|
|
|
stacked = np.stack(features) |
|
|
return stacked, np.array(labels), class_names |
|
|
|
|
|
|
|
|
def unfold_patches(x: torch.Tensor, patch_size: int = 4) -> torch.Tensor: |
|
|
|
|
|
patches_h = x.unfold(1, patch_size, patch_size) |
|
|
patches = patches_h.unfold(2, patch_size, patch_size) |
|
|
return patches.contiguous().view(x.shape[0], -1, patch_size * patch_size) |
|
|
|
|
|
|
|
|
def extract_tokens(spec: np.ndarray, device: torch.device) -> torch.Tensor: |
|
|
tensor = torch.from_numpy(spec).unsqueeze(0).to(device) |
|
|
return unfold_patches(tensor) |
|
|
|
|
|
|
|
|
def pool_embeddings( |
|
|
tokens: torch.Tensor, |
|
|
model: torch.nn.Module, |
|
|
pooling: str, |
|
|
) -> np.ndarray: |
|
|
|
|
|
cls_token = torch.full((tokens.size(0), 1, tokens.size(-1)), 0.2, device=tokens.device) |
|
|
inputs = torch.cat([cls_token, tokens], dim=1) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(inputs) |
|
|
|
|
|
if pooling == "cls": |
|
|
pooled = outputs[:, 0] |
|
|
else: |
|
|
pooled = outputs[:, 1:].mean(dim=1) |
|
|
|
|
|
return pooled.detach().cpu().numpy() |
|
|
|
|
|
|
|
|
def run_tsne(x: np.ndarray, labels: np.ndarray, title: str, ax: plt.Axes) -> None: |
|
|
scaler = StandardScaler() |
|
|
x_scaled = scaler.fit_transform(x) |
|
|
|
|
|
max_perplexity = max(5, min(30, len(x_scaled) // 10)) |
|
|
perplexity = min(max_perplexity, len(x_scaled) - 1) |
|
|
perplexity = max(perplexity, 5) |
|
|
|
|
|
tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42) |
|
|
embedding = tsne.fit_transform(x_scaled) |
|
|
|
|
|
class_names = sorted(np.unique(labels)) |
|
|
colors = plt.cm.Set3(np.linspace(0, 1, len(class_names))) |
|
|
for color, class_name in zip(colors, class_names): |
|
|
mask = labels == class_name |
|
|
ax.scatter(embedding[mask, 0], embedding[mask, 1], c=[color], s=18, alpha=0.7, label=class_name) |
|
|
|
|
|
ax.set_title(title, fontsize=14, fontweight="bold") |
|
|
ax.set_xlabel("t-SNE Component 1", fontsize=12) |
|
|
ax.set_ylabel("t-SNE Component 2", fontsize=12) |
|
|
ax.grid(True, alpha=0.3) |
|
|
ax.legend(bbox_to_anchor=(1.02, 1), loc="upper left", fontsize=8) |
|
|
|
|
|
|
|
|
def compute_metrics(name: str, features: np.ndarray, labels: np.ndarray) -> None: |
|
|
scaler = StandardScaler() |
|
|
features_scaled = scaler.fit_transform(features) |
|
|
|
|
|
silhouette = silhouette_score(features_scaled, labels) |
|
|
|
|
|
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42) |
|
|
scores: List[float] = [] |
|
|
for train_idx, test_idx in skf.split(features_scaled, labels): |
|
|
clf = KNeighborsClassifier(n_neighbors=5) |
|
|
clf.fit(features_scaled[train_idx], labels[train_idx]) |
|
|
scores.append(clf.score(features_scaled[test_idx], labels[test_idx])) |
|
|
|
|
|
mean_acc = float(np.mean(scores)) |
|
|
std_acc = float(np.std(scores)) |
|
|
print( |
|
|
f"[METRIC] {name}: silhouette={silhouette:.3f}, " |
|
|
f"5-NN accuracy={mean_acc:.3f} ± {std_acc:.3f}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
args = parse_args() |
|
|
|
|
|
random.seed(args.seed) |
|
|
np.random.seed(args.seed) |
|
|
torch.manual_seed(args.seed) |
|
|
|
|
|
data_root = Path(args.data_root) |
|
|
if not data_root.exists(): |
|
|
raise FileNotFoundError(f"Data root not found: {data_root}") |
|
|
|
|
|
allowed_dopplers = parse_list_argument(args.dopplers) |
|
|
allowed_modulations = parse_list_argument(args.modulation) |
|
|
allowed_snrs = parse_list_argument(args.snr) |
|
|
|
|
|
class_samples = list_doppler_samples( |
|
|
data_root, |
|
|
allowed_modulations, |
|
|
allowed_snrs, |
|
|
allowed_dopplers, |
|
|
args.fft_folder, |
|
|
args.samples_per_doppler, |
|
|
random, |
|
|
args.sampling_mode, |
|
|
args.balance_mode, |
|
|
args.samples_per_combo, |
|
|
) |
|
|
samples, labels, _ = sample_balanced_dataset(class_samples) |
|
|
unique_labels = np.unique(labels) |
|
|
print(f"[INFO] Loaded {samples.shape[0]} spectrograms across {len(unique_labels)} mobility buckets") |
|
|
class_counts = Counter(labels) |
|
|
print("[INFO] Samples per mobility:") |
|
|
for name, count in sorted(class_counts.items()): |
|
|
print(f" {name}: {count}") |
|
|
|
|
|
normalized_samples = normalize_per_sample(samples) |
|
|
|
|
|
|
|
|
raw_vectors = normalized_samples.reshape(normalized_samples.shape[0], -1) |
|
|
|
|
|
|
|
|
if args.checkpoint: |
|
|
checkpoint_path = Path(args.checkpoint) |
|
|
if not checkpoint_path.exists(): |
|
|
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") |
|
|
else: |
|
|
checkpoint_path = find_latest_checkpoint(Path(args.models_root)) |
|
|
|
|
|
contrastive_path = None |
|
|
if args.contrastive_checkpoint: |
|
|
contrastive_path = Path(args.contrastive_checkpoint) |
|
|
if not contrastive_path.exists(): |
|
|
raise FileNotFoundError(f"Contrastive checkpoint not found: {contrastive_path}") |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"[INFO] Using device: {device}") |
|
|
print(f"[INFO] Pooling strategy: {args.pooling}") |
|
|
|
|
|
transformer = lwm_model(element_length=16, d_model=128, n_layers=12, max_len=1025, n_heads=8, dropout=0.1) |
|
|
transformer = transformer.to(device) |
|
|
|
|
|
def embed_with_checkpoint(path: Path, label: str) -> np.ndarray: |
|
|
print(f"[INFO] Using checkpoint ({label}): {path}") |
|
|
state_dict = torch.load(path, map_location=device) |
|
|
if any(k.startswith("module.") for k in state_dict): |
|
|
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} |
|
|
transformer.load_state_dict(state_dict, strict=False) |
|
|
transformer.eval() |
|
|
|
|
|
embeddings: List[np.ndarray] = [] |
|
|
for spec in normalized_samples: |
|
|
tokens = extract_tokens(spec, device) |
|
|
embedding = pool_embeddings(tokens, transformer, args.pooling) |
|
|
embeddings.append(embedding.squeeze(0)) |
|
|
embeddings_np = np.vstack(embeddings) |
|
|
print(f"[INFO] Generated embeddings ({label}) with shape {embeddings_np.shape}") |
|
|
return embeddings_np |
|
|
|
|
|
pooling_label = "Mean Pool" if args.pooling == "mean" else "CLS Token" |
|
|
embedding_views: List[Tuple[str, np.ndarray]] = [] |
|
|
baseline_embeddings = embed_with_checkpoint(checkpoint_path, "baseline") |
|
|
baseline_title = f"LWM Embedding t-SNE ({pooling_label})" |
|
|
embedding_views.append((baseline_title, baseline_embeddings)) |
|
|
|
|
|
if contrastive_path is not None: |
|
|
contrastive_embeddings = embed_with_checkpoint(contrastive_path, "contrastive") |
|
|
contrastive_title = f"Contrastive Embedding t-SNE ({pooling_label})" |
|
|
embedding_views.append((contrastive_title, contrastive_embeddings)) |
|
|
|
|
|
if args.report_metrics: |
|
|
compute_metrics("Raw spectrogram", raw_vectors, labels) |
|
|
for title, embedding in embedding_views: |
|
|
compute_metrics(title, embedding, labels) |
|
|
if args.metrics_only: |
|
|
return |
|
|
|
|
|
|
|
|
total_panels = 1 + len(embedding_views) |
|
|
fig_width = 9 * total_panels |
|
|
fig, axes = plt.subplots(1, total_panels, figsize=(fig_width, 7)) |
|
|
if total_panels == 1: |
|
|
axes = [axes] |
|
|
raw_title = "Raw Spectrogram t-SNE" |
|
|
run_tsne(raw_vectors, labels, raw_title, axes[0]) |
|
|
for idx, (title, embedding) in enumerate(embedding_views, start=1): |
|
|
run_tsne(embedding, labels, title, axes[idx]) |
|
|
|
|
|
fig.tight_layout() |
|
|
save_path = Path(args.save_path) |
|
|
save_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
plt.savefig(save_path, dpi=300, bbox_inches="tight") |
|
|
print(f"[INFO] Figure saved to {save_path}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|