File size: 2,012 Bytes
26cab2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from typing import List

import numpy as np
import torch
from PIL.Image import Image
from scipy.special import softmax
from torch import Tensor
from torch.jit import RecursiveScriptModule
from torchvision import transforms


def normalize_neural_network_output(neural_network_output: np.ndarray) -> List[float]:
    return softmax(neural_network_output).tolist()


def predict(pil_images: List[Image],
            labels_to_output_index: list,
            neural_network_model: RecursiveScriptModule
            ) -> List[dict]:
    neural_network_input = format_pils_images(pil_images)
    neural_network_output = predict_neural_network(neural_network_input, neural_network_model)
    return format_output(neural_network_output, labels_to_output_index)


def predict_neural_network(neural_network_input: Tensor, neural_network_model: RecursiveScriptModule) -> List[
    List[float]]:
    with torch.no_grad():
        all_outputs = neural_network_model(neural_network_input).tolist()

    normalized_neural_network_output = list(map(normalize_neural_network_output, all_outputs))
    return normalized_neural_network_output


def format_output(neural_network_output: list, labels_to_output_index: list) -> List[dict]:
    results = []
    for i, a_nn_output in enumerate(neural_network_output):
        results.append(
            {
                'probabilities_neural_network': {
                    labels_to_output_index[j]: p for j, p in enumerate(a_nn_output)
                },
            })
    return results


def format_pils_images(pil_images: List[Image]) -> Tensor:
    pil_images_transformed = []

    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    for a_pil_image in pil_images:
        a_pil_image = transform(a_pil_image)
        pil_images_transformed.append(a_pil_image.unsqueeze(0))

    return torch.cat(pil_images_transformed, dim=0)