clothes-stance / src /shared /stance_recognize.py
louis JULIEN
model and first documentation
26cab2b
from typing import List
import numpy as np
import torch
from PIL.Image import Image
from scipy.special import softmax
from torch import Tensor
from torch.jit import RecursiveScriptModule
from torchvision import transforms
def normalize_neural_network_output(neural_network_output: np.ndarray) -> List[float]:
return softmax(neural_network_output).tolist()
def predict(pil_images: List[Image],
labels_to_output_index: list,
neural_network_model: RecursiveScriptModule
) -> List[dict]:
neural_network_input = format_pils_images(pil_images)
neural_network_output = predict_neural_network(neural_network_input, neural_network_model)
return format_output(neural_network_output, labels_to_output_index)
def predict_neural_network(neural_network_input: Tensor, neural_network_model: RecursiveScriptModule) -> List[
List[float]]:
with torch.no_grad():
all_outputs = neural_network_model(neural_network_input).tolist()
normalized_neural_network_output = list(map(normalize_neural_network_output, all_outputs))
return normalized_neural_network_output
def format_output(neural_network_output: list, labels_to_output_index: list) -> List[dict]:
results = []
for i, a_nn_output in enumerate(neural_network_output):
results.append(
{
'probabilities_neural_network': {
labels_to_output_index[j]: p for j, p in enumerate(a_nn_output)
},
})
return results
def format_pils_images(pil_images: List[Image]) -> Tensor:
pil_images_transformed = []
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
for a_pil_image in pil_images:
a_pil_image = transform(a_pil_image)
pil_images_transformed.append(a_pil_image.unsqueeze(0))
return torch.cat(pil_images_transformed, dim=0)