lebiraja's picture
Upload BACKEND.md with huggingface_hub
f95f207 verified

Backend Integration Guide

Complete guide for backend developers integrating the retinal disease classifier.


Table of Contents

  1. Architecture Overview
  2. FastAPI Implementation
  3. Flask Implementation
  4. API Endpoints
  5. Payload Formats
  6. Error Handling
  7. Deployment
  8. Performance Optimization

Architecture Overview

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚   Frontend/Client   β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
           β”‚ HTTP
           β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚   API Server        β”‚ (FastAPI/Flask)
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
           β”‚
           β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚   Model Inference   β”‚ (PyTorch)
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
           β”‚
           β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚   GPU/CPU Device    β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

FastAPI Implementation

Installation

pip install fastapi uvicorn python-multipart pillow torch torchvision albumentations

Basic Server

# app.py
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
import numpy as np
from PIL import Image
import io
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Load model once at startup
app = FastAPI(title="Retinal Disease Classifier API", version="1.0")

# CORS configuration
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Global model
MODEL = None
DEVICE = None

@app.on_event("startup")
async def load_model():
    global MODEL, DEVICE
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    MODEL = torch.load("pytorch_model.bin", map_location=DEVICE)
    MODEL.eval()
    print(f"Model loaded on {DEVICE}")

# Response schemas
class DiseaseResult(BaseModel):
    disease_risk: bool
    predictions: dict
    detected_diseases: list
    num_detected: int
    confidence: float

# Disease list
DISEASE_NAMES = [
    "DR", "ARMD", "MH", "DN", "MYA", "BRVO", "TSLN", "ERM", "LS", "MS",
    "CSR", "ODC", "CRVO", "TV", "AH", "ODP", "ODE", "ST", "AION", "PT",
    "RT", "RS", "CRS", "EDN", "RPEC", "MHL", "RP", "CWS", "CB", "ODPM",
    "PRH", "MNF", "HR", "CRAO", "TD", "CME", "PTCR", "CF", "VH", "MCA",
    "VS", "BRAO", "PLQ", "HPED", "CL",
]

@app.get("/health")
async def health_check():
    return {
        "status": "healthy",
        "model_loaded": MODEL is not None,
        "device": str(DEVICE),
    }

@app.post("/predict", response_model=DiseaseResult)
async def predict(file: UploadFile = File(...), threshold: float = 0.5):
    """
    Predict diseases in retinal image.

    Args:
        file: PNG/JPG image file
        threshold: Disease detection threshold (0-1)

    Returns:
        DiseaseResult with predictions
    """
    try:
        # Validate input
        if not file.content_type.startswith("image/"):
            raise HTTPException(status_code=400, detail="File must be an image")

        if threshold < 0 or threshold > 1:
            raise HTTPException(status_code=400, detail="Threshold must be 0-1")

        # Load image
        image_data = await file.read()
        image = Image.open(io.BytesIO(image_data)).convert("RGB")

        # Preprocess
        transform = A.Compose([
            A.Resize(384, 384),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ])

        image_array = np.array(image)
        tensor = transform(image=image_array)["image"].unsqueeze(0).to(DEVICE)

        # Inference
        with torch.no_grad():
            logits = MODEL(tensor)
            probs = torch.sigmoid(logits)[0].cpu().numpy()

        # Parse results
        predictions = {name: float(prob) for name, prob in zip(DISEASE_NAMES, probs)}
        detected = [name for name, prob in predictions.items() if prob >= threshold]

        return DiseaseResult(
            disease_risk=len(detected) > 0,
            predictions=predictions,
            detected_diseases=detected,
            num_detected=len(detected),
            confidence=float(np.mean([predictions[d] for d in detected]) if detected else 0.0),
        )

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/predict-batch")
async def predict_batch(files: list[UploadFile] = File(...), threshold: float = 0.5):
    """Batch prediction for multiple images."""
    results = []
    for file in files:
        result = await predict(file, threshold)
        results.append(result)
    return {"results": results}

@app.get("/info")
async def model_info():
    return {
        "model_name": "Retinal Disease Classifier",
        "version": "1.0",
        "num_classes": 45,
        "diseases": DISEASE_NAMES,
        "input_size": 384,
        "metrics": {
            "best_auc": 0.8204,
            "train_loss": 0.2118,
            "val_loss": 0.2578,
        },
    }

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)

