Brain Tumor Segmentation β€” SAM3 Linear Probe

PyTorch checkpoint artifact for the MultiAgentMedClassifier tumor segmentation task. Contains a linear 1Γ—1 Conv2d probe head trained on top of a frozen SAM3 backbone for pixel-level brain tumor segmentation.

These are checkpoint files for the accompanying project loaders, not standalone Transformers models.

Model Description

  • Task: brain tumor MRI segmentation (binary mask: tumor / background)
  • Architecture: frozen SAM3 image encoder + linear 1Γ—1 Conv2d probe head
  • Backbone input resolution: 1008 Γ— 1008
  • Probe head: nn.Conv2d(feature_dim=256, out_channels=2, kernel_size=1)
  • Framework: PyTorch

Performance

Method Dice IoU Sensitivity
Zero-shot SAM3 0.189 0.124 0.397
Linear probe (frozen encoder) 0.836 (pixel) / 0.801 (per-case mean) β€” β€”

Files

  • tumor_segmentation/sam3/sam3_linear_probe_tumor_segmentation_best.pt: SAM3 frozen backbone + 1Γ—1 Conv2d linear probe for brain tumor segmentation.

Checkpoint Format

The checkpoint is a dict with:

{
    "model_state_dict": {"weight": ..., "bias": ...},  # Conv2d probe weights
    "feature_dim": 256,                                  # SAM3 feature channels
}

Alternate supported key formats: classifier.weight/bias, module.classifier.weight/bias, or flat weight/bias.

Runtime Requirements

SAM3 has strict runtime prerequisites:

  • Python 3.12+
  • PyTorch 2.7+
  • CUDA GPU with CUDA 12.6+

The probe checkpoint alone can be loaded without SAM3 installed, but inference requires the full SAM3 backbone.

Inference Example

from huggingface_hub import hf_hub_download
from agents.sam3_tool import SAM3Tool
from config import DEFAULT_CONFIG

probe_path = hf_hub_download(
    repo_id="tamara-kostova/multiagentmed-tumor-segmentation",
    filename="tumor_segmentation/sam3/sam3_linear_probe_tumor_segmentation_best.pt",
)

DEFAULT_CONFIG.model.sam3_linear_probe_checkpoint = probe_path

tool = SAM3Tool(DEFAULT_CONFIG.model)
result = tool.segment("path/to/brain_mri.png", text_prompt="brain tumor")

print(result["mask_path"])        # binary segmentation mask
print(result["bbox"])             # [x1, y1, x2, y2]
print(result["guided_image_path"]) # original image with red bbox overlay

Output Format

{
    "mask_path": "outputs/segmentation/mask_<uid>.png",   # binary mask (0/255)
    "bbox": [x1, y1, x2, y2],                            # bounding box of mask
    "guided_image_path": "outputs/segmentation/guided_<uid>.png",  # bbox overlay for MedGemma
    "skipped": False
}

Intended Use

Research and experimentation only. Not a medical device. Always validate on your own held-out test set before using in any pipeline.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for tamara-kostova/multiagentmed-tumor-segmentation

Base model

facebook/sam3
Finetuned
(12)
this model