Spaces:
Running
Running
| from __future__ import annotations | |
| """ | |
| MesomorphicECG XAI Gradio app for Hugging Face Spaces. | |
| This version focuses on: | |
| - Selecting sampling rate (100 / 500 Hz), model type (categorical vs single-linear), | |
| and task (norm_vs_cd / norm_vs_hyp / norm_vs_mi / norm_vs_sttc). | |
| - Loading pre-packaged ECG examples from local binary .npz files in this Space. | |
| - Downloading the corresponding IMN checkpoint from | |
| `SEARCH-IHI/mesomorphicECG` on the Hugging Face Hub. | |
| - Running inference and visualizing intrinsic feature attributions | |
| (Impact = w * x) as a lead × segment heatmap plus per-lead ECG traces. | |
| Data binaries | |
| ------------- | |
| For each (sampling_rate, task) pair you should provide a `.npz` file as | |
| configured in DATA_FILES below, with keys: | |
| signals : float32 array [N, 12, L] | |
| labels : float32/int array [N] with 0 (NORM) / 1 (POS_CLASS) | |
| reports : object array [N] of clinical notes | |
| age : array [N] | |
| sex : object array [N] | |
| ecg_id : array [N] | |
| """ | |
| import os | |
| from functools import lru_cache | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import numpy as np | |
| import torch | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt # noqa: E402 | |
| import gradio as gr # noqa: E402 | |
| from huggingface_hub import hf_hub_download, list_repo_files # noqa: E402 | |
| import single_linear_imn_core as sl_core # noqa: E402 | |
| import categorical_imn_core as cat_core # noqa: E402 | |
| HF_MODEL_REPO = "SEARCH-IHI/mesomorphicECG" | |
| TASK_TO_POS = { | |
| "norm_vs_mi": "MI", | |
| "norm_vs_sttc": "STTC", | |
| "norm_vs_cd": "CD", | |
| "norm_vs_hyp": "HYP", | |
| } | |
| LEAD_NAMES = sl_core.DEFAULT_LEAD_NAMES | |
| # Mapping from (sampling_rate, task) -> local data binary. | |
| DATA_FILES: Dict[Tuple[int, str], str] = { | |
| # 100 Hz | |
| (100, "norm_vs_cd"): "data/ptbxl_100hz_norm_vs_cd_test.npz", | |
| (100, "norm_vs_hyp"): "data/ptbxl_100hz_norm_vs_hyp_test.npz", | |
| (100, "norm_vs_mi"): "data/ptbxl_100hz_norm_vs_mi_test.npz", | |
| (100, "norm_vs_sttc"): "data/ptbxl_100hz_norm_vs_sttc_test.npz", | |
| # 500 Hz | |
| (500, "norm_vs_cd"): "data/ptbxl_500hz_norm_vs_cd_test.npz", | |
| (500, "norm_vs_hyp"): "data/ptbxl_500hz_norm_vs_hyp_test.npz", | |
| (500, "norm_vs_mi"): "data/ptbxl_500hz_norm_vs_mi_test.npz", | |
| (500, "norm_vs_sttc"): "data/ptbxl_500hz_norm_vs_sttc_test.npz", | |
| } | |
| DATA_CACHE: Dict[Tuple[int, str], Dict[str, Any]] = {} | |
| MODEL_CACHE: Dict[Tuple[str, int, str], Dict[str, Any]] = {} | |
| def zscore_per_lead(x: np.ndarray) -> np.ndarray: | |
| """Per-lead z-score normalization.""" | |
| mean = x.mean(axis=1, keepdims=True) | |
| std = x.std(axis=1, keepdims=True).clip(min=1e-6) | |
| return ((x - mean) / std).astype(np.float32) | |
| def ablate_and_recompute_imn_single( | |
| gen_w: torch.Tensor, | |
| x_t: torch.Tensor, | |
| gen_b: torch.Tensor, | |
| top_leads: list[int], | |
| top_segments: list[tuple[int, int]], | |
| n_remove_leads: int, | |
| n_remove_segments: int, | |
| window: int, | |
| stride: int, | |
| L: int, | |
| ) -> float: | |
| """ | |
| Ablation helper for SINGLE-LINEAR IMN (binary logit). | |
| gen_w: [1, 1, 12, L], x_t: [1, 12, L], gen_b: [1, 1] or [1, 1, 1]. | |
| We zero weights at selected leads/segments, keep the scalar bias unchanged, | |
| then recompute the single logit and its sigmoid P(pos_class). | |
| """ | |
| # Extract [12, L] weight map | |
| w_single = gen_w[0, 0].clone() # [12, L] | |
| # Remove whole leads | |
| if n_remove_leads > 0 and top_leads: | |
| for li in top_leads[:n_remove_leads]: | |
| w_single[li, :] = 0.0 | |
| # Remove segments | |
| if n_remove_segments > 0 and top_segments: | |
| for lead, t in top_segments[:n_remove_segments]: | |
| s = t * stride | |
| e = min(s + window, L) | |
| w_single[lead, s:e] = 0.0 | |
| x_exp = x_t[0] # [12, L] | |
| # Bias: handle [1, 1] or [1, 1, 1] | |
| if gen_b.ndim == 3: | |
| b = gen_b[0, 0, 0] | |
| else: | |
| b = gen_b[0, 0] | |
| new_logit = (w_single.to(x_exp.device) * x_exp).sum() + b | |
| prob_pos = float(torch.sigmoid(new_logit).item()) | |
| return prob_pos | |
| def ablate_and_recompute_imn_categorical( | |
| logits: torch.Tensor, | |
| gen_w: torch.Tensor, | |
| x_t: torch.Tensor, | |
| gen_b: torch.Tensor, | |
| top_leads: list[int], | |
| top_segments: list[tuple[int, int]], | |
| n_remove_leads: int, | |
| n_remove_segments: int, | |
| window: int, | |
| stride: int, | |
| L: int, | |
| pos_class_idx: int = 1, | |
| ) -> float: | |
| """ | |
| Ablation helper for CATEGORICAL IMN (2-class softmax). | |
| We zero weights for the positive class at selected leads/segments, | |
| keep all other class logits unchanged, and recompute P(pos_class) | |
| via softmax. | |
| """ | |
| # gen_w: [1, num_classes, 12, L] | |
| w_pos = gen_w[0, pos_class_idx].clone() # [12, L] | |
| # Remove whole leads | |
| if n_remove_leads > 0 and top_leads: | |
| for li in top_leads[:n_remove_leads]: | |
| w_pos[li, :] = 0.0 | |
| # Remove segments | |
| if n_remove_segments > 0 and top_segments: | |
| for lead, t in top_segments[:n_remove_segments]: | |
| s = t * stride | |
| e = min(s + window, L) | |
| w_pos[lead, s:e] = 0.0 | |
| x_exp = x_t[0] # [12, L] | |
| # Bias: handle [1, num_classes] or [1, num_classes, 1] | |
| if gen_b.ndim == 3: | |
| b_pos = gen_b[0, pos_class_idx, 0] | |
| else: | |
| b_pos = gen_b[0, pos_class_idx] | |
| new_logit_pos = (w_pos.to(x_exp.device) * x_exp).sum() + b_pos | |
| # logits: [1, num_classes] | |
| orig_logits = logits[0].clone() | |
| orig_logits[pos_class_idx] = new_logit_pos | |
| probs = torch.softmax(orig_logits, dim=0) | |
| return float(probs[pos_class_idx].item()) | |
| def build_fig_imn_with_highlights( | |
| x_np: np.ndarray, | |
| seg_hm: np.ndarray, | |
| window: int, | |
| stride: int, | |
| T: int, | |
| pred: str, | |
| prob_pos: float, | |
| pos_class_name: str, | |
| sampling_rate: int, | |
| top_leads: Optional[list[int]] = None, | |
| top_segments: Optional[list[tuple[int, int]]] = None, | |
| removed_leads: Optional[list[int]] = None, | |
| removed_segments: Optional[list[tuple[int, int]]] = None, | |
| prob_abl: Optional[float] = None, | |
| lead_imp_signed: Optional[np.ndarray] = None, | |
| ) -> plt.Figure: | |
| """ | |
| Build a matplotlib figure with: | |
| - top heatmap of segment-wise importance (per lead), | |
| - 12 ECG traces with overlays highlighting important / removed segments and leads. | |
| """ | |
| import matplotlib.patches as mpatches | |
| L = x_np.shape[1] | |
| # Per-lead contribution share: use signed contribution so percentages match top leads. | |
| if lead_imp_signed is not None: | |
| denom = np.abs(lead_imp_signed).sum() + 1e-9 | |
| lead_pct = 100.0 * lead_imp_signed / denom | |
| else: | |
| lead_abs = np.abs(seg_hm).sum(axis=1) | |
| lead_pct = 100.0 * lead_abs / (lead_abs.sum() + 1e-9) | |
| cmap = "Reds" | |
| shade_color = "red" | |
| rem_lead_set = set(removed_leads or []) | |
| rem_seg_set = set(removed_segments or []) | |
| top_seg_set = set(top_segments or []) | |
| fig = plt.figure(figsize=(12, 14)) | |
| gs = fig.add_gridspec(14, 1, height_ratios=[2] + [1] * 12 + [0.5]) | |
| # Top heatmap | |
| ax0 = fig.add_subplot(gs[0, 0]) | |
| im = ax0.imshow(seg_hm, aspect="auto", vmin=0.0, vmax=1.0, cmap=cmap) | |
| ax0.set_yticks(range(12)) | |
| ax0.set_yticklabels(LEAD_NAMES) | |
| ax0.set_xlabel(f"Segments (window={window}, stride={stride}, fs={sampling_rate}Hz)") | |
| prob_str = f"P({pos_class_name})={prob_pos:.3f}" | |
| title = f"IMN Intrinsic Explanation | {pred} | {prob_str}" | |
| if prob_abl is not None: | |
| p_str = f"{prob_abl:.4f}" if prob_abl < 0.001 else f"{prob_abl:.3f}" | |
| title += f" -> Ablated P({pos_class_name}) = {p_str}" | |
| ax0.set_title(title) | |
| fig.colorbar(im, ax=ax0, fraction=0.02, pad=0.01) | |
| # Highlight removed leads in the heatmap | |
| for rl in rem_lead_set: | |
| rect = mpatches.Rectangle( | |
| (-0.5, rl - 0.5), | |
| T, | |
| 1, | |
| fill=False, | |
| edgecolor="red", | |
| linewidth=2.5, | |
| zorder=10, | |
| ) | |
| ax0.add_patch(rect) | |
| # Lead-wise ECG traces with overlays | |
| for lead in range(12): | |
| ax = fig.add_subplot(gs[lead + 1, 0]) | |
| ax.plot(x_np[lead], linewidth=0.8, color="black", alpha=0.6) | |
| ax.set_xlim(0, L - 1) | |
| ax.set_ylabel( | |
| f"{LEAD_NAMES[lead]} {lead_pct[lead]:.1f}%", | |
| rotation=0, | |
| labelpad=20, | |
| va="center", | |
| ) | |
| ylo, yhi = ax.get_ylim() | |
| # Shade top leads | |
| if top_leads and lead in top_leads: | |
| ax.axhspan(ylo, yhi, alpha=0.15, color="gold", zorder=0) | |
| # Mark removed whole leads | |
| if lead in rem_lead_set: | |
| ax.add_patch( | |
| mpatches.Rectangle( | |
| (0, ylo), | |
| L - 1, | |
| yhi - ylo, | |
| fill=False, | |
| edgecolor="red", | |
| linewidth=2.5, | |
| zorder=10, | |
| ) | |
| ) | |
| contrib = seg_hm[lead] | |
| for t in range(T): | |
| a = float(contrib[t]) | |
| alpha = min(0.5, a * 0.6) | |
| if alpha <= 0.05: | |
| continue | |
| start = t * stride | |
| end = min(start + window, L) | |
| hi = (lead, t) in top_seg_set | |
| is_rem = (lead, t) in rem_seg_set | |
| if is_rem: | |
| ax.axvspan(start, end, alpha=alpha, facecolor=shade_color, zorder=0) | |
| ax.add_patch( | |
| mpatches.Rectangle( | |
| (start, ylo), | |
| end - start, | |
| yhi - ylo, | |
| fill=False, | |
| edgecolor="red", | |
| linewidth=2, | |
| zorder=10, | |
| ) | |
| ) | |
| elif hi: | |
| ax.axvspan( | |
| start, | |
| end, | |
| alpha=alpha, | |
| facecolor=shade_color, | |
| edgecolor="lime", | |
| linewidth=1.5, | |
| zorder=1, | |
| ) | |
| else: | |
| ax.axvspan(start, end, alpha=alpha, facecolor=shade_color, zorder=0) | |
| ax.set_xticks([]) | |
| # Footer | |
| axf = fig.add_subplot(gs[13, 0]) | |
| axf.axis("off") | |
| leg = ( | |
| f"IMN Feature Attribution: |w(x)·x| aggregated by segment " | |
| f"(window={window}, stride={stride}). Gold/Lime = top leads/segments by signed contribution " | |
| f"(highest positive = most evidence for {pos_class_name}). " | |
| ) | |
| if rem_lead_set or rem_seg_set: | |
| leg += "Red boxes = removed (ablation)." | |
| axf.text(0.5, 0.5, leg, fontsize=9, wrap=True, transform=axf.transAxes, ha="center", va="center") | |
| fig.tight_layout() | |
| return fig | |
| def _list_model_repo_files() -> List[str]: | |
| return list_repo_files(repo_id=HF_MODEL_REPO, repo_type="model") | |
| def _resolve_ckpt_filename(model_type: str, sampling_rate: int, task: str) -> str: | |
| if model_type == "single_linear": | |
| category = f"single_linear_imn_{sampling_rate}hz" | |
| else: | |
| category = f"categorical_imn_{sampling_rate}hz" | |
| prefix = f"{category}/{task}/" | |
| files = _list_model_repo_files() | |
| candidates = [f for f in files if f.startswith(prefix) and f.endswith(".ckpt")] | |
| if not candidates: | |
| raise FileNotFoundError( | |
| f"No checkpoint (.ckpt) found in repo {HF_MODEL_REPO} under {prefix}. " | |
| "Ensure upload_best_checkpoints_to_hf.py has populated this path." | |
| ) | |
| best_style = [f for f in candidates if "best-imn-epoch=" in f] | |
| chosen = sorted(best_style or candidates)[-1] | |
| return chosen | |
| def load_imn_model( | |
| model_type: str, | |
| sampling_rate: int, | |
| task: str, | |
| ) -> Tuple[torch.nn.Module, str]: | |
| key = (model_type, sampling_rate, task) | |
| cached = MODEL_CACHE.get(key) | |
| if cached and cached["model"] is not None: | |
| return cached["model"], cached["device"] | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| filename = _resolve_ckpt_filename(model_type, sampling_rate, task) | |
| ckpt_local = hf_hub_download(repo_id=HF_MODEL_REPO, filename=filename) | |
| if model_type == "single_linear": | |
| model = sl_core.IMNLightning.load_from_checkpoint(ckpt_local, map_location=device) | |
| else: | |
| model = cat_core.IMNLightning.load_from_checkpoint(ckpt_local, map_location=device) | |
| model.eval() | |
| model.to(device) | |
| MODEL_CACHE[key] = {"path": ckpt_local, "model": model, "device": device} | |
| return model, device | |
| def load_data_binary(sampling_rate: int, task: str) -> Dict[str, Any]: | |
| key = (sampling_rate, task) | |
| if key in DATA_CACHE: | |
| return DATA_CACHE[key] | |
| path = DATA_FILES.get(key) | |
| if path is None: | |
| raise FileNotFoundError(f"No data file configured for (fs={sampling_rate}, task={task}).") | |
| if not os.path.isfile(path): | |
| raise FileNotFoundError( | |
| f"Data file not found at '{path}'. " | |
| "Upload a .npz with signals, labels, reports, age, sex, ecg_id." | |
| ) | |
| with np.load(path, allow_pickle=True) as npz: | |
| required = ["signals", "labels", "reports", "age", "sex", "ecg_id"] | |
| missing = [k for k in required if k not in npz] | |
| if missing: | |
| raise KeyError(f"Data file '{path}' missing keys: {missing}") | |
| data = {k: npz[k] for k in required} | |
| DATA_CACHE[key] = data | |
| return data | |
| def on_load_records( | |
| sampling_rate: int, | |
| task: str, | |
| state: Optional[dict], | |
| ): | |
| try: | |
| data = load_data_binary(int(sampling_rate), task) | |
| except Exception as e: | |
| return ( | |
| f"Load error: {e}", | |
| gr.update(choices=[], value=None), | |
| state or {}, | |
| "—", | |
| "—", | |
| ) | |
| signals = data["signals"] | |
| labels = data["labels"] | |
| reports = data["reports"] | |
| age = data["age"] | |
| sex = data["sex"] | |
| ecg_id = data["ecg_id"] | |
| N, C, L = signals.shape | |
| pos_class = TASK_TO_POS.get(task, "MI") | |
| records: List[Dict[str, Any]] = [] | |
| for i in range(N): | |
| gt = pos_class if float(labels[i]) >= 0.5 else "NORM" | |
| records.append( | |
| { | |
| "index": int(i), | |
| "ecg_id": int(ecg_id[i]), | |
| "gt": gt, | |
| "report": str(reports[i]) if reports is not None else "", | |
| "age": age[i] if age is not None else "", | |
| "sex": str(sex[i]) if sex is not None else "", | |
| } | |
| ) | |
| choices = [f"{r['index']} | {r['ecg_id']} | {r['gt']} | age {r['age']} {r['sex']}" for r in records] | |
| value = choices[0] if choices else None | |
| state = { | |
| "records": records, | |
| "fs": int(sampling_rate), | |
| "task": task, | |
| "pos_class": pos_class, | |
| } | |
| report = (records[0]["report"] or "(no clinical notes)") if records else "—" | |
| gt = records[0]["gt"] if records else "—" | |
| status = ( | |
| f"Loaded {N} examples (fs={sampling_rate}Hz, {pos_class} vs NORM, L={L})." | |
| if N > 0 | |
| else "No examples found in data file." | |
| ) | |
| return status, gr.update(choices=choices, value=value), state, report, gt | |
| def on_select_record(choice: str, state: Optional[dict]): | |
| if not state or not state.get("records") or not choice: | |
| return "—", "—" | |
| try: | |
| idx = int(choice.split("|")[0].strip()) | |
| except Exception: | |
| return "—", "—" | |
| for r in state["records"]: | |
| if r["index"] == idx: | |
| return r["report"] or "(no clinical notes)", r["gt"] | |
| return "—", "—" | |
| def explain_record( | |
| model_type: str, | |
| sampling_rate: int, | |
| task: str, | |
| record_choice: str, | |
| state: Optional[dict], | |
| window: int, | |
| stride: int, | |
| topk_leads: int, | |
| topk_segments: int, | |
| remove_leads: bool, | |
| n_remove_leads: int, | |
| remove_segments: bool, | |
| n_remove_segments: int, | |
| ): | |
| err = "Select a record and Load records first.", None, "—", "—", "—", "—", "—", "—" | |
| if not state or not state.get("records") or not record_choice: | |
| return err | |
| try: | |
| rec_idx = int(record_choice.split("|")[0].strip()) | |
| except Exception: | |
| return err | |
| rec = next((r for r in state["records"] if r["index"] == rec_idx), None) | |
| if not rec: | |
| return err | |
| fs = state["fs"] | |
| pos_class_name = state.get("pos_class", "MI") | |
| report = rec["report"] or "(no clinical notes)" | |
| gt = rec["gt"] | |
| try: | |
| data = load_data_binary(int(sampling_rate), task) | |
| except Exception as e: | |
| return f"Data error: {e}", None, report, gt, "—", "—", "—", "—" | |
| try: | |
| model, device = load_imn_model(model_type, int(sampling_rate), task) | |
| except Exception as e: | |
| return f"Checkpoint error: {e}", None, report, gt, "—", "—", "—", "—" | |
| signals = data["signals"] | |
| if rec_idx < 0 or rec_idx >= signals.shape[0]: | |
| return f"Invalid record index {rec_idx}.", None, report, gt, "—", "—", "—", "—" | |
| x = signals[rec_idx] # [12, L] | |
| if x.shape[0] != 12: | |
| return f"Expected 12 leads, got {x.shape[0]}.", None, report, gt, "—", "—", "—", "—" | |
| signal_len_model = int(model.hparams["signal_len"]) | |
| if x.shape[1] != signal_len_model: | |
| return ( | |
| f"ECG length {x.shape[1]} != model {signal_len_model}. " | |
| "Ensure data binaries match the training window length.", | |
| None, | |
| report, | |
| gt, | |
| "—", | |
| "—", | |
| "—", | |
| "—", | |
| ) | |
| x = zscore_per_lead(x) | |
| x_t = torch.from_numpy(x).float().unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| logits, gen_w, gen_b = model.model(x_t) | |
| if model_type == "single_linear": | |
| logit = logits.squeeze() | |
| prob_pos = float(torch.sigmoid(logit).item()) | |
| w_used = gen_w[0, 0, :, :].cpu().numpy() | |
| else: | |
| probs = torch.softmax(logits, dim=1) | |
| prob_pos = float(probs[0, 1].item()) | |
| w_used = gen_w[0, 1, :, :].cpu().numpy() | |
| # Ensure int types for sliders / numbers | |
| window = max(1, int(window)) | |
| stride = max(1, int(stride)) | |
| topk_leads = int(topk_leads) | |
| topk_segments = int(topk_segments) | |
| n_remove_leads = int(n_remove_leads) | |
| n_remove_segments = int(n_remove_segments) | |
| x_np = x.astype(np.float64) | |
| impact = w_used * x_np # [12, L] | |
| seg_hm = sl_core.imn_weights_to_segments(impact, window=window, stride=stride) # [12, T] | |
| # Build advanced figure with important leads/segments highlighted | |
| L = x_np.shape[1] | |
| T = seg_hm.shape[1] | |
| # Top-k leads: rank by SIGNED contribution (impact) to the positive logit. | |
| lead_imp_signed = impact.sum(axis=1) # [12] total contribution per lead | |
| k_leads = min(max(0, topk_leads), 12) | |
| top_leads = np.argsort(lead_imp_signed)[::-1][:k_leads].tolist() if k_leads else [] | |
| # Top-k segments: rank by SIGNED segment contribution (mean impact per segment) | |
| seg_signed = np.zeros((12, T), dtype=np.float64) | |
| for t in range(T): | |
| s = t * stride | |
| e = min(t * stride + window, L) | |
| seg_signed[:, t] = impact[:, s:e].mean(axis=1) | |
| seg_flat = seg_signed.flatten() | |
| k_seg = min(max(0, topk_segments), seg_flat.size) | |
| top_flat_idx = np.argsort(seg_flat)[::-1][:k_seg] | |
| top_segments = [(idx // T, idx % T) for idx in top_flat_idx] if k_seg else [] | |
| # Optional ablation (remove top leads/segments and recompute P(pos_class)) | |
| prob_abl: Optional[float] = None | |
| nr = n_remove_leads if remove_leads else 0 | |
| ns = n_remove_segments if remove_segments else 0 | |
| removed_leads: list[int] = [] | |
| removed_segments: list[tuple[int, int]] = [] | |
| if (nr or ns) and (top_leads or top_segments): | |
| if model_type == "single_linear": | |
| prob_abl = ablate_and_recompute_imn_single( | |
| gen_w, | |
| x_t, | |
| gen_b, | |
| top_leads, | |
| top_segments, | |
| nr, | |
| ns, | |
| window, | |
| stride, | |
| L, | |
| ) | |
| else: | |
| prob_abl = ablate_and_recompute_imn_categorical( | |
| logits, | |
| gen_w, | |
| x_t, | |
| gen_b, | |
| top_leads, | |
| top_segments, | |
| nr, | |
| ns, | |
| window, | |
| stride, | |
| L, | |
| pos_class_idx=1, | |
| ) | |
| if nr > 0 and top_leads: | |
| removed_leads = top_leads[:nr] | |
| if ns > 0 and top_segments: | |
| removed_segments = top_segments[:ns] | |
| pred = pos_class_name if prob_pos >= 0.5 else "NORM" | |
| fig = build_fig_imn_with_highlights( | |
| x_np, | |
| seg_hm, | |
| window, | |
| stride, | |
| T, | |
| pred, | |
| prob_pos, | |
| pos_class_name, | |
| fs, | |
| top_leads=top_leads or None, | |
| top_segments=top_segments or None, | |
| removed_leads=removed_leads or None, | |
| removed_segments=removed_segments or None, | |
| prob_abl=prob_abl, | |
| lead_imp_signed=lead_imp_signed, | |
| ) | |
| top_leads_str = ", ".join(LEAD_NAMES[i] for i in top_leads) if top_leads else "—" | |
| top_segments_str = ( | |
| ", ".join(f"({LEAD_NAMES[l]},{t})" for l, t in top_segments[:12]) | |
| if top_segments | |
| else "—" | |
| ) | |
| if len(top_segments) > 12: | |
| top_segments_str += " ..." | |
| def _fmt_prob(p: float) -> str: | |
| return f"{p:.4f}" if p < 0.001 else f"{p:.3f}" | |
| abl_str = f"Ablated P({pos_class_name}) = {_fmt_prob(prob_abl)}" if prob_abl is not None else "—" | |
| summary = ( | |
| f"**{pred}** | P({pos_class_name}) = {prob_pos:.3f}" | |
| + (f" → **{_fmt_prob(prob_abl)}** (after ablation)" if prob_abl is not None else "") | |
| + f" | Ground truth: **{gt}** | fs={fs}Hz, window={window}, stride={stride}" | |
| ) | |
| return summary, fig, report, gt, f"{rec['ecg_id']}", top_leads_str, top_segments_str, abl_str | |
| def main(): | |
| demo = gr.Blocks( | |
| title="MesomorphicECG XAI (IMN categorical + single-linear)", | |
| theme=gr.themes.Soft(), | |
| ) | |
| with demo: | |
| gr.Markdown( | |
| "# MesomorphicECG XAI\n" | |
| "Interactive XAI viewer for Interpretable Mesomorphic Networks (IMN) on PTB-XL ECGs.\n\n" | |
| "- Models and checkpoints from " | |
| "[SEARCH-IHI/mesomorphicECG](https://huggingface.co/SEARCH-IHI/mesomorphicECG).\n" | |
| "- Data samples loaded from binary `.npz` files stored in this Space.\n" | |
| "- Heatmaps show segment-wise IMN contribution per lead." | |
| ) | |
| with gr.Row(): | |
| sampling_rate = gr.Radio( | |
| label="Sampling rate", | |
| choices=[100, 500], | |
| value=500, | |
| ) | |
| model_type = gr.Radio( | |
| label="Model type", | |
| choices=["single_linear", "categorical"], | |
| value="single_linear", | |
| info="single_linear: single linear head; categorical: 2-class head.", | |
| ) | |
| task = gr.Radio( | |
| label="Task (positive class vs NORM)", | |
| choices=list(TASK_TO_POS.keys()), | |
| value="norm_vs_mi", | |
| ) | |
| load_btn = gr.Button("Load records", variant="secondary") | |
| load_status = gr.Markdown() | |
| records_state = gr.State(value=None) | |
| with gr.Row(): | |
| record_dd = gr.Dropdown( | |
| label="Record (index | ecg_id | GT | age sex)", | |
| choices=[], | |
| value=None, | |
| ) | |
| with gr.Row(): | |
| clinical_notes = gr.Textbox( | |
| label="Clinical notes (report)", | |
| value="", | |
| lines=4, | |
| max_lines=8, | |
| interactive=False, | |
| ) | |
| ground_truth = gr.Textbox( | |
| label="Ground truth", | |
| value="—", | |
| interactive=False, | |
| ) | |
| load_btn.click( | |
| fn=on_load_records, | |
| inputs=[sampling_rate, task, records_state], | |
| outputs=[load_status, record_dd, records_state, clinical_notes, ground_truth], | |
| ) | |
| record_dd.change( | |
| fn=on_select_record, | |
| inputs=[record_dd, records_state], | |
| outputs=[clinical_notes, ground_truth], | |
| ) | |
| gr.Markdown("### Window & stride (segment aggregation)") | |
| with gr.Row(): | |
| window = gr.Slider( | |
| label="Window size", | |
| minimum=10, | |
| maximum=500, | |
| value=50, | |
| step=5, | |
| info="Segment width for aggregating point-wise attributions", | |
| ) | |
| stride = gr.Slider( | |
| label="Stride", | |
| minimum=5, | |
| maximum=250, | |
| value=25, | |
| step=5, | |
| info="Step between segments (typically window/2)", | |
| ) | |
| gr.Markdown( | |
| "### Top-k leads & segments\n" | |
| "Ranked by **signed** contribution (w·x) to the positive logit: " | |
| "highest positive = most evidence for the selected positive class. " | |
| "Removing them should decrease P(pos_class) for positive-predicting leads." | |
| ) | |
| with gr.Row(): | |
| topk_leads = gr.Number( | |
| label="Top-k leads", | |
| value=3, | |
| minimum=0, | |
| maximum=12, | |
| step=1, | |
| ) | |
| topk_segments = gr.Number( | |
| label="Top-k segments", | |
| value=10, | |
| minimum=0, | |
| maximum=200, | |
| step=1, | |
| ) | |
| gr.Markdown("### Ablation (remove top leads/segments and recompute)") | |
| with gr.Row(): | |
| remove_leads = gr.Checkbox(label="Remove top leads", value=False) | |
| n_remove_leads = gr.Number( | |
| label="Num. leads to remove", | |
| value=1, | |
| minimum=0, | |
| maximum=12, | |
| step=1, | |
| ) | |
| remove_segments = gr.Checkbox(label="Remove top segments", value=False) | |
| n_remove_segments = gr.Number( | |
| label="Num. segments to remove", | |
| value=5, | |
| minimum=0, | |
| maximum=100, | |
| step=1, | |
| ) | |
| run_btn = gr.Button("Run IMN explanation", variant="primary") | |
| out_summary = gr.Markdown() | |
| out_plot = gr.Plot() | |
| out_notes = gr.Textbox(label="Clinical notes", lines=3, interactive=False) | |
| out_gt = gr.Textbox(label="Ground truth", interactive=False) | |
| out_meta = gr.Textbox(label="ECG ID", interactive=False) | |
| out_leads = gr.Textbox(label="Top leads", interactive=False) | |
| out_segments = gr.Textbox(label="Top segments (lead, seg_idx)", interactive=False) | |
| out_abl = gr.Textbox(label="Ablated", interactive=False) | |
| run_btn.click( | |
| fn=explain_record, | |
| inputs=[ | |
| model_type, | |
| sampling_rate, | |
| task, | |
| record_dd, | |
| records_state, | |
| window, | |
| stride, | |
| topk_leads, | |
| topk_segments, | |
| remove_leads, | |
| n_remove_leads, | |
| remove_segments, | |
| n_remove_segments, | |
| ], | |
| outputs=[ | |
| out_summary, | |
| out_plot, | |
| out_notes, | |
| out_gt, | |
| out_meta, | |
| out_leads, | |
| out_segments, | |
| out_abl, | |
| ], | |
| ) | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() | |