Multi-Domain DANN for Mitosis Detection (MIDOG++)
ResNet50-based mitosis classifier trained with Domain-Adversarial Neural Network (DANN) training on the MIDOG++ dataset. The model generalizes across scanners, labs, species, and tumor types by adversarially suppressing domain-specific features.
Model Description
The architecture uses a pretrained ResNet50 backbone split into 4 named stages. Feature maps from layer2 (512-d), layer3 (1024-d), and layer4 (2048-d) are each globally average-pooled and concatenated into a single 3584-d multi-scale feature vector. This allows the mitosis classifier to draw on both coarse geometric features (spindle shape, nuclear envelope breakdown) and fine texture.
Downstream heads operating on this 3584-d vector:
- Mitosis head: MLP binary classifier (mitotic / non-mitotic)
- 4 adversarial domain heads: tumor type, species, origin, scanner, each with its own Gradient Reversal Layer (GRL)
| File | Description |
|---|---|
best_dann_model.pth |
Multi-stage DANN trained on RGB patches with augmentation |
Training Hyperparameters
| Parameter | Value |
|---|---|
| Epochs | 50 |
| Backbone LR | 1e-5 |
| Heads LR | 1e-4 |
| Lambda max (GRL) | 2.0 |
| Mitotic class weight | 2.0 |
| Batch size | 32 |
Training Data
Trained on 224x224 patches cropped around annotated bounding boxes from the MIDOG++ dataset. Patches are stratified 80/20 train/test split per domain combination (Tumor x Scanner x Origin x Species).
Training augmentations: random flips, rotation, RandomResizedCrop, ElasticTransform, ColorJitter, shot noise, Gaussian blur, and defocus blur to simulate cross-scanner variation.
Intended Use
Mitosis detection in H&E-stained whole-slide images, particularly in multi-scanner or multi-site settings where domain shift is a concern.
How to Use
The model must be loaded using the DANNModel class from final_model.py in the accompanying repository.
import torch
from final_model import DANNModel
num_domain_classes = {'Tumor': 7, 'Species': 2, 'Origin': 4, 'Scanner': 4}
model = DANNModel(num_classes=2,
num_domain_classes=num_domain_classes,
lambda_val=0.0)
model.load_state_dict(torch.load('best_dann_model.pth', map_location='cpu'))
model.eval()
# Inference only (skips domain heads)
with torch.no_grad():
logits = model.predict_only(image_tensor)
pred = logits.argmax(dim=1) # 0 = non-mitotic, 1 = mitotic
Limitations
- Trained and evaluated on MIDOG++ only; performance on other datasets is not validated
- Domain adaptation targets the four MIDOG++ domain axes; other sources of variation (e.g. tissue preparation) are not explicitly addressed
- Input patches must be 224x224 RGB, normalized with ImageNet mean/std