from __future__ import annotations """ MesomorphicECG XAI Gradio app for Hugging Face Spaces. This version focuses on: - Selecting sampling rate (100 / 500 Hz), model type (categorical vs single-linear), and task (norm_vs_cd / norm_vs_hyp / norm_vs_mi / norm_vs_sttc). - Loading pre-packaged ECG examples from local binary .npz files in this Space. - Downloading the corresponding IMN checkpoint from `SEARCH-IHI/mesomorphicECG` on the Hugging Face Hub. - Running inference and visualizing intrinsic feature attributions (Impact = w * x) as a lead × segment heatmap plus per-lead ECG traces. Data binaries ------------- For each (sampling_rate, task) pair you should provide a `.npz` file as configured in DATA_FILES below, with keys: signals : float32 array [N, 12, L] labels : float32/int array [N] with 0 (NORM) / 1 (POS_CLASS) reports : object array [N] of clinical notes age : array [N] sex : object array [N] ecg_id : array [N] """ import os from functools import lru_cache from typing import Any, Dict, List, Optional, Tuple import numpy as np import torch import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt # noqa: E402 import gradio as gr # noqa: E402 from huggingface_hub import hf_hub_download, list_repo_files # noqa: E402 import single_linear_imn_core as sl_core # noqa: E402 import categorical_imn_core as cat_core # noqa: E402 HF_MODEL_REPO = "SEARCH-IHI/mesomorphicECG" TASK_TO_POS = { "norm_vs_mi": "MI", "norm_vs_sttc": "STTC", "norm_vs_cd": "CD", "norm_vs_hyp": "HYP", } LEAD_NAMES = sl_core.DEFAULT_LEAD_NAMES # Mapping from (sampling_rate, task) -> local data binary. DATA_FILES: Dict[Tuple[int, str], str] = { # 100 Hz (100, "norm_vs_cd"): "data/ptbxl_100hz_norm_vs_cd_test.npz", (100, "norm_vs_hyp"): "data/ptbxl_100hz_norm_vs_hyp_test.npz", (100, "norm_vs_mi"): "data/ptbxl_100hz_norm_vs_mi_test.npz", (100, "norm_vs_sttc"): "data/ptbxl_100hz_norm_vs_sttc_test.npz", # 500 Hz (500, "norm_vs_cd"): "data/ptbxl_500hz_norm_vs_cd_test.npz", (500, "norm_vs_hyp"): "data/ptbxl_500hz_norm_vs_hyp_test.npz", (500, "norm_vs_mi"): "data/ptbxl_500hz_norm_vs_mi_test.npz", (500, "norm_vs_sttc"): "data/ptbxl_500hz_norm_vs_sttc_test.npz", } DATA_CACHE: Dict[Tuple[int, str], Dict[str, Any]] = {} MODEL_CACHE: Dict[Tuple[str, int, str], Dict[str, Any]] = {} def zscore_per_lead(x: np.ndarray) -> np.ndarray: """Per-lead z-score normalization.""" mean = x.mean(axis=1, keepdims=True) std = x.std(axis=1, keepdims=True).clip(min=1e-6) return ((x - mean) / std).astype(np.float32) def ablate_and_recompute_imn_single( gen_w: torch.Tensor, x_t: torch.Tensor, gen_b: torch.Tensor, top_leads: list[int], top_segments: list[tuple[int, int]], n_remove_leads: int, n_remove_segments: int, window: int, stride: int, L: int, ) -> float: """ Ablation helper for SINGLE-LINEAR IMN (binary logit). gen_w: [1, 1, 12, L], x_t: [1, 12, L], gen_b: [1, 1] or [1, 1, 1]. We zero weights at selected leads/segments, keep the scalar bias unchanged, then recompute the single logit and its sigmoid P(pos_class). """ # Extract [12, L] weight map w_single = gen_w[0, 0].clone() # [12, L] # Remove whole leads if n_remove_leads > 0 and top_leads: for li in top_leads[:n_remove_leads]: w_single[li, :] = 0.0 # Remove segments if n_remove_segments > 0 and top_segments: for lead, t in top_segments[:n_remove_segments]: s = t * stride e = min(s + window, L) w_single[lead, s:e] = 0.0 x_exp = x_t[0] # [12, L] # Bias: handle [1, 1] or [1, 1, 1] if gen_b.ndim == 3: b = gen_b[0, 0, 0] else: b = gen_b[0, 0] new_logit = (w_single.to(x_exp.device) * x_exp).sum() + b prob_pos = float(torch.sigmoid(new_logit).item()) return prob_pos def ablate_and_recompute_imn_categorical( logits: torch.Tensor, gen_w: torch.Tensor, x_t: torch.Tensor, gen_b: torch.Tensor, top_leads: list[int], top_segments: list[tuple[int, int]], n_remove_leads: int, n_remove_segments: int, window: int, stride: int, L: int, pos_class_idx: int = 1, ) -> float: """ Ablation helper for CATEGORICAL IMN (2-class softmax). We zero weights for the positive class at selected leads/segments, keep all other class logits unchanged, and recompute P(pos_class) via softmax. """ # gen_w: [1, num_classes, 12, L] w_pos = gen_w[0, pos_class_idx].clone() # [12, L] # Remove whole leads if n_remove_leads > 0 and top_leads: for li in top_leads[:n_remove_leads]: w_pos[li, :] = 0.0 # Remove segments if n_remove_segments > 0 and top_segments: for lead, t in top_segments[:n_remove_segments]: s = t * stride e = min(s + window, L) w_pos[lead, s:e] = 0.0 x_exp = x_t[0] # [12, L] # Bias: handle [1, num_classes] or [1, num_classes, 1] if gen_b.ndim == 3: b_pos = gen_b[0, pos_class_idx, 0] else: b_pos = gen_b[0, pos_class_idx] new_logit_pos = (w_pos.to(x_exp.device) * x_exp).sum() + b_pos # logits: [1, num_classes] orig_logits = logits[0].clone() orig_logits[pos_class_idx] = new_logit_pos probs = torch.softmax(orig_logits, dim=0) return float(probs[pos_class_idx].item()) def build_fig_imn_with_highlights( x_np: np.ndarray, seg_hm: np.ndarray, window: int, stride: int, T: int, pred: str, prob_pos: float, pos_class_name: str, sampling_rate: int, top_leads: Optional[list[int]] = None, top_segments: Optional[list[tuple[int, int]]] = None, removed_leads: Optional[list[int]] = None, removed_segments: Optional[list[tuple[int, int]]] = None, prob_abl: Optional[float] = None, lead_imp_signed: Optional[np.ndarray] = None, ) -> plt.Figure: """ Build a matplotlib figure with: - top heatmap of segment-wise importance (per lead), - 12 ECG traces with overlays highlighting important / removed segments and leads. """ import matplotlib.patches as mpatches L = x_np.shape[1] # Per-lead contribution share: use signed contribution so percentages match top leads. if lead_imp_signed is not None: denom = np.abs(lead_imp_signed).sum() + 1e-9 lead_pct = 100.0 * lead_imp_signed / denom else: lead_abs = np.abs(seg_hm).sum(axis=1) lead_pct = 100.0 * lead_abs / (lead_abs.sum() + 1e-9) cmap = "Reds" shade_color = "red" rem_lead_set = set(removed_leads or []) rem_seg_set = set(removed_segments or []) top_seg_set = set(top_segments or []) fig = plt.figure(figsize=(12, 14)) gs = fig.add_gridspec(14, 1, height_ratios=[2] + [1] * 12 + [0.5]) # Top heatmap ax0 = fig.add_subplot(gs[0, 0]) im = ax0.imshow(seg_hm, aspect="auto", vmin=0.0, vmax=1.0, cmap=cmap) ax0.set_yticks(range(12)) ax0.set_yticklabels(LEAD_NAMES) ax0.set_xlabel(f"Segments (window={window}, stride={stride}, fs={sampling_rate}Hz)") prob_str = f"P({pos_class_name})={prob_pos:.3f}" title = f"IMN Intrinsic Explanation | {pred} | {prob_str}" if prob_abl is not None: p_str = f"{prob_abl:.4f}" if prob_abl < 0.001 else f"{prob_abl:.3f}" title += f" -> Ablated P({pos_class_name}) = {p_str}" ax0.set_title(title) fig.colorbar(im, ax=ax0, fraction=0.02, pad=0.01) # Highlight removed leads in the heatmap for rl in rem_lead_set: rect = mpatches.Rectangle( (-0.5, rl - 0.5), T, 1, fill=False, edgecolor="red", linewidth=2.5, zorder=10, ) ax0.add_patch(rect) # Lead-wise ECG traces with overlays for lead in range(12): ax = fig.add_subplot(gs[lead + 1, 0]) ax.plot(x_np[lead], linewidth=0.8, color="black", alpha=0.6) ax.set_xlim(0, L - 1) ax.set_ylabel( f"{LEAD_NAMES[lead]} {lead_pct[lead]:.1f}%", rotation=0, labelpad=20, va="center", ) ylo, yhi = ax.get_ylim() # Shade top leads if top_leads and lead in top_leads: ax.axhspan(ylo, yhi, alpha=0.15, color="gold", zorder=0) # Mark removed whole leads if lead in rem_lead_set: ax.add_patch( mpatches.Rectangle( (0, ylo), L - 1, yhi - ylo, fill=False, edgecolor="red", linewidth=2.5, zorder=10, ) ) contrib = seg_hm[lead] for t in range(T): a = float(contrib[t]) alpha = min(0.5, a * 0.6) if alpha <= 0.05: continue start = t * stride end = min(start + window, L) hi = (lead, t) in top_seg_set is_rem = (lead, t) in rem_seg_set if is_rem: ax.axvspan(start, end, alpha=alpha, facecolor=shade_color, zorder=0) ax.add_patch( mpatches.Rectangle( (start, ylo), end - start, yhi - ylo, fill=False, edgecolor="red", linewidth=2, zorder=10, ) ) elif hi: ax.axvspan( start, end, alpha=alpha, facecolor=shade_color, edgecolor="lime", linewidth=1.5, zorder=1, ) else: ax.axvspan(start, end, alpha=alpha, facecolor=shade_color, zorder=0) ax.set_xticks([]) # Footer axf = fig.add_subplot(gs[13, 0]) axf.axis("off") leg = ( f"IMN Feature Attribution: |w(x)·x| aggregated by segment " f"(window={window}, stride={stride}). Gold/Lime = top leads/segments by signed contribution " f"(highest positive = most evidence for {pos_class_name}). " ) if rem_lead_set or rem_seg_set: leg += "Red boxes = removed (ablation)." axf.text(0.5, 0.5, leg, fontsize=9, wrap=True, transform=axf.transAxes, ha="center", va="center") fig.tight_layout() return fig @lru_cache(maxsize=None) def _list_model_repo_files() -> List[str]: return list_repo_files(repo_id=HF_MODEL_REPO, repo_type="model") def _resolve_ckpt_filename(model_type: str, sampling_rate: int, task: str) -> str: if model_type == "single_linear": category = f"single_linear_imn_{sampling_rate}hz" else: category = f"categorical_imn_{sampling_rate}hz" prefix = f"{category}/{task}/" files = _list_model_repo_files() candidates = [f for f in files if f.startswith(prefix) and f.endswith(".ckpt")] if not candidates: raise FileNotFoundError( f"No checkpoint (.ckpt) found in repo {HF_MODEL_REPO} under {prefix}. " "Ensure upload_best_checkpoints_to_hf.py has populated this path." ) best_style = [f for f in candidates if "best-imn-epoch=" in f] chosen = sorted(best_style or candidates)[-1] return chosen def load_imn_model( model_type: str, sampling_rate: int, task: str, ) -> Tuple[torch.nn.Module, str]: key = (model_type, sampling_rate, task) cached = MODEL_CACHE.get(key) if cached and cached["model"] is not None: return cached["model"], cached["device"] device = "cuda" if torch.cuda.is_available() else "cpu" filename = _resolve_ckpt_filename(model_type, sampling_rate, task) ckpt_local = hf_hub_download(repo_id=HF_MODEL_REPO, filename=filename) if model_type == "single_linear": model = sl_core.IMNLightning.load_from_checkpoint(ckpt_local, map_location=device) else: model = cat_core.IMNLightning.load_from_checkpoint(ckpt_local, map_location=device) model.eval() model.to(device) MODEL_CACHE[key] = {"path": ckpt_local, "model": model, "device": device} return model, device def load_data_binary(sampling_rate: int, task: str) -> Dict[str, Any]: key = (sampling_rate, task) if key in DATA_CACHE: return DATA_CACHE[key] path = DATA_FILES.get(key) if path is None: raise FileNotFoundError(f"No data file configured for (fs={sampling_rate}, task={task}).") if not os.path.isfile(path): raise FileNotFoundError( f"Data file not found at '{path}'. " "Upload a .npz with signals, labels, reports, age, sex, ecg_id." ) with np.load(path, allow_pickle=True) as npz: required = ["signals", "labels", "reports", "age", "sex", "ecg_id"] missing = [k for k in required if k not in npz] if missing: raise KeyError(f"Data file '{path}' missing keys: {missing}") data = {k: npz[k] for k in required} DATA_CACHE[key] = data return data def on_load_records( sampling_rate: int, task: str, state: Optional[dict], ): try: data = load_data_binary(int(sampling_rate), task) except Exception as e: return ( f"Load error: {e}", gr.update(choices=[], value=None), state or {}, "—", "—", ) signals = data["signals"] labels = data["labels"] reports = data["reports"] age = data["age"] sex = data["sex"] ecg_id = data["ecg_id"] N, C, L = signals.shape pos_class = TASK_TO_POS.get(task, "MI") records: List[Dict[str, Any]] = [] for i in range(N): gt = pos_class if float(labels[i]) >= 0.5 else "NORM" records.append( { "index": int(i), "ecg_id": int(ecg_id[i]), "gt": gt, "report": str(reports[i]) if reports is not None else "", "age": age[i] if age is not None else "", "sex": str(sex[i]) if sex is not None else "", } ) choices = [f"{r['index']} | {r['ecg_id']} | {r['gt']} | age {r['age']} {r['sex']}" for r in records] value = choices[0] if choices else None state = { "records": records, "fs": int(sampling_rate), "task": task, "pos_class": pos_class, } report = (records[0]["report"] or "(no clinical notes)") if records else "—" gt = records[0]["gt"] if records else "—" status = ( f"Loaded {N} examples (fs={sampling_rate}Hz, {pos_class} vs NORM, L={L})." if N > 0 else "No examples found in data file." ) return status, gr.update(choices=choices, value=value), state, report, gt def on_select_record(choice: str, state: Optional[dict]): if not state or not state.get("records") or not choice: return "—", "—" try: idx = int(choice.split("|")[0].strip()) except Exception: return "—", "—" for r in state["records"]: if r["index"] == idx: return r["report"] or "(no clinical notes)", r["gt"] return "—", "—" def explain_record( model_type: str, sampling_rate: int, task: str, record_choice: str, state: Optional[dict], window: int, stride: int, topk_leads: int, topk_segments: int, remove_leads: bool, n_remove_leads: int, remove_segments: bool, n_remove_segments: int, ): err = "Select a record and Load records first.", None, "—", "—", "—", "—", "—", "—" if not state or not state.get("records") or not record_choice: return err try: rec_idx = int(record_choice.split("|")[0].strip()) except Exception: return err rec = next((r for r in state["records"] if r["index"] == rec_idx), None) if not rec: return err fs = state["fs"] pos_class_name = state.get("pos_class", "MI") report = rec["report"] or "(no clinical notes)" gt = rec["gt"] try: data = load_data_binary(int(sampling_rate), task) except Exception as e: return f"Data error: {e}", None, report, gt, "—", "—", "—", "—" try: model, device = load_imn_model(model_type, int(sampling_rate), task) except Exception as e: return f"Checkpoint error: {e}", None, report, gt, "—", "—", "—", "—" signals = data["signals"] if rec_idx < 0 or rec_idx >= signals.shape[0]: return f"Invalid record index {rec_idx}.", None, report, gt, "—", "—", "—", "—" x = signals[rec_idx] # [12, L] if x.shape[0] != 12: return f"Expected 12 leads, got {x.shape[0]}.", None, report, gt, "—", "—", "—", "—" signal_len_model = int(model.hparams["signal_len"]) if x.shape[1] != signal_len_model: return ( f"ECG length {x.shape[1]} != model {signal_len_model}. " "Ensure data binaries match the training window length.", None, report, gt, "—", "—", "—", "—", ) x = zscore_per_lead(x) x_t = torch.from_numpy(x).float().unsqueeze(0).to(device) with torch.no_grad(): logits, gen_w, gen_b = model.model(x_t) if model_type == "single_linear": logit = logits.squeeze() prob_pos = float(torch.sigmoid(logit).item()) w_used = gen_w[0, 0, :, :].cpu().numpy() else: probs = torch.softmax(logits, dim=1) prob_pos = float(probs[0, 1].item()) w_used = gen_w[0, 1, :, :].cpu().numpy() # Ensure int types for sliders / numbers window = max(1, int(window)) stride = max(1, int(stride)) topk_leads = int(topk_leads) topk_segments = int(topk_segments) n_remove_leads = int(n_remove_leads) n_remove_segments = int(n_remove_segments) x_np = x.astype(np.float64) impact = w_used * x_np # [12, L] seg_hm = sl_core.imn_weights_to_segments(impact, window=window, stride=stride) # [12, T] # Build advanced figure with important leads/segments highlighted L = x_np.shape[1] T = seg_hm.shape[1] # Top-k leads: rank by SIGNED contribution (impact) to the positive logit. lead_imp_signed = impact.sum(axis=1) # [12] total contribution per lead k_leads = min(max(0, topk_leads), 12) top_leads = np.argsort(lead_imp_signed)[::-1][:k_leads].tolist() if k_leads else [] # Top-k segments: rank by SIGNED segment contribution (mean impact per segment) seg_signed = np.zeros((12, T), dtype=np.float64) for t in range(T): s = t * stride e = min(t * stride + window, L) seg_signed[:, t] = impact[:, s:e].mean(axis=1) seg_flat = seg_signed.flatten() k_seg = min(max(0, topk_segments), seg_flat.size) top_flat_idx = np.argsort(seg_flat)[::-1][:k_seg] top_segments = [(idx // T, idx % T) for idx in top_flat_idx] if k_seg else [] # Optional ablation (remove top leads/segments and recompute P(pos_class)) prob_abl: Optional[float] = None nr = n_remove_leads if remove_leads else 0 ns = n_remove_segments if remove_segments else 0 removed_leads: list[int] = [] removed_segments: list[tuple[int, int]] = [] if (nr or ns) and (top_leads or top_segments): if model_type == "single_linear": prob_abl = ablate_and_recompute_imn_single( gen_w, x_t, gen_b, top_leads, top_segments, nr, ns, window, stride, L, ) else: prob_abl = ablate_and_recompute_imn_categorical( logits, gen_w, x_t, gen_b, top_leads, top_segments, nr, ns, window, stride, L, pos_class_idx=1, ) if nr > 0 and top_leads: removed_leads = top_leads[:nr] if ns > 0 and top_segments: removed_segments = top_segments[:ns] pred = pos_class_name if prob_pos >= 0.5 else "NORM" fig = build_fig_imn_with_highlights( x_np, seg_hm, window, stride, T, pred, prob_pos, pos_class_name, fs, top_leads=top_leads or None, top_segments=top_segments or None, removed_leads=removed_leads or None, removed_segments=removed_segments or None, prob_abl=prob_abl, lead_imp_signed=lead_imp_signed, ) top_leads_str = ", ".join(LEAD_NAMES[i] for i in top_leads) if top_leads else "—" top_segments_str = ( ", ".join(f"({LEAD_NAMES[l]},{t})" for l, t in top_segments[:12]) if top_segments else "—" ) if len(top_segments) > 12: top_segments_str += " ..." def _fmt_prob(p: float) -> str: return f"{p:.4f}" if p < 0.001 else f"{p:.3f}" abl_str = f"Ablated P({pos_class_name}) = {_fmt_prob(prob_abl)}" if prob_abl is not None else "—" summary = ( f"**{pred}** | P({pos_class_name}) = {prob_pos:.3f}" + (f" → **{_fmt_prob(prob_abl)}** (after ablation)" if prob_abl is not None else "") + f" | Ground truth: **{gt}** | fs={fs}Hz, window={window}, stride={stride}" ) return summary, fig, report, gt, f"{rec['ecg_id']}", top_leads_str, top_segments_str, abl_str def main(): demo = gr.Blocks( title="MesomorphicECG XAI (IMN categorical + single-linear)", theme=gr.themes.Soft(), ) with demo: gr.Markdown( "# MesomorphicECG XAI\n" "Interactive XAI viewer for Interpretable Mesomorphic Networks (IMN) on PTB-XL ECGs.\n\n" "- Models and checkpoints from " "[SEARCH-IHI/mesomorphicECG](https://huggingface.co/SEARCH-IHI/mesomorphicECG).\n" "- Data samples loaded from binary `.npz` files stored in this Space.\n" "- Heatmaps show segment-wise IMN contribution per lead." ) with gr.Row(): sampling_rate = gr.Radio( label="Sampling rate", choices=[100, 500], value=500, ) model_type = gr.Radio( label="Model type", choices=["single_linear", "categorical"], value="single_linear", info="single_linear: single linear head; categorical: 2-class head.", ) task = gr.Radio( label="Task (positive class vs NORM)", choices=list(TASK_TO_POS.keys()), value="norm_vs_mi", ) load_btn = gr.Button("Load records", variant="secondary") load_status = gr.Markdown() records_state = gr.State(value=None) with gr.Row(): record_dd = gr.Dropdown( label="Record (index | ecg_id | GT | age sex)", choices=[], value=None, ) with gr.Row(): clinical_notes = gr.Textbox( label="Clinical notes (report)", value="", lines=4, max_lines=8, interactive=False, ) ground_truth = gr.Textbox( label="Ground truth", value="—", interactive=False, ) load_btn.click( fn=on_load_records, inputs=[sampling_rate, task, records_state], outputs=[load_status, record_dd, records_state, clinical_notes, ground_truth], ) record_dd.change( fn=on_select_record, inputs=[record_dd, records_state], outputs=[clinical_notes, ground_truth], ) gr.Markdown("### Window & stride (segment aggregation)") with gr.Row(): window = gr.Slider( label="Window size", minimum=10, maximum=500, value=50, step=5, info="Segment width for aggregating point-wise attributions", ) stride = gr.Slider( label="Stride", minimum=5, maximum=250, value=25, step=5, info="Step between segments (typically window/2)", ) gr.Markdown( "### Top-k leads & segments\n" "Ranked by **signed** contribution (w·x) to the positive logit: " "highest positive = most evidence for the selected positive class. " "Removing them should decrease P(pos_class) for positive-predicting leads." ) with gr.Row(): topk_leads = gr.Number( label="Top-k leads", value=3, minimum=0, maximum=12, step=1, ) topk_segments = gr.Number( label="Top-k segments", value=10, minimum=0, maximum=200, step=1, ) gr.Markdown("### Ablation (remove top leads/segments and recompute)") with gr.Row(): remove_leads = gr.Checkbox(label="Remove top leads", value=False) n_remove_leads = gr.Number( label="Num. leads to remove", value=1, minimum=0, maximum=12, step=1, ) remove_segments = gr.Checkbox(label="Remove top segments", value=False) n_remove_segments = gr.Number( label="Num. segments to remove", value=5, minimum=0, maximum=100, step=1, ) run_btn = gr.Button("Run IMN explanation", variant="primary") out_summary = gr.Markdown() out_plot = gr.Plot() out_notes = gr.Textbox(label="Clinical notes", lines=3, interactive=False) out_gt = gr.Textbox(label="Ground truth", interactive=False) out_meta = gr.Textbox(label="ECG ID", interactive=False) out_leads = gr.Textbox(label="Top leads", interactive=False) out_segments = gr.Textbox(label="Top segments (lead, seg_idx)", interactive=False) out_abl = gr.Textbox(label="Ablated", interactive=False) run_btn.click( fn=explain_record, inputs=[ model_type, sampling_rate, task, record_dd, records_state, window, stride, topk_leads, topk_segments, remove_leads, n_remove_leads, remove_segments, n_remove_segments, ], outputs=[ out_summary, out_plot, out_notes, out_gt, out_meta, out_leads, out_segments, out_abl, ], ) demo.launch() if __name__ == "__main__": main()