File size: 13,159 Bytes
9a01dbb 7d4ea8e 9a01dbb dfcd20b 9a01dbb 7d4ea8e 9a01dbb 7d4ea8e 9a01dbb 7d4ea8e 9a01dbb dfcd20b ff81c43 7d4ea8e dfcd20b 9a01dbb 7d4ea8e 9a01dbb 7d4ea8e 9a01dbb 7d4ea8e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 | ---
license: cc-by-4.0
language:
- en
tags:
- multimodal
- vision-language
- medical
- neuroradiology
- brain-mri
- report-generation
- 3d-vision
- medgemma
- medsiglip
datasets:
- BraTS2020
- TextBraTS2021
- MPI-Leipzig_Mind-Brain-Body
base_model:
- google/medgemma-1.5-4b-it
- google/medsiglip-448
---
# π§ BrainGemma3D β Brain Report Automation via Inflated Vision Transformers in 3D
BrainGemma3D is a **multimodal vision-language model** that generates clinically accurate radiology reports directly from **native 3D brain MRI** volumes. Unlike 2D slice-based approaches, BrainGemma3D processes MRI scans volumetrically, preserving the spatial context critical for accurate neuroradiological interpretation.
<div align="center">
<a href="https://github.com/PRAISELab-PicusLab/BrainGemma3D" target="_blank"><img alt="GitHub Repository"
src="https://img.shields.io/badge/GitHub-BrainGemma3D-181717?style=for-the-badge&logo=github&logoSize=auto"/></a>
<a href="https://www.kaggle.com/code/antonioromano45/braingemma3d" target="_blank"><img alt="Kaggle Notebook"
src="https://img.shields.io/badge/Kaggle-Notebook-20BEFF?style=for-the-badge&logo=kaggle&logoSize=auto"/></a>
<br>
<a href="https://www.kaggle.com/competitions/med-gemma-impact-challenge/overview" target="_blank"><img alt="MedGemma Challenge"
src="https://img.shields.io/badge/Kaggle-MedGemma_Impact_Challenge-blue?style=for-the-badge&logo=kaggle&logoSize=auto&color=20BEFF"/></a>
</div>
---
## π― Key Features
- **π¬ Native 3D Processing**: Inflated 2D medical vision encoder ([MedSigLIP](https://huggingface.co/google/medsiglip-448)) to 3D for volumetric understanding
- **π Clinical Accuracy**: 95.1% F1 score on pathology entity recognition (on BraTS dataset)
- **π§ Spatial Awareness**: 68.9% laterality F1 (correct left/right hemisphere localization)
- **π Interpretable**: LIME-based 3D attribution maps show which brain regions drive predictions
- **π Efficient**: Processes full 3D volumes with 32 compressed visual tokens
- **π₯ Research-Ready**: Pre-trained on 369 brain tumor cases + 99 healthy controls
---
## ποΈ Architecture
BrainGemma3D combines:
1. **3D Vision Encoder**: MedSigLIP inflated to 3D via center-frame initialization (Conv2D β Conv3D)
*Base model: [google/medsiglip-448](https://huggingface.co/google/medsiglip-448)*
2. **Token Compressor**: 2-layer Perceiver that reduces 3D patches to 32 visual tokens
3. **Vision-Language Projector**: 2-layer MLP that projects visual tokens to language model embedding space
4. **Language Model**: 4-bit quantized [MedGemma-1.5-4B-IT](https://huggingface.co/google/medgemma-1.5-4b-it) with LoRA adapters
*Base model: [google/medgemma-1.5-4b-it](https://huggingface.co/google/medgemma-1.5-4b-it)*
---
## π Usage
### Requirements
```bash
pip install torch torchvision transformers nibabel scikit-image lime
```
### Model Download
```python
from huggingface_hub import snapshot_download
# 1. Download the repository containing our custom architecture from Hugging Face
repo_id = "praiselab-picuslab/BrainGemma3D"
print(f"Downloading repository: {repo_id}...")
local_dir = snapshot_download(repo_id)
print(f"β
Repository downloaded to: {local_dir}")
```
### Quick Start
```python
import os
import torch
import sys
sys.path.append(local_dir)
from medgemma3d_architecture import MedGemma3D, load_nifti_volume, CANONICAL_PROMPT
# Automatically select the optimal hardware accelerator (GPU if available, otherwise CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Hardware accelerator selected: {device}")
# 2. Instantiate the base architecture (3D-inflated MedSigLIP + MedGemma)
model = MedGemma3D(
vision_model_dir=f"{local_dir}/vision_model",
language_model_dir=f"{local_dir}/language_model",
depth=2,
num_vision_tokens=32,
freeze_vision=True,
freeze_language=True,
device_map={"": 0} if device == "cuda" else None,
)
# 3. Load projector
proj_path = os.path.join(local_dir, "projector_vis_scale.pt")
print(f"Loading custom projector weights from: {proj_path}...")
# Load the checkpoint into memory
ckpt = torch.load(proj_path, map_location=device)
# Inject the weights into the visual projector (which bridges Vision and Language)
model.vision_projector.load_state_dict(ckpt["vision_projector"])
# Load the visual scaling factor, ensuring correct tensor formatting
if ckpt.get("vis_scale") is not None:
if isinstance(ckpt["vis_scale"], torch.Tensor):
model.vis_scale.data = ckpt["vis_scale"].to(device)
else:
model.vis_scale.data.fill_(ckpt["vis_scale"])
# Transition the model to evaluation mode for inference
model.eval()
print("β
BrainGemma3D is fully loaded and ready for inference!")
# 4. Load MRI scan
volume = load_nifti_volume(
"path/to/brain_flair.nii.gz",
target_size=(32, 128, 128)
).to(device)
if volume.ndim == 4:
volume = volume.unsqueeze(0)
# 5. Generate report
with torch.no_grad():
report = model.generate_report(
volume,
prompt=CANONICAL_PROMPT,
max_new_tokens=256,
temperature=0.1,
top_p=0.9,
)
print("\n===== GENERATED REPORT =====\n")
print(report)
```
### Expected Output Example
```
Generated Report:
The lesion area is in the left parietal and frontal lobes with mixed high-signal
areas. Edema signals are mainly observed around these lesions, indicating significant
edema presence affecting parts of both frontal and temporal regions as well as some
portions within the parietal lobe. Necrosis may be present at low signal intensity
or scattered throughout certain sections of the brain tissue affected by edema.
Ventricular compression effects on adjacent ventricles can occur due to pressure
from surrounding tissues near the ventricular system.
```
---
## π Training Pipeline
BrainGemma3D is trained in **three progressive stages** to prevent catastrophic forgetting:
### Phase 1: Contrastive Grounding (Image-Text Alignment)
- **Goal**: Align 3D visual features with textual report embeddings
- **Loss**: InfoNCE (CLIP-style contrastive learning)
- **Trainable**: 3D Vision Encoder + Projector
- **Frozen**: Language Model
- **Epochs**: 100
### Phase 2A: Projector Warmup
- **Goal**: Train the projector to condition the LM effectively
- **Loss**: Next-token prediction (Cross-Entropy)
- **Trainable**: Projector only
- **Frozen**: Vision Encoder + Language Model
- **Epochs**: 100
### Phase 2B: LoRA Linguistic Specialization
- **Goal**: Adapt LM to generate structured clinical reports
- **Loss**: Next-token prediction (Cross-Entropy)
- **Trainable**: Projector + LoRA adapters (rank=4) on LM attention layers
- **Frozen**: Vision Encoder + LM base weights
- **Epochs**: 100
**Dataset**:
- 369 [BraTS 2020](https://www.kaggle.com/datasets/awsaf49/brats20-dataset-training-validation) brain tumor MRI cases with radiologist-written reports from [TextBraTS 2021](https://github.com/Jupitern52/TextBraTS)
- 99 healthy control scans with synthetic reports from [MPI-Leipzig Mind-Brain-Body](https://openneuro.org/datasets/ds000221/versions/00002)
- Stratified group-based splits (70% train / 10% val / 20% test) to prevent patient leakage
---
## π Performance
Evaluated on **468 subjects** (369 BraTS pathological + 99 healthy controls) with group-based splits.
### Quantitative Results (Test Set)
| Model | BLEU-1 | BLEU-4 | ROUGE-L | CIDEr | Lat F1 | Anat F1 | **Path F1** |
|-------|--------|--------|---------|-------|--------|---------|-------------|
| Med3DVLM *(3D Generalist)* | 0.051 | 0.005 | 0.083 | 0.007 | 0.300 | 0.225 | 0.119 |
| MedGemma 1.5 *(2D Slice)* | 0.245 | 0.024 | 0.189 | 0.029 | 0.526 | 0.461 | 0.413 |
| **BrainGemma3D** *(Ours)* | **0.302** | **0.098** | **0.289** | **0.293** | **0.689** | **0.691** | **0.951** |
### Clinical Metrics Breakdown
- **Laterality F1**: 0.689 β Correct hemispheric localization (left/right)
- **Anatomy F1**: 0.691 β Accurate anatomical structure identification
- **Pathology F1**: 0.951 β Near-perfect pathological entity recognition
- **Healthy Specificity**: 1.0 β Zero hallucinations on healthy controls
> **Key Insight**: The **+130% gain in Pathology F1** (0.951 vs 0.413 compared to 2D baseline) demonstrates that native 3D processing is essential for diagnostic accuracy in neuroradiology.
---
## π Interpretability
BrainGemma3D includes **LIME-based 3D interpretability** to visualize which brain regions drive diagnostic predictions.
```python
from braingemma3d_interpretability import run_interpretability
# 6. Run interpretability analysis
weights, wvol = run_interpretability(
model=model,
load_nifti_volume=load_nifti_volume,
CANONICAL_PROMPT=CANONICAL_PROMPT,
mri_path="path/to/brain_flair.nii.gz",
report=report,
output_dir="./interpretability_output",
lime_samples=100, # Number of perturbations (more = better but slower)
n_segments=20, # Number of brain regions to analyze
alpha=0.45, # Overlay transparency
clip_q=0.99, # Heatmap clipping
seed=42,
)
```
**Output**:
- `overlay_slices.png` β Full 3D heatmap (red=supportive, blue=contradicting)
- `lime_2x3_grid.png` β 2Γ3 grid with selected slices (original + LIME overlay)
- `lime_top_supervoxels_grid.png` β Most influential supervoxels
- `lime_weights.json` β Supervoxel importance scores
### Expected Output Example
<div align="left">
<img src="https://cdn-uploads.huggingface.co/production/uploads/662a12d70951c58269b066fb/UkQwmZRwkn-rlNlFBNVkH.png" alt="LIME Interpretability" width="80%">
<p><i>Figure 1: LIME attribution maps for a BraTS sample. Red regions show supervoxels that positively contribute to pathology predictions. The model correctly focuses on tumor-affected areas in the left parietal and frontal lobes.</i></p>
</div>
---
## βοΈ Model Details
- **Model Type**: Multimodal Vision-Language Model (3D MRI β Text)
- **Architecture**: Inflated ViT + Perceiver Compressor + MLP Projector + Quantized Gemma-1.5
- **Input**: 3D brain MRI FLAIR volumes (64Γ128Γ128 voxels)
- **Output**: Free-form radiology reports (up to 256 tokens)
- **Parameters**: ~454M (450M vision + 4B language, 4-bit quantized)
- **Training Compute**: 1Γ NVIDIA A100 64GB (β12 GPU-hours total)
- **Framework**: PyTorch 2.0, Transformers 4.40+
### Preprocessing Requirements
- **Orientation**: RAS (as-closest-canonical)
- **Resolution**: Resampled to (64, 128, 128) via trilinear interpolation
- **Normalization**: Percentile clipping (p1, p99) + z-score normalization
- **Format**: NIfTI (.nii or .nii.gz)
---
## β οΈ Limitations & Intended Use
### β
Intended Use
- **Research**: Medical AI research, neuroradiology automation
- **Education**: Teaching radiology residents about report generation
- **Prototyping**: Building diagnostic support tools (non-clinical)
### β Not Intended For
- **Clinical Diagnosis**: This model is NOT FDA/CE approved for medical use
- **Primary Interpretation**: Always verify with board-certified radiologists
- **Real-Time Emergency**: Not validated for acute stroke or trauma cases
### Known Limitations
- **Training Bias**: Trained primarily on glioblastoma (BraTS dataset) β may underperform on other pathologies
- **Language**: English only (radiology reports)
- **Hallucination Risk**: May generate plausible but incorrect anatomical details (always verify)
- **Compute Requirements**: Requires GPU with β₯16GB VRAM for inference
---
## π₯ Clinical Validation Notes
BrainGemma3D achieved **95.1% pathology F1** on the BraTS, but this does NOT imply clinical readiness. Key considerations:
1. **Dataset Homogeneity**: BraTS contains predominantly glioblastomas β performance on other tumor types (meningiomas, metastases) is unknown
2. **Report Quality**: Ground truth reports are from a single institution β may not generalize to other radiology practices
3. **No Radiologist Review**: Generated reports have not been clinically validated by neuroradiologists
4. **Regulatory Status**: Not cleared by FDA, EMA, or any regulatory body
**Recommendation**: Use only in research settings with appropriate ethical oversight and informed consent.
---
## π Acknowledgements
This project was developed by:
**Mariano Barone** Β· **Francesco Di Serio** Β· **Giuseppe Riccio** Β· **Antonio Romano** Β· **Vincenzo Moscato**
*Department of Electrical Engineering and Information Technology*
*University of Naples Federico II, Italy*
### Built With
- [Google MedGemma](https://huggingface.co/google/medgemma-1.5-4b-it) β Medical domain language model
- [Google MedSigLIP](https://huggingface.co/google/medsiglip-448) β Medical vision encoder
- [Hugging Face Transformers](https://huggingface.co/docs/transformers) β Model framework
---
<div align="center">
<p><i>Built with β€οΈ for the <a href="https://www.kaggle.com/competitions/med-gemma-impact-challenge/overview">MedGemma Impact Challenge</a> π</i></p>
<p><i>Advancing Medical AI with Google's Health AI Developer Foundations</i></p>
</div> |