lab17 / app.py
mariaria's picture
Upload 3 files
71f9807 verified
from os import listdir
import gradio as gr
import torch
from torchvision import transforms
from torchvision.models import resnet34
def image_classifier(image):
# Готовим данные
inputs = transform(image)
inputs = inputs.unsqueeze(0)
# Пропускаем данные через модель
with torch.no_grad():
outputs = model(inputs)
# Ставим метки и вероятности
predictions = torch.nn.functional.softmax(outputs[0], dim=0)
predictions = predictions[:len(labels)]
return {labels[i]: p.item() for i, p in enumerate(predictions)}
# Загружаем модель
model = resnet34()
state = torch.load('model.pth', map_location=torch.device('cpu'))
model.load_state_dict(state)
model.eval()
# Преобразования над сырыми данными
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
# Названия классов
labels = [
'Pedro Pascal',
'Robert Downey Jr',
'Tom Holland',
]
# Добавляем виджет в интерфейс
img_widget = gr.Image(image_mode='RGB', type='pil')
app = gr.Interface(fn=image_classifier, inputs=img_widget, outputs='label',
flagging_mode='never')
# Запускаем приложение
app.launch()