import base64 import io from pathlib import Path from typing import Any, Dict, List from urllib.request import urlopen import open_clip import torch from PIL import Image, ImageOps from torchvision.transforms import Compose, Normalize, ToTensor INPUT_SIZE = 224 def _is_git_lfs_pointer(path: Path) -> bool: if not path.is_file() or path.stat().st_size > 1024: return False with path.open("rb") as handle: return handle.read(64).startswith(b"version https://git-lfs.github.com/spec/v1") class EndpointHandler: def __init__(self, model_dir: str = "", **kwargs: Any): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model_dir = Path(model_dir or "/repository") self._validate_model_files() model_id = f"local-dir:{self.model_dir}" self.model, preprocess = open_clip.create_model_from_pretrained( model_id, device=self.device, return_transform=True, ) self.tokenizer = open_clip.get_tokenizer(model_id) self.model.eval() self.tensor_preprocess = self._build_tensor_preprocess(preprocess) def _validate_model_files(self) -> None: config_path = self.model_dir / "open_clip_config.json" checkpoint_paths = [ self.model_dir / "open_clip_model.safetensors", self.model_dir / "open_clip_pytorch_model.bin", ] if not config_path.is_file(): raise FileNotFoundError( f"Missing {config_path.name} in {self.model_dir}. " "This repository must contain the OpenCLIP config file." ) existing_checkpoints = [path for path in checkpoint_paths if path.is_file()] if not existing_checkpoints: raise FileNotFoundError( f"No OpenCLIP checkpoint found in {self.model_dir}. " "Expected open_clip_model.safetensors or open_clip_pytorch_model.bin." ) pointer_paths = [path.name for path in existing_checkpoints if _is_git_lfs_pointer(path)] if pointer_paths: raise RuntimeError( "The repository contains Git LFS pointer files instead of real model weights: " f"{', '.join(pointer_paths)}. " "Upload the actual LFS blobs to the Hugging Face model repo before starting the endpoint." ) @staticmethod def _build_tensor_preprocess(original_preprocess) -> Compose: """Extract Normalize from the model's preprocess and build ToTensor + Normalize only. The default model preprocess includes Resize + CenterCrop + ToTensor + Normalize. Since we manually squash images to INPUT_SIZE x INPUT_SIZE, we only need ToTensor + Normalize to match the existing embedding pipeline. """ normalize = None for t in original_preprocess.transforms: if isinstance(t, Normalize): normalize = t break if normalize is None: normalize = Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) return Compose([ToTensor(), normalize]) @staticmethod def _prepare_image(img: Image.Image) -> Image.Image: """Squash image to INPUT_SIZE x INPUT_SIZE.""" return img.resize((INPUT_SIZE, INPUT_SIZE), Image.BICUBIC) def _load_image(self, image_input: Any) -> Image.Image | None: if not isinstance(image_input, str): return None if image_input.startswith(("http://", "https://")): with urlopen(image_input, timeout=10) as response: img = Image.open(io.BytesIO(response.read())) else: image_bytes = base64.b64decode(image_input.split(",")[-1]) img = Image.open(io.BytesIO(image_bytes)) img = ImageOps.exif_transpose(img) return img.convert("RGB") def _preprocess_image(self, image: Image.Image) -> torch.Tensor: """Squash to INPUT_SIZE and apply tensor normalization.""" image = self._prepare_image(image) return self.tensor_preprocess(image).unsqueeze(0).to(self.device) def _tokenize_text(self, text: str | List[str]) -> torch.Tensor: texts = text if isinstance(text, list) else [text] return self.tokenizer(texts).to(self.device) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: payload = data.get("inputs", data) text = payload.get("text") image_input = payload.get("image") image = self._load_image(image_input) with torch.no_grad(): if image is not None and text is not None: image_tensor = self._preprocess_image(image) text_tensor = self._tokenize_text(text) image_features = self.model.encode_image(image_tensor, normalize=True) text_features = self.model.encode_text(text_tensor, normalize=True) response = {"image_embedding": image_features[0].cpu().tolist()} if isinstance(text, list): response["text_embeddings"] = text_features.cpu().tolist() else: response["text_embedding"] = text_features[0].cpu().tolist() return response elif image is not None: image_tensor = self._preprocess_image(image) image_features = self.model.encode_image(image_tensor, normalize=True) return {"image_embedding": image_features[0].cpu().tolist()} elif text is not None: text_tensor = self._tokenize_text(text) text_features = self.model.encode_text(text_tensor, normalize=True) if isinstance(text, list): return {"text_embeddings": text_features.cpu().tolist()} return {"text_embedding": text_features[0].cpu().tolist()} else: return {"error": "Provide 'text' or 'image' (base64 or URL)."}