| |
| """ |
| BrainGemma3D + LIME Interpretability |
| ================================================== |
| Usage: |
| python braingemma3d_interpretability.py \\ |
| --model_dir ./final_model \\ |
| --mri_path /path/to/scan.nii.gz \\ |
| --report "The brain shows a mass in the left frontal lobe..." \\ |
| --output_dir ./lime_output |
| |
| If --report is not provided, the script will generate it first. |
| """ |
|
|
| import os |
| import sys |
| import json |
| import argparse |
| import random |
| import importlib.util |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
|
|
| |
| from lime import lime_image |
| from skimage.segmentation import slic |
| from scipy.ndimage import binary_closing, binary_opening, binary_fill_holes, binary_erosion |
| from skimage.morphology import ball, remove_small_objects |
| from skimage.measure import label as cc_label |
|
|
|
|
| def set_seed(seed: int = 42): |
| """Set random seed for reproducibility.""" |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def import_architecture_from_model_dir(model_dir): |
| """Dynamically import braingemma3d_architecture.py from model folder.""" |
| arch_path = os.path.join(model_dir, "braingemma3d_architecture.py") |
| spec = importlib.util.spec_from_file_location("braingemma3d_architecture", arch_path) |
| module = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(module) |
| return module |
|
|
|
|
| def load_full_model(model_dir, device): |
| """Load BrainGemma3D model with projector weights.""" |
| arch_module = import_architecture_from_model_dir(model_dir) |
| BrainGemma3D = arch_module.BrainGemma3D |
| load_nifti_volume = arch_module.load_nifti_volume |
| CANONICAL_PROMPT = arch_module.CANONICAL_PROMPT |
|
|
| with open(os.path.join(model_dir, "model_config.json")) as f: |
| cfg = json.load(f) |
|
|
| model = BrainGemma3D( |
| vision_model_dir=os.path.join(model_dir, cfg["vision_model_dir"]), |
| language_model_dir=os.path.join(model_dir, cfg["language_model_dir"]), |
| depth=cfg["depth"], |
| num_vision_tokens=cfg["num_vision_tokens"], |
| freeze_vision=True, |
| freeze_language=True, |
| device_map={"": 0} if device == "cuda" else None, |
| ) |
|
|
| |
| proj_path = os.path.join(model_dir, "projector_vis_scale.pt") |
| ckpt = torch.load(proj_path, map_location=device) |
| model.vision_projector.load_state_dict(ckpt["vision_projector"]) |
|
|
| if ckpt["vis_scale"] is not None: |
| if isinstance(ckpt["vis_scale"], torch.Tensor): |
| model.vis_scale.data = ckpt["vis_scale"].to(device) |
| else: |
| model.vis_scale.data.fill_(ckpt["vis_scale"]) |
|
|
| model.eval() |
| return model, load_nifti_volume, CANONICAL_PROMPT |
|
|
|
|
| |
| |
| |
| @torch.no_grad() |
| def lime_score_report_nll(volumes, model, prompt: str, report_ref: str, batch_size: int = 1): |
| """ |
| Score perturbed volumes with NLL of reference report. |
| Lower NLL = model more confident in reference report = better support. |
| LIME maximizes this score, so we return -NLL. |
| |
| Implementation follows original interpretability.py logic. |
| """ |
| device = model.lm_device |
| |
| |
| prompt_ids = model.tokenizer(prompt, return_tensors="pt", add_special_tokens=True).input_ids.to(device) |
| report_ids = model.tokenizer(report_ref, return_tensors="pt", add_special_tokens=False).input_ids.to(device) |
| text_ids_1 = torch.cat([prompt_ids, report_ids], dim=1) |
| |
| |
| vols = torch.from_numpy(np.asarray(volumes)).to(device) |
| if vols.ndim == 4: |
| vols = vols.unsqueeze(1) |
| |
| N = vols.shape[0] |
| scores = [] |
| |
| for i in range(0, N, batch_size): |
| v = vols[i:i+batch_size].to(dtype=torch.bfloat16) |
| |
| |
| vision_tokens = model.encode_volume(v) |
| |
| |
| text_ids = text_ids_1.repeat(v.size(0), 1) |
| text_embeds = model.language_model.get_input_embeddings()(text_ids) |
| |
| |
| inputs_embeds = torch.cat([vision_tokens, text_embeds], dim=1) |
| |
| |
| V = vision_tokens.size(1) |
| prompt_mask = torch.full((v.size(0), prompt_ids.size(1)), -100, device=device, dtype=torch.long) |
| vision_mask = torch.full((v.size(0), V), -100, device=device, dtype=torch.long) |
| labels = torch.cat([vision_mask, prompt_mask, report_ids.repeat(v.size(0), 1)], dim=1) |
| |
| |
| out = model.language_model(inputs_embeds=inputs_embeds, labels=labels) |
| loss = out.loss |
| scores.append((-loss).detach().float().cpu()) |
| |
| return torch.stack(scores).numpy().reshape(-1, 1) |
|
|
|
|
| |
| |
| |
| def quick_brain_mask( |
| vol_zyx: np.ndarray, |
| p_thresh: float = 25, |
| min_cc_vox: int = 2000 |
| ) -> np.ndarray: |
| """Create brain mask from 3D volume.""" |
| v = vol_zyx.astype(np.float32) |
| thr = np.percentile(v, p_thresh) |
| m = v > thr |
| m = binary_opening(m, structure=ball(1)) |
| m = binary_closing(m, structure=ball(2)) |
| m = binary_fill_holes(m) |
| m = remove_small_objects(m, min_size=min_cc_vox) |
|
|
| |
| cc = cc_label(m) |
| if cc.max() > 1: |
| sizes = [(i, (cc == i).sum()) for i in range(1, cc.max() + 1)] |
| largest = max(sizes, key=lambda x: x[1])[0] |
| m = (cc == largest) |
| |
| return m.astype(bool) |
|
|
|
|
| def big_supervoxels_brain_only( |
| vol_zyx: np.ndarray, |
| n_segments: int = 20, |
| compactness: float = 0.05, |
| sigma: float = 1.0, |
| p_thresh: float = 25, |
| min_cc_vox: int = 2000, |
| ): |
| """ |
| Segment ONLY brain tissue using SLIC with brain mask. |
| |
| Returns segments with 0-based contiguous labels: |
| - 0 = background (not brain) |
| - 1, 2, ..., N = brain supervoxels |
| |
| This labeling is CRITICAL for LIME 0.2.0.1 which uses feature |
| indices directly as segment labels: mask[segments == feature_idx]. |
| With 0-based contiguous labels, feature i maps exactly to segment i. |
| Background (0) adds one harmless noise feature to LIME's regression. |
| """ |
| brain = quick_brain_mask(vol_zyx, p_thresh=p_thresh, min_cc_vox=min_cc_vox) |
|
|
| |
| |
| |
| seg = slic( |
| vol_zyx, |
| n_segments=n_segments, |
| compactness=compactness, |
| sigma=sigma, |
| channel_axis=None, |
| start_label=1, |
| mask=brain, |
| ) |
| |
| |
| seg[seg < 0] = 0 |
|
|
| |
| unique = np.unique(seg) |
| expected = np.arange(len(unique)) |
| if not np.array_equal(unique, expected): |
| new_seg = np.zeros_like(seg) |
| for new_id, old_id in enumerate(unique): |
| new_seg[seg == old_id] = new_id |
| seg = new_seg |
| print(f"โน๏ธ Relabeled segments to contiguous 0..{len(unique)-1}", flush=True) |
|
|
| n_brain_segs = len(np.unique(seg)) - 1 |
| print(f"๐งฉ Brain-only SLIC: {n_brain_segs} brain supervoxels " |
| f"(requested {n_segments}), brain covers {100*brain.sum()/brain.size:.1f}% of volume", |
| flush=True) |
|
|
| return seg, brain |
|
|
|
|
| def make_segmentation_fn(cached_segments: np.ndarray): |
| """Return a segmentation function that always returns pre-computed segments.""" |
| def segmentation_fn(vol): |
| return cached_segments |
| return segmentation_fn |
|
|
|
|
| |
| |
| |
| def save_slice_png(volume_zyx: np.ndarray, out_path: str, axis: int = 0, idx: int | None = None, rot_k: int = 0): |
| if idx is None: |
| idx = volume_zyx.shape[axis] // 2 |
|
|
| if axis == 0: |
| img = volume_zyx[idx, :, :] |
| title = f"Axial (Z) slice {idx}" |
| elif axis == 1: |
| img = volume_zyx[:, idx, :] |
| title = f"Coronal (Y) slice {idx}" |
| else: |
| img = volume_zyx[:, :, idx] |
| title = f"Sagittal (X) slice {idx}" |
|
|
| img = np.rot90(img, k=rot_k) |
|
|
| plt.figure(figsize=(6, 6)) |
| plt.imshow(img, cmap="gray", origin="lower") |
| plt.title(title) |
| plt.axis("off") |
| os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True) |
| plt.tight_layout() |
| plt.savefig(out_path, dpi=160) |
| plt.close() |
|
|
|
|
| def save_overlay_png( |
| volume_zyx: np.ndarray, |
| heat_zyx: np.ndarray, |
| out_path: str, |
| axis: int = 0, |
| idx: int | None = None, |
| alpha: float = 0.45, |
| clip_q: float = 0.99, |
| rot_k: int = 0, |
| ): |
| assert volume_zyx.shape == heat_zyx.shape |
|
|
| if idx is None: |
| idx = volume_zyx.shape[axis] // 2 |
|
|
| if axis == 0: |
| img = volume_zyx[idx, :, :] |
| h = heat_zyx[idx, :, :] |
| title = f"Axial (Z) overlay slice {idx}" |
| elif axis == 1: |
| img = volume_zyx[:, idx, :] |
| h = heat_zyx[:, idx, :] |
| title = f"Coronal (Y) overlay slice {idx}" |
| else: |
| img = volume_zyx[:, :, idx] |
| h = heat_zyx[:, :, idx] |
| title = f"Sagittal (X) overlay slice {idx}" |
|
|
| img = np.rot90(img, k=rot_k) |
| h = np.rot90(h, k=rot_k) |
|
|
| m = float(max(np.quantile(np.abs(h), clip_q), 1e-8)) |
| h_vis = np.clip(h, -m, m) |
|
|
| plt.figure(figsize=(6, 6)) |
| plt.imshow(img, cmap="gray", origin="lower") |
| im = plt.imshow(h_vis, cmap="bwr", alpha=alpha, origin="lower", vmin=-m, vmax=m) |
| plt.title(title) |
| plt.axis("off") |
| plt.colorbar(im, fraction=0.046, pad=0.04) |
| os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True) |
| plt.tight_layout() |
| plt.savefig(out_path, dpi=160) |
| plt.close() |
|
|
|
|
| def save_overlay_grid_png( |
| volume_zyx: np.ndarray, |
| heat_zyx: np.ndarray, |
| out_path: str, |
| axis: int = 0, |
| idxs=None, |
| n_cols: int = 6, |
| n_slices: int = 36, |
| alpha: float = 0.45, |
| clip_q: float = 0.99, |
| rot_k: int = 0, |
| figsize_per_cell: float = 2.2, |
| add_colorbar: bool = False, |
| suptitle: str | None = None, |
| ): |
| assert volume_zyx.shape == heat_zyx.shape |
| assert axis in (0, 1, 2) |
|
|
| dim = volume_zyx.shape[axis] |
| if idxs is None: |
| lo = int(0.10 * (dim - 1)) |
| hi = int(0.90 * (dim - 1)) |
| if hi <= lo: |
| lo, hi = 0, dim - 1 |
| idxs = np.linspace(lo, hi, n_slices, dtype=int).tolist() |
| else: |
| idxs = list(map(int, idxs)) |
|
|
| n = len(idxs) |
| n_rows = int(np.ceil(n / n_cols)) |
|
|
| m = float(max(np.quantile(np.abs(heat_zyx), clip_q), 1e-8)) |
|
|
| fig_w = n_cols * figsize_per_cell |
| fig_h = n_rows * figsize_per_cell |
| fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_w, fig_h)) |
| axes = np.array(axes).reshape(-1) |
|
|
| def get_slice(arr, ax, i): |
| if ax == 0: |
| s = arr[i, :, :] |
| elif ax == 1: |
| s = arr[:, i, :] |
| else: |
| s = arr[:, :, i] |
| return np.rot90(s, k=rot_k) |
|
|
| im_for_cbar = None |
| for j, idx in enumerate(idxs): |
| axp = axes[j] |
| img = get_slice(volume_zyx, axis, idx) |
| h = get_slice(heat_zyx, axis, idx) |
| h_vis = np.clip(h, -m, m) |
|
|
| axp.imshow(img, cmap="gray", origin="lower") |
| im_for_cbar = axp.imshow(h_vis, cmap="bwr", alpha=alpha, origin="lower", vmin=-m, vmax=m) |
| axp.set_title(f"{idx}", fontsize=9) |
| axp.axis("off") |
|
|
| for k in range(n, n_rows * n_cols): |
| axes[k].axis("off") |
|
|
| if suptitle is None: |
| name = "Axial (Z)" if axis == 0 else ("Coronal (Y)" if axis == 1 else "Sagittal (X)") |
| suptitle = f"{name} | rot {rot_k*90}ยฐ | clip_q={clip_q} | alpha={alpha}" |
| fig.suptitle(suptitle, y=0.98, fontsize=12) |
|
|
| if add_colorbar and im_for_cbar is not None: |
| cbar = fig.colorbar(im_for_cbar, ax=axes[:n], fraction=0.02, pad=0.01) |
| cbar.set_label("LIME weight (clipped)", rotation=90) |
|
|
| os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True) |
| plt.tight_layout() |
| plt.savefig(out_path, dpi=160) |
| plt.close(fig) |
|
|
|
|
| def create_overlay_from_segments(segments_2d: np.ndarray, weights: dict, alpha=0.5) -> np.ndarray: |
| """ |
| Create RGBA overlay from segments and LIME weights. |
| Red = positive (supportive), Blue = negative (contradictory) |
| Visualizes ALL supervoxels based on their weights. |
| Returns (H, W, 4) RGBA array |
| """ |
| H, W = segments_2d.shape |
| overlay = np.zeros((H, W, 4), dtype=np.float32) |
| |
| |
| all_weights = [float(v) for k, v in weights.items() if int(k) != 0] |
| if not all_weights: |
| return overlay |
| |
| max_abs_weight = max(abs(w) for w in all_weights) |
| if max_abs_weight < 1e-8: |
| return overlay |
| |
| |
| for seg_id_str, weight in weights.items(): |
| seg_id = int(seg_id_str) |
| if seg_id == 0: |
| continue |
| |
| mask = (segments_2d == seg_id) |
| if not mask.any(): |
| continue |
| |
| |
| norm_weight = weight / max_abs_weight |
| |
| |
| edge_mask = mask & (~binary_erosion(mask)) |
| |
| if weight > 0: |
| |
| overlay[mask, 0] = 1.0 |
| overlay[mask, 1] = 0.0 |
| overlay[mask, 2] = 0.0 |
| overlay[mask, 3] = alpha * abs(norm_weight) |
| |
| |
| overlay[edge_mask, 3] = min(1.0, alpha * abs(norm_weight) * 2.0) |
| else: |
| |
| overlay[mask, 0] = 0.0 |
| overlay[mask, 1] = 0.4 |
| overlay[mask, 2] = 1.0 |
| overlay[mask, 3] = alpha * abs(norm_weight) |
| |
| |
| overlay[edge_mask, 3] = min(1.0, alpha * abs(norm_weight) * 2.0) |
| |
| return overlay |
|
|
|
|
| def get_top_positive_supervoxel_id(weights: dict, ignore_ids=(0,)) -> int: |
| """Return segment ID with highest positive LIME weight (most RED / supportive). |
| Ignores background segment 0 by default.""" |
| items = [(int(k), float(v)) for k, v in weights.items() if int(k) not in ignore_ids] |
| if not items: |
| raise ValueError("weights vuoto o contiene solo segmenti ignorati.") |
|
|
| pos = [(k, v) for k, v in items if v > 0] |
| if pos: |
| return max(pos, key=lambda kv: kv[1])[0] |
| return max(items, key=lambda kv: kv[1])[0] |
|
|
|
|
| def get_top_negative_supervoxel_id(weights: dict, ignore_ids=(0,)) -> int: |
| """ |
| Ritorna l'id del segmento con weight piรน negativo (piรน 'blu'). |
| Se non esistono pesi negativi, ritorna comunque il min (anche se positivo). |
| """ |
| items = [(int(k), float(v)) for k, v in weights.items() if int(k) not in ignore_ids] |
| if not items: |
| raise ValueError("weights vuoto o contiene solo segmenti ignorati.") |
|
|
| neg = [(k, v) for k, v in items if v < 0] |
| if neg: |
| return min(neg, key=lambda kv: kv[1])[0] |
| return min(items, key=lambda kv: kv[1])[0] |
|
|
|
|
| def _rgba_overlay_from_mask(mask2d: np.ndarray, rgba=(1.0, 0.0, 0.0), alpha=0.45) -> np.ndarray: |
| """ |
| mask2d: float/bool (H,W) con 1 dove disegnare |
| rgba: (R,G,B) in [0,1] |
| """ |
| m = mask2d.astype(np.float32) |
| overlay = np.zeros((m.shape[0], m.shape[1], 4), dtype=np.float32) |
| overlay[..., 0] = float(rgba[0]) |
| overlay[..., 1] = float(rgba[1]) |
| overlay[..., 2] = float(rgba[2]) |
| overlay[..., 3] = float(alpha) * m |
| return overlay |
|
|
|
|
| def _rgba_edge_from_mask(mask2d: np.ndarray, rgba=(1.0, 0.0, 0.0), edge_alpha=1.0) -> np.ndarray: |
| m = mask2d.astype(bool) |
| edge = m & (~binary_erosion(m)) |
| overlay = np.zeros((m.shape[0], m.shape[1], 4), dtype=np.float32) |
| overlay[..., 0] = float(rgba[0]) |
| overlay[..., 1] = float(rgba[1]) |
| overlay[..., 2] = float(rgba[2]) |
| overlay[..., 3] = float(edge_alpha) * edge.astype(np.float32) |
| return overlay |
|
|
|
|
| def save_overlay_single_supervoxel_png( |
| volume_zyx: np.ndarray, |
| segments_zyx: np.ndarray, |
| weights: dict, |
| out_path: str, |
| axis: int = 0, |
| idx: int | None = None, |
| rot_k: int = 0, |
| alpha: float = 0.45, |
| origin: str = "lower", |
| edge_alpha: float = 1.0, |
| ): |
| """ |
| Salva overlay con: |
| - supervoxel piรน 'rosso' (peso massimo positivo) in rosso acceso |
| - supervoxel piรน 'blu' (peso piรน negativo) in blu acceso |
| Ritorna (best_red_id, best_blue_id). |
| """ |
| best_red_id = get_top_positive_supervoxel_id(weights, ignore_ids=(0,)) |
| best_blue_id = get_top_negative_supervoxel_id(weights, ignore_ids=(0,)) |
|
|
| mask_red_3d = (segments_zyx == best_red_id).astype(np.float32) |
| mask_blue_3d = (segments_zyx == best_blue_id).astype(np.float32) |
|
|
| if idx is None: |
| idx = volume_zyx.shape[axis] // 2 |
|
|
| |
| if axis == 0: |
| img = volume_zyx[idx, :, :] |
| m_red = mask_red_3d[idx, :, :] |
| m_blue = mask_blue_3d[idx, :, :] |
| title = f"Axial(Z) slice {idx} | red={best_red_id} | blue={best_blue_id}" |
| elif axis == 1: |
| img = volume_zyx[:, idx, :] |
| m_red = mask_red_3d[:, idx, :] |
| m_blue = mask_blue_3d[:, idx, :] |
| title = f"Coronal(Y) slice {idx} | red={best_red_id} | blue={best_blue_id}" |
| else: |
| img = volume_zyx[:, :, idx] |
| m_red = mask_red_3d[:, :, idx] |
| m_blue = mask_blue_3d[:, :, idx] |
| title = f"Sagittal(X) slice {idx} | red={best_red_id} | blue={best_blue_id}" |
|
|
| img = np.rot90(img, k=rot_k) |
| m_red = np.rot90(m_red, k=rot_k) |
| m_blue = np.rot90(m_blue, k=rot_k) |
|
|
| plt.figure(figsize=(6, 6)) |
| plt.imshow(img, cmap="gray", origin=origin) |
|
|
| |
| plt.imshow(_rgba_overlay_from_mask(m_blue, rgba=(0.0, 0.4, 1.0), alpha=alpha), origin=origin) |
| plt.imshow(_rgba_edge_from_mask(m_blue, rgba=(0.0, 0.4, 1.0), edge_alpha=edge_alpha), origin=origin) |
|
|
| plt.imshow(_rgba_overlay_from_mask(m_red, rgba=(1.0, 0.0, 0.0), alpha=alpha), origin=origin) |
| plt.imshow(_rgba_edge_from_mask(m_red, rgba=(1.0, 0.0, 0.0), edge_alpha=edge_alpha), origin=origin) |
|
|
| plt.title(title) |
| plt.axis("off") |
|
|
| os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True) |
| plt.tight_layout() |
| plt.savefig(out_path, dpi=160) |
| plt.close() |
|
|
| return best_red_id, best_blue_id |
|
|
|
|
| def save_overlay_grid_single_supervoxel_png( |
| volume_zyx: np.ndarray, |
| segments_zyx: np.ndarray, |
| weights: dict, |
| out_path: str, |
| axis: int = 0, |
| n_cols: int = 8, |
| rot_k: int = 0, |
| alpha: float = 0.45, |
| origin: str = "lower", |
| suptitle: str | None = None, |
| edge_alpha: float = 1.0, |
| ): |
| """ |
| Griglia overlay con TUTTE le slice, organizzate come save_flair_grid_all: |
| - supervoxel piรน 'rosso' in rosso acceso |
| - supervoxel piรน 'blu' in blu acceso |
| Ritorna (best_red_id, best_blue_id). |
| """ |
| best_red_id = get_top_positive_supervoxel_id(weights, ignore_ids=(0,)) |
| best_blue_id = get_top_negative_supervoxel_id(weights, ignore_ids=(0,)) |
|
|
| mask_red_3d = (segments_zyx == best_red_id).astype(np.float32) |
| mask_blue_3d = (segments_zyx == best_blue_id).astype(np.float32) |
|
|
| dim = volume_zyx.shape[axis] |
| n_rows = int(np.ceil(dim / n_cols)) |
|
|
| fig, axes = plt.subplots( |
| n_rows, |
| n_cols, |
| figsize=(n_cols * 2, n_rows * 2), |
| facecolor="black" |
| ) |
| axes = np.array(axes).reshape(-1) |
|
|
| def get_slice(arr, ax, i): |
| if ax == 0: |
| s = arr[i, :, :] |
| elif ax == 1: |
| s = arr[:, i, :] |
| else: |
| s = arr[:, :, i] |
| return np.rot90(s, k=rot_k) |
|
|
| for i in range(dim): |
| img = get_slice(volume_zyx, axis, i) |
| m_red = get_slice(mask_red_3d, axis, i) |
| m_blue = get_slice(mask_blue_3d, axis, i) |
|
|
| axes[i].imshow(img, cmap="gray", origin=origin) |
|
|
| |
| axes[i].imshow(_rgba_overlay_from_mask(m_blue, rgba=(0.0, 0.4, 1.0), alpha=alpha), origin=origin) |
| axes[i].imshow(_rgba_edge_from_mask(m_blue, rgba=(0.0, 0.4, 1.0), edge_alpha=edge_alpha), origin=origin) |
|
|
| axes[i].imshow(_rgba_overlay_from_mask(m_red, rgba=(1.0, 0.0, 0.0), alpha=alpha), origin=origin) |
| axes[i].imshow(_rgba_edge_from_mask(m_red, rgba=(1.0, 0.0, 0.0), edge_alpha=edge_alpha), origin=origin) |
|
|
| axes[i].set_title( |
| f"z={i}", |
| color="cyan", |
| fontsize=9, |
| fontweight='bold' |
| ) |
| axes[i].axis("off") |
|
|
| |
| for i in range(dim, len(axes)): |
| axes[i].axis("off") |
|
|
| if suptitle is None: |
| name = "Axial(Z)" if axis == 0 else ("Coronal(Y)" if axis == 1 else "Sagittal(X)") |
| suptitle = f"{name} | red={best_red_id} | blue={best_blue_id} | rot {rot_k*90}ยฐ" |
| fig.suptitle(suptitle) |
|
|
| os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True) |
| plt.tight_layout() |
| plt.savefig(out_path, dpi=150, bbox_inches="tight") |
| plt.close(fig) |
|
|
| return best_red_id, best_blue_id |
|
|
|
|
| def save_volume_slices_overlay( |
| vol: torch.Tensor, |
| heat: np.ndarray, |
| save_path: str, |
| title: str = "Volume overlay", |
| ncols: int = 8, |
| is_healthy: bool = False, |
| alpha: float = 0.45, |
| clip_q: float = 0.99, |
| rot_k: int = 0, |
| brain_mask: np.ndarray | None = None, |
| ): |
| |
| if vol.ndim == 5: |
| vol = vol[0, 0] |
| elif vol.ndim == 4: |
| vol = vol[0] |
|
|
| vol_np = vol.detach().cpu().numpy().astype(np.float32) |
| heat_np = heat.astype(np.float32) |
|
|
| if vol_np.shape != heat_np.shape: |
| raise ValueError(f"Shape mismatch: vol {vol_np.shape} vs heat {heat_np.shape}") |
|
|
| if brain_mask is not None: |
| if brain_mask.shape != vol_np.shape: |
| raise ValueError(f"Brain mask shape mismatch: {brain_mask.shape} vs {vol_np.shape}") |
| brain_np = brain_mask.astype(bool) |
| else: |
| brain_np = None |
|
|
| D, H, W = vol_np.shape |
| nrows = int(np.ceil(D / ncols)) |
|
|
| |
| m = float(max(np.quantile(np.abs(heat_np), clip_q), 1e-8)) |
|
|
| fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 2, nrows * 2), facecolor="black") |
| axes = axes.flatten() |
|
|
| for i in range(D): |
| img = vol_np[i] |
| h = heat_np[i] |
|
|
| if brain_np is not None: |
| b = brain_np[i] |
| else: |
| b = None |
|
|
| |
| img = np.rot90(img, k=rot_k) |
| h = np.rot90(h, k=rot_k) |
| if b is not None: |
| b = np.rot90(b, k=rot_k) |
|
|
| |
| h_vis = np.clip(h, -m, m) |
|
|
| ax = axes[i] |
| ax.set_facecolor("black") |
|
|
| if b is not None: |
| |
| img_ma = np.ma.array(img, mask=~b) |
| ax.imshow(img_ma, cmap="gray", origin="lower") |
|
|
| |
| h_ma = np.ma.array(h_vis, mask=~b) |
| ax.imshow(h_ma, cmap="bwr", alpha=alpha, vmin=-m, vmax=m, origin="lower") |
| else: |
| ax.imshow(img, cmap="gray", origin="lower") |
| ax.imshow(h_vis, cmap="bwr", alpha=alpha, vmin=-m, vmax=m, origin="lower") |
|
|
| ax.set_title(f"z={i}", color="cyan", fontsize=9, fontweight="bold") |
| ax.axis("off") |
|
|
| for i in range(D, len(axes)): |
| axes[i].set_facecolor("black") |
| axes[i].axis("off") |
|
|
| fig.suptitle(f"{title} {'(Healthy)' if is_healthy else '(Pathological)'}", color="white") |
|
|
| Path(save_path).parent.mkdir(parents=True, exist_ok=True) |
| plt.tight_layout() |
| plt.savefig(save_path, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor()) |
| plt.close(fig) |
|
|
|
|
| def save_flair_grid_all(nifti_path: str, save_path: str, load_nifti_volume_fn, ncols: int = 8): |
| """ |
| Save grid of all slices from a NIfTI file. |
| Note: load_nifti_volume_fn must be provided (get it from import_architecture_from_model_dir). |
| """ |
| vol = load_nifti_volume_fn(nifti_path) |
| vol = vol.squeeze(0).squeeze(0).detach().cpu().numpy() |
| D = vol.shape[0] |
| nrows = int(np.ceil(D / ncols)) |
|
|
| fig, axes = plt.subplots( |
| nrows, |
| ncols, |
| figsize=(ncols * 2, nrows * 2), |
| facecolor="black" |
| ) |
| axes = axes.flatten() |
|
|
| for i in range(D): |
| axes[i].imshow(vol[i], cmap="gray", origin="lower") |
| axes[i].set_title( |
| f"z={i}", |
| color="cyan", |
| fontsize=9, |
| fontweight='bold' |
| ) |
| axes[i].axis("off") |
|
|
| |
| for i in range(D, len(axes)): |
| axes[i].axis("off") |
|
|
| |
| Path(save_path).parent.mkdir(parents=True, exist_ok=True) |
|
|
| plt.tight_layout() |
| plt.savefig(save_path, dpi=150, bbox_inches="tight") |
| plt.close(fig) |
|
|
|
|
| |
| |
| |
| def run_interpretability( |
| model, |
| load_nifti_volume, |
| CANONICAL_PROMPT, |
| mri_path: str, |
| report: str, |
| output_dir: str, |
| lime_samples: int = 100, |
| n_segments: int = 20, |
| hide_color: float = 0.0, |
| alpha: float = 0.45, |
| clip_q: float = 0.99, |
| seed: int = 42, |
| ): |
| """Run LIME interpretability on a single MRI scan.""" |
| set_seed(seed) |
| device = next(model.parameters()).device |
| out_dir = Path(output_dir) |
| out_dir.mkdir(parents=True, exist_ok=True) |
|
|
| print(f"\n{'='*60}") |
| print("๐ BrainGemma3D LIME Interpretability") |
| print(f"{'='*60}") |
| print(f"๐ MRI: {mri_path}") |
| print(f"๐ Report: {report[:100]}...") |
| print(f"๐พ Output: {output_dir}") |
| print(f"{'='*60}\n") |
|
|
| |
| print("๐ฅ Loading MRI volume...") |
| volume = load_nifti_volume(mri_path, target_size=(64, 128, 128)).to(device) |
| if volume.ndim == 4: |
| volume = volume.unsqueeze(0) |
| vol_np = volume.squeeze().cpu().numpy() |
| print(f" Shape: {vol_np.shape}") |
|
|
| |
| print(f"\n๐งฉ Creating {n_segments} brain supervoxels...") |
| segments, brain_mask = big_supervoxels_brain_only(vol_np, n_segments=n_segments) |
| |
| |
| print(f"\n๐ฌ Running LIME with {lime_samples} samples...") |
| segmentation_fn = make_segmentation_fn(segments) |
| |
| explainer = lime_image.LimeImageExplainer() |
| |
| def predict_fn(vols_4d): |
| """ |
| vols_4d: (n_samples, D, H, W) - LIME perturbed volumes |
| Returns: (n_samples,) array of scores |
| """ |
| |
| vols_5d = vols_4d[:, np.newaxis, :, :, :] |
| scores = lime_score_report_nll( |
| vols_5d, |
| model, |
| prompt=CANONICAL_PROMPT, |
| report_ref=report, |
| batch_size=1, |
| ) |
| return scores |
| |
| |
| explanation = explainer.explain_instance( |
| vol_np, |
| predict_fn, |
| top_labels=1, |
| hide_color=hide_color, |
| num_samples=lime_samples, |
| segmentation_fn=segmentation_fn, |
| ) |
| |
| |
| label = explanation.top_labels[0] |
| weights = dict(explanation.local_exp[label]) |
| |
| print(f"\nโ
LIME completed!") |
| print(f" Supervoxel weights (sample): {list(weights.items())[:5]}") |
| |
| |
| print("\n๐ Building weight volume...") |
| wvol = np.zeros_like(vol_np, dtype=np.float32) |
| for seg_id, w in weights.items(): |
| seg_id = int(seg_id) |
| if seg_id == 0: |
| continue |
| wvol[segments == seg_id] = float(w) |
| |
| |
| wvol[~brain_mask] = 0.0 |
| |
| |
| print("\n๐พ Saving visualizations...") |
| |
| |
| save_volume_slices_overlay( |
| volume, |
| wvol, |
| str(out_dir / "overlay_slices.png"), |
| title="LIME Interpretability", |
| ncols=8, |
| is_healthy=False, |
| alpha=alpha, |
| clip_q=clip_q, |
| rot_k=0, |
| brain_mask=brain_mask, |
| ) |
| |
| |
| save_overlay_grid_single_supervoxel_png( |
| vol_np, segments, weights, |
| out_path=str(out_dir / "lime_top_supervoxels_grid.png"), |
| axis=0, n_cols=8, alpha=0.55, |
| suptitle="Top Supportive (Red) and Contradicting (Blue) Supervoxels" |
| ) |
| |
| |
| print("\n๐พ Creating 2x3 grid figure (original + LIME overlay)...") |
| D = vol_np.shape[0] |
| |
| lo = int(0.30 * D) |
| hi = int(0.70 * D) |
| selected_slices = np.linspace(lo, hi, 3, dtype=int).tolist() |
| |
| n_slices = len(selected_slices) |
| fig, axes = plt.subplots(2, n_slices, figsize=(n_slices * 4, 2 * 4)) |
| |
| for col, slice_idx in enumerate(selected_slices): |
| |
| img_slice = vol_np[slice_idx, :, :] |
| seg_slice = segments[slice_idx, :, :] |
| |
| |
| axes[0, col].imshow(img_slice, cmap='gray', origin='lower', interpolation='bilinear') |
| axes[0, col].set_title(f'Slice {slice_idx}', fontsize=12, fontweight='bold') |
| axes[0, col].axis('off') |
| |
| |
| axes[1, col].imshow(img_slice, cmap='gray', origin='lower', interpolation='bilinear') |
| overlay = create_overlay_from_segments(seg_slice, weights, alpha=0.5) |
| axes[1, col].imshow(overlay, origin='lower', interpolation='nearest') |
| axes[1, col].axis('off') |
| |
| |
| axes[0, 0].text(-0.15, 0.5, 'Original', transform=axes[0, 0].transAxes, |
| fontsize=14, fontweight='bold', va='center', rotation=90) |
| axes[1, 0].text(-0.15, 0.5, 'LIME Overlay', transform=axes[1, 0].transAxes, |
| fontsize=14, fontweight='bold', va='center', rotation=90) |
| |
| plt.tight_layout() |
| plt.savefig(str(out_dir / "lime_2x3_grid.png"), dpi=300, bbox_inches='tight', facecolor='white') |
| plt.close() |
| print(f"โ
Saved 2x3 grid (slices {selected_slices})") |
| |
| |
| with open(out_dir / "lime_report.txt", "w") as f: |
| f.write(f"Reference Report:\n{report}\n\n") |
| f.write(f"LIME Supervoxel Weights (top 20):\n") |
| sorted_weights = sorted(weights.items(), key=lambda x: abs(x[1]), reverse=True) |
| for seg_id, weight in sorted_weights[:20]: |
| if int(seg_id) != 0: |
| f.write(f" Segment {seg_id}: {weight:.4f}\n") |
| |
| |
| weights_dict = {int(k): float(v) for k, v in weights.items() if int(k) != 0} |
| with open(out_dir / "lime_weights.json", "w") as f: |
| json.dump(weights_dict, f, indent=2) |
| print(f"๐พ Saved lime_weights.json ({len(weights_dict)} brain supervoxels)", flush=True) |
| |
| |
| np.save(str(out_dir / "lime_wvol.npy"), wvol) |
| np.save(str(out_dir / "lime_segments.npy"), segments) |
| print(f"โ
Saved wvol/segments arrays", flush=True) |
| print(f" wvol stats: shape={wvol.shape} min={wvol.min():.4g} max={wvol.max():.4g}", flush=True) |
| |
| print(f"\n{'='*60}") |
| print("โ
Interpretability analysis completed!") |
| print(f" Results saved to: {output_dir}") |
| print(f"{'='*60}\n") |
| |
| return weights, wvol |
|
|
|
|
| |
| |
| |
| def main(): |
| parser = argparse.ArgumentParser(description="BrainGemma3D LIME Interpretability") |
| |
| |
| parser.add_argument("--model_dir", required=True, help="Path to BrainGemma3D model folder") |
| parser.add_argument("--mri_path", required=True, help="Path to .nii/.nii.gz MRI scan") |
| |
| |
| parser.add_argument("--report", default=None, help="Reference report text. If not provided, will generate it first.") |
| parser.add_argument("--output_dir", default="./lime_output", help="Output directory for results") |
| |
| |
| parser.add_argument("--max_new_tokens", type=int, default=256) |
| parser.add_argument("--temperature", type=float, default=0.1) |
| parser.add_argument("--top_p", type=float, default=0.9) |
| |
| |
| parser.add_argument("--lime_samples", type=int, default=100, help="Number of LIME samples") |
| parser.add_argument("--n_segments", type=int, default=20, help="Number of supervoxels") |
| parser.add_argument("--hide_color", type=float, default=0.0, help="Hide color for LIME perturbations") |
| |
| |
| parser.add_argument("--alpha", type=float, default=0.45, help="Overlay transparency") |
| parser.add_argument("--clip_q", type=float, default=0.99, help="Heatmap clipping quantile") |
| |
| |
| parser.add_argument("--seed", type=int, default=42, help="Random seed") |
| |
| args = parser.parse_args() |
| |
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"๐ Loading BrainGemma3D model from {args.model_dir}...") |
| model, load_nifti_volume, CANONICAL_PROMPT = load_full_model(args.model_dir, device) |
| print("โ
Model loaded successfully!") |
| |
| |
| if args.report is None: |
| print("\n๐ No report provided, generating one...") |
| volume = load_nifti_volume(args.mri_path, target_size=(64, 128, 128)).to(device) |
| if volume.ndim == 4: |
| volume = volume.unsqueeze(0) |
| |
| with torch.no_grad(): |
| report = model.generate_report( |
| volume, |
| prompt=CANONICAL_PROMPT, |
| max_new_tokens=args.max_new_tokens, |
| temperature=args.temperature, |
| top_p=args.top_p, |
| ) |
| print(f"โ
Generated report: {report}") |
| else: |
| report = args.report |
| |
| |
| run_interpretability( |
| model=model, |
| load_nifti_volume=load_nifti_volume, |
| CANONICAL_PROMPT=CANONICAL_PROMPT, |
| mri_path=args.mri_path, |
| report=report, |
| output_dir=args.output_dir, |
| lime_samples=args.lime_samples, |
| n_segments=args.n_segments, |
| hide_color=args.hide_color, |
| alpha=args.alpha, |
| clip_q=args.clip_q, |
| seed=args.seed, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|