--- license: cc-by-4.0 language: - en tags: - multimodal - vision-language - medical - neuroradiology - brain-mri - report-generation - 3d-vision - medgemma - medsiglip datasets: - BraTS2020 - TextBraTS2021 - MPI-Leipzig_Mind-Brain-Body base_model: - google/medgemma-1.5-4b-it - google/medsiglip-448 --- # ๐ง BrainGemma3D โ Brain Report Automation via Inflated Vision Transformers in 3D BrainGemma3D is a **multimodal vision-language model** that generates clinically accurate radiology reports directly from **native 3D brain MRI** volumes. Unlike 2D slice-based approaches, BrainGemma3D processes MRI scans volumetrically, preserving the spatial context critical for accurate neuroradiological interpretation.
--- ## ๐ฏ Key Features - **๐ฌ Native 3D Processing**: Inflated 2D medical vision encoder ([MedSigLIP](https://huggingface.co/google/medsiglip-448)) to 3D for volumetric understanding - **๐ Clinical Accuracy**: 95.1% F1 score on pathology entity recognition (on BraTS dataset) - **๐งญ Spatial Awareness**: 68.9% laterality F1 (correct left/right hemisphere localization) - **๐ Interpretable**: LIME-based 3D attribution maps show which brain regions drive predictions - **๐ Efficient**: Processes full 3D volumes with 32 compressed visual tokens - **๐ฅ Research-Ready**: Pre-trained on 369 brain tumor cases + 99 healthy controls --- ## ๐๏ธ Architecture BrainGemma3D combines: 1. **3D Vision Encoder**: MedSigLIP inflated to 3D via center-frame initialization (Conv2D โ Conv3D) *Base model: [google/medsiglip-448](https://huggingface.co/google/medsiglip-448)* 2. **Token Compressor**: 2-layer Perceiver that reduces 3D patches to 32 visual tokens 3. **Vision-Language Projector**: 2-layer MLP that projects visual tokens to language model embedding space 4. **Language Model**: 4-bit quantized [MedGemma-1.5-4B-IT](https://huggingface.co/google/medgemma-1.5-4b-it) with LoRA adapters *Base model: [google/medgemma-1.5-4b-it](https://huggingface.co/google/medgemma-1.5-4b-it)* --- ## ๐ Usage ### Requirements ```bash pip install torch torchvision transformers nibabel scikit-image lime ``` ### Model Download ```python from huggingface_hub import snapshot_download # 1. Download the repository containing our custom architecture from Hugging Face repo_id = "praiselab-picuslab/BrainGemma3D" print(f"Downloading repository: {repo_id}...") local_dir = snapshot_download(repo_id) print(f"โ Repository downloaded to: {local_dir}") ``` ### Quick Start ```python import os import torch import sys sys.path.append(local_dir) from medgemma3d_architecture import MedGemma3D, load_nifti_volume, CANONICAL_PROMPT # Automatically select the optimal hardware accelerator (GPU if available, otherwise CPU) device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Hardware accelerator selected: {device}") # 2. Instantiate the base architecture (3D-inflated MedSigLIP + MedGemma) model = MedGemma3D( vision_model_dir=f"{local_dir}/vision_model", language_model_dir=f"{local_dir}/language_model", depth=2, num_vision_tokens=32, freeze_vision=True, freeze_language=True, device_map={"": 0} if device == "cuda" else None, ) # 3. Load projector proj_path = os.path.join(local_dir, "projector_vis_scale.pt") print(f"Loading custom projector weights from: {proj_path}...") # Load the checkpoint into memory ckpt = torch.load(proj_path, map_location=device) # Inject the weights into the visual projector (which bridges Vision and Language) model.vision_projector.load_state_dict(ckpt["vision_projector"]) # Load the visual scaling factor, ensuring correct tensor formatting if ckpt.get("vis_scale") is not None: if isinstance(ckpt["vis_scale"], torch.Tensor): model.vis_scale.data = ckpt["vis_scale"].to(device) else: model.vis_scale.data.fill_(ckpt["vis_scale"]) # Transition the model to evaluation mode for inference model.eval() print("โ BrainGemma3D is fully loaded and ready for inference!") # 4. Load MRI scan volume = load_nifti_volume( "path/to/brain_flair.nii.gz", target_size=(32, 128, 128) ).to(device) if volume.ndim == 4: volume = volume.unsqueeze(0) # 5. Generate report with torch.no_grad(): report = model.generate_report( volume, prompt=CANONICAL_PROMPT, max_new_tokens=256, temperature=0.1, top_p=0.9, ) print("\n===== GENERATED REPORT =====\n") print(report) ``` ### Expected Output Example ``` Generated Report: The lesion area is in the left parietal and frontal lobes with mixed high-signal areas. Edema signals are mainly observed around these lesions, indicating significant edema presence affecting parts of both frontal and temporal regions as well as some portions within the parietal lobe. Necrosis may be present at low signal intensity or scattered throughout certain sections of the brain tissue affected by edema. Ventricular compression effects on adjacent ventricles can occur due to pressure from surrounding tissues near the ventricular system. ``` --- ## ๐ Training Pipeline BrainGemma3D is trained in **three progressive stages** to prevent catastrophic forgetting: ### Phase 1: Contrastive Grounding (Image-Text Alignment) - **Goal**: Align 3D visual features with textual report embeddings - **Loss**: InfoNCE (CLIP-style contrastive learning) - **Trainable**: 3D Vision Encoder + Projector - **Frozen**: Language Model - **Epochs**: 100 ### Phase 2A: Projector Warmup - **Goal**: Train the projector to condition the LM effectively - **Loss**: Next-token prediction (Cross-Entropy) - **Trainable**: Projector only - **Frozen**: Vision Encoder + Language Model - **Epochs**: 100 ### Phase 2B: LoRA Linguistic Specialization - **Goal**: Adapt LM to generate structured clinical reports - **Loss**: Next-token prediction (Cross-Entropy) - **Trainable**: Projector + LoRA adapters (rank=4) on LM attention layers - **Frozen**: Vision Encoder + LM base weights - **Epochs**: 100 **Dataset**: - 369 [BraTS 2020](https://www.kaggle.com/datasets/awsaf49/brats20-dataset-training-validation) brain tumor MRI cases with radiologist-written reports from [TextBraTS 2021](https://github.com/Jupitern52/TextBraTS) - 99 healthy control scans with synthetic reports from [MPI-Leipzig Mind-Brain-Body](https://openneuro.org/datasets/ds000221/versions/00002) - Stratified group-based splits (70% train / 10% val / 20% test) to prevent patient leakage --- ## ๐ Performance Evaluated on **468 subjects** (369 BraTS pathological + 99 healthy controls) with group-based splits. ### Quantitative Results (Test Set) | Model | BLEU-1 | BLEU-4 | ROUGE-L | CIDEr | Lat F1 | Anat F1 | **Path F1** | |-------|--------|--------|---------|-------|--------|---------|-------------| | Med3DVLM *(3D Generalist)* | 0.051 | 0.005 | 0.083 | 0.007 | 0.300 | 0.225 | 0.119 | | MedGemma 1.5 *(2D Slice)* | 0.245 | 0.024 | 0.189 | 0.029 | 0.526 | 0.461 | 0.413 | | **BrainGemma3D** *(Ours)* | **0.302** | **0.098** | **0.289** | **0.293** | **0.689** | **0.691** | **0.951** | ### Clinical Metrics Breakdown - **Laterality F1**: 0.689 โ Correct hemispheric localization (left/right) - **Anatomy F1**: 0.691 โ Accurate anatomical structure identification - **Pathology F1**: 0.951 โ Near-perfect pathological entity recognition - **Healthy Specificity**: 1.0 โ Zero hallucinations on healthy controls > **Key Insight**: The **+130% gain in Pathology F1** (0.951 vs 0.413 compared to 2D baseline) demonstrates that native 3D processing is essential for diagnostic accuracy in neuroradiology. --- ## ๐ Interpretability BrainGemma3D includes **LIME-based 3D interpretability** to visualize which brain regions drive diagnostic predictions. ```python from braingemma3d_interpretability import run_interpretability # 6. Run interpretability analysis weights, wvol = run_interpretability( model=model, load_nifti_volume=load_nifti_volume, CANONICAL_PROMPT=CANONICAL_PROMPT, mri_path="path/to/brain_flair.nii.gz", report=report, output_dir="./interpretability_output", lime_samples=100, # Number of perturbations (more = better but slower) n_segments=20, # Number of brain regions to analyze alpha=0.45, # Overlay transparency clip_q=0.99, # Heatmap clipping seed=42, ) ``` **Output**: - `overlay_slices.png` โ Full 3D heatmap (red=supportive, blue=contradicting) - `lime_2x3_grid.png` โ 2ร3 grid with selected slices (original + LIME overlay) - `lime_top_supervoxels_grid.png` โ Most influential supervoxels - `lime_weights.json` โ Supervoxel importance scores ### Expected Output Example
Figure 1: LIME attribution maps for a BraTS sample. Red regions show supervoxels that positively contribute to pathology predictions. The model correctly focuses on tumor-affected areas in the left parietal and frontal lobes.
Built with โค๏ธ for the MedGemma Impact Challenge ๐
Advancing Medical AI with Google's Health AI Developer Foundations