import configparser import pandas as pd from PIL import Image import torch from torchvision import models, transforms # This is an abstract class. The method "get_model" must be implemented # by the child class. class ImageClassifierBase(): def __init__(self): pass # self.logger = logging.getLogger(__name__) # logging.basicConfig(filename='app.log', level=logging.INFO) def __read_text_labels__(self): text_labels = pd.read_csv('imagenet_labels.csv').values text_labels = text_labels.flatten() return text_labels def __read_image__(self, device, image): preprocess = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) input_tensor = preprocess(image) input_batch = input_tensor.unsqueeze(0) input_batch = input_batch.to(device) return input_batch def get_device(self): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") return device # This method must be implemented by the child class. def get_model(self, device): pass def image_classification(self, model, input_image, device, hard_detection_threshold=0.0): input_batch = self.__read_image__(device, input_image) with torch.no_grad(): output = model(input_batch).data text_labels = self.__read_text_labels__() classification_summary = pd.DataFrame() classification_summary['label'] = text_labels classification_summary['prob'] = output[0] classification_summary = \ classification_summary.sort_values(by=['prob'], ascending=False) return classification_summary class ImageClassifierVGG16(ImageClassifierBase): def __init__(self): super().__init__() def get_model(self, device): model = models.vgg16(pretrained=True) model.to(device) model.eval() return model