| import torch |
| import torch.nn as nn |
| from torchvision import models, transforms |
| from PIL import Image |
| import gradio as gr |
|
|
| |
| model = models.resnet18(pretrained=False) |
| model.fc = nn.Linear(model.fc.in_features, 5) |
| model.load_state_dict(torch.load("resnet_model.pth", map_location="cpu")) |
| model.eval() |
|
|
| |
| transform = transforms.Compose([ |
| transforms.Resize((224,224)), |
| transforms.ToTensor(), |
| ]) |
|
|
| classes = ["No_DR","Mild","Moderate","Severe","Proliferative"] |
|
|
| def predict(image): |
| image = transform(image).unsqueeze(0) |
|
|
| with torch.no_grad(): |
| output = model(image) |
| pred = torch.argmax(output,1).item() |
|
|
| return classes[pred] |
|
|
| demo = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs="text") |
|
|
| demo.launch(share=False, ssr_mode=False) |