Brain MRI SigLIP

Brain MRI SigLIP is a 3D MRI vision-language representation model trained with a SigLIP-style image-text contrastive objective. This repository publishes the final saved stage2_joint_finetune checkpoint from the brain_mri_siglip_run_0509 experiment.

This checkpoint is intended as a research visual encoder for brain MRI downstream tasks and as a warm-start encoder for building a medical VLM. It is not a clinical diagnostic device.

Model Summary

  • Base text tower: google/medsiglip-448
  • Model class: BrainMRISiglipModel
  • Vision input: single-channel 3D MRI volumes
  • Expected volume shape: [1, 128, 192, 192]
  • Projection dimension: 1152
  • Patch size: [8, 16, 16]
  • Training precision: bf16
  • Training input format: preprocessed .pt tensors, float16, value range [-1, 1]

Training Context

This model was initialized from the brain_mri_siglip_run_0509/stage1_freeze_text checkpoint and then jointly fine-tuned with both vision and text towers trainable.

Training summary:

  • Training samples: 950,720
  • Validation samples: 67,450
  • Validation samples with metadata_text: 32,278
  • Stage 1: frozen text tower, vision-heavy training
  • Stage 2: joint vision-text fine-tuning
  • Stage 2 epochs configured: 8
  • World size: 5
  • Stage 2 per-device batch size: 160
  • Stage 2 contrastive forward batch: 800
  • Gradient checkpointing: text and vision enabled

Training-time retrieval evaluation used capped validation subsets and should be treated as monitoring rather than a final benchmark.

Loading

This model uses custom Transformers code. Load it with trust_remote_code=True.

import torch
from transformers import AutoModel, AutoProcessor

repo_id = "shenxiaochen/brain-mri-siglip"
device = "cuda" if torch.cuda.is_available() else "cpu"

model = AutoModel.from_pretrained(
    repo_id,
    trust_remote_code=True,
    dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
).to(device).eval()

processor = AutoProcessor.from_pretrained(
    repo_id,
    trust_remote_code=True,
)

NIfTI Preprocessing

For reproducible inference from NIfTI files, pass paths directly to the saved processor. This repository includes the offline-aligned preprocessing implementation used to match the training tensor distribution.

nifti_path = "/path/to/brain_mri.nii.gz"

inputs = processor(
    volumes=nifti_path,
    return_tensors="pt",
)
pixel_values = inputs["pixel_values"].to(device)

if torch.cuda.is_available():
    pixel_values = pixel_values.to(dtype=torch.bfloat16)

with torch.inference_mode():
    image_embeds = model.get_image_features(pixel_values=pixel_values)

print(pixel_values.shape)  # [1, 1, 128, 192, 192]
print(image_embeds.shape)  # [1, 1152]

The saved path-based preprocessing recipe is:

  • canonicalize image orientation to closest RAS
  • build foreground mask with threshold 1e-3
  • keep the largest connected foreground component
  • crop foreground with 5mm margin
  • normalize foreground intensities with 0.5/99.5 percentiles
  • map intensities to [-1, 1]
  • resample to spacing (1.25, 1.0, 1.0)
  • downscale to fit [128, 192, 192]
  • center-pad with background value -1.0

The exact settings are saved in preprocessor_config.json and processor_config.json.

Using Preprocessed .pt Inputs

If your data is already stored as the same offline preprocessed tensors used during training, you can load it directly:

payload = torch.load("/path/to/sample.pt", map_location="cpu")
pixel_values = payload["pixel_values"] if isinstance(payload, dict) else payload

if pixel_values.ndim == 4:
    pixel_values = pixel_values.unsqueeze(0)

pixel_values = pixel_values.to(device=device, dtype=torch.bfloat16)

with torch.inference_mode():
    image_embeds = model.get_image_features(pixel_values=pixel_values)

Expected tensor format:

  • shape [1, 128, 192, 192] for one volume, or [B, 1, 128, 192, 192] for a batch
  • values in [-1, 1]
  • padded background voxels near -1.0

VLM Integration Notes

For VLM construction, use the 3D vision tower as a visual backbone and add a projector, Q-Former, Perceiver resampler, or other token compressor before connecting to an LLM.

A practical downstream recipe is:

  1. Freeze this MRI encoder and train only the multimodal projector/resampler.
  2. Evaluate downstream classification, retrieval, report alignment, or instruction-following behavior.
  3. Optionally unfreeze the top vision layers with a much smaller learning rate.

Limitations

  • This checkpoint was trained for representation learning, not diagnosis.
  • Performance should be validated on task-specific subject-level or study-level splits.
  • Scanner, protocol, site, and preprocessing differences can affect embeddings.
  • External users should preserve the saved preprocessing pipeline for NIfTI inference.
  • Retrieval monitoring during training is not a substitute for downstream clinical validation.

Citation

If you use this checkpoint, please cite this model repository and the upstream MedSigLIP model where appropriate.

Downloads last month
36
Safetensors
Model size
0.9B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for shenxiaochen/brain-mri-siglip

Finetuned
(46)
this model