| from typing import List | |
| from src.interface import ModelInterface | |
| from src.data.classification_result import ClassificationResult | |
| from transformers import ViTFeatureExtractor, ViTForImageClassification, ViTImageProcessor | |
| import torch | |
| class GoogleVit(ModelInterface): | |
| def __init__(self): | |
| print('init... google vit model') | |
| # Load ViT model and feature extractor | |
| self.feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224') | |
| self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224') | |
| self.processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224') | |
| def classify_image(self, image) -> List[ClassificationResult]: | |
| # Preprocess the image | |
| inputs = self.processor(images=image, return_tensors="pt") | |
| # Perform inference | |
| outputs = self.model(**inputs) | |
| logits = outputs.logits.detach().numpy() | |
| # Convert logits to probabilities using softmax (using PyTorch) | |
| probabilities = torch.nn.functional.softmax(torch.from_numpy(logits), dim=-1).numpy() | |
| # Get the top 5 predictions | |
| top_5 = torch.argsort(torch.from_numpy(probabilities), axis=-1, descending=True)[0][:5].numpy() | |
| # Create ClassificationResult objects with confidence information | |
| results = [ | |
| ClassificationResult( | |
| class_name=self.model.config.id2label[top_5[i]], | |
| confidence=float(probabilities[0][top_5[i]]) | |
| ) | |
| for i in range(5) | |
| ] | |
| return results |