DrRetina / inference_server.py
masimhanif's picture
Upload folder using huggingface_hub
3028f96 verified
#!/usr/bin/env python3
"""
DrRetina – AMD MI300X Inference Server
Runs on AMD GPU server; exposes REST API for HF Spaces Gradio UI.
Usage (on AMD server):
pip install fastapi uvicorn python-multipart pillow
python3 inference_server.py
"""
import os, io, base64, cv2, numpy as np
from PIL import Image
import torch, torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from transformers import ViTMAEModel
import matplotlib; matplotlib.use("Agg")
import matplotlib.cm as cm
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
import uvicorn
# ── Constants ─────────────────────────────────────────────────────
GRADES = {
0: "No DR", 1: "Mild DR", 2: "Moderate DR",
3: "Severe DR", 4: "Proliferative DR",
}
CHECKPOINT = os.path.join(os.path.dirname(__file__), "checkpoints", "best_model.pth")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ── Model ──────────────────────────────────────────────────────────
class DRClassifier(nn.Module):
def __init__(self, num_classes=5, dropout=0.3):
super().__init__()
self.backbone = ViTMAEModel.from_pretrained("facebook/vit-mae-base")
hidden = self.backbone.config.hidden_size
self.classifier = nn.Sequential(
nn.Linear(hidden, 256), nn.BatchNorm1d(256),
nn.ReLU(), nn.Dropout(dropout), nn.Linear(256, num_classes),
)
def forward(self, pixel_values):
out = self.backbone(pixel_values=pixel_values, noise=None)
return self.classifier(out.last_hidden_state[:, 0, :])
print(f"[Server] Loading model on {device}...")
model = DRClassifier().to(device)
ckpt = torch.load(CHECKPOINT, map_location=device, weights_only=True)
model.load_state_dict(ckpt.get("model_state_dict", ckpt), strict=False)
model.eval()
print("[Server] Model ready βœ…")
# ── Preprocessing ──────────────────────────────────────────────────
def circle_crop(img_bgr):
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
_, thresh = cv2.threshold(gray, 15, 255, cv2.THRESH_BINARY)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return img_bgr
x, y, w, h = cv2.boundingRect(max(contours, key=cv2.contourArea))
return img_bgr[y:y+h, x:x+w]
def apply_clahe(img_bgr):
lab = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2LAB)
l, a, b = cv2.split(lab)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
return cv2.cvtColor(cv2.merge([clahe.apply(l), a, b]), cv2.COLOR_LAB2BGR)
INFER_TF = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
def preprocess(pil_img):
bgr = cv2.cvtColor(np.array(pil_img.convert("RGB")), cv2.COLOR_RGB2BGR)
bgr = apply_clahe(circle_crop(bgr))
bgr = cv2.resize(bgr, (224, 224))
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
return Image.fromarray(rgb), INFER_TF(Image.fromarray(rgb)).unsqueeze(0).to(device)
# ── GradCAM ────────────────────────────────────────────────────────
class ViTGradCAM:
def __init__(self, m):
self.m = m; self._f = self._g = None
layer = m.backbone.encoder.layer[-1]
layer.register_forward_hook(
lambda *a: setattr(self, "_f", a[2][0] if isinstance(a[2], tuple) else a[2]))
layer.register_full_backward_hook(
lambda *a: setattr(self, "_g", a[2][0]))
def generate(self, tensor, cls_idx):
self.m.zero_grad()
self.m(tensor)[0, cls_idx].backward()
g = self._g[0, 1:, :]; f = self._f[0, 1:, :]
cam = F.relu((g.mean(-1).unsqueeze(-1) * f).sum(-1))
cam = cam.reshape(14, 14).detach().cpu().numpy()
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
return cv2.resize(cam, (224, 224))
def overlay_heatmap(pil224, cam):
img = np.array(pil224).astype(np.float32)
heat = (cm.jet(cam)[:, :, :3] * 255).astype(np.float32)
return Image.fromarray((0.55 * img + 0.45 * heat).clip(0, 255).astype(np.uint8))
def pil_to_b64(img: Image.Image) -> str:
buf = io.BytesIO()
img.save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode()
# ── FastAPI App ────────────────────────────────────────────────────
app = FastAPI(title="DrRetina Inference API")
@app.get("/health")
def health():
return {"status": "ok", "device": str(device), "model": "ViT-MAE DR Classifier"}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
try:
contents = await file.read()
pil_img = Image.open(io.BytesIO(contents))
pil224, tensor = preprocess(pil_img)
gradcam = ViTGradCAM(model)
with torch.set_grad_enabled(True):
logits = model(tensor)
probs = F.softmax(logits, dim=-1)[0].detach().cpu().numpy()
grade = int(probs.argmax())
cam = gradcam.generate(tensor.clone(), grade)
cam_pil = overlay_heatmap(pil224, cam)
return JSONResponse({
"grade": grade,
"grade_name": GRADES[grade],
"confidence": float(probs[grade]),
"probs": [float(p) for p in probs],
"image_b64": pil_to_b64(pil224),
"cam_b64": pil_to_b64(cam_pil),
})
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
if __name__ == "__main__":
port = int(os.environ.get("PORT", 8000))
print(f"[Server] Starting on 0.0.0.0:{port}")
uvicorn.run(app, host="0.0.0.0", port=port)