Spaces:
Configuration error
Configuration error
| import gradio as gr | |
| from PIL import Image | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from torchvision import models, transforms | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| import os | |
| import datetime | |
| # Setup | |
| device = torch.device("cpu") | |
| # Create save directory in current working directory for cross-platform compatibility | |
| save_dir = os.path.join(os.getcwd(), "saved_predictions") | |
| if not os.path.exists(save_dir): | |
| os.makedirs(save_dir) | |
| print("📁 Folder created:", save_dir) | |
| os.makedirs(save_dir, exist_ok=True) | |
| # Load model | |
| model = models.resnet50(weights=None) | |
| model.fc = torch.nn.Linear(model.fc.in_features, 2) | |
| model.load_state_dict(torch.load("resnet50_dr_classifier.pth", map_location=device)) | |
| model.to(device) | |
| model.eval() | |
| # Grad-CAM | |
| target_layer = model.layer4[-1] | |
| cam = GradCAM(model=model, target_layers=[target_layer]) | |
| # Preprocessing | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], | |
| [0.229, 0.224, 0.225]) | |
| ]) | |
| # Predict and save | |
| def predict_retinopathy(image): | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| img = image.convert("RGB").resize((224, 224)) | |
| img_tensor = transform(img).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| output = model(img_tensor) | |
| probs = F.softmax(output, dim=1) | |
| pred = torch.argmax(probs, dim=1).item() | |
| confidence = probs[0][pred].item() | |
| label = "DR" if pred == 0 else "NoDR" | |
| # Grad-CAM | |
| rgb_img_np = np.array(img).astype(np.float32) / 255.0 | |
| rgb_img_np = np.ascontiguousarray(rgb_img_np) | |
| grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(pred)])[0] | |
| cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True) | |
| cam_pil = Image.fromarray(cam_image) | |
| # Save image with label and confidence | |
| filename = f"{timestamp}_{label}_{confidence:.2f}.png" | |
| cam_pil.save(os.path.join(save_dir, filename)) | |
| return cam_pil, f"{label} (Confidence: {confidence:.2f})" | |
| # Gradio app | |
| gr.Interface( | |
| fn=predict_retinopathy, | |
| inputs=gr.Image(type="pil", label="Upload OCT Image"), | |
| outputs=[ | |
| gr.Image(type="pil", label="Grad-CAM Heatmap"), | |
| gr.Text(label="Diabetic Retinopathy Prediction") | |
| ], | |
| title="AI Diabetic Retinopathy Detection", | |
| description="Upload an OCT image to analyze for diabetic retinopathy. The AI will show a Grad-CAM heatmap highlighting areas of interest.", | |
| examples=[ | |
| ["example_oct.jpg"] if os.path.exists("example_oct.jpg") else None | |
| ] | |
| ).launch(server_name="0.0.0.0", server_port=7860) | |