| # Backend Integration Guide | |
| Complete guide for backend developers integrating the retinal disease classifier. | |
| --- | |
| ## Table of Contents | |
| 1. [Architecture Overview](#architecture-overview) | |
| 2. [FastAPI Implementation](#fastapi-implementation) | |
| 3. [Flask Implementation](#flask-implementation) | |
| 4. [API Endpoints](#api-endpoints) | |
| 5. [Payload Formats](#payload-formats) | |
| 6. [Error Handling](#error-handling) | |
| 7. [Deployment](#deployment) | |
| 8. [Performance Optimization](#performance-optimization) | |
| --- | |
| ## Architecture Overview | |
| ``` | |
| βββββββββββββββββββββββ | |
| β Frontend/Client β | |
| ββββββββββββ¬βββββββββββ | |
| β HTTP | |
| βΌ | |
| βββββββββββββββββββββββ | |
| β API Server β (FastAPI/Flask) | |
| ββββββββββββ¬βββββββββββ | |
| β | |
| βΌ | |
| βββββββββββββββββββββββ | |
| β Model Inference β (PyTorch) | |
| ββββββββββββ¬βββββββββββ | |
| β | |
| βΌ | |
| βββββββββββββββββββββββ | |
| β GPU/CPU Device β | |
| βββββββββββββββββββββββ | |
| ``` | |
| --- | |
| ## FastAPI Implementation | |
| ### Installation | |
| ```bash | |
| pip install fastapi uvicorn python-multipart pillow torch torchvision albumentations | |
| ``` | |
| ### Basic Server | |
| ```python | |
| # app.py | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| import albumentations as A | |
| from albumentations.pytorch import ToTensorV2 | |
| # Load model once at startup | |
| app = FastAPI(title="Retinal Disease Classifier API", version="1.0") | |
| # CORS configuration | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global model | |
| MODEL = None | |
| DEVICE = None | |
| @app.on_event("startup") | |
| async def load_model(): | |
| global MODEL, DEVICE | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| MODEL = torch.load("pytorch_model.bin", map_location=DEVICE) | |
| MODEL.eval() | |
| print(f"Model loaded on {DEVICE}") | |
| # Response schemas | |
| class DiseaseResult(BaseModel): | |
| disease_risk: bool | |
| predictions: dict | |
| detected_diseases: list | |
| num_detected: int | |
| confidence: float | |
| # Disease list | |
| DISEASE_NAMES = [ | |
| "DR", "ARMD", "MH", "DN", "MYA", "BRVO", "TSLN", "ERM", "LS", "MS", | |
| "CSR", "ODC", "CRVO", "TV", "AH", "ODP", "ODE", "ST", "AION", "PT", | |
| "RT", "RS", "CRS", "EDN", "RPEC", "MHL", "RP", "CWS", "CB", "ODPM", | |
| "PRH", "MNF", "HR", "CRAO", "TD", "CME", "PTCR", "CF", "VH", "MCA", | |
| "VS", "BRAO", "PLQ", "HPED", "CL", | |
| ] | |
| @app.get("/health") | |
| async def health_check(): | |
| return { | |
| "status": "healthy", | |
| "model_loaded": MODEL is not None, | |
| "device": str(DEVICE), | |
| } | |
| @app.post("/predict", response_model=DiseaseResult) | |
| async def predict(file: UploadFile = File(...), threshold: float = 0.5): | |
| """ | |
| Predict diseases in retinal image. | |
| Args: | |
| file: PNG/JPG image file | |
| threshold: Disease detection threshold (0-1) | |
| Returns: | |
| DiseaseResult with predictions | |
| """ | |
| try: | |
| # Validate input | |
| if not file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="File must be an image") | |
| if threshold < 0 or threshold > 1: | |
| raise HTTPException(status_code=400, detail="Threshold must be 0-1") | |
| # Load image | |
| image_data = await file.read() | |
| image = Image.open(io.BytesIO(image_data)).convert("RGB") | |
| # Preprocess | |
| transform = A.Compose([ | |
| A.Resize(384, 384), | |
| A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
| ToTensorV2(), | |
| ]) | |
| image_array = np.array(image) | |
| tensor = transform(image=image_array)["image"].unsqueeze(0).to(DEVICE) | |
| # Inference | |
| with torch.no_grad(): | |
| logits = MODEL(tensor) | |
| probs = torch.sigmoid(logits)[0].cpu().numpy() | |
| # Parse results | |
| predictions = {name: float(prob) for name, prob in zip(DISEASE_NAMES, probs)} | |
| detected = [name for name, prob in predictions.items() if prob >= threshold] | |
| return DiseaseResult( | |
| disease_risk=len(detected) > 0, | |
| predictions=predictions, | |
| detected_diseases=detected, | |
| num_detected=len(detected), | |
| confidence=float(np.mean([predictions[d] for d in detected]) if detected else 0.0), | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| @app.post("/predict-batch") | |
| async def predict_batch(files: list[UploadFile] = File(...), threshold: float = 0.5): | |
| """Batch prediction for multiple images.""" | |
| results = [] | |
| for file in files: | |
| result = await predict(file, threshold) | |
| results.append(result) | |
| return {"results": results} | |
| @app.get("/info") | |
| async def model_info(): | |
| return { | |
| "model_name": "Retinal Disease Classifier", | |
| "version": "1.0", | |
| "num_classes": 45, | |
| "diseases": DISEASE_NAMES, | |
| "input_size": 384, | |
| "metrics": { | |
| "best_auc": 0.8204, | |
| "train_loss": 0.2118, | |
| "val_loss": 0.2578, | |
| }, | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |
| ``` | |
| ### Run Server | |
| ```bash | |
| # Development | |
| uvicorn app:app --reload | |
| # Production | |
| uvicorn app:app --host 0.0.0.0 --port 8000 --workers 4 | |
| ``` | |
| ### Access API | |
| ```bash | |
| # Health check | |
| curl http://localhost:8000/health | |
| # Single prediction | |
| curl -X POST http://localhost:8000/predict \ | |
| -F "file=@image.png" \ | |
| -F "threshold=0.5" | |
| # Model info | |
| curl http://localhost:8000/info | |
| # Swagger docs | |
| http://localhost:8000/docs | |
| ``` | |
| --- | |
| ## Flask Implementation | |
| ### Installation | |
| ```bash | |
| pip install flask torch torchvision pillow albumentations | |
| ``` | |
| ### Basic Server | |
| ```python | |
| # app.py | |
| from flask import Flask, request, jsonify | |
| from werkzeug.utils import secure_filename | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| import albumentations as A | |
| from albumentations.pytorch import ToTensorV2 | |
| import os | |
| app = Flask(__name__) | |
| app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size | |
| # Load model | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| MODEL = torch.load("pytorch_model.bin", map_location=DEVICE) | |
| MODEL.eval() | |
| DISEASE_NAMES = [ | |
| "DR", "ARMD", "MH", "DN", "MYA", "BRVO", "TSLN", "ERM", "LS", "MS", | |
| "CSR", "ODC", "CRVO", "TV", "AH", "ODP", "ODE", "ST", "AION", "PT", | |
| "RT", "RS", "CRS", "EDN", "RPEC", "MHL", "RP", "CWS", "CB", "ODPM", | |
| "PRH", "MNF", "HR", "CRAO", "TD", "CME", "PTCR", "CF", "VH", "MCA", | |
| "VS", "BRAO", "PLQ", "HPED", "CL", | |
| ] | |
| @app.route("/health", methods=["GET"]) | |
| def health(): | |
| return jsonify({ | |
| "status": "healthy", | |
| "device": str(DEVICE), | |
| }) | |
| @app.route("/predict", methods=["POST"]) | |
| def predict(): | |
| """Single image prediction.""" | |
| try: | |
| # Get parameters | |
| threshold = float(request.args.get("threshold", 0.5)) | |
| # Validate | |
| if "file" not in request.files: | |
| return jsonify({"error": "No file provided"}), 400 | |
| file = request.files["file"] | |
| if file.filename == "": | |
| return jsonify({"error": "No file selected"}), 400 | |
| if threshold < 0 or threshold > 1: | |
| return jsonify({"error": "Threshold must be 0-1"}), 400 | |
| # Load image | |
| image = Image.open(io.BytesIO(file.read())).convert("RGB") | |
| # Preprocess | |
| transform = A.Compose([ | |
| A.Resize(384, 384), | |
| A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
| ToTensorV2(), | |
| ]) | |
| image_array = np.array(image) | |
| tensor = transform(image=image_array)["image"].unsqueeze(0).to(DEVICE) | |
| # Inference | |
| with torch.no_grad(): | |
| logits = MODEL(tensor) | |
| probs = torch.sigmoid(logits)[0].cpu().numpy() | |
| # Results | |
| predictions = {name: float(prob) for name, prob in zip(DISEASE_NAMES, probs)} | |
| detected = [name for name, prob in predictions.items() if prob >= threshold] | |
| return jsonify({ | |
| "disease_risk": len(detected) > 0, | |
| "predictions": predictions, | |
| "detected_diseases": detected, | |
| "num_detected": len(detected), | |
| }) | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| @app.route("/info", methods=["GET"]) | |
| def info(): | |
| return jsonify({ | |
| "model_name": "Retinal Disease Classifier", | |
| "version": "1.0", | |
| "num_classes": 45, | |
| "diseases": DISEASE_NAMES, | |
| "metrics": {"best_auc": 0.8204}, | |
| }) | |
| if __name__ == "__main__": | |
| app.run(debug=False, host="0.0.0.0", port=8000) | |
| ``` | |
| ### Run Server | |
| ```bash | |
| # Development | |
| python app.py | |
| # Production (with gunicorn) | |
| pip install gunicorn | |
| gunicorn -w 4 -b 0.0.0.0:8000 app:app | |
| ``` | |
| --- | |
| ## API Endpoints | |
| ### POST /predict | |
| Predict diseases in single image. | |
| **Request:** | |
| ```bash | |
| curl -X POST http://localhost:8000/predict \ | |
| -F "file=@fundus.png" \ | |
| -F "threshold=0.5" | |
| ``` | |
| **Response (200 OK):** | |
| ```json | |
| { | |
| "disease_risk": true, | |
| "predictions": { | |
| "DR": 0.993, | |
| "CRVO": 0.899, | |
| "LS": 0.859, | |
| ... | |
| }, | |
| "detected_diseases": ["DR", "CRVO", "LS"], | |
| "num_detected": 3 | |
| } | |
| ``` | |
| **Errors:** | |
| - 400: Invalid file or threshold | |
| - 500: Server error | |
| --- | |
| ### POST /predict-batch (FastAPI only) | |
| **Request:** | |
| ```bash | |
| curl -X POST http://localhost:8000/predict-batch \ | |
| -F "files=@image1.png" \ | |
| -F "files=@image2.png" \ | |
| -F "threshold=0.5" | |
| ``` | |
| **Response:** | |
| ```json | |
| { | |
| "results": [ | |
| { "disease_risk": true, ... }, | |
| { "disease_risk": false, ... } | |
| ] | |
| } | |
| ``` | |
| --- | |
| ### GET /health | |
| Check API status. | |
| **Response:** | |
| ```json | |
| { | |
| "status": "healthy", | |
| "device": "cuda", | |
| "model_loaded": true | |
| } | |
| ``` | |
| --- | |
| ### GET /info | |
| Get model information. | |
| **Response:** | |
| ```json | |
| { | |
| "model_name": "Retinal Disease Classifier", | |
| "version": "1.0", | |
| "num_classes": 45, | |
| "diseases": ["DR", "ARMD", ...], | |
| "metrics": { | |
| "best_auc": 0.8204, | |
| "train_loss": 0.2118, | |
| "val_loss": 0.2578 | |
| } | |
| } | |
| ``` | |
| --- | |
| ## Payload Formats | |
| ### Input | |
| ``` | |
| Content-Type: multipart/form-data | |
| - file: Binary image data (PNG/JPG) | |
| - threshold: Float 0-1 (optional, default 0.5) | |
| ``` | |
| ### Output | |
| ```json | |
| { | |
| "disease_risk": boolean, | |
| "predictions": { | |
| "disease_name": probability (0-1), | |
| ... | |
| }, | |
| "detected_diseases": [disease_names], | |
| "num_detected": integer, | |
| "confidence": average_probability (optional) | |
| } | |
| ``` | |
| --- | |
| ## Error Handling | |
| ```python | |
| class APIError(Exception): | |
| def __init__(self, message, status_code=500): | |
| self.message = message | |
| self.status_code = status_code | |
| @app.errorhandler(APIError) | |
| def handle_error(error): | |
| return jsonify({"error": error.message}), error.status_code | |
| # Usage | |
| if not valid: | |
| raise APIError("Invalid threshold", 400) | |
| ``` | |
| --- | |
| ## Deployment | |
| ### Docker | |
| ```dockerfile | |
| FROM python:3.10 | |
| WORKDIR /app | |
| COPY requirements.txt . | |
| RUN pip install -r requirements.txt | |
| COPY . . | |
| CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"] | |
| ``` | |
| Build & Run: | |
| ```bash | |
| docker build -t retinal-classifier . | |
| docker run -p 8000:8000 --gpus all retinal-classifier | |
| ``` | |
| ### Docker Compose | |
| ```yaml | |
| version: '3' | |
| services: | |
| api: | |
| build: . | |
| ports: | |
| - "8000:8000" | |
| environment: | |
| - CUDA_VISIBLE_DEVICES=0 | |
| volumes: | |
| - ./pytorch_model.bin:/app/pytorch_model.bin | |
| deploy: | |
| resources: | |
| reservations: | |
| devices: | |
| - driver: nvidia | |
| count: 1 | |
| capabilities: [gpu] | |
| ``` | |
| --- | |
| ## Performance Optimization | |
| ### Batch Processing | |
| ```python | |
| @app.post("/predict-batch") | |
| async def batch_predict(files: list[UploadFile]): | |
| # Load all images | |
| tensors = [] | |
| for file in files: | |
| image = Image.open(io.BytesIO(await file.read())).convert("RGB") | |
| tensor = transform(image=np.array(image))["image"] | |
| tensors.append(tensor) | |
| # Batch inference | |
| batch = torch.stack(tensors).to(DEVICE) | |
| with torch.no_grad(): | |
| logits = MODEL(batch) | |
| probs = torch.sigmoid(logits).cpu().numpy() | |
| return probs # (N, 45) | |
| ``` | |
| ### Caching | |
| ```python | |
| from functools import lru_cache | |
| @lru_cache(maxsize=100) | |
| def get_disease_info(disease_name): | |
| return DISEASE_INFO.get(disease_name, {}) | |
| ``` | |
| ### Async Processing | |
| ```python | |
| import asyncio | |
| from concurrent.futures import ThreadPoolExecutor | |
| executor = ThreadPoolExecutor(max_workers=4) | |
| @app.post("/predict-async") | |
| async def predict_async(file: UploadFile): | |
| # Run inference in thread pool | |
| result = await asyncio.get_event_loop().run_in_executor( | |
| executor, | |
| predict_sync, | |
| file | |
| ) | |
| return result | |
| ``` | |
| ### GPU Optimization | |
| ```python | |
| # Use mixed precision | |
| from torch.cuda.amp import autocast | |
| with autocast(): | |
| logits = MODEL(tensor) | |
| probs = torch.sigmoid(logits) | |
| # Batch inference | |
| batch_size = 32 | |
| for i in range(0, len(images), batch_size): | |
| batch = images[i:i+batch_size] | |
| # Process batch | |
| ``` | |
| --- | |
| ## Testing | |
| ```python | |
| # test_api.py | |
| import requests | |
| import io | |
| from PIL import Image | |
| BASE_URL = "http://localhost:8000" | |
| def test_health(): | |
| response = requests.get(f"{BASE_URL}/health") | |
| assert response.status_code == 200 | |
| def test_predict(): | |
| # Create dummy image | |
| image = Image.new("RGB", (384, 384)) | |
| img_bytes = io.BytesIO() | |
| image.save(img_bytes, format="PNG") | |
| img_bytes.seek(0) | |
| files = {"file": ("test.png", img_bytes, "image/png")} | |
| response = requests.post(f"{BASE_URL}/predict", files=files) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert "disease_risk" in data | |
| assert "predictions" in data | |
| if __name__ == "__main__": | |
| test_health() | |
| test_predict() | |
| print("β All tests passed") | |
| ``` | |
| Run tests: | |
| ```bash | |
| python test_api.py | |
| ``` | |
| --- | |
| **Last Updated:** February 22, 2026 | |
| **Status:** Production Ready β | |