lebiraja's picture
Upload USER_GUIDE.md with huggingface_hub
a678aba verified

Retinal Disease Classifier - User Guide

Complete guide for using the retinal disease classifier model.


Table of Contents

  1. Quick Start
  2. Installation
  3. Basic Usage
  4. Understanding Results
  5. Advanced Usage
  6. Troubleshooting
  7. FAQ

Quick Start

Install:

pip install torch torchvision pillow albumentations scikit-learn numpy

Use:

import torch
from PIL import Image
import numpy as np
from model_inference import predict_image

result = predict_image("path/to/fundus_image.png")
print(f"Detected diseases: {result['detected_diseases']}")

Output:

{
  "disease_risk": true,
  "detected_diseases": ["DR", "CRVO"],
  "num_detected": 2,
  "predictions": { "DR": 0.993, "CRVO": 0.899, ... }
}

Installation

Requirements

  • Python 3.10+
  • PyTorch 2.0+
  • 4GB RAM (2GB with smaller batch sizes)
  • GPU recommended (CUDA 12.1+)

Step 1: Clone or Download

Option A: From Hugging Face

git clone https://huggingface.co/lebiraja/retinal-disease-classifier
cd retinal-disease-classifier

Option B: Manual Download Download pytorch_model.bin from Hugging Face and place in your project directory.

Step 2: Install Dependencies

pip install -r requirements.txt

Or manually:

pip install torch==2.1.0 torchvision==0.16.0 pillow albumentations scikit-learn numpy matplotlib tqdm

Step 3: Verify Installation

python3 << 'EOF'
import torch
from PIL import Image
import albumentations
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print("βœ… All dependencies OK")
EOF

Basic Usage

Using Command Line

# Single image
python3 inference.py --image path/to/image.png

# Batch processing
for img in *.png; do
  python3 inference.py --image "$img"
done

# Custom threshold
python3 inference.py --image image.png --threshold 0.4

Using Python API

import torch
from PIL import Image
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.hub.load('path/to/model', 'custom', path='pytorch_model.bin')
model = model.to(device)
model.eval()

