#!/usr/bin/env python3 """Generate t-SNE plots comparing raw spectrograms and LWM embeddings. This reproduces the style of `/workspace/lwm-spectro/mcs_separation_plot.png`, but it automatically loads the most recent LWM checkpoint stored under `models/` (selecting the run directory with the latest modification time and within it the checkpoint with the best validation metric). The script focuses on a single communication family (default: LTE) and treats all code-rate variants of the same modulation as a single class (three modulation labels) at SNR20dB static conditions using the 128×128 spectrograms generated from 512-point FFTs. Usage example: ```bash python task1/plot_mod_tsne.py \ --data-root spectrograms/city_1_losangeles/LTE \ --save-path task1/mcs_separation_plot_latest.png ``` Adjust `--samples-per-class` if you want to control how many examples feed into t-SNE (default 500 per class). Pass `--pooling cls` to use CLS token embeddings instead of mean pooling. """ from __future__ import annotations import argparse import glob import pickle import random import re from pathlib import Path from collections import Counter, defaultdict from typing import Dict, Iterable, List, Tuple import matplotlib.pyplot as plt import numpy as np import torch from sklearn.manifold import TSNE from sklearn.metrics import silhouette_score from sklearn.model_selection import StratifiedKFold from sklearn.neighbors import KNeighborsClassifier from sklearn.preprocessing import StandardScaler from pretraining.pretrained_model import lwm as lwm_model def normalize_per_sample(specs: np.ndarray, eps: float = 1e-6) -> np.ndarray: means = specs.mean(axis=(1, 2), keepdims=True) stds = specs.std(axis=(1, 2), keepdims=True) stds = np.maximum(stds, eps) return ((specs - means) / stds).astype(np.float32, copy=False) # --------------------------------------------------------------------------- # Utility helpers # --------------------------------------------------------------------------- def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--data-root", default="spectrograms/city_1_losangeles/LTE", help="Root directory containing modulation folders (default: %(default)s)", ) parser.add_argument( "--snr", default="all", help=( "SNR folder name to filter on. Pass 'all' to include every SNR " "(default: %(default)s)" ), ) parser.add_argument( "--mobility", default="all", help=( "Mobility folder to filter on. Pass 'all' to include every mobility " "(default: %(default)s)" ), ) parser.add_argument( "--fft-folder", default="all", help=( "FFT size folder name to use. Pass 'all' to include every FFT variant " "(default: %(default)s)" ), ) parser.add_argument( "--samples-per-class", type=int, default=500, help="Maximum number of samples to draw for each modulation class", ) parser.add_argument( "--seed", type=int, default=42, help="Random seed for sampling and t-SNE", ) parser.add_argument( "--pooling", choices=("mean", "cls"), default="mean", help="How to collapse token embeddings into a single vector", ) parser.add_argument( "--save-path", default="task1/mcs_separation_plot_latest.png", help="Location to save the generated figure (default: %(default)s)", ) parser.add_argument( "--checkpoint", default=None, help="Optional explicit checkpoint path; overrides automatic latest selection", ) parser.add_argument( "--models-root", default="models/LTE_models", help=( "Directory containing checkpoints. When --checkpoint is not given, " "the latest/best checkpoint inside this directory will be used " "(default: %(default)s)" ), ) parser.add_argument( "--report-metrics", action="store_true", help="Print clustering metrics (silhouette, 5-fold kNN accuracy)", ) parser.add_argument( "--metrics-only", action="store_true", help="Exit after reporting metrics without running t-SNE or saving figures", ) parser.add_argument( "--sampling-mode", choices=("first", "reservoir"), default="first", help="How to down-sample each class (default: first)", ) return parser.parse_args() def find_latest_checkpoint(models_root: Path) -> Path: """Return a checkpoint path under ``models_root``. Works with either a parent directory that contains multiple run folders, or directly with a single run directory containing ``*.pth`` files. Chooses the checkpoint with the lowest parsed validation value when available, else falls back to most-recent modification time. """ if not models_root.exists(): raise FileNotFoundError(f"Models root not found: {models_root}") if models_root.is_file(): raise FileNotFoundError(f"Expected a directory, got file: {models_root}") # If the provided directory itself contains checkpoints, use it directly. checkpoints = list(models_root.glob("*.pth")) if not checkpoints: # Otherwise, look for subdirectories that contain checkpoints and ignore others (e.g., tensorboard) run_dirs = [p for p in models_root.iterdir() if p.is_dir()] candidate_runs = [d for d in run_dirs if any(d.glob("*.pth"))] if not candidate_runs: raise FileNotFoundError( f"No checkpoints found under {models_root} (no .pth files in this dir or its run subdirs)" ) latest_run = max(candidate_runs, key=lambda p: p.stat().st_mtime) checkpoints = list(latest_run.glob("*.pth")) def parse_val_metric(path: Path) -> float | None: match = re.search(r"_val([0-9.]+)", path.name) if match: try: return float(match.group(1)) except ValueError: return None return None parsed = [(parse_val_metric(p), p) for p in checkpoints] valid = [item for item in parsed if item[0] is not None] if valid: valid.sort(key=lambda item: item[0]) return valid[0][1] # Fallback to most recent modification time return max(checkpoints, key=lambda p: p.stat().st_mtime) def list_class_samples( data_root: Path, snr: str, mobility: str, fft_folder: str, max_per_class: int, rng: random.Random, mode: str, ) -> Dict[str, List[np.ndarray]]: """Collect spectrogram samples grouped by modulation class. Only samples that match the requested SNR, mobility and FFT size are included (or every sample when the corresponding argument is 'all'). Each sample is stored as a float32 array with shape (128, 128). """ class_samples: Dict[str, List[np.ndarray]] = defaultdict(list) seen_counts: Dict[str, int] = defaultdict(int) pattern = str(data_root / "**" / "spectrograms" / "*.pkl") for path_str in glob.glob(pattern, recursive=True): path = Path(path_str) parts = path.parts if snr != "all" and snr not in parts: continue if mobility != "all" and mobility not in parts: continue if fft_folder != "all" and fft_folder not in parts: continue try: rel_parts = path.relative_to(data_root).parts except ValueError: continue if len(rel_parts) < 7: # Expect: modulation / rate / SNR / mobility / window / samples / fft / ... continue modulation = rel_parts[0] class_label = modulation if mode == "first" and len(class_samples[class_label]) >= max_per_class: continue try: with open(path, "rb") as fh: data = pickle.load(fh) except Exception as exc: # pragma: no cover - I/O heavy print(f"[WARN] Failed to load {path}: {exc}") continue if isinstance(data, dict) and "spectrograms" in data: specs = data["spectrograms"] elif isinstance(data, np.ndarray): specs = data else: print(f"[WARN] Unknown format in {path}: {type(data)}") continue specs = np.asarray(specs) if specs.ndim == 3: pass # Already [samples, 128, 128] elif specs.ndim == 2: specs = specs[None, ...] else: print(f"[WARN] Unexpected spectrogram shape in {path}: {specs.shape}") continue for spec in specs: sample = spec.astype(np.float32) bucket = class_samples[class_label] if len(bucket) < max_per_class: bucket.append(sample) seen_counts[class_label] += 1 elif mode == "reservoir": seen_counts[class_label] += 1 j = rng.randint(0, seen_counts[class_label] - 1) if j < max_per_class: bucket[j] = sample else: # mode == "first" and already full break return class_samples def sample_balanced_dataset( class_samples: Dict[str, List[np.ndarray]], ) -> Tuple[np.ndarray, np.ndarray, List[str]]: """Draw up to `samples_per_class` from each class and stack into arrays.""" features: List[np.ndarray] = [] labels: List[str] = [] class_names = sorted(class_samples.keys()) for class_name in class_names: samples = class_samples[class_name] if not samples: continue features.extend(samples) labels.extend([class_name] * len(samples)) if not features: raise RuntimeError("No spectrogram samples collected for the specified filters") stacked = np.stack(features) # [N, 128, 128] return stacked, np.array(labels), class_names def unfold_patches(x: torch.Tensor, patch_size: int = 4) -> torch.Tensor: # Input shape: [B, 128, 128] patches_h = x.unfold(1, patch_size, patch_size) patches = patches_h.unfold(2, patch_size, patch_size) return patches.contiguous().view(x.shape[0], -1, patch_size * patch_size) def extract_tokens(spec: np.ndarray, device: torch.device) -> torch.Tensor: tensor = torch.from_numpy(spec).unsqueeze(0).to(device) return unfold_patches(tensor) # [1, 1024, 16] def pool_embeddings( tokens: torch.Tensor, model: torch.nn.Module, pooling: str, ) -> np.ndarray: # Append CLS token (value 0.2) before passing through the transformer. cls_token = torch.full((tokens.size(0), 1, tokens.size(-1)), 0.2, device=tokens.device) inputs = torch.cat([cls_token, tokens], dim=1) # [B, 1025, 16] with torch.no_grad(): outputs = model(inputs) # [B, 1025, 128] if pooling == "cls": pooled = outputs[:, 0] else: # mean pooling across patch tokens (exclude CLS) pooled = outputs[:, 1:].mean(dim=1) return pooled.detach().cpu().numpy() def run_tsne(x: np.ndarray, labels: np.ndarray, title: str, ax: plt.Axes) -> None: scaler = StandardScaler() x_scaled = scaler.fit_transform(x) # Use a safe perplexity relative to sample count (sklearn requirement: < n_samples). max_perplexity = max(5, min(30, len(x_scaled) // 10)) perplexity = min(max_perplexity, len(x_scaled) - 1) perplexity = max(perplexity, 5) tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42) embedding = tsne.fit_transform(x_scaled) class_names = sorted(np.unique(labels)) colors = plt.cm.Set3(np.linspace(0, 1, len(class_names))) for color, class_name in zip(colors, class_names): mask = labels == class_name ax.scatter(embedding[mask, 0], embedding[mask, 1], c=[color], s=18, alpha=0.7, label=class_name) ax.set_title(title, fontsize=14, fontweight="bold") ax.set_xlabel("t-SNE Component 1", fontsize=12) ax.set_ylabel("t-SNE Component 2", fontsize=12) ax.grid(True, alpha=0.3) ax.legend(bbox_to_anchor=(1.02, 1), loc="upper left", fontsize=8) def compute_metrics(name: str, features: np.ndarray, labels: np.ndarray) -> None: scaler = StandardScaler() features_scaled = scaler.fit_transform(features) silhouette = silhouette_score(features_scaled, labels) skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42) scores: List[float] = [] for train_idx, test_idx in skf.split(features_scaled, labels): clf = KNeighborsClassifier(n_neighbors=5) clf.fit(features_scaled[train_idx], labels[train_idx]) scores.append(clf.score(features_scaled[test_idx], labels[test_idx])) mean_acc = float(np.mean(scores)) std_acc = float(np.std(scores)) print( f"[METRIC] {name}: silhouette={silhouette:.3f}, " f"5-NN accuracy={mean_acc:.3f} ± {std_acc:.3f}" ) # --------------------------------------------------------------------------- # Main execution # --------------------------------------------------------------------------- def main() -> None: args = parse_args() random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) data_root = Path(args.data_root) if not data_root.exists(): raise FileNotFoundError(f"Data root not found: {data_root}") class_samples = list_class_samples( data_root, args.snr, args.mobility, args.fft_folder, args.samples_per_class, random, args.sampling_mode, ) samples, labels, _ = sample_balanced_dataset(class_samples) print(f"[INFO] Loaded {samples.shape[0]} spectrograms across {len(np.unique(labels))} classes") class_counts = Counter(labels) print("[INFO] Class counts:") for name, count in sorted(class_counts.items()): print(f" {name}: {count}") normalized_samples = normalize_per_sample(samples) # Flatten spectrograms (after optional normalization) for the raw t-SNE view. raw_vectors = normalized_samples.reshape(normalized_samples.shape[0], -1) # Prepare LWM model and embeddings for the right subplot. if args.checkpoint: checkpoint_path = Path(args.checkpoint) if not checkpoint_path.exists(): raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") else: checkpoint_path = find_latest_checkpoint(Path(args.models_root)) print(f"[INFO] Using checkpoint: {checkpoint_path}") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"[INFO] Using device: {device}") print(f"[INFO] Pooling strategy: {args.pooling}") model = lwm_model(element_length=16, d_model=128, n_layers=12, max_len=1025, n_heads=8, dropout=0.1) state_dict = torch.load(checkpoint_path, map_location=device) if any(k.startswith("module.") for k in state_dict): state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} model.load_state_dict(state_dict, strict=False) model = model.to(device).eval() embeddings: List[np.ndarray] = [] for spec in normalized_samples: tokens = extract_tokens(spec, device) embedding = pool_embeddings(tokens, model, args.pooling) embeddings.append(embedding.squeeze(0)) embeddings_np = np.vstack(embeddings) print(f"[INFO] Generated embeddings with shape {embeddings_np.shape}") if args.report_metrics: compute_metrics("Raw spectrogram", raw_vectors, labels) pool_label = "LWM mean" if args.pooling == "mean" else "LWM CLS" compute_metrics(pool_label, embeddings_np, labels) if args.metrics_only: return # Plot results (two subplots matching the original figure format). fig, axes = plt.subplots(1, 2, figsize=(18, 7)) raw_title = "Raw Spectrogram t-SNE" pooling_label = "Mean Pool" if args.pooling == "mean" else "CLS Token" embedding_title = f"LWM Embedding t-SNE ({pooling_label})" run_tsne(raw_vectors, labels, raw_title, axes[0]) run_tsne(embeddings_np, labels, embedding_title, axes[1]) fig.tight_layout() save_path = Path(args.save_path) save_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(save_path, dpi=300, bbox_inches="tight") print(f"[INFO] Figure saved to {save_path}") if __name__ == "__main__": main()