Spaces:
Running
Running
| import os | |
| import shutil | |
| import netrc | |
| from pathlib import Path | |
| from typing import Dict, List, Tuple, Optional | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| from matplotlib.backends.backend_agg import FigureCanvasAgg | |
| import numpy as np | |
| import pandas as pd | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from sklearn.decomposition import PCA | |
| from sklearn.manifold import TSNE | |
| from sklearn.metrics import accuracy_score, confusion_matrix, f1_score | |
| from sklearn.neighbors import KNeighborsClassifier | |
| from sklearn.preprocessing import StandardScaler | |
| APP_DIR = Path(__file__).resolve().parent | |
| DEMO_DATA_PATH = APP_DIR / "demo_data.pt" | |
| MOE_DATA_PATH = APP_DIR / "demo_data_moe.pt" | |
| # Where to download the demo tensors from. | |
| # Configure in Space settings if the default repo is private or you need to pin an older revision. | |
| HUB_REPO_ID = os.getenv("LWM_SPECTRO_DEMO_REPO_ID", "wi-lab/lwm-spectro") | |
| HUB_REVISION = os.getenv("LWM_SPECTRO_DEMO_REVISION") # optional git sha / tag / branch | |
| HUB_DEMO_DATA_FILENAME = os.getenv("LWM_SPECTRO_DEMO_DATA_FILENAME", "demo_data.pt") | |
| HUB_MOE_DATA_FILENAME = os.getenv("LWM_SPECTRO_MOE_DATA_FILENAME", "demo_data_moe.pt") | |
| HUB_REPO_TYPES = tuple( | |
| t.strip() for t in os.getenv("LWM_SPECTRO_DEMO_REPO_TYPES", "model").split(",") if t.strip() | |
| ) | |
| def _get_hf_token() -> str | None: | |
| # Spaces / HF Hub tooling uses a few common names. | |
| token = ( | |
| os.getenv("HF_TOKEN") | |
| or os.getenv("HF_HUB_TOKEN") | |
| or os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| or os.getenv("HF_API_TOKEN") | |
| or os.getenv("HUGGINGFACE_TOKEN") | |
| or os.getenv("HUGGINGFACE_ACCESS_TOKEN") | |
| ) | |
| if token: | |
| return token | |
| # If a token exists in ~/.netrc (common in some environments), use it. | |
| try: | |
| auth = netrc.netrc().authenticators("huggingface.co") | |
| if auth and auth[2]: | |
| return auth[2] | |
| except Exception: | |
| return None | |
| return None | |
| HF_TOKEN = _get_hf_token() | |
| # Fixed ordering for the 14 joint SNR/Doppler labels | |
| JOINT_LABELS = [ | |
| ("SNR-5dB", "pedestrian"), | |
| ("SNR-5dB", "vehicular"), | |
| ("SNR0dB", "pedestrian"), | |
| ("SNR0dB", "vehicular"), | |
| ("SNR5dB", "pedestrian"), | |
| ("SNR5dB", "vehicular"), | |
| ("SNR10dB", "pedestrian"), | |
| ("SNR10dB", "vehicular"), | |
| ("SNR15dB", "pedestrian"), | |
| ("SNR15dB", "vehicular"), | |
| ("SNR20dB", "pedestrian"), | |
| ("SNR20dB", "vehicular"), | |
| ("SNR25dB", "pedestrian"), | |
| ("SNR25dB", "vehicular"), | |
| ] | |
| SNR_ORDER = ["SNR-5dB", "SNR0dB", "SNR5dB", "SNR10dB", "SNR15dB", "SNR20dB", "SNR25dB"] | |
| TECH_EXPERT_ORDER = ["LTE", "WiFi", "5G"] | |
| TECH_TO_EXPERT_IDX = {name: idx for idx, name in enumerate(TECH_EXPERT_ORDER)} | |
| DEFAULT_TSNE_SAMPLES_PER_SNR = 500 | |
| def _sort_snrs(labels: List[str] | np.ndarray) -> List[str]: | |
| ordering = {snr: idx for idx, snr in enumerate(SNR_ORDER)} | |
| return sorted(labels, key=lambda x: ordering.get(x, len(ordering))) | |
| def load_joint_mapping() -> Dict[str, object]: | |
| label_names = [f"{snr} | {mob}" for snr, mob in JOINT_LABELS] | |
| pair_to_name = {pair: name for pair, name in zip(JOINT_LABELS, label_names)} | |
| name_to_id = {name: idx for idx, name in enumerate(label_names)} | |
| pair_to_id = {pair: idx for idx, pair in enumerate(JOINT_LABELS)} | |
| return { | |
| "pairs": JOINT_LABELS, | |
| "label_names": label_names, | |
| "pair_to_name": pair_to_name, | |
| "name_to_id": name_to_id, | |
| "pair_to_id": pair_to_id, | |
| } | |
| def _safe_load_tensor(path: Path): | |
| # Torch 2.6 defaults to weights_only=True, which breaks our saved dicts. | |
| return torch.load(path, weights_only=False) | |
| def _is_git_lfs_pointer(path: Path) -> bool: | |
| try: | |
| with path.open("rb") as handle: | |
| head = handle.read(256) | |
| return b"git-lfs.github.com/spec" in head | |
| except OSError: | |
| return False | |
| def _normalize_tech_label(value: object) -> object: | |
| if value is None: | |
| return value | |
| text = str(value).strip() | |
| if not text: | |
| return value | |
| normalized = text.lower().replace(" ", "").replace("-", "") | |
| if normalized in {"wifi", "wi-fi", "wi_fi"}: | |
| return "WiFi" | |
| if normalized == "lte": | |
| return "LTE" | |
| if normalized in {"5g", "nr", "5gnr", "sub6", "sub6ghz", "5gsub6", "5gsub6ghz"}: | |
| return "5G" | |
| return text | |
| def _normalize_mobility_label(value: object) -> object: | |
| if value is None: | |
| return value | |
| text = str(value).strip() | |
| if not text: | |
| return value | |
| normalized = text.lower().replace(" ", "").replace("-", "") | |
| if normalized in {"ped", "pedestrian", "walking"}: | |
| return "pedestrian" | |
| if normalized in {"veh", "vehicular", "vehicle", "driving", "car"}: | |
| return "vehicular" | |
| return text | |
| def _normalize_sample(sample: Dict[str, object]) -> Dict[str, object]: | |
| out = dict(sample) | |
| # Schema aliases (some artifacts use longer names). | |
| if "tech" not in out and "technology" in out: | |
| out["tech"] = out.get("technology") | |
| if "mod" not in out and "modulation" in out: | |
| out["mod"] = out.get("modulation") | |
| if "mob" not in out and "mobility" in out: | |
| out["mob"] = out.get("mobility") | |
| if "snr" not in out and "snr_label" in out: | |
| out["snr"] = out.get("snr_label") | |
| out["tech"] = _normalize_tech_label(out.get("tech")) | |
| out["mob"] = _normalize_mobility_label(out.get("mob")) | |
| return out | |
| def _create_dummy_dataset(base_path: Path, moe_path: Path) -> None: | |
| """Deprecated: kept for backward compatibility, but avoided in production.""" | |
| raise RuntimeError("Synthetic on-disk dataset generation disabled") | |
| def _create_dummy_samples() -> List[Dict[str, object]]: | |
| """In-memory fallback when the filesystem is not writable.""" | |
| rng = np.random.default_rng(42) | |
| samples: List[Dict[str, object]] = [] | |
| techs = ["LTE", "WiFi", "5G"] | |
| snrs = ["SNR0dB", "SNR10dB", "SNR20dB"] | |
| mods = ["QPSK", "16QAM", "64QAM"] | |
| mobs = ["pedestrian", "vehicular"] | |
| for i in range(30): | |
| tech = techs[i % len(techs)] | |
| snr = snrs[i % len(snrs)] | |
| mob = mobs[i % len(mobs)] | |
| mod = mods[i % len(mods)] | |
| spectrogram = rng.normal(size=(128, 128)).astype(np.float32) | |
| embedding = rng.normal(size=(128,)).astype(np.float32) | |
| moe_embedding = rng.normal(size=(128,)).astype(np.float32) | |
| samples.append( | |
| { | |
| "tech": tech, | |
| "snr": snr, | |
| "mod": mod, | |
| "mob": mob, | |
| "data": spectrogram, | |
| "embedding": embedding, | |
| "moe_embedding": moe_embedding, | |
| } | |
| ) | |
| return samples | |
| def _ensure_local_file(local_path: Path, hub_filename: str) -> Optional[Path]: | |
| """Ensure a file exists locally; try Hub download if missing.""" | |
| if local_path.exists() and not _is_git_lfs_pointer(local_path): | |
| return local_path | |
| global LAST_DEMO_DOWNLOAD_ERROR | |
| # Prefer a stored token if present (Spaces sometimes have credentials available | |
| # even when HF_TOKEN env var is not explicitly set). | |
| token = HF_TOKEN or True | |
| # Try configured repo types (default: model). This Space historically used a model repo. | |
| last_exc: Exception | None = None | |
| for repo_type in HUB_REPO_TYPES: | |
| try: | |
| cached = hf_hub_download( | |
| repo_id=HUB_REPO_ID, | |
| filename=hub_filename, | |
| token=token, | |
| repo_type=repo_type, | |
| revision=HUB_REVISION, | |
| ) | |
| cached_path = Path(cached) | |
| print(f"[INFO] Using cached Hub file for {hub_filename}: {cached_path} (repo_type={repo_type})") | |
| return cached_path | |
| except Exception as exc: | |
| last_exc = exc | |
| # Final fallback: try downloading from the Space repo itself (useful when artifacts are stored in Space). | |
| try: | |
| cached = hf_hub_download( | |
| repo_id="wi-lab/LWM-Spectro", | |
| filename=hub_filename, | |
| token=token, | |
| repo_type="space", | |
| revision=None, | |
| ) | |
| cached_path = Path(cached) | |
| print(f"[INFO] Using cached Space file for {hub_filename}: {cached_path}") | |
| return cached_path | |
| except Exception as exc: | |
| # Persist a short error string for the UI status line. | |
| err = str(last_exc or exc) | |
| if len(err) > 240: | |
| err = err[:240] + "..." | |
| LAST_DEMO_DOWNLOAD_ERROR = err | |
| print( | |
| f"[WARN] Could not download {hub_filename} from Hub (repo_id={HUB_REPO_ID}, repo_types={HUB_REPO_TYPES}, revision={HUB_REVISION or 'main'}: {last_exc}) " | |
| f"or Space repo ({exc}); continuing without it." | |
| ) | |
| return None | |
| USING_SYNTHETIC_DATA = False | |
| LAST_DEMO_DOWNLOAD_ERROR: str | None = None | |
| def load_augmented_samples() -> Tuple[List[Dict[str, object]], bool]: | |
| moe_path = _ensure_local_file(MOE_DATA_PATH, HUB_MOE_DATA_FILENAME) | |
| base_path = _ensure_local_file(DEMO_DATA_PATH, HUB_DEMO_DATA_FILENAME) | |
| if moe_path and moe_path.exists() and not _is_git_lfs_pointer(moe_path): | |
| print(f"[INFO] Loading MoE-augmented dataset from {moe_path}") | |
| return _safe_load_tensor(moe_path), True | |
| if base_path and base_path.exists() and not _is_git_lfs_pointer(base_path): | |
| print(f"[WARN] MoE data missing; falling back to base data: {base_path}") | |
| return _safe_load_tensor(base_path), False | |
| # Last resort: in-memory synthetic data (keeps app alive, but clearly not the full demo dataset). | |
| global USING_SYNTHETIC_DATA | |
| USING_SYNTHETIC_DATA = True | |
| print( | |
| "[WARN] Falling back to a tiny synthetic dataset (30 samples). " | |
| "This usually means the real demo_data*.pt could not be downloaded. " | |
| "If the Hub repo is private, add a Space secret named HF_TOKEN with read access." | |
| ) | |
| return _create_dummy_samples(), False | |
| def load_data(mapping: Dict[str, object]): | |
| data, has_moe = load_augmented_samples() | |
| pair_to_name = mapping["pair_to_name"] | |
| pair_to_id = mapping["pair_to_id"] | |
| records = [] | |
| skipped = 0 | |
| for i, sample in enumerate(data): | |
| if not isinstance(sample, dict): | |
| skipped += 1 | |
| continue | |
| sample = _normalize_sample(sample) | |
| if not sample.get("tech") or not sample.get("snr") or not sample.get("mob") or not sample.get("mod"): | |
| skipped += 1 | |
| continue | |
| if "embedding" not in sample or "data" not in sample: | |
| skipped += 1 | |
| continue | |
| embedding = sample["embedding"] | |
| if isinstance(embedding, torch.Tensor): | |
| base_embedding = embedding.detach().cpu().numpy() | |
| else: | |
| base_embedding = np.asarray(embedding) | |
| spectrogram = sample["data"] | |
| if isinstance(spectrogram, torch.Tensor): | |
| flat_spec = spectrogram.numpy().flatten() | |
| else: | |
| flat_spec = np.asarray(spectrogram).flatten() | |
| moe_embedding = sample.get("moe_embedding") | |
| if isinstance(moe_embedding, torch.Tensor): | |
| moe_embedding = moe_embedding.numpy() | |
| elif moe_embedding is not None: | |
| moe_embedding = np.asarray(moe_embedding) | |
| tech_embedding = sample.get("tech_embedding") | |
| if isinstance(tech_embedding, torch.Tensor): | |
| tech_embedding = tech_embedding.numpy() | |
| elif tech_embedding is not None: | |
| tech_embedding = np.asarray(tech_embedding) | |
| if tech_embedding is not None: | |
| tech_embedding = tech_embedding.astype(np.float32, copy=False) | |
| embed_dim_hint = sample.get("tech_embedding_dim") or sample.get("embedding_dim") | |
| try: | |
| embed_dim_hint = int(embed_dim_hint) if embed_dim_hint is not None else None | |
| except (TypeError, ValueError): | |
| embed_dim_hint = None | |
| if tech_embedding is None: | |
| tech_embedding = _select_tech_embedding(base_embedding, sample["tech"], embed_dim_hint) | |
| if tech_embedding is not None: | |
| tech_embedding = tech_embedding.astype(np.float32, copy=False) | |
| pair = (sample["snr"], sample["mob"]) | |
| joint_label = pair_to_name.get(pair) | |
| joint_label_id = pair_to_id.get(pair) | |
| tsne_x = sample.get("tsne_x") | |
| tsne_y = sample.get("tsne_y") | |
| tsne_raw_x = sample.get("tsne_raw_x") | |
| tsne_raw_y = sample.get("tsne_raw_y") | |
| records.append( | |
| { | |
| "index": i, | |
| "tech": sample["tech"], | |
| "snr": sample["snr"], | |
| "mod": sample["mod"], | |
| "mob": sample["mob"], | |
| "embedding": base_embedding, | |
| "tech_embedding": tech_embedding, | |
| "moe_embedding": moe_embedding, | |
| "spectrogram": flat_spec, | |
| "joint_label": joint_label, | |
| "joint_label_id": joint_label_id, | |
| "tsne_x": tsne_x, | |
| "tsne_y": tsne_y, | |
| "tsne_raw_x": tsne_raw_x, | |
| "tsne_raw_y": tsne_raw_y, | |
| } | |
| ) | |
| df = pd.DataFrame(records) | |
| if skipped: | |
| print(f"[WARN] Skipped {skipped} malformed samples while loading demo data") | |
| print(f"[INFO] Loaded {len(df)} samples (MoE embeddings: {has_moe})") | |
| return df, has_moe | |
| def apply_filters( | |
| dataframe: pd.DataFrame, | |
| tech_filter, | |
| snr_filter, | |
| mod_filter, | |
| mob_filter, | |
| ) -> pd.DataFrame: | |
| filtered = dataframe.copy() | |
| if tech_filter: | |
| filtered = filtered[filtered["tech"].isin(tech_filter)] | |
| if snr_filter: | |
| filtered = filtered[filtered["snr"].isin(snr_filter)] | |
| if mod_filter: | |
| filtered = filtered[filtered["mod"].isin(mod_filter)] | |
| if mob_filter: | |
| filtered = filtered[filtered["mob"].isin(mob_filter)] | |
| return filtered | |
| def _select_tech_embedding(flat_embedding: np.ndarray | None, tech: str, embed_dim: Optional[int]) -> Optional[np.ndarray]: | |
| """Extract the technology-specific expert embedding. | |
| Some artifacts don't include an explicit embedding dimension hint. In that case, | |
| infer `embed_dim = total_dim / num_experts` when divisible. | |
| """ | |
| if flat_embedding is None: | |
| return None | |
| flat_embedding = np.asarray(flat_embedding).reshape(-1) | |
| total = flat_embedding.size | |
| blocks = len(TECH_EXPERT_ORDER) | |
| if blocks <= 0: | |
| return None | |
| inferred_dim = embed_dim | |
| if inferred_dim is None: | |
| if total % blocks != 0: | |
| return None | |
| inferred_dim = total // blocks | |
| try: | |
| inferred_dim = int(inferred_dim) | |
| except (TypeError, ValueError): | |
| return None | |
| if inferred_dim <= 0: | |
| return None | |
| expected = blocks * inferred_dim | |
| if expected != total: | |
| # If metadata is wrong, don't crash; fall back to an even split only if possible. | |
| if total % blocks != 0: | |
| return None | |
| inferred_dim = total // blocks | |
| try: | |
| arr = flat_embedding.reshape(blocks, inferred_dim) | |
| except ValueError: | |
| return None | |
| tech_idx = TECH_TO_EXPERT_IDX.get(str(tech)) | |
| if tech_idx is None or tech_idx >= arr.shape[0]: | |
| return arr.mean(axis=0) | |
| return arr[tech_idx] | |
| def _sample_balanced_by_snr(dataframe: pd.DataFrame, samples_per_snr: int, seed: int) -> pd.DataFrame: | |
| if dataframe.empty: | |
| return dataframe | |
| rng = np.random.default_rng(int(seed)) | |
| grouped = {snr: grp for snr, grp in dataframe.groupby("snr") if not grp.empty} | |
| if not grouped: | |
| return dataframe.iloc[0:0] | |
| ordered_snrs = sorted(grouped.keys()) | |
| sampled_frames: List[pd.DataFrame] = [] | |
| for snr_label in ordered_snrs: | |
| group = grouped[snr_label] | |
| if samples_per_snr <= 0 or samples_per_snr >= len(group): | |
| sampled_frames.append(group) | |
| continue | |
| random_state = int(rng.integers(0, 1_000_000_000)) | |
| sampled_frames.append(group.sample(n=samples_per_snr, random_state=random_state)) | |
| if not sampled_frames: | |
| return dataframe.iloc[0:0] | |
| return pd.concat(sampled_frames).reset_index(drop=True) | |
| def plot_tsne( | |
| tech_filter, | |
| snr_filter, | |
| mod_filter, | |
| mob_filter, | |
| representation, | |
| color_label, | |
| perplexity, | |
| n_iter, | |
| samples_per_snr, | |
| sampling_seed, | |
| ): | |
| filtered_df = apply_filters(df, tech_filter, snr_filter, mod_filter, mob_filter) | |
| sampled_df = _sample_balanced_by_snr(filtered_df, samples_per_snr, sampling_seed) | |
| if len(sampled_df) < 5: | |
| fig = go.Figure() | |
| fig.update_layout( | |
| title=f"Not enough samples to plot (n={len(sampled_df)}). Widen filters or increase samples.", | |
| xaxis=dict(visible=False), | |
| yaxis=dict(visible=False), | |
| ) | |
| return fig | |
| sampled_df = sampled_df.copy() | |
| color_column = COLOR_OPTIONS.get(color_label, "snr") | |
| if representation == "LWM Embedding": | |
| embed_mask = sampled_df["tech_embedding"].apply(lambda x: x is not None) | |
| if embed_mask.sum() >= 5: | |
| sampled_df = sampled_df.loc[embed_mask].reset_index(drop=True) | |
| features = np.stack(sampled_df["tech_embedding"].values) | |
| title_prefix = "t-SNE of LWM Embedding" | |
| else: | |
| # Fallback: use the full embedding vector so the UI doesn't go blank when | |
| # per-expert metadata is missing in the artifact. | |
| base_mask = sampled_df["embedding"].apply(lambda x: x is not None) | |
| if base_mask.sum() < 5: | |
| fig = go.Figure() | |
| fig.update_layout( | |
| title="No embeddings available for the selected filters.", | |
| xaxis=dict(visible=False), | |
| yaxis=dict(visible=False), | |
| ) | |
| return fig | |
| sampled_df = sampled_df.loc[base_mask].reset_index(drop=True) | |
| features = np.stack(sampled_df["embedding"].values) | |
| title_prefix = "t-SNE of LWM Embedding (full vector)" | |
| else: | |
| features = build_tsne_raw_vectors(sampled_df["spectrogram"]) | |
| title_prefix = "t-SNE of Raw Spectrogram" | |
| if features.size == 0: | |
| fig = go.Figure() | |
| fig.update_layout(title="No features available for t-SNE.", xaxis=dict(visible=False), yaxis=dict(visible=False)) | |
| return fig | |
| features = _standardize_for_tsne(features) | |
| eff_perplexity = min(perplexity, len(sampled_df) - 1) | |
| eff_perplexity = max(5, eff_perplexity) | |
| tsne_kwargs = dict( | |
| n_components=2, | |
| perplexity=eff_perplexity, | |
| random_state=42, | |
| init="pca", | |
| learning_rate="auto", | |
| ) | |
| try: | |
| tsne = TSNE(**tsne_kwargs, n_iter=n_iter) | |
| except TypeError: | |
| tsne = TSNE(**tsne_kwargs) | |
| try: | |
| projections = tsne.fit_transform(features) | |
| except Exception: | |
| pca = PCA(n_components=2, random_state=42) | |
| projections = pca.fit_transform(features) | |
| sampled_df["x"] = projections[:, 0] | |
| sampled_df["y"] = projections[:, 1] | |
| category_orders = {} | |
| if color_column == "snr": | |
| category_orders["snr"] = [snr for snr in SNR_ORDER if snr in sampled_df["snr"].unique()] | |
| fig = px.scatter( | |
| sampled_df, | |
| x="x", | |
| y="y", | |
| color=color_column, | |
| hover_data=["tech", "snr", "mod", "mob"], | |
| title=f"{title_prefix} ({len(sampled_df)} samples)", | |
| template="plotly_white", | |
| category_orders=category_orders, | |
| ) | |
| height = 680 if color_label == "SNR" else 640 | |
| fig.update_layout(legend_title_text=color_label, width=640, height=height) | |
| fig.update_yaxes(scaleanchor="x", scaleratio=1) | |
| return fig | |
| def build_raw_feature_matrix( | |
| samples: pd.Series, | |
| max_components: Optional[int] = 256, | |
| *, | |
| normalize: bool = True, | |
| reduce_dim: bool = True, | |
| ) -> np.ndarray: | |
| raw_flat = [] | |
| for spec in samples: | |
| arr = np.asarray(spec, dtype=np.float32) | |
| raw_flat.append(arr.reshape(-1)) | |
| matrix = np.stack(raw_flat) | |
| matrix = np.nan_to_num(matrix, copy=False) | |
| if normalize: | |
| scaler = StandardScaler() | |
| matrix = scaler.fit_transform(matrix) | |
| if reduce_dim and max_components: | |
| # Cap n_components to valid PCA range: <= min(n_samples-1, n_features) | |
| n_samples, n_features = matrix.shape | |
| if n_samples > 1: | |
| max_valid = min(n_features, max(n_samples - 1, 1)) | |
| else: | |
| max_valid = 1 | |
| target = min(max_components, max_valid) | |
| if target < 1: | |
| target = 1 | |
| if target < n_features: | |
| projector = PCA(n_components=target, random_state=42) | |
| try: | |
| matrix = projector.fit_transform(matrix) | |
| except ValueError: | |
| safe_components = max(1, min(n_samples, n_features) - 1) | |
| safe_components = min(safe_components, target) | |
| if safe_components >= 1: | |
| fallback = PCA(n_components=safe_components, random_state=42) | |
| matrix = fallback.fit_transform(matrix) | |
| return matrix | |
| def build_tsne_raw_vectors(samples: pd.Series, eps: float = 1e-6) -> np.ndarray: | |
| rows: List[np.ndarray] = [] | |
| for spec in samples: | |
| arr = np.asarray(spec, dtype=np.float32) | |
| flat = arr.reshape(-1) | |
| mean = float(flat.mean()) | |
| std = float(flat.std()) | |
| if std < eps: | |
| std = eps | |
| normalized = (flat - mean) / std | |
| rows.append(normalized.astype(np.float32, copy=False)) | |
| if not rows: | |
| return np.empty((0, 0), dtype=np.float32) | |
| return np.stack(rows) | |
| def _standardize_for_tsne(features: np.ndarray) -> np.ndarray: | |
| if features.size == 0: | |
| return features | |
| scaler = StandardScaler() | |
| scaled = scaler.fit_transform(features) | |
| scaled = np.nan_to_num(scaled, copy=False, nan=0.0, posinf=0.0, neginf=0.0) | |
| scaled = np.clip(scaled, -1e6, 1e6) | |
| return scaled.astype(np.float32, copy=False) | |
| def stratified_split(filtered_df: pd.DataFrame, train_ratio: float, seed: int) -> Tuple[np.ndarray, np.ndarray]: | |
| rng = np.random.default_rng(int(seed)) | |
| train_indices = [] | |
| test_indices = [] | |
| for label_id, group in filtered_df.groupby("joint_label_id"): | |
| indices = group.index.to_numpy() | |
| if indices.size < 2: | |
| raise ValueError(f"Class '{CLASS_LABELS[int(label_id)]}' needs at least 2 samples for evaluation.") | |
| rng.shuffle(indices) | |
| split = int(round(indices.size * train_ratio)) | |
| split = max(1, min(indices.size - 1, split)) | |
| train_indices.extend(indices[:split]) | |
| test_indices.extend(indices[split:]) | |
| return np.array(train_indices), np.array(test_indices) | |
| def select_knn_k(train_labels: np.ndarray, max_k: int = 9) -> int: | |
| if train_labels.size == 0: | |
| return 1 | |
| class_counts = pd.Series(train_labels).value_counts() | |
| min_class = int(class_counts.min()) | |
| heuristic = int(np.sqrt(train_labels.size)) | |
| candidate = max(1, min(max_k, heuristic)) | |
| k = max(1, min(candidate, min_class)) | |
| if k % 2 == 0 and k > 1: | |
| k -= 1 | |
| return k | |
| def plot_confusion_heatmap( | |
| confusion: np.ndarray, label_names: List[str], title: str = "Prototype Classifier Confusion Matrix" | |
| ) -> go.Figure: | |
| fig = go.Figure( | |
| data=go.Heatmap( | |
| z=confusion, | |
| x=label_names, | |
| y=label_names, | |
| colorscale="Viridis", | |
| hovertemplate="Predicted %{x}<br>True %{y}<br>Count %{z}<extra></extra>", | |
| ) | |
| ) | |
| fig.update_layout( | |
| title=title, | |
| xaxis_title="Predicted", | |
| yaxis_title="True", | |
| xaxis=dict(tickangle=45), | |
| ) | |
| return fig | |
| def run_joint_evaluation(train_pct, seed, tech_filter, snr_filter, mod_filter, mob_filter): | |
| if evaluation_disabled: | |
| fig = go.Figure() | |
| fig.update_layout(title="MoE embeddings unavailable", xaxis=dict(visible=False), yaxis=dict(visible=False)) | |
| return fig, fig, "MoE embeddings are not available in this Space build." | |
| filtered = apply_filters(joint_eval_df, tech_filter, snr_filter, mod_filter, mob_filter) | |
| if filtered.empty: | |
| fig = go.Figure() | |
| fig.update_layout(title="No samples after filtering", xaxis=dict(visible=False), yaxis=dict(visible=False)) | |
| return fig, fig, "No samples match the selected filters." | |
| if filtered["joint_label_id"].nunique() < 2: | |
| fig = go.Figure() | |
| fig.update_layout(title="Need at least two classes", xaxis=dict(visible=False), yaxis=dict(visible=False)) | |
| return fig, fig, "Need at least two joint SNR/Doppler classes to evaluate." | |
| filtered = filtered.reset_index(drop=True) | |
| try: | |
| train_idx, test_idx = stratified_split(filtered, train_pct / 100.0, seed) | |
| except ValueError as exc: | |
| fig = go.Figure() | |
| fig.update_layout(title="Unable to split dataset", xaxis=dict(visible=False), yaxis=dict(visible=False)) | |
| return fig, fig, str(exc) | |
| labels = filtered["joint_label_id"].to_numpy(dtype=int) | |
| moe_features = np.stack(filtered["moe_embedding"].values) | |
| raw_features = build_raw_feature_matrix( | |
| filtered["spectrogram"], | |
| max_components=None, | |
| normalize=False, | |
| reduce_dim=False, | |
| ) | |
| train_labels = labels[train_idx] | |
| knn_k = select_knn_k(train_labels) | |
| moe_metrics = compute_knn_metrics(moe_features, labels, train_idx, test_idx, knn_k, label_lookup=CLASS_LABELS) | |
| raw_metrics = compute_knn_metrics(raw_features, labels, train_idx, test_idx, knn_k, label_lookup=CLASS_LABELS) | |
| moe_fig = plot_confusion_heatmap( | |
| moe_metrics["confusion"], moe_metrics["label_names"], title=f"MoE Embedding Confusion (k={moe_metrics['k']})" | |
| ) | |
| raw_fig = plot_confusion_heatmap( | |
| raw_metrics["confusion"], raw_metrics["label_names"], title=f"Raw Spectrogram Confusion (k={raw_metrics['k']})" | |
| ) | |
| status = ( | |
| f"### Joint SNR/Doppler Metrics\n" | |
| f"**Train/Test Samples:** {len(train_idx)} / {len(test_idx)} | **Train %:** {train_pct}% | **Seed:** {seed} | **k-NN k:** {knn_k}\n\n" | |
| "| Representation | Accuracy | Macro F1 |\n" | |
| "| --- | --- | --- |\n" | |
| f"| **MoE Embedding** | {moe_metrics['accuracy'] * 100:.2f}% | {moe_metrics['macro_f1']:.3f} |\n" | |
| f"| **Raw Spectrogram** | {raw_metrics['accuracy'] * 100:.2f}% | {raw_metrics['macro_f1']:.3f} |" | |
| ) | |
| return moe_fig, raw_fig, status | |
| def stratified_split_mod(df_subset: pd.DataFrame, train_ratio: float, seed: int) -> Tuple[np.ndarray, np.ndarray]: | |
| rng = np.random.default_rng(int(seed)) | |
| train_idx = [] | |
| test_idx = [] | |
| for _, group in df_subset.groupby("mod"): | |
| indices = group.index.to_numpy() | |
| if indices.size < 2: | |
| raise ValueError("Each modulation needs at least 2 samples.") | |
| rng.shuffle(indices) | |
| split = int(round(len(indices) * train_ratio)) | |
| split = max(1, min(len(indices) - 1, split)) | |
| train_idx.extend(indices[:split]) | |
| test_idx.extend(indices[split:]) | |
| return np.array(train_idx), np.array(test_idx) | |
| def compute_knn_metrics( | |
| features: np.ndarray, | |
| labels: np.ndarray, | |
| train_idx: np.ndarray, | |
| test_idx: np.ndarray, | |
| knn_k: int, | |
| label_lookup: List[str] | None = None, | |
| ) -> Dict[str, object]: | |
| train_features = features[train_idx] | |
| test_features = features[test_idx] | |
| train_labels = labels[train_idx] | |
| test_labels = labels[test_idx] | |
| candidate_k = max(1, min(int(knn_k), len(train_labels))) | |
| if candidate_k % 2 == 0 and candidate_k > 1: | |
| candidate_k -= 1 | |
| knn = KNeighborsClassifier(n_neighbors=candidate_k, metric="euclidean") | |
| knn.fit(train_features, train_labels) | |
| preds = knn.predict(test_features) | |
| acc = accuracy_score(test_labels, preds) | |
| active_labels = np.unique(np.concatenate([train_labels, test_labels, preds])) | |
| macro = f1_score(test_labels, preds, labels=active_labels, average="macro", zero_division=0) | |
| if label_lookup is None: | |
| label_names = [str(lbl) for lbl in active_labels] | |
| else: | |
| label_names = [label_lookup[int(lbl)] for lbl in active_labels] | |
| cm = confusion_matrix(test_labels, preds, labels=active_labels) | |
| return { | |
| "accuracy": acc, | |
| "macro_f1": macro, | |
| "confusion": cm, | |
| "label_names": label_names, | |
| "k": candidate_k, | |
| } | |
| def evaluate_modulation(tech: str, train_pct: int, seed: int): | |
| if not tech: | |
| fig = go.Figure() | |
| fig.update_layout(title="Select a technology to evaluate.", xaxis=dict(visible=False), yaxis=dict(visible=False)) | |
| return fig, fig, "No technology selected." | |
| subset = df[df["tech"] == tech].copy().reset_index(drop=True) | |
| if subset.empty or subset["mod"].nunique() < 2: | |
| fig = go.Figure() | |
| fig.update_layout( | |
| title="Need at least two modulation classes for this technology.", | |
| xaxis=dict(visible=False), | |
| yaxis=dict(visible=False), | |
| ) | |
| return fig, fig, "Not enough modulation classes." | |
| try: | |
| train_idx, test_idx = stratified_split_mod(subset, train_pct / 100.0, seed) | |
| except ValueError as exc: | |
| fig = go.Figure() | |
| fig.update_layout(title=str(exc), xaxis=dict(visible=False), yaxis=dict(visible=False)) | |
| return fig, fig, str(exc) | |
| labels = subset["mod"].astype(str).to_numpy() | |
| emb_features = np.stack(subset["embedding"].values) | |
| raw_features = build_raw_feature_matrix( | |
| subset["spectrogram"], | |
| max_components=None, | |
| normalize=False, | |
| reduce_dim=False, | |
| ) | |
| train_labels = labels[train_idx] | |
| class_counts = pd.Series(train_labels).value_counts() | |
| if class_counts.empty: | |
| fig = go.Figure() | |
| fig.update_layout(title="No modulation classes found.", xaxis=dict(visible=False), yaxis=dict(visible=False)) | |
| return fig, fig, "No modulation classes found." | |
| knn_k = select_knn_k(train_labels) | |
| emb_metrics = compute_knn_metrics(emb_features, labels, train_idx, test_idx, knn_k) | |
| raw_metrics = compute_knn_metrics(raw_features, labels, train_idx, test_idx, knn_k) | |
| emb_fig = plot_confusion_heatmap(emb_metrics["confusion"], emb_metrics["label_names"], title="Embedding Confusion") | |
| raw_fig = plot_confusion_heatmap(raw_metrics["confusion"], raw_metrics["label_names"], title="Raw Confusion") | |
| summary = ( | |
| f"### {tech} Modulation Metrics\n" | |
| f"**Train/Test Samples:** {len(train_idx)} / {len(test_idx)} | **Classifier:** k-NN (k = {emb_metrics['k']})\n\n" | |
| "| Representation | Accuracy | Macro F1 |\n" | |
| "| --- | --- | --- |\n" | |
| f"| **LWM Embedding** | {emb_metrics['accuracy'] * 100:.2f}% | {emb_metrics['macro_f1']:.3f} |\n" | |
| f"| **Raw Spectrogram** | {raw_metrics['accuracy'] * 100:.2f}% | {raw_metrics['macro_f1']:.3f} |" | |
| ) | |
| return emb_fig, raw_fig, summary | |
| def _reshape_spectrogram(spec: np.ndarray) -> np.ndarray: | |
| arr = np.asarray(spec) | |
| if arr.ndim == 1: | |
| side = int(round(arr.size ** 0.5)) | |
| if side * side == arr.size: | |
| arr = arr.reshape(side, side) | |
| else: | |
| arr = arr.reshape(-1, side) | |
| elif arr.ndim == 3: | |
| arr = arr.squeeze() | |
| return arr | |
| def _spectrogram_to_image(spec: np.ndarray, title: str) -> np.ndarray: | |
| normalized = spec.astype(np.float32) | |
| if np.isnan(normalized).any(): | |
| normalized = np.nan_to_num(normalized) | |
| vmin, vmax = normalized.min(), normalized.max() | |
| if vmax - vmin > 0: | |
| normalized = (normalized - vmin) / (vmax - vmin) | |
| fig, ax = plt.subplots(figsize=(3, 3)) | |
| im = ax.imshow(normalized, cmap="turbo", aspect="auto", origin="lower") | |
| ax.set_xticks([]) | |
| ax.set_yticks([]) | |
| ax.set_title(title, fontsize=8) | |
| cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04) | |
| cbar.ax.tick_params(labelsize=6) | |
| fig.tight_layout(pad=0.5) | |
| canvas = FigureCanvasAgg(fig) | |
| canvas.draw() | |
| width, height = canvas.get_width_height() | |
| buf = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8).reshape(height, width, 4) | |
| image = buf[..., :3].copy() | |
| plt.close(fig) | |
| return image | |
| def render_spectrogram_gallery(tech, snr, mod, mob, sample_count, seed): | |
| tech_list = [tech] if tech else None | |
| snr_list = [snr] if snr else None | |
| mod_list = [mod] if mod else None | |
| mob_list = [mob] if mob else None | |
| filtered = apply_filters(df, tech_list, snr_list, mod_list, mob_list) | |
| if filtered.empty: | |
| return [], "No spectrograms match the selected filters." | |
| sample_count = max(1, int(sample_count)) | |
| rng = np.random.default_rng(int(seed)) | |
| if len(filtered) > sample_count: | |
| indices = rng.choice(filtered.index.to_numpy(), size=sample_count, replace=False) | |
| subset = filtered.loc[indices] | |
| else: | |
| subset = filtered | |
| gallery_items = [] | |
| for _, row in subset.iterrows(): | |
| spec = _reshape_spectrogram(row["spectrogram"]) | |
| caption = f"{row['tech']} | {row['mod']} | {row['snr']} | {row['mob']}" | |
| img = _spectrogram_to_image(spec, caption) | |
| gallery_items.append((img, caption)) | |
| status = f"Showing {len(subset)} of {len(filtered)} matches." | |
| return gallery_items, status | |
| mapping_info = load_joint_mapping() | |
| df, has_moe_embeddings = load_data(mapping_info) | |
| CLASS_LABELS = mapping_info["label_names"] | |
| DATASET_STATUS = ( | |
| f"Dataset loaded: {len(df)} samples | " | |
| f"MoE embeddings: {'yes' if has_moe_embeddings else 'no'} | " | |
| f"HF token detected: {'yes' if HF_TOKEN else 'no'} | " | |
| f"Synthetic fallback: {'yes' if USING_SYNTHETIC_DATA else 'no'} | " | |
| f"Demo repo: {HUB_REPO_ID}@{HUB_REVISION or 'main'} ({','.join(HUB_REPO_TYPES)})" | |
| ) | |
| if LAST_DEMO_DOWNLOAD_ERROR: | |
| DATASET_STATUS += f" | Download error: {LAST_DEMO_DOWNLOAD_ERROR}" | |
| has_moe_column = df["moe_embedding"].apply(lambda x: x is not None) | |
| joint_eval_df = df[has_moe_column & df["joint_label_id"].notna()] | |
| tech_choices = sorted(df["tech"].unique()) | |
| snr_choices = _sort_snrs(df["snr"].unique()) | |
| mod_choices = sorted(df["mod"].unique()) | |
| mob_choices = sorted(df["mob"].unique()) | |
| TECH_TO_MODS: Dict[str, List[str]] = { | |
| tech: sorted(df.loc[df["tech"] == tech, "mod"].unique().tolist()) for tech in tech_choices | |
| } | |
| COLOR_OPTIONS: Dict[str, str] = { | |
| "SNR": "snr", | |
| "Modulation": "mod", | |
| "Mobility": "mob", | |
| } | |
| default_tech = tech_choices[:1] if tech_choices else [] | |
| initial_spec_mod_choices = TECH_TO_MODS.get(default_tech[0], mod_choices) if default_tech else mod_choices | |
| evaluation_disabled = (not has_moe_embeddings) or joint_eval_df.empty | |
| def update_modulation_choices(selected_tech: Optional[str]): | |
| choices = mod_choices | |
| if selected_tech: | |
| choices = TECH_TO_MODS.get(selected_tech, mod_choices) | |
| return gr.Dropdown.update(choices=choices, value=None) | |
| with gr.Blocks(title="LWM-Spectro Lab") as demo: | |
| gr.Markdown("# 🔬 LWM-Spectro Interactive Demo") | |
| gr.Markdown(f"**{DATASET_STATUS}**") | |
| gr.Markdown( | |
| """ | |
| Interactive lab for exploring LWM spectrogram embeddings and cached MoE probes. | |
| """ | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab("Spectrograms"): | |
| gr.Markdown( | |
| """ | |
| ### 🔎 Spectrogram Studio | |
| - Peek at the raw 128×128 Sub-6 GHz I/Q baseband spectrograms that drive the SNR/mobility recognition tasks. | |
| - Filter by technology/SNR/modulation/mobility to understand how diverse the training pool is across scenarios. | |
| - Use the gallery to sanity-check preprocessing before sending samples through LWM or downstream models. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=320): | |
| spec_tech = gr.Dropdown( | |
| choices=tech_choices, | |
| value=default_tech[0] if default_tech else None, | |
| label="Technology", | |
| ) | |
| spec_snr = gr.Dropdown(choices=snr_choices, value=None, label="SNR (optional)") | |
| spec_mod = gr.Dropdown(choices=initial_spec_mod_choices, value=None, label="Modulation (optional)") | |
| spec_mob = gr.Dropdown(choices=mob_choices, value=None, label="Mobility (optional)") | |
| spec_count = gr.Slider(minimum=1, maximum=12, step=1, value=6, label="Samples to show") | |
| spec_seed = gr.Slider(minimum=0, maximum=9999, step=1, value=0, label="Random seed") | |
| spec_btn = gr.Button("Show spectrograms", variant="primary") | |
| with gr.Column(scale=3): | |
| gallery = gr.Gallery( | |
| label="Spectrogram Samples", | |
| columns=[3], | |
| rows=[3], | |
| height=560, | |
| preview=True, | |
| ) | |
| gallery_status = gr.Textbox(label="Status", interactive=False) | |
| spec_inputs = [spec_tech, spec_snr, spec_mod, spec_mob, spec_count, spec_seed] | |
| spec_btn.click(render_spectrogram_gallery, inputs=spec_inputs, outputs=[gallery, gallery_status]) | |
| demo.load(render_spectrogram_gallery, inputs=spec_inputs, outputs=[gallery, gallery_status]) | |
| spec_tech.change(update_modulation_choices, inputs=spec_tech, outputs=spec_mod) | |
| with gr.Tab("t-SNE Analysis"): | |
| gr.Markdown( | |
| """ | |
| ### 🌀 Embedding vs. Raw Space | |
| - Run quick t-SNE sweeps on either LWM embeddings or raw spectrogram vectors. | |
| - Toggle **Color By** to mirror the "colored by modulation vs. SNR" comparisons from the CLI examples. | |
| - Balanced per-SNR sampling plus configurable perplexity help match the figures you generate locally with `plot/plot_tsne.py`. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=300): | |
| gr.Markdown("### Filters") | |
| tech_filter = gr.CheckboxGroup(choices=tech_choices, value=default_tech, label="Technology") | |
| snr_filter = gr.Dropdown( | |
| choices=snr_choices, value=None, multiselect=True, label="SNR (Empty = All)" | |
| ) | |
| mod_filter = gr.Dropdown( | |
| choices=mod_choices, value=None, multiselect=True, label="Modulation (Empty = All)" | |
| ) | |
| mob_filter = gr.Dropdown( | |
| choices=mob_choices, value=None, multiselect=True, label="Mobility (Empty = All)" | |
| ) | |
| gr.Markdown("### Visualization Settings") | |
| representation = gr.Radio( | |
| choices=["LWM Embedding", "Raw Spectrogram"], | |
| value="LWM Embedding", | |
| label="Representation", | |
| ) | |
| color_by = gr.Dropdown( | |
| choices=list(COLOR_OPTIONS.keys()), | |
| value="SNR", | |
| label="Color By", | |
| ) | |
| with gr.Accordion("Advanced t-SNE Settings", open=False): | |
| perplexity = gr.Slider(minimum=5, maximum=50, value=30, step=1, label="Perplexity") | |
| n_iter = gr.Slider(minimum=250, maximum=2000, value=1000, step=50, label="Iterations") | |
| samples_per_snr = gr.Slider( | |
| minimum=20, | |
| maximum=500, | |
| value=DEFAULT_TSNE_SAMPLES_PER_SNR, | |
| step=10, | |
| label="Samples per SNR", | |
| ) | |
| sampling_seed = gr.Slider( | |
| minimum=0, | |
| maximum=9999, | |
| value=42, | |
| step=1, | |
| label="Sampling Seed", | |
| ) | |
| btn = gr.Button("Update Plot", variant="primary") | |
| with gr.Column(scale=3): | |
| plot = gr.Plot(label="t-SNE Visualization") | |
| btn.click( | |
| plot_tsne, | |
| inputs=[ | |
| tech_filter, | |
| snr_filter, | |
| mod_filter, | |
| mob_filter, | |
| representation, | |
| color_by, | |
| perplexity, | |
| n_iter, | |
| samples_per_snr, | |
| sampling_seed, | |
| ], | |
| outputs=[plot], | |
| ) | |
| demo.load( | |
| plot_tsne, | |
| inputs=[ | |
| tech_filter, | |
| snr_filter, | |
| mod_filter, | |
| mob_filter, | |
| representation, | |
| color_by, | |
| perplexity, | |
| n_iter, | |
| samples_per_snr, | |
| sampling_seed, | |
| ], | |
| outputs=[plot], | |
| ) | |
| with gr.Tab("Modulation Classification"): | |
| gr.Markdown( | |
| """ | |
| ### 🎯 Lightweight Modulation Head | |
| - Prototype how well the frozen LWM backbone separates modulation formats for each technology using spectrograms as input. | |
| - The adaptive k-NN classifier approximates the behavior of the downstream residual 1D-CNN before heavy training; each tech is evaluated separately to measure its expert’s modulation discrimination. | |
| - Sweep train/test splits and seeds to gauge robustness when only a portion of the dataset is labeled. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=320): | |
| mod_tech = gr.Dropdown( | |
| choices=tech_choices, | |
| value=default_tech[0] if default_tech else None, | |
| label="Technology", | |
| ) | |
| mod_train = gr.Slider(minimum=50, maximum=90, step=5, value=70, label="Training Percentage (%)") | |
| mod_seed = gr.Slider(minimum=0, maximum=9999, step=1, value=42, label="Random Seed") | |
| gr.Markdown("k-NN uses an adaptive k based on the number of modulation classes and available training samples.") | |
| mod_btn = gr.Button("Run modulation evaluation", variant="primary") | |
| with gr.Column(scale=3): | |
| with gr.Row(): | |
| emb_plot = gr.Plot(label="Embedding Confusion Matrix") | |
| raw_plot = gr.Plot(label="Raw Confusion Matrix") | |
| mod_summary = gr.Markdown(value="Select a technology and run the evaluation to view metrics.") | |
| mod_btn.click( | |
| evaluate_modulation, | |
| inputs=[mod_tech, mod_train, mod_seed], | |
| outputs=[emb_plot, raw_plot, mod_summary], | |
| ) | |
| with gr.Tab("Joint SNR/Doppler Evaluation"): | |
| gr.Markdown( | |
| """ | |
| ### 🌪️ Joint Channel Dynamics Benchmark | |
| - Evaluate the precomputed MoE embeddings on the 14-class joint SNR/Doppler recognition task. | |
| - Mirrors the second stage of our reliability workflow where, without an explicit technology label, the MoE router sends samples to the most relevant expert and mobility-aware cues guide SNR-aware routing. | |
| - Upload or reference Hub-hosted tensors to compare MoE vs. raw spectrogram baselines before fine-tuning heavier heads. | |
| """ | |
| ) | |
| if evaluation_disabled: | |
| gr.Markdown( | |
| "⚠️ Precomputed MoE embeddings are not bundled in this Space build. Upload a dataset locally to run evaluations." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=320): | |
| gr.Markdown("### Evaluation Filters") | |
| eval_tech_filter = gr.CheckboxGroup( | |
| choices=tech_choices, | |
| value=default_tech, | |
| label="Technology", | |
| interactive=not evaluation_disabled, | |
| ) | |
| eval_snr_filter = gr.Dropdown( | |
| choices=snr_choices, | |
| value=None, | |
| multiselect=True, | |
| label="SNR (Empty = All)", | |
| interactive=not evaluation_disabled, | |
| ) | |
| eval_mod_filter = gr.Dropdown( | |
| choices=mod_choices, | |
| value=None, | |
| multiselect=True, | |
| label="Modulation (Empty = All)", | |
| interactive=not evaluation_disabled, | |
| ) | |
| eval_mob_filter = gr.Dropdown( | |
| choices=mob_choices, | |
| value=None, | |
| multiselect=True, | |
| label="Mobility (Empty = All)", | |
| interactive=not evaluation_disabled, | |
| ) | |
| gr.Markdown("### Prototype Settings") | |
| train_pct = gr.Slider( | |
| minimum=10, | |
| maximum=80, | |
| step=5, | |
| value=60, | |
| label="Training Percentage (%)", | |
| interactive=not evaluation_disabled, | |
| ) | |
| seed = gr.Slider( | |
| minimum=0, | |
| maximum=9999, | |
| step=1, | |
| value=42, | |
| label="Random Seed", | |
| interactive=not evaluation_disabled, | |
| ) | |
| eval_btn = gr.Button("Run evaluation", variant="primary", interactive=not evaluation_disabled) | |
| with gr.Column(scale=3): | |
| with gr.Row(): | |
| eval_plot = gr.Plot(label="MoE Prototype Confusion") | |
| eval_plot_raw = gr.Plot(label="Raw Prototype Confusion") | |
| eval_status = gr.Markdown(value="Run an evaluation to compare MoE vs raw baselines.") | |
| eval_btn.click( | |
| run_joint_evaluation, | |
| inputs=[train_pct, seed, eval_tech_filter, eval_snr_filter, eval_mod_filter, eval_mob_filter], | |
| outputs=[eval_plot, eval_plot_raw, eval_status], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |