File size: 1,408 Bytes
9087ee6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
from typing import List
from src.interface import ModelInterface
from src.data.classification_result import ClassificationResult
from transformers import AutoImageProcessor, ResNetForImageClassification
import torch
class Resnet50(ModelInterface):
def __init__(self):
print('init... clip vit model')
self.processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
self.model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50")
def classify_image(self, image) -> List[ClassificationResult]:
# Preprocess the image
inputs = self.processor(images=image, return_tensors="pt")
# Perform inference
outputs = self.model(**inputs)
logits = outputs.logits.detach().numpy()
# Convert logits to probabilities using softmax (using PyTorch)
probabilities = torch.nn.functional.softmax(torch.from_numpy(logits), dim=-1).numpy()
# Get the top 5 predictions
top_5 = torch.argsort(torch.from_numpy(probabilities), axis=-1, descending=True)[0][:5].numpy()
# Create ClassificationResult objects with confidence information
results = [
ClassificationResult(
class_name=self.model.config.id2label[top_5[i]],
confidence=float(probabilities[0][top_5[i]])
)
for i in range(5)
]
return results |