Spaces:
Sleeping
Sleeping
File size: 1,305 Bytes
d2a649a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 | import torch
from torchvision import models, transforms
from PIL import Image
resnet = models.resnet101(weights=True)
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
def predict_object_function(img_path):
print('to detect object')
try:
img = Image.open(img_path).convert('RGB')
img_t = preprocess(img)
batch_t = torch.unsqueeze(img_t, 0)
resnet.eval()
except Exception as e:
print(str(e))
img = Image.open(img_path)
img_t = preprocess(img)
batch_t = torch.unsqueeze(img_t, 0)
resnet.eval()
print('detecting object')
out = resnet(batch_t)
print('detected object')
with open('predict_image/data/new_data.txt') as file:
labels = [line.strip() for line in file.readlines()]
index = torch.max(out, 1)[1]
percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
output = labels[index[0]], percentage[index[0]].item()
print("Output", output)
indices = torch.sort(out, 1, descending=True)[1][0][:5]
print([(labels[idx], percentage[idx].item()) for idx in indices])
return output
|