File size: 2,825 Bytes
ae8f111 412ca3a ae8f111 300c3a7 6a60023 412ca3a ae8f111 3dee463 ae8f111 412ca3a ae8f111 412ca3a ae8f111 e473c8a ae8f111 412ca3a ae8f111 412ca3a ae8f111 6a60023 ae8f111 3dee463 412ca3a 52fef39 64f36cb 412ca3a 6a60023 3dee463 8440cfb 3dee463 6a60023 3dee463 8440cfb e473c8a 6a60023 8440cfb 6a60023 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import gradio as gr
from PIL import Image
import torch
import torch.nn.functional as F
import numpy as np
from torchvision import models, transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import os
import datetime
# Setup
device = torch.device("cpu")
save_dir = "/home/user/app/saved_predictions"
if not os.path.exists(save_dir):
os.makedirs(save_dir)
print("π Folder created:", save_dir)
os.makedirs(save_dir, exist_ok=True)
# Load model
model = models.resnet50(weights=None)
model.fc = torch.nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load("resnet50_dr_classifier.pth", map_location=device))
model.to(device)
model.eval()
# Grad-CAM
target_layer = model.layer4[-1]
cam = GradCAM(model=model, target_layers=[target_layer])
# Preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# Predict and save
def predict_retinopathy(image):
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
img = image.convert("RGB").resize((224, 224))
img_tensor = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
output = model(img_tensor)
probs = F.softmax(output, dim=1)
pred = torch.argmax(probs, dim=1).item()
confidence = probs[0][pred].item()
label = "DR" if pred == 0 else "NoDR"
# Grad-CAM
rgb_img_np = np.array(img).astype(np.float32) / 255.0
rgb_img_np = np.ascontiguousarray(rgb_img_np)
grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(pred)])[0]
cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
cam_pil = Image.fromarray(cam_image)
# Save image with label and confidence
filename = f"{timestamp}_{label}_{confidence:.2f}.png"
cam_pil.save(os.path.join(save_dir, filename))
return cam_pil, f"{label} (Confidence: {confidence:.2f})"
# Gradio app
gr.Interface(
fn=predict_retinopathy,
inputs=gr.Image(type="pil", label="Upload Retinal Image"),
outputs=[
gr.Image(type="pil", label="Grad-CAM Heatmap"),
gr.Text(label="Prediction")
],
title="OpthaDetect β AI Retinal Screening",
description=(
"A lightweight AI system for detecting Diabetic Retinopathy from retinal images. "
"Upload an image to classify DR and visualise the Grad-CAM heatmap showing important regions."
),
article=(
"βοΈ **OpthaDetect** is an AI-powered ophthalmic decision-support tool. "
"It highlights retinal risk regions using Grad-CAM for better clinical interpretability."
)
).launch()
|