vlbthambawita's picture
updated app
dec6141
from __future__ import annotations
"""
MesomorphicECG XAI Gradio app for Hugging Face Spaces.
This version focuses on:
- Selecting sampling rate (100 / 500 Hz), model type (categorical vs single-linear),
and task (norm_vs_cd / norm_vs_hyp / norm_vs_mi / norm_vs_sttc).
- Loading pre-packaged ECG examples from local binary .npz files in this Space.
- Downloading the corresponding IMN checkpoint from
`SEARCH-IHI/mesomorphicECG` on the Hugging Face Hub.
- Running inference and visualizing intrinsic feature attributions
(Impact = w * x) as a lead × segment heatmap plus per-lead ECG traces.
Data binaries
-------------
For each (sampling_rate, task) pair you should provide a `.npz` file as
configured in DATA_FILES below, with keys:
signals : float32 array [N, 12, L]
labels : float32/int array [N] with 0 (NORM) / 1 (POS_CLASS)
reports : object array [N] of clinical notes
age : array [N]
sex : object array [N]
ecg_id : array [N]
"""
import os
from functools import lru_cache
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt # noqa: E402
import gradio as gr # noqa: E402
from huggingface_hub import hf_hub_download, list_repo_files # noqa: E402
import single_linear_imn_core as sl_core # noqa: E402
import categorical_imn_core as cat_core # noqa: E402
HF_MODEL_REPO = "SEARCH-IHI/mesomorphicECG"
TASK_TO_POS = {
"norm_vs_mi": "MI",
"norm_vs_sttc": "STTC",
"norm_vs_cd": "CD",
"norm_vs_hyp": "HYP",
}
LEAD_NAMES = sl_core.DEFAULT_LEAD_NAMES
# Mapping from (sampling_rate, task) -> local data binary.
DATA_FILES: Dict[Tuple[int, str], str] = {
# 100 Hz
(100, "norm_vs_cd"): "data/ptbxl_100hz_norm_vs_cd_test.npz",
(100, "norm_vs_hyp"): "data/ptbxl_100hz_norm_vs_hyp_test.npz",
(100, "norm_vs_mi"): "data/ptbxl_100hz_norm_vs_mi_test.npz",
(100, "norm_vs_sttc"): "data/ptbxl_100hz_norm_vs_sttc_test.npz",
# 500 Hz
(500, "norm_vs_cd"): "data/ptbxl_500hz_norm_vs_cd_test.npz",
(500, "norm_vs_hyp"): "data/ptbxl_500hz_norm_vs_hyp_test.npz",
(500, "norm_vs_mi"): "data/ptbxl_500hz_norm_vs_mi_test.npz",
(500, "norm_vs_sttc"): "data/ptbxl_500hz_norm_vs_sttc_test.npz",
}
DATA_CACHE: Dict[Tuple[int, str], Dict[str, Any]] = {}
MODEL_CACHE: Dict[Tuple[str, int, str], Dict[str, Any]] = {}
def zscore_per_lead(x: np.ndarray) -> np.ndarray:
"""Per-lead z-score normalization."""
mean = x.mean(axis=1, keepdims=True)
std = x.std(axis=1, keepdims=True).clip(min=1e-6)
return ((x - mean) / std).astype(np.float32)
def ablate_and_recompute_imn_single(
gen_w: torch.Tensor,
x_t: torch.Tensor,
gen_b: torch.Tensor,
top_leads: list[int],
top_segments: list[tuple[int, int]],
n_remove_leads: int,
n_remove_segments: int,
window: int,
stride: int,
L: int,
) -> float:
"""
Ablation helper for SINGLE-LINEAR IMN (binary logit).
gen_w: [1, 1, 12, L], x_t: [1, 12, L], gen_b: [1, 1] or [1, 1, 1].
We zero weights at selected leads/segments, keep the scalar bias unchanged,
then recompute the single logit and its sigmoid P(pos_class).
"""
# Extract [12, L] weight map
w_single = gen_w[0, 0].clone() # [12, L]
# Remove whole leads
if n_remove_leads > 0 and top_leads:
for li in top_leads[:n_remove_leads]:
w_single[li, :] = 0.0
# Remove segments
if n_remove_segments > 0 and top_segments:
for lead, t in top_segments[:n_remove_segments]:
s = t * stride
e = min(s + window, L)
w_single[lead, s:e] = 0.0
x_exp = x_t[0] # [12, L]
# Bias: handle [1, 1] or [1, 1, 1]
if gen_b.ndim == 3:
b = gen_b[0, 0, 0]
else:
b = gen_b[0, 0]
new_logit = (w_single.to(x_exp.device) * x_exp).sum() + b
prob_pos = float(torch.sigmoid(new_logit).item())
return prob_pos
def ablate_and_recompute_imn_categorical(
logits: torch.Tensor,
gen_w: torch.Tensor,
x_t: torch.Tensor,
gen_b: torch.Tensor,
top_leads: list[int],
top_segments: list[tuple[int, int]],
n_remove_leads: int,
n_remove_segments: int,
window: int,
stride: int,
L: int,
pos_class_idx: int = 1,
) -> float:
"""
Ablation helper for CATEGORICAL IMN (2-class softmax).
We zero weights for the positive class at selected leads/segments,
keep all other class logits unchanged, and recompute P(pos_class)
via softmax.
"""
# gen_w: [1, num_classes, 12, L]
w_pos = gen_w[0, pos_class_idx].clone() # [12, L]
# Remove whole leads
if n_remove_leads > 0 and top_leads:
for li in top_leads[:n_remove_leads]:
w_pos[li, :] = 0.0
# Remove segments
if n_remove_segments > 0 and top_segments:
for lead, t in top_segments[:n_remove_segments]:
s = t * stride
e = min(s + window, L)
w_pos[lead, s:e] = 0.0
x_exp = x_t[0] # [12, L]
# Bias: handle [1, num_classes] or [1, num_classes, 1]
if gen_b.ndim == 3:
b_pos = gen_b[0, pos_class_idx, 0]
else:
b_pos = gen_b[0, pos_class_idx]
new_logit_pos = (w_pos.to(x_exp.device) * x_exp).sum() + b_pos
# logits: [1, num_classes]
orig_logits = logits[0].clone()
orig_logits[pos_class_idx] = new_logit_pos
probs = torch.softmax(orig_logits, dim=0)
return float(probs[pos_class_idx].item())
def build_fig_imn_with_highlights(
x_np: np.ndarray,
seg_hm: np.ndarray,
window: int,
stride: int,
T: int,
pred: str,
prob_pos: float,
pos_class_name: str,
sampling_rate: int,
top_leads: Optional[list[int]] = None,
top_segments: Optional[list[tuple[int, int]]] = None,
removed_leads: Optional[list[int]] = None,
removed_segments: Optional[list[tuple[int, int]]] = None,
prob_abl: Optional[float] = None,
lead_imp_signed: Optional[np.ndarray] = None,
) -> plt.Figure:
"""
Build a matplotlib figure with:
- top heatmap of segment-wise importance (per lead),
- 12 ECG traces with overlays highlighting important / removed segments and leads.
"""
import matplotlib.patches as mpatches
L = x_np.shape[1]
# Per-lead contribution share: use signed contribution so percentages match top leads.
if lead_imp_signed is not None:
denom = np.abs(lead_imp_signed).sum() + 1e-9
lead_pct = 100.0 * lead_imp_signed / denom
else:
lead_abs = np.abs(seg_hm).sum(axis=1)
lead_pct = 100.0 * lead_abs / (lead_abs.sum() + 1e-9)
cmap = "Reds"
shade_color = "red"
rem_lead_set = set(removed_leads or [])
rem_seg_set = set(removed_segments or [])
top_seg_set = set(top_segments or [])
fig = plt.figure(figsize=(12, 14))
gs = fig.add_gridspec(14, 1, height_ratios=[2] + [1] * 12 + [0.5])
# Top heatmap
ax0 = fig.add_subplot(gs[0, 0])
im = ax0.imshow(seg_hm, aspect="auto", vmin=0.0, vmax=1.0, cmap=cmap)
ax0.set_yticks(range(12))
ax0.set_yticklabels(LEAD_NAMES)
ax0.set_xlabel(f"Segments (window={window}, stride={stride}, fs={sampling_rate}Hz)")
prob_str = f"P({pos_class_name})={prob_pos:.3f}"
title = f"IMN Intrinsic Explanation | {pred} | {prob_str}"
if prob_abl is not None:
p_str = f"{prob_abl:.4f}" if prob_abl < 0.001 else f"{prob_abl:.3f}"
title += f" -> Ablated P({pos_class_name}) = {p_str}"
ax0.set_title(title)
fig.colorbar(im, ax=ax0, fraction=0.02, pad=0.01)
# Highlight removed leads in the heatmap
for rl in rem_lead_set:
rect = mpatches.Rectangle(
(-0.5, rl - 0.5),
T,
1,
fill=False,
edgecolor="red",
linewidth=2.5,
zorder=10,
)
ax0.add_patch(rect)
# Lead-wise ECG traces with overlays
for lead in range(12):
ax = fig.add_subplot(gs[lead + 1, 0])
ax.plot(x_np[lead], linewidth=0.8, color="black", alpha=0.6)
ax.set_xlim(0, L - 1)
ax.set_ylabel(
f"{LEAD_NAMES[lead]} {lead_pct[lead]:.1f}%",
rotation=0,
labelpad=20,
va="center",
)
ylo, yhi = ax.get_ylim()
# Shade top leads
if top_leads and lead in top_leads:
ax.axhspan(ylo, yhi, alpha=0.15, color="gold", zorder=0)
# Mark removed whole leads
if lead in rem_lead_set:
ax.add_patch(
mpatches.Rectangle(
(0, ylo),
L - 1,
yhi - ylo,
fill=False,
edgecolor="red",
linewidth=2.5,
zorder=10,
)
)
contrib = seg_hm[lead]
for t in range(T):
a = float(contrib[t])
alpha = min(0.5, a * 0.6)
if alpha <= 0.05:
continue
start = t * stride
end = min(start + window, L)
hi = (lead, t) in top_seg_set
is_rem = (lead, t) in rem_seg_set
if is_rem:
ax.axvspan(start, end, alpha=alpha, facecolor=shade_color, zorder=0)
ax.add_patch(
mpatches.Rectangle(
(start, ylo),
end - start,
yhi - ylo,
fill=False,
edgecolor="red",
linewidth=2,
zorder=10,
)
)
elif hi:
ax.axvspan(
start,
end,
alpha=alpha,
facecolor=shade_color,
edgecolor="lime",
linewidth=1.5,
zorder=1,
)
else:
ax.axvspan(start, end, alpha=alpha, facecolor=shade_color, zorder=0)
ax.set_xticks([])
# Footer
axf = fig.add_subplot(gs[13, 0])
axf.axis("off")
leg = (
f"IMN Feature Attribution: |w(x)·x| aggregated by segment "
f"(window={window}, stride={stride}). Gold/Lime = top leads/segments by signed contribution "
f"(highest positive = most evidence for {pos_class_name}). "
)
if rem_lead_set or rem_seg_set:
leg += "Red boxes = removed (ablation)."
axf.text(0.5, 0.5, leg, fontsize=9, wrap=True, transform=axf.transAxes, ha="center", va="center")
fig.tight_layout()
return fig
@lru_cache(maxsize=None)
def _list_model_repo_files() -> List[str]:
return list_repo_files(repo_id=HF_MODEL_REPO, repo_type="model")
def _resolve_ckpt_filename(model_type: str, sampling_rate: int, task: str) -> str:
if model_type == "single_linear":
category = f"single_linear_imn_{sampling_rate}hz"
else:
category = f"categorical_imn_{sampling_rate}hz"
prefix = f"{category}/{task}/"
files = _list_model_repo_files()
candidates = [f for f in files if f.startswith(prefix) and f.endswith(".ckpt")]
if not candidates:
raise FileNotFoundError(
f"No checkpoint (.ckpt) found in repo {HF_MODEL_REPO} under {prefix}. "
"Ensure upload_best_checkpoints_to_hf.py has populated this path."
)
best_style = [f for f in candidates if "best-imn-epoch=" in f]
chosen = sorted(best_style or candidates)[-1]
return chosen
def load_imn_model(
model_type: str,
sampling_rate: int,
task: str,
) -> Tuple[torch.nn.Module, str]:
key = (model_type, sampling_rate, task)
cached = MODEL_CACHE.get(key)
if cached and cached["model"] is not None:
return cached["model"], cached["device"]
device = "cuda" if torch.cuda.is_available() else "cpu"
filename = _resolve_ckpt_filename(model_type, sampling_rate, task)
ckpt_local = hf_hub_download(repo_id=HF_MODEL_REPO, filename=filename)
if model_type == "single_linear":
model = sl_core.IMNLightning.load_from_checkpoint(ckpt_local, map_location=device)
else:
model = cat_core.IMNLightning.load_from_checkpoint(ckpt_local, map_location=device)
model.eval()
model.to(device)
MODEL_CACHE[key] = {"path": ckpt_local, "model": model, "device": device}
return model, device
def load_data_binary(sampling_rate: int, task: str) -> Dict[str, Any]:
key = (sampling_rate, task)
if key in DATA_CACHE:
return DATA_CACHE[key]
path = DATA_FILES.get(key)
if path is None:
raise FileNotFoundError(f"No data file configured for (fs={sampling_rate}, task={task}).")
if not os.path.isfile(path):
raise FileNotFoundError(
f"Data file not found at '{path}'. "
"Upload a .npz with signals, labels, reports, age, sex, ecg_id."
)
with np.load(path, allow_pickle=True) as npz:
required = ["signals", "labels", "reports", "age", "sex", "ecg_id"]
missing = [k for k in required if k not in npz]
if missing:
raise KeyError(f"Data file '{path}' missing keys: {missing}")
data = {k: npz[k] for k in required}
DATA_CACHE[key] = data
return data
def on_load_records(
sampling_rate: int,
task: str,
state: Optional[dict],
):
try:
data = load_data_binary(int(sampling_rate), task)
except Exception as e:
return (
f"Load error: {e}",
gr.update(choices=[], value=None),
state or {},
"—",
"—",
)
signals = data["signals"]
labels = data["labels"]
reports = data["reports"]
age = data["age"]
sex = data["sex"]
ecg_id = data["ecg_id"]
N, C, L = signals.shape
pos_class = TASK_TO_POS.get(task, "MI")
records: List[Dict[str, Any]] = []
for i in range(N):
gt = pos_class if float(labels[i]) >= 0.5 else "NORM"
records.append(
{
"index": int(i),
"ecg_id": int(ecg_id[i]),
"gt": gt,
"report": str(reports[i]) if reports is not None else "",
"age": age[i] if age is not None else "",
"sex": str(sex[i]) if sex is not None else "",
}
)
choices = [f"{r['index']} | {r['ecg_id']} | {r['gt']} | age {r['age']} {r['sex']}" for r in records]
value = choices[0] if choices else None
state = {
"records": records,
"fs": int(sampling_rate),
"task": task,
"pos_class": pos_class,
}
report = (records[0]["report"] or "(no clinical notes)") if records else "—"
gt = records[0]["gt"] if records else "—"
status = (
f"Loaded {N} examples (fs={sampling_rate}Hz, {pos_class} vs NORM, L={L})."
if N > 0
else "No examples found in data file."
)
return status, gr.update(choices=choices, value=value), state, report, gt
def on_select_record(choice: str, state: Optional[dict]):
if not state or not state.get("records") or not choice:
return "—", "—"
try:
idx = int(choice.split("|")[0].strip())
except Exception:
return "—", "—"
for r in state["records"]:
if r["index"] == idx:
return r["report"] or "(no clinical notes)", r["gt"]
return "—", "—"
def explain_record(
model_type: str,
sampling_rate: int,
task: str,
record_choice: str,
state: Optional[dict],
window: int,
stride: int,
topk_leads: int,
topk_segments: int,
remove_leads: bool,
n_remove_leads: int,
remove_segments: bool,
n_remove_segments: int,
):
err = "Select a record and Load records first.", None, "—", "—", "—", "—", "—", "—"
if not state or not state.get("records") or not record_choice:
return err
try:
rec_idx = int(record_choice.split("|")[0].strip())
except Exception:
return err
rec = next((r for r in state["records"] if r["index"] == rec_idx), None)
if not rec:
return err
fs = state["fs"]
pos_class_name = state.get("pos_class", "MI")
report = rec["report"] or "(no clinical notes)"
gt = rec["gt"]
try:
data = load_data_binary(int(sampling_rate), task)
except Exception as e:
return f"Data error: {e}", None, report, gt, "—", "—", "—", "—"
try:
model, device = load_imn_model(model_type, int(sampling_rate), task)
except Exception as e:
return f"Checkpoint error: {e}", None, report, gt, "—", "—", "—", "—"
signals = data["signals"]
if rec_idx < 0 or rec_idx >= signals.shape[0]:
return f"Invalid record index {rec_idx}.", None, report, gt, "—", "—", "—", "—"
x = signals[rec_idx] # [12, L]
if x.shape[0] != 12:
return f"Expected 12 leads, got {x.shape[0]}.", None, report, gt, "—", "—", "—", "—"
signal_len_model = int(model.hparams["signal_len"])
if x.shape[1] != signal_len_model:
return (
f"ECG length {x.shape[1]} != model {signal_len_model}. "
"Ensure data binaries match the training window length.",
None,
report,
gt,
"—",
"—",
"—",
"—",
)
x = zscore_per_lead(x)
x_t = torch.from_numpy(x).float().unsqueeze(0).to(device)
with torch.no_grad():
logits, gen_w, gen_b = model.model(x_t)
if model_type == "single_linear":
logit = logits.squeeze()
prob_pos = float(torch.sigmoid(logit).item())
w_used = gen_w[0, 0, :, :].cpu().numpy()
else:
probs = torch.softmax(logits, dim=1)
prob_pos = float(probs[0, 1].item())
w_used = gen_w[0, 1, :, :].cpu().numpy()
# Ensure int types for sliders / numbers
window = max(1, int(window))
stride = max(1, int(stride))
topk_leads = int(topk_leads)
topk_segments = int(topk_segments)
n_remove_leads = int(n_remove_leads)
n_remove_segments = int(n_remove_segments)
x_np = x.astype(np.float64)
impact = w_used * x_np # [12, L]
seg_hm = sl_core.imn_weights_to_segments(impact, window=window, stride=stride) # [12, T]
# Build advanced figure with important leads/segments highlighted
L = x_np.shape[1]
T = seg_hm.shape[1]
# Top-k leads: rank by SIGNED contribution (impact) to the positive logit.
lead_imp_signed = impact.sum(axis=1) # [12] total contribution per lead
k_leads = min(max(0, topk_leads), 12)
top_leads = np.argsort(lead_imp_signed)[::-1][:k_leads].tolist() if k_leads else []
# Top-k segments: rank by SIGNED segment contribution (mean impact per segment)
seg_signed = np.zeros((12, T), dtype=np.float64)
for t in range(T):
s = t * stride
e = min(t * stride + window, L)
seg_signed[:, t] = impact[:, s:e].mean(axis=1)
seg_flat = seg_signed.flatten()
k_seg = min(max(0, topk_segments), seg_flat.size)
top_flat_idx = np.argsort(seg_flat)[::-1][:k_seg]
top_segments = [(idx // T, idx % T) for idx in top_flat_idx] if k_seg else []
# Optional ablation (remove top leads/segments and recompute P(pos_class))
prob_abl: Optional[float] = None
nr = n_remove_leads if remove_leads else 0
ns = n_remove_segments if remove_segments else 0
removed_leads: list[int] = []
removed_segments: list[tuple[int, int]] = []
if (nr or ns) and (top_leads or top_segments):
if model_type == "single_linear":
prob_abl = ablate_and_recompute_imn_single(
gen_w,
x_t,
gen_b,
top_leads,
top_segments,
nr,
ns,
window,
stride,
L,
)
else:
prob_abl = ablate_and_recompute_imn_categorical(
logits,
gen_w,
x_t,
gen_b,
top_leads,
top_segments,
nr,
ns,
window,
stride,
L,
pos_class_idx=1,
)
if nr > 0 and top_leads:
removed_leads = top_leads[:nr]
if ns > 0 and top_segments:
removed_segments = top_segments[:ns]
pred = pos_class_name if prob_pos >= 0.5 else "NORM"
fig = build_fig_imn_with_highlights(
x_np,
seg_hm,
window,
stride,
T,
pred,
prob_pos,
pos_class_name,
fs,
top_leads=top_leads or None,
top_segments=top_segments or None,
removed_leads=removed_leads or None,
removed_segments=removed_segments or None,
prob_abl=prob_abl,
lead_imp_signed=lead_imp_signed,
)
top_leads_str = ", ".join(LEAD_NAMES[i] for i in top_leads) if top_leads else "—"
top_segments_str = (
", ".join(f"({LEAD_NAMES[l]},{t})" for l, t in top_segments[:12])
if top_segments
else "—"
)
if len(top_segments) > 12:
top_segments_str += " ..."
def _fmt_prob(p: float) -> str:
return f"{p:.4f}" if p < 0.001 else f"{p:.3f}"
abl_str = f"Ablated P({pos_class_name}) = {_fmt_prob(prob_abl)}" if prob_abl is not None else "—"
summary = (
f"**{pred}** | P({pos_class_name}) = {prob_pos:.3f}"
+ (f" → **{_fmt_prob(prob_abl)}** (after ablation)" if prob_abl is not None else "")
+ f" | Ground truth: **{gt}** | fs={fs}Hz, window={window}, stride={stride}"
)
return summary, fig, report, gt, f"{rec['ecg_id']}", top_leads_str, top_segments_str, abl_str
def main():
demo = gr.Blocks(
title="MesomorphicECG XAI (IMN categorical + single-linear)",
theme=gr.themes.Soft(),
)
with demo:
gr.Markdown(
"# MesomorphicECG XAI\n"
"Interactive XAI viewer for Interpretable Mesomorphic Networks (IMN) on PTB-XL ECGs.\n\n"
"- Models and checkpoints from "
"[SEARCH-IHI/mesomorphicECG](https://huggingface.co/SEARCH-IHI/mesomorphicECG).\n"
"- Data samples loaded from binary `.npz` files stored in this Space.\n"
"- Heatmaps show segment-wise IMN contribution per lead."
)
with gr.Row():
sampling_rate = gr.Radio(
label="Sampling rate",
choices=[100, 500],
value=500,
)
model_type = gr.Radio(
label="Model type",
choices=["single_linear", "categorical"],
value="single_linear",
info="single_linear: single linear head; categorical: 2-class head.",
)
task = gr.Radio(
label="Task (positive class vs NORM)",
choices=list(TASK_TO_POS.keys()),
value="norm_vs_mi",
)
load_btn = gr.Button("Load records", variant="secondary")
load_status = gr.Markdown()
records_state = gr.State(value=None)
with gr.Row():
record_dd = gr.Dropdown(
label="Record (index | ecg_id | GT | age sex)",
choices=[],
value=None,
)
with gr.Row():
clinical_notes = gr.Textbox(
label="Clinical notes (report)",
value="",
lines=4,
max_lines=8,
interactive=False,
)
ground_truth = gr.Textbox(
label="Ground truth",
value="—",
interactive=False,
)
load_btn.click(
fn=on_load_records,
inputs=[sampling_rate, task, records_state],
outputs=[load_status, record_dd, records_state, clinical_notes, ground_truth],
)
record_dd.change(
fn=on_select_record,
inputs=[record_dd, records_state],
outputs=[clinical_notes, ground_truth],
)
gr.Markdown("### Window & stride (segment aggregation)")
with gr.Row():
window = gr.Slider(
label="Window size",
minimum=10,
maximum=500,
value=50,
step=5,
info="Segment width for aggregating point-wise attributions",
)
stride = gr.Slider(
label="Stride",
minimum=5,
maximum=250,
value=25,
step=5,
info="Step between segments (typically window/2)",
)
gr.Markdown(
"### Top-k leads & segments\n"
"Ranked by **signed** contribution (w·x) to the positive logit: "
"highest positive = most evidence for the selected positive class. "
"Removing them should decrease P(pos_class) for positive-predicting leads."
)
with gr.Row():
topk_leads = gr.Number(
label="Top-k leads",
value=3,
minimum=0,
maximum=12,
step=1,
)
topk_segments = gr.Number(
label="Top-k segments",
value=10,
minimum=0,
maximum=200,
step=1,
)
gr.Markdown("### Ablation (remove top leads/segments and recompute)")
with gr.Row():
remove_leads = gr.Checkbox(label="Remove top leads", value=False)
n_remove_leads = gr.Number(
label="Num. leads to remove",
value=1,
minimum=0,
maximum=12,
step=1,
)
remove_segments = gr.Checkbox(label="Remove top segments", value=False)
n_remove_segments = gr.Number(
label="Num. segments to remove",
value=5,
minimum=0,
maximum=100,
step=1,
)
run_btn = gr.Button("Run IMN explanation", variant="primary")
out_summary = gr.Markdown()
out_plot = gr.Plot()
out_notes = gr.Textbox(label="Clinical notes", lines=3, interactive=False)
out_gt = gr.Textbox(label="Ground truth", interactive=False)
out_meta = gr.Textbox(label="ECG ID", interactive=False)
out_leads = gr.Textbox(label="Top leads", interactive=False)
out_segments = gr.Textbox(label="Top segments (lead, seg_idx)", interactive=False)
out_abl = gr.Textbox(label="Ablated", interactive=False)
run_btn.click(
fn=explain_record,
inputs=[
model_type,
sampling_rate,
task,
record_dd,
records_state,
window,
stride,
topk_leads,
topk_segments,
remove_leads,
n_remove_leads,
remove_segments,
n_remove_segments,
],
outputs=[
out_summary,
out_plot,
out_notes,
out_gt,
out_meta,
out_leads,
out_segments,
out_abl,
],
)
demo.launch()
if __name__ == "__main__":
main()