import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image import gradio as gr from torchvision import transforms # ----------------- MODEL ---------------- class SimpleCNN(nn.Module): def __init__(self, num_classes=5): super(SimpleCNN, self).__init__() self.conv1 = nn.Conv2d(3, 32, 3, padding=1) self.conv2 = nn.Conv2d(32, 64, 3, padding=1) self.conv3 = nn.Conv2d(64, 128, 3, padding=1) self.pool = nn.MaxPool2d(2, 2) self.fc1 = nn.Linear(128 * 28 * 28, 256) self.fc2 = nn.Linear(256, num_classes) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = self.pool(F.relu(self.conv3(x))) x = x.view(x.size(0), -1) x = F.relu(self.fc1(x)) x = self.fc2(x) return x # ---------------- LOAD MODEL ---------------- device = torch.device("cpu") model = SimpleCNN(num_classes=5) model.load_state_dict(torch.load("best_model_aptos.pth", map_location=device)) model.eval() # ---------------- LABEL MAP ---------------- label_map = { 2:'No DR', 0:'Mild', 1:'Moderate', 4:'Severe', 3: 'Proliferative DR' } # ---------------- TRANSFORM ---------------- transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) # ---------------- PREDICTION FUNCTION ---------------- # def predict(image): # image = image.convert("RGB") # image = transform(image).unsqueeze(0) # with torch.no_grad(): # outputs = model(image) # probs = F.softmax(outputs, dim=1) # # Convert to dictionary for all classes # confidences = { # label_map[i]: float(probs[i]) # for i in range(len(probs)) # } # return confidences def predict(image): image = transform(image).unsqueeze(0) with torch.no_grad(): outputs = model(image) probs = torch.softmax(outputs, dim=1).squeeze() probs = probs.tolist() result = { label_map[i]: float(probs[i]) for i in range(len(probs)) } return result # confidence, pred_class = torch.max(probs, 1) # confidence = confidence.item() # pred_class = pred_class.item() # predicted_label = label_map[pred_class] # # Return dictionary (Gradio shows nicely) # return { # predicted_label: confidence # } # ---------------- GRADIO UI ---------------- interface = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=5), title=" Diabetic Retinopathy Classifier", description="Upload a retinal image or try sample images below", # ADD THIS # examples=[ # "https://commons.wikimedia.org/wiki/File:Sample1.png", # "https://commons.wikimedia.org/wiki/File:Sample22.png", # "https://commons.wikimedia.org/wiki/File:Sample3.png" # ] ) if __name__ == "__main__": interface.launch(share=False, ssr_mode=False)