Retinal OCT Disease Classifier
A deep learning model for classifying retinal diseases from Optical Coherence Tomography (OCT) images. The model distinguishes between four conditions:
- CNV - Choroidal Neovascularization
- DME - Diabetic Macular Edema
- DRUSEN - Early AMD (Age-related Macular Degeneration)
- NORMAL - Healthy Retina
Model Description
This model uses EfficientNet-B3 as the backbone with a custom classification head, trained on the Kermany2018 OCT dataset. It achieves state-of-the-art performance on retinal disease classification.
Architecture
- Backbone: EfficientNet-B3 (pretrained on ImageNet)
- Head: Dropout(0.3) → Linear(1536→512) → ReLU → Dropout(0.15) → Linear(512→4)
- Input Size: 224×224 RGB images
- Output: 4-class probability distribution
Performance
| Metric | Validation | Test |
|---|---|---|
| Accuracy | 97.8% | 99.6% |
| Macro F1 | 96.8% | 99.6% |
Per-Class Performance (Test Set)
| Class | Precision | Recall | F1-Score | Support |
|---|---|---|---|---|
| CNV | 98.4% | 100.0% | 99.2% | 242 |
| DME | 100.0% | 100.0% | 100.0% | 242 |
| DRUSEN | 100.0% | 98.4% | 99.2% | 242 |
| NORMAL | 100.0% | 100.0% | 100.0% | 242 |
Usage
Quick Start
import torch
from PIL import Image
from torchvision import transforms
# Load model
model = torch.hub.load('your-username/oct-classifier', 'oct_classifier', pretrained=True)
model.eval()
# Preprocess image
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Predict
image = Image.open('oct_scan.jpeg').convert('RGB')
input_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)
probabilities = torch.softmax(output, dim=1)
predicted_class = torch.argmax(probabilities, dim=1).item()
classes = ['CNV', 'DME', 'DRUSEN', 'NORMAL']
print(f"Prediction: {classes[predicted_class]}")
print(f"Confidence: {probabilities[0][predicted_class]:.2%}")
Using with timm
import timm
import torch
# Load the model
model = timm.create_model('efficientnet_b3', pretrained=False, num_classes=4)
model.load_state_dict(torch.load('pytorch_model.bin'))
model.eval()
Training Details
Dataset
- Source: Kermany2018 OCT Dataset
- Total Images: ~84,000 OCT scans
- Train/Val Split: 80/20 stratified
- Test Set: 968 images (242 per class)
Training Configuration
- Optimizer: AdamW (lr=1e-4, weight_decay=0.01)
- Scheduler: Cosine Annealing with 2-epoch warmup
- Batch Size: 32
- Epochs: 20
- Mixed Precision: FP16
- Gradient Clipping: 1.0
Data Augmentation
- Horizontal flip (p=0.5)
- Rotation (±15°)
- Brightness/Contrast adjustment (±0.2)
- Gaussian noise (p=0.3)
- Gaussian blur (p=0.2)
Interpretability
The model includes Grad-CAM visualization support to show which regions of the retinal scan influenced the prediction - critical for clinical trust and validation.
Limitations
- Trained on a single OCT device type; may require fine-tuning for other devices
- Not validated for clinical deployment without proper medical device certification
- Performance may vary on images with different quality or preprocessing
Links
- GitHub: tomalmog/retinal-oct-classifier
- Hugging Face: tomalmog/oct-retinal-classifier
Citation
If you use this model, please cite:
@misc{almog2025oct,
title={Retinal OCT Disease Classification with EfficientNet-B3},
author={Almog, Tom},
year={2025},
url={https://huggingface.co/tomalmog/oct-retinal-classifier}
}
This model was trained on the Kermany2018 dataset:
@article{kermany2018identifying,
title={Identifying medical diagnoses and treatable diseases by image-based deep learning},
author={Kermany, Daniel S and Goldbaum, Michael and Cai, Wenjia and others},
journal={Cell},
volume={172},
number={5},
pages={1122--1131},
year={2018},
publisher={Elsevier}
}
License
MIT License
- Downloads last month
- 149
