--- license: cc-by-4.0 language: - en tags: - multimodal - vision-language - medical - neuroradiology - brain-mri - report-generation - 3d-vision - medgemma - medsiglip datasets: - BraTS2020 - TextBraTS2021 - MPI-Leipzig_Mind-Brain-Body base_model: - google/medgemma-1.5-4b-it - google/medsiglip-448 --- # ๐Ÿง  BrainGemma3D โ€” Brain Report Automation via Inflated Vision Transformers in 3D BrainGemma3D is a **multimodal vision-language model** that generates clinically accurate radiology reports directly from **native 3D brain MRI** volumes. Unlike 2D slice-based approaches, BrainGemma3D processes MRI scans volumetrically, preserving the spatial context critical for accurate neuroradiological interpretation.
GitHub Repository Kaggle Notebook
MedGemma Challenge
--- ## ๐ŸŽฏ Key Features - **๐Ÿ”ฌ Native 3D Processing**: Inflated 2D medical vision encoder ([MedSigLIP](https://huggingface.co/google/medsiglip-448)) to 3D for volumetric understanding - **๐Ÿ“ Clinical Accuracy**: 95.1% F1 score on pathology entity recognition (on BraTS dataset) - **๐Ÿงญ Spatial Awareness**: 68.9% laterality F1 (correct left/right hemisphere localization) - **๐Ÿ” Interpretable**: LIME-based 3D attribution maps show which brain regions drive predictions - **๐Ÿš€ Efficient**: Processes full 3D volumes with 32 compressed visual tokens - **๐Ÿฅ Research-Ready**: Pre-trained on 369 brain tumor cases + 99 healthy controls --- ## ๐Ÿ—๏ธ Architecture BrainGemma3D combines: 1. **3D Vision Encoder**: MedSigLIP inflated to 3D via center-frame initialization (Conv2D โ†’ Conv3D) *Base model: [google/medsiglip-448](https://huggingface.co/google/medsiglip-448)* 2. **Token Compressor**: 2-layer Perceiver that reduces 3D patches to 32 visual tokens 3. **Vision-Language Projector**: 2-layer MLP that projects visual tokens to language model embedding space 4. **Language Model**: 4-bit quantized [MedGemma-1.5-4B-IT](https://huggingface.co/google/medgemma-1.5-4b-it) with LoRA adapters *Base model: [google/medgemma-1.5-4b-it](https://huggingface.co/google/medgemma-1.5-4b-it)* --- ## ๐Ÿš€ Usage ### Requirements ```bash pip install torch torchvision transformers nibabel scikit-image lime ``` ### Model Download ```python from huggingface_hub import snapshot_download # 1. Download the repository containing our custom architecture from Hugging Face repo_id = "praiselab-picuslab/BrainGemma3D" print(f"Downloading repository: {repo_id}...") local_dir = snapshot_download(repo_id) print(f"โœ… Repository downloaded to: {local_dir}") ``` ### Quick Start ```python import os import torch import sys sys.path.append(local_dir) from medgemma3d_architecture import MedGemma3D, load_nifti_volume, CANONICAL_PROMPT # Automatically select the optimal hardware accelerator (GPU if available, otherwise CPU) device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Hardware accelerator selected: {device}") # 2. Instantiate the base architecture (3D-inflated MedSigLIP + MedGemma) model = MedGemma3D( vision_model_dir=f"{local_dir}/vision_model", language_model_dir=f"{local_dir}/language_model", depth=2, num_vision_tokens=32, freeze_vision=True, freeze_language=True, device_map={"": 0} if device == "cuda" else None, ) # 3. Load projector proj_path = os.path.join(local_dir, "projector_vis_scale.pt") print(f"Loading custom projector weights from: {proj_path}...") # Load the checkpoint into memory ckpt = torch.load(proj_path, map_location=device) # Inject the weights into the visual projector (which bridges Vision and Language) model.vision_projector.load_state_dict(ckpt["vision_projector"]) # Load the visual scaling factor, ensuring correct tensor formatting if ckpt.get("vis_scale") is not None: if isinstance(ckpt["vis_scale"], torch.Tensor): model.vis_scale.data = ckpt["vis_scale"].to(device) else: model.vis_scale.data.fill_(ckpt["vis_scale"]) # Transition the model to evaluation mode for inference model.eval() print("โœ… BrainGemma3D is fully loaded and ready for inference!") # 4. Load MRI scan volume = load_nifti_volume( "path/to/brain_flair.nii.gz", target_size=(32, 128, 128) ).to(device) if volume.ndim == 4: volume = volume.unsqueeze(0) # 5. Generate report with torch.no_grad(): report = model.generate_report( volume, prompt=CANONICAL_PROMPT, max_new_tokens=256, temperature=0.1, top_p=0.9, ) print("\n===== GENERATED REPORT =====\n") print(report) ``` ### Expected Output Example ``` Generated Report: The lesion area is in the left parietal and frontal lobes with mixed high-signal areas. Edema signals are mainly observed around these lesions, indicating significant edema presence affecting parts of both frontal and temporal regions as well as some portions within the parietal lobe. Necrosis may be present at low signal intensity or scattered throughout certain sections of the brain tissue affected by edema. Ventricular compression effects on adjacent ventricles can occur due to pressure from surrounding tissues near the ventricular system. ``` --- ## ๐ŸŽ“ Training Pipeline BrainGemma3D is trained in **three progressive stages** to prevent catastrophic forgetting: ### Phase 1: Contrastive Grounding (Image-Text Alignment) - **Goal**: Align 3D visual features with textual report embeddings - **Loss**: InfoNCE (CLIP-style contrastive learning) - **Trainable**: 3D Vision Encoder + Projector - **Frozen**: Language Model - **Epochs**: 100 ### Phase 2A: Projector Warmup - **Goal**: Train the projector to condition the LM effectively - **Loss**: Next-token prediction (Cross-Entropy) - **Trainable**: Projector only - **Frozen**: Vision Encoder + Language Model - **Epochs**: 100 ### Phase 2B: LoRA Linguistic Specialization - **Goal**: Adapt LM to generate structured clinical reports - **Loss**: Next-token prediction (Cross-Entropy) - **Trainable**: Projector + LoRA adapters (rank=4) on LM attention layers - **Frozen**: Vision Encoder + LM base weights - **Epochs**: 100 **Dataset**: - 369 [BraTS 2020](https://www.kaggle.com/datasets/awsaf49/brats20-dataset-training-validation) brain tumor MRI cases with radiologist-written reports from [TextBraTS 2021](https://github.com/Jupitern52/TextBraTS) - 99 healthy control scans with synthetic reports from [MPI-Leipzig Mind-Brain-Body](https://openneuro.org/datasets/ds000221/versions/00002) - Stratified group-based splits (70% train / 10% val / 20% test) to prevent patient leakage --- ## ๐Ÿ“Š Performance Evaluated on **468 subjects** (369 BraTS pathological + 99 healthy controls) with group-based splits. ### Quantitative Results (Test Set) | Model | BLEU-1 | BLEU-4 | ROUGE-L | CIDEr | Lat F1 | Anat F1 | **Path F1** | |-------|--------|--------|---------|-------|--------|---------|-------------| | Med3DVLM *(3D Generalist)* | 0.051 | 0.005 | 0.083 | 0.007 | 0.300 | 0.225 | 0.119 | | MedGemma 1.5 *(2D Slice)* | 0.245 | 0.024 | 0.189 | 0.029 | 0.526 | 0.461 | 0.413 | | **BrainGemma3D** *(Ours)* | **0.302** | **0.098** | **0.289** | **0.293** | **0.689** | **0.691** | **0.951** | ### Clinical Metrics Breakdown - **Laterality F1**: 0.689 โ€” Correct hemispheric localization (left/right) - **Anatomy F1**: 0.691 โ€” Accurate anatomical structure identification - **Pathology F1**: 0.951 โ€” Near-perfect pathological entity recognition - **Healthy Specificity**: 1.0 โ€” Zero hallucinations on healthy controls > **Key Insight**: The **+130% gain in Pathology F1** (0.951 vs 0.413 compared to 2D baseline) demonstrates that native 3D processing is essential for diagnostic accuracy in neuroradiology. --- ## ๐Ÿ” Interpretability BrainGemma3D includes **LIME-based 3D interpretability** to visualize which brain regions drive diagnostic predictions. ```python from braingemma3d_interpretability import run_interpretability # 6. Run interpretability analysis weights, wvol = run_interpretability( model=model, load_nifti_volume=load_nifti_volume, CANONICAL_PROMPT=CANONICAL_PROMPT, mri_path="path/to/brain_flair.nii.gz", report=report, output_dir="./interpretability_output", lime_samples=100, # Number of perturbations (more = better but slower) n_segments=20, # Number of brain regions to analyze alpha=0.45, # Overlay transparency clip_q=0.99, # Heatmap clipping seed=42, ) ``` **Output**: - `overlay_slices.png` โ€” Full 3D heatmap (red=supportive, blue=contradicting) - `lime_2x3_grid.png` โ€” 2ร—3 grid with selected slices (original + LIME overlay) - `lime_top_supervoxels_grid.png` โ€” Most influential supervoxels - `lime_weights.json` โ€” Supervoxel importance scores ### Expected Output Example
LIME Interpretability

