File size: 1,624 Bytes
cebad5c
9087ee6
 
cebad5c
 
9087ee6
 
cebad5c
 
 
 
 
9087ee6
 
 
 
 
 
 
cebad5c
 
9087ee6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cebad5c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from typing import List
import torch
import timm
from src.interface import ModelInterface
from src.data.classification_result import ClassificationResult
from PIL import Image
import urllib.request

class MobilenetV3(ModelInterface):
    
    def __init__(self):
        print('init... mobilenet v3 model')
        self.model = timm.create_model('mobilenetv3_large_100', pretrained=True).eval()

        # Download and read class labels
        url, filename = ("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt", "imagenet_classes.txt")
        urllib.request.urlretrieve(url, filename)
        with open(filename, "r") as f:
            self.class_labels = [s.strip() for s in f.readlines()]

    def classify_image(self, image) -> List[ClassificationResult]:
  
        # Get model-specific transforms (normalization, resize)
        data_config = timm.data.resolve_model_data_config(self.model)
        transforms = timm.data.create_transform(**data_config, is_training=False)
        input_tensor = transforms(image).unsqueeze(0)

        # Perform inference
        with torch.no_grad():
            output = self.model(input_tensor)

        # Get the top 5 predictions
        probabilities, top5_class_indices = torch.topk(output.softmax(dim=1), k=5)

        # Create ClassificationResult objects with confidence information
        results = [
            ClassificationResult(
                class_name=self.class_labels[top5_class_indices[0][i].item()],
                confidence=probabilities[0][i].item()
            )
            for i in range(5)
        ]

        return results