Findle
/

marqo-fashionCLIP / handler.py
Findle's picture
Update handler.py
0d3b7b8 verified
from typing import Dict, Any
from PIL import Image
import open_clip
import torch
import base64
import io
import os
import requests
class EndpointHandler:
def __init__(self, path: str = ""):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
safetensors = f"{path}/open_clip_model.safetensors"
bin_file = f"{path}/open_clip_pytorch_model.bin"
if os.path.exists(safetensors):
pretrained = safetensors
elif os.path.exists(bin_file):
pretrained = bin_file
else:
raise RuntimeError(f"No open_clip weights found in {path}")
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
"ViT-B-16",
pretrained=pretrained,
)
self.tokenizer = open_clip.get_tokenizer("ViT-B-16")
self.model = self.model.to(self.device)
self.model.eval()
def __call__(self, data: Dict[str, Any]) -> list:
inputs = data.get("inputs")
if not inputs:
raise ValueError("'inputs' is required — pass an image URL, base64 string, or text")
if self._is_image(inputs):
return self._embed_image(inputs)
else:
return self._embed_text(inputs)
def _is_image(self, source: str) -> bool:
return source.startswith("http://") or source.startswith("https://")
def _embed_image(self, source: str) -> list:
image = self._load_image(source)
pixel_values = self.preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad():
features = self.model.encode_image(pixel_values, normalize=True)
return features[0].tolist()
def _embed_text(self, text: str) -> list:
tokens = self.tokenizer([text]).to(self.device)
with torch.no_grad():
features = self.model.encode_text(tokens, normalize=True)
return features[0].tolist()
def _load_image(self, source: str) -> Image.Image:
if source.startswith("http://") or source.startswith("https://"):
response = requests.get(source, timeout=10)
response.raise_for_status()
return Image.open(io.BytesIO(response.content)).convert("RGB")
try:
image_bytes = base64.b64decode(source)
return Image.open(io.BytesIO(image_bytes)).convert("RGB")
except Exception as e:
raise ValueError(f"Could not load image from input: {e}")