| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from PIL import Image |
| import gradio as gr |
| from torchvision import transforms |
|
|
| |
| class SimpleCNN(nn.Module): |
| def __init__(self, num_classes=5): |
| super(SimpleCNN, self).__init__() |
|
|
| self.conv1 = nn.Conv2d(3, 32, 3, padding=1) |
| self.conv2 = nn.Conv2d(32, 64, 3, padding=1) |
| self.conv3 = nn.Conv2d(64, 128, 3, padding=1) |
|
|
| self.pool = nn.MaxPool2d(2, 2) |
|
|
| self.fc1 = nn.Linear(128 * 28 * 28, 256) |
| self.fc2 = nn.Linear(256, num_classes) |
|
|
| def forward(self, x): |
| x = self.pool(F.relu(self.conv1(x))) |
| x = self.pool(F.relu(self.conv2(x))) |
| x = self.pool(F.relu(self.conv3(x))) |
|
|
| x = x.view(x.size(0), -1) |
|
|
| x = F.relu(self.fc1(x)) |
| x = self.fc2(x) |
|
|
| return x |
|
|
|
|
| |
| device = torch.device("cpu") |
|
|
| model = SimpleCNN(num_classes=5) |
| model.load_state_dict(torch.load("best_model_aptos.pth", map_location=device)) |
| model.eval() |
|
|
|
|
| |
| label_map = { |
| 2:'No DR', |
| 0:'Mild', |
| 1:'Moderate', |
| 4:'Severe', |
| 3: 'Proliferative DR' |
| } |
|
|
|
|
| |
| transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| ]) |
|
|
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
|
|
| def predict(image): |
| image = transform(image).unsqueeze(0) |
|
|
| with torch.no_grad(): |
| outputs = model(image) |
| probs = torch.softmax(outputs, dim=1).squeeze() |
|
|
| probs = probs.tolist() |
|
|
| result = { |
| label_map[i]: float(probs[i]) |
| for i in range(len(probs)) |
| } |
|
|
| return result |
|
|
| |
|
|
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
|
|
|
|
| |
| interface = gr.Interface( |
| fn=predict, |
| inputs=gr.Image(type="pil"), |
| outputs=gr.Label(num_top_classes=5), |
| title=" Diabetic Retinopathy Classifier", |
| description="Upload a retinal image or try sample images below", |
|
|
| |
| |
| |
| |
| |
| |
| ) |
|
|
| if __name__ == "__main__": |
| interface.launch(share=False, ssr_mode=False) |