kas1293's picture
Upload handler.py
e9dc81c verified
Raw
History Blame
1.28 kB
from typing import Dict, Any
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
import torch
import base64
from io import BytesIO
class EndpointHandler:
def __init__(self, path=""):
self.model = CLIPModel.from_pretrained(path)
self.processor = CLIPProcessor.from_pretrained(path)
self.model.eval()
def _to_image(self, x) -> Image.Image:
if isinstance(x, Image.Image):
return x.convert("RGB")
if isinstance(x, (bytes, bytearray)):
return Image.open(BytesIO(x)).convert("RGB")
if isinstance(x, str):
return Image.open(BytesIO(base64.b64decode(x))).convert("RGB")
if isinstance(x, dict) and "image" in x:
return self._to_image(x["image"])
raise ValueError("Unsupported image input")
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
inputs = data.get("inputs", data)
image = self._to_image(inputs)
proc = self.processor(images=image, return_tensors="pt")
with torch.no_grad():
feats = self.model.get_image_features(**proc)
feats = feats / feats.norm(p=2, dim=-1, keepdim=True) # L2 normalize
return {"embedding": feats[0].tolist(), "dim": int(feats.shape[-1])}