kas1293's picture
Upload handler.py
7b1872f verified
Raw
History Blame Contribute Delete
1.28 kB
from typing import Dict, Any
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
import torch
import base64
from io import BytesIO
class EndpointHandler:
def __init__(self, path=""):
self.model = CLIPModel.from_pretrained(path)
self.processor = CLIPProcessor.from_pretrained(path)
self.model.eval()
def _to_image(self, x) -> Image.Image:
if isinstance(x, Image.Image):
return x.convert("RGB")
if isinstance(x, (bytes, bytearray)):
return Image.open(BytesIO(x)).convert("RGB")
if isinstance(x, str):
return Image.open(BytesIO(base64.b64decode(x))).convert("RGB")
if isinstance(x, dict) and "image" in x:
return self._to_image(x["image"])
raise ValueError("Unsupported image input")
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
inputs = data.get("inputs", data)
image = self._to_image(inputs)
proc = self.processor(images=image, return_tensors="pt")
with torch.no_grad():
feats = self.model.get_image_features(**proc)
feats = feats / feats.norm(p=2, dim=-1, keepdim=True) # L2 normalize
return {"embedding": feats[0].tolist(), "dim": int(feats.shape[-1])}