import torch from torchvision import models, transforms from PIL import Image resnet = models.resnet101(weights=True) preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) def predict_object_function(img_path): print('to detect object') try: img = Image.open(img_path).convert('RGB') img_t = preprocess(img) batch_t = torch.unsqueeze(img_t, 0) resnet.eval() except Exception as e: print(str(e)) img = Image.open(img_path) img_t = preprocess(img) batch_t = torch.unsqueeze(img_t, 0) resnet.eval() print('detecting object') out = resnet(batch_t) print('detected object') with open('predict_image/data/new_data.txt') as file: labels = [line.strip() for line in file.readlines()] index = torch.max(out, 1)[1] percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100 output = labels[index[0]], percentage[index[0]].item() print("Output", output) indices = torch.sort(out, 1, descending=True)[1][0][:5] print([(labels[idx], percentage[idx].item()) for idx in indices]) return output