dheeren-tejani's picture
Initial commit
4f95b95
---
license: apache-2.0
language: en
library_name: pytorch
tags:
- image-classification
- medical-imaging
- diabetic-retinopathy
- pytorch
- timm
- efficientnet
datasets:
- aptos2019-blindness-detection
widget:
- src: gradcam_visualizations/gradcam_sample_003.png
example_title: No DR Example
- src: gradcam_visualizations/gradcam_sample_007.png
example_title: Severe DR Example
---
# Diabetic Retinopathy Grading Model (V2)
This is a multi-task deep learning model trained to classify the severity of Diabetic Retinopathy (DR) from retinal fundus images. It is based on the **EfficientNet-B3** architecture and was specifically optimized to improve the **Quadratic Weighted Kappa (QWK)** score, a clinically relevant metric for ordinal classification tasks like DR grading.
This model is the second iteration (V2) of a project focused on building a diagnostically "smarter" classifier that is more sensitive to severe, vision-threatening stages of the disease.
## Model Details
- **Architecture:** `timm/efficientnet_b3` backbone with a custom multi-task head.
- **Input Size:** 512x512 pixels.
- **Output:** A dictionary containing logits for three tasks:
- `severity`: 5 classes (0: No DR, 1: Mild, 2: Moderate, 3: Severe, 4: Proliferative).
- `lesions`: 5 classes (multi-label for various lesion types).
- `regions`: 5 classes (multi-label for affected anatomical regions).
- **Training Objective:** The model was trained focusing only on the `severity` task by setting the loss weights for auxiliary tasks to zero. The auxiliary heads can still produce outputs for interpretability.
## How to Get Started & Use
The model can be easily loaded from Hugging Face Hub for inference.
```bash
# Install required libraries
pip install torch torchvision timm albumentations huggingface-hub numpy pillow opencv-python
```
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from PIL import Image
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from huggingface_hub import hf_hub_download
# Define the model architecture
class MultiTaskDRModel(nn.Module):
def __init__(self, model_name='efficientnet_b3', num_classes=5,
num_lesion_types=5, num_regions=5, pretrained=False):
super(MultiTaskDRModel, self).__init__()
self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
self.feature_dim = self.backbone.num_features
self.attention = nn.Sequential(
nn.AdaptiveAvgPool2d(1), nn.Flatten(),
nn.Linear(self.feature_dim, self.feature_dim // 8), nn.ReLU(inplace=True),
nn.Linear(self.feature_dim // 8, self.feature_dim), nn.Sigmoid()
)
self.feature_norm = nn.BatchNorm1d(self.feature_dim)
self.dropout = nn.Dropout(0.4)
self.severity_classifier = nn.Sequential(
nn.Linear(self.feature_dim, self.feature_dim // 2), nn.ReLU(inplace=True),
nn.Dropout(0.2), nn.Linear(self.feature_dim // 2, num_classes)
)
self.lesion_detector = nn.Sequential(
nn.Linear(self.feature_dim, self.feature_dim // 4), nn.ReLU(inplace=True),
nn.Dropout(0.2), nn.Linear(self.feature_dim // 4, num_lesion_types)
)
self.region_predictor = nn.Sequential(
nn.Linear(self.feature_dim, self.feature_dim // 4), nn.ReLU(inplace=True),
nn.Dropout(0.2), nn.Linear(self.feature_dim // 4, num_regions)
)
def forward(self, x):
features = self.backbone.forward_features(x)
pooled_features = F.adaptive_avg_pool2d(features, 1).flatten(1)
attention_weights = self.attention(pooled_features.unsqueeze(-1).unsqueeze(-1))
features = pooled_features * attention_weights
features = self.feature_norm(features)
features = self.dropout(features)
severity_logits = self.severity_classifier(features)
lesion_logits = self.lesion_detector(features)
region_logits = self.region_predictor(features)
return {
'severity': severity_logits,
'lesions': lesion_logits,
'regions': region_logits,
'features': features
}
# Load the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MultiTaskDRModel()
# Download and load the checkpoint
model_path = hf_hub_download(
repo_id="dheeren-tejani/DiabeticRetinpathyClassifier",
filename="best_model_v2.pth"
)
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
print("Model loaded successfully!")
# Preprocessing function
def preprocess_image(image_path):
transforms = A.Compose([
A.Resize(512, 512),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
])
image = np.array(Image.open(image_path).convert("RGB"))
image_tensor = transforms(image=image)['image'].unsqueeze(0)
return image_tensor
# Example inference
def predict_dr_severity(image_path):
image_tensor = preprocess_image(image_path).to(device)
with torch.no_grad():
outputs = model(image_tensor)
# Get severity prediction
severity_probs = torch.softmax(outputs['severity'], dim=1)
predicted_class = torch.argmax(severity_probs, dim=1).item()
confidence = severity_probs[0, predicted_class].item()
severity_labels = {
0: "No DR",
1: "Mild DR",
2: "Moderate DR",
3: "Severe DR",
4: "Proliferative DR"
}
return {
'predicted_severity': severity_labels[predicted_class],
'confidence': confidence,
'all_probabilities': severity_probs[0].cpu().numpy()
}
# Example usage
# result = predict_dr_severity("path/to/your/fundus_image.jpg")
# print(f"Predicted: {result['predicted_severity']} (Confidence: {result['confidence']:.3f})")
```
## Training Details
### V2 Improvements
This model (V2) was specifically designed to address the shortcomings of a baseline model (V1) that struggled with severe-stage DR detection:
- **Higher Resolution:** Increased from 224×224 to 512×512 to capture finer pathological details
- **Class Balancing:** Implemented WeightedRandomSampler to oversample rare minority classes (Severe and Proliferative DR)
- **Focal Loss:** Replaced standard Cross-Entropy with Focal Loss (γ=2.0) to focus on hard-to-classify examples
- **Focused Training:** Set auxiliary task weights to zero, dedicating full model capacity to severity classification
### Hyperparameters
- **Optimizer:** AdamW
- **Learning Rate:** 1e-4
- **Scheduler:** CosineAnnealingWarmRestarts (T_MAX=10)
- **Batch Size:** 16
- **Epochs:** 17 (Early stopping)
- **Image Size:** 512×512
## Performance
The model was evaluated on a held-out validation set of 735 images:
| Metric | Score |
|--------|-------|
| **Quadratic Weighted Kappa (QWK)** | **0.796** |
| Accuracy | 65.0% |
| F1-Score (Weighted) | 66.3% |
| F1-Score (Macro) | 53.5% |
### Key Achievement
The V2 model achieved a **+3.5% improvement in QWK** over the V1 baseline (0.761), indicating it makes "smarter" errors that are more aligned with clinical judgment, despite lower overall accuracy. This trade-off prioritizes clinically relevant performance over naive accuracy.
## Limitations
⚠️ **Important Disclaimers:**
- This model was trained on a single public dataset and may not generalize to different clinical settings, camera types, or patient demographics
- The dataset may contain inherent demographic biases
- **This is NOT a medical device** and should not be used for actual clinical diagnosis
- Always consult qualified healthcare professionals for medical decisions
## Citation
If you use this model in your research, please cite:
```bibtex
@misc{dheerentejani2025dr,
author = {Dheeren Tejani},
title = {Diabetic Retinopathy Grading Model V2},
year = {2025},
publisher = {Hugging Face},
journal = {Hugging Face Model Hub},
howpublished = {\url{https://huggingface.co/dheeren-tejani/DiabeticRetinpathyClassifier}},
}
```
## License
This model is released under the Apache 2.0 License.