Multi-Domain DANN for Mitosis Detection (MIDOG++)

ResNet50-based mitosis classifier trained with Domain-Adversarial Neural Network (DANN) training on the MIDOG++ dataset. The model generalizes across scanners, labs, species, and tumor types by adversarially suppressing domain-specific features.

Model Description

The architecture uses a pretrained ResNet50 backbone split into 4 named stages. Feature maps from layer2 (512-d), layer3 (1024-d), and layer4 (2048-d) are each globally average-pooled and concatenated into a single 3584-d multi-scale feature vector. This allows the mitosis classifier to draw on both coarse geometric features (spindle shape, nuclear envelope breakdown) and fine texture.

Downstream heads operating on this 3584-d vector:

  • Mitosis head: MLP binary classifier (mitotic / non-mitotic)
  • 4 adversarial domain heads: tumor type, species, origin, scanner, each with its own Gradient Reversal Layer (GRL)
File Description
best_dann_model.pth Multi-stage DANN trained on RGB patches with augmentation

Training Hyperparameters

Parameter Value
Epochs 50
Backbone LR 1e-5
Heads LR 1e-4
Lambda max (GRL) 2.0
Mitotic class weight 2.0
Batch size 32

Training Data

Trained on 224x224 patches cropped around annotated bounding boxes from the MIDOG++ dataset. Patches are stratified 80/20 train/test split per domain combination (Tumor x Scanner x Origin x Species).

Training augmentations: random flips, rotation, RandomResizedCrop, ElasticTransform, ColorJitter, shot noise, Gaussian blur, and defocus blur to simulate cross-scanner variation.

Intended Use

Mitosis detection in H&E-stained whole-slide images, particularly in multi-scanner or multi-site settings where domain shift is a concern.

How to Use

The model must be loaded using the DANNModel class from final_model.py in the accompanying repository.

import torch
from final_model import DANNModel

num_domain_classes = {'Tumor': 7, 'Species': 2, 'Origin': 4, 'Scanner': 4}

model = DANNModel(num_classes=2,
                  num_domain_classes=num_domain_classes,
                  lambda_val=0.0)
model.load_state_dict(torch.load('best_dann_model.pth', map_location='cpu'))
model.eval()

# Inference only (skips domain heads)
with torch.no_grad():
    logits = model.predict_only(image_tensor)
    pred = logits.argmax(dim=1)  # 0 = non-mitotic, 1 = mitotic

Limitations

  • Trained and evaluated on MIDOG++ only; performance on other datasets is not validated
  • Domain adaptation targets the four MIDOG++ domain axes; other sources of variation (e.g. tissue preparation) are not explicitly addressed
  • Input patches must be 224x224 RGB, normalized with ImageNet mean/std
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support