|
|
from typing import List |
|
|
import torch |
|
|
import timm |
|
|
from src.interface import ModelInterface |
|
|
from src.data.classification_result import ClassificationResult |
|
|
from PIL import Image |
|
|
import urllib.request |
|
|
|
|
|
class MobilenetV3(ModelInterface): |
|
|
|
|
|
def __init__(self): |
|
|
print('init... mobilenet v3 model') |
|
|
self.model = timm.create_model('mobilenetv3_large_100', pretrained=True).eval() |
|
|
|
|
|
|
|
|
url, filename = ("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt", "imagenet_classes.txt") |
|
|
urllib.request.urlretrieve(url, filename) |
|
|
with open(filename, "r") as f: |
|
|
self.class_labels = [s.strip() for s in f.readlines()] |
|
|
|
|
|
def classify_image(self, image) -> List[ClassificationResult]: |
|
|
|
|
|
|
|
|
data_config = timm.data.resolve_model_data_config(self.model) |
|
|
transforms = timm.data.create_transform(**data_config, is_training=False) |
|
|
input_tensor = transforms(image).unsqueeze(0) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output = self.model(input_tensor) |
|
|
|
|
|
|
|
|
probabilities, top5_class_indices = torch.topk(output.softmax(dim=1), k=5) |
|
|
|
|
|
|
|
|
results = [ |
|
|
ClassificationResult( |
|
|
class_name=self.class_labels[top5_class_indices[0][i].item()], |
|
|
confidence=probabilities[0][i].item() |
|
|
) |
|
|
for i in range(5) |
|
|
] |
|
|
|
|
|
return results |