π§ Self-Supervised Vision Transformers for Prostate Cancer Gleason Grading
Summary. This repository provides pretrained self-supervised Vision Transformer (ViT) models (DINO-v1/v2) for automated Gleason grade classification from prostate cancer whole-slide images (WSIs). Beyond high accuracy, the models emphasize reproducibility and potential clinical decision-support value by addressing inter-observer variability and workload in pathology.
π©Ί Clinical Motivation & Impact
Manual Gleason grading is time-consuming and subject to inter-observer variability, which can affect diagnostic consistency and treatment planning. By learning label-efficient, stain-robust representations with self-supervised learning (SSL), these models aim to:
- improve grading reproducibility,
- reduce pathologist workload,
- and facilitate timely, consistent decision-making as part of a decision-support pipeline.
Note: These models are intended for research and development of clinical decision-support systems; they are not cleared for direct clinical use.
π Model Description
- Framework: DINO (self-distillation without labels)
- Backbones evaluated: ViT-B/16 (DINO-v1), ViT-L/14 (DINO-v2)
- Downstream heads: MLP (best), k-NN, CNN heads (e.g., DenseNet)
- Training objective: SSL pretraining on histology tiles, followed by supervised fine-tuning for Gleason classes
- Library: PyTorch / Hugging Face Transformers
π§ͺ Data & Training
- Dataset: TCGA-PRAD
- Cases / slides / patches: 403 patients, 449 diagnostic slides, 81,126 224Γ224 tiles
- Split: patient-level (no patient leakage) β 80% train / 20% test
- Preprocessing: standard WSI tiling; color/stain variability present in TCGA
- Goal: multi-class Gleason grade classification
Classes
3+3, 3+4, 3+5, 4+3, 4+4, 4+5, 5+3, 5+4, 5+5
π Results
Best overall configuration: DINO-v1 ViT-B/16 + MLP
| Model | Backbone | Classifier | Accuracy | Precision | Recall | F1 |
|---|---|---|---|---|---|---|
| DINO-v1 (ViT-B/16) | B/16 | MLP | 90.40 | 90.40 | 90.39 | 90.36 |
| DINO-v1 (ViT-B/16) | B/16 | k-NN | 89.36 | 89.19 | 89.38 | 89.16 |
| DINO-v2 (ViT-L/14) | L/14 | MLP | 83.31 | 83.16 | 83.31 | 83.17 |
| CNN (DenseNet) | β | FC head | 87.86 | 87.17 | 87.86 | 87.20 |
- ROC-AUC: 0.991 (MLP)
- Agreement: Cohenβs ΞΊ = 0.89 (substantial)
- Feature analysis: SSL features showed robustness to stain variation & histologic heterogeneity; features 11 & 13 were most discriminative (feature 19 weak).
Interpretation. SSL-derived ViT features outperformed supervised baselines and were more robust to known sources of variability in computational pathologyβsupporting reproducible grading and potential clinical workflow integration (after appropriate validation).
π Key Contributions
- A unified multi-SSL pipeline (DINO-v1/v2, iBOT, token registration) for Gleason grading.
- Evidence that SSL features improve robustness/generalization over supervised baselines.
- Feature-level statistical validation (e.g., ANOVA, discriminant power).
- Pretrained weights released for reproducibility and benchmarking.
π©Έ Intended Use
Research use only. Suitable for:
- prototyping automated pathology tools,
- benchmarking histopathology classifiers,
- exploring self-supervised learning in medical imaging,
- building decision-support pipelines (with additional validation).
β οΈ Not for clinical use. External, multi-center validation and regulatory clearance are required prior to any deployment impacting patient care.
βοΈ Limitations & Ethical Considerations
- Domain shift: Trained on TCGA-PRAD; performance may vary with scanner type, staining protocol, lab workflow, or demographics.
- Generalization: Requires multi-institutional external validation.
- Fairness & bias: Assess subgroup performance before deployment.
- Human-in-the-loop: Models should augment, not replace, expert pathology review.
π¬ How to Use
from transformers import AutoImageProcessor, ViTForImageClassification
from PIL import Image
import torch
model_id = "buseyaren/self-supervised-prostate-cancer"
processor = AutoImageProcessor.from_pretrained(model_id)
model = ViTForImageClassification.from_pretrained(model_id)
model.eval()
img = Image.open("example_tile.png").convert("RGB")
inputs = processor(images=img, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
probs = logits.softmax(dim=-1)
pred_id = probs.argmax(dim=-1).item()
pred_score = probs[0, pred_id].item()
id2label = model.config.id2label if hasattr(model.config, "id2label") else {}
print("Predicted class:", id2label.get(pred_id, pred_id), f"(p={pred_score:.3f})")
Model tree for buseyaren/self-supervised-prostate-cancer
Base model
facebook/dino-vitb16