File size: 2,794 Bytes
ae8f111
 
 
 
 
 
 
 
 
 
412ca3a
 
 
 
ae8f111
957df8a
 
18a45f0
 
 
412ca3a
ae8f111
3dee463
ae8f111
 
 
 
 
 
412ca3a
ae8f111
 
 
412ca3a
ae8f111
 
 
 
 
 
 
412ca3a
ae8f111
412ca3a
ae8f111
 
 
 
 
 
 
 
 
412ca3a
ae8f111
 
 
 
 
 
3dee463
412ca3a
 
 
 
 
ab15865
64f36cb
412ca3a
3dee463
 
957df8a
3dee463
957df8a
 
3dee463
957df8a
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gradio as gr
from PIL import Image
import torch
import torch.nn.functional as F
import numpy as np
from torchvision import models, transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

import os
import datetime

# Setup
device = torch.device("cpu")
# Create save directory in current working directory for cross-platform compatibility
save_dir = os.path.join(os.getcwd(), "saved_predictions")
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
    print("📁 Folder created:", save_dir)
os.makedirs(save_dir, exist_ok=True)

# Load model
model = models.resnet50(weights=None)
model.fc = torch.nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load("resnet50_dr_classifier.pth", map_location=device))
model.to(device)
model.eval()

# Grad-CAM
target_layer = model.layer4[-1]
cam = GradCAM(model=model, target_layers=[target_layer])

# Preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# Predict and save
def predict_retinopathy(image):
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    img = image.convert("RGB").resize((224, 224))
    img_tensor = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(img_tensor)
        probs = F.softmax(output, dim=1)
        pred = torch.argmax(probs, dim=1).item()
        confidence = probs[0][pred].item()

    label = "DR" if pred == 0 else "NoDR"

    # Grad-CAM
    rgb_img_np = np.array(img).astype(np.float32) / 255.0
    rgb_img_np = np.ascontiguousarray(rgb_img_np)
    grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(pred)])[0]
    cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
    cam_pil = Image.fromarray(cam_image)

    # Save image with label and confidence
    filename = f"{timestamp}_{label}_{confidence:.2f}.png"
    cam_pil.save(os.path.join(save_dir, filename))

    return cam_pil, f"{label} (Confidence: {confidence:.2f})"

# Gradio app
gr.Interface(
    fn=predict_retinopathy,
    inputs=gr.Image(type="pil", label="Upload OCT Image"),
    outputs=[
        gr.Image(type="pil", label="Grad-CAM Heatmap"),
        gr.Text(label="Diabetic Retinopathy Prediction")
    ],
    title="AI Diabetic Retinopathy Detection",
    description="Upload an OCT image to analyze for diabetic retinopathy. The AI will show a Grad-CAM heatmap highlighting areas of interest.",
    examples=[
        ["example_oct.jpg"] if os.path.exists("example_oct.jpg") else None
    ]
).launch(server_name="0.0.0.0", server_port=7860)