| --- |
| language: en |
| tags: |
| - medical-imaging |
| - polyp-segmentation |
| - dinov3 |
| - vision-transformer |
| - kvasir-seg |
| - colonoscopy |
| - unet |
| datasets: |
| - kmader/kvasir-segmentation |
| metrics: |
| - dice |
| - iou |
| - precision |
| - recall |
| - hd95 |
| library_name: pytorch |
| pipeline_tag: image-segmentation |
| license: mit |
| --- |
| |
| # DINOv3 Polyp Segmentation with U-Net Decoder |
|
|
| ## Model Description |
|
|
| This model performs **polyp segmentation** in colonoscopy images using a frozen DINOv3-ViT-L/16 backbone with multi-scale feature extraction and a U-Net style decoder with skip connections. The model was trained on the Kvasir-SEG dataset. |
|
|
| **Key Features:** |
| - 🏗️ **U-Net architecture**: Skip connections from shallow stem for precise boundary detection |
| - 📐 **Multi-scale features**: Extracts DINOv3 features from layers [5, 11, 17, 20, 23] for rich hierarchical representation |
| - 🩺 **Medical-grade segmentation**: Specifically designed for polyp detection in colonoscopy |
| - 🔒 **Frozen backbone**: Leverages DINOv3's rich visual features without overfitting |
| - 📊 **Comprehensive metrics**: Evaluated with Dice, IoU, Precision, Recall, and HD95 |
| - 🔄 **Cosine annealing**: Uses CosineAnnealingWarmRestarts for better convergence |
|
|
| ## Model Architecture |
| Input Image (256×256×3) |
| ↓ |
| ┌───────────────────────┬──────────────────────┐ |
| │ Shallow Stem │ DINOv3 Encoder │ |
| │ (Trainable) │ (Frozen) │ |
| │ │ │ |
| │ Conv 3→64 (3×3) │ Layers [5,11,17, │ |
| │ Conv 64→128 (stride2)│ 20,23] │ |
| │ Conv 128→256 (stride2)│ Multi-scale concat │ |
| │ Conv 256→512 (stride2)│ 5 × 1024 = 5120 │ |
| └───────┬───────────────┴──────────┬───────────┘ |
| │ Skip Connections │ |
| │ [512, 256, 128] │ |
| ↓ ↓ |
| ┌──────────────────────────────────────────┐ |
| │ U-Net Decoder (Trainable) │ |
| │ │ |
| │ Conv 5120→256 + Skip(512) → ConvBlock │ |
| │ Upsample → Conv 384→128 + Skip(256) │ |
| │ Upsample → Conv 192→64 + Skip(128) │ |
| │ Upsample → Final Conv 64→1 (1×1) │ |
| └──────────────────┬───────────────────────┘ |
| ↓ |
| Segmentation Mask (256×256×1) |
|
|
|
|
|
|
| ## Training Details |
|
|
| | Hyperparameter | Value | |
| |---------------|-------| |
| | Backbone | DINOv3-ViT-L/16 (frozen) | |
| | Multi-scale Layers | [5, 11, 17, 20, 23] | |
| | Input Resolution | 256×256 | |
| | Batch Size | 96 | |
| | Epochs | 150 | |
| | Learning Rate | 1e-4 (initial) | |
| | Min Learning Rate | 1e-6 | |
| | Weight Decay | 1e-4 | |
| | Optimizer | AdamW | |
| | Scheduler | CosineAnnealingWarmRestarts | |
| | Scheduler Config | T_0=10, T_mult=2 | |
| | Loss Function | Focal + Dice (0.7/0.3 weights) | |
| | Focal Loss Gamma | 2.0 | |
| | Focal Loss Alpha | 0.25 | |
| | Trainable Parameters | ~21M (Stem + Decoder) | |
|
|
| ### Data Augmentation |
| - Random 90° rotation |
| - Horizontal/Vertical flips |
| - ShiftScaleRotate (shift=0.05, scale=0.05, rotate=15°) |
| - MotionBlur/GaussianBlur |
| - ColorJitter (brightness, contrast, saturation, hue) |
|
|
| ## Performance Metrics |
|
|
| ### Final Test Set Results |
|
|
| | Metric | Score | |
| |--------|-------| |
| | **Dice Score** | **0.8289 ± 0.0000** | |
| | **IoU** | **0.7078 ± 0.0000** | |
| | **Precision** | 0.7910 ± 0.0000 | |
| | **Recall** | 0.8705 ± 0.0000 | |
| | **HD95 (pixels)** | 45.46 ± 0.00 | |
| | **Best Validation Dice** | 0.7327 | |
|
|
| ### Validation Set Results |
|
|
| | Metric | Score | |
| |--------|-------| |
| | **Dice Score** | **0.8795 ± 0.0304** | |
| | **IoU** | **0.7862 ± 0.0485** | |
| | **Precision** | 0.8846 ± 0.0256 | |
| | **Recall** | 0.8744 ± 0.0351 | |
| | **HD95 (pixels)** | 30.85 ± 8.90 | |
|
|
| ### Training Set Results (Final Epoch) |
|
|
| | Metric | Score | |
| |--------|-------| |
| | **Dice Score** | 0.8747 ± 0.0108 | |
| | **IoU** | 0.7775 ± 0.0170 | |
| | **Precision** | 0.8698 ± 0.0189 | |
| | **Recall** | 0.8801 ± 0.0136 | |
| | **HD95 (pixels)** | 33.91 ± 1.69 | |
|
|
| ## Usage |
|
|
| ### Installation |
|
|
| ```bash |
| pip install torch transformers pillow matplotlib numpy opencv-python albumentations scipy scikit-learn |
| Basic Inference |
| python |
| import torch |
| import numpy as np |
| from PIL import Image |
| import matplotlib.pyplot as plt |
| |
| # Import the model architecture (same as training) |
| from model import DINOv3Encoder, ShallowStem, UNetDecoder, PolypSegmentationModel |
| |
| # Load model |
| model = PolypSegmentationModel.from_pretrained( |
| "your-username/dinov3-polyp-seg", |
| device="cuda" if torch.cuda.is_available() else "cpu" |
| ) |
| |
| # Preprocess image |
| def preprocess_image(image_path, target_size=(256, 256)): |
| image = Image.open(image_path).convert('RGB') |
| image = image.resize(target_size, Image.Resampling.BILINEAR) |
| |
| # Convert to numpy and normalize |
| image_array = np.array(image).astype(np.float32) / 255.0 |
| mean = np.array([0.485, 0.456, 0.406]).reshape(1, 1, 3) |
| std = np.array([0.229, 0.224, 0.225]).reshape(1, 1, 3) |
| image_array = (image_array - mean) / std |
| |
| # Convert to tensor [B, C, H, W] |
| image_tensor = torch.from_numpy(image_array).permute(2, 0, 1).unsqueeze(0) |
| return image_tensor, image |
| |
| # Run inference |
| image_tensor, original_image = preprocess_image("colonoscopy_image.jpg") |
| |
| with torch.no_grad(): |
| prediction = model(image_tensor) |
| mask = torch.sigmoid(prediction) |
| binary_mask = (mask > 0.5).float() |
| mask_np = binary_mask.squeeze().cpu().numpy() |
| |
| # Visualize |
| fig, axes = plt.subplots(1, 3, figsize=(15, 5)) |
| axes[0].imshow(original_image) |
| axes[0].set_title("Input Image") |
| axes[1].imshow(mask_np, cmap='gray') |
| axes[1].set_title("Polyp Segmentation") |
| axes[2].imshow(original_image) |
| axes[2].imshow(mask_np, cmap='Reds', alpha=0.5) |
| axes[2].set_title("Overlay") |
| plt.show() |
| Advanced Usage with Metrics |
| python |
| from scipy.ndimage import morphology |
| |
| def compute_hd95(pred, target): |
| """Compute Hausdorff Distance 95th percentile""" |
| if pred.sum() == 0 or target.sum() == 0: |
| return float('inf') |
| |
| pred_border = pred - morphology.binary_erosion(pred) |
| target_border = target - morphology.binary_erosion(target) |
| |
| pred_coords = np.argwhere(pred_border > 0) |
| target_coords = np.argwhere(target_border > 0) |
| |
| distances = [] |
| for p in pred_coords: |
| dist = np.min(np.sqrt(np.sum((target_coords - p) ** 2, axis=1))) |
| distances.append(dist) |
| |
| return np.percentile(distances, 95) |
| |
| # Batch inference |
| dataloader = DataLoader(dataset, batch_size=16, shuffle=False) |
| |
| all_metrics = {'dice': [], 'iou': [], 'hd95': []} |
| for images, masks in dataloader: |
| with torch.no_grad(): |
| predictions = model(images) |
| |
| # Calculate metrics for each image |
| for pred, mask in zip(predictions, masks): |
| pred_binary = (torch.sigmoid(pred) > 0.5).float() |
| |
| # Dice |
| intersection = (pred_binary * mask).sum() |
| dice = (2. * intersection) / (pred_binary.sum() + mask.sum() + 1e-6) |
| |
| # IoU |
| union = pred_binary.sum() + mask.sum() - intersection |
| iou = intersection / (union + 1e-6) |
| |
| # HD95 |
| hd95 = compute_hd95(pred_binary.numpy().squeeze(), mask.numpy().squeeze()) |
| |
| all_metrics['dice'].append(dice.item()) |
| all_metrics['iou'].append(iou.item()) |
| all_metrics['hd95'].append(hd95) |
| |
| print(f"Average Dice: {np.mean(all_metrics['dice']):.4f} ± {np.std(all_metrics['dice']):.4f}") |
| print(f"Average IoU: {np.mean(all_metrics['iou']):.4f} ± {np.std(all_metrics['iou']):.4f}") |
| print(f"Average HD95: {np.mean(all_metrics['hd95']):.2f} ± {np.std(all_metrics['hd95']):.2f}") |
| Model Limitations |
| Input size: Fixed to 256×256 pixels (resize your images accordingly) |
| |
| Domain: Trained only on colonoscopy images from Kvasir-SEG |
| |
| Polyp types: May not generalize to all polyp morphologies |
| |
| Image quality: Best performance with standard white-light colonoscopy images |
| |
| ## Dataset |
| Trained on the Kvasir-SEG dataset, which contains 1000 polyp images with corresponding ground truth masks from colonoscopy procedures. |
| |
| ## License |
| This model is released under the MIT License. |
| |
| ## Citation |
| If you use this model in your research, please cite: |
| |
| bibtex |
| @software{dinov3_polyp_seg, |
| author = {Amirreza Mehrzadian}, |
| title = {DINOv3 Polyp Segmentation with U-Net Decoder}, |
| year = {2024}, |
| url = {https://huggingface.co/uncleMehrzad/dinov3-polyp-seg} |
| } |
| ## Acknowledgments |
| DINOv3 team for the powerful vision backbone |
| |
| Kvasir-SEG dataset providers for the polyp segmentation data |
| |
| HuggingFace for model hosting infrastructure |
| |
| |
| |
| |
| |
| ```python |
| class PolypSegmentationModel(nn.Module): |
| """Complete model wrapper matching training architecture""" |
| |
| def __init__(self, encoder, stem, decoder): |
| super().__init__() |
| self.encoder = encoder |
| self.stem = stem |
| self.decoder = decoder |
| |
| def forward(self, x): |
| vit_features = self.encoder(x) |
| skip_features = self.stem(x) |
| return self.decoder(vit_features, skip_features) |
| |
| @classmethod |
| def from_pretrained(cls, model_path, config, device="cpu"): |
| """Load the complete model from checkpoint""" |
| checkpoint = torch.load(model_path, map_location=device) |
| |
| # Initialize components |
| encoder = DINOv3Encoder( |
| model_name=config.model_name, |
| local_path=config.local_model_path, |
| freeze=True, |
| layers=config.multi_scale_layers |
| ) |
| |
| stem = ShallowStem(in_channels=3, base_channels=64) |
| |
| decoder = UNetDecoder( |
| vit_channels=encoder.out_channels, |
| stem_channels=[512, 256, 128], |
| num_classes=1 |
| ) |
| |
| # Load weights |
| decoder.load_state_dict(checkpoint['decoder_state_dict']) |
| stem.load_state_dict(checkpoint['stem_state_dict']) |
| |
| model = cls(encoder, stem, decoder) |
| model.to(device) |
| model.eval() |
| |
| return model |