polyp-segmentation / README.md
uncleMehrzad's picture
Update README.md
302bc19 verified
metadata
language: en
tags:
  - medical-imaging
  - polyp-segmentation
  - dinov3
  - vision-transformer
  - kvasir-seg
  - colonoscopy
  - unet
datasets:
  - kmader/kvasir-segmentation
metrics:
  - dice
  - iou
  - precision
  - recall
  - hd95
library_name: pytorch
pipeline_tag: image-segmentation
license: mit

DINOv3 Polyp Segmentation with U-Net Decoder

Model Description

This model performs polyp segmentation in colonoscopy images using a frozen DINOv3-ViT-L/16 backbone with multi-scale feature extraction and a U-Net style decoder with skip connections. The model was trained on the Kvasir-SEG dataset.

Key Features:

  • πŸ—οΈ U-Net architecture: Skip connections from shallow stem for precise boundary detection
  • πŸ“ Multi-scale features: Extracts DINOv3 features from layers [5, 11, 17, 20, 23] for rich hierarchical representation
  • 🩺 Medical-grade segmentation: Specifically designed for polyp detection in colonoscopy
  • πŸ”’ Frozen backbone: Leverages DINOv3's rich visual features without overfitting
  • πŸ“Š Comprehensive metrics: Evaluated with Dice, IoU, Precision, Recall, and HD95
  • πŸ”„ Cosine annealing: Uses CosineAnnealingWarmRestarts for better convergence

Model Architecture

Input Image (256Γ—256Γ—3) ↓ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ Shallow Stem β”‚ DINOv3 Encoder β”‚ β”‚ (Trainable) β”‚ (Frozen) β”‚ β”‚ β”‚ β”‚ β”‚ Conv 3β†’64 (3Γ—3) β”‚ Layers [5,11,17, β”‚ β”‚ Conv 64β†’128 (stride2)β”‚ 20,23] β”‚ β”‚ Conv 128β†’256 (stride2)β”‚ Multi-scale concat β”‚ β”‚ Conv 256β†’512 (stride2)β”‚ 5 Γ— 1024 = 5120 β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ Skip Connections β”‚ β”‚ [512, 256, 128] β”‚ ↓ ↓ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ U-Net Decoder (Trainable) β”‚ β”‚ β”‚ β”‚ Conv 5120β†’256 + Skip(512) β†’ ConvBlock β”‚ β”‚ Upsample β†’ Conv 384β†’128 + Skip(256) β”‚ β”‚ Upsample β†’ Conv 192β†’64 + Skip(128) β”‚ β”‚ Upsample β†’ Final Conv 64β†’1 (1Γ—1) β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ ↓ Segmentation Mask (256Γ—256Γ—1)

Training Details

Hyperparameter Value
Backbone DINOv3-ViT-L/16 (frozen)
Multi-scale Layers [5, 11, 17, 20, 23]
Input Resolution 256Γ—256
Batch Size 96
Epochs 150
Learning Rate 1e-4 (initial)
Min Learning Rate 1e-6
Weight Decay 1e-4
Optimizer AdamW
Scheduler CosineAnnealingWarmRestarts
Scheduler Config T_0=10, T_mult=2
Loss Function Focal + Dice (0.7/0.3 weights)
Focal Loss Gamma 2.0
Focal Loss Alpha 0.25
Trainable Parameters ~21M (Stem + Decoder)

Data Augmentation

  • Random 90Β° rotation
  • Horizontal/Vertical flips
  • ShiftScaleRotate (shift=0.05, scale=0.05, rotate=15Β°)
  • MotionBlur/GaussianBlur
  • ColorJitter (brightness, contrast, saturation, hue)

Performance Metrics

Final Test Set Results

Metric Score
Dice Score 0.8289 Β± 0.0000
IoU 0.7078 Β± 0.0000
Precision 0.7910 Β± 0.0000
Recall 0.8705 Β± 0.0000
HD95 (pixels) 45.46 Β± 0.00
Best Validation Dice 0.7327

Validation Set Results

Metric Score
Dice Score 0.8795 Β± 0.0304
IoU 0.7862 Β± 0.0485
Precision 0.8846 Β± 0.0256
Recall 0.8744 Β± 0.0351
HD95 (pixels) 30.85 Β± 8.90

Training Set Results (Final Epoch)

Metric Score
Dice Score 0.8747 Β± 0.0108
IoU 0.7775 Β± 0.0170
Precision 0.8698 Β± 0.0189
Recall 0.8801 Β± 0.0136
HD95 (pixels) 33.91 Β± 1.69

Usage

Installation

pip install torch transformers pillow matplotlib numpy opencv-python albumentations scipy scikit-learn
Basic Inference
python
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

# Import the model architecture (same as training)
from model import DINOv3Encoder, ShallowStem, UNetDecoder, PolypSegmentationModel

# Load model
model = PolypSegmentationModel.from_pretrained(
    "your-username/dinov3-polyp-seg",
    device="cuda" if torch.cuda.is_available() else "cpu"
)

