|
|
|
|
|
"""Generate t-SNE plots comparing raw spectrograms and LWM embeddings. |
|
|
|
|
|
This reproduces the style of `/workspace/lwm-spectro/mcs_separation_plot.png`, |
|
|
but it automatically loads the most recent LWM checkpoint stored under |
|
|
`models/` (selecting the run directory with the latest modification time and |
|
|
within it the checkpoint with the best validation metric). |
|
|
|
|
|
The script focuses on a single communication family (default: LTE) and treats |
|
|
all code-rate variants of the same modulation as a single class (three |
|
|
modulation labels) at SNR20dB static conditions using the 128×128 spectrograms |
|
|
generated from 512-point FFTs. |
|
|
|
|
|
Usage example: |
|
|
|
|
|
```bash |
|
|
python task1/plot_mod_tsne.py \ |
|
|
--data-root spectrograms/city_1_losangeles/LTE \ |
|
|
--save-path task1/mcs_separation_plot_latest.png |
|
|
``` |
|
|
|
|
|
Adjust `--samples-per-class` if you want to control how many examples feed |
|
|
into t-SNE (default 500 per class). Pass `--pooling cls` to use CLS token |
|
|
embeddings instead of mean pooling. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import argparse |
|
|
import glob |
|
|
import pickle |
|
|
import random |
|
|
import re |
|
|
from pathlib import Path |
|
|
from collections import Counter, defaultdict |
|
|
from typing import Dict, Iterable, List, Tuple |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
import torch |
|
|
from sklearn.manifold import TSNE |
|
|
from sklearn.metrics import silhouette_score |
|
|
from sklearn.model_selection import StratifiedKFold |
|
|
from sklearn.neighbors import KNeighborsClassifier |
|
|
from sklearn.preprocessing import StandardScaler |
|
|
|
|
|
from pretraining.pretrained_model import lwm as lwm_model |
|
|
|
|
|
|
|
|
def normalize_per_sample(specs: np.ndarray, eps: float = 1e-6) -> np.ndarray: |
|
|
means = specs.mean(axis=(1, 2), keepdims=True) |
|
|
stds = specs.std(axis=(1, 2), keepdims=True) |
|
|
stds = np.maximum(stds, eps) |
|
|
return ((specs - means) / stds).astype(np.float32, copy=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
|
parser = argparse.ArgumentParser(description=__doc__) |
|
|
parser.add_argument( |
|
|
"--data-root", |
|
|
default="spectrograms/city_1_losangeles/LTE", |
|
|
help="Root directory containing modulation folders (default: %(default)s)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--snr", |
|
|
default="all", |
|
|
help=( |
|
|
"SNR folder name to filter on. Pass 'all' to include every SNR " |
|
|
"(default: %(default)s)" |
|
|
), |
|
|
) |
|
|
parser.add_argument( |
|
|
"--mobility", |
|
|
default="all", |
|
|
help=( |
|
|
"Mobility folder to filter on. Pass 'all' to include every mobility " |
|
|
"(default: %(default)s)" |
|
|
), |
|
|
) |
|
|
parser.add_argument( |
|
|
"--fft-folder", |
|
|
default="all", |
|
|
help=( |
|
|
"FFT size folder name to use. Pass 'all' to include every FFT variant " |
|
|
"(default: %(default)s)" |
|
|
), |
|
|
) |
|
|
parser.add_argument( |
|
|
"--samples-per-class", |
|
|
type=int, |
|
|
default=500, |
|
|
help="Maximum number of samples to draw for each modulation class", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--seed", |
|
|
type=int, |
|
|
default=42, |
|
|
help="Random seed for sampling and t-SNE", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--pooling", |
|
|
choices=("mean", "cls"), |
|
|
default="mean", |
|
|
help="How to collapse token embeddings into a single vector", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--save-path", |
|
|
default="task1/mcs_separation_plot_latest.png", |
|
|
help="Location to save the generated figure (default: %(default)s)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--checkpoint", |
|
|
default=None, |
|
|
help="Optional explicit checkpoint path; overrides automatic latest selection", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--models-root", |
|
|
default="models/LTE_models", |
|
|
help=( |
|
|
"Directory containing checkpoints. When --checkpoint is not given, " |
|
|
"the latest/best checkpoint inside this directory will be used " |
|
|
"(default: %(default)s)" |
|
|
), |
|
|
) |
|
|
parser.add_argument( |
|
|
"--report-metrics", |
|
|
action="store_true", |
|
|
help="Print clustering metrics (silhouette, 5-fold kNN accuracy)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--metrics-only", |
|
|
action="store_true", |
|
|
help="Exit after reporting metrics without running t-SNE or saving figures", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--sampling-mode", |
|
|
choices=("first", "reservoir"), |
|
|
default="first", |
|
|
help="How to down-sample each class (default: first)", |
|
|
) |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def find_latest_checkpoint(models_root: Path) -> Path: |
|
|
"""Return a checkpoint path under ``models_root``. |
|
|
|
|
|
Works with either a parent directory that contains multiple run folders, |
|
|
or directly with a single run directory containing ``*.pth`` files. |
|
|
Chooses the checkpoint with the lowest parsed validation value when |
|
|
available, else falls back to most-recent modification time. |
|
|
""" |
|
|
|
|
|
if not models_root.exists(): |
|
|
raise FileNotFoundError(f"Models root not found: {models_root}") |
|
|
|
|
|
if models_root.is_file(): |
|
|
raise FileNotFoundError(f"Expected a directory, got file: {models_root}") |
|
|
|
|
|
|
|
|
checkpoints = list(models_root.glob("*.pth")) |
|
|
if not checkpoints: |
|
|
|
|
|
run_dirs = [p for p in models_root.iterdir() if p.is_dir()] |
|
|
candidate_runs = [d for d in run_dirs if any(d.glob("*.pth"))] |
|
|
if not candidate_runs: |
|
|
raise FileNotFoundError( |
|
|
f"No checkpoints found under {models_root} (no .pth files in this dir or its run subdirs)" |
|
|
) |
|
|
latest_run = max(candidate_runs, key=lambda p: p.stat().st_mtime) |
|
|
checkpoints = list(latest_run.glob("*.pth")) |
|
|
|
|
|
def parse_val_metric(path: Path) -> float | None: |
|
|
match = re.search(r"_val([0-9.]+)", path.name) |
|
|
if match: |
|
|
try: |
|
|
return float(match.group(1)) |
|
|
except ValueError: |
|
|
return None |
|
|
return None |
|
|
|
|
|
parsed = [(parse_val_metric(p), p) for p in checkpoints] |
|
|
valid = [item for item in parsed if item[0] is not None] |
|
|
if valid: |
|
|
valid.sort(key=lambda item: item[0]) |
|
|
return valid[0][1] |
|
|
|
|
|
|
|
|
return max(checkpoints, key=lambda p: p.stat().st_mtime) |
|
|
|
|
|
|
|
|
def list_class_samples( |
|
|
data_root: Path, |
|
|
snr: str, |
|
|
mobility: str, |
|
|
fft_folder: str, |
|
|
max_per_class: int, |
|
|
rng: random.Random, |
|
|
mode: str, |
|
|
) -> Dict[str, List[np.ndarray]]: |
|
|
"""Collect spectrogram samples grouped by modulation class. |
|
|
|
|
|
Only samples that match the requested SNR, mobility and FFT size are |
|
|
included (or every sample when the corresponding argument is 'all'). Each |
|
|
sample is stored as a float32 array with shape (128, 128). |
|
|
""" |
|
|
|
|
|
class_samples: Dict[str, List[np.ndarray]] = defaultdict(list) |
|
|
seen_counts: Dict[str, int] = defaultdict(int) |
|
|
pattern = str(data_root / "**" / "spectrograms" / "*.pkl") |
|
|
for path_str in glob.glob(pattern, recursive=True): |
|
|
path = Path(path_str) |
|
|
parts = path.parts |
|
|
if snr != "all" and snr not in parts: |
|
|
continue |
|
|
if mobility != "all" and mobility not in parts: |
|
|
continue |
|
|
if fft_folder != "all" and fft_folder not in parts: |
|
|
continue |
|
|
|
|
|
try: |
|
|
rel_parts = path.relative_to(data_root).parts |
|
|
except ValueError: |
|
|
continue |
|
|
|
|
|
if len(rel_parts) < 7: |
|
|
|
|
|
continue |
|
|
|
|
|
modulation = rel_parts[0] |
|
|
class_label = modulation |
|
|
|
|
|
if mode == "first" and len(class_samples[class_label]) >= max_per_class: |
|
|
continue |
|
|
|
|
|
try: |
|
|
with open(path, "rb") as fh: |
|
|
data = pickle.load(fh) |
|
|
except Exception as exc: |
|
|
print(f"[WARN] Failed to load {path}: {exc}") |
|
|
continue |
|
|
|
|
|
if isinstance(data, dict) and "spectrograms" in data: |
|
|
specs = data["spectrograms"] |
|
|
elif isinstance(data, np.ndarray): |
|
|
specs = data |
|
|
else: |
|
|
print(f"[WARN] Unknown format in {path}: {type(data)}") |
|
|
continue |
|
|
|
|
|
specs = np.asarray(specs) |
|
|
if specs.ndim == 3: |
|
|
pass |
|
|
elif specs.ndim == 2: |
|
|
specs = specs[None, ...] |
|
|
else: |
|
|
print(f"[WARN] Unexpected spectrogram shape in {path}: {specs.shape}") |
|
|
continue |
|
|
|
|
|
for spec in specs: |
|
|
sample = spec.astype(np.float32) |
|
|
bucket = class_samples[class_label] |
|
|
|
|
|
if len(bucket) < max_per_class: |
|
|
bucket.append(sample) |
|
|
seen_counts[class_label] += 1 |
|
|
elif mode == "reservoir": |
|
|
seen_counts[class_label] += 1 |
|
|
j = rng.randint(0, seen_counts[class_label] - 1) |
|
|
if j < max_per_class: |
|
|
bucket[j] = sample |
|
|
else: |
|
|
break |
|
|
|
|
|
return class_samples |
|
|
|
|
|
|
|
|
def sample_balanced_dataset( |
|
|
class_samples: Dict[str, List[np.ndarray]], |
|
|
) -> Tuple[np.ndarray, np.ndarray, List[str]]: |
|
|
"""Draw up to `samples_per_class` from each class and stack into arrays.""" |
|
|
|
|
|
features: List[np.ndarray] = [] |
|
|
labels: List[str] = [] |
|
|
class_names = sorted(class_samples.keys()) |
|
|
|
|
|
for class_name in class_names: |
|
|
samples = class_samples[class_name] |
|
|
if not samples: |
|
|
continue |
|
|
features.extend(samples) |
|
|
labels.extend([class_name] * len(samples)) |
|
|
|
|
|
if not features: |
|
|
raise RuntimeError("No spectrogram samples collected for the specified filters") |
|
|
|
|
|
stacked = np.stack(features) |
|
|
return stacked, np.array(labels), class_names |
|
|
|
|
|
|
|
|
def unfold_patches(x: torch.Tensor, patch_size: int = 4) -> torch.Tensor: |
|
|
|
|
|
patches_h = x.unfold(1, patch_size, patch_size) |
|
|
patches = patches_h.unfold(2, patch_size, patch_size) |
|
|
return patches.contiguous().view(x.shape[0], -1, patch_size * patch_size) |
|
|
|
|
|
|
|
|
def extract_tokens(spec: np.ndarray, device: torch.device) -> torch.Tensor: |
|
|
tensor = torch.from_numpy(spec).unsqueeze(0).to(device) |
|
|
return unfold_patches(tensor) |
|
|
|
|
|
|
|
|
def pool_embeddings( |
|
|
tokens: torch.Tensor, |
|
|
model: torch.nn.Module, |
|
|
pooling: str, |
|
|
) -> np.ndarray: |
|
|
|
|
|
cls_token = torch.full((tokens.size(0), 1, tokens.size(-1)), 0.2, device=tokens.device) |
|
|
inputs = torch.cat([cls_token, tokens], dim=1) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(inputs) |
|
|
|
|
|
if pooling == "cls": |
|
|
pooled = outputs[:, 0] |
|
|
else: |
|
|
pooled = outputs[:, 1:].mean(dim=1) |
|
|
|
|
|
return pooled.detach().cpu().numpy() |
|
|
|
|
|
|
|
|
def run_tsne(x: np.ndarray, labels: np.ndarray, title: str, ax: plt.Axes) -> None: |
|
|
scaler = StandardScaler() |
|
|
x_scaled = scaler.fit_transform(x) |
|
|
|
|
|
max_perplexity = max(5, min(30, len(x_scaled) // 10)) |
|
|
perplexity = min(max_perplexity, len(x_scaled) - 1) |
|
|
perplexity = max(perplexity, 5) |
|
|
|
|
|
tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42) |
|
|
embedding = tsne.fit_transform(x_scaled) |
|
|
|
|
|
class_names = sorted(np.unique(labels)) |
|
|
colors = plt.cm.Set3(np.linspace(0, 1, len(class_names))) |
|
|
for color, class_name in zip(colors, class_names): |
|
|
mask = labels == class_name |
|
|
ax.scatter(embedding[mask, 0], embedding[mask, 1], c=[color], s=18, alpha=0.7, label=class_name) |
|
|
|
|
|
ax.set_title(title, fontsize=14, fontweight="bold") |
|
|
ax.set_xlabel("t-SNE Component 1", fontsize=12) |
|
|
ax.set_ylabel("t-SNE Component 2", fontsize=12) |
|
|
ax.grid(True, alpha=0.3) |
|
|
ax.legend(bbox_to_anchor=(1.02, 1), loc="upper left", fontsize=8) |
|
|
|
|
|
|
|
|
def compute_metrics(name: str, features: np.ndarray, labels: np.ndarray) -> None: |
|
|
scaler = StandardScaler() |
|
|
features_scaled = scaler.fit_transform(features) |
|
|
|
|
|
silhouette = silhouette_score(features_scaled, labels) |
|
|
|
|
|
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42) |
|
|
scores: List[float] = [] |
|
|
for train_idx, test_idx in skf.split(features_scaled, labels): |
|
|
clf = KNeighborsClassifier(n_neighbors=5) |
|
|
clf.fit(features_scaled[train_idx], labels[train_idx]) |
|
|
scores.append(clf.score(features_scaled[test_idx], labels[test_idx])) |
|
|
|
|
|
mean_acc = float(np.mean(scores)) |
|
|
std_acc = float(np.std(scores)) |
|
|
print( |
|
|
f"[METRIC] {name}: silhouette={silhouette:.3f}, " |
|
|
f"5-NN accuracy={mean_acc:.3f} ± {std_acc:.3f}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
args = parse_args() |
|
|
|
|
|
random.seed(args.seed) |
|
|
np.random.seed(args.seed) |
|
|
torch.manual_seed(args.seed) |
|
|
|
|
|
data_root = Path(args.data_root) |
|
|
if not data_root.exists(): |
|
|
raise FileNotFoundError(f"Data root not found: {data_root}") |
|
|
|
|
|
class_samples = list_class_samples( |
|
|
data_root, |
|
|
args.snr, |
|
|
args.mobility, |
|
|
args.fft_folder, |
|
|
args.samples_per_class, |
|
|
random, |
|
|
args.sampling_mode, |
|
|
) |
|
|
samples, labels, _ = sample_balanced_dataset(class_samples) |
|
|
print(f"[INFO] Loaded {samples.shape[0]} spectrograms across {len(np.unique(labels))} classes") |
|
|
class_counts = Counter(labels) |
|
|
print("[INFO] Class counts:") |
|
|
for name, count in sorted(class_counts.items()): |
|
|
print(f" {name}: {count}") |
|
|
|
|
|
normalized_samples = normalize_per_sample(samples) |
|
|
|
|
|
|
|
|
raw_vectors = normalized_samples.reshape(normalized_samples.shape[0], -1) |
|
|
|
|
|
|
|
|
if args.checkpoint: |
|
|
checkpoint_path = Path(args.checkpoint) |
|
|
if not checkpoint_path.exists(): |
|
|
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") |
|
|
else: |
|
|
checkpoint_path = find_latest_checkpoint(Path(args.models_root)) |
|
|
print(f"[INFO] Using checkpoint: {checkpoint_path}") |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"[INFO] Using device: {device}") |
|
|
print(f"[INFO] Pooling strategy: {args.pooling}") |
|
|
model = lwm_model(element_length=16, d_model=128, n_layers=12, max_len=1025, n_heads=8, dropout=0.1) |
|
|
state_dict = torch.load(checkpoint_path, map_location=device) |
|
|
if any(k.startswith("module.") for k in state_dict): |
|
|
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} |
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
model = model.to(device).eval() |
|
|
|
|
|
embeddings: List[np.ndarray] = [] |
|
|
for spec in normalized_samples: |
|
|
tokens = extract_tokens(spec, device) |
|
|
embedding = pool_embeddings(tokens, model, args.pooling) |
|
|
embeddings.append(embedding.squeeze(0)) |
|
|
|
|
|
embeddings_np = np.vstack(embeddings) |
|
|
print(f"[INFO] Generated embeddings with shape {embeddings_np.shape}") |
|
|
|
|
|
if args.report_metrics: |
|
|
compute_metrics("Raw spectrogram", raw_vectors, labels) |
|
|
pool_label = "LWM mean" if args.pooling == "mean" else "LWM CLS" |
|
|
compute_metrics(pool_label, embeddings_np, labels) |
|
|
if args.metrics_only: |
|
|
return |
|
|
|
|
|
|
|
|
fig, axes = plt.subplots(1, 2, figsize=(18, 7)) |
|
|
raw_title = "Raw Spectrogram t-SNE" |
|
|
pooling_label = "Mean Pool" if args.pooling == "mean" else "CLS Token" |
|
|
embedding_title = f"LWM Embedding t-SNE ({pooling_label})" |
|
|
run_tsne(raw_vectors, labels, raw_title, axes[0]) |
|
|
run_tsne(embeddings_np, labels, embedding_title, axes[1]) |
|
|
|
|
|
fig.tight_layout() |
|
|
save_path = Path(args.save_path) |
|
|
save_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
plt.savefig(save_path, dpi=300, bbox_inches="tight") |
|
|
print(f"[INFO] Figure saved to {save_path}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|