File size: 2,825 Bytes
ae8f111
 
 
 
 
 
 
 
 
 
412ca3a
 
 
 
ae8f111
300c3a7
6a60023
 
 
412ca3a
ae8f111
3dee463
ae8f111
 
 
 
 
 
412ca3a
ae8f111
 
 
412ca3a
ae8f111
 
 
 
 
 
 
e473c8a
ae8f111
412ca3a
ae8f111
 
 
 
 
 
 
 
 
412ca3a
ae8f111
 
 
 
6a60023
ae8f111
3dee463
412ca3a
 
 
 
 
52fef39
64f36cb
412ca3a
6a60023
3dee463
8440cfb
3dee463
6a60023
3dee463
 
8440cfb
 
 
 
 
 
e473c8a
6a60023
8440cfb
6a60023
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import gradio as gr
from PIL import Image
import torch
import torch.nn.functional as F
import numpy as np
from torchvision import models, transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

import os
import datetime

# Setup
device = torch.device("cpu")
save_dir = "/home/user/app/saved_predictions"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
    print("πŸ“ Folder created:", save_dir)
os.makedirs(save_dir, exist_ok=True)

# Load model
model = models.resnet50(weights=None)
model.fc = torch.nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load("resnet50_dr_classifier.pth", map_location=device))
model.to(device)
model.eval()

# Grad-CAM
target_layer = model.layer4[-1]
cam = GradCAM(model=model, target_layers=[target_layer])

# Preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# Predict and save
def predict_retinopathy(image):
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    img = image.convert("RGB").resize((224, 224))
    img_tensor = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(img_tensor)
        probs = F.softmax(output, dim=1)
        pred = torch.argmax(probs, dim=1).item()
        confidence = probs[0][pred].item()

    label = "DR" if pred == 0 else "NoDR"

    # Grad-CAM
    rgb_img_np = np.array(img).astype(np.float32) / 255.0
    rgb_img_np = np.ascontiguousarray(rgb_img_np)
    grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(pred)])[0]
    cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
    cam_pil = Image.fromarray(cam_image)

    # Save image with label and confidence
    filename = f"{timestamp}_{label}_{confidence:.2f}.png"
    cam_pil.save(os.path.join(save_dir, filename))

    return cam_pil, f"{label} (Confidence: {confidence:.2f})"

# Gradio app
gr.Interface(
    fn=predict_retinopathy,
    inputs=gr.Image(type="pil", label="Upload Retinal Image"),
    outputs=[
        gr.Image(type="pil", label="Grad-CAM Heatmap"),
        gr.Text(label="Prediction")
    ],
    title="OpthaDetect – AI Retinal Screening",
    description=(
        "A lightweight AI system for detecting Diabetic Retinopathy from retinal images. "
        "Upload an image to classify DR and visualise the Grad-CAM heatmap showing important regions."
    ),
    article=(
        "βš•οΈ **OpthaDetect** is an AI-powered ophthalmic decision-support tool. "
        "It highlights retinal risk regions using Grad-CAM for better clinical interpretability."
    )
).launch()