Spaces:
Build error
Build error
| import torch | |
| import gradio as gr | |
| from torchvision import transforms | |
| from torchvision import models | |
| from PIL import Image | |
| from src.config import Config | |
| from src.utils import id2label | |
| def predict(image): | |
| if image is None: | |
| raise gr.Error("No image found. Please upload an image to predict.") | |
| transform = transforms.Compose([ | |
| transforms.Resize(int(Config.imgsize * 1.143)), | |
| transforms.CenterCrop(Config.imgsize), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| model = Config.model | |
| try: | |
| image = Image.fromarray(image.astype('uint8'), 'RGB') | |
| except: | |
| raise gr.Error("Image could not be converted to RGB. Please try another image.") | |
| preprocessed_img = transform(image).unsqueeze(0) | |
| outputs = model(preprocessed_img) | |
| outputs = torch.softmax(outputs, dim=1) | |
| preds = torch.topk(outputs, 5) | |
| labels = {id2label(preds.indices[0][i].item()).capitalize(): preds.values[0][i].item() for i in range(5)} | |
| return labels | |