| --- |
| license: cc-by-4.0 |
| language: |
| - en |
| tags: |
| - multimodal |
| - vision-language |
| - medical |
| - neuroradiology |
| - brain-mri |
| - report-generation |
| - 3d-vision |
| - medgemma |
| - medsiglip |
| datasets: |
| - BraTS2020 |
| - TextBraTS2021 |
| - MPI-Leipzig_Mind-Brain-Body |
| base_model: |
| - google/medgemma-1.5-4b-it |
| - google/medsiglip-448 |
| --- |
| |
| # π§ BrainGemma3D β Brain Report Automation via Inflated Vision Transformers in 3D |
|
|
| BrainGemma3D is a **multimodal vision-language model** that generates clinically accurate radiology reports directly from **native 3D brain MRI** volumes. Unlike 2D slice-based approaches, BrainGemma3D processes MRI scans volumetrically, preserving the spatial context critical for accurate neuroradiological interpretation. |
|
|
| <div align="center"> |
| <a href="https://github.com/PRAISELab-PicusLab/BrainGemma3D" target="_blank"><img alt="GitHub Repository" |
| src="https://img.shields.io/badge/GitHub-BrainGemma3D-181717?style=for-the-badge&logo=github&logoSize=auto"/></a> |
| <a href="https://www.kaggle.com/code/antonioromano45/braingemma3d" target="_blank"><img alt="Kaggle Notebook" |
| src="https://img.shields.io/badge/Kaggle-Notebook-20BEFF?style=for-the-badge&logo=kaggle&logoSize=auto"/></a> |
| <br> |
| <a href="https://www.kaggle.com/competitions/med-gemma-impact-challenge/overview" target="_blank"><img alt="MedGemma Challenge" |
| src="https://img.shields.io/badge/Kaggle-MedGemma_Impact_Challenge-blue?style=for-the-badge&logo=kaggle&logoSize=auto&color=20BEFF"/></a> |
| </div> |
| |
| --- |
|
|
| ## π― Key Features |
|
|
| - **π¬ Native 3D Processing**: Inflated 2D medical vision encoder ([MedSigLIP](https://huggingface.co/google/medsiglip-448)) to 3D for volumetric understanding |
| - **π Clinical Accuracy**: 95.1% F1 score on pathology entity recognition (on BraTS dataset) |
| - **π§ Spatial Awareness**: 68.9% laterality F1 (correct left/right hemisphere localization) |
| - **π Interpretable**: LIME-based 3D attribution maps show which brain regions drive predictions |
| - **π Efficient**: Processes full 3D volumes with 32 compressed visual tokens |
| - **π₯ Research-Ready**: Pre-trained on 369 brain tumor cases + 99 healthy controls |
|
|
| --- |
|
|
| ## ποΈ Architecture |
|
|
| BrainGemma3D combines: |
|
|
| 1. **3D Vision Encoder**: MedSigLIP inflated to 3D via center-frame initialization (Conv2D β Conv3D) |
| *Base model: [google/medsiglip-448](https://huggingface.co/google/medsiglip-448)* |
|
|
| 2. **Token Compressor**: 2-layer Perceiver that reduces 3D patches to 32 visual tokens |
|
|
| 3. **Vision-Language Projector**: 2-layer MLP that projects visual tokens to language model embedding space |
|
|
| 4. **Language Model**: 4-bit quantized [MedGemma-1.5-4B-IT](https://huggingface.co/google/medgemma-1.5-4b-it) with LoRA adapters |
| *Base model: [google/medgemma-1.5-4b-it](https://huggingface.co/google/medgemma-1.5-4b-it)* |
|
|
| --- |
|
|
| ## π Usage |
|
|
| ### Requirements |
|
|
| ```bash |
| pip install torch torchvision transformers nibabel scikit-image lime |
| ``` |
|
|
| ### Model Download |
|
|
| ```python |
| from huggingface_hub import snapshot_download |
| |
| # 1. Download the repository containing our custom architecture from Hugging Face |
| repo_id = "praiselab-picuslab/BrainGemma3D" |
| print(f"Downloading repository: {repo_id}...") |
| local_dir = snapshot_download(repo_id) |
| print(f"β
Repository downloaded to: {local_dir}") |
| ``` |
|
|
| ### Quick Start |
|
|
| ```python |
| import os |
| import torch |
| import sys |
| sys.path.append(local_dir) |
| |
| from medgemma3d_architecture import MedGemma3D, load_nifti_volume, CANONICAL_PROMPT |
| |
| # Automatically select the optimal hardware accelerator (GPU if available, otherwise CPU) |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Hardware accelerator selected: {device}") |
| |
| # 2. Instantiate the base architecture (3D-inflated MedSigLIP + MedGemma) |
| model = MedGemma3D( |
| vision_model_dir=f"{local_dir}/vision_model", |
| language_model_dir=f"{local_dir}/language_model", |
| depth=2, |
| num_vision_tokens=32, |
| freeze_vision=True, |
| freeze_language=True, |
| device_map={"": 0} if device == "cuda" else None, |
| ) |
| |
| # 3. Load projector |
| proj_path = os.path.join(local_dir, "projector_vis_scale.pt") |
| print(f"Loading custom projector weights from: {proj_path}...") |
| |
| # Load the checkpoint into memory |
| ckpt = torch.load(proj_path, map_location=device) |
| |
| # Inject the weights into the visual projector (which bridges Vision and Language) |
| model.vision_projector.load_state_dict(ckpt["vision_projector"]) |
| |
| # Load the visual scaling factor, ensuring correct tensor formatting |
| if ckpt.get("vis_scale") is not None: |
| if isinstance(ckpt["vis_scale"], torch.Tensor): |
| model.vis_scale.data = ckpt["vis_scale"].to(device) |
| else: |
| model.vis_scale.data.fill_(ckpt["vis_scale"]) |
| |
| # Transition the model to evaluation mode for inference |
| model.eval() |
| print("β
BrainGemma3D is fully loaded and ready for inference!") |
| |
| # 4. Load MRI scan |
| volume = load_nifti_volume( |
| "path/to/brain_flair.nii.gz", |
| target_size=(32, 128, 128) |
| ).to(device) |
| |
| if volume.ndim == 4: |
| volume = volume.unsqueeze(0) |
| |
| # 5. Generate report |
| with torch.no_grad(): |
| report = model.generate_report( |
| volume, |
| prompt=CANONICAL_PROMPT, |
| max_new_tokens=256, |
| temperature=0.1, |
| top_p=0.9, |
| ) |
| |
| print("\n===== GENERATED REPORT =====\n") |
| print(report) |
| ``` |
|
|
| ### Expected Output Example |
|
|
| ``` |
| Generated Report: |
| The lesion area is in the left parietal and frontal lobes with mixed high-signal |
| areas. Edema signals are mainly observed around these lesions, indicating significant |
| edema presence affecting parts of both frontal and temporal regions as well as some |
| portions within the parietal lobe. Necrosis may be present at low signal intensity |
| or scattered throughout certain sections of the brain tissue affected by edema. |
| Ventricular compression effects on adjacent ventricles can occur due to pressure |
| from surrounding tissues near the ventricular system. |
| ``` |
|
|
| --- |
|
|
| ## π Training Pipeline |
|
|
| BrainGemma3D is trained in **three progressive stages** to prevent catastrophic forgetting: |
|
|
| ### Phase 1: Contrastive Grounding (Image-Text Alignment) |
| - **Goal**: Align 3D visual features with textual report embeddings |
| - **Loss**: InfoNCE (CLIP-style contrastive learning) |
| - **Trainable**: 3D Vision Encoder + Projector |
| - **Frozen**: Language Model |
| - **Epochs**: 100 |
|
|
| ### Phase 2A: Projector Warmup |
| - **Goal**: Train the projector to condition the LM effectively |
| - **Loss**: Next-token prediction (Cross-Entropy) |
| - **Trainable**: Projector only |
| - **Frozen**: Vision Encoder + Language Model |
| - **Epochs**: 100 |
|
|
| ### Phase 2B: LoRA Linguistic Specialization |
| - **Goal**: Adapt LM to generate structured clinical reports |
| - **Loss**: Next-token prediction (Cross-Entropy) |
| - **Trainable**: Projector + LoRA adapters (rank=4) on LM attention layers |
| - **Frozen**: Vision Encoder + LM base weights |
| - **Epochs**: 100 |
|
|
| **Dataset**: |
| - 369 [BraTS 2020](https://www.kaggle.com/datasets/awsaf49/brats20-dataset-training-validation) brain tumor MRI cases with radiologist-written reports from [TextBraTS 2021](https://github.com/Jupitern52/TextBraTS) |
| - 99 healthy control scans with synthetic reports from [MPI-Leipzig Mind-Brain-Body](https://openneuro.org/datasets/ds000221/versions/00002) |
| - Stratified group-based splits (70% train / 10% val / 20% test) to prevent patient leakage |
|
|
| --- |
|
|
| ## π Performance |
|
|
| Evaluated on **468 subjects** (369 BraTS pathological + 99 healthy controls) with group-based splits. |
|
|
| ### Quantitative Results (Test Set) |
|
|
| | Model | BLEU-1 | BLEU-4 | ROUGE-L | CIDEr | Lat F1 | Anat F1 | **Path F1** | |
| |-------|--------|--------|---------|-------|--------|---------|-------------| |
| | Med3DVLM *(3D Generalist)* | 0.051 | 0.005 | 0.083 | 0.007 | 0.300 | 0.225 | 0.119 | |
| | MedGemma 1.5 *(2D Slice)* | 0.245 | 0.024 | 0.189 | 0.029 | 0.526 | 0.461 | 0.413 | |
| | **BrainGemma3D** *(Ours)* | **0.302** | **0.098** | **0.289** | **0.293** | **0.689** | **0.691** | **0.951** | |
|
|
| ### Clinical Metrics Breakdown |
|
|
| - **Laterality F1**: 0.689 β Correct hemispheric localization (left/right) |
| - **Anatomy F1**: 0.691 β Accurate anatomical structure identification |
| - **Pathology F1**: 0.951 β Near-perfect pathological entity recognition |
| - **Healthy Specificity**: 1.0 β Zero hallucinations on healthy controls |
|
|
| > **Key Insight**: The **+130% gain in Pathology F1** (0.951 vs 0.413 compared to 2D baseline) demonstrates that native 3D processing is essential for diagnostic accuracy in neuroradiology. |
|
|
| --- |
|
|
| ## π Interpretability |
|
|
| BrainGemma3D includes **LIME-based 3D interpretability** to visualize which brain regions drive diagnostic predictions. |
|
|
| ```python |
| from braingemma3d_interpretability import run_interpretability |
| |
| |
| # 6. Run interpretability analysis |
| weights, wvol = run_interpretability( |
| model=model, |
| load_nifti_volume=load_nifti_volume, |
| CANONICAL_PROMPT=CANONICAL_PROMPT, |
| mri_path="path/to/brain_flair.nii.gz", |
| report=report, |
| output_dir="./interpretability_output", |
| lime_samples=100, # Number of perturbations (more = better but slower) |
| n_segments=20, # Number of brain regions to analyze |
| alpha=0.45, # Overlay transparency |
| clip_q=0.99, # Heatmap clipping |
| seed=42, |
| ) |
| ``` |
|
|
| **Output**: |
| - `overlay_slices.png` β Full 3D heatmap (red=supportive, blue=contradicting) |
| - `lime_2x3_grid.png` β 2Γ3 grid with selected slices (original + LIME overlay) |
| - `lime_top_supervoxels_grid.png` β Most influential supervoxels |
| - `lime_weights.json` β Supervoxel importance scores |
|
|
| ### Expected Output Example |
|
|
| <div align="left"> |
| <img src="https://cdn-uploads.huggingface.co/production/uploads/662a12d70951c58269b066fb/UkQwmZRwkn-rlNlFBNVkH.png" alt="LIME Interpretability" width="80%"> |
| <p><i>Figure 1: LIME attribution maps for a BraTS sample. Red regions show supervoxels that positively contribute to pathology predictions. The model correctly focuses on tumor-affected areas in the left parietal and frontal lobes.</i></p> |
| </div> |
| |
| --- |
|
|
| ## βοΈ Model Details |
|
|
| - **Model Type**: Multimodal Vision-Language Model (3D MRI β Text) |
| - **Architecture**: Inflated ViT + Perceiver Compressor + MLP Projector + Quantized Gemma-1.5 |
| - **Input**: 3D brain MRI FLAIR volumes (64Γ128Γ128 voxels) |
| - **Output**: Free-form radiology reports (up to 256 tokens) |
| - **Parameters**: ~454M (450M vision + 4B language, 4-bit quantized) |
| - **Training Compute**: 1Γ NVIDIA A100 64GB (β12 GPU-hours total) |
| - **Framework**: PyTorch 2.0, Transformers 4.40+ |
|
|
| ### Preprocessing Requirements |
|
|
| - **Orientation**: RAS (as-closest-canonical) |
| - **Resolution**: Resampled to (64, 128, 128) via trilinear interpolation |
| - **Normalization**: Percentile clipping (p1, p99) + z-score normalization |
| - **Format**: NIfTI (.nii or .nii.gz) |
|
|
| --- |
|
|
| ## β οΈ Limitations & Intended Use |
|
|
| ### β
Intended Use |
| - **Research**: Medical AI research, neuroradiology automation |
| - **Education**: Teaching radiology residents about report generation |
| - **Prototyping**: Building diagnostic support tools (non-clinical) |
|
|
| ### β Not Intended For |
| - **Clinical Diagnosis**: This model is NOT FDA/CE approved for medical use |
| - **Primary Interpretation**: Always verify with board-certified radiologists |
| - **Real-Time Emergency**: Not validated for acute stroke or trauma cases |
|
|
| ### Known Limitations |
| - **Training Bias**: Trained primarily on glioblastoma (BraTS dataset) β may underperform on other pathologies |
| - **Language**: English only (radiology reports) |
| - **Hallucination Risk**: May generate plausible but incorrect anatomical details (always verify) |
| - **Compute Requirements**: Requires GPU with β₯16GB VRAM for inference |
|
|
| --- |
|
|
| ## π₯ Clinical Validation Notes |
|
|
| BrainGemma3D achieved **95.1% pathology F1** on the BraTS, but this does NOT imply clinical readiness. Key considerations: |
|
|
| 1. **Dataset Homogeneity**: BraTS contains predominantly glioblastomas β performance on other tumor types (meningiomas, metastases) is unknown |
| 2. **Report Quality**: Ground truth reports are from a single institution β may not generalize to other radiology practices |
| 3. **No Radiologist Review**: Generated reports have not been clinically validated by neuroradiologists |
| 4. **Regulatory Status**: Not cleared by FDA, EMA, or any regulatory body |
|
|
| **Recommendation**: Use only in research settings with appropriate ethical oversight and informed consent. |
|
|
| --- |
|
|
| ## π Acknowledgements |
|
|
| This project was developed by: |
|
|
| **Mariano Barone** Β· **Francesco Di Serio** Β· **Giuseppe Riccio** Β· **Antonio Romano** Β· **Vincenzo Moscato** |
|
|
| *Department of Electrical Engineering and Information Technology* |
| *University of Naples Federico II, Italy* |
|
|
| ### Built With |
| - [Google MedGemma](https://huggingface.co/google/medgemma-1.5-4b-it) β Medical domain language model |
| - [Google MedSigLIP](https://huggingface.co/google/medsiglip-448) β Medical vision encoder |
| - [Hugging Face Transformers](https://huggingface.co/docs/transformers) β Model framework |
|
|
| --- |
|
|
| <div align="center"> |
| <p><i>Built with β€οΈ for the <a href="https://www.kaggle.com/competitions/med-gemma-impact-challenge/overview">MedGemma Impact Challenge</a> π</i></p> |
| <p><i>Advancing Medical AI with Google's Health AI Developer Foundations</i></p> |
| </div> |