--- language: en tags: - medical-imaging - polyp-segmentation - dinov3 - vision-transformer - kvasir-seg - colonoscopy - unet datasets: - kmader/kvasir-segmentation metrics: - dice - iou - precision - recall - hd95 library_name: pytorch pipeline_tag: image-segmentation license: mit --- # DINOv3 Polyp Segmentation with U-Net Decoder ## Model Description This model performs **polyp segmentation** in colonoscopy images using a frozen DINOv3-ViT-L/16 backbone with multi-scale feature extraction and a U-Net style decoder with skip connections. The model was trained on the Kvasir-SEG dataset. **Key Features:** - 🏗️ **U-Net architecture**: Skip connections from shallow stem for precise boundary detection - 📐 **Multi-scale features**: Extracts DINOv3 features from layers [5, 11, 17, 20, 23] for rich hierarchical representation - 🩺 **Medical-grade segmentation**: Specifically designed for polyp detection in colonoscopy - 🔒 **Frozen backbone**: Leverages DINOv3's rich visual features without overfitting - 📊 **Comprehensive metrics**: Evaluated with Dice, IoU, Precision, Recall, and HD95 - 🔄 **Cosine annealing**: Uses CosineAnnealingWarmRestarts for better convergence ## Model Architecture Input Image (256×256×3) ↓ ┌───────────────────────┬──────────────────────┐ │ Shallow Stem │ DINOv3 Encoder │ │ (Trainable) │ (Frozen) │ │ │ │ │ Conv 3→64 (3×3) │ Layers [5,11,17, │ │ Conv 64→128 (stride2)│ 20,23] │ │ Conv 128→256 (stride2)│ Multi-scale concat │ │ Conv 256→512 (stride2)│ 5 × 1024 = 5120 │ └───────┬───────────────┴──────────┬───────────┘ │ Skip Connections │ │ [512, 256, 128] │ ↓ ↓ ┌──────────────────────────────────────────┐ │ U-Net Decoder (Trainable) │ │ │ │ Conv 5120→256 + Skip(512) → ConvBlock │ │ Upsample → Conv 384→128 + Skip(256) │ │ Upsample → Conv 192→64 + Skip(128) │ │ Upsample → Final Conv 64→1 (1×1) │ └──────────────────┬───────────────────────┘ ↓ Segmentation Mask (256×256×1) ## Training Details | Hyperparameter | Value | |---------------|-------| | Backbone | DINOv3-ViT-L/16 (frozen) | | Multi-scale Layers | [5, 11, 17, 20, 23] | | Input Resolution | 256×256 | | Batch Size | 96 | | Epochs | 150 | | Learning Rate | 1e-4 (initial) | | Min Learning Rate | 1e-6 | | Weight Decay | 1e-4 | | Optimizer | AdamW | | Scheduler | CosineAnnealingWarmRestarts | | Scheduler Config | T_0=10, T_mult=2 | | Loss Function | Focal + Dice (0.7/0.3 weights) | | Focal Loss Gamma | 2.0 | | Focal Loss Alpha | 0.25 | | Trainable Parameters | ~21M (Stem + Decoder) | ### Data Augmentation - Random 90° rotation - Horizontal/Vertical flips - ShiftScaleRotate (shift=0.05, scale=0.05, rotate=15°) - MotionBlur/GaussianBlur - ColorJitter (brightness, contrast, saturation, hue) ## Performance Metrics ### Final Test Set Results | Metric | Score | |--------|-------| | **Dice Score** | **0.8289 ± 0.0000** | | **IoU** | **0.7078 ± 0.0000** | | **Precision** | 0.7910 ± 0.0000 | | **Recall** | 0.8705 ± 0.0000 | | **HD95 (pixels)** | 45.46 ± 0.00 | | **Best Validation Dice** | 0.7327 | ### Validation Set Results | Metric | Score | |--------|-------| | **Dice Score** | **0.8795 ± 0.0304** | | **IoU** | **0.7862 ± 0.0485** | | **Precision** | 0.8846 ± 0.0256 | | **Recall** | 0.8744 ± 0.0351 | | **HD95 (pixels)** | 30.85 ± 8.90 | ### Training Set Results (Final Epoch) | Metric | Score | |--------|-------| | **Dice Score** | 0.8747 ± 0.0108 | | **IoU** | 0.7775 ± 0.0170 | | **Precision** | 0.8698 ± 0.0189 | | **Recall** | 0.8801 ± 0.0136 | | **HD95 (pixels)** | 33.91 ± 1.69 | ## Usage ### Installation ```bash pip install torch transformers pillow matplotlib numpy opencv-python albumentations scipy scikit-learn Basic Inference python import torch import numpy as np from PIL import Image import matplotlib.pyplot as plt # Import the model architecture (same as training) from model import DINOv3Encoder, ShallowStem, UNetDecoder, PolypSegmentationModel # Load model model = PolypSegmentationModel.from_pretrained( "your-username/dinov3-polyp-seg", device="cuda" if torch.cuda.is_available() else "cpu" ) # Preprocess image def preprocess_image(image_path, target_size=(256, 256)): image = Image.open(image_path).convert('RGB') image = image.resize(target_size, Image.Resampling.BILINEAR) # Convert to numpy and normalize image_array = np.array(image).astype(np.float32) / 255.0 mean = np.array([0.485, 0.456, 0.406]).reshape(1, 1, 3) std = np.array([0.229, 0.224, 0.225]).reshape(1, 1, 3) image_array = (image_array - mean) / std # Convert to tensor [B, C, H, W] image_tensor = torch.from_numpy(image_array).permute(2, 0, 1).unsqueeze(0) return image_tensor, image # Run inference image_tensor, original_image = preprocess_image("colonoscopy_image.jpg") with torch.no_grad(): prediction = model(image_tensor) mask = torch.sigmoid(prediction) binary_mask = (mask > 0.5).float() mask_np = binary_mask.squeeze().cpu().numpy() # Visualize fig, axes = plt.subplots(1, 3, figsize=(15, 5)) axes[0].imshow(original_image) axes[0].set_title("Input Image") axes[1].imshow(mask_np, cmap='gray') axes[1].set_title("Polyp Segmentation") axes[2].imshow(original_image) axes[2].imshow(mask_np, cmap='Reds', alpha=0.5) axes[2].set_title("Overlay") plt.show() Advanced Usage with Metrics python from scipy.ndimage import morphology def compute_hd95(pred, target): """Compute Hausdorff Distance 95th percentile""" if pred.sum() == 0 or target.sum() == 0: return float('inf') pred_border = pred - morphology.binary_erosion(pred) target_border = target - morphology.binary_erosion(target) pred_coords = np.argwhere(pred_border > 0) target_coords = np.argwhere(target_border > 0) distances = [] for p in pred_coords: dist = np.min(np.sqrt(np.sum((target_coords - p) ** 2, axis=1))) distances.append(dist) return np.percentile(distances, 95) # Batch inference dataloader = DataLoader(dataset, batch_size=16, shuffle=False) all_metrics = {'dice': [], 'iou': [], 'hd95': []} for images, masks in dataloader: with torch.no_grad(): predictions = model(images) # Calculate metrics for each image for pred, mask in zip(predictions, masks): pred_binary = (torch.sigmoid(pred) > 0.5).float() # Dice intersection = (pred_binary * mask).sum() dice = (2. * intersection) / (pred_binary.sum() + mask.sum() + 1e-6) # IoU union = pred_binary.sum() + mask.sum() - intersection iou = intersection / (union + 1e-6) # HD95 hd95 = compute_hd95(pred_binary.numpy().squeeze(), mask.numpy().squeeze()) all_metrics['dice'].append(dice.item()) all_metrics['iou'].append(iou.item()) all_metrics['hd95'].append(hd95) print(f"Average Dice: {np.mean(all_metrics['dice']):.4f} ± {np.std(all_metrics['dice']):.4f}") print(f"Average IoU: {np.mean(all_metrics['iou']):.4f} ± {np.std(all_metrics['iou']):.4f}") print(f"Average HD95: {np.mean(all_metrics['hd95']):.2f} ± {np.std(all_metrics['hd95']):.2f}") Model Limitations Input size: Fixed to 256×256 pixels (resize your images accordingly) Domain: Trained only on colonoscopy images from Kvasir-SEG Polyp types: May not generalize to all polyp morphologies Image quality: Best performance with standard white-light colonoscopy images ## Dataset Trained on the Kvasir-SEG dataset, which contains 1000 polyp images with corresponding ground truth masks from colonoscopy procedures. ## License This model is released under the MIT License. ## Citation If you use this model in your research, please cite: bibtex @software{dinov3_polyp_seg, author = {Amirreza Mehrzadian}, title = {DINOv3 Polyp Segmentation with U-Net Decoder}, year = {2024}, url = {https://huggingface.co/uncleMehrzad/dinov3-polyp-seg} } ## Acknowledgments DINOv3 team for the powerful vision backbone Kvasir-SEG dataset providers for the polyp segmentation data HuggingFace for model hosting infrastructure ```python class PolypSegmentationModel(nn.Module): """Complete model wrapper matching training architecture""" def __init__(self, encoder, stem, decoder): super().__init__() self.encoder = encoder self.stem = stem self.decoder = decoder def forward(self, x): vit_features = self.encoder(x) skip_features = self.stem(x) return self.decoder(vit_features, skip_features) @classmethod def from_pretrained(cls, model_path, config, device="cpu"): """Load the complete model from checkpoint""" checkpoint = torch.load(model_path, map_location=device) # Initialize components encoder = DINOv3Encoder( model_name=config.model_name, local_path=config.local_model_path, freeze=True, layers=config.multi_scale_layers ) stem = ShallowStem(in_channels=3, base_channels=64) decoder = UNetDecoder( vit_channels=encoder.out_channels, stem_channels=[512, 256, 128], num_classes=1 ) # Load weights decoder.load_state_dict(checkpoint['decoder_state_dict']) stem.load_state_dict(checkpoint['stem_state_dict']) model = cls(encoder, stem, decoder) model.to(device) model.eval() return model