import torch import torch.nn as nn from torchvision import models, transforms from PIL import Image import gradio as gr # model model = models.resnet18(pretrained=False) model.fc = nn.Linear(model.fc.in_features, 5) model.load_state_dict(torch.load("resnet_model.pth", map_location="cpu")) model.eval() # transform transform = transforms.Compose([ transforms.Resize((224,224)), transforms.ToTensor(), ]) classes = ["No_DR","Mild","Moderate","Severe","Proliferative"] def predict(image): image = transform(image).unsqueeze(0) with torch.no_grad(): output = model(image) pred = torch.argmax(output,1).item() return classes[pred] demo = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs="text") demo.launch(share=False, ssr_mode=False)