File size: 4,764 Bytes
233acb0
aa10251
e1369ab
 
233acb0
9188b68
35037e4
aa10251
9188b68
 
35037e4
e1369ab
825b375
e1369ab
825b375
 
 
35037e4
 
825b375
35037e4
08ce2dc
e1369ab
 
 
35037e4
825b375
aa10251
35037e4
 
 
e1369ab
35037e4
 
825b375
35037e4
9188b68
e1369ab
aa10251
 
e1369ab
 
 
9188b68
825b375
 
e1369ab
 
08ce2dc
 
9188b68
e1369ab
08ce2dc
233acb0
35037e4
e1369ab
aa10251
 
 
 
 
 
e1369ab
 
aa10251
08ce2dc
e1369ab
35037e4
233acb0
08ce2dc
e1369ab
35037e4
e1369ab
35037e4
08ce2dc
 
35037e4
aa10251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# handler.py  (repo root)

# handler.py  (repo root)

import io, base64, torch
from PIL import Image
import open_clip


class EndpointHandler:
    """
    Zero‑shot classifier for MobileCLIP‑B (OpenCLIP) with a text‑embedding cache.

    Client JSON:
    {
      "inputs": {
        "image": "<base64 PNG/JPEG>",
        "candidate_labels": ["cat", "dog", ...]
      }
    }
    """

    # ------------------------------------------------- #
    #                 INITIALISATION                    #
    # ------------------------------------------------- #
    def __init__(self, path: str = ""):
        weights = f"{path}/mobileclip_b.pt"

        self.model, _, self.preprocess = open_clip.create_model_and_transforms(
            "MobileCLIP-B", pretrained=weights
        )
        self.model.eval()

        self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model.to(self.device)

        # cache: {prompt -> 1×512 tensor on device}
        self.label_cache: dict[str, torch.Tensor] = {}

    # ------------------------------------------------- #
    #                    INFERENCE                      #
    # ------------------------------------------------- #
    def __call__(self, data):
        payload = data.get("inputs", data)

        img_b64 = payload["image"]
        labels  = payload.get("candidate_labels", [])
        if not labels:
            return {"error": "candidate_labels list is empty"}

        # --- image ----
        image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
        img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)

        # --- text (with cache) ----
        missing = [l for l in labels if l not in self.label_cache]
        if missing:
            tokens = self.tokenizer(missing).to(self.device)
            with torch.no_grad():
                emb = self.model.encode_text(tokens)
                emb = emb / emb.norm(dim=-1, keepdim=True)
            for l, e in zip(missing, emb):
                self.label_cache[l] = e
        txt_feat = torch.stack([self.label_cache[l] for l in labels])

        # --- forward & softmax ----
        with torch.no_grad(), torch.cuda.amp.autocast():
            img_feat = self.model.encode_image(img_tensor)
            img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
            probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()

        # --- sorted output ----
        return [
            {"label": l, "score": float(p)}
            for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
        ]

# # handler.py  (repo root)
# import io, base64, torch
# from PIL import Image
# import open_clip

# class EndpointHandler:
#     """
#     Zero‑shot classifier for MobileCLIP‑B (OpenCLIP).

#     Expected client JSON *to the endpoint*:
#     {
#       "inputs": {
#         "image": "<base64 PNG/JPEG>",
#         "candidate_labels": ["cat", "dog", ...]
#       }
#     }
#     """

#     def __init__(self, path: str = ""):
#         weights = f"{path}/mobileclip_b.pt"
#         self.model, _, self.preprocess = open_clip.create_model_and_transforms(
#             "MobileCLIP-B", pretrained=weights
#         )
#         self.model.eval()

#         self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
#         self.device = "cuda" if torch.cuda.is_available() else "cpu"
#         self.model.to(self.device)

#     def __call__(self, data):
#         # ── unwrap Hugging Face's `inputs` envelope ───────────
#         payload = data.get("inputs", data)

#         img_b64 = payload["image"]
#         labels  = payload.get("candidate_labels", [])
#         if not labels:
#             return {"error": "candidate_labels list is empty"}

#         # Decode & preprocess image
#         image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
#         img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)

#         # Tokenise labels
#         text_tokens = self.tokenizer(labels).to(self.device)

#         # Forward pass
#         with torch.no_grad(), torch.cuda.amp.autocast():
#             img_feat = self.model.encode_image(img_tensor)
#             txt_feat = self.model.encode_text(text_tokens)
#             img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
#             txt_feat = txt_feat / txt_feat.norm(dim=-1, keepdim=True)
#             probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()

#         # Sorted output
#         return [
#             {"label": l, "score": float(p)}
#             for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
#         ]