Figure 1: LIME attribution maps for a BraTS sample. Red regions show supervoxels that positively contribute to pathology predictions. The model correctly focuses on tumor-affected areas in the left parietal and frontal lobes.

--- ## โš™๏ธ Model Details - **Model Type**: Multimodal Vision-Language Model (3D MRI โ†’ Text) - **Architecture**: Inflated ViT + Perceiver Compressor + MLP Projector + Quantized Gemma-1.5 - **Input**: 3D brain MRI FLAIR volumes (64ร—128ร—128 voxels) - **Output**: Free-form radiology reports (up to 256 tokens) - **Parameters**: ~454M (450M vision + 4B language, 4-bit quantized) - **Training Compute**: 1ร— NVIDIA A100 64GB (โ‰ˆ12 GPU-hours total) - **Framework**: PyTorch 2.0, Transformers 4.40+ ### Preprocessing Requirements - **Orientation**: RAS (as-closest-canonical) - **Resolution**: Resampled to (64, 128, 128) via trilinear interpolation - **Normalization**: Percentile clipping (p1, p99) + z-score normalization - **Format**: NIfTI (.nii or .nii.gz) --- ## โš ๏ธ Limitations & Intended Use ### โœ… Intended Use - **Research**: Medical AI research, neuroradiology automation - **Education**: Teaching radiology residents about report generation - **Prototyping**: Building diagnostic support tools (non-clinical) ### โŒ Not Intended For - **Clinical Diagnosis**: This model is NOT FDA/CE approved for medical use - **Primary Interpretation**: Always verify with board-certified radiologists - **Real-Time Emergency**: Not validated for acute stroke or trauma cases ### Known Limitations - **Training Bias**: Trained primarily on glioblastoma (BraTS dataset) โ€” may underperform on other pathologies - **Language**: English only (radiology reports) - **Hallucination Risk**: May generate plausible but incorrect anatomical details (always verify) - **Compute Requirements**: Requires GPU with โ‰ฅ16GB VRAM for inference --- ## ๐Ÿฅ Clinical Validation Notes BrainGemma3D achieved **95.1% pathology F1** on the BraTS, but this does NOT imply clinical readiness. Key considerations: 1. **Dataset Homogeneity**: BraTS contains predominantly glioblastomas โ€” performance on other tumor types (meningiomas, metastases) is unknown 2. **Report Quality**: Ground truth reports are from a single institution โ€” may not generalize to other radiology practices 3. **No Radiologist Review**: Generated reports have not been clinically validated by neuroradiologists 4. **Regulatory Status**: Not cleared by FDA, EMA, or any regulatory body **Recommendation**: Use only in research settings with appropriate ethical oversight and informed consent. --- ## ๐Ÿ™ Acknowledgements This project was developed by: **Mariano Barone** ยท **Francesco Di Serio** ยท **Giuseppe Riccio** ยท **Antonio Romano** ยท **Vincenzo Moscato** *Department of Electrical Engineering and Information Technology* *University of Naples Federico II, Italy* ### Built With - [Google MedGemma](https://huggingface.co/google/medgemma-1.5-4b-it) โ€” Medical domain language model - [Google MedSigLIP](https://huggingface.co/google/medsiglip-448) โ€” Medical vision encoder - [Hugging Face Transformers](https://huggingface.co/docs/transformers) โ€” Model framework ---

Built with โค๏ธ for the MedGemma Impact Challenge ๐Ÿ†

Advancing Medical AI with Google's Health AI Developer Foundations