File size: 1,408 Bytes
9087ee6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from typing import List
from src.interface import ModelInterface
from src.data.classification_result import ClassificationResult
from transformers import AutoImageProcessor, ResNetForImageClassification
import torch

class Resnet50(ModelInterface):
    def __init__(self):
        print('init... clip vit model')
        self.processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
        self.model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50")

    def classify_image(self, image) -> List[ClassificationResult]:
        # Preprocess the image
        inputs = self.processor(images=image, return_tensors="pt")

        # Perform inference
        outputs = self.model(**inputs)
        logits = outputs.logits.detach().numpy()

        # Convert logits to probabilities using softmax (using PyTorch)
        probabilities = torch.nn.functional.softmax(torch.from_numpy(logits), dim=-1).numpy()

        # Get the top 5 predictions
        top_5 = torch.argsort(torch.from_numpy(probabilities), axis=-1, descending=True)[0][:5].numpy()

        # Create ClassificationResult objects with confidence information
        results = [
            ClassificationResult(
                class_name=self.model.config.id2label[top_5[i]],
                confidence=float(probabilities[0][top_5[i]])
            )
            for i in range(5)
        ]

        return results