Zero-Shot Image Classification
Transformers
Safetensors
siglip
vision
File size: 5,520 Bytes
d68ec83
 
 
 
 
d7068bc
36b465e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d68ec83
 
 
 
36b465e
d68ec83
 
 
 
5c6483a
d7068bc
d68ec83
 
 
 
d7068bc
 
 
d68ec83
 
d7068bc
d68ec83
d7068bc
d68ec83
 
 
 
 
 
 
d7068bc
 
 
d68ec83
 
 
d7068bc
 
d68ec83
 
d7068bc
d68ec83
d7068bc
d68ec83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""
Custom Inference Handler for SigLIP2-base-patch16-512
Supports: zero_shot, image_embedding, text_embedding, similarity
Returns 768D embeddings.
"""
from typing import Any, Dict, List, Union
import torch
from PIL import Image
import requests
from io import BytesIO
import base64
from transformers import AutoProcessor, AutoModel


class EndpointHandler:
    def __init__(self, path: str = ""):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = AutoModel.from_pretrained(path, trust_remote_code=True).to(self.device)
        self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True)
        self.model.eval()

    def _load_image(self, image_data: Any) -> Image.Image:
        if isinstance(image_data, str):
            if image_data.startswith(("http://", "https://")):
                response = requests.get(image_data, timeout=10)
                response.raise_for_status()
                return Image.open(BytesIO(response.content)).convert("RGB")
            else:
                if "," in image_data:
                    image_data = image_data.split(",")[1]
                image_bytes = base64.b64decode(image_data)
                return Image.open(BytesIO(image_bytes)).convert("RGB")
        elif isinstance(image_data, bytes):
            return Image.open(BytesIO(image_data)).convert("RGB")
        raise ValueError(f"Unsupported image format: {type(image_data)}")

    def _get_image_embeddings(self, images: List[Image.Image]) -> torch.Tensor:
        inputs = self.processor(images=images, return_tensors="pt").to(self.device)
        with torch.no_grad():
            features = self.model.get_image_features(**inputs)
        return features / features.norm(dim=-1, keepdim=True)

    def _get_text_embeddings(self, texts: List[str]) -> torch.Tensor:
        inputs = self.processor(text=texts, padding="max_length", truncation=True, return_tensors="pt").to(self.device)
        with torch.no_grad():
            features = self.model.get_text_features(**inputs)
        return features / features.norm(dim=-1, keepdim=True)

    def __call__(self, data: Dict[str, Any]) -> Any:
        inputs = data.get("inputs", data)
        parameters = data.get("parameters", {})
        mode = parameters.get("mode", "auto")

        # Auto-detect mode
        if mode == "auto":
            if isinstance(inputs, dict) and ("image" in inputs or "images" in inputs):
                mode = "similarity"
            elif "candidate_labels" in parameters:
                mode = "zero_shot"
            elif isinstance(inputs, str) and not inputs.startswith(("http", "data:")) and len(inputs) < 500:
                mode = "text_embedding"
            elif isinstance(inputs, list) and all(
                isinstance(i, str) and not i.startswith(("http", "data:")) and len(i) < 500 for i in inputs
            ):
                mode = "text_embedding"
            else:
                mode = "image_embedding"

        if mode == "zero_shot":
            return self._zero_shot(inputs, parameters)
        elif mode == "image_embedding":
            return self._image_embedding(inputs)
        elif mode == "text_embedding":
            return self._text_embedding(inputs)
        elif mode == "similarity":
            return self._similarity(inputs)
        else:
            raise ValueError(f"Unknown mode: {mode}")

    def _zero_shot(self, inputs, parameters):
        candidate_labels = parameters.get("candidate_labels", ["photo", "illustration", "diagram"])
        if isinstance(candidate_labels, str):
            candidate_labels = [l.strip() for l in candidate_labels.split(",")]

        images = [self._load_image(inputs)] if not isinstance(inputs, list) else [self._load_image(i) for i in inputs]
        image_embeds = self._get_image_embeddings(images)
        text_embeds = self._get_text_embeddings(candidate_labels)

        logits = image_embeds @ text_embeds.T
        probs = torch.softmax(logits, dim=-1)

        results = []
        for i, prob in enumerate(probs):
            scores = prob.cpu().tolist()
            result = [{"label": l, "score": s} for l, s in sorted(zip(candidate_labels, scores), key=lambda x: -x[1])]
            results.append(result)

        return results[0] if len(results) == 1 else results

    def _image_embedding(self, inputs):
        images = [self._load_image(inputs)] if not isinstance(inputs, list) else [self._load_image(i) for i in inputs]
        embeddings = self._get_image_embeddings(images)
        return [{"embedding": emb.cpu().tolist()} for emb in embeddings]

    def _text_embedding(self, inputs):
        texts = [inputs] if isinstance(inputs, str) else inputs
        embeddings = self._get_text_embeddings(texts)
        return [{"embedding": emb.cpu().tolist()} for emb in embeddings]

    def _similarity(self, inputs):
        image_input = inputs.get("image") or inputs.get("images")
        text_input = inputs.get("text") or inputs.get("texts")

        images = [self._load_image(image_input)] if not isinstance(image_input, list) else [self._load_image(i) for i in image_input]
        texts = [text_input] if isinstance(text_input, str) else text_input

        image_embeds = self._get_image_embeddings(images)
        text_embeds = self._get_text_embeddings(texts)

        similarity = (image_embeds @ text_embeds.T).cpu().tolist()
        return {"similarity_scores": similarity, "image_count": len(images), "text_count": len(texts)}