SAM Brain Tumor Segmentation Model

This model is a fine-tuned Segment Anything Model (SAM) for brain tumor segmentation from medical imaging data. It was trained using a simulated dataset of 2D slices derived from 3D NIfTI (.nii.gz) images and their corresponding segmentation masks.

Model Description

The original SAM model is a powerful general-purpose image segmentation model. This fine-tuned version specializes in identifying brain tumors, leveraging the prompt-based segmentation capabilities of SAM. The model is prompted with bounding boxes around the tumor regions (derived from ground truth masks during training) to generate precise segmentation masks.

Training Details

  • Base Model: facebook/sam-vit-base
  • Dataset: Simulated 2D axial slices from 3D NIfTI images, normalized to 0-1 range.
  • Image Preprocessing: Grayscale images were duplicated across 3 channels to match SAM's expected input. Bounding box prompts were generated from ground truth masks.
  • Loss Functions: Binary Cross-Entropy (BCE) Loss and Dice Loss.
  • Optimizer: AdamW with a learning rate of 1e-5.
  • Epochs: 5
  • Average Dice Score on Validation Set: 0.9756 (on simulated data)

Usage

To use this model for inference, you can load it with the transformers library and provide an image along with a bounding box prompt for the region of interest. The model will then predict a segmentation mask.

from transformers import SamModel, SamProcessor
from PIL import Image
import torch
import numpy as np

# Load the fine-tuned model and processor
processor = SamProcessor.from_pretrained("Lorenzob/sam-brain-tumor-segmentation")
model = SamModel.from_pretrained("Lorenzob/sam-brain-tumor-segmentation")

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# Example: Create a dummy image (replace with your actual medical image)
# This should be a 2D grayscale image, then converted to 3 channels.
# For a real image, load it and ensure it's normalized 0-1 and uint8 or float.
image_size = 256 # Example size
dummy_image_data = np.random.rand(image_size, image_size) * 255
dummy_image = Image.fromarray(dummy_image_data.astype(np.uint8)).convert("RGB")

# Example: Define a bounding box for the tumor region (x_min, y_min, x_max, y_max)
# In a real scenario, this bounding box would be provided by an expert or a detection model.
input_boxes = [[100, 100, 200, 200]] # Example bounding box coordinates

# Preprocess the image and bounding box
inputs = processor(dummy_image, input_boxes=input_boxes, return_tensors="pt").to(device)

# Perform inference
with torch.no_grad():
    outputs = model(**inputs, multimask_output=False)

# Post-process the predicted mask
masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())

# The output `masks` is a list of dictionaries. Each dict contains 'segmentation'.
# For simplicity, let's take the first mask (assuming multimask_output=False)
predicted_mask = masks[0]['segmentation'].squeeze().numpy() # Shape (H, W)

print("Predicted mask shape:", predicted_mask.shape)
# You can visualize 'predicted_mask' using matplotlib or other image libraries.
# For example:
# import matplotlib.pyplot as plt
# plt.imshow(predicted_mask, cmap='gray')
# plt.title('Predicted Segmentation Mask')
# plt.show()

Inference Endpoint Configuration (Optional)

If you wish to deploy this model as an Inference Endpoint on Hugging Face, here's a sample configuration you might use in your README.md (or directly in the UI):

widget:
- src: "app.py"
  example_title: "Brain Tumor Segmentation Example"
  inputs:
  - filename: "image.png"
    image: https://huggingface.co/datasets/huggingface/sample-images/resolve/main/segmentation_image_input.png
    input_boxes: [[100, 100, 200, 200]]

--- # Optional section for specific endpoint settings

parameters:
  do_normalize: false # Assuming inputs are already normalized 0-1
  do_rescale: false   # Assuming inputs are already scaled correctly
  multimask_output: false # For single best mask output

# Example of specific hardware/software config for advanced users
# inference:
#   accelerator: cuda
#   container: pytorch_latest
#   hardware: gpu_small
#   task: image-segmentation

Note: The example image and input_boxes in the YAML configuration are placeholders. For a real medical image endpoint, you would provide a relevant example image and a bounding box corresponding to a tumor within that image.

Limitations

  • The model was fine-tuned on a simulated dataset. Its performance on real, diverse clinical data may vary and needs further rigorous validation.
  • The model relies on a bounding box prompt. Its accuracy is highly dependent on the quality and precision of the provided bounding box.
  • Currently, the model handles 2D slices. Adaptation for full 3D volume segmentation would require further development.

Future Work

  • Evaluate and fine-tune the model on large, real-world medical imaging datasets (e.g., BraTS, TCIA).
  • Explore methods for automatic bounding box generation for tumor regions.
  • Extend the model to handle 3D medical images directly.
  • Implement quantitative metrics (e.g., IoU, Hausdorff Distance) during evaluation with real data.
Downloads last month
307
Safetensors
Model size
93.7M params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support