File size: 6,019 Bytes
3e0528a 3b2121f 3e0528a 3b2121f 3e0528a 5eb4524 3b2121f 3e0528a 3b2121f 5eb4524 3b2121f 3e0528a 5eb4524 3e0528a 3b2121f 5eb4524 3b2121f 5eb4524 3b2121f 5eb4524 3b2121f 3e0528a 3b2121f 3e0528a 5eb4524 3b2121f bcfaacf 3b2121f 3e0528a 5eb4524 bcfaacf 3e0528a 3b2121f bcfaacf 3b2121f 3e0528a 3b2121f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | import base64
import io
from pathlib import Path
from typing import Any, Dict, List
from urllib.request import urlopen
import open_clip
import torch
from PIL import Image, ImageOps
from torchvision.transforms import Compose, Normalize, ToTensor
INPUT_SIZE = 224
def _is_git_lfs_pointer(path: Path) -> bool:
if not path.is_file() or path.stat().st_size > 1024:
return False
with path.open("rb") as handle:
return handle.read(64).startswith(b"version https://git-lfs.github.com/spec/v1")
class EndpointHandler:
def __init__(self, model_dir: str = "", **kwargs: Any):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model_dir = Path(model_dir or "/repository")
self._validate_model_files()
model_id = f"local-dir:{self.model_dir}"
self.model, preprocess = open_clip.create_model_from_pretrained(
model_id,
device=self.device,
return_transform=True,
)
self.tokenizer = open_clip.get_tokenizer(model_id)
self.model.eval()
self.tensor_preprocess = self._build_tensor_preprocess(preprocess)
def _validate_model_files(self) -> None:
config_path = self.model_dir / "open_clip_config.json"
checkpoint_paths = [
self.model_dir / "open_clip_model.safetensors",
self.model_dir / "open_clip_pytorch_model.bin",
]
if not config_path.is_file():
raise FileNotFoundError(
f"Missing {config_path.name} in {self.model_dir}. "
"This repository must contain the OpenCLIP config file."
)
existing_checkpoints = [path for path in checkpoint_paths if path.is_file()]
if not existing_checkpoints:
raise FileNotFoundError(
f"No OpenCLIP checkpoint found in {self.model_dir}. "
"Expected open_clip_model.safetensors or open_clip_pytorch_model.bin."
)
pointer_paths = [path.name for path in existing_checkpoints if _is_git_lfs_pointer(path)]
if pointer_paths:
raise RuntimeError(
"The repository contains Git LFS pointer files instead of real model weights: "
f"{', '.join(pointer_paths)}. "
"Upload the actual LFS blobs to the Hugging Face model repo before starting the endpoint."
)
@staticmethod
def _build_tensor_preprocess(original_preprocess) -> Compose:
"""Extract Normalize from the model's preprocess and build ToTensor + Normalize only.
The default model preprocess includes Resize + CenterCrop + ToTensor + Normalize.
Since we manually squash images to INPUT_SIZE x INPUT_SIZE, we only need
ToTensor + Normalize to match the existing embedding pipeline.
"""
normalize = None
for t in original_preprocess.transforms:
if isinstance(t, Normalize):
normalize = t
break
if normalize is None:
normalize = Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
return Compose([ToTensor(), normalize])
@staticmethod
def _prepare_image(img: Image.Image) -> Image.Image:
"""Squash image to INPUT_SIZE x INPUT_SIZE."""
return img.resize((INPUT_SIZE, INPUT_SIZE), Image.BICUBIC)
def _load_image(self, image_input: Any) -> Image.Image | None:
if not isinstance(image_input, str):
return None
if image_input.startswith(("http://", "https://")):
with urlopen(image_input, timeout=10) as response:
img = Image.open(io.BytesIO(response.read()))
else:
image_bytes = base64.b64decode(image_input.split(",")[-1])
img = Image.open(io.BytesIO(image_bytes))
img = ImageOps.exif_transpose(img)
return img.convert("RGB")
def _preprocess_image(self, image: Image.Image) -> torch.Tensor:
"""Squash to INPUT_SIZE and apply tensor normalization."""
image = self._prepare_image(image)
return self.tensor_preprocess(image).unsqueeze(0).to(self.device)
def _tokenize_text(self, text: str | List[str]) -> torch.Tensor:
texts = text if isinstance(text, list) else [text]
return self.tokenizer(texts).to(self.device)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
payload = data.get("inputs", data)
text = payload.get("text")
image_input = payload.get("image")
image = self._load_image(image_input)
with torch.no_grad():
if image is not None and text is not None:
image_tensor = self._preprocess_image(image)
text_tensor = self._tokenize_text(text)
image_features = self.model.encode_image(image_tensor, normalize=True)
text_features = self.model.encode_text(text_tensor, normalize=True)
response = {"image_embedding": image_features[0].cpu().tolist()}
if isinstance(text, list):
response["text_embeddings"] = text_features.cpu().tolist()
else:
response["text_embedding"] = text_features[0].cpu().tolist()
return response
elif image is not None:
image_tensor = self._preprocess_image(image)
image_features = self.model.encode_image(image_tensor, normalize=True)
return {"image_embedding": image_features[0].cpu().tolist()}
elif text is not None:
text_tensor = self._tokenize_text(text)
text_features = self.model.encode_text(text_tensor, normalize=True)
if isinstance(text, list):
return {"text_embeddings": text_features.cpu().tolist()}
return {"text_embedding": text_features[0].cpu().tolist()}
else:
return {"error": "Provide 'text' or 'image' (base64 or URL)."}
|