MEYTI BECI BAGUNDA commited on
Commit
472fb0c
·
1 Parent(s): 74ffb90

Update file google_vit.py

Browse files
Files changed (1) hide show
  1. src/models/google_vit.py +38 -0
src/models/google_vit.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from src.interface import ModelInterface
3
+ from src.data.classification_result import ClassificationResult
4
+ from transformers import ViTFeatureExtractor, ViTForImageClassification, ViTImageProcessor
5
+ import torch
6
+
7
+ class GoogleVit(ModelInterface):
8
+ def __init__(self):
9
+ print('init... google vit model')
10
+ # Load ViT model and feature extractor
11
+ self.feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
12
+ self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
13
+ self.processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
14
+
15
+ def classify_image(self, image) -> List[ClassificationResult]:
16
+ # Preprocess the image
17
+ inputs = self.processor(images=image, return_tensors="pt")
18
+
19
+ # Perform inference
20
+ outputs = self.model(**inputs)
21
+ logits = outputs.logits.detach().numpy()
22
+
23
+ # Convert logits to probabilities using softmax (using PyTorch)
24
+ probabilities = torch.nn.functional.softmax(torch.from_numpy(logits), dim=-1).numpy()
25
+
26
+ # Get the top 5 predictions
27
+ top_5 = torch.argsort(torch.from_numpy(probabilities), axis=-1, descending=True)[0][:5].numpy()
28
+
29
+ # Create ClassificationResult objects with confidence information
30
+ results = [
31
+ ClassificationResult(
32
+ class_name=self.model.config.id2label[top_5[i]],
33
+ confidence=float(probabilities[0][top_5[i]])
34
+ )
35
+ for i in range(5)
36
+ ]
37
+
38
+ return results