Run Server

# Development
uvicorn app:app --reload

# Production
uvicorn app:app --host 0.0.0.0 --port 8000 --workers 4

Access API

# Health check
curl http://localhost:8000/health

# Single prediction
curl -X POST http://localhost:8000/predict \
  -F "file=@image.png" \
  -F "threshold=0.5"

# Model info
curl http://localhost:8000/info

# Swagger docs
http://localhost:8000/docs

Flask Implementation

Installation

pip install flask torch torchvision pillow albumentations

Basic Server

# app.py
from flask import Flask, request, jsonify
from werkzeug.utils import secure_filename
import torch
import numpy as np
from PIL import Image
import io
import albumentations as A
from albumentations.pytorch import ToTensorV2
import os

app = Flask(__name__)
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024  # 16MB max file size

# Load model
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL = torch.load("pytorch_model.bin", map_location=DEVICE)
MODEL.eval()

DISEASE_NAMES = [
    "DR", "ARMD", "MH", "DN", "MYA", "BRVO", "TSLN", "ERM", "LS", "MS",
    "CSR", "ODC", "CRVO", "TV", "AH", "ODP", "ODE", "ST", "AION", "PT",
    "RT", "RS", "CRS", "EDN", "RPEC", "MHL", "RP", "CWS", "CB", "ODPM",
    "PRH", "MNF", "HR", "CRAO", "TD", "CME", "PTCR", "CF", "VH", "MCA",
    "VS", "BRAO", "PLQ", "HPED", "CL",
]

@app.route("/health", methods=["GET"])
def health():
    return jsonify({
        "status": "healthy",
        "device": str(DEVICE),
    })

@app.route("/predict", methods=["POST"])
def predict():
    """Single image prediction."""
    try:
        # Get parameters
        threshold = float(request.args.get("threshold", 0.5))

        # Validate
        if "file" not in request.files:
            return jsonify({"error": "No file provided"}), 400

        file = request.files["file"]
        if file.filename == "":
            return jsonify({"error": "No file selected"}), 400

        if threshold < 0 or threshold > 1:
            return jsonify({"error": "Threshold must be 0-1"}), 400

        # Load image
        image = Image.open(io.BytesIO(file.read())).convert("RGB")

        # Preprocess
        transform = A.Compose([
            A.Resize(384, 384),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ])

        image_array = np.array(image)
        tensor = transform(image=image_array)["image"].unsqueeze(0).to(DEVICE)

        # Inference
        with torch.no_grad():
            logits = MODEL(tensor)
            probs = torch.sigmoid(logits)[0].cpu().numpy()

        # Results
        predictions = {name: float(prob) for name, prob in zip(DISEASE_NAMES, probs)}
        detected = [name for name, prob in predictions.items() if prob >= threshold]

        return jsonify({
            "disease_risk": len(detected) > 0,
            "predictions": predictions,
            "detected_diseases": detected,
            "num_detected": len(detected),
        })

    except Exception as e:
        return jsonify({"error": str(e)}), 500

@app.route("/info", methods=["GET"])
def info():
    return jsonify({
        "model_name": "Retinal Disease Classifier",
        "version": "1.0",
        "num_classes": 45,
        "diseases": DISEASE_NAMES,
        "metrics": {"best_auc": 0.8204},
    })

if __name__ == "__main__":
    app.run(debug=False, host="0.0.0.0", port=8000)

Run Server

# Development
python app.py

# Production (with gunicorn)
pip install gunicorn
gunicorn -w 4 -b 0.0.0.0:8000 app:app

API Endpoints

POST /predict

Predict diseases in single image.

Request:

curl -X POST http://localhost:8000/predict \
  -F "file=@fundus.png" \
  -F "threshold=0.5"

Response (200 OK):

{
  "disease_risk": true,
  "predictions": {
    "DR": 0.993,
    "CRVO": 0.899,
    "LS": 0.859,
    ...
  },
  "detected_diseases": ["DR", "CRVO", "LS"],
  "num_detected": 3
}

Errors:

  • 400: Invalid file or threshold
  • 500: Server error

POST /predict-batch (FastAPI only)

Request:

curl -X POST http://localhost:8000/predict-batch \
  -F "files=@image1.png" \
  -F "files=@image2.png" \
  -F "threshold=0.5"

