| import torch, open_clip | |
| from PIL import Image | |
| from typing import Any, Dict | |
| class EndpointHandler: | |
| def __init__(self, model_dir: str): | |
| self.device = "cpu" | |
| self.model, _, self.preprocess = open_clip.create_model_and_transforms( | |
| "ViT-B-32", pretrained="laion2b_s34b_b79K", device=self.device | |
| ) | |
| self.tokenizer = open_clip.get_tokenizer("ViT-B-32") | |
| def _encode_text(self, text: str): | |
| tokens = self.tokenizer([text]).to(self.device) | |
| with torch.no_grad(): | |
| return self.model.encode_text(tokens).cpu().numpy()[0].tolist() | |
| def _encode_image(self, image: Image.Image): | |
| img = self.preprocess(image).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| return self.model.encode_image(img).cpu().numpy()[0].tolist() | |
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |
| if "image" in data: | |
| if isinstance(data["image"], str): | |
| import requests, io | |
| resp = requests.get(data["image"]) | |
| img = Image.open(io.BytesIO(resp.content)).convert("RGB") | |
| else: | |
| img = Image.open(data["image"]).convert("RGB") | |
| emb = self._encode_image(img) | |
| elif "inputs" in data: | |
| emb = self._encode_text(data["inputs"]) | |
| else: | |
| raise ValueError("Provide 'image' or 'inputs'.") | |
| return {"embedding": emb} | |