import os
import shutil
import netrc
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import gradio as gr
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import torch
from huggingface_hub import hf_hub_download
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler
APP_DIR = Path(__file__).resolve().parent
DEMO_DATA_PATH = APP_DIR / "demo_data.pt"
MOE_DATA_PATH = APP_DIR / "demo_data_moe.pt"
# Where to download the demo tensors from.
# Configure in Space settings if the default repo is private or you need to pin an older revision.
HUB_REPO_ID = os.getenv("LWM_SPECTRO_DEMO_REPO_ID", "wi-lab/lwm-spectro")
HUB_REVISION = os.getenv("LWM_SPECTRO_DEMO_REVISION") # optional git sha / tag / branch
HUB_DEMO_DATA_FILENAME = os.getenv("LWM_SPECTRO_DEMO_DATA_FILENAME", "demo_data.pt")
HUB_MOE_DATA_FILENAME = os.getenv("LWM_SPECTRO_MOE_DATA_FILENAME", "demo_data_moe.pt")
HUB_REPO_TYPES = tuple(
t.strip() for t in os.getenv("LWM_SPECTRO_DEMO_REPO_TYPES", "model").split(",") if t.strip()
)
def _get_hf_token() -> str | None:
# Spaces / HF Hub tooling uses a few common names.
token = (
os.getenv("HF_TOKEN")
or os.getenv("HF_HUB_TOKEN")
or os.getenv("HUGGINGFACEHUB_API_TOKEN")
or os.getenv("HF_API_TOKEN")
or os.getenv("HUGGINGFACE_TOKEN")
or os.getenv("HUGGINGFACE_ACCESS_TOKEN")
)
if token:
return token
# If a token exists in ~/.netrc (common in some environments), use it.
try:
auth = netrc.netrc().authenticators("huggingface.co")
if auth and auth[2]:
return auth[2]
except Exception:
return None
return None
HF_TOKEN = _get_hf_token()
# Fixed ordering for the 14 joint SNR/Doppler labels
JOINT_LABELS = [
("SNR-5dB", "pedestrian"),
("SNR-5dB", "vehicular"),
("SNR0dB", "pedestrian"),
("SNR0dB", "vehicular"),
("SNR5dB", "pedestrian"),
("SNR5dB", "vehicular"),
("SNR10dB", "pedestrian"),
("SNR10dB", "vehicular"),
("SNR15dB", "pedestrian"),
("SNR15dB", "vehicular"),
("SNR20dB", "pedestrian"),
("SNR20dB", "vehicular"),
("SNR25dB", "pedestrian"),
("SNR25dB", "vehicular"),
]
SNR_ORDER = ["SNR-5dB", "SNR0dB", "SNR5dB", "SNR10dB", "SNR15dB", "SNR20dB", "SNR25dB"]
TECH_EXPERT_ORDER = ["LTE", "WiFi", "5G"]
TECH_TO_EXPERT_IDX = {name: idx for idx, name in enumerate(TECH_EXPERT_ORDER)}
DEFAULT_TSNE_SAMPLES_PER_SNR = 500
def _sort_snrs(labels: List[str] | np.ndarray) -> List[str]:
ordering = {snr: idx for idx, snr in enumerate(SNR_ORDER)}
return sorted(labels, key=lambda x: ordering.get(x, len(ordering)))
def load_joint_mapping() -> Dict[str, object]:
label_names = [f"{snr} | {mob}" for snr, mob in JOINT_LABELS]
pair_to_name = {pair: name for pair, name in zip(JOINT_LABELS, label_names)}
name_to_id = {name: idx for idx, name in enumerate(label_names)}
pair_to_id = {pair: idx for idx, pair in enumerate(JOINT_LABELS)}
return {
"pairs": JOINT_LABELS,
"label_names": label_names,
"pair_to_name": pair_to_name,
"name_to_id": name_to_id,
"pair_to_id": pair_to_id,
}
def _safe_load_tensor(path: Path):
# Torch 2.6 defaults to weights_only=True, which breaks our saved dicts.
return torch.load(path, weights_only=False)
def _is_git_lfs_pointer(path: Path) -> bool:
try:
with path.open("rb") as handle:
head = handle.read(256)
return b"git-lfs.github.com/spec" in head
except OSError:
return False
def _normalize_tech_label(value: object) -> object:
if value is None:
return value
text = str(value).strip()
if not text:
return value
normalized = text.lower().replace(" ", "").replace("-", "")
if normalized in {"wifi", "wi-fi", "wi_fi"}:
return "WiFi"
if normalized == "lte":
return "LTE"
if normalized in {"5g", "nr", "5gnr", "sub6", "sub6ghz", "5gsub6", "5gsub6ghz"}:
return "5G"
return text
def _normalize_mobility_label(value: object) -> object:
if value is None:
return value
text = str(value).strip()
if not text:
return value
normalized = text.lower().replace(" ", "").replace("-", "")
if normalized in {"ped", "pedestrian", "walking"}:
return "pedestrian"
if normalized in {"veh", "vehicular", "vehicle", "driving", "car"}:
return "vehicular"
return text
def _normalize_sample(sample: Dict[str, object]) -> Dict[str, object]:
out = dict(sample)
# Schema aliases (some artifacts use longer names).
if "tech" not in out and "technology" in out:
out["tech"] = out.get("technology")
if "mod" not in out and "modulation" in out:
out["mod"] = out.get("modulation")
if "mob" not in out and "mobility" in out:
out["mob"] = out.get("mobility")
if "snr" not in out and "snr_label" in out:
out["snr"] = out.get("snr_label")
out["tech"] = _normalize_tech_label(out.get("tech"))
out["mob"] = _normalize_mobility_label(out.get("mob"))
return out
def _create_dummy_dataset(base_path: Path, moe_path: Path) -> None:
"""Deprecated: kept for backward compatibility, but avoided in production."""
raise RuntimeError("Synthetic on-disk dataset generation disabled")
def _create_dummy_samples() -> List[Dict[str, object]]:
"""In-memory fallback when the filesystem is not writable."""
rng = np.random.default_rng(42)
samples: List[Dict[str, object]] = []
techs = ["LTE", "WiFi", "5G"]
snrs = ["SNR0dB", "SNR10dB", "SNR20dB"]
mods = ["QPSK", "16QAM", "64QAM"]
mobs = ["pedestrian", "vehicular"]
for i in range(30):
tech = techs[i % len(techs)]
snr = snrs[i % len(snrs)]
mob = mobs[i % len(mobs)]
mod = mods[i % len(mods)]
spectrogram = rng.normal(size=(128, 128)).astype(np.float32)
embedding = rng.normal(size=(128,)).astype(np.float32)
moe_embedding = rng.normal(size=(128,)).astype(np.float32)
samples.append(
{
"tech": tech,
"snr": snr,
"mod": mod,
"mob": mob,
"data": spectrogram,
"embedding": embedding,
"moe_embedding": moe_embedding,
}
)
return samples
def _ensure_local_file(local_path: Path, hub_filename: str) -> Optional[Path]:
"""Ensure a file exists locally; try Hub download if missing."""
if local_path.exists() and not _is_git_lfs_pointer(local_path):
return local_path
global LAST_DEMO_DOWNLOAD_ERROR
# Prefer a stored token if present (Spaces sometimes have credentials available
# even when HF_TOKEN env var is not explicitly set).
token = HF_TOKEN or True
# Try configured repo types (default: model). This Space historically used a model repo.
last_exc: Exception | None = None
for repo_type in HUB_REPO_TYPES:
try:
cached = hf_hub_download(
repo_id=HUB_REPO_ID,
filename=hub_filename,
token=token,
repo_type=repo_type,
revision=HUB_REVISION,
)
cached_path = Path(cached)
print(f"[INFO] Using cached Hub file for {hub_filename}: {cached_path} (repo_type={repo_type})")
return cached_path
except Exception as exc:
last_exc = exc
# Final fallback: try downloading from the Space repo itself (useful when artifacts are stored in Space).
try:
cached = hf_hub_download(
repo_id="wi-lab/LWM-Spectro",
filename=hub_filename,
token=token,
repo_type="space",
revision=None,
)
cached_path = Path(cached)
print(f"[INFO] Using cached Space file for {hub_filename}: {cached_path}")
return cached_path
except Exception as exc:
# Persist a short error string for the UI status line.
err = str(last_exc or exc)
if len(err) > 240:
err = err[:240] + "..."
LAST_DEMO_DOWNLOAD_ERROR = err
print(
f"[WARN] Could not download {hub_filename} from Hub (repo_id={HUB_REPO_ID}, repo_types={HUB_REPO_TYPES}, revision={HUB_REVISION or 'main'}: {last_exc}) "
f"or Space repo ({exc}); continuing without it."
)
return None
USING_SYNTHETIC_DATA = False
LAST_DEMO_DOWNLOAD_ERROR: str | None = None
def load_augmented_samples() -> Tuple[List[Dict[str, object]], bool]:
moe_path = _ensure_local_file(MOE_DATA_PATH, HUB_MOE_DATA_FILENAME)
base_path = _ensure_local_file(DEMO_DATA_PATH, HUB_DEMO_DATA_FILENAME)
if moe_path and moe_path.exists() and not _is_git_lfs_pointer(moe_path):
print(f"[INFO] Loading MoE-augmented dataset from {moe_path}")
return _safe_load_tensor(moe_path), True
if base_path and base_path.exists() and not _is_git_lfs_pointer(base_path):
print(f"[WARN] MoE data missing; falling back to base data: {base_path}")
return _safe_load_tensor(base_path), False
# Last resort: in-memory synthetic data (keeps app alive, but clearly not the full demo dataset).
global USING_SYNTHETIC_DATA
USING_SYNTHETIC_DATA = True
print(
"[WARN] Falling back to a tiny synthetic dataset (30 samples). "
"This usually means the real demo_data*.pt could not be downloaded. "
"If the Hub repo is private, add a Space secret named HF_TOKEN with read access."
)
return _create_dummy_samples(), False
def load_data(mapping: Dict[str, object]):
data, has_moe = load_augmented_samples()
pair_to_name = mapping["pair_to_name"]
pair_to_id = mapping["pair_to_id"]
records = []
skipped = 0
for i, sample in enumerate(data):
if not isinstance(sample, dict):
skipped += 1
continue
sample = _normalize_sample(sample)
if not sample.get("tech") or not sample.get("snr") or not sample.get("mob") or not sample.get("mod"):
skipped += 1
continue
if "embedding" not in sample or "data" not in sample:
skipped += 1
continue
embedding = sample["embedding"]
if isinstance(embedding, torch.Tensor):
base_embedding = embedding.detach().cpu().numpy()
else:
base_embedding = np.asarray(embedding)
spectrogram = sample["data"]
if isinstance(spectrogram, torch.Tensor):
flat_spec = spectrogram.numpy().flatten()
else:
flat_spec = np.asarray(spectrogram).flatten()
moe_embedding = sample.get("moe_embedding")
if isinstance(moe_embedding, torch.Tensor):
moe_embedding = moe_embedding.numpy()
elif moe_embedding is not None:
moe_embedding = np.asarray(moe_embedding)
tech_embedding = sample.get("tech_embedding")
if isinstance(tech_embedding, torch.Tensor):
tech_embedding = tech_embedding.numpy()
elif tech_embedding is not None:
tech_embedding = np.asarray(tech_embedding)
if tech_embedding is not None:
tech_embedding = tech_embedding.astype(np.float32, copy=False)
embed_dim_hint = sample.get("tech_embedding_dim") or sample.get("embedding_dim")
try:
embed_dim_hint = int(embed_dim_hint) if embed_dim_hint is not None else None
except (TypeError, ValueError):
embed_dim_hint = None
if tech_embedding is None:
tech_embedding = _select_tech_embedding(base_embedding, sample["tech"], embed_dim_hint)
if tech_embedding is not None:
tech_embedding = tech_embedding.astype(np.float32, copy=False)
pair = (sample["snr"], sample["mob"])
joint_label = pair_to_name.get(pair)
joint_label_id = pair_to_id.get(pair)
tsne_x = sample.get("tsne_x")
tsne_y = sample.get("tsne_y")
tsne_raw_x = sample.get("tsne_raw_x")
tsne_raw_y = sample.get("tsne_raw_y")
records.append(
{
"index": i,
"tech": sample["tech"],
"snr": sample["snr"],
"mod": sample["mod"],
"mob": sample["mob"],
"embedding": base_embedding,
"tech_embedding": tech_embedding,
"moe_embedding": moe_embedding,
"spectrogram": flat_spec,
"joint_label": joint_label,
"joint_label_id": joint_label_id,
"tsne_x": tsne_x,
"tsne_y": tsne_y,
"tsne_raw_x": tsne_raw_x,
"tsne_raw_y": tsne_raw_y,
}
)
df = pd.DataFrame(records)
if skipped:
print(f"[WARN] Skipped {skipped} malformed samples while loading demo data")
print(f"[INFO] Loaded {len(df)} samples (MoE embeddings: {has_moe})")
return df, has_moe
def apply_filters(
dataframe: pd.DataFrame,
tech_filter,
snr_filter,
mod_filter,
mob_filter,
) -> pd.DataFrame:
filtered = dataframe.copy()
if tech_filter:
filtered = filtered[filtered["tech"].isin(tech_filter)]
if snr_filter:
filtered = filtered[filtered["snr"].isin(snr_filter)]
if mod_filter:
filtered = filtered[filtered["mod"].isin(mod_filter)]
if mob_filter:
filtered = filtered[filtered["mob"].isin(mob_filter)]
return filtered
def _select_tech_embedding(flat_embedding: np.ndarray | None, tech: str, embed_dim: Optional[int]) -> Optional[np.ndarray]:
"""Extract the technology-specific expert embedding.
Some artifacts don't include an explicit embedding dimension hint. In that case,
infer `embed_dim = total_dim / num_experts` when divisible.
"""
if flat_embedding is None:
return None
flat_embedding = np.asarray(flat_embedding).reshape(-1)
total = flat_embedding.size
blocks = len(TECH_EXPERT_ORDER)
if blocks <= 0:
return None
inferred_dim = embed_dim
if inferred_dim is None:
if total % blocks != 0:
return None
inferred_dim = total // blocks
try:
inferred_dim = int(inferred_dim)
except (TypeError, ValueError):
return None
if inferred_dim <= 0:
return None
expected = blocks * inferred_dim
if expected != total:
# If metadata is wrong, don't crash; fall back to an even split only if possible.
if total % blocks != 0:
return None
inferred_dim = total // blocks
try:
arr = flat_embedding.reshape(blocks, inferred_dim)
except ValueError:
return None
tech_idx = TECH_TO_EXPERT_IDX.get(str(tech))
if tech_idx is None or tech_idx >= arr.shape[0]:
return arr.mean(axis=0)
return arr[tech_idx]
def _sample_balanced_by_snr(dataframe: pd.DataFrame, samples_per_snr: int, seed: int) -> pd.DataFrame:
if dataframe.empty:
return dataframe
rng = np.random.default_rng(int(seed))
grouped = {snr: grp for snr, grp in dataframe.groupby("snr") if not grp.empty}
if not grouped:
return dataframe.iloc[0:0]
ordered_snrs = sorted(grouped.keys())
sampled_frames: List[pd.DataFrame] = []
for snr_label in ordered_snrs:
group = grouped[snr_label]
if samples_per_snr <= 0 or samples_per_snr >= len(group):
sampled_frames.append(group)
continue
random_state = int(rng.integers(0, 1_000_000_000))
sampled_frames.append(group.sample(n=samples_per_snr, random_state=random_state))
if not sampled_frames:
return dataframe.iloc[0:0]
return pd.concat(sampled_frames).reset_index(drop=True)
def plot_tsne(
tech_filter,
snr_filter,
mod_filter,
mob_filter,
representation,
color_label,
perplexity,
n_iter,
samples_per_snr,
sampling_seed,
):
filtered_df = apply_filters(df, tech_filter, snr_filter, mod_filter, mob_filter)
sampled_df = _sample_balanced_by_snr(filtered_df, samples_per_snr, sampling_seed)
if len(sampled_df) < 5:
fig = go.Figure()
fig.update_layout(
title=f"Not enough samples to plot (n={len(sampled_df)}). Widen filters or increase samples.",
xaxis=dict(visible=False),
yaxis=dict(visible=False),
)
return fig
sampled_df = sampled_df.copy()
color_column = COLOR_OPTIONS.get(color_label, "snr")
if representation == "LWM Embedding":
embed_mask = sampled_df["tech_embedding"].apply(lambda x: x is not None)
if embed_mask.sum() >= 5:
sampled_df = sampled_df.loc[embed_mask].reset_index(drop=True)
features = np.stack(sampled_df["tech_embedding"].values)
title_prefix = "t-SNE of LWM Embedding"
else:
# Fallback: use the full embedding vector so the UI doesn't go blank when
# per-expert metadata is missing in the artifact.
base_mask = sampled_df["embedding"].apply(lambda x: x is not None)
if base_mask.sum() < 5:
fig = go.Figure()
fig.update_layout(
title="No embeddings available for the selected filters.",
xaxis=dict(visible=False),
yaxis=dict(visible=False),
)
return fig
sampled_df = sampled_df.loc[base_mask].reset_index(drop=True)
features = np.stack(sampled_df["embedding"].values)
title_prefix = "t-SNE of LWM Embedding (full vector)"
else:
features = build_tsne_raw_vectors(sampled_df["spectrogram"])
title_prefix = "t-SNE of Raw Spectrogram"
if features.size == 0:
fig = go.Figure()
fig.update_layout(title="No features available for t-SNE.", xaxis=dict(visible=False), yaxis=dict(visible=False))
return fig
features = _standardize_for_tsne(features)
eff_perplexity = min(perplexity, len(sampled_df) - 1)
eff_perplexity = max(5, eff_perplexity)
tsne_kwargs = dict(
n_components=2,
perplexity=eff_perplexity,
random_state=42,
init="pca",
learning_rate="auto",
)
try:
tsne = TSNE(**tsne_kwargs, n_iter=n_iter)
except TypeError:
tsne = TSNE(**tsne_kwargs)
try:
projections = tsne.fit_transform(features)
except Exception:
pca = PCA(n_components=2, random_state=42)
projections = pca.fit_transform(features)
sampled_df["x"] = projections[:, 0]
sampled_df["y"] = projections[:, 1]
category_orders = {}
if color_column == "snr":
category_orders["snr"] = [snr for snr in SNR_ORDER if snr in sampled_df["snr"].unique()]
fig = px.scatter(
sampled_df,
x="x",
y="y",
color=color_column,
hover_data=["tech", "snr", "mod", "mob"],
title=f"{title_prefix} ({len(sampled_df)} samples)",
template="plotly_white",
category_orders=category_orders,
)
height = 680 if color_label == "SNR" else 640
fig.update_layout(legend_title_text=color_label, width=640, height=height)
fig.update_yaxes(scaleanchor="x", scaleratio=1)
return fig
def build_raw_feature_matrix(
samples: pd.Series,
max_components: Optional[int] = 256,
*,
normalize: bool = True,
reduce_dim: bool = True,
) -> np.ndarray:
raw_flat = []
for spec in samples:
arr = np.asarray(spec, dtype=np.float32)
raw_flat.append(arr.reshape(-1))
matrix = np.stack(raw_flat)
matrix = np.nan_to_num(matrix, copy=False)
if normalize:
scaler = StandardScaler()
matrix = scaler.fit_transform(matrix)
if reduce_dim and max_components:
# Cap n_components to valid PCA range: <= min(n_samples-1, n_features)
n_samples, n_features = matrix.shape
if n_samples > 1:
max_valid = min(n_features, max(n_samples - 1, 1))
else:
max_valid = 1
target = min(max_components, max_valid)
if target < 1:
target = 1
if target < n_features:
projector = PCA(n_components=target, random_state=42)
try:
matrix = projector.fit_transform(matrix)
except ValueError:
safe_components = max(1, min(n_samples, n_features) - 1)
safe_components = min(safe_components, target)
if safe_components >= 1:
fallback = PCA(n_components=safe_components, random_state=42)
matrix = fallback.fit_transform(matrix)
return matrix
def build_tsne_raw_vectors(samples: pd.Series, eps: float = 1e-6) -> np.ndarray:
rows: List[np.ndarray] = []
for spec in samples:
arr = np.asarray(spec, dtype=np.float32)
flat = arr.reshape(-1)
mean = float(flat.mean())
std = float(flat.std())
if std < eps:
std = eps
normalized = (flat - mean) / std
rows.append(normalized.astype(np.float32, copy=False))
if not rows:
return np.empty((0, 0), dtype=np.float32)
return np.stack(rows)
def _standardize_for_tsne(features: np.ndarray) -> np.ndarray:
if features.size == 0:
return features
scaler = StandardScaler()
scaled = scaler.fit_transform(features)
scaled = np.nan_to_num(scaled, copy=False, nan=0.0, posinf=0.0, neginf=0.0)
scaled = np.clip(scaled, -1e6, 1e6)
return scaled.astype(np.float32, copy=False)
def stratified_split(filtered_df: pd.DataFrame, train_ratio: float, seed: int) -> Tuple[np.ndarray, np.ndarray]:
rng = np.random.default_rng(int(seed))
train_indices = []
test_indices = []
for label_id, group in filtered_df.groupby("joint_label_id"):
indices = group.index.to_numpy()
if indices.size < 2:
raise ValueError(f"Class '{CLASS_LABELS[int(label_id)]}' needs at least 2 samples for evaluation.")
rng.shuffle(indices)
split = int(round(indices.size * train_ratio))
split = max(1, min(indices.size - 1, split))
train_indices.extend(indices[:split])
test_indices.extend(indices[split:])
return np.array(train_indices), np.array(test_indices)
def select_knn_k(train_labels: np.ndarray, max_k: int = 9) -> int:
if train_labels.size == 0:
return 1
class_counts = pd.Series(train_labels).value_counts()
min_class = int(class_counts.min())
heuristic = int(np.sqrt(train_labels.size))
candidate = max(1, min(max_k, heuristic))
k = max(1, min(candidate, min_class))
if k % 2 == 0 and k > 1:
k -= 1
return k
def plot_confusion_heatmap(
confusion: np.ndarray, label_names: List[str], title: str = "Prototype Classifier Confusion Matrix"
) -> go.Figure:
fig = go.Figure(
data=go.Heatmap(
z=confusion,
x=label_names,
y=label_names,
colorscale="Viridis",
hovertemplate="Predicted %{x}
True %{y}
Count %{z}",
)
)
fig.update_layout(
title=title,
xaxis_title="Predicted",
yaxis_title="True",
xaxis=dict(tickangle=45),
)
return fig
def run_joint_evaluation(train_pct, seed, tech_filter, snr_filter, mod_filter, mob_filter):
if evaluation_disabled:
fig = go.Figure()
fig.update_layout(title="MoE embeddings unavailable", xaxis=dict(visible=False), yaxis=dict(visible=False))
return fig, fig, "MoE embeddings are not available in this Space build."
filtered = apply_filters(joint_eval_df, tech_filter, snr_filter, mod_filter, mob_filter)
if filtered.empty:
fig = go.Figure()
fig.update_layout(title="No samples after filtering", xaxis=dict(visible=False), yaxis=dict(visible=False))
return fig, fig, "No samples match the selected filters."
if filtered["joint_label_id"].nunique() < 2:
fig = go.Figure()
fig.update_layout(title="Need at least two classes", xaxis=dict(visible=False), yaxis=dict(visible=False))
return fig, fig, "Need at least two joint SNR/Doppler classes to evaluate."
filtered = filtered.reset_index(drop=True)
try:
train_idx, test_idx = stratified_split(filtered, train_pct / 100.0, seed)
except ValueError as exc:
fig = go.Figure()
fig.update_layout(title="Unable to split dataset", xaxis=dict(visible=False), yaxis=dict(visible=False))
return fig, fig, str(exc)
labels = filtered["joint_label_id"].to_numpy(dtype=int)
moe_features = np.stack(filtered["moe_embedding"].values)
raw_features = build_raw_feature_matrix(
filtered["spectrogram"],
max_components=None,
normalize=False,
reduce_dim=False,
)
train_labels = labels[train_idx]
knn_k = select_knn_k(train_labels)
moe_metrics = compute_knn_metrics(moe_features, labels, train_idx, test_idx, knn_k, label_lookup=CLASS_LABELS)
raw_metrics = compute_knn_metrics(raw_features, labels, train_idx, test_idx, knn_k, label_lookup=CLASS_LABELS)
moe_fig = plot_confusion_heatmap(
moe_metrics["confusion"], moe_metrics["label_names"], title=f"MoE Embedding Confusion (k={moe_metrics['k']})"
)
raw_fig = plot_confusion_heatmap(
raw_metrics["confusion"], raw_metrics["label_names"], title=f"Raw Spectrogram Confusion (k={raw_metrics['k']})"
)
status = (
f"### Joint SNR/Doppler Metrics\n"
f"**Train/Test Samples:** {len(train_idx)} / {len(test_idx)} | **Train %:** {train_pct}% | **Seed:** {seed} | **k-NN k:** {knn_k}\n\n"
"| Representation | Accuracy | Macro F1 |\n"
"| --- | --- | --- |\n"
f"| **MoE Embedding** | {moe_metrics['accuracy'] * 100:.2f}% | {moe_metrics['macro_f1']:.3f} |\n"
f"| **Raw Spectrogram** | {raw_metrics['accuracy'] * 100:.2f}% | {raw_metrics['macro_f1']:.3f} |"
)
return moe_fig, raw_fig, status
def stratified_split_mod(df_subset: pd.DataFrame, train_ratio: float, seed: int) -> Tuple[np.ndarray, np.ndarray]:
rng = np.random.default_rng(int(seed))
train_idx = []
test_idx = []
for _, group in df_subset.groupby("mod"):
indices = group.index.to_numpy()
if indices.size < 2:
raise ValueError("Each modulation needs at least 2 samples.")
rng.shuffle(indices)
split = int(round(len(indices) * train_ratio))
split = max(1, min(len(indices) - 1, split))
train_idx.extend(indices[:split])
test_idx.extend(indices[split:])
return np.array(train_idx), np.array(test_idx)
def compute_knn_metrics(
features: np.ndarray,
labels: np.ndarray,
train_idx: np.ndarray,
test_idx: np.ndarray,
knn_k: int,
label_lookup: List[str] | None = None,
) -> Dict[str, object]:
train_features = features[train_idx]
test_features = features[test_idx]
train_labels = labels[train_idx]
test_labels = labels[test_idx]
candidate_k = max(1, min(int(knn_k), len(train_labels)))
if candidate_k % 2 == 0 and candidate_k > 1:
candidate_k -= 1
knn = KNeighborsClassifier(n_neighbors=candidate_k, metric="euclidean")
knn.fit(train_features, train_labels)
preds = knn.predict(test_features)
acc = accuracy_score(test_labels, preds)
active_labels = np.unique(np.concatenate([train_labels, test_labels, preds]))
macro = f1_score(test_labels, preds, labels=active_labels, average="macro", zero_division=0)
if label_lookup is None:
label_names = [str(lbl) for lbl in active_labels]
else:
label_names = [label_lookup[int(lbl)] for lbl in active_labels]
cm = confusion_matrix(test_labels, preds, labels=active_labels)
return {
"accuracy": acc,
"macro_f1": macro,
"confusion": cm,
"label_names": label_names,
"k": candidate_k,
}
def evaluate_modulation(tech: str, train_pct: int, seed: int):
if not tech:
fig = go.Figure()
fig.update_layout(title="Select a technology to evaluate.", xaxis=dict(visible=False), yaxis=dict(visible=False))
return fig, fig, "No technology selected."
subset = df[df["tech"] == tech].copy().reset_index(drop=True)
if subset.empty or subset["mod"].nunique() < 2:
fig = go.Figure()
fig.update_layout(
title="Need at least two modulation classes for this technology.",
xaxis=dict(visible=False),
yaxis=dict(visible=False),
)
return fig, fig, "Not enough modulation classes."
try:
train_idx, test_idx = stratified_split_mod(subset, train_pct / 100.0, seed)
except ValueError as exc:
fig = go.Figure()
fig.update_layout(title=str(exc), xaxis=dict(visible=False), yaxis=dict(visible=False))
return fig, fig, str(exc)
labels = subset["mod"].astype(str).to_numpy()
emb_features = np.stack(subset["embedding"].values)
raw_features = build_raw_feature_matrix(
subset["spectrogram"],
max_components=None,
normalize=False,
reduce_dim=False,
)
train_labels = labels[train_idx]
class_counts = pd.Series(train_labels).value_counts()
if class_counts.empty:
fig = go.Figure()
fig.update_layout(title="No modulation classes found.", xaxis=dict(visible=False), yaxis=dict(visible=False))
return fig, fig, "No modulation classes found."
knn_k = select_knn_k(train_labels)
emb_metrics = compute_knn_metrics(emb_features, labels, train_idx, test_idx, knn_k)
raw_metrics = compute_knn_metrics(raw_features, labels, train_idx, test_idx, knn_k)
emb_fig = plot_confusion_heatmap(emb_metrics["confusion"], emb_metrics["label_names"], title="Embedding Confusion")
raw_fig = plot_confusion_heatmap(raw_metrics["confusion"], raw_metrics["label_names"], title="Raw Confusion")
summary = (
f"### {tech} Modulation Metrics\n"
f"**Train/Test Samples:** {len(train_idx)} / {len(test_idx)} | **Classifier:** k-NN (k = {emb_metrics['k']})\n\n"
"| Representation | Accuracy | Macro F1 |\n"
"| --- | --- | --- |\n"
f"| **LWM Embedding** | {emb_metrics['accuracy'] * 100:.2f}% | {emb_metrics['macro_f1']:.3f} |\n"
f"| **Raw Spectrogram** | {raw_metrics['accuracy'] * 100:.2f}% | {raw_metrics['macro_f1']:.3f} |"
)
return emb_fig, raw_fig, summary
def _reshape_spectrogram(spec: np.ndarray) -> np.ndarray:
arr = np.asarray(spec)
if arr.ndim == 1:
side = int(round(arr.size ** 0.5))
if side * side == arr.size:
arr = arr.reshape(side, side)
else:
arr = arr.reshape(-1, side)
elif arr.ndim == 3:
arr = arr.squeeze()
return arr
def _spectrogram_to_image(spec: np.ndarray, title: str) -> np.ndarray:
normalized = spec.astype(np.float32)
if np.isnan(normalized).any():
normalized = np.nan_to_num(normalized)
vmin, vmax = normalized.min(), normalized.max()
if vmax - vmin > 0:
normalized = (normalized - vmin) / (vmax - vmin)
fig, ax = plt.subplots(figsize=(3, 3))
im = ax.imshow(normalized, cmap="turbo", aspect="auto", origin="lower")
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(title, fontsize=8)
cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
cbar.ax.tick_params(labelsize=6)
fig.tight_layout(pad=0.5)
canvas = FigureCanvasAgg(fig)
canvas.draw()
width, height = canvas.get_width_height()
buf = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8).reshape(height, width, 4)
image = buf[..., :3].copy()
plt.close(fig)
return image
def render_spectrogram_gallery(tech, snr, mod, mob, sample_count, seed):
tech_list = [tech] if tech else None
snr_list = [snr] if snr else None
mod_list = [mod] if mod else None
mob_list = [mob] if mob else None
filtered = apply_filters(df, tech_list, snr_list, mod_list, mob_list)
if filtered.empty:
return [], "No spectrograms match the selected filters."
sample_count = max(1, int(sample_count))
rng = np.random.default_rng(int(seed))
if len(filtered) > sample_count:
indices = rng.choice(filtered.index.to_numpy(), size=sample_count, replace=False)
subset = filtered.loc[indices]
else:
subset = filtered
gallery_items = []
for _, row in subset.iterrows():
spec = _reshape_spectrogram(row["spectrogram"])
caption = f"{row['tech']} | {row['mod']} | {row['snr']} | {row['mob']}"
img = _spectrogram_to_image(spec, caption)
gallery_items.append((img, caption))
status = f"Showing {len(subset)} of {len(filtered)} matches."
return gallery_items, status
mapping_info = load_joint_mapping()
df, has_moe_embeddings = load_data(mapping_info)
CLASS_LABELS = mapping_info["label_names"]
DATASET_STATUS = (
f"Dataset loaded: {len(df)} samples | "
f"MoE embeddings: {'yes' if has_moe_embeddings else 'no'} | "
f"HF token detected: {'yes' if HF_TOKEN else 'no'} | "
f"Synthetic fallback: {'yes' if USING_SYNTHETIC_DATA else 'no'} | "
f"Demo repo: {HUB_REPO_ID}@{HUB_REVISION or 'main'} ({','.join(HUB_REPO_TYPES)})"
)
if LAST_DEMO_DOWNLOAD_ERROR:
DATASET_STATUS += f" | Download error: {LAST_DEMO_DOWNLOAD_ERROR}"
has_moe_column = df["moe_embedding"].apply(lambda x: x is not None)
joint_eval_df = df[has_moe_column & df["joint_label_id"].notna()]
tech_choices = sorted(df["tech"].unique())
snr_choices = _sort_snrs(df["snr"].unique())
mod_choices = sorted(df["mod"].unique())
mob_choices = sorted(df["mob"].unique())
TECH_TO_MODS: Dict[str, List[str]] = {
tech: sorted(df.loc[df["tech"] == tech, "mod"].unique().tolist()) for tech in tech_choices
}
COLOR_OPTIONS: Dict[str, str] = {
"SNR": "snr",
"Modulation": "mod",
"Mobility": "mob",
}
default_tech = tech_choices[:1] if tech_choices else []
initial_spec_mod_choices = TECH_TO_MODS.get(default_tech[0], mod_choices) if default_tech else mod_choices
evaluation_disabled = (not has_moe_embeddings) or joint_eval_df.empty
def update_modulation_choices(selected_tech: Optional[str]):
choices = mod_choices
if selected_tech:
choices = TECH_TO_MODS.get(selected_tech, mod_choices)
return gr.Dropdown.update(choices=choices, value=None)
with gr.Blocks(title="LWM-Spectro Lab") as demo:
gr.Markdown("# ๐ฌ LWM-Spectro Interactive Demo")
gr.Markdown(f"**{DATASET_STATUS}**")
gr.Markdown(
"""
Interactive lab for exploring LWM spectrogram embeddings and cached MoE probes.
"""
)
with gr.Tabs():
with gr.Tab("Spectrograms"):
gr.Markdown(
"""
### ๐ Spectrogram Studio
- Peek at the raw 128ร128 Sub-6 GHz I/Q baseband spectrograms that drive the SNR/mobility recognition tasks.
- Filter by technology/SNR/modulation/mobility to understand how diverse the training pool is across scenarios.
- Use the gallery to sanity-check preprocessing before sending samples through LWM or downstream models.
"""
)
with gr.Row():
with gr.Column(scale=1, min_width=320):
spec_tech = gr.Dropdown(
choices=tech_choices,
value=default_tech[0] if default_tech else None,
label="Technology",
)
spec_snr = gr.Dropdown(choices=snr_choices, value=None, label="SNR (optional)")
spec_mod = gr.Dropdown(choices=initial_spec_mod_choices, value=None, label="Modulation (optional)")
spec_mob = gr.Dropdown(choices=mob_choices, value=None, label="Mobility (optional)")
spec_count = gr.Slider(minimum=1, maximum=12, step=1, value=6, label="Samples to show")
spec_seed = gr.Slider(minimum=0, maximum=9999, step=1, value=0, label="Random seed")
spec_btn = gr.Button("Show spectrograms", variant="primary")
with gr.Column(scale=3):
gallery = gr.Gallery(
label="Spectrogram Samples",
columns=[3],
rows=[3],
height=560,
preview=True,
)
gallery_status = gr.Textbox(label="Status", interactive=False)
spec_inputs = [spec_tech, spec_snr, spec_mod, spec_mob, spec_count, spec_seed]
spec_btn.click(render_spectrogram_gallery, inputs=spec_inputs, outputs=[gallery, gallery_status])
demo.load(render_spectrogram_gallery, inputs=spec_inputs, outputs=[gallery, gallery_status])
spec_tech.change(update_modulation_choices, inputs=spec_tech, outputs=spec_mod)
with gr.Tab("t-SNE Analysis"):
gr.Markdown(
"""
### ๐ Embedding vs. Raw Space
- Run quick t-SNE sweeps on either LWM embeddings or raw spectrogram vectors.
- Toggle **Color By** to mirror the "colored by modulation vs. SNR" comparisons from the CLI examples.
- Balanced per-SNR sampling plus configurable perplexity help match the figures you generate locally with `plot/plot_tsne.py`.
"""
)
with gr.Row():
with gr.Column(scale=1, min_width=300):
gr.Markdown("### Filters")
tech_filter = gr.CheckboxGroup(choices=tech_choices, value=default_tech, label="Technology")
snr_filter = gr.Dropdown(
choices=snr_choices, value=None, multiselect=True, label="SNR (Empty = All)"
)
mod_filter = gr.Dropdown(
choices=mod_choices, value=None, multiselect=True, label="Modulation (Empty = All)"
)
mob_filter = gr.Dropdown(
choices=mob_choices, value=None, multiselect=True, label="Mobility (Empty = All)"
)
gr.Markdown("### Visualization Settings")
representation = gr.Radio(
choices=["LWM Embedding", "Raw Spectrogram"],
value="LWM Embedding",
label="Representation",
)
color_by = gr.Dropdown(
choices=list(COLOR_OPTIONS.keys()),
value="SNR",
label="Color By",
)
with gr.Accordion("Advanced t-SNE Settings", open=False):
perplexity = gr.Slider(minimum=5, maximum=50, value=30, step=1, label="Perplexity")
n_iter = gr.Slider(minimum=250, maximum=2000, value=1000, step=50, label="Iterations")
samples_per_snr = gr.Slider(
minimum=20,
maximum=500,
value=DEFAULT_TSNE_SAMPLES_PER_SNR,
step=10,
label="Samples per SNR",
)
sampling_seed = gr.Slider(
minimum=0,
maximum=9999,
value=42,
step=1,
label="Sampling Seed",
)
btn = gr.Button("Update Plot", variant="primary")
with gr.Column(scale=3):
plot = gr.Plot(label="t-SNE Visualization")
btn.click(
plot_tsne,
inputs=[
tech_filter,
snr_filter,
mod_filter,
mob_filter,
representation,
color_by,
perplexity,
n_iter,
samples_per_snr,
sampling_seed,
],
outputs=[plot],
)
demo.load(
plot_tsne,
inputs=[
tech_filter,
snr_filter,
mod_filter,
mob_filter,
representation,
color_by,
perplexity,
n_iter,
samples_per_snr,
sampling_seed,
],
outputs=[plot],
)
with gr.Tab("Modulation Classification"):
gr.Markdown(
"""
### ๐ฏ Lightweight Modulation Head
- Prototype how well the frozen LWM backbone separates modulation formats for each technology using spectrograms as input.
- The adaptive k-NN classifier approximates the behavior of the downstream residual 1D-CNN before heavy training; each tech is evaluated separately to measure its expertโs modulation discrimination.
- Sweep train/test splits and seeds to gauge robustness when only a portion of the dataset is labeled.
"""
)
with gr.Row():
with gr.Column(scale=1, min_width=320):
mod_tech = gr.Dropdown(
choices=tech_choices,
value=default_tech[0] if default_tech else None,
label="Technology",
)
mod_train = gr.Slider(minimum=50, maximum=90, step=5, value=70, label="Training Percentage (%)")
mod_seed = gr.Slider(minimum=0, maximum=9999, step=1, value=42, label="Random Seed")
gr.Markdown("k-NN uses an adaptive k based on the number of modulation classes and available training samples.")
mod_btn = gr.Button("Run modulation evaluation", variant="primary")
with gr.Column(scale=3):
with gr.Row():
emb_plot = gr.Plot(label="Embedding Confusion Matrix")
raw_plot = gr.Plot(label="Raw Confusion Matrix")
mod_summary = gr.Markdown(value="Select a technology and run the evaluation to view metrics.")
mod_btn.click(
evaluate_modulation,
inputs=[mod_tech, mod_train, mod_seed],
outputs=[emb_plot, raw_plot, mod_summary],
)
with gr.Tab("Joint SNR/Doppler Evaluation"):
gr.Markdown(
"""
### ๐ช๏ธ Joint Channel Dynamics Benchmark
- Evaluate the precomputed MoE embeddings on the 14-class joint SNR/Doppler recognition task.
- Mirrors the second stage of our reliability workflow where, without an explicit technology label, the MoE router sends samples to the most relevant expert and mobility-aware cues guide SNR-aware routing.
- Upload or reference Hub-hosted tensors to compare MoE vs. raw spectrogram baselines before fine-tuning heavier heads.
"""
)
if evaluation_disabled:
gr.Markdown(
"โ ๏ธ Precomputed MoE embeddings are not bundled in this Space build. Upload a dataset locally to run evaluations."
)
with gr.Row():
with gr.Column(scale=1, min_width=320):
gr.Markdown("### Evaluation Filters")
eval_tech_filter = gr.CheckboxGroup(
choices=tech_choices,
value=default_tech,
label="Technology",
interactive=not evaluation_disabled,
)
eval_snr_filter = gr.Dropdown(
choices=snr_choices,
value=None,
multiselect=True,
label="SNR (Empty = All)",
interactive=not evaluation_disabled,
)
eval_mod_filter = gr.Dropdown(
choices=mod_choices,
value=None,
multiselect=True,
label="Modulation (Empty = All)",
interactive=not evaluation_disabled,
)
eval_mob_filter = gr.Dropdown(
choices=mob_choices,
value=None,
multiselect=True,
label="Mobility (Empty = All)",
interactive=not evaluation_disabled,
)
gr.Markdown("### Prototype Settings")
train_pct = gr.Slider(
minimum=10,
maximum=80,
step=5,
value=60,
label="Training Percentage (%)",
interactive=not evaluation_disabled,
)
seed = gr.Slider(
minimum=0,
maximum=9999,
step=1,
value=42,
label="Random Seed",
interactive=not evaluation_disabled,
)
eval_btn = gr.Button("Run evaluation", variant="primary", interactive=not evaluation_disabled)
with gr.Column(scale=3):
with gr.Row():
eval_plot = gr.Plot(label="MoE Prototype Confusion")
eval_plot_raw = gr.Plot(label="Raw Prototype Confusion")
eval_status = gr.Markdown(value="Run an evaluation to compare MoE vs raw baselines.")
eval_btn.click(
run_joint_evaluation,
inputs=[train_pct, seed, eval_tech_filter, eval_snr_filter, eval_mod_filter, eval_mob_filter],
outputs=[eval_plot, eval_plot_raw, eval_status],
)
if __name__ == "__main__":
demo.launch()