object / detect_object.py
msabonkudi's picture
Update detect_object.py
d2a649a verified
import torch
from torchvision import models, transforms
from PIL import Image
resnet = models.resnet101(weights=True)
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
def predict_object_function(img_path):
print('to detect object')
try:
img = Image.open(img_path).convert('RGB')
img_t = preprocess(img)
batch_t = torch.unsqueeze(img_t, 0)
resnet.eval()
except Exception as e:
print(str(e))
img = Image.open(img_path)
img_t = preprocess(img)
batch_t = torch.unsqueeze(img_t, 0)
resnet.eval()
print('detecting object')
out = resnet(batch_t)
print('detected object')
with open('predict_image/data/new_data.txt') as file:
labels = [line.strip() for line in file.readlines()]
index = torch.max(out, 1)[1]
percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
output = labels[index[0]], percentage[index[0]].item()
print("Output", output)
indices = torch.sort(out, 1, descending=True)[1][0][:5]
print([(labels[idx], percentage[idx].item()) for idx in indices])
return output