File size: 1,624 Bytes
cebad5c 9087ee6 cebad5c 9087ee6 cebad5c 9087ee6 cebad5c 9087ee6 cebad5c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
from typing import List
import torch
import timm
from src.interface import ModelInterface
from src.data.classification_result import ClassificationResult
from PIL import Image
import urllib.request
class MobilenetV3(ModelInterface):
def __init__(self):
print('init... mobilenet v3 model')
self.model = timm.create_model('mobilenetv3_large_100', pretrained=True).eval()
# Download and read class labels
url, filename = ("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt", "imagenet_classes.txt")
urllib.request.urlretrieve(url, filename)
with open(filename, "r") as f:
self.class_labels = [s.strip() for s in f.readlines()]
def classify_image(self, image) -> List[ClassificationResult]:
# Get model-specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(self.model)
transforms = timm.data.create_transform(**data_config, is_training=False)
input_tensor = transforms(image).unsqueeze(0)
# Perform inference
with torch.no_grad():
output = self.model(input_tensor)
# Get the top 5 predictions
probabilities, top5_class_indices = torch.topk(output.softmax(dim=1), k=5)
# Create ClassificationResult objects with confidence information
results = [
ClassificationResult(
class_name=self.class_labels[top5_class_indices[0][i].item()],
confidence=probabilities[0][i].item()
)
for i in range(5)
]
return results |