File size: 1,520 Bytes
71f9807
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from os import listdir

import gradio as gr
import torch
from torchvision import transforms
from torchvision.models import resnet34


def image_classifier(image):
    # Готовим данные
    inputs = transform(image)
    inputs = inputs.unsqueeze(0)

    # Пропускаем данные через модель
    with torch.no_grad():
        outputs = model(inputs)

    # Ставим метки и вероятности
    predictions = torch.nn.functional.softmax(outputs[0], dim=0)
    predictions = predictions[:len(labels)]
    return {labels[i]: p.item() for i, p in enumerate(predictions)}


# Загружаем модель
model = resnet34()
state = torch.load('model.pth', map_location=torch.device('cpu'))
model.load_state_dict(state)
model.eval()

# Преобразования над сырыми данными
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

# Названия классов
labels = [
    'Pedro Pascal',
    'Robert Downey Jr',
    'Tom Holland',
]

# Добавляем виджет в интерфейс
img_widget = gr.Image(image_mode='RGB', type='pil')
app = gr.Interface(fn=image_classifier, inputs=img_widget, outputs='label',
                   flagging_mode='never')

# Запускаем приложение
app.launch()