Spaces:
Sleeping
Sleeping
| import torch | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| resnet = models.resnet101(weights=True) | |
| preprocess = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ) | |
| ]) | |
| def predict_object_function(img_path): | |
| print('to detect object') | |
| try: | |
| img = Image.open(img_path).convert('RGB') | |
| img_t = preprocess(img) | |
| batch_t = torch.unsqueeze(img_t, 0) | |
| resnet.eval() | |
| except Exception as e: | |
| print(str(e)) | |
| img = Image.open(img_path) | |
| img_t = preprocess(img) | |
| batch_t = torch.unsqueeze(img_t, 0) | |
| resnet.eval() | |
| print('detecting object') | |
| out = resnet(batch_t) | |
| print('detected object') | |
| with open('predict_image/data/new_data.txt') as file: | |
| labels = [line.strip() for line in file.readlines()] | |
| index = torch.max(out, 1)[1] | |
| percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100 | |
| output = labels[index[0]], percentage[index[0]].item() | |
| print("Output", output) | |
| indices = torch.sort(out, 1, descending=True)[1][0][:5] | |
| print([(labels[idx], percentage[idx].item()) for idx in indices]) | |
| return output | |