| from typing import Dict, Any |
| from PIL import Image |
| from transformers import CLIPProcessor, CLIPModel |
| import torch |
| import base64 |
| from io import BytesIO |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| self.model = CLIPModel.from_pretrained(path) |
| self.processor = CLIPProcessor.from_pretrained(path) |
| self.model.eval() |
|
|
| def _to_image(self, x) -> Image.Image: |
| if isinstance(x, Image.Image): |
| return x.convert("RGB") |
| if isinstance(x, (bytes, bytearray)): |
| return Image.open(BytesIO(x)).convert("RGB") |
| if isinstance(x, str): |
| return Image.open(BytesIO(base64.b64decode(x))).convert("RGB") |
| if isinstance(x, dict) and "image" in x: |
| return self._to_image(x["image"]) |
| raise ValueError("Unsupported image input") |
|
|
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| inputs = data.get("inputs", data) |
| image = self._to_image(inputs) |
| proc = self.processor(images=image, return_tensors="pt") |
| with torch.no_grad(): |
| feats = self.model.get_image_features(**proc) |
| feats = feats / feats.norm(p=2, dim=-1, keepdim=True) |
| return {"embedding": feats[0].tolist(), "dim": int(feats.shape[-1])} |