| from transformers import CLIPModel, CLIPProcessor | |
| from typing import Dict, List, Any | |
| from PIL import Image | |
| from transformers import pipeline | |
| import requests | |
| import torch | |
| class EndpointHandler(): | |
| def __init__(self, path=""): | |
| """ | |
| path: | |
| """ | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.processor = CLIPProcessor.from_pretrained(path) | |
| self.model = CLIPModel.from_pretrained(path).to(self.device) | |
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: | |
| """ | |
| data args: | |
| inputs (:obj: `str` | `PIL.Image` | `np.array`) | |
| kwargs | |
| Return: | |
| A :obj:`list` | `dict`: will be serialized and returned | |
| """ | |
| result = {} | |
| inputs = data.pop("inputs", data) | |
| image_url = inputs['image_url'] | |
| image = Image.open(requests.get(image_url, stream=True).raw).convert('RGB') | |
| processed_image = self.processor(images=image, return_tensors="pt").to(self.device) | |
| output = self.model.get_image_features(processed_image["pixel_values"])[0].tolist() | |
| result["embedding"] = output | |
| return result | |