uncleMehrzad commited on
Commit
285bfbb
·
verified ·
1 Parent(s): 9759118

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +301 -0
README.md CHANGED
@@ -1,3 +1,304 @@
1
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  license: mit
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ language: en
3
+ tags:
4
+ - medical-imaging
5
+ - polyp-segmentation
6
+ - dinov3
7
+ - vision-transformer
8
+ - kvasir-seg
9
+ - colonoscopy
10
+ - unet
11
+ datasets:
12
+ - kmader/kvasir-segmentation
13
+ metrics:
14
+ - dice
15
+ - iou
16
+ - precision
17
+ - recall
18
+ - hd95
19
+ library_name: pytorch
20
+ pipeline_tag: image-segmentation
21
  license: mit
22
  ---
23
+
24
+ # DINOv3 Polyp Segmentation with U-Net Decoder
25
+
26
+ ## Model Description
27
+
28
+ This model performs **polyp segmentation** in colonoscopy images using a frozen DINOv3-ViT-L/16 backbone with multi-scale feature extraction and a U-Net style decoder with skip connections. The model was trained on the Kvasir-SEG dataset.
29
+
30
+ **Key Features:**
31
+ - 🏗️ **U-Net architecture**: Skip connections from shallow stem for precise boundary detection
32
+ - 📐 **Multi-scale features**: Extracts DINOv3 features from layers [5, 11, 17, 20, 23] for rich hierarchical representation
33
+ - 🩺 **Medical-grade segmentation**: Specifically designed for polyp detection in colonoscopy
34
+ - 🔒 **Frozen backbone**: Leverages DINOv3's rich visual features without overfitting
35
+ - 📊 **Comprehensive metrics**: Evaluated with Dice, IoU, Precision, Recall, and HD95
36
+ - 🔄 **Cosine annealing**: Uses CosineAnnealingWarmRestarts for better convergence
37
+
38
+ ## Model Architecture
39
+ Input Image (256×256×3)
40
+
41
+ ┌───────────────────────┬──────────────────────┐
42
+ │ Shallow Stem │ DINOv3 Encoder │
43
+ │ (Trainable) │ (Frozen) │
44
+ │ │ │
45
+ │ Conv 3→64 (3×3) │ Layers [5,11,17, │
46
+ │ Conv 64→128 (stride2)│ 20,23] │
47
+ │ Conv 128→256 (stride2)│ Multi-scale concat │
48
+ │ Conv 256→512 (stride2)│ 5 × 1024 = 5120 │
49
+ └───────┬───────────────┴──────────┬───────────┘
50
+ │ Skip Connections │
51
+ │ [512, 256, 128] │
52
+ ↓ ↓
53
+ ┌──────────────────────────────────────────┐
54
+ │ U-Net Decoder (Trainable) │
55
+ │ │
56
+ │ Conv 5120→256 + Skip(512) → ConvBlock │
57
+ │ Upsample → Conv 384→128 + Skip(256) │
58
+ │ Upsample → Conv 192→64 + Skip(128) │
59
+ │ Upsample → Final Conv 64→1 (1×1) │
60
+ └──────────────────┬───────────────────────┘
61
+
62
+ Segmentation Mask (256×256×1)
63
+
64
+
65
+
66
+ ## Training Details
67
+
68
+ | Hyperparameter | Value |
69
+ |---------------|-------|
70
+ | Backbone | DINOv3-ViT-L/16 (frozen) |
71
+ | Multi-scale Layers | [5, 11, 17, 20, 23] |
72
+ | Input Resolution | 256×256 |
73
+ | Batch Size | 32 |
74
+ | Epochs | 100 |
75
+ | Learning Rate | 1e-4 (initial) |
76
+ | Min Learning Rate | 1e-6 |
77
+ | Weight Decay | 1e-4 |
78
+ | Optimizer | AdamW |
79
+ | Scheduler | CosineAnnealingWarmRestarts |
80
+ | Scheduler Config | T_0=10, T_mult=2 |
81
+ | Loss Function | Focal + Dice (0.7/0.3 weights) |
82
+ | Focal Loss Gamma | 2.0 |
83
+ | Focal Loss Alpha | 0.25 |
84
+ | Trainable Parameters | ~8.5M (Stem + Decoder) |
85
+
86
+ ### Data Augmentation
87
+ - Random 90° rotation
88
+ - Horizontal/Vertical flips
89
+ - ShiftScaleRotate (shift=0.05, scale=0.05, rotate=15°)
90
+ - MotionBlur/GaussianBlur
91
+ - ColorJitter (brightness, contrast, saturation, hue)
92
+
93
+ ## Performance Metrics
94
+
95
+ ### Final Test Set Results
96
+
97
+ | Metric | Score |
98
+ |--------|-------|
99
+ | **Dice Score** | **{test_dice:.4f} ± {test_dice_std:.4f}** |
100
+ | **IoU** | **{test_iou:.4f} ± {test_iou_std:.4f}** |
101
+ | **Precision** | {test_precision:.4f} ± {test_precision_std:.4f} |
102
+ | **Recall** | {test_recall:.4f} ± {test_recall_std:.4f} |
103
+ | **HD95 (pixels)** | {test_hd95:.2f} ± {test_hd95_std:.2f} |
104
+ | **Best Validation Dice** | {best_dice:.4f} |
105
+
106
+ ### Validation Set Results
107
+
108
+ | Metric | Score |
109
+ |--------|-------|
110
+ | **Dice Score** | {val_dice:.4f} ± {val_dice_std:.4f} |
111
+ | **IoU** | {val_iou:.4f} ± {val_iou_std:.4f} |
112
+ | **Precision** | {val_precision:.4f} ± {val_precision_std:.4f} |
113
+ | **Recall** | {val_recall:.4f} ± {val_recall_std:.4f} |
114
+ | **HD95 (pixels)** | {val_hd95:.2f} ± {val_hd95_std:.2f} |
115
+
116
+ ## Usage
117
+
118
+ ### Installation
119
+
120
+ ```bash
121
+ pip install torch transformers pillow matplotlib numpy opencv-python albumentations scipy scikit-learn
122
+ Basic Inference
123
+ python
124
+ import torch
125
+ import numpy as np
126
+ from PIL import Image
127
+ import matplotlib.pyplot as plt
128
+
129
+ # Import the model architecture (same as training)
130
+ from model import DINOv3Encoder, ShallowStem, UNetDecoder, PolypSegmentationModel
131
+
132
+ # Load model
133
+ model = PolypSegmentationModel.from_pretrained(
134
+ "your-username/dinov3-polyp-seg",
135
+ device="cuda" if torch.cuda.is_available() else "cpu"
136
+ )
137
+
138
+ # Preprocess image
139
+ def preprocess_image(image_path, target_size=(256, 256)):
140
+ image = Image.open(image_path).convert('RGB')
141
+ image = image.resize(target_size, Image.Resampling.BILINEAR)
142
+
143
+ # Convert to numpy and normalize
144
+ image_array = np.array(image).astype(np.float32) / 255.0
145
+ mean = np.array([0.485, 0.456, 0.406]).reshape(1, 1, 3)
146
+ std = np.array([0.229, 0.224, 0.225]).reshape(1, 1, 3)
147
+ image_array = (image_array - mean) / std
148
+
149
+ # Convert to tensor [B, C, H, W]
150
+ image_tensor = torch.from_numpy(image_array).permute(2, 0, 1).unsqueeze(0)
151
+ return image_tensor, image
152
+
153
+ # Run inference
154
+ image_tensor, original_image = preprocess_image("colonoscopy_image.jpg")
155
+
156
+ with torch.no_grad():
157
+ prediction = model(image_tensor)
158
+ mask = torch.sigmoid(prediction)
159
+ binary_mask = (mask > 0.5).float()
160
+ mask_np = binary_mask.squeeze().cpu().numpy()
161
+
162
+ # Visualize
163
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
164
+ axes[0].imshow(original_image)
165
+ axes[0].set_title("Input Image")
166
+ axes[1].imshow(mask_np, cmap='gray')
167
+ axes[1].set_title("Polyp Segmentation")
168
+ axes[2].imshow(original_image)
169
+ axes[2].imshow(mask_np, cmap='Reds', alpha=0.5)
170
+ axes[2].set_title("Overlay")
171
+ plt.show()
172
+ Advanced Usage with Metrics
173
+ python
174
+ from scipy.ndimage import morphology
175
+
176
+ def compute_hd95(pred, target):
177
+ """Compute Hausdorff Distance 95th percentile"""
178
+ if pred.sum() == 0 or target.sum() == 0:
179
+ return float('inf')
180
+
181
+ pred_border = pred - morphology.binary_erosion(pred)
182
+ target_border = target - morphology.binary_erosion(target)
183
+
184
+ pred_coords = np.argwhere(pred_border > 0)
185
+ target_coords = np.argwhere(target_border > 0)
186
+
187
+ distances = []
188
+ for p in pred_coords:
189
+ dist = np.min(np.sqrt(np.sum((target_coords - p) ** 2, axis=1)))
190
+ distances.append(dist)
191
+
192
+ return np.percentile(distances, 95)
193
+
194
+ # Batch inference
195
+ dataloader = DataLoader(dataset, batch_size=16, shuffle=False)
196
+
197
+ all_metrics = {'dice': [], 'iou': [], 'hd95': []}
198
+ for images, masks in dataloader:
199
+ with torch.no_grad():
200
+ predictions = model(images)
201
+
202
+ # Calculate metrics for each image
203
+ for pred, mask in zip(predictions, masks):
204
+ pred_binary = (torch.sigmoid(pred) > 0.5).float()
205
+
206
+ # Dice
207
+ intersection = (pred_binary * mask).sum()
208
+ dice = (2. * intersection) / (pred_binary.sum() + mask.sum() + 1e-6)
209
+
210
+ # IoU
211
+ union = pred_binary.sum() + mask.sum() - intersection
212
+ iou = intersection / (union + 1e-6)
213
+
214
+ # HD95
215
+ hd95 = compute_hd95(pred_binary.numpy().squeeze(), mask.numpy().squeeze())
216
+
217
+ all_metrics['dice'].append(dice.item())
218
+ all_metrics['iou'].append(iou.item())
219
+ all_metrics['hd95'].append(hd95)
220
+
221
+ print(f"Average Dice: {np.mean(all_metrics['dice']):.4f} ± {np.std(all_metrics['dice']):.4f}")
222
+ print(f"Average IoU: {np.mean(all_metrics['iou']):.4f} ± {np.std(all_metrics['iou']):.4f}")
223
+ print(f"Average HD95: {np.mean(all_metrics['hd95']):.2f} ± {np.std(all_metrics['hd95']):.2f}")
224
+ Model Limitations
225
+ Input size: Fixed to 256×256 pixels (resize your images accordingly)
226
+
227
+ Domain: Trained only on colonoscopy images from Kvasir-SEG
228
+
229
+ Polyp types: May not generalize to all polyp morphologies
230
+
231
+ Image quality: Best performance with standard white-light colonoscopy images
232
+
233
+ ## Dataset
234
+ Trained on the Kvasir-SEG dataset, which contains 1000 polyp images with corresponding ground truth masks from colonoscopy procedures.
235
+
236
+ ## License
237
+ This model is released under the MIT License.
238
+
239
+ ## Citation
240
+ If you use this model in your research, please cite:
241
+
242
+ bibtex
243
+ @software{dinov3_polyp_seg,
244
+ author = {Your Name},
245
+ title = {DINOv3 Polyp Segmentation with U-Net Decoder},
246
+ year = {2024},
247
+ url = {https://huggingface.co/your-username/dinov3-polyp-seg}
248
+ }
249
+ ## Acknowledgments
250
+ DINOv3 team for the powerful vision backbone
251
+
252
+ Kvasir-SEG dataset providers for the polyp segmentation data
253
+
254
+ HuggingFace for model hosting infrastructure
255
+
256
+
257
+
258
+
259
+
260
+ ```python
261
+ class PolypSegmentationModel(nn.Module):
262
+ """Complete model wrapper matching training architecture"""
263
+
264
+ def __init__(self, encoder, stem, decoder):
265
+ super().__init__()
266
+ self.encoder = encoder
267
+ self.stem = stem
268
+ self.decoder = decoder
269
+
270
+ def forward(self, x):
271
+ vit_features = self.encoder(x)
272
+ skip_features = self.stem(x)
273
+ return self.decoder(vit_features, skip_features)
274
+
275
+ @classmethod
276
+ def from_pretrained(cls, model_path, config, device="cpu"):
277
+ """Load the complete model from checkpoint"""
278
+ checkpoint = torch.load(model_path, map_location=device)
279
+
280
+ # Initialize components
281
+ encoder = DINOv3Encoder(
282
+ model_name=config.model_name,
283
+ local_path=config.local_model_path,
284
+ freeze=True,
285
+ layers=config.multi_scale_layers
286
+ )
287
+
288
+ stem = ShallowStem(in_channels=3, base_channels=64)
289
+
290
+ decoder = UNetDecoder(
291
+ vit_channels=encoder.out_channels,
292
+ stem_channels=[512, 256, 128],
293
+ num_classes=1
294
+ )
295
+
296
+ # Load weights
297
+ decoder.load_state_dict(checkpoint['decoder_state_dict'])
298
+ stem.load_state_dict(checkpoint['stem_state_dict'])
299
+
300
+ model = cls(encoder, stem, decoder)
301
+ model.to(device)
302
+ model.eval()
303
+
304
+ return model