File size: 6,294 Bytes
3028f96 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | #!/usr/bin/env python3
"""
DrRetina β AMD MI300X Inference Server
Runs on AMD GPU server; exposes REST API for HF Spaces Gradio UI.
Usage (on AMD server):
pip install fastapi uvicorn python-multipart pillow
python3 inference_server.py
"""
import os, io, base64, cv2, numpy as np
from PIL import Image
import torch, torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from transformers import ViTMAEModel
import matplotlib; matplotlib.use("Agg")
import matplotlib.cm as cm
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
import uvicorn
# ββ Constants βββββββββββββββββββββββββββββββββββββββββββββββββββββ
GRADES = {
0: "No DR", 1: "Mild DR", 2: "Moderate DR",
3: "Severe DR", 4: "Proliferative DR",
}
CHECKPOINT = os.path.join(os.path.dirname(__file__), "checkpoints", "best_model.pth")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ββ Model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class DRClassifier(nn.Module):
def __init__(self, num_classes=5, dropout=0.3):
super().__init__()
self.backbone = ViTMAEModel.from_pretrained("facebook/vit-mae-base")
hidden = self.backbone.config.hidden_size
self.classifier = nn.Sequential(
nn.Linear(hidden, 256), nn.BatchNorm1d(256),
nn.ReLU(), nn.Dropout(dropout), nn.Linear(256, num_classes),
)
def forward(self, pixel_values):
out = self.backbone(pixel_values=pixel_values, noise=None)
return self.classifier(out.last_hidden_state[:, 0, :])
print(f"[Server] Loading model on {device}...")
model = DRClassifier().to(device)
ckpt = torch.load(CHECKPOINT, map_location=device, weights_only=True)
model.load_state_dict(ckpt.get("model_state_dict", ckpt), strict=False)
model.eval()
print("[Server] Model ready β
")
# ββ Preprocessing ββββββββββββββββββββββββββββββββββββββββββββββββββ
def circle_crop(img_bgr):
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
_, thresh = cv2.threshold(gray, 15, 255, cv2.THRESH_BINARY)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return img_bgr
x, y, w, h = cv2.boundingRect(max(contours, key=cv2.contourArea))
return img_bgr[y:y+h, x:x+w]
def apply_clahe(img_bgr):
lab = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2LAB)
l, a, b = cv2.split(lab)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
return cv2.cvtColor(cv2.merge([clahe.apply(l), a, b]), cv2.COLOR_LAB2BGR)
INFER_TF = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
def preprocess(pil_img):
bgr = cv2.cvtColor(np.array(pil_img.convert("RGB")), cv2.COLOR_RGB2BGR)
bgr = apply_clahe(circle_crop(bgr))
bgr = cv2.resize(bgr, (224, 224))
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
return Image.fromarray(rgb), INFER_TF(Image.fromarray(rgb)).unsqueeze(0).to(device)
# ββ GradCAM ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class ViTGradCAM:
def __init__(self, m):
self.m = m; self._f = self._g = None
layer = m.backbone.encoder.layer[-1]
layer.register_forward_hook(
lambda *a: setattr(self, "_f", a[2][0] if isinstance(a[2], tuple) else a[2]))
layer.register_full_backward_hook(
lambda *a: setattr(self, "_g", a[2][0]))
def generate(self, tensor, cls_idx):
self.m.zero_grad()
self.m(tensor)[0, cls_idx].backward()
g = self._g[0, 1:, :]; f = self._f[0, 1:, :]
cam = F.relu((g.mean(-1).unsqueeze(-1) * f).sum(-1))
cam = cam.reshape(14, 14).detach().cpu().numpy()
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
return cv2.resize(cam, (224, 224))
def overlay_heatmap(pil224, cam):
img = np.array(pil224).astype(np.float32)
heat = (cm.jet(cam)[:, :, :3] * 255).astype(np.float32)
return Image.fromarray((0.55 * img + 0.45 * heat).clip(0, 255).astype(np.uint8))
def pil_to_b64(img: Image.Image) -> str:
buf = io.BytesIO()
img.save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode()
# ββ FastAPI App ββββββββββββββββββββββββββββββββββββββββββββββββββββ
app = FastAPI(title="DrRetina Inference API")
@app.get("/health")
def health():
return {"status": "ok", "device": str(device), "model": "ViT-MAE DR Classifier"}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
try:
contents = await file.read()
pil_img = Image.open(io.BytesIO(contents))
pil224, tensor = preprocess(pil_img)
gradcam = ViTGradCAM(model)
with torch.set_grad_enabled(True):
logits = model(tensor)
probs = F.softmax(logits, dim=-1)[0].detach().cpu().numpy()
grade = int(probs.argmax())
cam = gradcam.generate(tensor.clone(), grade)
cam_pil = overlay_heatmap(pil224, cam)
return JSONResponse({
"grade": grade,
"grade_name": GRADES[grade],
"confidence": float(probs[grade]),
"probs": [float(p) for p in probs],
"image_b64": pil_to_b64(pil224),
"cam_b64": pil_to_b64(cam_pil),
})
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
if __name__ == "__main__":
port = int(os.environ.get("PORT", 8000))
print(f"[Server] Starting on 0.0.0.0:{port}")
uvicorn.run(app, host="0.0.0.0", port=port)
|