Retinal OCT Disease Classifier

A deep learning model for classifying retinal diseases from Optical Coherence Tomography (OCT) images. The model distinguishes between four conditions:

  • CNV - Choroidal Neovascularization
  • DME - Diabetic Macular Edema
  • DRUSEN - Early AMD (Age-related Macular Degeneration)
  • NORMAL - Healthy Retina

Model Description

This model uses EfficientNet-B3 as the backbone with a custom classification head, trained on the Kermany2018 OCT dataset. It achieves state-of-the-art performance on retinal disease classification.

Architecture

  • Backbone: EfficientNet-B3 (pretrained on ImageNet)
  • Head: Dropout(0.3) → Linear(1536→512) → ReLU → Dropout(0.15) → Linear(512→4)
  • Input Size: 224×224 RGB images
  • Output: 4-class probability distribution

Performance

Metric Validation Test
Accuracy 97.8% 99.6%
Macro F1 96.8% 99.6%

Per-Class Performance (Test Set)

Class Precision Recall F1-Score Support
CNV 98.4% 100.0% 99.2% 242
DME 100.0% 100.0% 100.0% 242
DRUSEN 100.0% 98.4% 99.2% 242
NORMAL 100.0% 100.0% 100.0% 242

Usage

Quick Start

import torch
from PIL import Image
from torchvision import transforms

# Load model
model = torch.hub.load('your-username/oct-classifier', 'oct_classifier', pretrained=True)
model.eval()

# Preprocess image
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Predict
image = Image.open('oct_scan.jpeg').convert('RGB')
input_tensor = transform(image).unsqueeze(0)

with torch.no_grad():
    output = model(input_tensor)
    probabilities = torch.softmax(output, dim=1)
    predicted_class = torch.argmax(probabilities, dim=1).item()

classes = ['CNV', 'DME', 'DRUSEN', 'NORMAL']
print(f"Prediction: {classes[predicted_class]}")
print(f"Confidence: {probabilities[0][predicted_class]:.2%}")

Using with timm

import timm
import torch

# Load the model
model = timm.create_model('efficientnet_b3', pretrained=False, num_classes=4)
model.load_state_dict(torch.load('pytorch_model.bin'))
model.eval()

Training Details

Dataset

  • Source: Kermany2018 OCT Dataset
  • Total Images: ~84,000 OCT scans
  • Train/Val Split: 80/20 stratified
  • Test Set: 968 images (242 per class)

Training Configuration

  • Optimizer: AdamW (lr=1e-4, weight_decay=0.01)
  • Scheduler: Cosine Annealing with 2-epoch warmup
  • Batch Size: 32
  • Epochs: 20
  • Mixed Precision: FP16
  • Gradient Clipping: 1.0

Data Augmentation

  • Horizontal flip (p=0.5)
  • Rotation (±15°)
  • Brightness/Contrast adjustment (±0.2)
  • Gaussian noise (p=0.3)
  • Gaussian blur (p=0.2)

Interpretability

The model includes Grad-CAM visualization support to show which regions of the retinal scan influenced the prediction - critical for clinical trust and validation.

Grad-CAM Samples

Limitations

  • Trained on a single OCT device type; may require fine-tuning for other devices
  • Not validated for clinical deployment without proper medical device certification
  • Performance may vary on images with different quality or preprocessing

Links

Citation

If you use this model, please cite:

@misc{almog2025oct,
  title={Retinal OCT Disease Classification with EfficientNet-B3},
  author={Almog, Tom},
  year={2025},
  url={https://huggingface.co/tomalmog/oct-retinal-classifier}
}

This model was trained on the Kermany2018 dataset:

@article{kermany2018identifying,
  title={Identifying medical diagnoses and treatable diseases by image-based deep learning},
  author={Kermany, Daniel S and Goldbaum, Michael and Cai, Wenjia and others},
  journal={Cell},
  volume={172},
  number={5},
  pages={1122--1131},
  year={2018},
  publisher={Elsevier}
}

License

MIT License

Downloads last month
149
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support