hist / engine.py
jehadcheyi's picture
Upload 7 files
be9af69 verified
Raw
History Blame Contribute Delete
23.5 kB
"""
BRACS inference + PDF report engine for the HF Space.
Reuses the LEAN report logic: prediction, per-patch relevance, heatmap,
high-res zooms, patch-level showcase, and the model-report figures.
Research use only — not for clinical diagnosis.
"""
import os, io, json, pickle, tempfile
import numpy as np
import h5py
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib.patches import Rectangle
from matplotlib.collections import PatchCollection
from matplotlib.colors import LinearSegmentedColormap, Normalize
from matplotlib.cm import ScalarMappable
from scipy.ndimage import gaussian_filter
# ----------------------------------------------------------------------------
# Colors / colormaps
# ----------------------------------------------------------------------------
C_TXT, C_BG, C_GRID = "#1a1a1a", "white", "#dddddd"
C_BEN, C_MAL, C_ACC, C_NEUT = "#2a9d5c", "#d1495b", "#3a6ea5", "#7f8c8d"
CMAP_ATT = LinearSegmentedColormap.from_list("att", ["#10103000", "#3a6ea5", "#ffd93d", "#d1495b"])
CMAP_BM = LinearSegmentedColormap.from_list("bm", [C_BEN, "#eeeeee", C_MAL])
DEFAULT_REPORT = {
"header_model": "LR_concat",
"header_cv_auc": 0.982,
"header_threshold": 0.45,
"model_comparison": [
("LR_titan", "0.880 ± 0.072", "0.904"),
("LR_pooled", "0.973 ± 0.016", "0.973"),
("LR_concat", "0.972 ± 0.018", "0.973"),
("ABMIL", "0.953 ± 0.025", "—"),
],
"best_model": "LR_concat",
"confusion": {"tn": 170, "fp": 14, "fn": 22, "tp": 156},
"threshold_table": [
(0.20, 0.933, 0.864), (0.30, 0.933, 0.897),
(0.50, 0.876, 0.924), (0.70, 0.831, 0.951), (0.80, 0.764, 0.957),
],
"best_balanced_thr": 0.30,
"dataset_used": [("Benign", 184), ("Malignant", 178)],
"n_patients": 144,
"results_paragraph": (
"The proposed patient-grouped BRACS classifier was evaluated on 362 "
"whole-slide images from 144 patients. Logistic Regression using pooled "
"slide-level features achieved a patient-grouped cross-validation AUC of "
"0.973 ± 0.016 and an official test-set AUC of 0.973. At the default "
"threshold of 0.50, sensitivity was 87.6% and specificity 92.4%. Threshold "
"optimization identified 0.30 as the operating point yielding the highest "
"balanced accuracy."
),
}
# ----------------------------------------------------------------------------
# H5 loading
# ----------------------------------------------------------------------------
def load_slide_h5(path):
with h5py.File(path, "r") as f:
feats = np.asarray(f["features"][:], np.float32)
titan = np.asarray(f["titan_slide_embedding"][:], np.float32).squeeze()
if "pooled_conch_slide_feature_mean_max_std" in f:
pooled = np.asarray(f["pooled_conch_slide_feature_mean_max_std"][:], np.float32).squeeze()
else:
pooled = np.concatenate([feats.mean(0), feats.max(0), feats.std(0)])
coords = np.asarray(f["coords"][:], np.int64) if "coords" in f else None
psize = None
if "coords" in f and "patch_size_level0" in f["coords"].attrs:
psize = int(f["coords"].attrs["patch_size_level0"])
sid = f.attrs.get("slide_id", os.path.basename(path))
sid = sid.decode() if isinstance(sid, bytes) else str(sid)
return {
"feats": feats, "titan": titan, "pooled": pooled,
"concat": np.concatenate([titan, pooled]).astype(np.float32),
"coords": coords, "patch_size": psize, "slide_id": sid, "path": path,
}
# ----------------------------------------------------------------------------
# Prediction + relevance
# ----------------------------------------------------------------------------
def _artifact(bundle):
if "artifact" in bundle:
return bundle["artifact"]
return {"kind": bundle.get("kind", "linear"),
"rep": bundle.get("rep", "concat"),
"models": bundle.get("models", [])}
def predict(bundle, slide):
art = _artifact(bundle)
if art["kind"] == "linear":
x = slide[art["rep"]].reshape(1, -1)
return np.array([m.predict_proba(x)[0, 1] for m in art["models"]])
import torch
f = torch.from_numpy(slide["feats"]).float()
probs = []
for st in art["states"]:
m = build_abmil(art["in_dim"]); m.load_state_dict(st); m.eval()
with torch.no_grad():
logit, _ = m(f); probs.append(torch.sigmoid(logit).item())
return np.array(probs)
def build_abmil(in_dim, hidden=192, att=128, dropout=0.3):
import torch.nn as nn, torch
class ABMIL(nn.Module):
def __init__(self):
super().__init__()
self.proj = nn.Sequential(nn.Linear(in_dim, hidden), nn.ReLU(), nn.Dropout(dropout))
self.V = nn.Linear(hidden, att); self.U = nn.Linear(hidden, att)
self.w = nn.Linear(att, 1); self.head = nn.Linear(hidden, 1)
def forward(self, x):
h = self.proj(x)
a = torch.softmax(self.w(torch.tanh(self.V(h)) * torch.sigmoid(self.U(h))), 0)
return self.head((a * h).sum(0, keepdim=True)).squeeze(), a.squeeze(-1)
return ABMIL()
def norm01(x):
return (x - x.min()) / (np.ptp(x) + 1e-9)
def patch_relevance(bundle, slide):
feats = slide["feats"]
art = _artifact(bundle)
if art.get("kind") == "linear" and art.get("rep") == "concat" and art.get("models"):
try:
coefs = []
for m in art["models"]:
if hasattr(m, "named_steps"):
c = np.asarray(m.named_steps["clf"].coef_).squeeze()
if "pca" in m.named_steps:
c = c @ m.named_steps["pca"].components_
if "scaler" in m.named_steps:
c = c / (m.named_steps["scaler"].scale_ + 1e-9)
else:
c = np.asarray(m.coef_).squeeze()
coefs.append(c)
coef = np.mean(coefs, 0)
d = feats.shape[1]
mean_block = coef[d:2 * d]
if mean_block.shape[0] == d:
return norm01(feats @ mean_block)
except Exception:
pass
return norm01(np.linalg.norm(feats, axis=1))
# ----------------------------------------------------------------------------
# Slide image backend (OpenSlide or PIL)
# ----------------------------------------------------------------------------
class SlideImage:
def __init__(self, path):
self.path = path
try:
import openslide
self.osh = openslide.OpenSlide(path)
self.w, self.h = self.osh.dimensions
self.backend = "openslide"
except Exception:
from PIL import Image
Image.MAX_IMAGE_PIXELS = None
self.img = Image.open(path).convert("RGB")
self.w, self.h = self.img.size
self.backend = "pil"
def thumbnail(self, max_size):
scale = min(max_size / self.w, max_size / self.h, 1.0)
tw, th = max(1, int(self.w * scale)), max(1, int(self.h * scale))
if self.backend == "openslide":
thumb = self.osh.get_thumbnail((tw, th))
else:
from PIL import Image
thumb = self.img.copy(); thumb.thumbnail((tw, th))
return np.array(thumb), scale
def read_region_l0(self, x, y, w, h):
x = max(0, min(int(x), self.w - 1)); y = max(0, min(int(y), self.h - 1))
w = max(1, min(int(w), self.w - x)); h = max(1, min(int(h), self.h - y))
if self.backend == "openslide":
return np.array(self.osh.read_region((x, y), 0, (w, h)).convert("RGB"))
return np.array(self.img.crop((x, y, x + w, y + h)))
def close(self):
if self.backend == "openslide" and hasattr(self, "osh"):
self.osh.close()
def tissue_bbox(thumb_shape, coords, psize, scale, pad_frac=0.04):
h, w = thumb_shape[:2]
if coords is None or len(coords) == 0:
return 0, 0, w, h
xs, ys = coords[:, 0] * scale, coords[:, 1] * scale
sp = psize * scale
x0, y0, x1, y1 = xs.min(), ys.min(), xs.max() + sp, ys.max() + sp
pad = pad_frac * max(x1 - x0, y1 - y0)
return (int(max(0, x0 - pad)), int(max(0, y0 - pad)),
int(min(w, x1 + pad)), int(min(h, y1 + pad)))
def heatmap_canvas(shape, coords, imp, psize, scale):
h, w = shape[:2]
acc = np.zeros((h, w), np.float32); cnt = np.zeros((h, w), np.float32)
sp = max(1, int(psize * scale))
for i in range(len(imp)):
x = int(coords[i, 0] * scale); y = int(coords[i, 1] * scale)
x2, y2 = min(x + sp, w), min(y + sp, h)
acc[y:y2, x:x2] += imp[i]; cnt[y:y2, x:x2] += 1
m = cnt > 0; acc[m] /= cnt[m]
return norm01(gaussian_filter(acc, sigma=max(2.0, min(h, w) / 60)))
def overlay(ax, thumb, coords, imp, psize, scale, alpha=0.55):
ax.imshow(thumb)
if coords is not None and len(coords):
hm = heatmap_canvas(thumb.shape, coords, imp, psize, scale)
rgba = CMAP_ATT(hm); rgba[..., 3] = alpha * hm
ax.imshow(rgba)
ax.set_xticks([]); ax.set_yticks([])
# ----------------------------------------------------------------------------
# Model-report figures (pages 2-3)
# ----------------------------------------------------------------------------
def _styled_table(ax, header, rows, highlight_row=None, fontsize=11):
tbl = ax.table(cellText=rows, colLabels=header, loc="center", cellLoc="center")
tbl.auto_set_font_size(False); tbl.set_fontsize(fontsize); tbl.scale(1, 1.9)
for (r, c), cell in tbl.get_celld().items():
cell.set_edgecolor("#cccccc")
if r == 0:
cell.set_facecolor(C_ACC); cell.set_text_props(color="white", fontweight="bold")
elif highlight_row is not None and r == highlight_row:
cell.set_facecolor("#eaf4ec"); cell.set_text_props(fontweight="bold", color=C_BEN)
return tbl
def report_pages(pdf, R):
if not R:
return
# Figures 1 & 2
fig = plt.figure(figsize=(14, 10), facecolor=C_BG)
fig.suptitle("Model Development & Evaluation", fontsize=18, fontweight="bold", y=0.97, color=C_TXT)
ax1 = fig.add_axes([0.07, 0.60, 0.86, 0.27]); ax1.axis("off")
ax1.set_title("Figure 1 — Model Comparison", fontweight="bold", fontsize=12, loc="left", color=C_ACC, pad=10)
best = R.get("best_model")
rows = [[m, cv, te] for (m, cv, te) in R["model_comparison"]]
hl = next((i + 1 for i, (m, _, _) in enumerate(R["model_comparison"]) if m == best), None)
_styled_table(ax1, ["Model", "Patient-CV AUC", "Test AUC"], rows, highlight_row=hl)
cm = R["confusion"]; tn, fp, fn, tp = cm["tn"], cm["fp"], cm["fn"], cm["tp"]
sens = tp / (tp + fn + 1e-9); spec = tn / (tn + fp + 1e-9); bacc = 0.5 * (sens + spec)
axcm = fig.add_axes([0.10, 0.12, 0.36, 0.36])
axcm.set_title(f"Figure 2 — Confusion Matrix ({best}, thr=0.50)", fontweight="bold",
fontsize=12, loc="left", color=C_ACC, pad=10)
mat = np.array([[tn, fp], [fn, tp]])
axcm.imshow(mat, cmap=CMAP_BM, vmin=0, vmax=mat.max())
axcm.set_xticks([0, 1]); axcm.set_yticks([0, 1])
axcm.set_xticklabels(["Pred Benign", "Pred Malignant"], fontsize=9)
axcm.set_yticklabels(["True Benign", "True Malignant"], fontsize=9, rotation=90, va="center")
for i in range(2):
for j in range(2):
axcm.text(j, i, f"{mat[i, j]}", ha="center", va="center", fontsize=18,
fontweight="bold", color="white" if mat[i, j] > mat.max() / 2 else C_TXT)
axm = fig.add_axes([0.55, 0.12, 0.38, 0.36]); axm.axis("off")
axm.text(0.0, 0.78, f"Sensitivity = {sens:.3f}\nSpecificity = {spec:.3f}\n"
f"Balanced Accuracy = {bacc:.3f}", fontsize=13, family="monospace", va="top",
color=C_TXT, bbox=dict(boxstyle="round,pad=0.8", facecolor="#f4f6f8", edgecolor=C_ACC, lw=1.5))
pdf.savefig(fig, dpi=150); plt.close(fig)
# Figures 3 & 4 + Results
fig = plt.figure(figsize=(14, 10), facecolor=C_BG)
fig.suptitle("Threshold Analysis & Dataset", fontsize=18, fontweight="bold", y=0.97, color=C_TXT)
tt = np.array(R["threshold_table"], dtype=float)
ax3 = fig.add_axes([0.08, 0.58, 0.52, 0.30])
ax3.set_title("Figure 3 — Threshold Analysis", fontweight="bold", fontsize=12, loc="left", color=C_ACC, pad=10)
ax3.plot(tt[:, 0], tt[:, 1], "-o", color=C_MAL, lw=2, label="Sensitivity")
ax3.plot(tt[:, 0], tt[:, 2], "-s", color=C_BEN, lw=2, label="Specificity")
bthr = R.get("best_balanced_thr")
if bthr is not None:
ax3.axvline(bthr, color=C_ACC, ls="--", lw=2, label=f"Best balanced ({bthr})")
ax3.set_xlabel("Threshold"); ax3.set_ylabel("Metric"); ax3.grid(True, alpha=0.3)
ax3.legend(fontsize=8, loc="lower center"); ax3.set_ylim(0.7, 1.0)
ax3t = fig.add_axes([0.64, 0.58, 0.30, 0.30]); ax3t.axis("off")
rows = [[f"{t:.2f}", f"{s:.3f}", f"{sp:.3f}"] for (t, s, sp) in R["threshold_table"]]
hl = next((i + 1 for i, (t, _, _) in enumerate(R["threshold_table"]) if abs(t - (bthr or -1)) < 1e-9), None)
_styled_table(ax3t, ["Thr", "Sens", "Spec"], rows, highlight_row=hl, fontsize=10)
fig.text(0.08, 0.46, f"Best balanced accuracy at {bthr}. Operating point can be tuned by whether "
"missing cancers or false alarms are costlier.", fontsize=9.5, style="italic", color=C_NEUT)
used = R["dataset_used"]; total_used = sum(v for _, v in used)
ax4 = fig.add_axes([0.30, 0.27, 0.40, 0.16]); ax4.axis("off")
ax4.set_title("Figure 4 — Dataset Composition", fontweight="bold", fontsize=12, loc="center", color=C_ACC, pad=8)
rows = [[c, str(v)] for c, v in used] + [["Total Used", str(total_used)]]
_styled_table(ax4, ["Category", "Slides"], rows, highlight_row=len(rows), fontsize=10)
fig.text(0.5, 0.245, f"{R['n_patients']} unique patients • patient-grouped CV • no patient leakage",
ha="center", fontsize=9.5, style="italic", color=C_NEUT)
import textwrap
axres = fig.add_axes([0.08, 0.05, 0.84, 0.17]); axres.axis("off")
axres.text(0.0, 0.95, "Results", fontsize=13, fontweight="bold", va="top", color=C_TXT)
axres.text(0.0, 0.72, "\n".join(textwrap.wrap(R["results_paragraph"], width=110)),
fontsize=10.5, va="top", color=C_TXT, linespacing=1.6,
bbox=dict(boxstyle="round,pad=0.9", facecolor="#f4f6f8", edgecolor=C_ACC, lw=1.2))
pdf.savefig(fig, dpi=150); plt.close(fig)
# ----------------------------------------------------------------------------
# MAIN: build the PDF for ONE slide
# ----------------------------------------------------------------------------
def build_report(bundle, slide, wsi_path, out_pdf,
thumb_max=4000, zoom_context=4, threshold=None,
report_meta=None, include_model_report=True):
thr = threshold if threshold is not None else bundle.get("threshold", 0.5)
names = bundle.get("class_names", ["Benign", "Malignant"])
probs = predict(bundle, slide)
prob = float(np.mean(probs)); spread = float(np.std(probs))
conf = 1.0 - spread; pred = int(prob >= thr)
pclass_conf = prob if pred == 1 else (1.0 - prob)
psize = slide["patch_size"] or 256
coords = slide["coords"]
imp = patch_relevance(bundle, slide)
color = C_MAL if pred else C_BEN
conf_bar = "█" * int(conf * 10) + "░" * (10 - int(conf * 10))
simg = None
if wsi_path and os.path.isfile(wsi_path):
try:
simg = SlideImage(wsi_path)
except Exception:
simg = None
with PdfPages(out_pdf) as pdf:
# ---- Page 1: summary ----
fig = plt.figure(figsize=(14, 10), facecolor=C_BG)
fig.text(0.5, 0.95, "BRACS Explainability Report", ha="center",
fontsize=20, fontweight="bold", color=C_TXT)
R = report_meta or DEFAULT_REPORT
fig.text(0.5, 0.91, f"Model: {R.get('header_model','?')} | CV AUC: "
f"{R.get('header_cv_auc', float('nan')):.3f} | Threshold: {thr:.2f}",
ha="center", fontsize=10, color=C_NEUT)
ax = fig.add_axes([0.1, 0.4, 0.8, 0.4]); ax.axis("off")
rows = [["Slide", "Prediction", "Confidence", "Stability", "Patches"],
[slide["slide_id"][:30], names[pred], f"{pclass_conf*100:.1f}%",
f"{conf*100:.0f}%", str(slide["feats"].shape[0])]]
tbl = ax.table(cellText=rows[1:], colLabels=rows[0], loc="center", cellLoc="center")
tbl.auto_set_font_size(False); tbl.set_fontsize(11); tbl.scale(1, 2.2)
for (r, c), cell in tbl.get_celld().items():
if r == 0:
cell.set_facecolor(C_ACC); cell.set_text_props(color="white", fontweight="bold")
elif c == 1:
cell.set_text_props(color=color, fontweight="bold")
fig.text(0.5, 0.30, "Research use only — not for clinical diagnosis.",
ha="center", fontsize=9, style="italic", color=C_NEUT)
pdf.savefig(fig, dpi=150); plt.close(fig)
# ---- Model report pages ----
if include_model_report:
report_pages(pdf, R)
# ---- Per-slide analysis page ----
fig = plt.figure(figsize=(16, 10), facecolor=C_BG)
gs = gridspec.GridSpec(2, 3, height_ratios=[1.3, 1], hspace=0.28,
wspace=0.18, width_ratios=[1, 1, 0.55])
fig.suptitle(f"{slide['slide_id']} | {names[pred]} confidence {pclass_conf*100:.1f}% [{conf_bar}]",
color=color, fontsize=14, fontweight="bold", y=0.99)
thumb = scale = bb = None
if simg is not None:
thumb, scale = simg.thumbnail(thumb_max)
bb = tissue_bbox(thumb.shape, coords, psize, scale)
axo = fig.add_subplot(gs[0, 0]); axo.set_title("Original WSI", fontweight="bold", fontsize=11, pad=8)
if thumb is not None:
axo.imshow(thumb); axo.set_xlim(bb[0], bb[2]); axo.set_ylim(bb[3], bb[1])
else:
axo.text(0.5, 0.5, "WSI not provided\n(.h5 only)", ha="center", va="center", fontsize=11)
axo.set_xticks([]); axo.set_yticks([])
axb = fig.add_subplot(gs[0, 1]); axb.set_title("Attention Heatmap", fontweight="bold", fontsize=11, pad=8)
if thumb is not None:
overlay(axb, thumb, coords, imp, psize, scale)
axb.set_xlim(bb[0], bb[2]); axb.set_ylim(bb[3], bb[1])
sm = ScalarMappable(cmap=CMAP_ATT, norm=Normalize(0, 1)); sm.set_array([])
fig.colorbar(sm, ax=axb, shrink=0.6, label="importance").ax.tick_params(labelsize=8)
else:
axb.text(0.5, 0.5, "WSI not provided", ha="center", va="center", fontsize=11)
axb.set_xticks([]); axb.set_yticks([])
axc = fig.add_subplot(gs[0, 2]); axc.axis("off")
axc.text(0.05, 0.95, f"Prediction: {names[pred]}\nConfidence: {pclass_conf*100:.0f}%\n"
f"Stability: {conf*100:.0f}%\nPatches: {slide['feats'].shape[0]}\n"
f"WSI: {'loaded' if simg else 'not provided'}",
transform=axc.transAxes, fontsize=10, va="top", family="monospace", color=C_TXT,
bbox=dict(boxstyle="round,pad=0.6", facecolor=C_GRID, alpha=0.5, edgecolor=color, lw=2))
top3 = np.argsort(imp)[::-1][:3] if len(imp) else []
for k, pi in enumerate(top3):
axz = fig.add_subplot(gs[1, k])
if simg is not None and coords is not None:
win = psize * zoom_context
cx, cy = coords[pi, 0] + psize // 2, coords[pi, 1] + psize // 2
x0, y0 = cx - win // 2, cy - win // 2
region = simg.read_region_l0(x0, y0, win, win)
axz.imshow(region, interpolation="lanczos")
px, py = coords[pi, 0] - max(0, x0), coords[pi, 1] - max(0, y0)
for gw, ga in [(7, 0.12), (5, 0.18), (3.2, 0.30)]:
axz.add_patch(Rectangle((px, py), psize, psize, fill=False,
edgecolor="#ffd93d", lw=gw, alpha=ga))
axz.add_patch(Rectangle((px, py), psize, psize, fill=False, edgecolor="#ffd93d", lw=1.6))
axz.text(0.04, 0.95, f"relevance {imp[pi]*100:.0f}%", transform=axz.transAxes,
fontsize=8.5, fontweight="bold", va="top", color="white",
bbox=dict(boxstyle="round,pad=0.3", facecolor=color, alpha=0.85, edgecolor="none"))
axz.set_title(f"Region #{k+1}", fontweight="bold", fontsize=10, pad=5, color=color)
axz.set_xticks([]); axz.set_yticks([])
else:
axz.text(0.5, 0.5, "no WSI/coords", ha="center", va="center", fontsize=10); axz.axis("off")
pdf.savefig(fig, dpi=160); plt.close(fig)
# ---- Showcase page: all patches ----
if simg is not None and coords is not None:
fig = plt.figure(figsize=(17, 10), facecolor=C_BG)
fig.suptitle(f"Patch-Level Visualization — {slide['slide_id']}", fontsize=17,
fontweight="bold", y=0.97, color=C_TXT)
npatch = slide["feats"].shape[0]
fig.text(0.5, 0.925, f"{names[pred]} confidence {pclass_conf*100:.1f}% | "
f"all {npatch} patches colored by relevance", ha="center", fontsize=11, color=C_NEUT)
thumb, scale = simg.thumbnail(thumb_max)
bb = tissue_bbox(thumb.shape, coords, psize, scale)
sp = max(1.0, psize * scale)
axL = fig.add_axes([0.04, 0.06, 0.44, 0.82]); axL.imshow(thumb)
axL.set_xlim(bb[0], bb[2]); axL.set_ylim(bb[3], bb[1])
axL.set_xticks([]); axL.set_yticks([])
axL.set_title("Original Tissue", fontweight="bold", fontsize=12, pad=8)
axR = fig.add_axes([0.50, 0.06, 0.44, 0.82]); axR.imshow(thumb, alpha=0.25)
rects, cols = [], []
for i in np.argsort(imp):
rects.append(Rectangle((coords[i, 0] * scale, coords[i, 1] * scale), sp, sp)); cols.append(imp[i])
pc = PatchCollection(rects, cmap=CMAP_ATT, alpha=0.85, edgecolor="white", linewidth=0.15)
pc.set_array(np.array(cols)); pc.set_clim(0, 1); axR.add_collection(pc)
axR.set_xlim(bb[0], bb[2]); axR.set_ylim(bb[3], bb[1])
axR.set_xticks([]); axR.set_yticks([])
axR.set_title(f"Relevance Map ({npatch} patches)", fontweight="bold", fontsize=12, pad=8)
for rank, pi in enumerate(np.argsort(imp)[::-1][:5]):
axR.text(coords[pi, 0] * scale + sp / 2, coords[pi, 1] * scale + sp / 2, str(rank + 1),
ha="center", va="center", fontsize=8, fontweight="bold", color="white",
bbox=dict(boxstyle="circle,pad=0.15", facecolor=color, edgecolor="white", lw=0.8, alpha=0.95))
fig.colorbar(pc, ax=axR, shrink=0.55, pad=0.01, label="patch relevance").ax.tick_params(labelsize=8)
pdf.savefig(fig, dpi=180); plt.close(fig)
if simg is not None:
simg.close()
return {
"slide_id": slide["slide_id"], "prediction": names[pred],
"p_malignant": prob, "confidence": pclass_conf, "stability": conf,
"n_patches": int(slide["feats"].shape[0]), "threshold": thr,
"had_wsi": simg is not None,
}