| import base64 |
| import io |
| from pathlib import Path |
| from typing import Any, Dict, List |
| from urllib.request import urlopen |
|
|
| import open_clip |
| import torch |
| from PIL import Image, ImageOps |
| from torchvision.transforms import Compose, Normalize, ToTensor |
|
|
| INPUT_SIZE = 224 |
|
|
|
|
| def _is_git_lfs_pointer(path: Path) -> bool: |
| if not path.is_file() or path.stat().st_size > 1024: |
| return False |
|
|
| with path.open("rb") as handle: |
| return handle.read(64).startswith(b"version https://git-lfs.github.com/spec/v1") |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, model_dir: str = "", **kwargs: Any): |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.model_dir = Path(model_dir or "/repository") |
|
|
| self._validate_model_files() |
|
|
| model_id = f"local-dir:{self.model_dir}" |
| self.model, preprocess = open_clip.create_model_from_pretrained( |
| model_id, |
| device=self.device, |
| return_transform=True, |
| ) |
| self.tokenizer = open_clip.get_tokenizer(model_id) |
| self.model.eval() |
| self.tensor_preprocess = self._build_tensor_preprocess(preprocess) |
|
|
| def _validate_model_files(self) -> None: |
| config_path = self.model_dir / "open_clip_config.json" |
| checkpoint_paths = [ |
| self.model_dir / "open_clip_model.safetensors", |
| self.model_dir / "open_clip_pytorch_model.bin", |
| ] |
|
|
| if not config_path.is_file(): |
| raise FileNotFoundError( |
| f"Missing {config_path.name} in {self.model_dir}. " |
| "This repository must contain the OpenCLIP config file." |
| ) |
|
|
| existing_checkpoints = [path for path in checkpoint_paths if path.is_file()] |
| if not existing_checkpoints: |
| raise FileNotFoundError( |
| f"No OpenCLIP checkpoint found in {self.model_dir}. " |
| "Expected open_clip_model.safetensors or open_clip_pytorch_model.bin." |
| ) |
|
|
| pointer_paths = [path.name for path in existing_checkpoints if _is_git_lfs_pointer(path)] |
| if pointer_paths: |
| raise RuntimeError( |
| "The repository contains Git LFS pointer files instead of real model weights: " |
| f"{', '.join(pointer_paths)}. " |
| "Upload the actual LFS blobs to the Hugging Face model repo before starting the endpoint." |
| ) |
|
|
| @staticmethod |
| def _build_tensor_preprocess(original_preprocess) -> Compose: |
| """Extract Normalize from the model's preprocess and build ToTensor + Normalize only. |
| |
| The default model preprocess includes Resize + CenterCrop + ToTensor + Normalize. |
| Since we manually squash images to INPUT_SIZE x INPUT_SIZE, we only need |
| ToTensor + Normalize to match the existing embedding pipeline. |
| """ |
| normalize = None |
| for t in original_preprocess.transforms: |
| if isinstance(t, Normalize): |
| normalize = t |
| break |
| if normalize is None: |
| normalize = Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) |
| return Compose([ToTensor(), normalize]) |
|
|
| @staticmethod |
| def _prepare_image(img: Image.Image) -> Image.Image: |
| """Squash image to INPUT_SIZE x INPUT_SIZE.""" |
| return img.resize((INPUT_SIZE, INPUT_SIZE), Image.BICUBIC) |
|
|
| def _load_image(self, image_input: Any) -> Image.Image | None: |
| if not isinstance(image_input, str): |
| return None |
|
|
| if image_input.startswith(("http://", "https://")): |
| with urlopen(image_input, timeout=10) as response: |
| img = Image.open(io.BytesIO(response.read())) |
| else: |
| image_bytes = base64.b64decode(image_input.split(",")[-1]) |
| img = Image.open(io.BytesIO(image_bytes)) |
|
|
| img = ImageOps.exif_transpose(img) |
| return img.convert("RGB") |
|
|
| def _preprocess_image(self, image: Image.Image) -> torch.Tensor: |
| """Squash to INPUT_SIZE and apply tensor normalization.""" |
| image = self._prepare_image(image) |
| return self.tensor_preprocess(image).unsqueeze(0).to(self.device) |
|
|
| def _tokenize_text(self, text: str | List[str]) -> torch.Tensor: |
| texts = text if isinstance(text, list) else [text] |
| return self.tokenizer(texts).to(self.device) |
|
|
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| payload = data.get("inputs", data) |
|
|
| text = payload.get("text") |
| image_input = payload.get("image") |
|
|
| image = self._load_image(image_input) |
|
|
| with torch.no_grad(): |
| if image is not None and text is not None: |
| image_tensor = self._preprocess_image(image) |
| text_tensor = self._tokenize_text(text) |
|
|
| image_features = self.model.encode_image(image_tensor, normalize=True) |
| text_features = self.model.encode_text(text_tensor, normalize=True) |
|
|
| response = {"image_embedding": image_features[0].cpu().tolist()} |
| if isinstance(text, list): |
| response["text_embeddings"] = text_features.cpu().tolist() |
| else: |
| response["text_embedding"] = text_features[0].cpu().tolist() |
| return response |
| elif image is not None: |
| image_tensor = self._preprocess_image(image) |
| image_features = self.model.encode_image(image_tensor, normalize=True) |
| return {"image_embedding": image_features[0].cpu().tolist()} |
| elif text is not None: |
| text_tensor = self._tokenize_text(text) |
| text_features = self.model.encode_text(text_tensor, normalize=True) |
| if isinstance(text, list): |
| return {"text_embeddings": text_features.cpu().tolist()} |
| return {"text_embedding": text_features[0].cpu().tolist()} |
| else: |
| return {"error": "Provide 'text' or 'image' (base64 or URL)."} |
|
|