| import torch | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| import json | |
| def load_classes(): | |
| with open('utils/imagenet-simple-labels.json') as f: | |
| labels = json.load(f) | |
| return labels | |
| def class_id_to_label(i): | |
| labels = load_classes() | |
| return labels[i] | |
| def load_model(): | |
| model = models.mobilenet_v2(pretrained=True) | |
| model.eval() | |
| return model | |
| def transform_image(img): | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| return transform(img) | |