Findle
/

File size: 2,483 Bytes
0d1ca6a
 
446df1b
0d1ca6a
 
 
446df1b
0d1ca6a
 
 
 
 
b4d3c6e
446df1b
9839589
 
 
 
 
 
 
 
0d3b7b8
446df1b
 
 
 
0d3b7b8
446df1b
0d1ca6a
 
 
 
 
0d3b7b8
0d1ca6a
0d3b7b8
 
 
 
 
 
 
0d1ca6a
0d3b7b8
 
 
0d1ca6a
446df1b
0d3b7b8
0d1ca6a
0d3b7b8
 
 
 
0d1ca6a
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from typing import Dict, Any
from PIL import Image
import open_clip
import torch
import base64
import io
import os
import requests


class EndpointHandler:
    def __init__(self, path: str = ""):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        safetensors = f"{path}/open_clip_model.safetensors"
        bin_file     = f"{path}/open_clip_pytorch_model.bin"
        if os.path.exists(safetensors):
            pretrained = safetensors
        elif os.path.exists(bin_file):
            pretrained = bin_file
        else:
            raise RuntimeError(f"No open_clip weights found in {path}")

        self.model, _, self.preprocess = open_clip.create_model_and_transforms(
            "ViT-B-16",
            pretrained=pretrained,
        )
        self.tokenizer = open_clip.get_tokenizer("ViT-B-16")
        self.model = self.model.to(self.device)
        self.model.eval()

    def __call__(self, data: Dict[str, Any]) -> list:
        inputs = data.get("inputs")
        if not inputs:
            raise ValueError("'inputs' is required — pass an image URL, base64 string, or text")

        if self._is_image(inputs):
            return self._embed_image(inputs)
        else:
            return self._embed_text(inputs)

    def _is_image(self, source: str) -> bool:
        return source.startswith("http://") or source.startswith("https://")

    def _embed_image(self, source: str) -> list:
        image = self._load_image(source)
        pixel_values = self.preprocess(image).unsqueeze(0).to(self.device)
        with torch.no_grad():
            features = self.model.encode_image(pixel_values, normalize=True)
        return features[0].tolist()

    def _embed_text(self, text: str) -> list:
        tokens = self.tokenizer([text]).to(self.device)
        with torch.no_grad():
            features = self.model.encode_text(tokens, normalize=True)
        return features[0].tolist()

    def _load_image(self, source: str) -> Image.Image:
        if source.startswith("http://") or source.startswith("https://"):
            response = requests.get(source, timeout=10)
            response.raise_for_status()
            return Image.open(io.BytesIO(response.content)).convert("RGB")

        try:
            image_bytes = base64.b64decode(source)
            return Image.open(io.BytesIO(image_bytes)).convert("RGB")
        except Exception as e:
            raise ValueError(f"Could not load image from input: {e}")