File size: 1,821 Bytes
9087ee6
 
 
cebad5c
 
 
 
 
 
9087ee6
 
 
 
cebad5c
 
9087ee6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel
from typing import List
from src.interface import ModelInterface
from src.data.classification_result import ClassificationResult

class ClipVit(ModelInterface):
    def __init__(self):
        print('Initializing CLIP VIT model...')
        # Load pre-trained CLIP model and processor
        self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

    def classify_image(self, image) -> List[ClassificationResult]:
        # Preprocess the image using CLIPProcessor
        inputs = self.processor(text=["volcano", "mountain", "alp", "mount", "valley"], images=image, return_tensors="pt", padding=True)

        # Perform inference
        outputs = self.model(**inputs)
        logits_per_image = outputs.logits_per_image  # This is the image-text similarity score

        # Convert logits to probabilities using softmax (using PyTorch)
        probabilities = torch.nn.functional.softmax(logits_per_image, dim=1).detach().numpy()

        # Get the top 5 predicted classes and their probabilities using torch.argsort
        top_indices = torch.argsort(torch.from_numpy(probabilities), dim=1, descending=True)[0, :5]
        top_indices = top_indices.tolist()
        top_probabilities = probabilities[0, top_indices]

        # Get the class labels from the processor's tokenizer
        class_name = ["volcano", "mountain", "alp", "mount", "valley"]

        # Create a list of ClassificationResult objects with predicted classes and probabilities
        result = [ClassificationResult(class_name=str(name), confidence=float(probabilities)) for name, probabilities in zip(class_name, top_probabilities)]

        return result