#!/usr/bin/env python3 """ BrainGemma3D + LIME Interpretability ================================================== Usage: python braingemma3d_interpretability.py \\ --model_dir ./final_model \\ --mri_path /path/to/scan.nii.gz \\ --report "The brain shows a mass in the left frontal lobe..." \\ --output_dir ./lime_output If --report is not provided, the script will generate it first. """ import os import sys import json import argparse import random import importlib.util from pathlib import Path import numpy as np import torch import matplotlib matplotlib.use("Agg") # Headless mode import matplotlib.pyplot as plt # LIME + segmentation from lime import lime_image from skimage.segmentation import slic from scipy.ndimage import binary_closing, binary_opening, binary_fill_holes, binary_erosion from skimage.morphology import ball, remove_small_objects from skimage.measure import label as cc_label def set_seed(seed: int = 42): """Set random seed for reproducibility.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def import_architecture_from_model_dir(model_dir): """Dynamically import braingemma3d_architecture.py from model folder.""" arch_path = os.path.join(model_dir, "braingemma3d_architecture.py") spec = importlib.util.spec_from_file_location("braingemma3d_architecture", arch_path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) return module def load_full_model(model_dir, device): """Load BrainGemma3D model with projector weights.""" arch_module = import_architecture_from_model_dir(model_dir) BrainGemma3D = arch_module.BrainGemma3D load_nifti_volume = arch_module.load_nifti_volume CANONICAL_PROMPT = arch_module.CANONICAL_PROMPT with open(os.path.join(model_dir, "model_config.json")) as f: cfg = json.load(f) model = BrainGemma3D( vision_model_dir=os.path.join(model_dir, cfg["vision_model_dir"]), language_model_dir=os.path.join(model_dir, cfg["language_model_dir"]), depth=cfg["depth"], num_vision_tokens=cfg["num_vision_tokens"], freeze_vision=True, freeze_language=True, device_map={"": 0} if device == "cuda" else None, ) # Load projector proj_path = os.path.join(model_dir, "projector_vis_scale.pt") ckpt = torch.load(proj_path, map_location=device) model.vision_projector.load_state_dict(ckpt["vision_projector"]) if ckpt["vis_scale"] is not None: if isinstance(ckpt["vis_scale"], torch.Tensor): model.vis_scale.data = ckpt["vis_scale"].to(device) else: model.vis_scale.data.fill_(ckpt["vis_scale"]) model.eval() return model, load_nifti_volume, CANONICAL_PROMPT # ============================================================ # LIME SCORING: NLL of reference report # ============================================================ @torch.no_grad() def lime_score_report_nll(volumes, model, prompt: str, report_ref: str, batch_size: int = 1): """ Score perturbed volumes with NLL of reference report. Lower NLL = model more confident in reference report = better support. LIME maximizes this score, so we return -NLL. Implementation follows original interpretability.py logic. """ device = model.lm_device # 1) Tokenize prompt and report separately prompt_ids = model.tokenizer(prompt, return_tensors="pt", add_special_tokens=True).input_ids.to(device) report_ids = model.tokenizer(report_ref, return_tensors="pt", add_special_tokens=False).input_ids.to(device) text_ids_1 = torch.cat([prompt_ids, report_ids], dim=1) # (1, P+R) # 2) Prepare volumes tensor vols = torch.from_numpy(np.asarray(volumes)).to(device) if vols.ndim == 4: # (N, Z, Y, X) vols = vols.unsqueeze(1) # (N, 1, Z, Y, X) N = vols.shape[0] scores = [] for i in range(0, N, batch_size): v = vols[i:i+batch_size].to(dtype=torch.bfloat16) # 3) Encode volume -> vision tokens vision_tokens = model.encode_volume(v) # (B, V, D_lm) # 4) Text embeddings text_ids = text_ids_1.repeat(v.size(0), 1) # (B, P+R) text_embeds = model.language_model.get_input_embeddings()(text_ids) # (B, P+R, D_lm) # 5) Concatenate embeds: [vision | text] inputs_embeds = torch.cat([vision_tokens, text_embeds], dim=1) # (B, V+P+R, D_lm) # 6) Labels: -100 on vision + prompt, target on report V = vision_tokens.size(1) prompt_mask = torch.full((v.size(0), prompt_ids.size(1)), -100, device=device, dtype=torch.long) vision_mask = torch.full((v.size(0), V), -100, device=device, dtype=torch.long) labels = torch.cat([vision_mask, prompt_mask, report_ids.repeat(v.size(0), 1)], dim=1) # (B, V+P+R) # 7) Forward LM with labels -> loss out = model.language_model(inputs_embeds=inputs_embeds, labels=labels) loss = out.loss # scalar mean over batch & tokens (masked) scores.append((-loss).detach().float().cpu()) return torch.stack(scores).numpy().reshape(-1, 1) # ============================================================ # 3D BRAIN SEGMENTATION # ============================================================ def quick_brain_mask( vol_zyx: np.ndarray, p_thresh: float = 25, min_cc_vox: int = 2000 ) -> np.ndarray: """Create brain mask from 3D volume.""" v = vol_zyx.astype(np.float32) thr = np.percentile(v, p_thresh) m = v > thr m = binary_opening(m, structure=ball(1)) m = binary_closing(m, structure=ball(2)) m = binary_fill_holes(m) m = remove_small_objects(m, min_size=min_cc_vox) # Keep largest connected component cc = cc_label(m) if cc.max() > 1: sizes = [(i, (cc == i).sum()) for i in range(1, cc.max() + 1)] largest = max(sizes, key=lambda x: x[1])[0] m = (cc == largest) return m.astype(bool) def big_supervoxels_brain_only( vol_zyx: np.ndarray, n_segments: int = 20, compactness: float = 0.05, sigma: float = 1.0, p_thresh: float = 25, min_cc_vox: int = 2000, ): """ Segment ONLY brain tissue using SLIC with brain mask. Returns segments with 0-based contiguous labels: - 0 = background (not brain) - 1, 2, ..., N = brain supervoxels This labeling is CRITICAL for LIME 0.2.0.1 which uses feature indices directly as segment labels: mask[segments == feature_idx]. With 0-based contiguous labels, feature i maps exactly to segment i. Background (0) adds one harmless noise feature to LIME's regression. """ brain = quick_brain_mask(vol_zyx, p_thresh=p_thresh, min_cc_vox=min_cc_vox) # Segment ONLY brain tissue using mask parameter. # Without mask, SLIC wastes most segments on empty background # (e.g. 84.5% background for typical BraTS volumes). seg = slic( vol_zyx, n_segments=n_segments, compactness=compactness, sigma=sigma, channel_axis=None, start_label=1, mask=brain, # ← brain-only segmentation ) # SLIC with mask assigns -1 to background voxels. # Relabel background to 0 for clean 0-based contiguous labels. seg[seg < 0] = 0 # Verify labels are contiguous 0..N (required for LIME feature indexing). unique = np.unique(seg) expected = np.arange(len(unique)) if not np.array_equal(unique, expected): new_seg = np.zeros_like(seg) for new_id, old_id in enumerate(unique): new_seg[seg == old_id] = new_id seg = new_seg print(f"ℹ️ Relabeled segments to contiguous 0..{len(unique)-1}", flush=True) n_brain_segs = len(np.unique(seg)) - 1 # exclude background (0) print(f"🧩 Brain-only SLIC: {n_brain_segs} brain supervoxels " f"(requested {n_segments}), brain covers {100*brain.sum()/brain.size:.1f}% of volume", flush=True) return seg, brain def make_segmentation_fn(cached_segments: np.ndarray): """Return a segmentation function that always returns pre-computed segments.""" def segmentation_fn(vol): return cached_segments return segmentation_fn # ============================================================ # VISUALIZATION HELPERS (SAVE TO FILE) # ============================================================ def save_slice_png(volume_zyx: np.ndarray, out_path: str, axis: int = 0, idx: int | None = None, rot_k: int = 0): if idx is None: idx = volume_zyx.shape[axis] // 2 if axis == 0: img = volume_zyx[idx, :, :] title = f"Axial (Z) slice {idx}" elif axis == 1: img = volume_zyx[:, idx, :] title = f"Coronal (Y) slice {idx}" else: img = volume_zyx[:, :, idx] title = f"Sagittal (X) slice {idx}" img = np.rot90(img, k=rot_k) plt.figure(figsize=(6, 6)) plt.imshow(img, cmap="gray", origin="lower") plt.title(title) plt.axis("off") os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True) plt.tight_layout() plt.savefig(out_path, dpi=160) plt.close() def save_overlay_png( volume_zyx: np.ndarray, heat_zyx: np.ndarray, out_path: str, axis: int = 0, idx: int | None = None, alpha: float = 0.45, clip_q: float = 0.99, rot_k: int = 0, ): assert volume_zyx.shape == heat_zyx.shape if idx is None: idx = volume_zyx.shape[axis] // 2 if axis == 0: img = volume_zyx[idx, :, :] h = heat_zyx[idx, :, :] title = f"Axial (Z) overlay slice {idx}" elif axis == 1: img = volume_zyx[:, idx, :] h = heat_zyx[:, idx, :] title = f"Coronal (Y) overlay slice {idx}" else: img = volume_zyx[:, :, idx] h = heat_zyx[:, :, idx] title = f"Sagittal (X) overlay slice {idx}" img = np.rot90(img, k=rot_k) h = np.rot90(h, k=rot_k) m = float(max(np.quantile(np.abs(h), clip_q), 1e-8)) h_vis = np.clip(h, -m, m) plt.figure(figsize=(6, 6)) plt.imshow(img, cmap="gray", origin="lower") im = plt.imshow(h_vis, cmap="bwr", alpha=alpha, origin="lower", vmin=-m, vmax=m) plt.title(title) plt.axis("off") plt.colorbar(im, fraction=0.046, pad=0.04) os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True) plt.tight_layout() plt.savefig(out_path, dpi=160) plt.close() def save_overlay_grid_png( volume_zyx: np.ndarray, heat_zyx: np.ndarray, out_path: str, axis: int = 0, idxs=None, n_cols: int = 6, n_slices: int = 36, alpha: float = 0.45, clip_q: float = 0.99, rot_k: int = 0, figsize_per_cell: float = 2.2, add_colorbar: bool = False, suptitle: str | None = None, ): assert volume_zyx.shape == heat_zyx.shape assert axis in (0, 1, 2) dim = volume_zyx.shape[axis] if idxs is None: lo = int(0.10 * (dim - 1)) hi = int(0.90 * (dim - 1)) if hi <= lo: lo, hi = 0, dim - 1 idxs = np.linspace(lo, hi, n_slices, dtype=int).tolist() else: idxs = list(map(int, idxs)) n = len(idxs) n_rows = int(np.ceil(n / n_cols)) m = float(max(np.quantile(np.abs(heat_zyx), clip_q), 1e-8)) fig_w = n_cols * figsize_per_cell fig_h = n_rows * figsize_per_cell fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_w, fig_h)) axes = np.array(axes).reshape(-1) def get_slice(arr, ax, i): if ax == 0: s = arr[i, :, :] elif ax == 1: s = arr[:, i, :] else: s = arr[:, :, i] return np.rot90(s, k=rot_k) im_for_cbar = None for j, idx in enumerate(idxs): axp = axes[j] img = get_slice(volume_zyx, axis, idx) h = get_slice(heat_zyx, axis, idx) h_vis = np.clip(h, -m, m) axp.imshow(img, cmap="gray", origin="lower") im_for_cbar = axp.imshow(h_vis, cmap="bwr", alpha=alpha, origin="lower", vmin=-m, vmax=m) axp.set_title(f"{idx}", fontsize=9) axp.axis("off") for k in range(n, n_rows * n_cols): axes[k].axis("off") if suptitle is None: name = "Axial (Z)" if axis == 0 else ("Coronal (Y)" if axis == 1 else "Sagittal (X)") suptitle = f"{name} | rot {rot_k*90}° | clip_q={clip_q} | alpha={alpha}" fig.suptitle(suptitle, y=0.98, fontsize=12) if add_colorbar and im_for_cbar is not None: cbar = fig.colorbar(im_for_cbar, ax=axes[:n], fraction=0.02, pad=0.01) cbar.set_label("LIME weight (clipped)", rotation=90) os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True) plt.tight_layout() plt.savefig(out_path, dpi=160) plt.close(fig) def create_overlay_from_segments(segments_2d: np.ndarray, weights: dict, alpha=0.5) -> np.ndarray: """ Create RGBA overlay from segments and LIME weights. Red = positive (supportive), Blue = negative (contradictory) Visualizes ALL supervoxels based on their weights. Returns (H, W, 4) RGBA array """ H, W = segments_2d.shape overlay = np.zeros((H, W, 4), dtype=np.float32) # Find max absolute weight for normalization all_weights = [float(v) for k, v in weights.items() if int(k) != 0] if not all_weights: return overlay max_abs_weight = max(abs(w) for w in all_weights) if max_abs_weight < 1e-8: return overlay # Apply color based on weight for each segment (ALL segments) for seg_id_str, weight in weights.items(): seg_id = int(seg_id_str) if seg_id == 0: # Skip background continue mask = (segments_2d == seg_id) if not mask.any(): continue # Normalize weight norm_weight = weight / max_abs_weight # Create edges for better visibility edge_mask = mask & (~binary_erosion(mask)) if weight > 0: # Positive = Red # Fill overlay[mask, 0] = 1.0 # R overlay[mask, 1] = 0.0 # G overlay[mask, 2] = 0.0 # B overlay[mask, 3] = alpha * abs(norm_weight) # Alpha # Edge (brighter) overlay[edge_mask, 3] = min(1.0, alpha * abs(norm_weight) * 2.0) else: # Negative = Blue # Fill overlay[mask, 0] = 0.0 # R overlay[mask, 1] = 0.4 # G (slight green for better visibility) overlay[mask, 2] = 1.0 # B overlay[mask, 3] = alpha * abs(norm_weight) # Alpha # Edge (brighter) overlay[edge_mask, 3] = min(1.0, alpha * abs(norm_weight) * 2.0) return overlay def get_top_positive_supervoxel_id(weights: dict, ignore_ids=(0,)) -> int: """Return segment ID with highest positive LIME weight (most RED / supportive). Ignores background segment 0 by default.""" items = [(int(k), float(v)) for k, v in weights.items() if int(k) not in ignore_ids] if not items: raise ValueError("weights vuoto o contiene solo segmenti ignorati.") pos = [(k, v) for k, v in items if v > 0] if pos: return max(pos, key=lambda kv: kv[1])[0] return max(items, key=lambda kv: kv[1])[0] def get_top_negative_supervoxel_id(weights: dict, ignore_ids=(0,)) -> int: """ Ritorna l'id del segmento con weight più negativo (più 'blu'). Se non esistono pesi negativi, ritorna comunque il min (anche se positivo). """ items = [(int(k), float(v)) for k, v in weights.items() if int(k) not in ignore_ids] if not items: raise ValueError("weights vuoto o contiene solo segmenti ignorati.") neg = [(k, v) for k, v in items if v < 0] if neg: return min(neg, key=lambda kv: kv[1])[0] # più negativo = minimo return min(items, key=lambda kv: kv[1])[0] def _rgba_overlay_from_mask(mask2d: np.ndarray, rgba=(1.0, 0.0, 0.0), alpha=0.45) -> np.ndarray: """ mask2d: float/bool (H,W) con 1 dove disegnare rgba: (R,G,B) in [0,1] """ m = mask2d.astype(np.float32) overlay = np.zeros((m.shape[0], m.shape[1], 4), dtype=np.float32) overlay[..., 0] = float(rgba[0]) overlay[..., 1] = float(rgba[1]) overlay[..., 2] = float(rgba[2]) overlay[..., 3] = float(alpha) * m return overlay def _rgba_edge_from_mask(mask2d: np.ndarray, rgba=(1.0, 0.0, 0.0), edge_alpha=1.0) -> np.ndarray: m = mask2d.astype(bool) edge = m & (~binary_erosion(m)) overlay = np.zeros((m.shape[0], m.shape[1], 4), dtype=np.float32) overlay[..., 0] = float(rgba[0]) overlay[..., 1] = float(rgba[1]) overlay[..., 2] = float(rgba[2]) overlay[..., 3] = float(edge_alpha) * edge.astype(np.float32) return overlay def save_overlay_single_supervoxel_png( volume_zyx: np.ndarray, segments_zyx: np.ndarray, weights: dict, out_path: str, axis: int = 0, idx: int | None = None, rot_k: int = 0, alpha: float = 0.45, origin: str = "lower", edge_alpha: float = 1.0, ): """ Salva overlay con: - supervoxel più 'rosso' (peso massimo positivo) in rosso acceso - supervoxel più 'blu' (peso più negativo) in blu acceso Ritorna (best_red_id, best_blue_id). """ best_red_id = get_top_positive_supervoxel_id(weights, ignore_ids=(0,)) best_blue_id = get_top_negative_supervoxel_id(weights, ignore_ids=(0,)) mask_red_3d = (segments_zyx == best_red_id).astype(np.float32) mask_blue_3d = (segments_zyx == best_blue_id).astype(np.float32) if idx is None: idx = volume_zyx.shape[axis] // 2 # estrai slice (mantengo la tua .T solo per assiale) if axis == 0: img = volume_zyx[idx, :, :] m_red = mask_red_3d[idx, :, :] m_blue = mask_blue_3d[idx, :, :] title = f"Axial(Z) slice {idx} | red={best_red_id} | blue={best_blue_id}" elif axis == 1: img = volume_zyx[:, idx, :] m_red = mask_red_3d[:, idx, :] m_blue = mask_blue_3d[:, idx, :] title = f"Coronal(Y) slice {idx} | red={best_red_id} | blue={best_blue_id}" else: img = volume_zyx[:, :, idx] m_red = mask_red_3d[:, :, idx] m_blue = mask_blue_3d[:, :, idx] title = f"Sagittal(X) slice {idx} | red={best_red_id} | blue={best_blue_id}" img = np.rot90(img, k=rot_k) m_red = np.rot90(m_red, k=rot_k) m_blue = np.rot90(m_blue, k=rot_k) plt.figure(figsize=(6, 6)) plt.imshow(img, cmap="gray", origin=origin) # blu prima, rosso sopra (così se sovrappongono vince il rosso) plt.imshow(_rgba_overlay_from_mask(m_blue, rgba=(0.0, 0.4, 1.0), alpha=alpha), origin=origin) plt.imshow(_rgba_edge_from_mask(m_blue, rgba=(0.0, 0.4, 1.0), edge_alpha=edge_alpha), origin=origin) plt.imshow(_rgba_overlay_from_mask(m_red, rgba=(1.0, 0.0, 0.0), alpha=alpha), origin=origin) plt.imshow(_rgba_edge_from_mask(m_red, rgba=(1.0, 0.0, 0.0), edge_alpha=edge_alpha), origin=origin) plt.title(title) plt.axis("off") os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True) plt.tight_layout() plt.savefig(out_path, dpi=160) plt.close() return best_red_id, best_blue_id def save_overlay_grid_single_supervoxel_png( volume_zyx: np.ndarray, segments_zyx: np.ndarray, weights: dict, out_path: str, axis: int = 0, n_cols: int = 8, rot_k: int = 0, alpha: float = 0.45, origin: str = "lower", suptitle: str | None = None, edge_alpha: float = 1.0, ): """ Griglia overlay con TUTTE le slice, organizzate come save_flair_grid_all: - supervoxel più 'rosso' in rosso acceso - supervoxel più 'blu' in blu acceso Ritorna (best_red_id, best_blue_id). """ best_red_id = get_top_positive_supervoxel_id(weights, ignore_ids=(0,)) best_blue_id = get_top_negative_supervoxel_id(weights, ignore_ids=(0,)) mask_red_3d = (segments_zyx == best_red_id).astype(np.float32) mask_blue_3d = (segments_zyx == best_blue_id).astype(np.float32) dim = volume_zyx.shape[axis] n_rows = int(np.ceil(dim / n_cols)) fig, axes = plt.subplots( n_rows, n_cols, figsize=(n_cols * 2, n_rows * 2), facecolor="black" ) axes = np.array(axes).reshape(-1) def get_slice(arr, ax, i): if ax == 0: s = arr[i, :, :] elif ax == 1: s = arr[:, i, :] else: s = arr[:, :, i] return np.rot90(s, k=rot_k) for i in range(dim): img = get_slice(volume_zyx, axis, i) m_red = get_slice(mask_red_3d, axis, i) m_blue = get_slice(mask_blue_3d, axis, i) axes[i].imshow(img, cmap="gray", origin=origin) # blu sotto, rosso sopra axes[i].imshow(_rgba_overlay_from_mask(m_blue, rgba=(0.0, 0.4, 1.0), alpha=alpha), origin=origin) axes[i].imshow(_rgba_edge_from_mask(m_blue, rgba=(0.0, 0.4, 1.0), edge_alpha=edge_alpha), origin=origin) axes[i].imshow(_rgba_overlay_from_mask(m_red, rgba=(1.0, 0.0, 0.0), alpha=alpha), origin=origin) axes[i].imshow(_rgba_edge_from_mask(m_red, rgba=(1.0, 0.0, 0.0), edge_alpha=edge_alpha), origin=origin) axes[i].set_title( f"z={i}", color="cyan", fontsize=9, fontweight='bold' ) axes[i].axis("off") # Spegni assi inutilizzati for i in range(dim, len(axes)): axes[i].axis("off") if suptitle is None: name = "Axial(Z)" if axis == 0 else ("Coronal(Y)" if axis == 1 else "Sagittal(X)") suptitle = f"{name} | red={best_red_id} | blue={best_blue_id} | rot {rot_k*90}°" fig.suptitle(suptitle) os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True) plt.tight_layout() plt.savefig(out_path, dpi=150, bbox_inches="tight") plt.close(fig) return best_red_id, best_blue_id def save_volume_slices_overlay( vol: torch.Tensor, heat: np.ndarray, # wvol è np.ndarray nel tuo codice save_path: str, title: str = "Volume overlay", ncols: int = 8, is_healthy: bool = False, alpha: float = 0.45, clip_q: float = 0.99, rot_k: int = 0, brain_mask: np.ndarray | None = None, # <--- aggiungi ): # --- squeeze to (D,H,W) if vol.ndim == 5: vol = vol[0, 0] elif vol.ndim == 4: vol = vol[0] vol_np = vol.detach().cpu().numpy().astype(np.float32) heat_np = heat.astype(np.float32) if vol_np.shape != heat_np.shape: raise ValueError(f"Shape mismatch: vol {vol_np.shape} vs heat {heat_np.shape}") if brain_mask is not None: if brain_mask.shape != vol_np.shape: raise ValueError(f"Brain mask shape mismatch: {brain_mask.shape} vs {vol_np.shape}") brain_np = brain_mask.astype(bool) else: brain_np = None D, H, W = vol_np.shape nrows = int(np.ceil(D / ncols)) # clipping globale coerente m = float(max(np.quantile(np.abs(heat_np), clip_q), 1e-8)) fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 2, nrows * 2), facecolor="black") axes = axes.flatten() for i in range(D): img = vol_np[i] h = heat_np[i] if brain_np is not None: b = brain_np[i] else: b = None # rotazione img = np.rot90(img, k=rot_k) h = np.rot90(h, k=rot_k) if b is not None: b = np.rot90(b, k=rot_k) # clip heat h_vis = np.clip(h, -m, m) ax = axes[i] ax.set_facecolor("black") if b is not None: # Maschera img: fuori cervello -> trasparente img_ma = np.ma.array(img, mask=~b) ax.imshow(img_ma, cmap="gray", origin="lower") # Maschera anche heat: fuori cervello -> trasparente h_ma = np.ma.array(h_vis, mask=~b) ax.imshow(h_ma, cmap="bwr", alpha=alpha, vmin=-m, vmax=m, origin="lower") else: ax.imshow(img, cmap="gray", origin="lower") ax.imshow(h_vis, cmap="bwr", alpha=alpha, vmin=-m, vmax=m, origin="lower") ax.set_title(f"z={i}", color="cyan", fontsize=9, fontweight="bold") ax.axis("off") for i in range(D, len(axes)): axes[i].set_facecolor("black") axes[i].axis("off") fig.suptitle(f"{title} {'(Healthy)' if is_healthy else '(Pathological)'}", color="white") Path(save_path).parent.mkdir(parents=True, exist_ok=True) plt.tight_layout() plt.savefig(save_path, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor()) plt.close(fig) def save_flair_grid_all(nifti_path: str, save_path: str, load_nifti_volume_fn, ncols: int = 8): """ Save grid of all slices from a NIfTI file. Note: load_nifti_volume_fn must be provided (get it from import_architecture_from_model_dir). """ vol = load_nifti_volume_fn(nifti_path) vol = vol.squeeze(0).squeeze(0).detach().cpu().numpy() D = vol.shape[0] nrows = int(np.ceil(D / ncols)) fig, axes = plt.subplots( nrows, ncols, figsize=(ncols * 2, nrows * 2), facecolor="black" ) axes = axes.flatten() for i in range(D): axes[i].imshow(vol[i], cmap="gray", origin="lower") axes[i].set_title( f"z={i}", color="cyan", fontsize=9, fontweight='bold' ) axes[i].axis("off") # Spegni assi inutilizzati for i in range(D, len(axes)): axes[i].axis("off") # crea cartella se non esiste Path(save_path).parent.mkdir(parents=True, exist_ok=True) plt.tight_layout() plt.savefig(save_path, dpi=150, bbox_inches="tight") plt.close(fig) # ============================================================ # MAIN INTERPRETABILITY PIPELINE # ============================================================ def run_interpretability( model, load_nifti_volume, CANONICAL_PROMPT, mri_path: str, report: str, output_dir: str, lime_samples: int = 100, n_segments: int = 20, hide_color: float = 0.0, alpha: float = 0.45, clip_q: float = 0.99, seed: int = 42, ): """Run LIME interpretability on a single MRI scan.""" set_seed(seed) device = next(model.parameters()).device out_dir = Path(output_dir) out_dir.mkdir(parents=True, exist_ok=True) print(f"\n{'='*60}") print("🔍 BrainGemma3D LIME Interpretability") print(f"{'='*60}") print(f"📂 MRI: {mri_path}") print(f"📝 Report: {report[:100]}...") print(f"💾 Output: {output_dir}") print(f"{'='*60}\n") # Load volume print("📥 Loading MRI volume...") volume = load_nifti_volume(mri_path, target_size=(64, 128, 128)).to(device) if volume.ndim == 4: volume = volume.unsqueeze(0) vol_np = volume.squeeze().cpu().numpy() # (D, H, W) print(f" Shape: {vol_np.shape}") # Create supervoxels print(f"\n🧩 Creating {n_segments} brain supervoxels...") segments, brain_mask = big_supervoxels_brain_only(vol_np, n_segments=n_segments) # Prepare LIME explainer print(f"\n🔬 Running LIME with {lime_samples} samples...") segmentation_fn = make_segmentation_fn(segments) explainer = lime_image.LimeImageExplainer() def predict_fn(vols_4d): """ vols_4d: (n_samples, D, H, W) - LIME perturbed volumes Returns: (n_samples,) array of scores """ # Add batch and channel dims: (n_samples, 1, D, H, W) vols_5d = vols_4d[:, np.newaxis, :, :, :] scores = lime_score_report_nll( vols_5d, model, prompt=CANONICAL_PROMPT, report_ref=report, batch_size=1, ) return scores # Run LIME explanation = explainer.explain_instance( vol_np, # (D, H, W) predict_fn, top_labels=1, hide_color=hide_color, num_samples=lime_samples, segmentation_fn=segmentation_fn, ) # Get weights label = explanation.top_labels[0] weights = dict(explanation.local_exp[label]) print(f"\n✅ LIME completed!") print(f" Supervoxel weights (sample): {list(weights.items())[:5]}") # Build weight volume print("\n📊 Building weight volume...") wvol = np.zeros_like(vol_np, dtype=np.float32) for seg_id, w in weights.items(): seg_id = int(seg_id) if seg_id == 0: # Skip background (segment 0) continue wvol[segments == seg_id] = float(w) # Safety: zero out anything outside brain mask wvol[~brain_mask] = 0.0 # Save visualizations print("\n💾 Saving visualizations...") # 1. Full volume overlay with brain mask save_volume_slices_overlay( volume, # torch tensor wvol, # numpy array heatmap str(out_dir / "overlay_slices.png"), title="LIME Interpretability", ncols=8, is_healthy=False, alpha=alpha, clip_q=clip_q, rot_k=0, brain_mask=brain_mask, ) # 2. Top supervoxels save_overlay_grid_single_supervoxel_png( vol_np, segments, weights, out_path=str(out_dir / "lime_top_supervoxels_grid.png"), axis=0, n_cols=8, alpha=0.55, suptitle="Top Supportive (Red) and Contradicting (Blue) Supervoxels" ) # 3. 2x3 grid - Selected slices (Original + LIME overlay) print("\n💾 Creating 2x3 grid figure (original + LIME overlay)...") D = vol_np.shape[0] # Select 3 representative slices from middle 60% of volume lo = int(0.30 * D) hi = int(0.70 * D) selected_slices = np.linspace(lo, hi, 3, dtype=int).tolist() n_slices = len(selected_slices) fig, axes = plt.subplots(2, n_slices, figsize=(n_slices * 4, 2 * 4)) for col, slice_idx in enumerate(selected_slices): # Extract axial slice img_slice = vol_np[slice_idx, :, :] seg_slice = segments[slice_idx, :, :] # Row 0: Original axes[0, col].imshow(img_slice, cmap='gray', origin='lower', interpolation='bilinear') axes[0, col].set_title(f'Slice {slice_idx}', fontsize=12, fontweight='bold') axes[0, col].axis('off') # Row 1: LIME Overlay axes[1, col].imshow(img_slice, cmap='gray', origin='lower', interpolation='bilinear') overlay = create_overlay_from_segments(seg_slice, weights, alpha=0.5) axes[1, col].imshow(overlay, origin='lower', interpolation='nearest') axes[1, col].axis('off') # Add row labels axes[0, 0].text(-0.15, 0.5, 'Original', transform=axes[0, 0].transAxes, fontsize=14, fontweight='bold', va='center', rotation=90) axes[1, 0].text(-0.15, 0.5, 'LIME Overlay', transform=axes[1, 0].transAxes, fontsize=14, fontweight='bold', va='center', rotation=90) plt.tight_layout() plt.savefig(str(out_dir / "lime_2x3_grid.png"), dpi=300, bbox_inches='tight', facecolor='white') plt.close() print(f"✅ Saved 2x3 grid (slices {selected_slices})") # 4. Save report and weights with open(out_dir / "lime_report.txt", "w") as f: f.write(f"Reference Report:\n{report}\n\n") f.write(f"LIME Supervoxel Weights (top 20):\n") sorted_weights = sorted(weights.items(), key=lambda x: abs(x[1]), reverse=True) for seg_id, weight in sorted_weights[:20]: if int(seg_id) != 0: # Skip background f.write(f" Segment {seg_id}: {weight:.4f}\n") # Save supervoxel weights in JSON format (exclude background) weights_dict = {int(k): float(v) for k, v in weights.items() if int(k) != 0} with open(out_dir / "lime_weights.json", "w") as f: json.dump(weights_dict, f, indent=2) print(f"💾 Saved lime_weights.json ({len(weights_dict)} brain supervoxels)", flush=True) # Save numpy arrays for further analysis np.save(str(out_dir / "lime_wvol.npy"), wvol) np.save(str(out_dir / "lime_segments.npy"), segments) print(f"✅ Saved wvol/segments arrays", flush=True) print(f" wvol stats: shape={wvol.shape} min={wvol.min():.4g} max={wvol.max():.4g}", flush=True) print(f"\n{'='*60}") print("✅ Interpretability analysis completed!") print(f" Results saved to: {output_dir}") print(f"{'='*60}\n") return weights, wvol # ============================================================ # MAIN SCRIPT # ============================================================ def main(): parser = argparse.ArgumentParser(description="BrainGemma3D LIME Interpretability") # Required parser.add_argument("--model_dir", required=True, help="Path to BrainGemma3D model folder") parser.add_argument("--mri_path", required=True, help="Path to .nii/.nii.gz MRI scan") # Optional parser.add_argument("--report", default=None, help="Reference report text. If not provided, will generate it first.") parser.add_argument("--output_dir", default="./lime_output", help="Output directory for results") # Generation params (if report not provided) parser.add_argument("--max_new_tokens", type=int, default=256) parser.add_argument("--temperature", type=float, default=0.1) parser.add_argument("--top_p", type=float, default=0.9) # LIME params parser.add_argument("--lime_samples", type=int, default=100, help="Number of LIME samples") parser.add_argument("--n_segments", type=int, default=20, help="Number of supervoxels") parser.add_argument("--hide_color", type=float, default=0.0, help="Hide color for LIME perturbations") # Visualization parser.add_argument("--alpha", type=float, default=0.45, help="Overlay transparency") parser.add_argument("--clip_q", type=float, default=0.99, help="Heatmap clipping quantile") # Misc parser.add_argument("--seed", type=int, default=42, help="Random seed") args = parser.parse_args() # Load model device = "cuda" if torch.cuda.is_available() else "cpu" print(f"🚀 Loading BrainGemma3D model from {args.model_dir}...") model, load_nifti_volume, CANONICAL_PROMPT = load_full_model(args.model_dir, device) print("✅ Model loaded successfully!") # Generate report if not provided if args.report is None: print("\n📝 No report provided, generating one...") volume = load_nifti_volume(args.mri_path, target_size=(64, 128, 128)).to(device) if volume.ndim == 4: volume = volume.unsqueeze(0) with torch.no_grad(): report = model.generate_report( volume, prompt=CANONICAL_PROMPT, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_p=args.top_p, ) print(f"✅ Generated report: {report}") else: report = args.report # Run interpretability run_interpretability( model=model, load_nifti_volume=load_nifti_volume, CANONICAL_PROMPT=CANONICAL_PROMPT, mri_path=args.mri_path, report=report, output_dir=args.output_dir, lime_samples=args.lime_samples, n_segments=args.n_segments, hide_color=args.hide_color, alpha=args.alpha, clip_q=args.clip_q, seed=args.seed, ) if __name__ == "__main__": main()