Spaces:
Configuration error
Configuration error
File size: 2,794 Bytes
ae8f111 412ca3a ae8f111 957df8a 18a45f0 412ca3a ae8f111 3dee463 ae8f111 412ca3a ae8f111 412ca3a ae8f111 412ca3a ae8f111 412ca3a ae8f111 412ca3a ae8f111 3dee463 412ca3a ab15865 64f36cb 412ca3a 3dee463 957df8a 3dee463 957df8a 3dee463 957df8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import gradio as gr
from PIL import Image
import torch
import torch.nn.functional as F
import numpy as np
from torchvision import models, transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import os
import datetime
# Setup
device = torch.device("cpu")
# Create save directory in current working directory for cross-platform compatibility
save_dir = os.path.join(os.getcwd(), "saved_predictions")
if not os.path.exists(save_dir):
os.makedirs(save_dir)
print("📁 Folder created:", save_dir)
os.makedirs(save_dir, exist_ok=True)
# Load model
model = models.resnet50(weights=None)
model.fc = torch.nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load("resnet50_dr_classifier.pth", map_location=device))
model.to(device)
model.eval()
# Grad-CAM
target_layer = model.layer4[-1]
cam = GradCAM(model=model, target_layers=[target_layer])
# Preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# Predict and save
def predict_retinopathy(image):
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
img = image.convert("RGB").resize((224, 224))
img_tensor = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
output = model(img_tensor)
probs = F.softmax(output, dim=1)
pred = torch.argmax(probs, dim=1).item()
confidence = probs[0][pred].item()
label = "DR" if pred == 0 else "NoDR"
# Grad-CAM
rgb_img_np = np.array(img).astype(np.float32) / 255.0
rgb_img_np = np.ascontiguousarray(rgb_img_np)
grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(pred)])[0]
cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
cam_pil = Image.fromarray(cam_image)
# Save image with label and confidence
filename = f"{timestamp}_{label}_{confidence:.2f}.png"
cam_pil.save(os.path.join(save_dir, filename))
return cam_pil, f"{label} (Confidence: {confidence:.2f})"
# Gradio app
gr.Interface(
fn=predict_retinopathy,
inputs=gr.Image(type="pil", label="Upload OCT Image"),
outputs=[
gr.Image(type="pil", label="Grad-CAM Heatmap"),
gr.Text(label="Diabetic Retinopathy Prediction")
],
title="AI Diabetic Retinopathy Detection",
description="Upload an OCT image to analyze for diabetic retinopathy. The AI will show a Grad-CAM heatmap highlighting areas of interest.",
examples=[
["example_oct.jpg"] if os.path.exists("example_oct.jpg") else None
]
).launch(server_name="0.0.0.0", server_port=7860)
|