""" BRACS inference + PDF report engine for the HF Space. Reuses the LEAN report logic: prediction, per-patch relevance, heatmap, high-res zooms, patch-level showcase, and the model-report figures. Research use only — not for clinical diagnosis. """ import os, io, json, pickle, tempfile import numpy as np import h5py import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec from matplotlib.backends.backend_pdf import PdfPages from matplotlib.patches import Rectangle from matplotlib.collections import PatchCollection from matplotlib.colors import LinearSegmentedColormap, Normalize from matplotlib.cm import ScalarMappable from scipy.ndimage import gaussian_filter # ---------------------------------------------------------------------------- # Colors / colormaps # ---------------------------------------------------------------------------- C_TXT, C_BG, C_GRID = "#1a1a1a", "white", "#dddddd" C_BEN, C_MAL, C_ACC, C_NEUT = "#2a9d5c", "#d1495b", "#3a6ea5", "#7f8c8d" CMAP_ATT = LinearSegmentedColormap.from_list("att", ["#10103000", "#3a6ea5", "#ffd93d", "#d1495b"]) CMAP_BM = LinearSegmentedColormap.from_list("bm", [C_BEN, "#eeeeee", C_MAL]) DEFAULT_REPORT = { "header_model": "LR_concat", "header_cv_auc": 0.982, "header_threshold": 0.45, "model_comparison": [ ("LR_titan", "0.880 ± 0.072", "0.904"), ("LR_pooled", "0.973 ± 0.016", "0.973"), ("LR_concat", "0.972 ± 0.018", "0.973"), ("ABMIL", "0.953 ± 0.025", "—"), ], "best_model": "LR_concat", "confusion": {"tn": 170, "fp": 14, "fn": 22, "tp": 156}, "threshold_table": [ (0.20, 0.933, 0.864), (0.30, 0.933, 0.897), (0.50, 0.876, 0.924), (0.70, 0.831, 0.951), (0.80, 0.764, 0.957), ], "best_balanced_thr": 0.30, "dataset_used": [("Benign", 184), ("Malignant", 178)], "n_patients": 144, "results_paragraph": ( "The proposed patient-grouped BRACS classifier was evaluated on 362 " "whole-slide images from 144 patients. Logistic Regression using pooled " "slide-level features achieved a patient-grouped cross-validation AUC of " "0.973 ± 0.016 and an official test-set AUC of 0.973. At the default " "threshold of 0.50, sensitivity was 87.6% and specificity 92.4%. Threshold " "optimization identified 0.30 as the operating point yielding the highest " "balanced accuracy." ), } # ---------------------------------------------------------------------------- # H5 loading # ---------------------------------------------------------------------------- def load_slide_h5(path): with h5py.File(path, "r") as f: feats = np.asarray(f["features"][:], np.float32) titan = np.asarray(f["titan_slide_embedding"][:], np.float32).squeeze() if "pooled_conch_slide_feature_mean_max_std" in f: pooled = np.asarray(f["pooled_conch_slide_feature_mean_max_std"][:], np.float32).squeeze() else: pooled = np.concatenate([feats.mean(0), feats.max(0), feats.std(0)]) coords = np.asarray(f["coords"][:], np.int64) if "coords" in f else None psize = None if "coords" in f and "patch_size_level0" in f["coords"].attrs: psize = int(f["coords"].attrs["patch_size_level0"]) sid = f.attrs.get("slide_id", os.path.basename(path)) sid = sid.decode() if isinstance(sid, bytes) else str(sid) return { "feats": feats, "titan": titan, "pooled": pooled, "concat": np.concatenate([titan, pooled]).astype(np.float32), "coords": coords, "patch_size": psize, "slide_id": sid, "path": path, } # ---------------------------------------------------------------------------- # Prediction + relevance # ---------------------------------------------------------------------------- def _artifact(bundle): if "artifact" in bundle: return bundle["artifact"] return {"kind": bundle.get("kind", "linear"), "rep": bundle.get("rep", "concat"), "models": bundle.get("models", [])} def predict(bundle, slide): art = _artifact(bundle) if art["kind"] == "linear": x = slide[art["rep"]].reshape(1, -1) return np.array([m.predict_proba(x)[0, 1] for m in art["models"]]) import torch f = torch.from_numpy(slide["feats"]).float() probs = [] for st in art["states"]: m = build_abmil(art["in_dim"]); m.load_state_dict(st); m.eval() with torch.no_grad(): logit, _ = m(f); probs.append(torch.sigmoid(logit).item()) return np.array(probs) def build_abmil(in_dim, hidden=192, att=128, dropout=0.3): import torch.nn as nn, torch class ABMIL(nn.Module): def __init__(self): super().__init__() self.proj = nn.Sequential(nn.Linear(in_dim, hidden), nn.ReLU(), nn.Dropout(dropout)) self.V = nn.Linear(hidden, att); self.U = nn.Linear(hidden, att) self.w = nn.Linear(att, 1); self.head = nn.Linear(hidden, 1) def forward(self, x): h = self.proj(x) a = torch.softmax(self.w(torch.tanh(self.V(h)) * torch.sigmoid(self.U(h))), 0) return self.head((a * h).sum(0, keepdim=True)).squeeze(), a.squeeze(-1) return ABMIL() def norm01(x): return (x - x.min()) / (np.ptp(x) + 1e-9) def patch_relevance(bundle, slide): feats = slide["feats"] art = _artifact(bundle) if art.get("kind") == "linear" and art.get("rep") == "concat" and art.get("models"): try: coefs = [] for m in art["models"]: if hasattr(m, "named_steps"): c = np.asarray(m.named_steps["clf"].coef_).squeeze() if "pca" in m.named_steps: c = c @ m.named_steps["pca"].components_ if "scaler" in m.named_steps: c = c / (m.named_steps["scaler"].scale_ + 1e-9) else: c = np.asarray(m.coef_).squeeze() coefs.append(c) coef = np.mean(coefs, 0) d = feats.shape[1] mean_block = coef[d:2 * d] if mean_block.shape[0] == d: return norm01(feats @ mean_block) except Exception: pass return norm01(np.linalg.norm(feats, axis=1)) # ---------------------------------------------------------------------------- # Slide image backend (OpenSlide or PIL) # ---------------------------------------------------------------------------- class SlideImage: def __init__(self, path): self.path = path try: import openslide self.osh = openslide.OpenSlide(path) self.w, self.h = self.osh.dimensions self.backend = "openslide" except Exception: from PIL import Image Image.MAX_IMAGE_PIXELS = None self.img = Image.open(path).convert("RGB") self.w, self.h = self.img.size self.backend = "pil" def thumbnail(self, max_size): scale = min(max_size / self.w, max_size / self.h, 1.0) tw, th = max(1, int(self.w * scale)), max(1, int(self.h * scale)) if self.backend == "openslide": thumb = self.osh.get_thumbnail((tw, th)) else: from PIL import Image thumb = self.img.copy(); thumb.thumbnail((tw, th)) return np.array(thumb), scale def read_region_l0(self, x, y, w, h): x = max(0, min(int(x), self.w - 1)); y = max(0, min(int(y), self.h - 1)) w = max(1, min(int(w), self.w - x)); h = max(1, min(int(h), self.h - y)) if self.backend == "openslide": return np.array(self.osh.read_region((x, y), 0, (w, h)).convert("RGB")) return np.array(self.img.crop((x, y, x + w, y + h))) def close(self): if self.backend == "openslide" and hasattr(self, "osh"): self.osh.close() def tissue_bbox(thumb_shape, coords, psize, scale, pad_frac=0.04): h, w = thumb_shape[:2] if coords is None or len(coords) == 0: return 0, 0, w, h xs, ys = coords[:, 0] * scale, coords[:, 1] * scale sp = psize * scale x0, y0, x1, y1 = xs.min(), ys.min(), xs.max() + sp, ys.max() + sp pad = pad_frac * max(x1 - x0, y1 - y0) return (int(max(0, x0 - pad)), int(max(0, y0 - pad)), int(min(w, x1 + pad)), int(min(h, y1 + pad))) def heatmap_canvas(shape, coords, imp, psize, scale): h, w = shape[:2] acc = np.zeros((h, w), np.float32); cnt = np.zeros((h, w), np.float32) sp = max(1, int(psize * scale)) for i in range(len(imp)): x = int(coords[i, 0] * scale); y = int(coords[i, 1] * scale) x2, y2 = min(x + sp, w), min(y + sp, h) acc[y:y2, x:x2] += imp[i]; cnt[y:y2, x:x2] += 1 m = cnt > 0; acc[m] /= cnt[m] return norm01(gaussian_filter(acc, sigma=max(2.0, min(h, w) / 60))) def overlay(ax, thumb, coords, imp, psize, scale, alpha=0.55): ax.imshow(thumb) if coords is not None and len(coords): hm = heatmap_canvas(thumb.shape, coords, imp, psize, scale) rgba = CMAP_ATT(hm); rgba[..., 3] = alpha * hm ax.imshow(rgba) ax.set_xticks([]); ax.set_yticks([]) # ---------------------------------------------------------------------------- # Model-report figures (pages 2-3) # ---------------------------------------------------------------------------- def _styled_table(ax, header, rows, highlight_row=None, fontsize=11): tbl = ax.table(cellText=rows, colLabels=header, loc="center", cellLoc="center") tbl.auto_set_font_size(False); tbl.set_fontsize(fontsize); tbl.scale(1, 1.9) for (r, c), cell in tbl.get_celld().items(): cell.set_edgecolor("#cccccc") if r == 0: cell.set_facecolor(C_ACC); cell.set_text_props(color="white", fontweight="bold") elif highlight_row is not None and r == highlight_row: cell.set_facecolor("#eaf4ec"); cell.set_text_props(fontweight="bold", color=C_BEN) return tbl def report_pages(pdf, R): if not R: return # Figures 1 & 2 fig = plt.figure(figsize=(14, 10), facecolor=C_BG) fig.suptitle("Model Development & Evaluation", fontsize=18, fontweight="bold", y=0.97, color=C_TXT) ax1 = fig.add_axes([0.07, 0.60, 0.86, 0.27]); ax1.axis("off") ax1.set_title("Figure 1 — Model Comparison", fontweight="bold", fontsize=12, loc="left", color=C_ACC, pad=10) best = R.get("best_model") rows = [[m, cv, te] for (m, cv, te) in R["model_comparison"]] hl = next((i + 1 for i, (m, _, _) in enumerate(R["model_comparison"]) if m == best), None) _styled_table(ax1, ["Model", "Patient-CV AUC", "Test AUC"], rows, highlight_row=hl) cm = R["confusion"]; tn, fp, fn, tp = cm["tn"], cm["fp"], cm["fn"], cm["tp"] sens = tp / (tp + fn + 1e-9); spec = tn / (tn + fp + 1e-9); bacc = 0.5 * (sens + spec) axcm = fig.add_axes([0.10, 0.12, 0.36, 0.36]) axcm.set_title(f"Figure 2 — Confusion Matrix ({best}, thr=0.50)", fontweight="bold", fontsize=12, loc="left", color=C_ACC, pad=10) mat = np.array([[tn, fp], [fn, tp]]) axcm.imshow(mat, cmap=CMAP_BM, vmin=0, vmax=mat.max()) axcm.set_xticks([0, 1]); axcm.set_yticks([0, 1]) axcm.set_xticklabels(["Pred Benign", "Pred Malignant"], fontsize=9) axcm.set_yticklabels(["True Benign", "True Malignant"], fontsize=9, rotation=90, va="center") for i in range(2): for j in range(2): axcm.text(j, i, f"{mat[i, j]}", ha="center", va="center", fontsize=18, fontweight="bold", color="white" if mat[i, j] > mat.max() / 2 else C_TXT) axm = fig.add_axes([0.55, 0.12, 0.38, 0.36]); axm.axis("off") axm.text(0.0, 0.78, f"Sensitivity = {sens:.3f}\nSpecificity = {spec:.3f}\n" f"Balanced Accuracy = {bacc:.3f}", fontsize=13, family="monospace", va="top", color=C_TXT, bbox=dict(boxstyle="round,pad=0.8", facecolor="#f4f6f8", edgecolor=C_ACC, lw=1.5)) pdf.savefig(fig, dpi=150); plt.close(fig) # Figures 3 & 4 + Results fig = plt.figure(figsize=(14, 10), facecolor=C_BG) fig.suptitle("Threshold Analysis & Dataset", fontsize=18, fontweight="bold", y=0.97, color=C_TXT) tt = np.array(R["threshold_table"], dtype=float) ax3 = fig.add_axes([0.08, 0.58, 0.52, 0.30]) ax3.set_title("Figure 3 — Threshold Analysis", fontweight="bold", fontsize=12, loc="left", color=C_ACC, pad=10) ax3.plot(tt[:, 0], tt[:, 1], "-o", color=C_MAL, lw=2, label="Sensitivity") ax3.plot(tt[:, 0], tt[:, 2], "-s", color=C_BEN, lw=2, label="Specificity") bthr = R.get("best_balanced_thr") if bthr is not None: ax3.axvline(bthr, color=C_ACC, ls="--", lw=2, label=f"Best balanced ({bthr})") ax3.set_xlabel("Threshold"); ax3.set_ylabel("Metric"); ax3.grid(True, alpha=0.3) ax3.legend(fontsize=8, loc="lower center"); ax3.set_ylim(0.7, 1.0) ax3t = fig.add_axes([0.64, 0.58, 0.30, 0.30]); ax3t.axis("off") rows = [[f"{t:.2f}", f"{s:.3f}", f"{sp:.3f}"] for (t, s, sp) in R["threshold_table"]] hl = next((i + 1 for i, (t, _, _) in enumerate(R["threshold_table"]) if abs(t - (bthr or -1)) < 1e-9), None) _styled_table(ax3t, ["Thr", "Sens", "Spec"], rows, highlight_row=hl, fontsize=10) fig.text(0.08, 0.46, f"Best balanced accuracy at {bthr}. Operating point can be tuned by whether " "missing cancers or false alarms are costlier.", fontsize=9.5, style="italic", color=C_NEUT) used = R["dataset_used"]; total_used = sum(v for _, v in used) ax4 = fig.add_axes([0.30, 0.27, 0.40, 0.16]); ax4.axis("off") ax4.set_title("Figure 4 — Dataset Composition", fontweight="bold", fontsize=12, loc="center", color=C_ACC, pad=8) rows = [[c, str(v)] for c, v in used] + [["Total Used", str(total_used)]] _styled_table(ax4, ["Category", "Slides"], rows, highlight_row=len(rows), fontsize=10) fig.text(0.5, 0.245, f"{R['n_patients']} unique patients • patient-grouped CV • no patient leakage", ha="center", fontsize=9.5, style="italic", color=C_NEUT) import textwrap axres = fig.add_axes([0.08, 0.05, 0.84, 0.17]); axres.axis("off") axres.text(0.0, 0.95, "Results", fontsize=13, fontweight="bold", va="top", color=C_TXT) axres.text(0.0, 0.72, "\n".join(textwrap.wrap(R["results_paragraph"], width=110)), fontsize=10.5, va="top", color=C_TXT, linespacing=1.6, bbox=dict(boxstyle="round,pad=0.9", facecolor="#f4f6f8", edgecolor=C_ACC, lw=1.2)) pdf.savefig(fig, dpi=150); plt.close(fig) # ---------------------------------------------------------------------------- # MAIN: build the PDF for ONE slide # ---------------------------------------------------------------------------- def build_report(bundle, slide, wsi_path, out_pdf, thumb_max=4000, zoom_context=4, threshold=None, report_meta=None, include_model_report=True): thr = threshold if threshold is not None else bundle.get("threshold", 0.5) names = bundle.get("class_names", ["Benign", "Malignant"]) probs = predict(bundle, slide) prob = float(np.mean(probs)); spread = float(np.std(probs)) conf = 1.0 - spread; pred = int(prob >= thr) pclass_conf = prob if pred == 1 else (1.0 - prob) psize = slide["patch_size"] or 256 coords = slide["coords"] imp = patch_relevance(bundle, slide) color = C_MAL if pred else C_BEN conf_bar = "█" * int(conf * 10) + "░" * (10 - int(conf * 10)) simg = None if wsi_path and os.path.isfile(wsi_path): try: simg = SlideImage(wsi_path) except Exception: simg = None with PdfPages(out_pdf) as pdf: # ---- Page 1: summary ---- fig = plt.figure(figsize=(14, 10), facecolor=C_BG) fig.text(0.5, 0.95, "BRACS Explainability Report", ha="center", fontsize=20, fontweight="bold", color=C_TXT) R = report_meta or DEFAULT_REPORT fig.text(0.5, 0.91, f"Model: {R.get('header_model','?')} | CV AUC: " f"{R.get('header_cv_auc', float('nan')):.3f} | Threshold: {thr:.2f}", ha="center", fontsize=10, color=C_NEUT) ax = fig.add_axes([0.1, 0.4, 0.8, 0.4]); ax.axis("off") rows = [["Slide", "Prediction", "Confidence", "Stability", "Patches"], [slide["slide_id"][:30], names[pred], f"{pclass_conf*100:.1f}%", f"{conf*100:.0f}%", str(slide["feats"].shape[0])]] tbl = ax.table(cellText=rows[1:], colLabels=rows[0], loc="center", cellLoc="center") tbl.auto_set_font_size(False); tbl.set_fontsize(11); tbl.scale(1, 2.2) for (r, c), cell in tbl.get_celld().items(): if r == 0: cell.set_facecolor(C_ACC); cell.set_text_props(color="white", fontweight="bold") elif c == 1: cell.set_text_props(color=color, fontweight="bold") fig.text(0.5, 0.30, "Research use only — not for clinical diagnosis.", ha="center", fontsize=9, style="italic", color=C_NEUT) pdf.savefig(fig, dpi=150); plt.close(fig) # ---- Model report pages ---- if include_model_report: report_pages(pdf, R) # ---- Per-slide analysis page ---- fig = plt.figure(figsize=(16, 10), facecolor=C_BG) gs = gridspec.GridSpec(2, 3, height_ratios=[1.3, 1], hspace=0.28, wspace=0.18, width_ratios=[1, 1, 0.55]) fig.suptitle(f"{slide['slide_id']} | {names[pred]} confidence {pclass_conf*100:.1f}% [{conf_bar}]", color=color, fontsize=14, fontweight="bold", y=0.99) thumb = scale = bb = None if simg is not None: thumb, scale = simg.thumbnail(thumb_max) bb = tissue_bbox(thumb.shape, coords, psize, scale) axo = fig.add_subplot(gs[0, 0]); axo.set_title("Original WSI", fontweight="bold", fontsize=11, pad=8) if thumb is not None: axo.imshow(thumb); axo.set_xlim(bb[0], bb[2]); axo.set_ylim(bb[3], bb[1]) else: axo.text(0.5, 0.5, "WSI not provided\n(.h5 only)", ha="center", va="center", fontsize=11) axo.set_xticks([]); axo.set_yticks([]) axb = fig.add_subplot(gs[0, 1]); axb.set_title("Attention Heatmap", fontweight="bold", fontsize=11, pad=8) if thumb is not None: overlay(axb, thumb, coords, imp, psize, scale) axb.set_xlim(bb[0], bb[2]); axb.set_ylim(bb[3], bb[1]) sm = ScalarMappable(cmap=CMAP_ATT, norm=Normalize(0, 1)); sm.set_array([]) fig.colorbar(sm, ax=axb, shrink=0.6, label="importance").ax.tick_params(labelsize=8) else: axb.text(0.5, 0.5, "WSI not provided", ha="center", va="center", fontsize=11) axb.set_xticks([]); axb.set_yticks([]) axc = fig.add_subplot(gs[0, 2]); axc.axis("off") axc.text(0.05, 0.95, f"Prediction: {names[pred]}\nConfidence: {pclass_conf*100:.0f}%\n" f"Stability: {conf*100:.0f}%\nPatches: {slide['feats'].shape[0]}\n" f"WSI: {'loaded' if simg else 'not provided'}", transform=axc.transAxes, fontsize=10, va="top", family="monospace", color=C_TXT, bbox=dict(boxstyle="round,pad=0.6", facecolor=C_GRID, alpha=0.5, edgecolor=color, lw=2)) top3 = np.argsort(imp)[::-1][:3] if len(imp) else [] for k, pi in enumerate(top3): axz = fig.add_subplot(gs[1, k]) if simg is not None and coords is not None: win = psize * zoom_context cx, cy = coords[pi, 0] + psize // 2, coords[pi, 1] + psize // 2 x0, y0 = cx - win // 2, cy - win // 2 region = simg.read_region_l0(x0, y0, win, win) axz.imshow(region, interpolation="lanczos") px, py = coords[pi, 0] - max(0, x0), coords[pi, 1] - max(0, y0) for gw, ga in [(7, 0.12), (5, 0.18), (3.2, 0.30)]: axz.add_patch(Rectangle((px, py), psize, psize, fill=False, edgecolor="#ffd93d", lw=gw, alpha=ga)) axz.add_patch(Rectangle((px, py), psize, psize, fill=False, edgecolor="#ffd93d", lw=1.6)) axz.text(0.04, 0.95, f"relevance {imp[pi]*100:.0f}%", transform=axz.transAxes, fontsize=8.5, fontweight="bold", va="top", color="white", bbox=dict(boxstyle="round,pad=0.3", facecolor=color, alpha=0.85, edgecolor="none")) axz.set_title(f"Region #{k+1}", fontweight="bold", fontsize=10, pad=5, color=color) axz.set_xticks([]); axz.set_yticks([]) else: axz.text(0.5, 0.5, "no WSI/coords", ha="center", va="center", fontsize=10); axz.axis("off") pdf.savefig(fig, dpi=160); plt.close(fig) # ---- Showcase page: all patches ---- if simg is not None and coords is not None: fig = plt.figure(figsize=(17, 10), facecolor=C_BG) fig.suptitle(f"Patch-Level Visualization — {slide['slide_id']}", fontsize=17, fontweight="bold", y=0.97, color=C_TXT) npatch = slide["feats"].shape[0] fig.text(0.5, 0.925, f"{names[pred]} confidence {pclass_conf*100:.1f}% | " f"all {npatch} patches colored by relevance", ha="center", fontsize=11, color=C_NEUT) thumb, scale = simg.thumbnail(thumb_max) bb = tissue_bbox(thumb.shape, coords, psize, scale) sp = max(1.0, psize * scale) axL = fig.add_axes([0.04, 0.06, 0.44, 0.82]); axL.imshow(thumb) axL.set_xlim(bb[0], bb[2]); axL.set_ylim(bb[3], bb[1]) axL.set_xticks([]); axL.set_yticks([]) axL.set_title("Original Tissue", fontweight="bold", fontsize=12, pad=8) axR = fig.add_axes([0.50, 0.06, 0.44, 0.82]); axR.imshow(thumb, alpha=0.25) rects, cols = [], [] for i in np.argsort(imp): rects.append(Rectangle((coords[i, 0] * scale, coords[i, 1] * scale), sp, sp)); cols.append(imp[i]) pc = PatchCollection(rects, cmap=CMAP_ATT, alpha=0.85, edgecolor="white", linewidth=0.15) pc.set_array(np.array(cols)); pc.set_clim(0, 1); axR.add_collection(pc) axR.set_xlim(bb[0], bb[2]); axR.set_ylim(bb[3], bb[1]) axR.set_xticks([]); axR.set_yticks([]) axR.set_title(f"Relevance Map ({npatch} patches)", fontweight="bold", fontsize=12, pad=8) for rank, pi in enumerate(np.argsort(imp)[::-1][:5]): axR.text(coords[pi, 0] * scale + sp / 2, coords[pi, 1] * scale + sp / 2, str(rank + 1), ha="center", va="center", fontsize=8, fontweight="bold", color="white", bbox=dict(boxstyle="circle,pad=0.15", facecolor=color, edgecolor="white", lw=0.8, alpha=0.95)) fig.colorbar(pc, ax=axR, shrink=0.55, pad=0.01, label="patch relevance").ax.tick_params(labelsize=8) pdf.savefig(fig, dpi=180); plt.close(fig) if simg is not None: simg.close() return { "slide_id": slide["slide_id"], "prediction": names[pred], "p_malignant": prob, "confidence": pclass_conf, "stability": conf, "n_patches": int(slide["feats"].shape[0]), "threshold": thr, "had_wsi": simg is not None, }