lwm-spectro / task1 /plot_mod_tsne.py
wi-lab's picture
Upload task1/plot_mod_tsne.py with huggingface_hub
157e42d verified
#!/usr/bin/env python3
"""Generate t-SNE plots comparing raw spectrograms and LWM embeddings.
This reproduces the style of `/workspace/lwm-spectro/mcs_separation_plot.png`,
but it automatically loads the most recent LWM checkpoint stored under
`models/` (selecting the run directory with the latest modification time and
within it the checkpoint with the best validation metric).
The script focuses on a single communication family (default: LTE) and treats
all code-rate variants of the same modulation as a single class (three
modulation labels) at SNR20dB static conditions using the 128×128 spectrograms
generated from 512-point FFTs.
Usage example:
```bash
python task1/plot_mod_tsne.py \
--data-root spectrograms/city_1_losangeles/LTE \
--save-path task1/mcs_separation_plot_latest.png
```
Adjust `--samples-per-class` if you want to control how many examples feed
into t-SNE (default 500 per class). Pass `--pooling cls` to use CLS token
embeddings instead of mean pooling.
"""
from __future__ import annotations
import argparse
import glob
import pickle
import random
import re
from pathlib import Path
from collections import Counter, defaultdict
from typing import Dict, Iterable, List, Tuple
import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.manifold import TSNE
from sklearn.metrics import silhouette_score
from sklearn.model_selection import StratifiedKFold
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler
from pretraining.pretrained_model import lwm as lwm_model
def normalize_per_sample(specs: np.ndarray, eps: float = 1e-6) -> np.ndarray:
means = specs.mean(axis=(1, 2), keepdims=True)
stds = specs.std(axis=(1, 2), keepdims=True)
stds = np.maximum(stds, eps)
return ((specs - means) / stds).astype(np.float32, copy=False)
# ---------------------------------------------------------------------------
# Utility helpers
# ---------------------------------------------------------------------------
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--data-root",
default="spectrograms/city_1_losangeles/LTE",
help="Root directory containing modulation folders (default: %(default)s)",
)
parser.add_argument(
"--snr",
default="all",
help=(
"SNR folder name to filter on. Pass 'all' to include every SNR "
"(default: %(default)s)"
),
)
parser.add_argument(
"--mobility",
default="all",
help=(
"Mobility folder to filter on. Pass 'all' to include every mobility "
"(default: %(default)s)"
),
)
parser.add_argument(
"--fft-folder",
default="all",
help=(
"FFT size folder name to use. Pass 'all' to include every FFT variant "
"(default: %(default)s)"
),
)
parser.add_argument(
"--samples-per-class",
type=int,
default=500,
help="Maximum number of samples to draw for each modulation class",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="Random seed for sampling and t-SNE",
)
parser.add_argument(
"--pooling",
choices=("mean", "cls"),
default="mean",
help="How to collapse token embeddings into a single vector",
)
parser.add_argument(
"--save-path",
default="task1/mcs_separation_plot_latest.png",
help="Location to save the generated figure (default: %(default)s)",
)
parser.add_argument(
"--checkpoint",
default=None,
help="Optional explicit checkpoint path; overrides automatic latest selection",
)
parser.add_argument(
"--models-root",
default="models/LTE_models",
help=(
"Directory containing checkpoints. When --checkpoint is not given, "
"the latest/best checkpoint inside this directory will be used "
"(default: %(default)s)"
),
)
parser.add_argument(
"--report-metrics",
action="store_true",
help="Print clustering metrics (silhouette, 5-fold kNN accuracy)",
)
parser.add_argument(
"--metrics-only",
action="store_true",
help="Exit after reporting metrics without running t-SNE or saving figures",
)
parser.add_argument(
"--sampling-mode",
choices=("first", "reservoir"),
default="first",
help="How to down-sample each class (default: first)",
)
return parser.parse_args()
def find_latest_checkpoint(models_root: Path) -> Path:
"""Return a checkpoint path under ``models_root``.
Works with either a parent directory that contains multiple run folders,
or directly with a single run directory containing ``*.pth`` files.
Chooses the checkpoint with the lowest parsed validation value when
available, else falls back to most-recent modification time.
"""
if not models_root.exists():
raise FileNotFoundError(f"Models root not found: {models_root}")
if models_root.is_file():
raise FileNotFoundError(f"Expected a directory, got file: {models_root}")
# If the provided directory itself contains checkpoints, use it directly.
checkpoints = list(models_root.glob("*.pth"))
if not checkpoints:
# Otherwise, look for subdirectories that contain checkpoints and ignore others (e.g., tensorboard)
run_dirs = [p for p in models_root.iterdir() if p.is_dir()]
candidate_runs = [d for d in run_dirs if any(d.glob("*.pth"))]
if not candidate_runs:
raise FileNotFoundError(
f"No checkpoints found under {models_root} (no .pth files in this dir or its run subdirs)"
)
latest_run = max(candidate_runs, key=lambda p: p.stat().st_mtime)
checkpoints = list(latest_run.glob("*.pth"))
def parse_val_metric(path: Path) -> float | None:
match = re.search(r"_val([0-9.]+)", path.name)
if match:
try:
return float(match.group(1))
except ValueError:
return None
return None
parsed = [(parse_val_metric(p), p) for p in checkpoints]
valid = [item for item in parsed if item[0] is not None]
if valid:
valid.sort(key=lambda item: item[0])
return valid[0][1]
# Fallback to most recent modification time
return max(checkpoints, key=lambda p: p.stat().st_mtime)
def list_class_samples(
data_root: Path,
snr: str,
mobility: str,
fft_folder: str,
max_per_class: int,
rng: random.Random,
mode: str,
) -> Dict[str, List[np.ndarray]]:
"""Collect spectrogram samples grouped by modulation class.
Only samples that match the requested SNR, mobility and FFT size are
included (or every sample when the corresponding argument is 'all'). Each
sample is stored as a float32 array with shape (128, 128).
"""
class_samples: Dict[str, List[np.ndarray]] = defaultdict(list)
seen_counts: Dict[str, int] = defaultdict(int)
pattern = str(data_root / "**" / "spectrograms" / "*.pkl")
for path_str in glob.glob(pattern, recursive=True):
path = Path(path_str)
parts = path.parts
if snr != "all" and snr not in parts:
continue
if mobility != "all" and mobility not in parts:
continue
if fft_folder != "all" and fft_folder not in parts:
continue
try:
rel_parts = path.relative_to(data_root).parts
except ValueError:
continue
if len(rel_parts) < 7:
# Expect: modulation / rate / SNR / mobility / window / samples / fft / ...
continue
modulation = rel_parts[0]
class_label = modulation
if mode == "first" and len(class_samples[class_label]) >= max_per_class:
continue
try:
with open(path, "rb") as fh:
data = pickle.load(fh)
except Exception as exc: # pragma: no cover - I/O heavy
print(f"[WARN] Failed to load {path}: {exc}")
continue
if isinstance(data, dict) and "spectrograms" in data:
specs = data["spectrograms"]
elif isinstance(data, np.ndarray):
specs = data
else:
print(f"[WARN] Unknown format in {path}: {type(data)}")
continue
specs = np.asarray(specs)
if specs.ndim == 3:
pass # Already [samples, 128, 128]
elif specs.ndim == 2:
specs = specs[None, ...]
else:
print(f"[WARN] Unexpected spectrogram shape in {path}: {specs.shape}")
continue
for spec in specs:
sample = spec.astype(np.float32)
bucket = class_samples[class_label]
if len(bucket) < max_per_class:
bucket.append(sample)
seen_counts[class_label] += 1
elif mode == "reservoir":
seen_counts[class_label] += 1
j = rng.randint(0, seen_counts[class_label] - 1)
if j < max_per_class:
bucket[j] = sample
else: # mode == "first" and already full
break
return class_samples
def sample_balanced_dataset(
class_samples: Dict[str, List[np.ndarray]],
) -> Tuple[np.ndarray, np.ndarray, List[str]]:
"""Draw up to `samples_per_class` from each class and stack into arrays."""
features: List[np.ndarray] = []
labels: List[str] = []
class_names = sorted(class_samples.keys())
for class_name in class_names:
samples = class_samples[class_name]
if not samples:
continue
features.extend(samples)
labels.extend([class_name] * len(samples))
if not features:
raise RuntimeError("No spectrogram samples collected for the specified filters")
stacked = np.stack(features) # [N, 128, 128]
return stacked, np.array(labels), class_names
def unfold_patches(x: torch.Tensor, patch_size: int = 4) -> torch.Tensor:
# Input shape: [B, 128, 128]
patches_h = x.unfold(1, patch_size, patch_size)
patches = patches_h.unfold(2, patch_size, patch_size)
return patches.contiguous().view(x.shape[0], -1, patch_size * patch_size)
def extract_tokens(spec: np.ndarray, device: torch.device) -> torch.Tensor:
tensor = torch.from_numpy(spec).unsqueeze(0).to(device)
return unfold_patches(tensor) # [1, 1024, 16]
def pool_embeddings(
tokens: torch.Tensor,
model: torch.nn.Module,
pooling: str,
) -> np.ndarray:
# Append CLS token (value 0.2) before passing through the transformer.
cls_token = torch.full((tokens.size(0), 1, tokens.size(-1)), 0.2, device=tokens.device)
inputs = torch.cat([cls_token, tokens], dim=1) # [B, 1025, 16]
with torch.no_grad():
outputs = model(inputs) # [B, 1025, 128]
if pooling == "cls":
pooled = outputs[:, 0]
else: # mean pooling across patch tokens (exclude CLS)
pooled = outputs[:, 1:].mean(dim=1)
return pooled.detach().cpu().numpy()
def run_tsne(x: np.ndarray, labels: np.ndarray, title: str, ax: plt.Axes) -> None:
scaler = StandardScaler()
x_scaled = scaler.fit_transform(x)
# Use a safe perplexity relative to sample count (sklearn requirement: < n_samples).
max_perplexity = max(5, min(30, len(x_scaled) // 10))
perplexity = min(max_perplexity, len(x_scaled) - 1)
perplexity = max(perplexity, 5)
tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42)
embedding = tsne.fit_transform(x_scaled)
class_names = sorted(np.unique(labels))
colors = plt.cm.Set3(np.linspace(0, 1, len(class_names)))
for color, class_name in zip(colors, class_names):
mask = labels == class_name
ax.scatter(embedding[mask, 0], embedding[mask, 1], c=[color], s=18, alpha=0.7, label=class_name)
ax.set_title(title, fontsize=14, fontweight="bold")
ax.set_xlabel("t-SNE Component 1", fontsize=12)
ax.set_ylabel("t-SNE Component 2", fontsize=12)
ax.grid(True, alpha=0.3)
ax.legend(bbox_to_anchor=(1.02, 1), loc="upper left", fontsize=8)
def compute_metrics(name: str, features: np.ndarray, labels: np.ndarray) -> None:
scaler = StandardScaler()
features_scaled = scaler.fit_transform(features)
silhouette = silhouette_score(features_scaled, labels)
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
scores: List[float] = []
for train_idx, test_idx in skf.split(features_scaled, labels):
clf = KNeighborsClassifier(n_neighbors=5)
clf.fit(features_scaled[train_idx], labels[train_idx])
scores.append(clf.score(features_scaled[test_idx], labels[test_idx]))
mean_acc = float(np.mean(scores))
std_acc = float(np.std(scores))
print(
f"[METRIC] {name}: silhouette={silhouette:.3f}, "
f"5-NN accuracy={mean_acc:.3f} ± {std_acc:.3f}"
)
# ---------------------------------------------------------------------------
# Main execution
# ---------------------------------------------------------------------------
def main() -> None:
args = parse_args()
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
data_root = Path(args.data_root)
if not data_root.exists():
raise FileNotFoundError(f"Data root not found: {data_root}")
class_samples = list_class_samples(
data_root,
args.snr,
args.mobility,
args.fft_folder,
args.samples_per_class,
random,
args.sampling_mode,
)
samples, labels, _ = sample_balanced_dataset(class_samples)
print(f"[INFO] Loaded {samples.shape[0]} spectrograms across {len(np.unique(labels))} classes")
class_counts = Counter(labels)
print("[INFO] Class counts:")
for name, count in sorted(class_counts.items()):
print(f" {name}: {count}")
normalized_samples = normalize_per_sample(samples)
# Flatten spectrograms (after optional normalization) for the raw t-SNE view.
raw_vectors = normalized_samples.reshape(normalized_samples.shape[0], -1)
# Prepare LWM model and embeddings for the right subplot.
if args.checkpoint:
checkpoint_path = Path(args.checkpoint)
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
else:
checkpoint_path = find_latest_checkpoint(Path(args.models_root))
print(f"[INFO] Using checkpoint: {checkpoint_path}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] Using device: {device}")
print(f"[INFO] Pooling strategy: {args.pooling}")
model = lwm_model(element_length=16, d_model=128, n_layers=12, max_len=1025, n_heads=8, dropout=0.1)
state_dict = torch.load(checkpoint_path, map_location=device)
if any(k.startswith("module.") for k in state_dict):
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict, strict=False)
model = model.to(device).eval()
embeddings: List[np.ndarray] = []
for spec in normalized_samples:
tokens = extract_tokens(spec, device)
embedding = pool_embeddings(tokens, model, args.pooling)
embeddings.append(embedding.squeeze(0))
embeddings_np = np.vstack(embeddings)
print(f"[INFO] Generated embeddings with shape {embeddings_np.shape}")
if args.report_metrics:
compute_metrics("Raw spectrogram", raw_vectors, labels)
pool_label = "LWM mean" if args.pooling == "mean" else "LWM CLS"
compute_metrics(pool_label, embeddings_np, labels)
if args.metrics_only:
return
# Plot results (two subplots matching the original figure format).
fig, axes = plt.subplots(1, 2, figsize=(18, 7))
raw_title = "Raw Spectrogram t-SNE"
pooling_label = "Mean Pool" if args.pooling == "mean" else "CLS Token"
embedding_title = f"LWM Embedding t-SNE ({pooling_label})"
run_tsne(raw_vectors, labels, raw_title, axes[0])
run_tsne(embeddings_np, labels, embedding_title, axes[1])
fig.tight_layout()
save_path = Path(args.save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(save_path, dpi=300, bbox_inches="tight")
print(f"[INFO] Figure saved to {save_path}")
if __name__ == "__main__":
main()