|
|
from PIL import Image |
|
|
import torch |
|
|
from transformers import CLIPProcessor, CLIPModel |
|
|
from typing import List |
|
|
from src.interface import ModelInterface |
|
|
from src.data.classification_result import ClassificationResult |
|
|
|
|
|
class ClipVit(ModelInterface): |
|
|
def __init__(self): |
|
|
print('Initializing CLIP VIT model...') |
|
|
|
|
|
self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
|
|
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
|
|
|
|
def classify_image(self, image) -> List[ClassificationResult]: |
|
|
|
|
|
inputs = self.processor(text=["volcano", "mountain", "alp", "mount", "valley"], images=image, return_tensors="pt", padding=True) |
|
|
|
|
|
|
|
|
outputs = self.model(**inputs) |
|
|
logits_per_image = outputs.logits_per_image |
|
|
|
|
|
|
|
|
probabilities = torch.nn.functional.softmax(logits_per_image, dim=1).detach().numpy() |
|
|
|
|
|
|
|
|
top_indices = torch.argsort(torch.from_numpy(probabilities), dim=1, descending=True)[0, :5] |
|
|
top_indices = top_indices.tolist() |
|
|
top_probabilities = probabilities[0, top_indices] |
|
|
|
|
|
|
|
|
class_name = ["volcano", "mountain", "alp", "mount", "valley"] |
|
|
|
|
|
|
|
|
result = [ClassificationResult(class_name=str(name), confidence=float(probabilities)) for name, probabilities in zip(class_name, top_probabilities)] |
|
|
|
|
|
return result |