| from os import listdir
|
|
|
| import gradio as gr
|
| import torch
|
| from torchvision import transforms
|
| from torchvision.models import resnet34
|
|
|
|
|
| def image_classifier(image):
|
|
|
| inputs = transform(image)
|
| inputs = inputs.unsqueeze(0)
|
|
|
|
|
| with torch.no_grad():
|
| outputs = model(inputs)
|
|
|
|
|
| predictions = torch.nn.functional.softmax(outputs[0], dim=0)
|
| predictions = predictions[:len(labels)]
|
| return {labels[i]: p.item() for i, p in enumerate(predictions)}
|
|
|
|
|
|
|
| model = resnet34()
|
| state = torch.load('model.pth', map_location=torch.device('cpu'))
|
| model.load_state_dict(state)
|
| model.eval()
|
|
|
|
|
| transform = transforms.Compose([
|
| transforms.Resize(256),
|
| transforms.CenterCrop(224),
|
| transforms.ToTensor(),
|
| transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| std=[0.229, 0.224, 0.225]),
|
| ])
|
|
|
|
|
| labels = [
|
| 'Pedro Pascal',
|
| 'Robert Downey Jr',
|
| 'Tom Holland',
|
| ]
|
|
|
|
|
| img_widget = gr.Image(image_mode='RGB', type='pil')
|
| app = gr.Interface(fn=image_classifier, inputs=img_widget, outputs='label',
|
| flagging_mode='never')
|
|
|
|
|
| app.launch() |