File size: 1,305 Bytes
d2a649a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
from torchvision import models, transforms
from PIL import Image

resnet = models.resnet101(weights=True)

preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])


def predict_object_function(img_path):
    print('to detect object')

    try:
        img = Image.open(img_path).convert('RGB')

        img_t = preprocess(img)
        batch_t = torch.unsqueeze(img_t, 0)
        resnet.eval()
    except Exception as e:
        print(str(e))
        img = Image.open(img_path)

        img_t = preprocess(img)
        batch_t = torch.unsqueeze(img_t, 0)
        resnet.eval()

    print('detecting object')

    out = resnet(batch_t)

    print('detected object')

    with open('predict_image/data/new_data.txt') as file:
        labels = [line.strip() for line in file.readlines()]
        index = torch.max(out, 1)[1]
    percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
    output = labels[index[0]], percentage[index[0]].item()
    print("Output", output)
    indices = torch.sort(out, 1, descending=True)[1][0][:5]
    print([(labels[idx], percentage[idx].item()) for idx in indices])
    return output