File size: 1,588 Bytes
472fb0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4b4f11
472fb0c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from typing import List
from src.interface import ModelInterface
from src.data.classification_result import ClassificationResult
from transformers import ViTFeatureExtractor, ViTForImageClassification, ViTImageProcessor
import torch

class GoogleVit(ModelInterface):
    def __init__(self):
        print('init... google vit model')
        # Load ViT model and feature extractor
        self.feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
        self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
        self.processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')

    def classify_image(self, image) -> List[ClassificationResult]:
        # Preprocess the image
        inputs = self.processor(images=image, return_tensors="pt")

        # Perform inference
        outputs = self.model(**inputs)
        logits = outputs.logits.detach().numpy()

        # Convert logits to probabilities using softmax (using PyTorch)
        probabilities = torch.nn.functional.softmax(torch.from_numpy(logits), dim=-1).numpy()

        # Get the top 5 predictions
        top_5 = torch.argsort(torch.from_numpy(probabilities), axis=-1, descending=True)[0][:5].numpy()

        # Create ClassificationResult objects with confidence information
        results = [
            ClassificationResult(
                class_name=self.model.config.id2label[top_5[i]],
                confidence=float(probabilities[0][top_5[i]])
            )
            for i in range(5)
        ]


        return results