Response:

{
  "results": [
    { "disease_risk": true, ... },
    { "disease_risk": false, ... }
  ]
}

GET /health

Check API status.

Response:

{
  "status": "healthy",
  "device": "cuda",
  "model_loaded": true
}

GET /info

Get model information.

Response:

{
  "model_name": "Retinal Disease Classifier",
  "version": "1.0",
  "num_classes": 45,
  "diseases": ["DR", "ARMD", ...],
  "metrics": {
    "best_auc": 0.8204,
    "train_loss": 0.2118,
    "val_loss": 0.2578
  }
}

Payload Formats

Input

Content-Type: multipart/form-data

- file: Binary image data (PNG/JPG)
- threshold: Float 0-1 (optional, default 0.5)

Output

{
  "disease_risk": boolean,
  "predictions": {
    "disease_name": probability (0-1),
    ...
  },
  "detected_diseases": [disease_names],
  "num_detected": integer,
  "confidence": average_probability (optional)
}

Error Handling

class APIError(Exception):
    def __init__(self, message, status_code=500):
        self.message = message
        self.status_code = status_code

@app.errorhandler(APIError)
def handle_error(error):
    return jsonify({"error": error.message}), error.status_code

# Usage
if not valid:
    raise APIError("Invalid threshold", 400)

Deployment

Docker

FROM python:3.10

WORKDIR /app

COPY requirements.txt .
RUN pip install -r requirements.txt

COPY . .

CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]

Build & Run:

docker build -t retinal-classifier .
docker run -p 8000:8000 --gpus all retinal-classifier

Docker Compose

version: '3'
services:
  api:
    build: .
    ports:
      - "8000:8000"
    environment:
      - CUDA_VISIBLE_DEVICES=0
    volumes:
      - ./pytorch_model.bin:/app/pytorch_model.bin
    deploy:
      resources:
        reservations:
          devices:
            - driver: nvidia
              count: 1
              capabilities: [gpu]

Performance Optimization

Batch Processing

@app.post("/predict-batch")
async def batch_predict(files: list[UploadFile]):
    # Load all images
    tensors = []
    for file in files:
        image = Image.open(io.BytesIO(await file.read())).convert("RGB")
        tensor = transform(image=np.array(image))["image"]
        tensors.append(tensor)

    # Batch inference
    batch = torch.stack(tensors).to(DEVICE)
    with torch.no_grad():
        logits = MODEL(batch)
        probs = torch.sigmoid(logits).cpu().numpy()

    return probs  # (N, 45)

Caching

from functools import lru_cache

@lru_cache(maxsize=100)
def get_disease_info(disease_name):
    return DISEASE_INFO.get(disease_name, {})

Async Processing

import asyncio
from concurrent.futures import ThreadPoolExecutor

executor = ThreadPoolExecutor(max_workers=4)

@app.post("/predict-async")
async def predict_async(file: UploadFile):
    # Run inference in thread pool
    result = await asyncio.get_event_loop().run_in_executor(
        executor,
        predict_sync,
        file
    )
    return result

GPU Optimization

# Use mixed precision
from torch.cuda.amp import autocast

with autocast():
    logits = MODEL(tensor)
    probs = torch.sigmoid(logits)

# Batch inference
batch_size = 32
for i in range(0, len(images), batch_size):
    batch = images[i:i+batch_size]
    # Process batch

Testing

# test_api.py
import requests
import io
from PIL import Image

BASE_URL = "http://localhost:8000"

def test_health():
    response = requests.get(f"{BASE_URL}/health")
    assert response.status_code == 200

def test_predict():
    # Create dummy image
    image = Image.new("RGB", (384, 384))
    img_bytes = io.BytesIO()
    image.save(img_bytes, format="PNG")
    img_bytes.seek(0)

    files = {"file": ("test.png", img_bytes, "image/png")}
    response = requests.post(f"{BASE_URL}/predict", files=files)

    assert response.status_code == 200
    data = response.json()
    assert "disease_risk" in data
    assert "predictions" in data

if __name__ == "__main__":
    test_health()
    test_predict()
    print("βœ… All tests passed")

Run tests:

python test_api.py

Last Updated: February 22, 2026 Status: Production Ready βœ