File size: 6,294 Bytes
3028f96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#!/usr/bin/env python3
"""

DrRetina – AMD MI300X Inference Server

Runs on AMD GPU server; exposes REST API for HF Spaces Gradio UI.



Usage (on AMD server):

    pip install fastapi uvicorn python-multipart pillow

    python3 inference_server.py

"""

import os, io, base64, cv2, numpy as np
from PIL import Image
import torch, torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from transformers import ViTMAEModel
import matplotlib; matplotlib.use("Agg")
import matplotlib.cm as cm
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
import uvicorn

# ── Constants ─────────────────────────────────────────────────────
GRADES = {
    0: "No DR",  1: "Mild DR",  2: "Moderate DR",
    3: "Severe DR",  4: "Proliferative DR",
}
CHECKPOINT = os.path.join(os.path.dirname(__file__), "checkpoints", "best_model.pth")
device     = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ── Model ──────────────────────────────────────────────────────────
class DRClassifier(nn.Module):
    def __init__(self, num_classes=5, dropout=0.3):
        super().__init__()
        self.backbone   = ViTMAEModel.from_pretrained("facebook/vit-mae-base")
        hidden          = self.backbone.config.hidden_size
        self.classifier = nn.Sequential(
            nn.Linear(hidden, 256), nn.BatchNorm1d(256),
            nn.ReLU(), nn.Dropout(dropout), nn.Linear(256, num_classes),
        )
    def forward(self, pixel_values):
        out = self.backbone(pixel_values=pixel_values, noise=None)
        return self.classifier(out.last_hidden_state[:, 0, :])


print(f"[Server] Loading model on {device}...")
model = DRClassifier().to(device)
ckpt  = torch.load(CHECKPOINT, map_location=device, weights_only=True)
model.load_state_dict(ckpt.get("model_state_dict", ckpt), strict=False)
model.eval()
print("[Server] Model ready βœ…")

# ── Preprocessing ──────────────────────────────────────────────────
def circle_crop(img_bgr):
    gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
    _, thresh = cv2.threshold(gray, 15, 255, cv2.THRESH_BINARY)
    contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if not contours:
        return img_bgr
    x, y, w, h = cv2.boundingRect(max(contours, key=cv2.contourArea))
    return img_bgr[y:y+h, x:x+w]

def apply_clahe(img_bgr):
    lab = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2LAB)
    l, a, b = cv2.split(lab)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    return cv2.cvtColor(cv2.merge([clahe.apply(l), a, b]), cv2.COLOR_LAB2BGR)

INFER_TF = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

def preprocess(pil_img):
    bgr = cv2.cvtColor(np.array(pil_img.convert("RGB")), cv2.COLOR_RGB2BGR)
    bgr = apply_clahe(circle_crop(bgr))
    bgr = cv2.resize(bgr, (224, 224))
    rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
    return Image.fromarray(rgb), INFER_TF(Image.fromarray(rgb)).unsqueeze(0).to(device)

# ── GradCAM ────────────────────────────────────────────────────────
class ViTGradCAM:
    def __init__(self, m):
        self.m = m; self._f = self._g = None
        layer = m.backbone.encoder.layer[-1]
        layer.register_forward_hook(
            lambda *a: setattr(self, "_f", a[2][0] if isinstance(a[2], tuple) else a[2]))
        layer.register_full_backward_hook(
            lambda *a: setattr(self, "_g", a[2][0]))

    def generate(self, tensor, cls_idx):
        self.m.zero_grad()
        self.m(tensor)[0, cls_idx].backward()
        g = self._g[0, 1:, :]; f = self._f[0, 1:, :]
        cam = F.relu((g.mean(-1).unsqueeze(-1) * f).sum(-1))
        cam = cam.reshape(14, 14).detach().cpu().numpy()
        cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
        return cv2.resize(cam, (224, 224))

def overlay_heatmap(pil224, cam):
    img  = np.array(pil224).astype(np.float32)
    heat = (cm.jet(cam)[:, :, :3] * 255).astype(np.float32)
    return Image.fromarray((0.55 * img + 0.45 * heat).clip(0, 255).astype(np.uint8))

def pil_to_b64(img: Image.Image) -> str:
    buf = io.BytesIO()
    img.save(buf, format="PNG")
    return base64.b64encode(buf.getvalue()).decode()

# ── FastAPI App ────────────────────────────────────────────────────
app = FastAPI(title="DrRetina Inference API")

@app.get("/health")
def health():
    return {"status": "ok", "device": str(device), "model": "ViT-MAE DR Classifier"}

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    try:
        contents = await file.read()
        pil_img  = Image.open(io.BytesIO(contents))
        pil224, tensor = preprocess(pil_img)

        gradcam = ViTGradCAM(model)
        with torch.set_grad_enabled(True):
            logits = model(tensor)
        probs = F.softmax(logits, dim=-1)[0].detach().cpu().numpy()
        grade = int(probs.argmax())
        cam   = gradcam.generate(tensor.clone(), grade)
        cam_pil = overlay_heatmap(pil224, cam)

        return JSONResponse({
            "grade":      grade,
            "grade_name": GRADES[grade],
            "confidence": float(probs[grade]),
            "probs":      [float(p) for p in probs],
            "image_b64":  pil_to_b64(pil224),
            "cam_b64":    pil_to_b64(cam_pil),
        })
    except Exception as e:
        return JSONResponse({"error": str(e)}, status_code=500)


if __name__ == "__main__":
    port = int(os.environ.get("PORT", 8000))
    print(f"[Server] Starting on 0.0.0.0:{port}")
    uvicorn.run(app, host="0.0.0.0", port=port)