hedemil
Update handler.py with new logic
9405dd8
"""
Custom HuggingFace Inference Endpoint Handler for CLIP Image Embeddings.
This handler generates 512-dimensional embeddings for wine label images using CLIP ViT-B/32.
Optimized for similarity search with L2 normalization.
Deployment:
1. Upload this file to your HuggingFace model repository as 'handler.py'
2. Add requirements.txt with dependencies
3. Deploy via Inference Endpoints dashboard
Input Format:
- Binary image data (JPEG/PNG) sent as raw bytes
- OR JSON with base64-encoded image: {"inputs": "base64_string"}
Output Format:
- List of floats (512-dim normalized embedding)
- Format: [0.123, 0.456, ..., 0.789]
"""
from typing import Dict, List, Any, Union
import logging
import numpy as np
from PIL import Image
import io
import base64
logger = logging.getLogger(__name__)
class EndpointHandler:
"""
Custom handler for CLIP image embedding generation.
Returns L2-normalized 512-dim embeddings for cosine similarity search.
"""
def __init__(self, path: str = ""):
"""
Initialize CLIP model and processor.
Args:
path: Path to model weights (provided by HuggingFace Inference Endpoints)
"""
try:
from transformers import CLIPProcessor, CLIPModel
import torch
logger.info(f"Loading CLIP model from: {path}")
# Load CLIP ViT-B/32 model and processor
self.model = CLIPModel.from_pretrained(path)
self.processor = CLIPProcessor.from_pretrained(path)
# Set device (GPU if available, otherwise CPU)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
self.model.eval() # Set to evaluation mode
logger.info(f"CLIP model loaded successfully on device: {self.device}")
except Exception as e:
logger.error(f"Failed to initialize CLIP model: {e}")
raise RuntimeError(f"Model initialization failed: {e}")
def __call__(self, data: Dict[str, Any]) -> List[float]:
"""
Generate CLIP embedding for input image.
Args:
data: Request data with one of:
- Binary image bytes (raw JPEG/PNG data)
- Dict with "inputs" key containing base64-encoded image string
Returns:
List[float]: 512-dim L2-normalized embedding vector
Raises:
ValueError: If image format is invalid or unsupported
"""
try:
# Handle different input formats
image = self._parse_input(data)
# Generate embedding
embedding = self._generate_embedding(image)
# Normalize for cosine similarity
normalized_embedding = self._normalize_embedding(embedding)
logger.info(
f"Generated CLIP embedding: dim={len(normalized_embedding)}, "
f"norm={np.linalg.norm(normalized_embedding):.3f}"
)
return normalized_embedding
except Exception as e:
logger.error(f"Error generating embedding: {e}", exc_info=True)
raise ValueError(f"Failed to generate embedding: {str(e)}")
def _parse_input(self, data: Union[Dict[str, Any], bytes, Image.Image]) -> Image.Image:
"""
Parse input data into PIL Image.
Supports:
1. Raw binary image bytes (JPEG/PNG)
2. Dict with "inputs" key containing base64 string
3. Dict with "inputs" key containing binary bytes
Args:
data: Input data in various formats
Returns:
PIL.Image: Parsed image
Raises:
ValueError: If image format is invalid
"""
try:
# Case 0: Already a PIL Image
if isinstance(data, Image.Image):
return data.convert("RGB")
# Case 1: Raw binary bytes directly
if isinstance(data, (bytes, bytearray)):
return Image.open(io.BytesIO(data)).convert("RGB")
# Case 2: Dict with possible variants
if isinstance(data, dict):
# Many endpoints pass {"inputs": <something>}
inputs = data.get("inputs", data)
# 2a: Inputs is already a PIL image
if isinstance(inputs, Image.Image):
return inputs.convert("RGB")
# 2b: Raw bytes
if isinstance(inputs, (bytes, bytearray)):
return Image.open(io.BytesIO(inputs)).convert("RGB")
# 2c: Base64 string (plain or data URL)
if isinstance(inputs, str):
b64_str = inputs
if inputs.startswith("data:"):
b64_str = inputs.split(",", 1)[1]
image_bytes = base64.b64decode(b64_str)
return Image.open(io.BytesIO(image_bytes)).convert("RGB")
# 2d: Nested dict like {"image": <...>}
if isinstance(inputs, dict) and "image" in inputs:
inner = inputs["image"]
if isinstance(inner, Image.Image):
return inner.convert("RGB")
if isinstance(inner, (bytes, bytearray)):
return Image.open(io.BytesIO(inner)).convert("RGB")
if isinstance(inner, str):
b64_str = inner
if inner.startswith("data:"):
b64_str = inner.split(",", 1)[1]
image_bytes = base64.b64decode(b64_str)
return Image.open(io.BytesIO(image_bytes)).convert("RGB")
raise ValueError(f"Unsupported inputs type: {type(inputs)}")
raise ValueError(f"Unsupported data type: {type(data)}")
except Exception as e:
logger.error(f"Failed to parse input image: {e}")
raise ValueError(f"Invalid image format: {str(e)}")
def _generate_embedding(self, image: Image.Image) -> np.ndarray:
"""
Generate CLIP embedding for image.
Args:
image: PIL Image
Returns:
np.ndarray: Raw embedding vector (512-dim)
"""
import torch
# Preprocess image
inputs = self.processor(images=image, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Generate embedding with no gradient computation
with torch.no_grad():
image_features = self.model.get_image_features(**inputs)
# Convert to numpy
embedding = image_features.cpu().numpy()[0]
return embedding
def _normalize_embedding(self, embedding: np.ndarray) -> List[float]:
"""
L2-normalize embedding for cosine similarity.
Args:
embedding: Raw embedding vector
Returns:
List[float]: Normalized embedding (unit norm)
"""
norm = np.linalg.norm(embedding)
if norm == 0:
logger.warning("Embedding has zero norm, returning as-is")
return embedding.tolist()
normalized = embedding / norm
return normalized.tolist()