from os import listdir import gradio as gr import torch from torchvision import transforms from torchvision.models import resnet34 def image_classifier(image): # Готовим данные inputs = transform(image) inputs = inputs.unsqueeze(0) # Пропускаем данные через модель with torch.no_grad(): outputs = model(inputs) # Ставим метки и вероятности predictions = torch.nn.functional.softmax(outputs[0], dim=0) predictions = predictions[:len(labels)] return {labels[i]: p.item() for i, p in enumerate(predictions)} # Загружаем модель model = resnet34() state = torch.load('model.pth', map_location=torch.device('cpu')) model.load_state_dict(state) model.eval() # Преобразования над сырыми данными transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Названия классов labels = [ 'Pedro Pascal', 'Robert Downey Jr', 'Tom Holland', ] # Добавляем виджет в интерфейс img_widget = gr.Image(image_mode='RGB', type='pil') app = gr.Interface(fn=image_classifier, inputs=img_widget, outputs='label', flagging_mode='never') # Запускаем приложение app.launch()