File size: 4,739 Bytes
a465973 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 | """
SigLIP2 embedding handler for Hugging Face Inference Endpoints.
Supports image and text embeddings via get_image_features and get_text_features.
"""
import base64
from io import BytesIO
from typing import Any, Dict, List, Optional, Union
import torch
from PIL import Image
from transformers import AutoModel, AutoProcessor
from transformers.image_utils import load_image
def _load_image_from_input(image_input: Union[str, bytes]) -> Image.Image:
"""Load a PIL Image from a URL, file path, or base64 string."""
if isinstance(image_input, bytes):
return Image.open(BytesIO(image_input)).convert("RGB")
if not isinstance(image_input, str):
raise ValueError(f"Image input must be str or bytes, got {type(image_input)}")
# Base64 string (with or without data URL prefix)
if image_input.startswith("data:"):
# Format: data:image/jpeg;base64,<b64data>
b64_data = image_input.split(",", 1)[1] if "," in image_input else image_input
return Image.open(BytesIO(base64.b64decode(b64_data))).convert("RGB")
if image_input.startswith("/9j/") or len(image_input) > 500:
# Likely raw base64 without prefix
try:
return Image.open(BytesIO(base64.b64decode(image_input))).convert("RGB")
except Exception:
pass
# URL or file path
return load_image(image_input)
class EndpointHandler:
"""Hugging Face Inference Endpoints handler for SigLIP2 image and text embeddings."""
def __init__(self, path: str = ""):
"""Load model and processor from the given path (repo root when deployed)."""
self.model = (
AutoModel.from_pretrained(
path,
device_map="auto",
torch_dtype=torch.float16,
)
.eval()
)
self.processor = AutoProcessor.from_pretrained(path)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process a request containing images and/or texts and return embeddings.
Args:
data: Request payload with "inputs" key. Expected shape:
{
"inputs": {
"images": ["url1", "url2"] | ["data:image/jpeg;base64,...", ...],
"texts": ["text1", "text2"]
},
"normalize": true # optional, default True
}
At least one of "images" or "texts" must be provided.
Returns:
{
"image_embeddings": [[...], [...]] | null,
"text_embeddings": [[...], [...]] | null
}
"""
payload = data.get("inputs", data)
normalize = data.get("normalize", True)
if not isinstance(payload, dict):
raise ValueError(
"inputs must be a dict with 'images' and/or 'texts' keys. "
f"Got {type(payload)}."
)
images = payload.get("images")
texts = payload.get("texts")
if not images and not texts:
raise ValueError("At least one of 'images' or 'texts' must be provided.")
if images is not None and not isinstance(images, list):
raise ValueError("'images' must be a list.")
if texts is not None and not isinstance(texts, list):
raise ValueError("'texts' must be a list.")
result: Dict[str, Optional[List[List[float]]]] = {
"image_embeddings": None,
"text_embeddings": None,
}
with torch.no_grad():
if images:
pil_images = [_load_image_from_input(img) for img in images]
inputs = self.processor(
images=pil_images,
return_tensors="pt",
max_num_patches=256,
).to(self.model.device)
image_embeddings = self.model.get_image_features(**inputs)
if normalize:
image_embeddings = image_embeddings / image_embeddings.norm(
p=2, dim=-1, keepdim=True
)
result["image_embeddings"] = image_embeddings.cpu().tolist()
if texts:
inputs = self.processor(
text=texts,
return_tensors="pt",
).to(self.model.device)
text_embeddings = self.model.get_text_features(**inputs)
if normalize:
text_embeddings = text_embeddings / text_embeddings.norm(
p=2, dim=-1, keepdim=True
)
result["text_embeddings"] = text_embeddings.cpu().tolist()
return result
|