from typing import List import numpy as np import torch from PIL.Image import Image from scipy.special import softmax from torch import Tensor from torch.jit import RecursiveScriptModule from torchvision import transforms def normalize_neural_network_output(neural_network_output: np.ndarray) -> List[float]: return softmax(neural_network_output).tolist() def predict(pil_images: List[Image], labels_to_output_index: list, neural_network_model: RecursiveScriptModule ) -> List[dict]: neural_network_input = format_pils_images(pil_images) neural_network_output = predict_neural_network(neural_network_input, neural_network_model) return format_output(neural_network_output, labels_to_output_index) def predict_neural_network(neural_network_input: Tensor, neural_network_model: RecursiveScriptModule) -> List[ List[float]]: with torch.no_grad(): all_outputs = neural_network_model(neural_network_input).tolist() normalized_neural_network_output = list(map(normalize_neural_network_output, all_outputs)) return normalized_neural_network_output def format_output(neural_network_output: list, labels_to_output_index: list) -> List[dict]: results = [] for i, a_nn_output in enumerate(neural_network_output): results.append( { 'probabilities_neural_network': { labels_to_output_index[j]: p for j, p in enumerate(a_nn_output) }, }) return results def format_pils_images(pil_images: List[Image]) -> Tensor: pil_images_transformed = [] transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) for a_pil_image in pil_images: a_pil_image = transform(a_pil_image) pil_images_transformed.append(a_pil_image.unsqueeze(0)) return torch.cat(pil_images_transformed, dim=0)