|
|
| """
|
| DrRetina β AMD MI300X Inference Server
|
| Runs on AMD GPU server; exposes REST API for HF Spaces Gradio UI.
|
|
|
| Usage (on AMD server):
|
| pip install fastapi uvicorn python-multipart pillow
|
| python3 inference_server.py
|
| """
|
|
|
| import os, io, base64, cv2, numpy as np
|
| from PIL import Image
|
| import torch, torch.nn as nn
|
| import torch.nn.functional as F
|
| from torchvision import transforms
|
| from transformers import ViTMAEModel
|
| import matplotlib; matplotlib.use("Agg")
|
| import matplotlib.cm as cm
|
| from fastapi import FastAPI, UploadFile, File
|
| from fastapi.responses import JSONResponse
|
| import uvicorn
|
|
|
|
|
| GRADES = {
|
| 0: "No DR", 1: "Mild DR", 2: "Moderate DR",
|
| 3: "Severe DR", 4: "Proliferative DR",
|
| }
|
| CHECKPOINT = os.path.join(os.path.dirname(__file__), "checkpoints", "best_model.pth")
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
| class DRClassifier(nn.Module):
|
| def __init__(self, num_classes=5, dropout=0.3):
|
| super().__init__()
|
| self.backbone = ViTMAEModel.from_pretrained("facebook/vit-mae-base")
|
| hidden = self.backbone.config.hidden_size
|
| self.classifier = nn.Sequential(
|
| nn.Linear(hidden, 256), nn.BatchNorm1d(256),
|
| nn.ReLU(), nn.Dropout(dropout), nn.Linear(256, num_classes),
|
| )
|
| def forward(self, pixel_values):
|
| out = self.backbone(pixel_values=pixel_values, noise=None)
|
| return self.classifier(out.last_hidden_state[:, 0, :])
|
|
|
|
|
| print(f"[Server] Loading model on {device}...")
|
| model = DRClassifier().to(device)
|
| ckpt = torch.load(CHECKPOINT, map_location=device, weights_only=True)
|
| model.load_state_dict(ckpt.get("model_state_dict", ckpt), strict=False)
|
| model.eval()
|
| print("[Server] Model ready β
")
|
|
|
|
|
| def circle_crop(img_bgr):
|
| gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
|
| _, thresh = cv2.threshold(gray, 15, 255, cv2.THRESH_BINARY)
|
| contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| if not contours:
|
| return img_bgr
|
| x, y, w, h = cv2.boundingRect(max(contours, key=cv2.contourArea))
|
| return img_bgr[y:y+h, x:x+w]
|
|
|
| def apply_clahe(img_bgr):
|
| lab = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2LAB)
|
| l, a, b = cv2.split(lab)
|
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
| return cv2.cvtColor(cv2.merge([clahe.apply(l), a, b]), cv2.COLOR_LAB2BGR)
|
|
|
| INFER_TF = transforms.Compose([
|
| transforms.ToTensor(),
|
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
| ])
|
|
|
| def preprocess(pil_img):
|
| bgr = cv2.cvtColor(np.array(pil_img.convert("RGB")), cv2.COLOR_RGB2BGR)
|
| bgr = apply_clahe(circle_crop(bgr))
|
| bgr = cv2.resize(bgr, (224, 224))
|
| rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
|
| return Image.fromarray(rgb), INFER_TF(Image.fromarray(rgb)).unsqueeze(0).to(device)
|
|
|
|
|
| class ViTGradCAM:
|
| def __init__(self, m):
|
| self.m = m; self._f = self._g = None
|
| layer = m.backbone.encoder.layer[-1]
|
| layer.register_forward_hook(
|
| lambda *a: setattr(self, "_f", a[2][0] if isinstance(a[2], tuple) else a[2]))
|
| layer.register_full_backward_hook(
|
| lambda *a: setattr(self, "_g", a[2][0]))
|
|
|
| def generate(self, tensor, cls_idx):
|
| self.m.zero_grad()
|
| self.m(tensor)[0, cls_idx].backward()
|
| g = self._g[0, 1:, :]; f = self._f[0, 1:, :]
|
| cam = F.relu((g.mean(-1).unsqueeze(-1) * f).sum(-1))
|
| cam = cam.reshape(14, 14).detach().cpu().numpy()
|
| cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
|
| return cv2.resize(cam, (224, 224))
|
|
|
| def overlay_heatmap(pil224, cam):
|
| img = np.array(pil224).astype(np.float32)
|
| heat = (cm.jet(cam)[:, :, :3] * 255).astype(np.float32)
|
| return Image.fromarray((0.55 * img + 0.45 * heat).clip(0, 255).astype(np.uint8))
|
|
|
| def pil_to_b64(img: Image.Image) -> str:
|
| buf = io.BytesIO()
|
| img.save(buf, format="PNG")
|
| return base64.b64encode(buf.getvalue()).decode()
|
|
|
|
|
| app = FastAPI(title="DrRetina Inference API")
|
|
|
| @app.get("/health")
|
| def health():
|
| return {"status": "ok", "device": str(device), "model": "ViT-MAE DR Classifier"}
|
|
|
| @app.post("/predict")
|
| async def predict(file: UploadFile = File(...)):
|
| try:
|
| contents = await file.read()
|
| pil_img = Image.open(io.BytesIO(contents))
|
| pil224, tensor = preprocess(pil_img)
|
|
|
| gradcam = ViTGradCAM(model)
|
| with torch.set_grad_enabled(True):
|
| logits = model(tensor)
|
| probs = F.softmax(logits, dim=-1)[0].detach().cpu().numpy()
|
| grade = int(probs.argmax())
|
| cam = gradcam.generate(tensor.clone(), grade)
|
| cam_pil = overlay_heatmap(pil224, cam)
|
|
|
| return JSONResponse({
|
| "grade": grade,
|
| "grade_name": GRADES[grade],
|
| "confidence": float(probs[grade]),
|
| "probs": [float(p) for p in probs],
|
| "image_b64": pil_to_b64(pil224),
|
| "cam_b64": pil_to_b64(cam_pil),
|
| })
|
| except Exception as e:
|
| return JSONResponse({"error": str(e)}, status_code=500)
|
|
|
|
|
| if __name__ == "__main__":
|
| port = int(os.environ.get("PORT", 8000))
|
| print(f"[Server] Starting on 0.0.0.0:{port}")
|
| uvicorn.run(app, host="0.0.0.0", port=port)
|
|
|