File size: 1,588 Bytes
472fb0c d4b4f11 472fb0c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
from typing import List
from src.interface import ModelInterface
from src.data.classification_result import ClassificationResult
from transformers import ViTFeatureExtractor, ViTForImageClassification, ViTImageProcessor
import torch
class GoogleVit(ModelInterface):
def __init__(self):
print('init... google vit model')
# Load ViT model and feature extractor
self.feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
self.processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
def classify_image(self, image) -> List[ClassificationResult]:
# Preprocess the image
inputs = self.processor(images=image, return_tensors="pt")
# Perform inference
outputs = self.model(**inputs)
logits = outputs.logits.detach().numpy()
# Convert logits to probabilities using softmax (using PyTorch)
probabilities = torch.nn.functional.softmax(torch.from_numpy(logits), dim=-1).numpy()
# Get the top 5 predictions
top_5 = torch.argsort(torch.from_numpy(probabilities), axis=-1, descending=True)[0][:5].numpy()
# Create ClassificationResult objects with confidence information
results = [
ClassificationResult(
class_name=self.model.config.id2label[top_5[i]],
confidence=float(probabilities[0][top_5[i]])
)
for i in range(5)
]
return results |