Spaces:
Sleeping
Sleeping
| """ | |
| BRACS inference + PDF report engine for the HF Space. | |
| Reuses the LEAN report logic: prediction, per-patch relevance, heatmap, | |
| high-res zooms, patch-level showcase, and the model-report figures. | |
| Research use only — not for clinical diagnosis. | |
| """ | |
| import os, io, json, pickle, tempfile | |
| import numpy as np | |
| import h5py | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import matplotlib.gridspec as gridspec | |
| from matplotlib.backends.backend_pdf import PdfPages | |
| from matplotlib.patches import Rectangle | |
| from matplotlib.collections import PatchCollection | |
| from matplotlib.colors import LinearSegmentedColormap, Normalize | |
| from matplotlib.cm import ScalarMappable | |
| from scipy.ndimage import gaussian_filter | |
| # ---------------------------------------------------------------------------- | |
| # Colors / colormaps | |
| # ---------------------------------------------------------------------------- | |
| C_TXT, C_BG, C_GRID = "#1a1a1a", "white", "#dddddd" | |
| C_BEN, C_MAL, C_ACC, C_NEUT = "#2a9d5c", "#d1495b", "#3a6ea5", "#7f8c8d" | |
| CMAP_ATT = LinearSegmentedColormap.from_list("att", ["#10103000", "#3a6ea5", "#ffd93d", "#d1495b"]) | |
| CMAP_BM = LinearSegmentedColormap.from_list("bm", [C_BEN, "#eeeeee", C_MAL]) | |
| DEFAULT_REPORT = { | |
| "header_model": "LR_concat", | |
| "header_cv_auc": 0.982, | |
| "header_threshold": 0.45, | |
| "model_comparison": [ | |
| ("LR_titan", "0.880 ± 0.072", "0.904"), | |
| ("LR_pooled", "0.973 ± 0.016", "0.973"), | |
| ("LR_concat", "0.972 ± 0.018", "0.973"), | |
| ("ABMIL", "0.953 ± 0.025", "—"), | |
| ], | |
| "best_model": "LR_concat", | |
| "confusion": {"tn": 170, "fp": 14, "fn": 22, "tp": 156}, | |
| "threshold_table": [ | |
| (0.20, 0.933, 0.864), (0.30, 0.933, 0.897), | |
| (0.50, 0.876, 0.924), (0.70, 0.831, 0.951), (0.80, 0.764, 0.957), | |
| ], | |
| "best_balanced_thr": 0.30, | |
| "dataset_used": [("Benign", 184), ("Malignant", 178)], | |
| "n_patients": 144, | |
| "results_paragraph": ( | |
| "The proposed patient-grouped BRACS classifier was evaluated on 362 " | |
| "whole-slide images from 144 patients. Logistic Regression using pooled " | |
| "slide-level features achieved a patient-grouped cross-validation AUC of " | |
| "0.973 ± 0.016 and an official test-set AUC of 0.973. At the default " | |
| "threshold of 0.50, sensitivity was 87.6% and specificity 92.4%. Threshold " | |
| "optimization identified 0.30 as the operating point yielding the highest " | |
| "balanced accuracy." | |
| ), | |
| } | |
| # ---------------------------------------------------------------------------- | |
| # H5 loading | |
| # ---------------------------------------------------------------------------- | |
| def load_slide_h5(path): | |
| with h5py.File(path, "r") as f: | |
| feats = np.asarray(f["features"][:], np.float32) | |
| titan = np.asarray(f["titan_slide_embedding"][:], np.float32).squeeze() | |
| if "pooled_conch_slide_feature_mean_max_std" in f: | |
| pooled = np.asarray(f["pooled_conch_slide_feature_mean_max_std"][:], np.float32).squeeze() | |
| else: | |
| pooled = np.concatenate([feats.mean(0), feats.max(0), feats.std(0)]) | |
| coords = np.asarray(f["coords"][:], np.int64) if "coords" in f else None | |
| psize = None | |
| if "coords" in f and "patch_size_level0" in f["coords"].attrs: | |
| psize = int(f["coords"].attrs["patch_size_level0"]) | |
| sid = f.attrs.get("slide_id", os.path.basename(path)) | |
| sid = sid.decode() if isinstance(sid, bytes) else str(sid) | |
| return { | |
| "feats": feats, "titan": titan, "pooled": pooled, | |
| "concat": np.concatenate([titan, pooled]).astype(np.float32), | |
| "coords": coords, "patch_size": psize, "slide_id": sid, "path": path, | |
| } | |
| # ---------------------------------------------------------------------------- | |
| # Prediction + relevance | |
| # ---------------------------------------------------------------------------- | |
| def _artifact(bundle): | |
| if "artifact" in bundle: | |
| return bundle["artifact"] | |
| return {"kind": bundle.get("kind", "linear"), | |
| "rep": bundle.get("rep", "concat"), | |
| "models": bundle.get("models", [])} | |
| def predict(bundle, slide): | |
| art = _artifact(bundle) | |
| if art["kind"] == "linear": | |
| x = slide[art["rep"]].reshape(1, -1) | |
| return np.array([m.predict_proba(x)[0, 1] for m in art["models"]]) | |
| import torch | |
| f = torch.from_numpy(slide["feats"]).float() | |
| probs = [] | |
| for st in art["states"]: | |
| m = build_abmil(art["in_dim"]); m.load_state_dict(st); m.eval() | |
| with torch.no_grad(): | |
| logit, _ = m(f); probs.append(torch.sigmoid(logit).item()) | |
| return np.array(probs) | |
| def build_abmil(in_dim, hidden=192, att=128, dropout=0.3): | |
| import torch.nn as nn, torch | |
| class ABMIL(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.proj = nn.Sequential(nn.Linear(in_dim, hidden), nn.ReLU(), nn.Dropout(dropout)) | |
| self.V = nn.Linear(hidden, att); self.U = nn.Linear(hidden, att) | |
| self.w = nn.Linear(att, 1); self.head = nn.Linear(hidden, 1) | |
| def forward(self, x): | |
| h = self.proj(x) | |
| a = torch.softmax(self.w(torch.tanh(self.V(h)) * torch.sigmoid(self.U(h))), 0) | |
| return self.head((a * h).sum(0, keepdim=True)).squeeze(), a.squeeze(-1) | |
| return ABMIL() | |
| def norm01(x): | |
| return (x - x.min()) / (np.ptp(x) + 1e-9) | |
| def patch_relevance(bundle, slide): | |
| feats = slide["feats"] | |
| art = _artifact(bundle) | |
| if art.get("kind") == "linear" and art.get("rep") == "concat" and art.get("models"): | |
| try: | |
| coefs = [] | |
| for m in art["models"]: | |
| if hasattr(m, "named_steps"): | |
| c = np.asarray(m.named_steps["clf"].coef_).squeeze() | |
| if "pca" in m.named_steps: | |
| c = c @ m.named_steps["pca"].components_ | |
| if "scaler" in m.named_steps: | |
| c = c / (m.named_steps["scaler"].scale_ + 1e-9) | |
| else: | |
| c = np.asarray(m.coef_).squeeze() | |
| coefs.append(c) | |
| coef = np.mean(coefs, 0) | |
| d = feats.shape[1] | |
| mean_block = coef[d:2 * d] | |
| if mean_block.shape[0] == d: | |
| return norm01(feats @ mean_block) | |
| except Exception: | |
| pass | |
| return norm01(np.linalg.norm(feats, axis=1)) | |
| # ---------------------------------------------------------------------------- | |
| # Slide image backend (OpenSlide or PIL) | |
| # ---------------------------------------------------------------------------- | |
| class SlideImage: | |
| def __init__(self, path): | |
| self.path = path | |
| try: | |
| import openslide | |
| self.osh = openslide.OpenSlide(path) | |
| self.w, self.h = self.osh.dimensions | |
| self.backend = "openslide" | |
| except Exception: | |
| from PIL import Image | |
| Image.MAX_IMAGE_PIXELS = None | |
| self.img = Image.open(path).convert("RGB") | |
| self.w, self.h = self.img.size | |
| self.backend = "pil" | |
| def thumbnail(self, max_size): | |
| scale = min(max_size / self.w, max_size / self.h, 1.0) | |
| tw, th = max(1, int(self.w * scale)), max(1, int(self.h * scale)) | |
| if self.backend == "openslide": | |
| thumb = self.osh.get_thumbnail((tw, th)) | |
| else: | |
| from PIL import Image | |
| thumb = self.img.copy(); thumb.thumbnail((tw, th)) | |
| return np.array(thumb), scale | |
| def read_region_l0(self, x, y, w, h): | |
| x = max(0, min(int(x), self.w - 1)); y = max(0, min(int(y), self.h - 1)) | |
| w = max(1, min(int(w), self.w - x)); h = max(1, min(int(h), self.h - y)) | |
| if self.backend == "openslide": | |
| return np.array(self.osh.read_region((x, y), 0, (w, h)).convert("RGB")) | |
| return np.array(self.img.crop((x, y, x + w, y + h))) | |
| def close(self): | |
| if self.backend == "openslide" and hasattr(self, "osh"): | |
| self.osh.close() | |
| def tissue_bbox(thumb_shape, coords, psize, scale, pad_frac=0.04): | |
| h, w = thumb_shape[:2] | |
| if coords is None or len(coords) == 0: | |
| return 0, 0, w, h | |
| xs, ys = coords[:, 0] * scale, coords[:, 1] * scale | |
| sp = psize * scale | |
| x0, y0, x1, y1 = xs.min(), ys.min(), xs.max() + sp, ys.max() + sp | |
| pad = pad_frac * max(x1 - x0, y1 - y0) | |
| return (int(max(0, x0 - pad)), int(max(0, y0 - pad)), | |
| int(min(w, x1 + pad)), int(min(h, y1 + pad))) | |
| def heatmap_canvas(shape, coords, imp, psize, scale): | |
| h, w = shape[:2] | |
| acc = np.zeros((h, w), np.float32); cnt = np.zeros((h, w), np.float32) | |
| sp = max(1, int(psize * scale)) | |
| for i in range(len(imp)): | |
| x = int(coords[i, 0] * scale); y = int(coords[i, 1] * scale) | |
| x2, y2 = min(x + sp, w), min(y + sp, h) | |
| acc[y:y2, x:x2] += imp[i]; cnt[y:y2, x:x2] += 1 | |
| m = cnt > 0; acc[m] /= cnt[m] | |
| return norm01(gaussian_filter(acc, sigma=max(2.0, min(h, w) / 60))) | |
| def overlay(ax, thumb, coords, imp, psize, scale, alpha=0.55): | |
| ax.imshow(thumb) | |
| if coords is not None and len(coords): | |
| hm = heatmap_canvas(thumb.shape, coords, imp, psize, scale) | |
| rgba = CMAP_ATT(hm); rgba[..., 3] = alpha * hm | |
| ax.imshow(rgba) | |
| ax.set_xticks([]); ax.set_yticks([]) | |
| # ---------------------------------------------------------------------------- | |
| # Model-report figures (pages 2-3) | |
| # ---------------------------------------------------------------------------- | |
| def _styled_table(ax, header, rows, highlight_row=None, fontsize=11): | |
| tbl = ax.table(cellText=rows, colLabels=header, loc="center", cellLoc="center") | |
| tbl.auto_set_font_size(False); tbl.set_fontsize(fontsize); tbl.scale(1, 1.9) | |
| for (r, c), cell in tbl.get_celld().items(): | |
| cell.set_edgecolor("#cccccc") | |
| if r == 0: | |
| cell.set_facecolor(C_ACC); cell.set_text_props(color="white", fontweight="bold") | |
| elif highlight_row is not None and r == highlight_row: | |
| cell.set_facecolor("#eaf4ec"); cell.set_text_props(fontweight="bold", color=C_BEN) | |
| return tbl | |
| def report_pages(pdf, R): | |
| if not R: | |
| return | |
| # Figures 1 & 2 | |
| fig = plt.figure(figsize=(14, 10), facecolor=C_BG) | |
| fig.suptitle("Model Development & Evaluation", fontsize=18, fontweight="bold", y=0.97, color=C_TXT) | |
| ax1 = fig.add_axes([0.07, 0.60, 0.86, 0.27]); ax1.axis("off") | |
| ax1.set_title("Figure 1 — Model Comparison", fontweight="bold", fontsize=12, loc="left", color=C_ACC, pad=10) | |
| best = R.get("best_model") | |
| rows = [[m, cv, te] for (m, cv, te) in R["model_comparison"]] | |
| hl = next((i + 1 for i, (m, _, _) in enumerate(R["model_comparison"]) if m == best), None) | |
| _styled_table(ax1, ["Model", "Patient-CV AUC", "Test AUC"], rows, highlight_row=hl) | |
| cm = R["confusion"]; tn, fp, fn, tp = cm["tn"], cm["fp"], cm["fn"], cm["tp"] | |
| sens = tp / (tp + fn + 1e-9); spec = tn / (tn + fp + 1e-9); bacc = 0.5 * (sens + spec) | |
| axcm = fig.add_axes([0.10, 0.12, 0.36, 0.36]) | |
| axcm.set_title(f"Figure 2 — Confusion Matrix ({best}, thr=0.50)", fontweight="bold", | |
| fontsize=12, loc="left", color=C_ACC, pad=10) | |
| mat = np.array([[tn, fp], [fn, tp]]) | |
| axcm.imshow(mat, cmap=CMAP_BM, vmin=0, vmax=mat.max()) | |
| axcm.set_xticks([0, 1]); axcm.set_yticks([0, 1]) | |
| axcm.set_xticklabels(["Pred Benign", "Pred Malignant"], fontsize=9) | |
| axcm.set_yticklabels(["True Benign", "True Malignant"], fontsize=9, rotation=90, va="center") | |
| for i in range(2): | |
| for j in range(2): | |
| axcm.text(j, i, f"{mat[i, j]}", ha="center", va="center", fontsize=18, | |
| fontweight="bold", color="white" if mat[i, j] > mat.max() / 2 else C_TXT) | |
| axm = fig.add_axes([0.55, 0.12, 0.38, 0.36]); axm.axis("off") | |
| axm.text(0.0, 0.78, f"Sensitivity = {sens:.3f}\nSpecificity = {spec:.3f}\n" | |
| f"Balanced Accuracy = {bacc:.3f}", fontsize=13, family="monospace", va="top", | |
| color=C_TXT, bbox=dict(boxstyle="round,pad=0.8", facecolor="#f4f6f8", edgecolor=C_ACC, lw=1.5)) | |
| pdf.savefig(fig, dpi=150); plt.close(fig) | |
| # Figures 3 & 4 + Results | |
| fig = plt.figure(figsize=(14, 10), facecolor=C_BG) | |
| fig.suptitle("Threshold Analysis & Dataset", fontsize=18, fontweight="bold", y=0.97, color=C_TXT) | |
| tt = np.array(R["threshold_table"], dtype=float) | |
| ax3 = fig.add_axes([0.08, 0.58, 0.52, 0.30]) | |
| ax3.set_title("Figure 3 — Threshold Analysis", fontweight="bold", fontsize=12, loc="left", color=C_ACC, pad=10) | |
| ax3.plot(tt[:, 0], tt[:, 1], "-o", color=C_MAL, lw=2, label="Sensitivity") | |
| ax3.plot(tt[:, 0], tt[:, 2], "-s", color=C_BEN, lw=2, label="Specificity") | |
| bthr = R.get("best_balanced_thr") | |
| if bthr is not None: | |
| ax3.axvline(bthr, color=C_ACC, ls="--", lw=2, label=f"Best balanced ({bthr})") | |
| ax3.set_xlabel("Threshold"); ax3.set_ylabel("Metric"); ax3.grid(True, alpha=0.3) | |
| ax3.legend(fontsize=8, loc="lower center"); ax3.set_ylim(0.7, 1.0) | |
| ax3t = fig.add_axes([0.64, 0.58, 0.30, 0.30]); ax3t.axis("off") | |
| rows = [[f"{t:.2f}", f"{s:.3f}", f"{sp:.3f}"] for (t, s, sp) in R["threshold_table"]] | |
| hl = next((i + 1 for i, (t, _, _) in enumerate(R["threshold_table"]) if abs(t - (bthr or -1)) < 1e-9), None) | |
| _styled_table(ax3t, ["Thr", "Sens", "Spec"], rows, highlight_row=hl, fontsize=10) | |
| fig.text(0.08, 0.46, f"Best balanced accuracy at {bthr}. Operating point can be tuned by whether " | |
| "missing cancers or false alarms are costlier.", fontsize=9.5, style="italic", color=C_NEUT) | |
| used = R["dataset_used"]; total_used = sum(v for _, v in used) | |
| ax4 = fig.add_axes([0.30, 0.27, 0.40, 0.16]); ax4.axis("off") | |
| ax4.set_title("Figure 4 — Dataset Composition", fontweight="bold", fontsize=12, loc="center", color=C_ACC, pad=8) | |
| rows = [[c, str(v)] for c, v in used] + [["Total Used", str(total_used)]] | |
| _styled_table(ax4, ["Category", "Slides"], rows, highlight_row=len(rows), fontsize=10) | |
| fig.text(0.5, 0.245, f"{R['n_patients']} unique patients • patient-grouped CV • no patient leakage", | |
| ha="center", fontsize=9.5, style="italic", color=C_NEUT) | |
| import textwrap | |
| axres = fig.add_axes([0.08, 0.05, 0.84, 0.17]); axres.axis("off") | |
| axres.text(0.0, 0.95, "Results", fontsize=13, fontweight="bold", va="top", color=C_TXT) | |
| axres.text(0.0, 0.72, "\n".join(textwrap.wrap(R["results_paragraph"], width=110)), | |
| fontsize=10.5, va="top", color=C_TXT, linespacing=1.6, | |
| bbox=dict(boxstyle="round,pad=0.9", facecolor="#f4f6f8", edgecolor=C_ACC, lw=1.2)) | |
| pdf.savefig(fig, dpi=150); plt.close(fig) | |
| # ---------------------------------------------------------------------------- | |
| # MAIN: build the PDF for ONE slide | |
| # ---------------------------------------------------------------------------- | |
| def build_report(bundle, slide, wsi_path, out_pdf, | |
| thumb_max=4000, zoom_context=4, threshold=None, | |
| report_meta=None, include_model_report=True): | |
| thr = threshold if threshold is not None else bundle.get("threshold", 0.5) | |
| names = bundle.get("class_names", ["Benign", "Malignant"]) | |
| probs = predict(bundle, slide) | |
| prob = float(np.mean(probs)); spread = float(np.std(probs)) | |
| conf = 1.0 - spread; pred = int(prob >= thr) | |
| pclass_conf = prob if pred == 1 else (1.0 - prob) | |
| psize = slide["patch_size"] or 256 | |
| coords = slide["coords"] | |
| imp = patch_relevance(bundle, slide) | |
| color = C_MAL if pred else C_BEN | |
| conf_bar = "█" * int(conf * 10) + "░" * (10 - int(conf * 10)) | |
| simg = None | |
| if wsi_path and os.path.isfile(wsi_path): | |
| try: | |
| simg = SlideImage(wsi_path) | |
| except Exception: | |
| simg = None | |
| with PdfPages(out_pdf) as pdf: | |
| # ---- Page 1: summary ---- | |
| fig = plt.figure(figsize=(14, 10), facecolor=C_BG) | |
| fig.text(0.5, 0.95, "BRACS Explainability Report", ha="center", | |
| fontsize=20, fontweight="bold", color=C_TXT) | |
| R = report_meta or DEFAULT_REPORT | |
| fig.text(0.5, 0.91, f"Model: {R.get('header_model','?')} | CV AUC: " | |
| f"{R.get('header_cv_auc', float('nan')):.3f} | Threshold: {thr:.2f}", | |
| ha="center", fontsize=10, color=C_NEUT) | |
| ax = fig.add_axes([0.1, 0.4, 0.8, 0.4]); ax.axis("off") | |
| rows = [["Slide", "Prediction", "Confidence", "Stability", "Patches"], | |
| [slide["slide_id"][:30], names[pred], f"{pclass_conf*100:.1f}%", | |
| f"{conf*100:.0f}%", str(slide["feats"].shape[0])]] | |
| tbl = ax.table(cellText=rows[1:], colLabels=rows[0], loc="center", cellLoc="center") | |
| tbl.auto_set_font_size(False); tbl.set_fontsize(11); tbl.scale(1, 2.2) | |
| for (r, c), cell in tbl.get_celld().items(): | |
| if r == 0: | |
| cell.set_facecolor(C_ACC); cell.set_text_props(color="white", fontweight="bold") | |
| elif c == 1: | |
| cell.set_text_props(color=color, fontweight="bold") | |
| fig.text(0.5, 0.30, "Research use only — not for clinical diagnosis.", | |
| ha="center", fontsize=9, style="italic", color=C_NEUT) | |
| pdf.savefig(fig, dpi=150); plt.close(fig) | |
| # ---- Model report pages ---- | |
| if include_model_report: | |
| report_pages(pdf, R) | |
| # ---- Per-slide analysis page ---- | |
| fig = plt.figure(figsize=(16, 10), facecolor=C_BG) | |
| gs = gridspec.GridSpec(2, 3, height_ratios=[1.3, 1], hspace=0.28, | |
| wspace=0.18, width_ratios=[1, 1, 0.55]) | |
| fig.suptitle(f"{slide['slide_id']} | {names[pred]} confidence {pclass_conf*100:.1f}% [{conf_bar}]", | |
| color=color, fontsize=14, fontweight="bold", y=0.99) | |
| thumb = scale = bb = None | |
| if simg is not None: | |
| thumb, scale = simg.thumbnail(thumb_max) | |
| bb = tissue_bbox(thumb.shape, coords, psize, scale) | |
| axo = fig.add_subplot(gs[0, 0]); axo.set_title("Original WSI", fontweight="bold", fontsize=11, pad=8) | |
| if thumb is not None: | |
| axo.imshow(thumb); axo.set_xlim(bb[0], bb[2]); axo.set_ylim(bb[3], bb[1]) | |
| else: | |
| axo.text(0.5, 0.5, "WSI not provided\n(.h5 only)", ha="center", va="center", fontsize=11) | |
| axo.set_xticks([]); axo.set_yticks([]) | |
| axb = fig.add_subplot(gs[0, 1]); axb.set_title("Attention Heatmap", fontweight="bold", fontsize=11, pad=8) | |
| if thumb is not None: | |
| overlay(axb, thumb, coords, imp, psize, scale) | |
| axb.set_xlim(bb[0], bb[2]); axb.set_ylim(bb[3], bb[1]) | |
| sm = ScalarMappable(cmap=CMAP_ATT, norm=Normalize(0, 1)); sm.set_array([]) | |
| fig.colorbar(sm, ax=axb, shrink=0.6, label="importance").ax.tick_params(labelsize=8) | |
| else: | |
| axb.text(0.5, 0.5, "WSI not provided", ha="center", va="center", fontsize=11) | |
| axb.set_xticks([]); axb.set_yticks([]) | |
| axc = fig.add_subplot(gs[0, 2]); axc.axis("off") | |
| axc.text(0.05, 0.95, f"Prediction: {names[pred]}\nConfidence: {pclass_conf*100:.0f}%\n" | |
| f"Stability: {conf*100:.0f}%\nPatches: {slide['feats'].shape[0]}\n" | |
| f"WSI: {'loaded' if simg else 'not provided'}", | |
| transform=axc.transAxes, fontsize=10, va="top", family="monospace", color=C_TXT, | |
| bbox=dict(boxstyle="round,pad=0.6", facecolor=C_GRID, alpha=0.5, edgecolor=color, lw=2)) | |
| top3 = np.argsort(imp)[::-1][:3] if len(imp) else [] | |
| for k, pi in enumerate(top3): | |
| axz = fig.add_subplot(gs[1, k]) | |
| if simg is not None and coords is not None: | |
| win = psize * zoom_context | |
| cx, cy = coords[pi, 0] + psize // 2, coords[pi, 1] + psize // 2 | |
| x0, y0 = cx - win // 2, cy - win // 2 | |
| region = simg.read_region_l0(x0, y0, win, win) | |
| axz.imshow(region, interpolation="lanczos") | |
| px, py = coords[pi, 0] - max(0, x0), coords[pi, 1] - max(0, y0) | |
| for gw, ga in [(7, 0.12), (5, 0.18), (3.2, 0.30)]: | |
| axz.add_patch(Rectangle((px, py), psize, psize, fill=False, | |
| edgecolor="#ffd93d", lw=gw, alpha=ga)) | |
| axz.add_patch(Rectangle((px, py), psize, psize, fill=False, edgecolor="#ffd93d", lw=1.6)) | |
| axz.text(0.04, 0.95, f"relevance {imp[pi]*100:.0f}%", transform=axz.transAxes, | |
| fontsize=8.5, fontweight="bold", va="top", color="white", | |
| bbox=dict(boxstyle="round,pad=0.3", facecolor=color, alpha=0.85, edgecolor="none")) | |
| axz.set_title(f"Region #{k+1}", fontweight="bold", fontsize=10, pad=5, color=color) | |
| axz.set_xticks([]); axz.set_yticks([]) | |
| else: | |
| axz.text(0.5, 0.5, "no WSI/coords", ha="center", va="center", fontsize=10); axz.axis("off") | |
| pdf.savefig(fig, dpi=160); plt.close(fig) | |
| # ---- Showcase page: all patches ---- | |
| if simg is not None and coords is not None: | |
| fig = plt.figure(figsize=(17, 10), facecolor=C_BG) | |
| fig.suptitle(f"Patch-Level Visualization — {slide['slide_id']}", fontsize=17, | |
| fontweight="bold", y=0.97, color=C_TXT) | |
| npatch = slide["feats"].shape[0] | |
| fig.text(0.5, 0.925, f"{names[pred]} confidence {pclass_conf*100:.1f}% | " | |
| f"all {npatch} patches colored by relevance", ha="center", fontsize=11, color=C_NEUT) | |
| thumb, scale = simg.thumbnail(thumb_max) | |
| bb = tissue_bbox(thumb.shape, coords, psize, scale) | |
| sp = max(1.0, psize * scale) | |
| axL = fig.add_axes([0.04, 0.06, 0.44, 0.82]); axL.imshow(thumb) | |
| axL.set_xlim(bb[0], bb[2]); axL.set_ylim(bb[3], bb[1]) | |
| axL.set_xticks([]); axL.set_yticks([]) | |
| axL.set_title("Original Tissue", fontweight="bold", fontsize=12, pad=8) | |
| axR = fig.add_axes([0.50, 0.06, 0.44, 0.82]); axR.imshow(thumb, alpha=0.25) | |
| rects, cols = [], [] | |
| for i in np.argsort(imp): | |
| rects.append(Rectangle((coords[i, 0] * scale, coords[i, 1] * scale), sp, sp)); cols.append(imp[i]) | |
| pc = PatchCollection(rects, cmap=CMAP_ATT, alpha=0.85, edgecolor="white", linewidth=0.15) | |
| pc.set_array(np.array(cols)); pc.set_clim(0, 1); axR.add_collection(pc) | |
| axR.set_xlim(bb[0], bb[2]); axR.set_ylim(bb[3], bb[1]) | |
| axR.set_xticks([]); axR.set_yticks([]) | |
| axR.set_title(f"Relevance Map ({npatch} patches)", fontweight="bold", fontsize=12, pad=8) | |
| for rank, pi in enumerate(np.argsort(imp)[::-1][:5]): | |
| axR.text(coords[pi, 0] * scale + sp / 2, coords[pi, 1] * scale + sp / 2, str(rank + 1), | |
| ha="center", va="center", fontsize=8, fontweight="bold", color="white", | |
| bbox=dict(boxstyle="circle,pad=0.15", facecolor=color, edgecolor="white", lw=0.8, alpha=0.95)) | |
| fig.colorbar(pc, ax=axR, shrink=0.55, pad=0.01, label="patch relevance").ax.tick_params(labelsize=8) | |
| pdf.savefig(fig, dpi=180); plt.close(fig) | |
| if simg is not None: | |
| simg.close() | |
| return { | |
| "slide_id": slide["slide_id"], "prediction": names[pred], | |
| "p_malignant": prob, "confidence": pclass_conf, "stability": conf, | |
| "n_patches": int(slide["feats"].shape[0]), "threshold": thr, | |
| "had_wsi": simg is not None, | |
| } | |