lebiraja's picture
Upload BACKEND.md with huggingface_hub
f95f207 verified
# Backend Integration Guide
Complete guide for backend developers integrating the retinal disease classifier.
---
## Table of Contents
1. [Architecture Overview](#architecture-overview)
2. [FastAPI Implementation](#fastapi-implementation)
3. [Flask Implementation](#flask-implementation)
4. [API Endpoints](#api-endpoints)
5. [Payload Formats](#payload-formats)
6. [Error Handling](#error-handling)
7. [Deployment](#deployment)
8. [Performance Optimization](#performance-optimization)
---
## Architecture Overview
```
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Frontend/Client β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚ HTTP
β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ API Server β”‚ (FastAPI/Flask)
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚
β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Model Inference β”‚ (PyTorch)
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚
β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ GPU/CPU Device β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
```
---
## FastAPI Implementation
### Installation
```bash
pip install fastapi uvicorn python-multipart pillow torch torchvision albumentations
```
### Basic Server
```python
# app.py
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
import numpy as np
from PIL import Image
import io
import albumentations as A
from albumentations.pytorch import ToTensorV2
# Load model once at startup
app = FastAPI(title="Retinal Disease Classifier API", version="1.0")
# CORS configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global model
MODEL = None
DEVICE = None
@app.on_event("startup")
async def load_model():
global MODEL, DEVICE
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL = torch.load("pytorch_model.bin", map_location=DEVICE)
MODEL.eval()
print(f"Model loaded on {DEVICE}")
# Response schemas
class DiseaseResult(BaseModel):
disease_risk: bool
predictions: dict
detected_diseases: list
num_detected: int
confidence: float
# Disease list
DISEASE_NAMES = [
"DR", "ARMD", "MH", "DN", "MYA", "BRVO", "TSLN", "ERM", "LS", "MS",
"CSR", "ODC", "CRVO", "TV", "AH", "ODP", "ODE", "ST", "AION", "PT",
"RT", "RS", "CRS", "EDN", "RPEC", "MHL", "RP", "CWS", "CB", "ODPM",
"PRH", "MNF", "HR", "CRAO", "TD", "CME", "PTCR", "CF", "VH", "MCA",
"VS", "BRAO", "PLQ", "HPED", "CL",
]
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"model_loaded": MODEL is not None,
"device": str(DEVICE),
}
@app.post("/predict", response_model=DiseaseResult)
async def predict(file: UploadFile = File(...), threshold: float = 0.5):
"""
Predict diseases in retinal image.
Args:
file: PNG/JPG image file
threshold: Disease detection threshold (0-1)
Returns:
DiseaseResult with predictions
"""
try:
# Validate input
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File must be an image")
if threshold < 0 or threshold > 1:
raise HTTPException(status_code=400, detail="Threshold must be 0-1")
# Load image
image_data = await file.read()
image = Image.open(io.BytesIO(image_data)).convert("RGB")
# Preprocess
transform = A.Compose([
A.Resize(384, 384),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
])
image_array = np.array(image)
tensor = transform(image=image_array)["image"].unsqueeze(0).to(DEVICE)
# Inference
with torch.no_grad():
logits = MODEL(tensor)
probs = torch.sigmoid(logits)[0].cpu().numpy()
# Parse results
predictions = {name: float(prob) for name, prob in zip(DISEASE_NAMES, probs)}
detected = [name for name, prob in predictions.items() if prob >= threshold]
return DiseaseResult(
disease_risk=len(detected) > 0,
predictions=predictions,
detected_diseases=detected,
num_detected=len(detected),
confidence=float(np.mean([predictions[d] for d in detected]) if detected else 0.0),
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/predict-batch")
async def predict_batch(files: list[UploadFile] = File(...), threshold: float = 0.5):
"""Batch prediction for multiple images."""
results = []
for file in files:
result = await predict(file, threshold)
results.append(result)
return {"results": results}
@app.get("/info")
async def model_info():
return {
"model_name": "Retinal Disease Classifier",
"version": "1.0",
"num_classes": 45,
"diseases": DISEASE_NAMES,
"input_size": 384,
"metrics": {
"best_auc": 0.8204,
"train_loss": 0.2118,
"val_loss": 0.2578,
},
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
```
### Run Server
```bash
# Development
uvicorn app:app --reload
# Production
uvicorn app:app --host 0.0.0.0 --port 8000 --workers 4
```
### Access API
```bash
# Health check
curl http://localhost:8000/health
# Single prediction
curl -X POST http://localhost:8000/predict \
-F "file=@image.png" \
-F "threshold=0.5"
# Model info
curl http://localhost:8000/info
# Swagger docs
http://localhost:8000/docs
```
---
## Flask Implementation
### Installation
```bash
pip install flask torch torchvision pillow albumentations
```
### Basic Server
```python
# app.py
from flask import Flask, request, jsonify
from werkzeug.utils import secure_filename
import torch
import numpy as np
from PIL import Image
import io
import albumentations as A
from albumentations.pytorch import ToTensorV2
import os
app = Flask(__name__)
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size
# Load model
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL = torch.load("pytorch_model.bin", map_location=DEVICE)
MODEL.eval()
DISEASE_NAMES = [
"DR", "ARMD", "MH", "DN", "MYA", "BRVO", "TSLN", "ERM", "LS", "MS",
"CSR", "ODC", "CRVO", "TV", "AH", "ODP", "ODE", "ST", "AION", "PT",
"RT", "RS", "CRS", "EDN", "RPEC", "MHL", "RP", "CWS", "CB", "ODPM",
"PRH", "MNF", "HR", "CRAO", "TD", "CME", "PTCR", "CF", "VH", "MCA",
"VS", "BRAO", "PLQ", "HPED", "CL",
]
@app.route("/health", methods=["GET"])
def health():
return jsonify({
"status": "healthy",
"device": str(DEVICE),
})
@app.route("/predict", methods=["POST"])
def predict():
"""Single image prediction."""
try:
# Get parameters
threshold = float(request.args.get("threshold", 0.5))
# Validate
if "file" not in request.files:
return jsonify({"error": "No file provided"}), 400
file = request.files["file"]
if file.filename == "":
return jsonify({"error": "No file selected"}), 400
if threshold < 0 or threshold > 1:
return jsonify({"error": "Threshold must be 0-1"}), 400
# Load image
image = Image.open(io.BytesIO(file.read())).convert("RGB")
# Preprocess
transform = A.Compose([
A.Resize(384, 384),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
])
image_array = np.array(image)
tensor = transform(image=image_array)["image"].unsqueeze(0).to(DEVICE)
# Inference
with torch.no_grad():
logits = MODEL(tensor)
probs = torch.sigmoid(logits)[0].cpu().numpy()
# Results
predictions = {name: float(prob) for name, prob in zip(DISEASE_NAMES, probs)}
detected = [name for name, prob in predictions.items() if prob >= threshold]
return jsonify({
"disease_risk": len(detected) > 0,
"predictions": predictions,
"detected_diseases": detected,
"num_detected": len(detected),
})
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/info", methods=["GET"])
def info():
return jsonify({
"model_name": "Retinal Disease Classifier",
"version": "1.0",
"num_classes": 45,
"diseases": DISEASE_NAMES,
"metrics": {"best_auc": 0.8204},
})
if __name__ == "__main__":
app.run(debug=False, host="0.0.0.0", port=8000)
```
### Run Server
```bash
# Development
python app.py
# Production (with gunicorn)
pip install gunicorn
gunicorn -w 4 -b 0.0.0.0:8000 app:app
```
---
## API Endpoints
### POST /predict
Predict diseases in single image.
**Request:**
```bash
curl -X POST http://localhost:8000/predict \
-F "file=@fundus.png" \
-F "threshold=0.5"
```
**Response (200 OK):**
```json
{
"disease_risk": true,
"predictions": {
"DR": 0.993,
"CRVO": 0.899,
"LS": 0.859,
...
},
"detected_diseases": ["DR", "CRVO", "LS"],
"num_detected": 3
}
```
**Errors:**
- 400: Invalid file or threshold
- 500: Server error
---
### POST /predict-batch (FastAPI only)
**Request:**
```bash
curl -X POST http://localhost:8000/predict-batch \
-F "files=@image1.png" \
-F "files=@image2.png" \
-F "threshold=0.5"
```
**Response:**
```json
{
"results": [
{ "disease_risk": true, ... },
{ "disease_risk": false, ... }
]
}
```
---
### GET /health
Check API status.
**Response:**
```json
{
"status": "healthy",
"device": "cuda",
"model_loaded": true
}
```
---
### GET /info
Get model information.
**Response:**
```json
{
"model_name": "Retinal Disease Classifier",
"version": "1.0",
"num_classes": 45,
"diseases": ["DR", "ARMD", ...],
"metrics": {
"best_auc": 0.8204,
"train_loss": 0.2118,
"val_loss": 0.2578
}
}
```
---
## Payload Formats
### Input
```
Content-Type: multipart/form-data
- file: Binary image data (PNG/JPG)
- threshold: Float 0-1 (optional, default 0.5)
```
### Output
```json
{
"disease_risk": boolean,
"predictions": {
"disease_name": probability (0-1),
...
},
"detected_diseases": [disease_names],
"num_detected": integer,
"confidence": average_probability (optional)
}
```
---
## Error Handling
```python
class APIError(Exception):
def __init__(self, message, status_code=500):
self.message = message
self.status_code = status_code
@app.errorhandler(APIError)
def handle_error(error):
return jsonify({"error": error.message}), error.status_code
# Usage
if not valid:
raise APIError("Invalid threshold", 400)
```
---
## Deployment
### Docker
```dockerfile
FROM python:3.10
WORKDIR /app
COPY requirements.txt .
RUN pip install -r requirements.txt
COPY . .
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
```
Build & Run:
```bash
docker build -t retinal-classifier .
docker run -p 8000:8000 --gpus all retinal-classifier
```
### Docker Compose
```yaml
version: '3'
services:
api:
build: .
ports:
- "8000:8000"
environment:
- CUDA_VISIBLE_DEVICES=0
volumes:
- ./pytorch_model.bin:/app/pytorch_model.bin
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
```
---
## Performance Optimization
### Batch Processing
```python
@app.post("/predict-batch")
async def batch_predict(files: list[UploadFile]):
# Load all images
tensors = []
for file in files:
image = Image.open(io.BytesIO(await file.read())).convert("RGB")
tensor = transform(image=np.array(image))["image"]
tensors.append(tensor)
# Batch inference
batch = torch.stack(tensors).to(DEVICE)
with torch.no_grad():
logits = MODEL(batch)
probs = torch.sigmoid(logits).cpu().numpy()
return probs # (N, 45)
```
### Caching
```python
from functools import lru_cache
@lru_cache(maxsize=100)
def get_disease_info(disease_name):
return DISEASE_INFO.get(disease_name, {})
```
### Async Processing
```python
import asyncio
from concurrent.futures import ThreadPoolExecutor
executor = ThreadPoolExecutor(max_workers=4)
@app.post("/predict-async")
async def predict_async(file: UploadFile):
# Run inference in thread pool
result = await asyncio.get_event_loop().run_in_executor(
executor,
predict_sync,
file
)
return result
```
### GPU Optimization
```python
# Use mixed precision
from torch.cuda.amp import autocast
with autocast():
logits = MODEL(tensor)
probs = torch.sigmoid(logits)
# Batch inference
batch_size = 32
for i in range(0, len(images), batch_size):
batch = images[i:i+batch_size]
# Process batch
```
---
## Testing
```python
# test_api.py
import requests
import io
from PIL import Image
BASE_URL = "http://localhost:8000"
def test_health():
response = requests.get(f"{BASE_URL}/health")
assert response.status_code == 200
def test_predict():
# Create dummy image
image = Image.new("RGB", (384, 384))
img_bytes = io.BytesIO()
image.save(img_bytes, format="PNG")
img_bytes.seek(0)
files = {"file": ("test.png", img_bytes, "image/png")}
response = requests.post(f"{BASE_URL}/predict", files=files)
assert response.status_code == 200
data = response.json()
assert "disease_risk" in data
assert "predictions" in data
if __name__ == "__main__":
test_health()
test_predict()
print("βœ… All tests passed")
```
Run tests:
```bash
python test_api.py
```
---
**Last Updated:** February 22, 2026
**Status:** Production Ready βœ