| import gradio as gr |
| from PIL import Image |
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| from torchvision import models, transforms |
| from pytorch_grad_cam import GradCAM |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget |
| from pytorch_grad_cam.utils.image import show_cam_on_image |
|
|
| import csv |
| import datetime |
| import os |
|
|
| |
| device = torch.device("cpu") |
|
|
| |
| model = models.resnet50(weights=None) |
| model.fc = torch.nn.Linear(model.fc.in_features, 2) |
| model.load_state_dict(torch.load("resnet50_dr_classifier.pth", map_location=device)) |
| model.to(device) |
| model.eval() |
|
|
| |
| target_layer = model.layer4[-1] |
| cam = GradCAM(model=model, target_layers=[target_layer]) |
|
|
| |
| transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], |
| [0.229, 0.224, 0.225]) |
| ]) |
|
|
| |
| log_path = "prediction_logs.csv" |
|
|
| def log_prediction(filename, prediction, confidence): |
| timestamp = datetime.datetime.now().isoformat() |
| row = [timestamp, filename, prediction, f"{confidence:.4f}"] |
|
|
| print("⏺ Logging prediction:", row) |
|
|
| with open(log_path, mode='a', newline='') as file: |
| writer = csv.writer(file) |
| writer.writerow(row) |
|
|
| |
| def predict_retinopathy(image): |
| img = image.convert("RGB").resize((224, 224)) |
| img_tensor = transform(img).unsqueeze(0).to(device) |
|
|
| with torch.no_grad(): |
| output = model(img_tensor) |
| probs = F.softmax(output, dim=1) |
| pred = torch.argmax(probs, dim=1).item() |
| confidence = probs[0][pred].item() |
|
|
| label = "Diabetic Retinopathy (DR)" if pred == 0 else "No DR" |
|
|
| |
| rgb_img_np = np.array(img).astype(np.float32) / 255.0 |
| rgb_img_np = np.ascontiguousarray(rgb_img_np) |
| grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(pred)])[0] |
| cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True) |
|
|
| |
| filename = getattr(image, "filename", "uploaded_image") |
| log_prediction(filename, label, confidence) |
|
|
| cam_pil = Image.fromarray(cam_image) |
| return cam_pil, f"{label} (Confidence: {confidence:.2f})" |
|
|
| |
| gr.Interface( |
| fn=predict_retinopathy, |
| inputs=gr.Image(type="pil"), |
| outputs=[ |
| gr.Image(type="pil", label="Grad-CAM"), |
| gr.Text(label="Prediction") |
| ], |
| title="Diabetic Retinopathy Detection", |
| description="Upload a retinal image to classify DR and view Grad-CAM heatmap. All predictions are logged for analysis." |
| ).launch() |
|
|