|
|
""" |
|
|
Custom HuggingFace Inference Endpoint Handler for CLIP Image Embeddings. |
|
|
|
|
|
This handler generates 512-dimensional embeddings for wine label images using CLIP ViT-B/32. |
|
|
Optimized for similarity search with L2 normalization. |
|
|
|
|
|
Deployment: |
|
|
1. Upload this file to your HuggingFace model repository as 'handler.py' |
|
|
2. Add requirements.txt with dependencies |
|
|
3. Deploy via Inference Endpoints dashboard |
|
|
|
|
|
Input Format: |
|
|
- Binary image data (JPEG/PNG) sent as raw bytes |
|
|
- OR JSON with base64-encoded image: {"inputs": "base64_string"} |
|
|
|
|
|
Output Format: |
|
|
- List of floats (512-dim normalized embedding) |
|
|
- Format: [0.123, 0.456, ..., 0.789] |
|
|
""" |
|
|
|
|
|
from typing import Dict, List, Any, Union |
|
|
import logging |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import io |
|
|
import base64 |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
""" |
|
|
Custom handler for CLIP image embedding generation. |
|
|
|
|
|
Returns L2-normalized 512-dim embeddings for cosine similarity search. |
|
|
""" |
|
|
|
|
|
def __init__(self, path: str = ""): |
|
|
""" |
|
|
Initialize CLIP model and processor. |
|
|
|
|
|
Args: |
|
|
path: Path to model weights (provided by HuggingFace Inference Endpoints) |
|
|
""" |
|
|
try: |
|
|
from transformers import CLIPProcessor, CLIPModel |
|
|
import torch |
|
|
|
|
|
logger.info(f"Loading CLIP model from: {path}") |
|
|
|
|
|
|
|
|
self.model = CLIPModel.from_pretrained(path) |
|
|
self.processor = CLIPProcessor.from_pretrained(path) |
|
|
|
|
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
logger.info(f"CLIP model loaded successfully on device: {self.device}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize CLIP model: {e}") |
|
|
raise RuntimeError(f"Model initialization failed: {e}") |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[float]: |
|
|
""" |
|
|
Generate CLIP embedding for input image. |
|
|
|
|
|
Args: |
|
|
data: Request data with one of: |
|
|
- Binary image bytes (raw JPEG/PNG data) |
|
|
- Dict with "inputs" key containing base64-encoded image string |
|
|
|
|
|
Returns: |
|
|
List[float]: 512-dim L2-normalized embedding vector |
|
|
|
|
|
Raises: |
|
|
ValueError: If image format is invalid or unsupported |
|
|
""" |
|
|
try: |
|
|
|
|
|
image = self._parse_input(data) |
|
|
|
|
|
|
|
|
embedding = self._generate_embedding(image) |
|
|
|
|
|
|
|
|
normalized_embedding = self._normalize_embedding(embedding) |
|
|
|
|
|
logger.info( |
|
|
f"Generated CLIP embedding: dim={len(normalized_embedding)}, " |
|
|
f"norm={np.linalg.norm(normalized_embedding):.3f}" |
|
|
) |
|
|
|
|
|
return normalized_embedding |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error generating embedding: {e}", exc_info=True) |
|
|
raise ValueError(f"Failed to generate embedding: {str(e)}") |
|
|
|
|
|
def _parse_input(self, data: Union[Dict[str, Any], bytes, Image.Image]) -> Image.Image: |
|
|
""" |
|
|
Parse input data into PIL Image. |
|
|
|
|
|
Supports: |
|
|
1. Raw binary image bytes (JPEG/PNG) |
|
|
2. Dict with "inputs" key containing base64 string |
|
|
3. Dict with "inputs" key containing binary bytes |
|
|
|
|
|
Args: |
|
|
data: Input data in various formats |
|
|
|
|
|
Returns: |
|
|
PIL.Image: Parsed image |
|
|
|
|
|
Raises: |
|
|
ValueError: If image format is invalid |
|
|
""" |
|
|
try: |
|
|
|
|
|
if isinstance(data, Image.Image): |
|
|
return data.convert("RGB") |
|
|
|
|
|
|
|
|
if isinstance(data, (bytes, bytearray)): |
|
|
return Image.open(io.BytesIO(data)).convert("RGB") |
|
|
|
|
|
|
|
|
if isinstance(data, dict): |
|
|
|
|
|
inputs = data.get("inputs", data) |
|
|
|
|
|
|
|
|
if isinstance(inputs, Image.Image): |
|
|
return inputs.convert("RGB") |
|
|
|
|
|
|
|
|
if isinstance(inputs, (bytes, bytearray)): |
|
|
return Image.open(io.BytesIO(inputs)).convert("RGB") |
|
|
|
|
|
|
|
|
if isinstance(inputs, str): |
|
|
b64_str = inputs |
|
|
if inputs.startswith("data:"): |
|
|
b64_str = inputs.split(",", 1)[1] |
|
|
image_bytes = base64.b64decode(b64_str) |
|
|
return Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
|
|
|
|
|
|
if isinstance(inputs, dict) and "image" in inputs: |
|
|
inner = inputs["image"] |
|
|
if isinstance(inner, Image.Image): |
|
|
return inner.convert("RGB") |
|
|
if isinstance(inner, (bytes, bytearray)): |
|
|
return Image.open(io.BytesIO(inner)).convert("RGB") |
|
|
if isinstance(inner, str): |
|
|
b64_str = inner |
|
|
if inner.startswith("data:"): |
|
|
b64_str = inner.split(",", 1)[1] |
|
|
image_bytes = base64.b64decode(b64_str) |
|
|
return Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
|
|
|
raise ValueError(f"Unsupported inputs type: {type(inputs)}") |
|
|
|
|
|
raise ValueError(f"Unsupported data type: {type(data)}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to parse input image: {e}") |
|
|
raise ValueError(f"Invalid image format: {str(e)}") |
|
|
|
|
|
def _generate_embedding(self, image: Image.Image) -> np.ndarray: |
|
|
""" |
|
|
Generate CLIP embedding for image. |
|
|
|
|
|
Args: |
|
|
image: PIL Image |
|
|
|
|
|
Returns: |
|
|
np.ndarray: Raw embedding vector (512-dim) |
|
|
""" |
|
|
import torch |
|
|
|
|
|
|
|
|
inputs = self.processor(images=image, return_tensors="pt") |
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
image_features = self.model.get_image_features(**inputs) |
|
|
|
|
|
|
|
|
embedding = image_features.cpu().numpy()[0] |
|
|
|
|
|
return embedding |
|
|
|
|
|
def _normalize_embedding(self, embedding: np.ndarray) -> List[float]: |
|
|
""" |
|
|
L2-normalize embedding for cosine similarity. |
|
|
|
|
|
Args: |
|
|
embedding: Raw embedding vector |
|
|
|
|
|
Returns: |
|
|
List[float]: Normalized embedding (unit norm) |
|
|
""" |
|
|
norm = np.linalg.norm(embedding) |
|
|
|
|
|
if norm == 0: |
|
|
logger.warning("Embedding has zero norm, returning as-is") |
|
|
return embedding.tolist() |
|
|
|
|
|
normalized = embedding / norm |
|
|
return normalized.tolist() |
|
|
|