# Prepare image
transform = A.Compose([
    A.Resize(384, 384),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

image = np.array(Image.open("fundus.png").convert("RGB"))
tensor = transform(image=image)["image"].unsqueeze(0).to(device)

# Inference
with torch.no_grad():
    logits = model(tensor)
    probs = torch.sigmoid(logits)[0].cpu().numpy()

# Parse results
disease_names = [
    "DR", "ARMD", "MH", "DN", "MYA", "BRVO", "TSLN", "ERM", "LS", "MS",
    "CSR", "ODC", "CRVO", "TV", "AH", "ODP", "ODE", "ST", "AION", "PT",
    "RT", "RS", "CRS", "EDN", "RPEC", "MHL", "RP", "CWS", "CB", "ODPM",
    "PRH", "MNF", "HR", "CRAO", "TD", "CME", "PTCR", "CF", "VH", "MCA",
    "VS", "BRAO", "PLQ", "HPED", "CL",
]

threshold = 0.5
results = {
    "disease_risk": False,
    "predictions": {name: float(prob) for name, prob in zip(disease_names, probs)},
    "detected_diseases": [name for name, prob in zip(disease_names, probs) if prob >= threshold],
}

results["disease_risk"] = len(results["detected_diseases"]) > 0
print(results)

Understanding Results

Output Format

{
  "disease_risk": true,
  "predictions": {
    "DR": 0.993,
    "ARMD": 0.042,
    "MH": 0.029,
    ...
  },
  "detected_diseases": ["DR"],
  "num_detected": 1
}

Key Fields

Field Type Meaning
disease_risk bool Any disease detected (above threshold)
predictions dict Probability [0,1] for all 45 diseases
detected_diseases list Diseases above threshold
num_detected int Count of detected diseases

Interpreting Probabilities

  • 0.0 - 0.3: Disease unlikely
  • 0.3 - 0.7: Uncertain (review recommended)
  • 0.7 - 1.0: Disease likely present

Disease Abbreviations

Code Full Name Severity
DR Diabetic Retinopathy πŸ”΄ High
ARMD Age-Related Macular Degeneration πŸ”΄ High
CRVO Central Retinal Vein Occlusion πŸ”΄ High
BRVO Branch Retinal Vein Occlusion 🟑 Medium
LS Laser Scar 🟑 Medium
MH Myopia 🟒 Low
CWS Cotton Wool Spots 🟒 Low

Advanced Usage

Batch Processing with GPU

import torch
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import numpy as np

class RetinalDataset(Dataset):
    def __init__(self, image_paths, transform):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = np.array(Image.open(self.image_paths[idx]).convert("RGB"))
        tensor = self.transform(image=image)["image"]
        return tensor

# Load model
model = load_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()

# Create dataset
dataset = RetinalDataset(image_paths, transform)
loader = DataLoader(dataset, batch_size=32, num_workers=4)

# Batch inference
all_results = []
with torch.no_grad():
    for batch in loader:
        batch = batch.to(device)
        logits = model(batch)
        probs = torch.sigmoid(logits).cpu().numpy()
        all_results.append(probs)

results = np.concatenate(all_results, axis=0)  # (N, 45)

Threshold Optimization

# Adjust threshold based on use case
thresholds = {
    "high_sensitivity": 0.3,      # Catch more cases, more false positives
    "balanced": 0.5,              # Default
    "high_specificity": 0.7,      # Fewer false alarms, miss some cases
    "ultra_conservative": 0.85,   # Only very confident predictions
}

threshold = thresholds["balanced"]
detected = [name for name, prob in predictions.items() if prob >= threshold]

Per-Disease Thresholds

# Different thresholds for different diseases
per_disease_thresholds = {
    "DR": 0.4,      # DR is important, lower threshold
    "ARMD": 0.4,    # ARMD is important, lower threshold
    "MH": 0.6,      # Myopia less critical, higher threshold
    # ... default to 0.5 for others
}

detected = []
for name, prob in predictions.items():
    threshold = per_disease_thresholds.get(name, 0.5)
    if prob >= threshold:
        detected.append(name)

Troubleshooting

Image Issues

Problem: "Image has wrong dimensions"

Solution: Ensure image is RGB, not RGBA or grayscale
image = Image.open("image.png").convert("RGB")

Problem: "Out of memory (OOM)"

Solution: Reduce batch size or image resolution
batch_size = 8  # or lower
IMG_SIZE = 256  # instead of 384

Problem: "Image too small"

Solution: Model expects 384Γ—384 minimum
from PIL import Image
image = image.resize((384, 384))

Model Issues

Problem: "Model file not found"

# Download from Hugging Face
git clone https://huggingface.co/lebiraja/retinal-disease-classifier

Problem: "CUDA out of memory"

# Use CPU instead
device = torch.device("cpu")
model = model.to(device)

Problem: "Predictions are all zeros"

Check:
1. Image is valid fundus photo (not blank/black)
2. Model file is not corrupted (check MD5)
3. Image preprocessing is correct

FAQ

Q: What image formats are supported?

A: PNG, JPG, JPEG, BMP, TIFF. Convert with:

from PIL import Image
image = Image.open("image.bmp").convert("RGB").save("image.png")

Q: Can I use this for clinical diagnosis?

A: NO. This is for research/educational purposes only. Always consult qualified ophthalmologists for medical decisions.

Q: How accurate is the model?

A: Mean AUC: 0.8204 (82%). Accuracy varies by disease:

  • Common diseases: 90-95% AUC
  • Rare diseases: 60-75% AUC

Q: Can I fine-tune on my own data?

A: Yes! See DEVELOPER.md for fine-tuning instructions.

Q: What's the difference between probabilities?

A:

  • Probability 0.9: Very confident disease present
  • Probability 0.5: Uncertain, needs review
  • Probability 0.1: Very confident disease absent

Q: How do I batch process images?

A: See "Batch Processing with GPU" section above.

Q: Can I run on CPU only?

A: Yes, but ~10x slower. Set device to CPU.

Q: What if I get different results each time?

A: Add torch.manual_seed(42) for reproducibility.


Support


Last Updated: February 22, 2026 Model Version: 1.0 Status: Production Ready βœ