SAM Brain Tumor Segmentation Model
This model is a fine-tuned Segment Anything Model (SAM) for brain tumor segmentation from medical imaging data. It was trained using a simulated dataset of 2D slices derived from 3D NIfTI (.nii.gz) images and their corresponding segmentation masks.
Model Description
The original SAM model is a powerful general-purpose image segmentation model. This fine-tuned version specializes in identifying brain tumors, leveraging the prompt-based segmentation capabilities of SAM. The model is prompted with bounding boxes around the tumor regions (derived from ground truth masks during training) to generate precise segmentation masks.
Training Details
- Base Model:
facebook/sam-vit-base - Dataset: Simulated 2D axial slices from 3D NIfTI images, normalized to 0-1 range.
- Image Preprocessing: Grayscale images were duplicated across 3 channels to match SAM's expected input. Bounding box prompts were generated from ground truth masks.
- Loss Functions: Binary Cross-Entropy (BCE) Loss and Dice Loss.
- Optimizer: AdamW with a learning rate of 1e-5.
- Epochs: 5
- Average Dice Score on Validation Set: 0.9756 (on simulated data)
Usage
To use this model for inference, you can load it with the transformers library and provide an image along with a bounding box prompt for the region of interest. The model will then predict a segmentation mask.
from transformers import SamModel, SamProcessor
from PIL import Image
import torch
import numpy as np
# Load the fine-tuned model and processor
processor = SamProcessor.from_pretrained("Lorenzob/sam-brain-tumor-segmentation")
model = SamModel.from_pretrained("Lorenzob/sam-brain-tumor-segmentation")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# Example: Create a dummy image (replace with your actual medical image)
# This should be a 2D grayscale image, then converted to 3 channels.
# For a real image, load it and ensure it's normalized 0-1 and uint8 or float.
image_size = 256 # Example size
dummy_image_data = np.random.rand(image_size, image_size) * 255
dummy_image = Image.fromarray(dummy_image_data.astype(np.uint8)).convert("RGB")
# Example: Define a bounding box for the tumor region (x_min, y_min, x_max, y_max)
# In a real scenario, this bounding box would be provided by an expert or a detection model.
input_boxes = [[100, 100, 200, 200]] # Example bounding box coordinates
# Preprocess the image and bounding box
inputs = processor(dummy_image, input_boxes=input_boxes, return_tensors="pt").to(device)
# Perform inference
with torch.no_grad():
outputs = model(**inputs, multimask_output=False)
# Post-process the predicted mask
masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
# The output `masks` is a list of dictionaries. Each dict contains 'segmentation'.
# For simplicity, let's take the first mask (assuming multimask_output=False)
predicted_mask = masks[0]['segmentation'].squeeze().numpy() # Shape (H, W)
print("Predicted mask shape:", predicted_mask.shape)
# You can visualize 'predicted_mask' using matplotlib or other image libraries.
# For example:
# import matplotlib.pyplot as plt
# plt.imshow(predicted_mask, cmap='gray')
# plt.title('Predicted Segmentation Mask')
# plt.show()
Inference Endpoint Configuration (Optional)
If you wish to deploy this model as an Inference Endpoint on Hugging Face, here's a sample configuration you might use in your README.md (or directly in the UI):
widget:
- src: "app.py"
example_title: "Brain Tumor Segmentation Example"
inputs:
- filename: "image.png"
image: https://huggingface.co/datasets/huggingface/sample-images/resolve/main/segmentation_image_input.png
input_boxes: [[100, 100, 200, 200]]
--- # Optional section for specific endpoint settings
parameters:
do_normalize: false # Assuming inputs are already normalized 0-1
do_rescale: false # Assuming inputs are already scaled correctly
multimask_output: false # For single best mask output
# Example of specific hardware/software config for advanced users
# inference:
# accelerator: cuda
# container: pytorch_latest
# hardware: gpu_small
# task: image-segmentation
Note: The example image and input_boxes in the YAML configuration are placeholders. For a real medical image endpoint, you would provide a relevant example image and a bounding box corresponding to a tumor within that image.
Limitations
- The model was fine-tuned on a simulated dataset. Its performance on real, diverse clinical data may vary and needs further rigorous validation.
- The model relies on a bounding box prompt. Its accuracy is highly dependent on the quality and precision of the provided bounding box.
- Currently, the model handles 2D slices. Adaptation for full 3D volume segmentation would require further development.
Future Work
- Evaluate and fine-tune the model on large, real-world medical imaging datasets (e.g., BraTS, TCIA).
- Explore methods for automatic bounding box generation for tumor regions.
- Extend the model to handle 3D medical images directly.
- Implement quantitative metrics (e.g., IoU, Hausdorff Distance) during evaluation with real data.
- Downloads last month
- 307