import os import shutil import netrc from pathlib import Path from typing import Dict, List, Tuple, Optional import gradio as gr import matplotlib.pyplot as plt from matplotlib.backends.backend_agg import FigureCanvasAgg import numpy as np import pandas as pd import plotly.express as px import plotly.graph_objects as go import torch from huggingface_hub import hf_hub_download from sklearn.decomposition import PCA from sklearn.manifold import TSNE from sklearn.metrics import accuracy_score, confusion_matrix, f1_score from sklearn.neighbors import KNeighborsClassifier from sklearn.preprocessing import StandardScaler APP_DIR = Path(__file__).resolve().parent DEMO_DATA_PATH = APP_DIR / "demo_data.pt" MOE_DATA_PATH = APP_DIR / "demo_data_moe.pt" # Where to download the demo tensors from. # Configure in Space settings if the default repo is private or you need to pin an older revision. HUB_REPO_ID = os.getenv("LWM_SPECTRO_DEMO_REPO_ID", "wi-lab/lwm-spectro") HUB_REVISION = os.getenv("LWM_SPECTRO_DEMO_REVISION") # optional git sha / tag / branch HUB_DEMO_DATA_FILENAME = os.getenv("LWM_SPECTRO_DEMO_DATA_FILENAME", "demo_data.pt") HUB_MOE_DATA_FILENAME = os.getenv("LWM_SPECTRO_MOE_DATA_FILENAME", "demo_data_moe.pt") HUB_REPO_TYPES = tuple( t.strip() for t in os.getenv("LWM_SPECTRO_DEMO_REPO_TYPES", "model").split(",") if t.strip() ) def _get_hf_token() -> str | None: # Spaces / HF Hub tooling uses a few common names. token = ( os.getenv("HF_TOKEN") or os.getenv("HF_HUB_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HF_API_TOKEN") or os.getenv("HUGGINGFACE_TOKEN") or os.getenv("HUGGINGFACE_ACCESS_TOKEN") ) if token: return token # If a token exists in ~/.netrc (common in some environments), use it. try: auth = netrc.netrc().authenticators("huggingface.co") if auth and auth[2]: return auth[2] except Exception: return None return None HF_TOKEN = _get_hf_token() # Fixed ordering for the 14 joint SNR/Doppler labels JOINT_LABELS = [ ("SNR-5dB", "pedestrian"), ("SNR-5dB", "vehicular"), ("SNR0dB", "pedestrian"), ("SNR0dB", "vehicular"), ("SNR5dB", "pedestrian"), ("SNR5dB", "vehicular"), ("SNR10dB", "pedestrian"), ("SNR10dB", "vehicular"), ("SNR15dB", "pedestrian"), ("SNR15dB", "vehicular"), ("SNR20dB", "pedestrian"), ("SNR20dB", "vehicular"), ("SNR25dB", "pedestrian"), ("SNR25dB", "vehicular"), ] SNR_ORDER = ["SNR-5dB", "SNR0dB", "SNR5dB", "SNR10dB", "SNR15dB", "SNR20dB", "SNR25dB"] TECH_EXPERT_ORDER = ["LTE", "WiFi", "5G"] TECH_TO_EXPERT_IDX = {name: idx for idx, name in enumerate(TECH_EXPERT_ORDER)} DEFAULT_TSNE_SAMPLES_PER_SNR = 500 def _sort_snrs(labels: List[str] | np.ndarray) -> List[str]: ordering = {snr: idx for idx, snr in enumerate(SNR_ORDER)} return sorted(labels, key=lambda x: ordering.get(x, len(ordering))) def load_joint_mapping() -> Dict[str, object]: label_names = [f"{snr} | {mob}" for snr, mob in JOINT_LABELS] pair_to_name = {pair: name for pair, name in zip(JOINT_LABELS, label_names)} name_to_id = {name: idx for idx, name in enumerate(label_names)} pair_to_id = {pair: idx for idx, pair in enumerate(JOINT_LABELS)} return { "pairs": JOINT_LABELS, "label_names": label_names, "pair_to_name": pair_to_name, "name_to_id": name_to_id, "pair_to_id": pair_to_id, } def _safe_load_tensor(path: Path): # Torch 2.6 defaults to weights_only=True, which breaks our saved dicts. return torch.load(path, weights_only=False) def _is_git_lfs_pointer(path: Path) -> bool: try: with path.open("rb") as handle: head = handle.read(256) return b"git-lfs.github.com/spec" in head except OSError: return False def _normalize_tech_label(value: object) -> object: if value is None: return value text = str(value).strip() if not text: return value normalized = text.lower().replace(" ", "").replace("-", "") if normalized in {"wifi", "wi-fi", "wi_fi"}: return "WiFi" if normalized == "lte": return "LTE" if normalized in {"5g", "nr", "5gnr", "sub6", "sub6ghz", "5gsub6", "5gsub6ghz"}: return "5G" return text def _normalize_mobility_label(value: object) -> object: if value is None: return value text = str(value).strip() if not text: return value normalized = text.lower().replace(" ", "").replace("-", "") if normalized in {"ped", "pedestrian", "walking"}: return "pedestrian" if normalized in {"veh", "vehicular", "vehicle", "driving", "car"}: return "vehicular" return text def _normalize_sample(sample: Dict[str, object]) -> Dict[str, object]: out = dict(sample) # Schema aliases (some artifacts use longer names). if "tech" not in out and "technology" in out: out["tech"] = out.get("technology") if "mod" not in out and "modulation" in out: out["mod"] = out.get("modulation") if "mob" not in out and "mobility" in out: out["mob"] = out.get("mobility") if "snr" not in out and "snr_label" in out: out["snr"] = out.get("snr_label") out["tech"] = _normalize_tech_label(out.get("tech")) out["mob"] = _normalize_mobility_label(out.get("mob")) return out def _create_dummy_dataset(base_path: Path, moe_path: Path) -> None: """Deprecated: kept for backward compatibility, but avoided in production.""" raise RuntimeError("Synthetic on-disk dataset generation disabled") def _create_dummy_samples() -> List[Dict[str, object]]: """In-memory fallback when the filesystem is not writable.""" rng = np.random.default_rng(42) samples: List[Dict[str, object]] = [] techs = ["LTE", "WiFi", "5G"] snrs = ["SNR0dB", "SNR10dB", "SNR20dB"] mods = ["QPSK", "16QAM", "64QAM"] mobs = ["pedestrian", "vehicular"] for i in range(30): tech = techs[i % len(techs)] snr = snrs[i % len(snrs)] mob = mobs[i % len(mobs)] mod = mods[i % len(mods)] spectrogram = rng.normal(size=(128, 128)).astype(np.float32) embedding = rng.normal(size=(128,)).astype(np.float32) moe_embedding = rng.normal(size=(128,)).astype(np.float32) samples.append( { "tech": tech, "snr": snr, "mod": mod, "mob": mob, "data": spectrogram, "embedding": embedding, "moe_embedding": moe_embedding, } ) return samples def _ensure_local_file(local_path: Path, hub_filename: str) -> Optional[Path]: """Ensure a file exists locally; try Hub download if missing.""" if local_path.exists() and not _is_git_lfs_pointer(local_path): return local_path global LAST_DEMO_DOWNLOAD_ERROR # Prefer a stored token if present (Spaces sometimes have credentials available # even when HF_TOKEN env var is not explicitly set). token = HF_TOKEN or True # Try configured repo types (default: model). This Space historically used a model repo. last_exc: Exception | None = None for repo_type in HUB_REPO_TYPES: try: cached = hf_hub_download( repo_id=HUB_REPO_ID, filename=hub_filename, token=token, repo_type=repo_type, revision=HUB_REVISION, ) cached_path = Path(cached) print(f"[INFO] Using cached Hub file for {hub_filename}: {cached_path} (repo_type={repo_type})") return cached_path except Exception as exc: last_exc = exc # Final fallback: try downloading from the Space repo itself (useful when artifacts are stored in Space). try: cached = hf_hub_download( repo_id="wi-lab/LWM-Spectro", filename=hub_filename, token=token, repo_type="space", revision=None, ) cached_path = Path(cached) print(f"[INFO] Using cached Space file for {hub_filename}: {cached_path}") return cached_path except Exception as exc: # Persist a short error string for the UI status line. err = str(last_exc or exc) if len(err) > 240: err = err[:240] + "..." LAST_DEMO_DOWNLOAD_ERROR = err print( f"[WARN] Could not download {hub_filename} from Hub (repo_id={HUB_REPO_ID}, repo_types={HUB_REPO_TYPES}, revision={HUB_REVISION or 'main'}: {last_exc}) " f"or Space repo ({exc}); continuing without it." ) return None USING_SYNTHETIC_DATA = False LAST_DEMO_DOWNLOAD_ERROR: str | None = None def load_augmented_samples() -> Tuple[List[Dict[str, object]], bool]: moe_path = _ensure_local_file(MOE_DATA_PATH, HUB_MOE_DATA_FILENAME) base_path = _ensure_local_file(DEMO_DATA_PATH, HUB_DEMO_DATA_FILENAME) if moe_path and moe_path.exists() and not _is_git_lfs_pointer(moe_path): print(f"[INFO] Loading MoE-augmented dataset from {moe_path}") return _safe_load_tensor(moe_path), True if base_path and base_path.exists() and not _is_git_lfs_pointer(base_path): print(f"[WARN] MoE data missing; falling back to base data: {base_path}") return _safe_load_tensor(base_path), False # Last resort: in-memory synthetic data (keeps app alive, but clearly not the full demo dataset). global USING_SYNTHETIC_DATA USING_SYNTHETIC_DATA = True print( "[WARN] Falling back to a tiny synthetic dataset (30 samples). " "This usually means the real demo_data*.pt could not be downloaded. " "If the Hub repo is private, add a Space secret named HF_TOKEN with read access." ) return _create_dummy_samples(), False def load_data(mapping: Dict[str, object]): data, has_moe = load_augmented_samples() pair_to_name = mapping["pair_to_name"] pair_to_id = mapping["pair_to_id"] records = [] skipped = 0 for i, sample in enumerate(data): if not isinstance(sample, dict): skipped += 1 continue sample = _normalize_sample(sample) if not sample.get("tech") or not sample.get("snr") or not sample.get("mob") or not sample.get("mod"): skipped += 1 continue if "embedding" not in sample or "data" not in sample: skipped += 1 continue embedding = sample["embedding"] if isinstance(embedding, torch.Tensor): base_embedding = embedding.detach().cpu().numpy() else: base_embedding = np.asarray(embedding) spectrogram = sample["data"] if isinstance(spectrogram, torch.Tensor): flat_spec = spectrogram.numpy().flatten() else: flat_spec = np.asarray(spectrogram).flatten() moe_embedding = sample.get("moe_embedding") if isinstance(moe_embedding, torch.Tensor): moe_embedding = moe_embedding.numpy() elif moe_embedding is not None: moe_embedding = np.asarray(moe_embedding) tech_embedding = sample.get("tech_embedding") if isinstance(tech_embedding, torch.Tensor): tech_embedding = tech_embedding.numpy() elif tech_embedding is not None: tech_embedding = np.asarray(tech_embedding) if tech_embedding is not None: tech_embedding = tech_embedding.astype(np.float32, copy=False) embed_dim_hint = sample.get("tech_embedding_dim") or sample.get("embedding_dim") try: embed_dim_hint = int(embed_dim_hint) if embed_dim_hint is not None else None except (TypeError, ValueError): embed_dim_hint = None if tech_embedding is None: tech_embedding = _select_tech_embedding(base_embedding, sample["tech"], embed_dim_hint) if tech_embedding is not None: tech_embedding = tech_embedding.astype(np.float32, copy=False) pair = (sample["snr"], sample["mob"]) joint_label = pair_to_name.get(pair) joint_label_id = pair_to_id.get(pair) tsne_x = sample.get("tsne_x") tsne_y = sample.get("tsne_y") tsne_raw_x = sample.get("tsne_raw_x") tsne_raw_y = sample.get("tsne_raw_y") records.append( { "index": i, "tech": sample["tech"], "snr": sample["snr"], "mod": sample["mod"], "mob": sample["mob"], "embedding": base_embedding, "tech_embedding": tech_embedding, "moe_embedding": moe_embedding, "spectrogram": flat_spec, "joint_label": joint_label, "joint_label_id": joint_label_id, "tsne_x": tsne_x, "tsne_y": tsne_y, "tsne_raw_x": tsne_raw_x, "tsne_raw_y": tsne_raw_y, } ) df = pd.DataFrame(records) if skipped: print(f"[WARN] Skipped {skipped} malformed samples while loading demo data") print(f"[INFO] Loaded {len(df)} samples (MoE embeddings: {has_moe})") return df, has_moe def apply_filters( dataframe: pd.DataFrame, tech_filter, snr_filter, mod_filter, mob_filter, ) -> pd.DataFrame: filtered = dataframe.copy() if tech_filter: filtered = filtered[filtered["tech"].isin(tech_filter)] if snr_filter: filtered = filtered[filtered["snr"].isin(snr_filter)] if mod_filter: filtered = filtered[filtered["mod"].isin(mod_filter)] if mob_filter: filtered = filtered[filtered["mob"].isin(mob_filter)] return filtered def _select_tech_embedding(flat_embedding: np.ndarray | None, tech: str, embed_dim: Optional[int]) -> Optional[np.ndarray]: """Extract the technology-specific expert embedding. Some artifacts don't include an explicit embedding dimension hint. In that case, infer `embed_dim = total_dim / num_experts` when divisible. """ if flat_embedding is None: return None flat_embedding = np.asarray(flat_embedding).reshape(-1) total = flat_embedding.size blocks = len(TECH_EXPERT_ORDER) if blocks <= 0: return None inferred_dim = embed_dim if inferred_dim is None: if total % blocks != 0: return None inferred_dim = total // blocks try: inferred_dim = int(inferred_dim) except (TypeError, ValueError): return None if inferred_dim <= 0: return None expected = blocks * inferred_dim if expected != total: # If metadata is wrong, don't crash; fall back to an even split only if possible. if total % blocks != 0: return None inferred_dim = total // blocks try: arr = flat_embedding.reshape(blocks, inferred_dim) except ValueError: return None tech_idx = TECH_TO_EXPERT_IDX.get(str(tech)) if tech_idx is None or tech_idx >= arr.shape[0]: return arr.mean(axis=0) return arr[tech_idx] def _sample_balanced_by_snr(dataframe: pd.DataFrame, samples_per_snr: int, seed: int) -> pd.DataFrame: if dataframe.empty: return dataframe rng = np.random.default_rng(int(seed)) grouped = {snr: grp for snr, grp in dataframe.groupby("snr") if not grp.empty} if not grouped: return dataframe.iloc[0:0] ordered_snrs = sorted(grouped.keys()) sampled_frames: List[pd.DataFrame] = [] for snr_label in ordered_snrs: group = grouped[snr_label] if samples_per_snr <= 0 or samples_per_snr >= len(group): sampled_frames.append(group) continue random_state = int(rng.integers(0, 1_000_000_000)) sampled_frames.append(group.sample(n=samples_per_snr, random_state=random_state)) if not sampled_frames: return dataframe.iloc[0:0] return pd.concat(sampled_frames).reset_index(drop=True) def plot_tsne( tech_filter, snr_filter, mod_filter, mob_filter, representation, color_label, perplexity, n_iter, samples_per_snr, sampling_seed, ): filtered_df = apply_filters(df, tech_filter, snr_filter, mod_filter, mob_filter) sampled_df = _sample_balanced_by_snr(filtered_df, samples_per_snr, sampling_seed) if len(sampled_df) < 5: fig = go.Figure() fig.update_layout( title=f"Not enough samples to plot (n={len(sampled_df)}). Widen filters or increase samples.", xaxis=dict(visible=False), yaxis=dict(visible=False), ) return fig sampled_df = sampled_df.copy() color_column = COLOR_OPTIONS.get(color_label, "snr") if representation == "LWM Embedding": embed_mask = sampled_df["tech_embedding"].apply(lambda x: x is not None) if embed_mask.sum() >= 5: sampled_df = sampled_df.loc[embed_mask].reset_index(drop=True) features = np.stack(sampled_df["tech_embedding"].values) title_prefix = "t-SNE of LWM Embedding" else: # Fallback: use the full embedding vector so the UI doesn't go blank when # per-expert metadata is missing in the artifact. base_mask = sampled_df["embedding"].apply(lambda x: x is not None) if base_mask.sum() < 5: fig = go.Figure() fig.update_layout( title="No embeddings available for the selected filters.", xaxis=dict(visible=False), yaxis=dict(visible=False), ) return fig sampled_df = sampled_df.loc[base_mask].reset_index(drop=True) features = np.stack(sampled_df["embedding"].values) title_prefix = "t-SNE of LWM Embedding (full vector)" else: features = build_tsne_raw_vectors(sampled_df["spectrogram"]) title_prefix = "t-SNE of Raw Spectrogram" if features.size == 0: fig = go.Figure() fig.update_layout(title="No features available for t-SNE.", xaxis=dict(visible=False), yaxis=dict(visible=False)) return fig features = _standardize_for_tsne(features) eff_perplexity = min(perplexity, len(sampled_df) - 1) eff_perplexity = max(5, eff_perplexity) tsne_kwargs = dict( n_components=2, perplexity=eff_perplexity, random_state=42, init="pca", learning_rate="auto", ) try: tsne = TSNE(**tsne_kwargs, n_iter=n_iter) except TypeError: tsne = TSNE(**tsne_kwargs) try: projections = tsne.fit_transform(features) except Exception: pca = PCA(n_components=2, random_state=42) projections = pca.fit_transform(features) sampled_df["x"] = projections[:, 0] sampled_df["y"] = projections[:, 1] category_orders = {} if color_column == "snr": category_orders["snr"] = [snr for snr in SNR_ORDER if snr in sampled_df["snr"].unique()] fig = px.scatter( sampled_df, x="x", y="y", color=color_column, hover_data=["tech", "snr", "mod", "mob"], title=f"{title_prefix} ({len(sampled_df)} samples)", template="plotly_white", category_orders=category_orders, ) height = 680 if color_label == "SNR" else 640 fig.update_layout(legend_title_text=color_label, width=640, height=height) fig.update_yaxes(scaleanchor="x", scaleratio=1) return fig def build_raw_feature_matrix( samples: pd.Series, max_components: Optional[int] = 256, *, normalize: bool = True, reduce_dim: bool = True, ) -> np.ndarray: raw_flat = [] for spec in samples: arr = np.asarray(spec, dtype=np.float32) raw_flat.append(arr.reshape(-1)) matrix = np.stack(raw_flat) matrix = np.nan_to_num(matrix, copy=False) if normalize: scaler = StandardScaler() matrix = scaler.fit_transform(matrix) if reduce_dim and max_components: # Cap n_components to valid PCA range: <= min(n_samples-1, n_features) n_samples, n_features = matrix.shape if n_samples > 1: max_valid = min(n_features, max(n_samples - 1, 1)) else: max_valid = 1 target = min(max_components, max_valid) if target < 1: target = 1 if target < n_features: projector = PCA(n_components=target, random_state=42) try: matrix = projector.fit_transform(matrix) except ValueError: safe_components = max(1, min(n_samples, n_features) - 1) safe_components = min(safe_components, target) if safe_components >= 1: fallback = PCA(n_components=safe_components, random_state=42) matrix = fallback.fit_transform(matrix) return matrix def build_tsne_raw_vectors(samples: pd.Series, eps: float = 1e-6) -> np.ndarray: rows: List[np.ndarray] = [] for spec in samples: arr = np.asarray(spec, dtype=np.float32) flat = arr.reshape(-1) mean = float(flat.mean()) std = float(flat.std()) if std < eps: std = eps normalized = (flat - mean) / std rows.append(normalized.astype(np.float32, copy=False)) if not rows: return np.empty((0, 0), dtype=np.float32) return np.stack(rows) def _standardize_for_tsne(features: np.ndarray) -> np.ndarray: if features.size == 0: return features scaler = StandardScaler() scaled = scaler.fit_transform(features) scaled = np.nan_to_num(scaled, copy=False, nan=0.0, posinf=0.0, neginf=0.0) scaled = np.clip(scaled, -1e6, 1e6) return scaled.astype(np.float32, copy=False) def stratified_split(filtered_df: pd.DataFrame, train_ratio: float, seed: int) -> Tuple[np.ndarray, np.ndarray]: rng = np.random.default_rng(int(seed)) train_indices = [] test_indices = [] for label_id, group in filtered_df.groupby("joint_label_id"): indices = group.index.to_numpy() if indices.size < 2: raise ValueError(f"Class '{CLASS_LABELS[int(label_id)]}' needs at least 2 samples for evaluation.") rng.shuffle(indices) split = int(round(indices.size * train_ratio)) split = max(1, min(indices.size - 1, split)) train_indices.extend(indices[:split]) test_indices.extend(indices[split:]) return np.array(train_indices), np.array(test_indices) def select_knn_k(train_labels: np.ndarray, max_k: int = 9) -> int: if train_labels.size == 0: return 1 class_counts = pd.Series(train_labels).value_counts() min_class = int(class_counts.min()) heuristic = int(np.sqrt(train_labels.size)) candidate = max(1, min(max_k, heuristic)) k = max(1, min(candidate, min_class)) if k % 2 == 0 and k > 1: k -= 1 return k def plot_confusion_heatmap( confusion: np.ndarray, label_names: List[str], title: str = "Prototype Classifier Confusion Matrix" ) -> go.Figure: fig = go.Figure( data=go.Heatmap( z=confusion, x=label_names, y=label_names, colorscale="Viridis", hovertemplate="Predicted %{x}
True %{y}
Count %{z}", ) ) fig.update_layout( title=title, xaxis_title="Predicted", yaxis_title="True", xaxis=dict(tickangle=45), ) return fig def run_joint_evaluation(train_pct, seed, tech_filter, snr_filter, mod_filter, mob_filter): if evaluation_disabled: fig = go.Figure() fig.update_layout(title="MoE embeddings unavailable", xaxis=dict(visible=False), yaxis=dict(visible=False)) return fig, fig, "MoE embeddings are not available in this Space build." filtered = apply_filters(joint_eval_df, tech_filter, snr_filter, mod_filter, mob_filter) if filtered.empty: fig = go.Figure() fig.update_layout(title="No samples after filtering", xaxis=dict(visible=False), yaxis=dict(visible=False)) return fig, fig, "No samples match the selected filters." if filtered["joint_label_id"].nunique() < 2: fig = go.Figure() fig.update_layout(title="Need at least two classes", xaxis=dict(visible=False), yaxis=dict(visible=False)) return fig, fig, "Need at least two joint SNR/Doppler classes to evaluate." filtered = filtered.reset_index(drop=True) try: train_idx, test_idx = stratified_split(filtered, train_pct / 100.0, seed) except ValueError as exc: fig = go.Figure() fig.update_layout(title="Unable to split dataset", xaxis=dict(visible=False), yaxis=dict(visible=False)) return fig, fig, str(exc) labels = filtered["joint_label_id"].to_numpy(dtype=int) moe_features = np.stack(filtered["moe_embedding"].values) raw_features = build_raw_feature_matrix( filtered["spectrogram"], max_components=None, normalize=False, reduce_dim=False, ) train_labels = labels[train_idx] knn_k = select_knn_k(train_labels) moe_metrics = compute_knn_metrics(moe_features, labels, train_idx, test_idx, knn_k, label_lookup=CLASS_LABELS) raw_metrics = compute_knn_metrics(raw_features, labels, train_idx, test_idx, knn_k, label_lookup=CLASS_LABELS) moe_fig = plot_confusion_heatmap( moe_metrics["confusion"], moe_metrics["label_names"], title=f"MoE Embedding Confusion (k={moe_metrics['k']})" ) raw_fig = plot_confusion_heatmap( raw_metrics["confusion"], raw_metrics["label_names"], title=f"Raw Spectrogram Confusion (k={raw_metrics['k']})" ) status = ( f"### Joint SNR/Doppler Metrics\n" f"**Train/Test Samples:** {len(train_idx)} / {len(test_idx)} | **Train %:** {train_pct}% | **Seed:** {seed} | **k-NN k:** {knn_k}\n\n" "| Representation | Accuracy | Macro F1 |\n" "| --- | --- | --- |\n" f"| **MoE Embedding** | {moe_metrics['accuracy'] * 100:.2f}% | {moe_metrics['macro_f1']:.3f} |\n" f"| **Raw Spectrogram** | {raw_metrics['accuracy'] * 100:.2f}% | {raw_metrics['macro_f1']:.3f} |" ) return moe_fig, raw_fig, status def stratified_split_mod(df_subset: pd.DataFrame, train_ratio: float, seed: int) -> Tuple[np.ndarray, np.ndarray]: rng = np.random.default_rng(int(seed)) train_idx = [] test_idx = [] for _, group in df_subset.groupby("mod"): indices = group.index.to_numpy() if indices.size < 2: raise ValueError("Each modulation needs at least 2 samples.") rng.shuffle(indices) split = int(round(len(indices) * train_ratio)) split = max(1, min(len(indices) - 1, split)) train_idx.extend(indices[:split]) test_idx.extend(indices[split:]) return np.array(train_idx), np.array(test_idx) def compute_knn_metrics( features: np.ndarray, labels: np.ndarray, train_idx: np.ndarray, test_idx: np.ndarray, knn_k: int, label_lookup: List[str] | None = None, ) -> Dict[str, object]: train_features = features[train_idx] test_features = features[test_idx] train_labels = labels[train_idx] test_labels = labels[test_idx] candidate_k = max(1, min(int(knn_k), len(train_labels))) if candidate_k % 2 == 0 and candidate_k > 1: candidate_k -= 1 knn = KNeighborsClassifier(n_neighbors=candidate_k, metric="euclidean") knn.fit(train_features, train_labels) preds = knn.predict(test_features) acc = accuracy_score(test_labels, preds) active_labels = np.unique(np.concatenate([train_labels, test_labels, preds])) macro = f1_score(test_labels, preds, labels=active_labels, average="macro", zero_division=0) if label_lookup is None: label_names = [str(lbl) for lbl in active_labels] else: label_names = [label_lookup[int(lbl)] for lbl in active_labels] cm = confusion_matrix(test_labels, preds, labels=active_labels) return { "accuracy": acc, "macro_f1": macro, "confusion": cm, "label_names": label_names, "k": candidate_k, } def evaluate_modulation(tech: str, train_pct: int, seed: int): if not tech: fig = go.Figure() fig.update_layout(title="Select a technology to evaluate.", xaxis=dict(visible=False), yaxis=dict(visible=False)) return fig, fig, "No technology selected." subset = df[df["tech"] == tech].copy().reset_index(drop=True) if subset.empty or subset["mod"].nunique() < 2: fig = go.Figure() fig.update_layout( title="Need at least two modulation classes for this technology.", xaxis=dict(visible=False), yaxis=dict(visible=False), ) return fig, fig, "Not enough modulation classes." try: train_idx, test_idx = stratified_split_mod(subset, train_pct / 100.0, seed) except ValueError as exc: fig = go.Figure() fig.update_layout(title=str(exc), xaxis=dict(visible=False), yaxis=dict(visible=False)) return fig, fig, str(exc) labels = subset["mod"].astype(str).to_numpy() emb_features = np.stack(subset["embedding"].values) raw_features = build_raw_feature_matrix( subset["spectrogram"], max_components=None, normalize=False, reduce_dim=False, ) train_labels = labels[train_idx] class_counts = pd.Series(train_labels).value_counts() if class_counts.empty: fig = go.Figure() fig.update_layout(title="No modulation classes found.", xaxis=dict(visible=False), yaxis=dict(visible=False)) return fig, fig, "No modulation classes found." knn_k = select_knn_k(train_labels) emb_metrics = compute_knn_metrics(emb_features, labels, train_idx, test_idx, knn_k) raw_metrics = compute_knn_metrics(raw_features, labels, train_idx, test_idx, knn_k) emb_fig = plot_confusion_heatmap(emb_metrics["confusion"], emb_metrics["label_names"], title="Embedding Confusion") raw_fig = plot_confusion_heatmap(raw_metrics["confusion"], raw_metrics["label_names"], title="Raw Confusion") summary = ( f"### {tech} Modulation Metrics\n" f"**Train/Test Samples:** {len(train_idx)} / {len(test_idx)} | **Classifier:** k-NN (k = {emb_metrics['k']})\n\n" "| Representation | Accuracy | Macro F1 |\n" "| --- | --- | --- |\n" f"| **LWM Embedding** | {emb_metrics['accuracy'] * 100:.2f}% | {emb_metrics['macro_f1']:.3f} |\n" f"| **Raw Spectrogram** | {raw_metrics['accuracy'] * 100:.2f}% | {raw_metrics['macro_f1']:.3f} |" ) return emb_fig, raw_fig, summary def _reshape_spectrogram(spec: np.ndarray) -> np.ndarray: arr = np.asarray(spec) if arr.ndim == 1: side = int(round(arr.size ** 0.5)) if side * side == arr.size: arr = arr.reshape(side, side) else: arr = arr.reshape(-1, side) elif arr.ndim == 3: arr = arr.squeeze() return arr def _spectrogram_to_image(spec: np.ndarray, title: str) -> np.ndarray: normalized = spec.astype(np.float32) if np.isnan(normalized).any(): normalized = np.nan_to_num(normalized) vmin, vmax = normalized.min(), normalized.max() if vmax - vmin > 0: normalized = (normalized - vmin) / (vmax - vmin) fig, ax = plt.subplots(figsize=(3, 3)) im = ax.imshow(normalized, cmap="turbo", aspect="auto", origin="lower") ax.set_xticks([]) ax.set_yticks([]) ax.set_title(title, fontsize=8) cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04) cbar.ax.tick_params(labelsize=6) fig.tight_layout(pad=0.5) canvas = FigureCanvasAgg(fig) canvas.draw() width, height = canvas.get_width_height() buf = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8).reshape(height, width, 4) image = buf[..., :3].copy() plt.close(fig) return image def render_spectrogram_gallery(tech, snr, mod, mob, sample_count, seed): tech_list = [tech] if tech else None snr_list = [snr] if snr else None mod_list = [mod] if mod else None mob_list = [mob] if mob else None filtered = apply_filters(df, tech_list, snr_list, mod_list, mob_list) if filtered.empty: return [], "No spectrograms match the selected filters." sample_count = max(1, int(sample_count)) rng = np.random.default_rng(int(seed)) if len(filtered) > sample_count: indices = rng.choice(filtered.index.to_numpy(), size=sample_count, replace=False) subset = filtered.loc[indices] else: subset = filtered gallery_items = [] for _, row in subset.iterrows(): spec = _reshape_spectrogram(row["spectrogram"]) caption = f"{row['tech']} | {row['mod']} | {row['snr']} | {row['mob']}" img = _spectrogram_to_image(spec, caption) gallery_items.append((img, caption)) status = f"Showing {len(subset)} of {len(filtered)} matches." return gallery_items, status mapping_info = load_joint_mapping() df, has_moe_embeddings = load_data(mapping_info) CLASS_LABELS = mapping_info["label_names"] DATASET_STATUS = ( f"Dataset loaded: {len(df)} samples | " f"MoE embeddings: {'yes' if has_moe_embeddings else 'no'} | " f"HF token detected: {'yes' if HF_TOKEN else 'no'} | " f"Synthetic fallback: {'yes' if USING_SYNTHETIC_DATA else 'no'} | " f"Demo repo: {HUB_REPO_ID}@{HUB_REVISION or 'main'} ({','.join(HUB_REPO_TYPES)})" ) if LAST_DEMO_DOWNLOAD_ERROR: DATASET_STATUS += f" | Download error: {LAST_DEMO_DOWNLOAD_ERROR}" has_moe_column = df["moe_embedding"].apply(lambda x: x is not None) joint_eval_df = df[has_moe_column & df["joint_label_id"].notna()] tech_choices = sorted(df["tech"].unique()) snr_choices = _sort_snrs(df["snr"].unique()) mod_choices = sorted(df["mod"].unique()) mob_choices = sorted(df["mob"].unique()) TECH_TO_MODS: Dict[str, List[str]] = { tech: sorted(df.loc[df["tech"] == tech, "mod"].unique().tolist()) for tech in tech_choices } COLOR_OPTIONS: Dict[str, str] = { "SNR": "snr", "Modulation": "mod", "Mobility": "mob", } default_tech = tech_choices[:1] if tech_choices else [] initial_spec_mod_choices = TECH_TO_MODS.get(default_tech[0], mod_choices) if default_tech else mod_choices evaluation_disabled = (not has_moe_embeddings) or joint_eval_df.empty def update_modulation_choices(selected_tech: Optional[str]): choices = mod_choices if selected_tech: choices = TECH_TO_MODS.get(selected_tech, mod_choices) return gr.Dropdown.update(choices=choices, value=None) with gr.Blocks(title="LWM-Spectro Lab") as demo: gr.Markdown("# ๐Ÿ”ฌ LWM-Spectro Interactive Demo") gr.Markdown(f"**{DATASET_STATUS}**") gr.Markdown( """ Interactive lab for exploring LWM spectrogram embeddings and cached MoE probes. """ ) with gr.Tabs(): with gr.Tab("Spectrograms"): gr.Markdown( """ ### ๐Ÿ”Ž Spectrogram Studio - Peek at the raw 128ร—128 Sub-6 GHz I/Q baseband spectrograms that drive the SNR/mobility recognition tasks. - Filter by technology/SNR/modulation/mobility to understand how diverse the training pool is across scenarios. - Use the gallery to sanity-check preprocessing before sending samples through LWM or downstream models. """ ) with gr.Row(): with gr.Column(scale=1, min_width=320): spec_tech = gr.Dropdown( choices=tech_choices, value=default_tech[0] if default_tech else None, label="Technology", ) spec_snr = gr.Dropdown(choices=snr_choices, value=None, label="SNR (optional)") spec_mod = gr.Dropdown(choices=initial_spec_mod_choices, value=None, label="Modulation (optional)") spec_mob = gr.Dropdown(choices=mob_choices, value=None, label="Mobility (optional)") spec_count = gr.Slider(minimum=1, maximum=12, step=1, value=6, label="Samples to show") spec_seed = gr.Slider(minimum=0, maximum=9999, step=1, value=0, label="Random seed") spec_btn = gr.Button("Show spectrograms", variant="primary") with gr.Column(scale=3): gallery = gr.Gallery( label="Spectrogram Samples", columns=[3], rows=[3], height=560, preview=True, ) gallery_status = gr.Textbox(label="Status", interactive=False) spec_inputs = [spec_tech, spec_snr, spec_mod, spec_mob, spec_count, spec_seed] spec_btn.click(render_spectrogram_gallery, inputs=spec_inputs, outputs=[gallery, gallery_status]) demo.load(render_spectrogram_gallery, inputs=spec_inputs, outputs=[gallery, gallery_status]) spec_tech.change(update_modulation_choices, inputs=spec_tech, outputs=spec_mod) with gr.Tab("t-SNE Analysis"): gr.Markdown( """ ### ๐ŸŒ€ Embedding vs. Raw Space - Run quick t-SNE sweeps on either LWM embeddings or raw spectrogram vectors. - Toggle **Color By** to mirror the "colored by modulation vs. SNR" comparisons from the CLI examples. - Balanced per-SNR sampling plus configurable perplexity help match the figures you generate locally with `plot/plot_tsne.py`. """ ) with gr.Row(): with gr.Column(scale=1, min_width=300): gr.Markdown("### Filters") tech_filter = gr.CheckboxGroup(choices=tech_choices, value=default_tech, label="Technology") snr_filter = gr.Dropdown( choices=snr_choices, value=None, multiselect=True, label="SNR (Empty = All)" ) mod_filter = gr.Dropdown( choices=mod_choices, value=None, multiselect=True, label="Modulation (Empty = All)" ) mob_filter = gr.Dropdown( choices=mob_choices, value=None, multiselect=True, label="Mobility (Empty = All)" ) gr.Markdown("### Visualization Settings") representation = gr.Radio( choices=["LWM Embedding", "Raw Spectrogram"], value="LWM Embedding", label="Representation", ) color_by = gr.Dropdown( choices=list(COLOR_OPTIONS.keys()), value="SNR", label="Color By", ) with gr.Accordion("Advanced t-SNE Settings", open=False): perplexity = gr.Slider(minimum=5, maximum=50, value=30, step=1, label="Perplexity") n_iter = gr.Slider(minimum=250, maximum=2000, value=1000, step=50, label="Iterations") samples_per_snr = gr.Slider( minimum=20, maximum=500, value=DEFAULT_TSNE_SAMPLES_PER_SNR, step=10, label="Samples per SNR", ) sampling_seed = gr.Slider( minimum=0, maximum=9999, value=42, step=1, label="Sampling Seed", ) btn = gr.Button("Update Plot", variant="primary") with gr.Column(scale=3): plot = gr.Plot(label="t-SNE Visualization") btn.click( plot_tsne, inputs=[ tech_filter, snr_filter, mod_filter, mob_filter, representation, color_by, perplexity, n_iter, samples_per_snr, sampling_seed, ], outputs=[plot], ) demo.load( plot_tsne, inputs=[ tech_filter, snr_filter, mod_filter, mob_filter, representation, color_by, perplexity, n_iter, samples_per_snr, sampling_seed, ], outputs=[plot], ) with gr.Tab("Modulation Classification"): gr.Markdown( """ ### ๐ŸŽฏ Lightweight Modulation Head - Prototype how well the frozen LWM backbone separates modulation formats for each technology using spectrograms as input. - The adaptive k-NN classifier approximates the behavior of the downstream residual 1D-CNN before heavy training; each tech is evaluated separately to measure its expertโ€™s modulation discrimination. - Sweep train/test splits and seeds to gauge robustness when only a portion of the dataset is labeled. """ ) with gr.Row(): with gr.Column(scale=1, min_width=320): mod_tech = gr.Dropdown( choices=tech_choices, value=default_tech[0] if default_tech else None, label="Technology", ) mod_train = gr.Slider(minimum=50, maximum=90, step=5, value=70, label="Training Percentage (%)") mod_seed = gr.Slider(minimum=0, maximum=9999, step=1, value=42, label="Random Seed") gr.Markdown("k-NN uses an adaptive k based on the number of modulation classes and available training samples.") mod_btn = gr.Button("Run modulation evaluation", variant="primary") with gr.Column(scale=3): with gr.Row(): emb_plot = gr.Plot(label="Embedding Confusion Matrix") raw_plot = gr.Plot(label="Raw Confusion Matrix") mod_summary = gr.Markdown(value="Select a technology and run the evaluation to view metrics.") mod_btn.click( evaluate_modulation, inputs=[mod_tech, mod_train, mod_seed], outputs=[emb_plot, raw_plot, mod_summary], ) with gr.Tab("Joint SNR/Doppler Evaluation"): gr.Markdown( """ ### ๐ŸŒช๏ธ Joint Channel Dynamics Benchmark - Evaluate the precomputed MoE embeddings on the 14-class joint SNR/Doppler recognition task. - Mirrors the second stage of our reliability workflow where, without an explicit technology label, the MoE router sends samples to the most relevant expert and mobility-aware cues guide SNR-aware routing. - Upload or reference Hub-hosted tensors to compare MoE vs. raw spectrogram baselines before fine-tuning heavier heads. """ ) if evaluation_disabled: gr.Markdown( "โš ๏ธ Precomputed MoE embeddings are not bundled in this Space build. Upload a dataset locally to run evaluations." ) with gr.Row(): with gr.Column(scale=1, min_width=320): gr.Markdown("### Evaluation Filters") eval_tech_filter = gr.CheckboxGroup( choices=tech_choices, value=default_tech, label="Technology", interactive=not evaluation_disabled, ) eval_snr_filter = gr.Dropdown( choices=snr_choices, value=None, multiselect=True, label="SNR (Empty = All)", interactive=not evaluation_disabled, ) eval_mod_filter = gr.Dropdown( choices=mod_choices, value=None, multiselect=True, label="Modulation (Empty = All)", interactive=not evaluation_disabled, ) eval_mob_filter = gr.Dropdown( choices=mob_choices, value=None, multiselect=True, label="Mobility (Empty = All)", interactive=not evaluation_disabled, ) gr.Markdown("### Prototype Settings") train_pct = gr.Slider( minimum=10, maximum=80, step=5, value=60, label="Training Percentage (%)", interactive=not evaluation_disabled, ) seed = gr.Slider( minimum=0, maximum=9999, step=1, value=42, label="Random Seed", interactive=not evaluation_disabled, ) eval_btn = gr.Button("Run evaluation", variant="primary", interactive=not evaluation_disabled) with gr.Column(scale=3): with gr.Row(): eval_plot = gr.Plot(label="MoE Prototype Confusion") eval_plot_raw = gr.Plot(label="Raw Prototype Confusion") eval_status = gr.Markdown(value="Run an evaluation to compare MoE vs raw baselines.") eval_btn.click( run_joint_evaluation, inputs=[train_pct, seed, eval_tech_filter, eval_snr_filter, eval_mod_filter, eval_mob_filter], outputs=[eval_plot, eval_plot_raw, eval_status], ) if __name__ == "__main__": demo.launch()