ly17's picture
Add README.md
eda526e verified
# SAM3 Blood Vessel Segmentation
Fine-tuned SAM3 model for blood vessel angiography segmentation.
## Model Performance
| Model | Dice | IoU | Recall |
|-------|------|-----|--------|
| Original SAM3 | 0.00 | 0.00 | 0.00 |
| Baseline (5 epochs) | 0.79 | 0.66 | 0.73 |
| **Dice Optimized (10 epochs)** | **0.82** | **0.69** | **0.77** |
| Dice Optimized + Post-processing | **0.83** | **0.70** | **0.78** |
## Files
- `checkpoint_dice_optimized.pt` - **Recommended** - Dice optimized model
- `checkpoint_baseline.pt` - Baseline fine-tuned model
- `sam3_original.pt` - Original SAM3 weights
## Usage
```python
from huggingface_hub import hf_hub_download
import torch
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
# Download weights
checkpoint = hf_hub_download(
repo_id="qimingfan10/sam3-vessel-segmentation",
filename="checkpoint_dice_optimized.pt"
)
# Load model
model = build_sam3_image_model(
checkpoint_path="path/to/sam3_original.pt",
enable_segmentation=True,
device="cuda"
)
# Load fine-tuned weights
ckpt = torch.load(checkpoint, map_location="cuda")
state_dict = {k.replace('module.', ''): v for k, v in ckpt['model'].items()}
model.load_state_dict(state_dict, strict=False)
model.eval()
# Inference
processor = Sam3Processor(model)
state = processor.set_image(image)
output = processor.set_text_prompt(state=state, prompt="blood vessel")
masks = output["masks"]
```
## Training
See [VESSEL_SEGMENTATION_GUIDE.md](https://github.com/qimingfan10/Sam3/blob/main/VESSEL_SEGMENTATION_GUIDE.md) for training details.
## Citation
Please cite SAM3 if you use this model.