Retinal Disease Classifier - User Guide
Complete guide for using the retinal disease classifier model.
Table of Contents
Quick Start
Install:
pip install torch torchvision pillow albumentations scikit-learn numpy
Use:
import torch
from PIL import Image
import numpy as np
from model_inference import predict_image
result = predict_image("path/to/fundus_image.png")
print(f"Detected diseases: {result['detected_diseases']}")
Output:
{
"disease_risk": true,
"detected_diseases": ["DR", "CRVO"],
"num_detected": 2,
"predictions": { "DR": 0.993, "CRVO": 0.899, ... }
}
Installation
Requirements
- Python 3.10+
- PyTorch 2.0+
- 4GB RAM (2GB with smaller batch sizes)
- GPU recommended (CUDA 12.1+)
Step 1: Clone or Download
Option A: From Hugging Face
git clone https://huggingface.co/lebiraja/retinal-disease-classifier
cd retinal-disease-classifier
Option B: Manual Download
Download pytorch_model.bin from Hugging Face and place in your project directory.
Step 2: Install Dependencies
pip install -r requirements.txt
Or manually:
pip install torch==2.1.0 torchvision==0.16.0 pillow albumentations scikit-learn numpy matplotlib tqdm
Step 3: Verify Installation
python3 << 'EOF'
import torch
from PIL import Image
import albumentations
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print("β
All dependencies OK")
EOF
Basic Usage
Using Command Line
# Single image
python3 inference.py --image path/to/image.png
# Batch processing
for img in *.png; do
python3 inference.py --image "$img"
done
# Custom threshold
python3 inference.py --image image.png --threshold 0.4
Using Python API
import torch
from PIL import Image
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.hub.load('path/to/model', 'custom', path='pytorch_model.bin')
model = model.to(device)
model.eval()
# Prepare image
transform = A.Compose([
A.Resize(384, 384),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
])
image = np.array(Image.open("fundus.png").convert("RGB"))
tensor = transform(image=image)["image"].unsqueeze(0).to(device)
# Inference
with torch.no_grad():
logits = model(tensor)
probs = torch.sigmoid(logits)[0].cpu().numpy()
# Parse results
disease_names = [
"DR", "ARMD", "MH", "DN", "MYA", "BRVO", "TSLN", "ERM", "LS", "MS",
"CSR", "ODC", "CRVO", "TV", "AH", "ODP", "ODE", "ST", "AION", "PT",
"RT", "RS", "CRS", "EDN", "RPEC", "MHL", "RP", "CWS", "CB", "ODPM",
"PRH", "MNF", "HR", "CRAO", "TD", "CME", "PTCR", "CF", "VH", "MCA",
"VS", "BRAO", "PLQ", "HPED", "CL",
]
threshold = 0.5
results = {
"disease_risk": False,
"predictions": {name: float(prob) for name, prob in zip(disease_names, probs)},
"detected_diseases": [name for name, prob in zip(disease_names, probs) if prob >= threshold],
}
results["disease_risk"] = len(results["detected_diseases"]) > 0
print(results)
Understanding Results
Output Format
{
"disease_risk": true,
"predictions": {
"DR": 0.993,
"ARMD": 0.042,
"MH": 0.029,
...
},
"detected_diseases": ["DR"],
"num_detected": 1
}
Key Fields
| Field | Type | Meaning |
|---|---|---|
disease_risk |
bool | Any disease detected (above threshold) |
predictions |
dict | Probability [0,1] for all 45 diseases |
detected_diseases |
list | Diseases above threshold |
num_detected |
int | Count of detected diseases |
Interpreting Probabilities
- 0.0 - 0.3: Disease unlikely
- 0.3 - 0.7: Uncertain (review recommended)
- 0.7 - 1.0: Disease likely present
Disease Abbreviations
| Code | Full Name | Severity |
|---|---|---|
| DR | Diabetic Retinopathy | π΄ High |
| ARMD | Age-Related Macular Degeneration | π΄ High |
| CRVO | Central Retinal Vein Occlusion | π΄ High |
| BRVO | Branch Retinal Vein Occlusion | π‘ Medium |
| LS | Laser Scar | π‘ Medium |
| MH | Myopia | π’ Low |
| CWS | Cotton Wool Spots | π’ Low |
Advanced Usage
Batch Processing with GPU
import torch
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import numpy as np
class RetinalDataset(Dataset):
def __init__(self, image_paths, transform):
self.image_paths = image_paths
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = np.array(Image.open(self.image_paths[idx]).convert("RGB"))
tensor = self.transform(image=image)["image"]
return tensor
# Load model
model = load_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
# Create dataset
dataset = RetinalDataset(image_paths, transform)
loader = DataLoader(dataset, batch_size=32, num_workers=4)
# Batch inference
all_results = []
with torch.no_grad():
for batch in loader:
batch = batch.to(device)
logits = model(batch)
probs = torch.sigmoid(logits).cpu().numpy()
all_results.append(probs)
results = np.concatenate(all_results, axis=0) # (N, 45)
Threshold Optimization
# Adjust threshold based on use case
thresholds = {
"high_sensitivity": 0.3, # Catch more cases, more false positives
"balanced": 0.5, # Default
"high_specificity": 0.7, # Fewer false alarms, miss some cases
"ultra_conservative": 0.85, # Only very confident predictions
}
threshold = thresholds["balanced"]
detected = [name for name, prob in predictions.items() if prob >= threshold]
Per-Disease Thresholds
# Different thresholds for different diseases
per_disease_thresholds = {
"DR": 0.4, # DR is important, lower threshold
"ARMD": 0.4, # ARMD is important, lower threshold
"MH": 0.6, # Myopia less critical, higher threshold
# ... default to 0.5 for others
}
detected = []
for name, prob in predictions.items():
threshold = per_disease_thresholds.get(name, 0.5)
if prob >= threshold:
detected.append(name)
Troubleshooting
Image Issues
Problem: "Image has wrong dimensions"
Solution: Ensure image is RGB, not RGBA or grayscale
image = Image.open("image.png").convert("RGB")
Problem: "Out of memory (OOM)"
Solution: Reduce batch size or image resolution
batch_size = 8 # or lower
IMG_SIZE = 256 # instead of 384
Problem: "Image too small"
Solution: Model expects 384Γ384 minimum
from PIL import Image
image = image.resize((384, 384))
Model Issues
Problem: "Model file not found"
# Download from Hugging Face
git clone https://huggingface.co/lebiraja/retinal-disease-classifier
Problem: "CUDA out of memory"
# Use CPU instead
device = torch.device("cpu")
model = model.to(device)
Problem: "Predictions are all zeros"
Check:
1. Image is valid fundus photo (not blank/black)
2. Model file is not corrupted (check MD5)
3. Image preprocessing is correct
FAQ
Q: What image formats are supported?
A: PNG, JPG, JPEG, BMP, TIFF. Convert with:
from PIL import Image
image = Image.open("image.bmp").convert("RGB").save("image.png")
Q: Can I use this for clinical diagnosis?
A: NO. This is for research/educational purposes only. Always consult qualified ophthalmologists for medical decisions.
Q: How accurate is the model?
A: Mean AUC: 0.8204 (82%). Accuracy varies by disease:
- Common diseases: 90-95% AUC
- Rare diseases: 60-75% AUC
Q: Can I fine-tune on my own data?
A: Yes! See DEVELOPER.md for fine-tuning instructions.
Q: What's the difference between probabilities?
A:
- Probability 0.9: Very confident disease present
- Probability 0.5: Uncertain, needs review
- Probability 0.1: Very confident disease absent
Q: How do I batch process images?
A: See "Batch Processing with GPU" section above.
Q: Can I run on CPU only?
A: Yes, but ~10x slower. Set device to CPU.
Q: What if I get different results each time?
A: Add torch.manual_seed(42) for reproducibility.
Support
- GitHub Issues: Report bugs on the project GitHub
- Hugging Face: https://huggingface.co/lebiraja/retinal-disease-classifier
- Documentation: Check docs/ folder for detailed guides
Last Updated: February 22, 2026 Model Version: 1.0 Status: Production Ready β