avori / src /apps /classification.py
Alimustoofaa's picture
first commit
0e549ca
import os
import torch
from PIL import Image
import numpy as np
from torch import cuda, device
from torchvision import transforms
from src.utils.utils import download_and_unzip_model
class Classification:
def __init__(self, root_path, model_config):
self.root_path = root_path
self.model_config = model_config
self.model_name = f'{root_path}/{model_config["filename"]}'
self.image_size = model_config['image_size']
self.classes = model_config['classes']
self.filter_classes = model_config['filter_classes']
self.device = device("cuda:0" if cuda.is_available() else "cpu")
self.transform_image = self.transfrom_img()
self.model = self.__load_model()
@staticmethod
def __check_model(root_path:str, model_config:dict) -> None:
if not os.path.isfile(f'{root_path}/{model_config["filename"]}'):
download_and_unzip_model(
root_dir = root_path,
name = model_config['filename'],
url = model_config['url'],
file_size = model_config['file_size'],
unzip = False
)
else: print('Load model Classification')
def __load_model(self):
self.__check_model(self.root_path, self.model_config)
print(self.model_name)
return torch.load(self.model_name, map_location=self.device)
def transfrom_img(self):
return transforms.Compose([
transforms.Resize(size=256),
transforms.CenterCrop(size=224),
transforms.ToTensor(),
])
@classmethod
def tensor_image_to_numpy(cls, tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
def image_processing(self, image):
image = self.transform_image(image)
image = image.unsqueeze(0)
return image.to(self.device)
def __call__(self, image):
if type(image) == np.ndarray:
image = Image.fromarray(image)
image = image.resize((self.image_size, self.image_size))
image = self.image_processing(image)
if torch.cuda.is_available():
image_tensor = image.view(1, 3, 224, 224).cuda()
else:
image_tensor = image.view(1, 3, 224, 224)
with torch.no_grad():
output = self.model(image_tensor)
probs = torch.exp(output).to(self.device)
topk, topclass = probs.topk(len(self.filter_classes), dim=1)
result = list()
for i in range(len(probs[0])):
classes = self.classes[topclass.cpu().numpy()[0][i]]
conf = round(topk.cpu().numpy()[0][i], 2)
if classes in self.filter_classes:
result.append([classes, conf])
return result
if __name__ == '__main__':
root_path = os.path.expanduser('/Users/alimustofa/Downloads/')
classification_config = {
'filename' : 'resnet34_model_400.pt',
'image_size' : 480,
'classes' :['almost_ripe', 'not_ripe', 'overripe', 'ripe'],
'filter_classes' : ['almost_ripe', 'not_ripe', 'overripe', 'ripe'],
'url' : 'https://github.com/Alimustoofaa/ContainerNumber-Dev/releases/download/detection_v1/container_iso_maxgross.pt',
'file_size' : 14749585,
}
# model_name = '/Users/alimustofa/Downloads/resnet34_model_avorit.pt'
avori_classification = Classification(root_path,model_config=classification_config)