MEYTI BECI BAGUNDA
Update 4 files
9087ee6
from typing import List
import torch
import timm
from src.interface import ModelInterface
from src.data.classification_result import ClassificationResult
from PIL import Image
import urllib.request
class MobilenetV3(ModelInterface):
def __init__(self):
print('init... mobilenet v3 model')
self.model = timm.create_model('mobilenetv3_large_100', pretrained=True).eval()
# Download and read class labels
url, filename = ("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt", "imagenet_classes.txt")
urllib.request.urlretrieve(url, filename)
with open(filename, "r") as f:
self.class_labels = [s.strip() for s in f.readlines()]
def classify_image(self, image) -> List[ClassificationResult]:
# Get model-specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(self.model)
transforms = timm.data.create_transform(**data_config, is_training=False)
input_tensor = transforms(image).unsqueeze(0)
# Perform inference
with torch.no_grad():
output = self.model(input_tensor)
# Get the top 5 predictions
probabilities, top5_class_indices = torch.topk(output.softmax(dim=1), k=5)
# Create ClassificationResult objects with confidence information
results = [
ClassificationResult(
class_name=self.class_labels[top5_class_indices[0][i].item()],
confidence=probabilities[0][i].item()
)
for i in range(5)
]
return results