|
|
from typing import Dict, List, Any |
|
|
import torch |
|
|
from transformers import AutoProcessor, AutoModel |
|
|
from PIL import Image |
|
|
import base64 |
|
|
import io |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path=""): |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model_id = "google/siglip2-so400m-patch14-384" |
|
|
self.processor = AutoProcessor.from_pretrained(model_id) |
|
|
self.model = AutoModel.from_pretrained(model_id).to(self.device).eval() |
|
|
|
|
|
def __call__(self, data: Any) -> List[List[float]]: |
|
|
""" |
|
|
Args: |
|
|
data (:obj:): |
|
|
includes the input data and the parameters for the inference. |
|
|
Return: |
|
|
A :obj:`list`:. The output of the model. |
|
|
""" |
|
|
inputs_data = data.get("inputs", data) |
|
|
|
|
|
|
|
|
if not isinstance(inputs_data, list): |
|
|
inputs_data = [inputs_data] |
|
|
|
|
|
results = [] |
|
|
for item in inputs_data: |
|
|
try: |
|
|
|
|
|
if isinstance(item, str) and not self._is_base64(item): |
|
|
inputs = self.processor(text=[item], padding="max_length", return_tensors="pt").to(self.device) |
|
|
with torch.no_grad(): |
|
|
features = self.model.get_text_features(**inputs) |
|
|
results.append(features[0].cpu().tolist()) |
|
|
|
|
|
else: |
|
|
image = self._decode_image(item) |
|
|
|
|
|
inputs = self.processor(images=[image], return_tensors="pt").to(self.device) |
|
|
with torch.no_grad(): |
|
|
features = self.model.get_image_features(**inputs) |
|
|
results.append(features[0].cpu().tolist()) |
|
|
except Exception as e: |
|
|
print(f"Error processing item: {e}") |
|
|
raise e |
|
|
|
|
|
return results |
|
|
|
|
|
def _is_base64(self, s): |
|
|
try: |
|
|
if isinstance(s, bytes): |
|
|
s = s.decode('utf-8') |
|
|
return base64.b64encode(base64.b64decode(s)).decode('utf-8') == s.replace('\n', '').replace('\r', '') |
|
|
except Exception: |
|
|
return False |
|
|
|
|
|
def _decode_image(self, data): |
|
|
try: |
|
|
if isinstance(data, str): |
|
|
image_bytes = base64.b64decode(data) |
|
|
else: |
|
|
image_bytes = data |
|
|
img = Image.open(io.BytesIO(image_bytes)) |
|
|
|
|
|
img.load() |
|
|
return img.convert("RGB") |
|
|
except Exception as e: |
|
|
print(f"Image decode failed: {e}") |
|
|
raise ValueError(f"Invalid image data: {e}") |
|
|
|