from typing import Dict, List, Any import torch from transformers import AutoModel, AutoImageProcessor import base64 from PIL import Image import io class EndpointHandler(): def __init__(self, path="facebook/dinov2-small"): # Load DINOv2 model and image processor self.model = AutoModel.from_pretrained(path) self.processor = AutoImageProcessor.from_pretrained(path) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: # Get images from request images_b64 = data.pop("inputs", data) # Decode base64 images images = [] for img_b64 in images_b64: img = Image.open(io.BytesIO(base64.b64decode(img_b64))) images.append(img) # Process images inputs = self.processor(images=images, return_tensors="pt") # Get embeddings with torch.no_grad(): outputs = self.model(**inputs) # Get global image embedding by averaging the last hidden states image_features = outputs.last_hidden_state.mean(dim=1) # Calculate similarity if 2 images provided if len(images) == 2: similarity = torch.cosine_similarity( image_features[0], image_features[1], dim=0 ).item() return [{"similarity": similarity, "embeddings": image_features.tolist()}] return [{"embeddings": image_features.tolist()}]