| | --- |
| | license: apache-2.0 |
| | language: en |
| | library_name: pytorch |
| | tags: |
| | - image-classification |
| | - medical-imaging |
| | - diabetic-retinopathy |
| | - pytorch |
| | - timm |
| | - efficientnet |
| | datasets: |
| | - aptos2019-blindness-detection |
| | widget: |
| | - src: gradcam_visualizations/gradcam_sample_003.png |
| | example_title: No DR Example |
| | - src: gradcam_visualizations/gradcam_sample_007.png |
| | example_title: Severe DR Example |
| | --- |
| | |
| | # Diabetic Retinopathy Grading Model (V2) |
| |
|
| | This is a multi-task deep learning model trained to classify the severity of Diabetic Retinopathy (DR) from retinal fundus images. It is based on the **EfficientNet-B3** architecture and was specifically optimized to improve the **Quadratic Weighted Kappa (QWK)** score, a clinically relevant metric for ordinal classification tasks like DR grading. |
| |
|
| | This model is the second iteration (V2) of a project focused on building a diagnostically "smarter" classifier that is more sensitive to severe, vision-threatening stages of the disease. |
| |
|
| | ## Model Details |
| |
|
| | - **Architecture:** `timm/efficientnet_b3` backbone with a custom multi-task head. |
| | - **Input Size:** 512x512 pixels. |
| | - **Output:** A dictionary containing logits for three tasks: |
| | - `severity`: 5 classes (0: No DR, 1: Mild, 2: Moderate, 3: Severe, 4: Proliferative). |
| | - `lesions`: 5 classes (multi-label for various lesion types). |
| | - `regions`: 5 classes (multi-label for affected anatomical regions). |
| | - **Training Objective:** The model was trained focusing only on the `severity` task by setting the loss weights for auxiliary tasks to zero. The auxiliary heads can still produce outputs for interpretability. |
| |
|
| | ## How to Get Started & Use |
| |
|
| | The model can be easily loaded from Hugging Face Hub for inference. |
| |
|
| | ```bash |
| | # Install required libraries |
| | pip install torch torchvision timm albumentations huggingface-hub numpy pillow opencv-python |
| | ``` |
| |
|
| | ```python |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import timm |
| | from PIL import Image |
| | import numpy as np |
| | import albumentations as A |
| | from albumentations.pytorch import ToTensorV2 |
| | from huggingface_hub import hf_hub_download |
| | |
| | # Define the model architecture |
| | class MultiTaskDRModel(nn.Module): |
| | def __init__(self, model_name='efficientnet_b3', num_classes=5, |
| | num_lesion_types=5, num_regions=5, pretrained=False): |
| | super(MultiTaskDRModel, self).__init__() |
| | self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0) |
| | self.feature_dim = self.backbone.num_features |
| | |
| | self.attention = nn.Sequential( |
| | nn.AdaptiveAvgPool2d(1), nn.Flatten(), |
| | nn.Linear(self.feature_dim, self.feature_dim // 8), nn.ReLU(inplace=True), |
| | nn.Linear(self.feature_dim // 8, self.feature_dim), nn.Sigmoid() |
| | ) |
| | |
| | self.feature_norm = nn.BatchNorm1d(self.feature_dim) |
| | self.dropout = nn.Dropout(0.4) |
| | |
| | self.severity_classifier = nn.Sequential( |
| | nn.Linear(self.feature_dim, self.feature_dim // 2), nn.ReLU(inplace=True), |
| | nn.Dropout(0.2), nn.Linear(self.feature_dim // 2, num_classes) |
| | ) |
| | |
| | self.lesion_detector = nn.Sequential( |
| | nn.Linear(self.feature_dim, self.feature_dim // 4), nn.ReLU(inplace=True), |
| | nn.Dropout(0.2), nn.Linear(self.feature_dim // 4, num_lesion_types) |
| | ) |
| | |
| | self.region_predictor = nn.Sequential( |
| | nn.Linear(self.feature_dim, self.feature_dim // 4), nn.ReLU(inplace=True), |
| | nn.Dropout(0.2), nn.Linear(self.feature_dim // 4, num_regions) |
| | ) |
| | |
| | def forward(self, x): |
| | features = self.backbone.forward_features(x) |
| | pooled_features = F.adaptive_avg_pool2d(features, 1).flatten(1) |
| | attention_weights = self.attention(pooled_features.unsqueeze(-1).unsqueeze(-1)) |
| | features = pooled_features * attention_weights |
| | features = self.feature_norm(features) |
| | features = self.dropout(features) |
| | |
| | severity_logits = self.severity_classifier(features) |
| | lesion_logits = self.lesion_detector(features) |
| | region_logits = self.region_predictor(features) |
| | |
| | return { |
| | 'severity': severity_logits, |
| | 'lesions': lesion_logits, |
| | 'regions': region_logits, |
| | 'features': features |
| | } |
| | |
| | # Load the model |
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | model = MultiTaskDRModel() |
| | |
| | # Download and load the checkpoint |
| | model_path = hf_hub_download( |
| | repo_id="dheeren-tejani/DiabeticRetinpathyClassifier", |
| | filename="best_model_v2.pth" |
| | ) |
| | checkpoint = torch.load(model_path, map_location=device, weights_only=False) |
| | model.load_state_dict(checkpoint['model_state_dict']) |
| | model.to(device) |
| | model.eval() |
| | |
| | print("Model loaded successfully!") |
| | |
| | # Preprocessing function |
| | def preprocess_image(image_path): |
| | transforms = A.Compose([ |
| | A.Resize(512, 512), |
| | A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| | ToTensorV2(), |
| | ]) |
| | image = np.array(Image.open(image_path).convert("RGB")) |
| | image_tensor = transforms(image=image)['image'].unsqueeze(0) |
| | return image_tensor |
| | |
| | # Example inference |
| | def predict_dr_severity(image_path): |
| | image_tensor = preprocess_image(image_path).to(device) |
| | |
| | with torch.no_grad(): |
| | outputs = model(image_tensor) |
| | |
| | # Get severity prediction |
| | severity_probs = torch.softmax(outputs['severity'], dim=1) |
| | predicted_class = torch.argmax(severity_probs, dim=1).item() |
| | confidence = severity_probs[0, predicted_class].item() |
| | |
| | severity_labels = { |
| | 0: "No DR", |
| | 1: "Mild DR", |
| | 2: "Moderate DR", |
| | 3: "Severe DR", |
| | 4: "Proliferative DR" |
| | } |
| | |
| | return { |
| | 'predicted_severity': severity_labels[predicted_class], |
| | 'confidence': confidence, |
| | 'all_probabilities': severity_probs[0].cpu().numpy() |
| | } |
| | |
| | # Example usage |
| | # result = predict_dr_severity("path/to/your/fundus_image.jpg") |
| | # print(f"Predicted: {result['predicted_severity']} (Confidence: {result['confidence']:.3f})") |
| | ``` |
| |
|
| | ## Training Details |
| |
|
| | ### V2 Improvements |
| | This model (V2) was specifically designed to address the shortcomings of a baseline model (V1) that struggled with severe-stage DR detection: |
| |
|
| | - **Higher Resolution:** Increased from 224×224 to 512×512 to capture finer pathological details |
| | - **Class Balancing:** Implemented WeightedRandomSampler to oversample rare minority classes (Severe and Proliferative DR) |
| | - **Focal Loss:** Replaced standard Cross-Entropy with Focal Loss (γ=2.0) to focus on hard-to-classify examples |
| | - **Focused Training:** Set auxiliary task weights to zero, dedicating full model capacity to severity classification |
| |
|
| | ### Hyperparameters |
| | - **Optimizer:** AdamW |
| | - **Learning Rate:** 1e-4 |
| | - **Scheduler:** CosineAnnealingWarmRestarts (T_MAX=10) |
| | - **Batch Size:** 16 |
| | - **Epochs:** 17 (Early stopping) |
| | - **Image Size:** 512×512 |
| | |
| | ## Performance |
| | |
| | The model was evaluated on a held-out validation set of 735 images: |
| | |
| | | Metric | Score | |
| | |--------|-------| |
| | | **Quadratic Weighted Kappa (QWK)** | **0.796** | |
| | | Accuracy | 65.0% | |
| | | F1-Score (Weighted) | 66.3% | |
| | | F1-Score (Macro) | 53.5% | |
| | |
| | ### Key Achievement |
| | The V2 model achieved a **+3.5% improvement in QWK** over the V1 baseline (0.761), indicating it makes "smarter" errors that are more aligned with clinical judgment, despite lower overall accuracy. This trade-off prioritizes clinically relevant performance over naive accuracy. |
| | |
| | ## Limitations |
| | |
| | ⚠️ **Important Disclaimers:** |
| | - This model was trained on a single public dataset and may not generalize to different clinical settings, camera types, or patient demographics |
| | - The dataset may contain inherent demographic biases |
| | - **This is NOT a medical device** and should not be used for actual clinical diagnosis |
| | - Always consult qualified healthcare professionals for medical decisions |
| | |
| | ## Citation |
| | |
| | If you use this model in your research, please cite: |
| | |
| | ```bibtex |
| | @misc{dheerentejani2025dr, |
| | author = {Dheeren Tejani}, |
| | title = {Diabetic Retinopathy Grading Model V2}, |
| | year = {2025}, |
| | publisher = {Hugging Face}, |
| | journal = {Hugging Face Model Hub}, |
| | howpublished = {\url{https://huggingface.co/dheeren-tejani/DiabeticRetinpathyClassifier}}, |
| | } |
| | ``` |
| | |
| | ## License |
| | |
| | This model is released under the Apache 2.0 License. |