# Preprocess image
def preprocess_image(image_path, target_size=(256, 256)):
    image = Image.open(image_path).convert('RGB')
    image = image.resize(target_size, Image.Resampling.BILINEAR)
    
    # Convert to numpy and normalize
    image_array = np.array(image).astype(np.float32) / 255.0
    mean = np.array([0.485, 0.456, 0.406]).reshape(1, 1, 3)
    std = np.array([0.229, 0.224, 0.225]).reshape(1, 1, 3)
    image_array = (image_array - mean) / std
    
    # Convert to tensor [B, C, H, W]
    image_tensor = torch.from_numpy(image_array).permute(2, 0, 1).unsqueeze(0)
    return image_tensor, image

# Run inference
image_tensor, original_image = preprocess_image("colonoscopy_image.jpg")

with torch.no_grad():
    prediction = model(image_tensor)
    mask = torch.sigmoid(prediction)
    binary_mask = (mask > 0.5).float()
    mask_np = binary_mask.squeeze().cpu().numpy()

# Visualize
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(original_image)
axes[0].set_title("Input Image")
axes[1].imshow(mask_np, cmap='gray')
axes[1].set_title("Polyp Segmentation")
axes[2].imshow(original_image)
axes[2].imshow(mask_np, cmap='Reds', alpha=0.5)
axes[2].set_title("Overlay")
plt.show()
Advanced Usage with Metrics
python
from scipy.ndimage import morphology

def compute_hd95(pred, target):
    """Compute Hausdorff Distance 95th percentile"""
    if pred.sum() == 0 or target.sum() == 0:
        return float('inf')
    
    pred_border = pred - morphology.binary_erosion(pred)
    target_border = target - morphology.binary_erosion(target)
    
    pred_coords = np.argwhere(pred_border > 0)
    target_coords = np.argwhere(target_border > 0)
    
    distances = []
    for p in pred_coords:
        dist = np.min(np.sqrt(np.sum((target_coords - p) ** 2, axis=1)))
        distances.append(dist)
    
    return np.percentile(distances, 95)

# Batch inference
dataloader = DataLoader(dataset, batch_size=16, shuffle=False)

all_metrics = {'dice': [], 'iou': [], 'hd95': []}
for images, masks in dataloader:
    with torch.no_grad():
        predictions = model(images)
    
    # Calculate metrics for each image
    for pred, mask in zip(predictions, masks):
        pred_binary = (torch.sigmoid(pred) > 0.5).float()
        
        # Dice
        intersection = (pred_binary * mask).sum()
        dice = (2. * intersection) / (pred_binary.sum() + mask.sum() + 1e-6)
        
        # IoU
        union = pred_binary.sum() + mask.sum() - intersection
        iou = intersection / (union + 1e-6)
        
        # HD95
        hd95 = compute_hd95(pred_binary.numpy().squeeze(), mask.numpy().squeeze())
        
        all_metrics['dice'].append(dice.item())
        all_metrics['iou'].append(iou.item())
        all_metrics['hd95'].append(hd95)

print(f"Average Dice: {np.mean(all_metrics['dice']):.4f} Β± {np.std(all_metrics['dice']):.4f}")
print(f"Average IoU: {np.mean(all_metrics['iou']):.4f} Β± {np.std(all_metrics['iou']):.4f}")
print(f"Average HD95: {np.mean(all_metrics['hd95']):.2f} Β± {np.std(all_metrics['hd95']):.2f}")
Model Limitations
Input size: Fixed to 256Γ—256 pixels (resize your images accordingly)

Domain: Trained only on colonoscopy images from Kvasir-SEG

Polyp types: May not generalize to all polyp morphologies

Image quality: Best performance with standard white-light colonoscopy images

## Dataset
Trained on the Kvasir-SEG dataset, which contains 1000 polyp images with corresponding ground truth masks from colonoscopy procedures.

## License
This model is released under the MIT License.

## Citation
If you use this model in your research, please cite:

bibtex
@software{dinov3_polyp_seg,
  author = {Amirreza Mehrzadian},
  title = {DINOv3 Polyp Segmentation with U-Net Decoder},
  year = {2024},
  url = {https://huggingface.co/uncleMehrzad/dinov3-polyp-seg}
}
## Acknowledgments
DINOv3 team for the powerful vision backbone

Kvasir-SEG dataset providers for the polyp segmentation data

HuggingFace for model hosting infrastructure





```python
class PolypSegmentationModel(nn.Module):
    """Complete model wrapper matching training architecture"""
    
    def __init__(self, encoder, stem, decoder):
        super().__init__()
        self.encoder = encoder
        self.stem = stem
        self.decoder = decoder
    
    def forward(self, x):
        vit_features = self.encoder(x)
        skip_features = self.stem(x)
        return self.decoder(vit_features, skip_features)
    
    @classmethod
    def from_pretrained(cls, model_path, config, device="cpu"):
        """Load the complete model from checkpoint"""
        checkpoint = torch.load(model_path, map_location=device)
        
        # Initialize components
        encoder = DINOv3Encoder(
            model_name=config.model_name,
            local_path=config.local_model_path,
            freeze=True,
            layers=config.multi_scale_layers
        )
        
        stem = ShallowStem(in_channels=3, base_channels=64)
        
        decoder = UNetDecoder(
            vit_channels=encoder.out_channels,
            stem_channels=[512, 256, 128],
            num_classes=1
        )
        
        # Load weights
        decoder.load_state_dict(checkpoint['decoder_state_dict'])
        stem.load_state_dict(checkpoint['stem_state_dict'])
        
        model = cls(encoder, stem, decoder)
        model.to(device)
        model.eval()
        
        return model