File size: 2,788 Bytes
fbf307e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b2ecfb
 
 
 
 
 
 
 
 
 
 
d4572b2
7b2ecfb
 
 
 
 
 
fbf307e
 
 
 
 
 
 
 
 
 
 
 
a3019f7
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from typing import Dict, List, Any
import torch
from transformers import AutoProcessor, AutoModel
from PIL import Image
import base64
import io

class EndpointHandler:
    def __init__(self, path=""):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        model_id = "google/siglip2-so400m-patch14-384"
        self.processor = AutoProcessor.from_pretrained(model_id)
        self.model = AutoModel.from_pretrained(model_id).to(self.device).eval()

    def __call__(self, data: Any) -> List[List[float]]:
        """
        Args:
            data (:obj:):
                includes the input data and the parameters for the inference.
        Return:
            A :obj:`list`:. The output of the model.
        """
        inputs_data = data.get("inputs", data)
        
        # Check if inputs is a list or a single item
        if not isinstance(inputs_data, list):
            inputs_data = [inputs_data]
            
        results = []
        for item in inputs_data:
            try:
                # Handle text
                if isinstance(item, str) and not self._is_base64(item):
                    inputs = self.processor(text=[item], padding="max_length", return_tensors="pt").to(self.device)
                    with torch.no_grad():
                        features = self.model.get_text_features(**inputs)
                    results.append(features[0].cpu().tolist())
                # Handle image (base64)
                else:
                    image = self._decode_image(item)
                    # print(f"Processing image: {image.size} {image.mode}")
                    inputs = self.processor(images=[image], return_tensors="pt").to(self.device)
                    with torch.no_grad():
                        features = self.model.get_image_features(**inputs)
                    results.append(features[0].cpu().tolist())
            except Exception as e:
                print(f"Error processing item: {e}")
                raise e
                
        return results

    def _is_base64(self, s):
        try:
            if isinstance(s, bytes):
                s = s.decode('utf-8')
            return base64.b64encode(base64.b64decode(s)).decode('utf-8') == s.replace('\n', '').replace('\r', '')
        except Exception:
            return False

    def _decode_image(self, data):
        try:
            if isinstance(data, str):
                image_bytes = base64.b64decode(data)
            else:
                image_bytes = data
            img = Image.open(io.BytesIO(image_bytes))
            # Ensure loaded
            img.load()
            return img.convert("RGB")
        except Exception as e:
            print(f"Image decode failed: {e}")
            raise ValueError(f"Invalid image data: {e}")