|
|
from typing import List |
|
|
from src.interface import ModelInterface |
|
|
from src.data.classification_result import ClassificationResult |
|
|
from transformers import AutoImageProcessor, ResNetForImageClassification |
|
|
import torch |
|
|
|
|
|
class Resnet50(ModelInterface): |
|
|
def __init__(self): |
|
|
print('init... clip vit model') |
|
|
self.processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50") |
|
|
self.model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50") |
|
|
|
|
|
def classify_image(self, image) -> List[ClassificationResult]: |
|
|
|
|
|
inputs = self.processor(images=image, return_tensors="pt") |
|
|
|
|
|
|
|
|
outputs = self.model(**inputs) |
|
|
logits = outputs.logits.detach().numpy() |
|
|
|
|
|
|
|
|
probabilities = torch.nn.functional.softmax(torch.from_numpy(logits), dim=-1).numpy() |
|
|
|
|
|
|
|
|
top_5 = torch.argsort(torch.from_numpy(probabilities), axis=-1, descending=True)[0][:5].numpy() |
|
|
|
|
|
|
|
|
results = [ |
|
|
ClassificationResult( |
|
|
class_name=self.model.config.id2label[top_5[i]], |
|
|
confidence=float(probabilities[0][top_5[i]]) |
|
|
) |
|
|
for i in range(5) |
|
|
] |
|
|
|
|
|
return results |