from typing import Dict, Any from PIL import Image import open_clip import torch import base64 import io import os import requests class EndpointHandler: def __init__(self, path: str = ""): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") safetensors = f"{path}/open_clip_model.safetensors" bin_file = f"{path}/open_clip_pytorch_model.bin" if os.path.exists(safetensors): pretrained = safetensors elif os.path.exists(bin_file): pretrained = bin_file else: raise RuntimeError(f"No open_clip weights found in {path}") self.model, _, self.preprocess = open_clip.create_model_and_transforms( "ViT-B-16", pretrained=pretrained, ) self.tokenizer = open_clip.get_tokenizer("ViT-B-16") self.model = self.model.to(self.device) self.model.eval() def __call__(self, data: Dict[str, Any]) -> list: inputs = data.get("inputs") if not inputs: raise ValueError("'inputs' is required — pass an image URL, base64 string, or text") if self._is_image(inputs): return self._embed_image(inputs) else: return self._embed_text(inputs) def _is_image(self, source: str) -> bool: return source.startswith("http://") or source.startswith("https://") def _embed_image(self, source: str) -> list: image = self._load_image(source) pixel_values = self.preprocess(image).unsqueeze(0).to(self.device) with torch.no_grad(): features = self.model.encode_image(pixel_values, normalize=True) return features[0].tolist() def _embed_text(self, text: str) -> list: tokens = self.tokenizer([text]).to(self.device) with torch.no_grad(): features = self.model.encode_text(tokens, normalize=True) return features[0].tolist() def _load_image(self, source: str) -> Image.Image: if source.startswith("http://") or source.startswith("https://"): response = requests.get(source, timeout=10) response.raise_for_status() return Image.open(io.BytesIO(response.content)).convert("RGB") try: image_bytes = base64.b64decode(source) return Image.open(io.BytesIO(image_bytes)).convert("RGB") except Exception as e: raise ValueError(f"Could not load image from input: {e}")