polyp-segmentation / README.md
uncleMehrzad's picture
Update README.md
302bc19 verified
---
language: en
tags:
- medical-imaging
- polyp-segmentation
- dinov3
- vision-transformer
- kvasir-seg
- colonoscopy
- unet
datasets:
- kmader/kvasir-segmentation
metrics:
- dice
- iou
- precision
- recall
- hd95
library_name: pytorch
pipeline_tag: image-segmentation
license: mit
---
# DINOv3 Polyp Segmentation with U-Net Decoder
## Model Description
This model performs **polyp segmentation** in colonoscopy images using a frozen DINOv3-ViT-L/16 backbone with multi-scale feature extraction and a U-Net style decoder with skip connections. The model was trained on the Kvasir-SEG dataset.
**Key Features:**
- 🏗️ **U-Net architecture**: Skip connections from shallow stem for precise boundary detection
- 📐 **Multi-scale features**: Extracts DINOv3 features from layers [5, 11, 17, 20, 23] for rich hierarchical representation
- 🩺 **Medical-grade segmentation**: Specifically designed for polyp detection in colonoscopy
- 🔒 **Frozen backbone**: Leverages DINOv3's rich visual features without overfitting
- 📊 **Comprehensive metrics**: Evaluated with Dice, IoU, Precision, Recall, and HD95
- 🔄 **Cosine annealing**: Uses CosineAnnealingWarmRestarts for better convergence
## Model Architecture
Input Image (256×256×3)
┌───────────────────────┬──────────────────────┐
│ Shallow Stem │ DINOv3 Encoder │
│ (Trainable) │ (Frozen) │
│ │ │
│ Conv 3→64 (3×3) │ Layers [5,11,17, │
│ Conv 64→128 (stride2)│ 20,23] │
│ Conv 128→256 (stride2)│ Multi-scale concat │
│ Conv 256→512 (stride2)│ 5 × 1024 = 5120 │
└───────┬───────────────┴──────────┬───────────┘
│ Skip Connections │
│ [512, 256, 128] │
↓ ↓
┌──────────────────────────────────────────┐
│ U-Net Decoder (Trainable) │
│ │
│ Conv 5120→256 + Skip(512) → ConvBlock │
│ Upsample → Conv 384→128 + Skip(256) │
│ Upsample → Conv 192→64 + Skip(128) │
│ Upsample → Final Conv 64→1 (1×1) │
└──────────────────┬───────────────────────┘
Segmentation Mask (256×256×1)
## Training Details
| Hyperparameter | Value |
|---------------|-------|
| Backbone | DINOv3-ViT-L/16 (frozen) |
| Multi-scale Layers | [5, 11, 17, 20, 23] |
| Input Resolution | 256×256 |
| Batch Size | 96 |
| Epochs | 150 |
| Learning Rate | 1e-4 (initial) |
| Min Learning Rate | 1e-6 |
| Weight Decay | 1e-4 |
| Optimizer | AdamW |
| Scheduler | CosineAnnealingWarmRestarts |
| Scheduler Config | T_0=10, T_mult=2 |
| Loss Function | Focal + Dice (0.7/0.3 weights) |
| Focal Loss Gamma | 2.0 |
| Focal Loss Alpha | 0.25 |
| Trainable Parameters | ~21M (Stem + Decoder) |
### Data Augmentation
- Random 90° rotation
- Horizontal/Vertical flips
- ShiftScaleRotate (shift=0.05, scale=0.05, rotate=15°)
- MotionBlur/GaussianBlur
- ColorJitter (brightness, contrast, saturation, hue)
## Performance Metrics
### Final Test Set Results
| Metric | Score |
|--------|-------|
| **Dice Score** | **0.8289 ± 0.0000** |
| **IoU** | **0.7078 ± 0.0000** |
| **Precision** | 0.7910 ± 0.0000 |
| **Recall** | 0.8705 ± 0.0000 |
| **HD95 (pixels)** | 45.46 ± 0.00 |
| **Best Validation Dice** | 0.7327 |
### Validation Set Results
| Metric | Score |
|--------|-------|
| **Dice Score** | **0.8795 ± 0.0304** |
| **IoU** | **0.7862 ± 0.0485** |
| **Precision** | 0.8846 ± 0.0256 |
| **Recall** | 0.8744 ± 0.0351 |
| **HD95 (pixels)** | 30.85 ± 8.90 |
### Training Set Results (Final Epoch)
| Metric | Score |
|--------|-------|
| **Dice Score** | 0.8747 ± 0.0108 |
| **IoU** | 0.7775 ± 0.0170 |
| **Precision** | 0.8698 ± 0.0189 |
| **Recall** | 0.8801 ± 0.0136 |
| **HD95 (pixels)** | 33.91 ± 1.69 |
## Usage
### Installation
```bash
pip install torch transformers pillow matplotlib numpy opencv-python albumentations scipy scikit-learn
Basic Inference
python
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
# Import the model architecture (same as training)
from model import DINOv3Encoder, ShallowStem, UNetDecoder, PolypSegmentationModel
# Load model
model = PolypSegmentationModel.from_pretrained(
"your-username/dinov3-polyp-seg",
device="cuda" if torch.cuda.is_available() else "cpu"
)
# Preprocess image
def preprocess_image(image_path, target_size=(256, 256)):
image = Image.open(image_path).convert('RGB')
image = image.resize(target_size, Image.Resampling.BILINEAR)
# Convert to numpy and normalize
image_array = np.array(image).astype(np.float32) / 255.0
mean = np.array([0.485, 0.456, 0.406]).reshape(1, 1, 3)
std = np.array([0.229, 0.224, 0.225]).reshape(1, 1, 3)
image_array = (image_array - mean) / std
# Convert to tensor [B, C, H, W]
image_tensor = torch.from_numpy(image_array).permute(2, 0, 1).unsqueeze(0)
return image_tensor, image
# Run inference
image_tensor, original_image = preprocess_image("colonoscopy_image.jpg")
with torch.no_grad():
prediction = model(image_tensor)
mask = torch.sigmoid(prediction)
binary_mask = (mask > 0.5).float()
mask_np = binary_mask.squeeze().cpu().numpy()
# Visualize
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(original_image)
axes[0].set_title("Input Image")
axes[1].imshow(mask_np, cmap='gray')
axes[1].set_title("Polyp Segmentation")
axes[2].imshow(original_image)
axes[2].imshow(mask_np, cmap='Reds', alpha=0.5)
axes[2].set_title("Overlay")
plt.show()
Advanced Usage with Metrics
python
from scipy.ndimage import morphology
def compute_hd95(pred, target):
"""Compute Hausdorff Distance 95th percentile"""
if pred.sum() == 0 or target.sum() == 0:
return float('inf')
pred_border = pred - morphology.binary_erosion(pred)
target_border = target - morphology.binary_erosion(target)
pred_coords = np.argwhere(pred_border > 0)
target_coords = np.argwhere(target_border > 0)
distances = []
for p in pred_coords:
dist = np.min(np.sqrt(np.sum((target_coords - p) ** 2, axis=1)))
distances.append(dist)
return np.percentile(distances, 95)
# Batch inference
dataloader = DataLoader(dataset, batch_size=16, shuffle=False)
all_metrics = {'dice': [], 'iou': [], 'hd95': []}
for images, masks in dataloader:
with torch.no_grad():
predictions = model(images)
# Calculate metrics for each image
for pred, mask in zip(predictions, masks):
pred_binary = (torch.sigmoid(pred) > 0.5).float()
# Dice
intersection = (pred_binary * mask).sum()
dice = (2. * intersection) / (pred_binary.sum() + mask.sum() + 1e-6)
# IoU
union = pred_binary.sum() + mask.sum() - intersection
iou = intersection / (union + 1e-6)
# HD95
hd95 = compute_hd95(pred_binary.numpy().squeeze(), mask.numpy().squeeze())
all_metrics['dice'].append(dice.item())
all_metrics['iou'].append(iou.item())
all_metrics['hd95'].append(hd95)
print(f"Average Dice: {np.mean(all_metrics['dice']):.4f} ± {np.std(all_metrics['dice']):.4f}")
print(f"Average IoU: {np.mean(all_metrics['iou']):.4f} ± {np.std(all_metrics['iou']):.4f}")
print(f"Average HD95: {np.mean(all_metrics['hd95']):.2f} ± {np.std(all_metrics['hd95']):.2f}")
Model Limitations
Input size: Fixed to 256×256 pixels (resize your images accordingly)
Domain: Trained only on colonoscopy images from Kvasir-SEG
Polyp types: May not generalize to all polyp morphologies
Image quality: Best performance with standard white-light colonoscopy images
## Dataset
Trained on the Kvasir-SEG dataset, which contains 1000 polyp images with corresponding ground truth masks from colonoscopy procedures.
## License
This model is released under the MIT License.
## Citation
If you use this model in your research, please cite:
bibtex
@software{dinov3_polyp_seg,
author = {Amirreza Mehrzadian},
title = {DINOv3 Polyp Segmentation with U-Net Decoder},
year = {2024},
url = {https://huggingface.co/uncleMehrzad/dinov3-polyp-seg}
}
## Acknowledgments
DINOv3 team for the powerful vision backbone
Kvasir-SEG dataset providers for the polyp segmentation data
HuggingFace for model hosting infrastructure
```python
class PolypSegmentationModel(nn.Module):
"""Complete model wrapper matching training architecture"""
def __init__(self, encoder, stem, decoder):
super().__init__()
self.encoder = encoder
self.stem = stem
self.decoder = decoder
def forward(self, x):
vit_features = self.encoder(x)
skip_features = self.stem(x)
return self.decoder(vit_features, skip_features)
@classmethod
def from_pretrained(cls, model_path, config, device="cpu"):
"""Load the complete model from checkpoint"""
checkpoint = torch.load(model_path, map_location=device)
# Initialize components
encoder = DINOv3Encoder(
model_name=config.model_name,
local_path=config.local_model_path,
freeze=True,
layers=config.multi_scale_layers
)
stem = ShallowStem(in_channels=3, base_channels=64)
decoder = UNetDecoder(
vit_channels=encoder.out_channels,
stem_channels=[512, 256, 128],
num_classes=1
)
# Load weights
decoder.load_state_dict(checkpoint['decoder_state_dict'])
stem.load_state_dict(checkpoint['stem_state_dict'])
model = cls(encoder, stem, decoder)
model.to(device)
model.eval()
return model