Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms, datasets, models | |
| transformer = models.ResNet18_Weights.IMAGENET1K_V1.transforms() | |
| device = torch.device("cpu") | |
| class_names = ['Anger', 'Disgust', 'Fear', 'Happy', 'Pain', 'Sad'] | |
| classes_count = len(class_names) | |
| model = models.resnet18(weights='DEFAULT').to(device) | |
| model.fc = nn.Sequential( | |
| nn.Linear(512, classes_count) | |
| ) | |
| model.load_state_dict(torch.load('./model_param.pt', map_location=device), strict=False) | |
| def predict(image): | |
| image = transformer(image).unsqueeze(0).to(device) | |
| model.eval() | |
| with torch.inference_mode(): | |
| pred = torch.softmax(model(image), dim=1) | |
| preds_and_labels = {class_names[i]: pred[0][i].item() for i in range(len(pred[0]))} | |
| return preds_and_labels | |
| app = gr.Interface( | |
| predict, | |
| gr.Image(type='pil'), | |
| gr.Label(label='Predictions', num_top_classes=classes_count), | |
| #examples=[ | |
| # './example1.jpg', | |
| # './example2.jpg', | |
| # './example3.jpg', | |
| #], | |
| live=True | |
| ) | |
| app.launch() |