File size: 5,135 Bytes
233acb0
aa10251
233acb0
9188b68
35037e4
aa10251
 
9188b68
 
35037e4
233acb0
825b375
aa10251
825b375
 
 
35037e4
 
825b375
35037e4
08ce2dc
aa10251
 
 
35037e4
825b375
aa10251
 
35037e4
 
 
 
aa10251
 
 
 
35037e4
aa10251
 
825b375
35037e4
9188b68
aa10251
 
 
 
 
 
 
9188b68
aa10251
825b375
 
aa10251
 
08ce2dc
 
9188b68
aa10251
08ce2dc
233acb0
35037e4
aa10251
 
 
 
 
 
 
 
 
 
 
08ce2dc
aa10251
35037e4
233acb0
08ce2dc
35037e4
aa10251
 
 
 
35037e4
08ce2dc
 
35037e4
aa10251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# handler.py  (repo root)

import io, base64, torch
from PIL import Image
import open_clip
from open_clip import fuse_conv_bn_sequential


class EndpointHandler:
    """
    Zero‑shot classifier for MobileCLIP‑B (OpenCLIP).

    Client JSON format:
    {
      "inputs": {
        "image": "<base64 PNG/JPEG>",
        "candidate_labels": ["cat", "dog", ...]
      }
    }
    """

    # ----------------------------------------------------- #
    #               INITIALISATION (once)                  #
    # ----------------------------------------------------- #
    def __init__(self, path: str = ""):
        weights = f"{path}/mobileclip_b.pt"

        # Load model + transforms
        self.model, _, self.preprocess = open_clip.create_model_and_transforms(
            "MobileCLIP-B", pretrained=weights
        )

        # Fuse Conv+BN for faster inference
        self.model = fuse_conv_bn_sequential(self.model).eval()

        # Tokeniser
        self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")

        # Device
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model.to(self.device)

        # -------- text‑embedding cache --------
        # key: prompt string  •  value: torch.Tensor [512] on correct device
        self.label_cache: dict[str, torch.Tensor] = {}

    # ----------------------------------------------------- #
    #              INFERENCE  (per request)                #
    # ----------------------------------------------------- #
    def __call__(self, data):
        # 1. Unwrap the HF "inputs" envelope
        payload = data.get("inputs", data)

        img_b64  = payload["image"]
        labels   = payload.get("candidate_labels", [])
        if not labels:
            return {"error": "candidate_labels list is empty"}

        # 2. Decode & preprocess image
        image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
        img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)

        # 3. Text embeddings with cache
        missing = [l for l in labels if l not in self.label_cache]
        if missing:
            tokens = self.tokenizer(missing).to(self.device)
            with torch.no_grad():
                emb = self.model.encode_text(tokens)
                emb = emb / emb.norm(dim=-1, keepdim=True)
            for lbl, vec in zip(missing, emb):
                self.label_cache[lbl] = vec  # store on device

        txt_feat = torch.stack([self.label_cache[l] for l in labels])

        # 4. Forward pass for image
        with torch.no_grad(), torch.cuda.amp.autocast():
            img_feat = self.model.encode_image(img_tensor)
            img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)

        # 5. Similarity & softmax
        probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()

        # 6. Return sorted list
        return [
            {"label": l, "score": float(p)}
            for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
        ]

# # handler.py  (repo root)
# import io, base64, torch
# from PIL import Image
# import open_clip

# class EndpointHandler:
#     """
#     Zero‑shot classifier for MobileCLIP‑B (OpenCLIP).

#     Expected client JSON *to the endpoint*:
#     {
#       "inputs": {
#         "image": "<base64 PNG/JPEG>",
#         "candidate_labels": ["cat", "dog", ...]
#       }
#     }
#     """

#     def __init__(self, path: str = ""):
#         weights = f"{path}/mobileclip_b.pt"
#         self.model, _, self.preprocess = open_clip.create_model_and_transforms(
#             "MobileCLIP-B", pretrained=weights
#         )
#         self.model.eval()

#         self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
#         self.device = "cuda" if torch.cuda.is_available() else "cpu"
#         self.model.to(self.device)

#     def __call__(self, data):
#         # ── unwrap Hugging Face's `inputs` envelope ───────────
#         payload = data.get("inputs", data)

#         img_b64 = payload["image"]
#         labels  = payload.get("candidate_labels", [])
#         if not labels:
#             return {"error": "candidate_labels list is empty"}

#         # Decode & preprocess image
#         image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
#         img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)

#         # Tokenise labels
#         text_tokens = self.tokenizer(labels).to(self.device)

#         # Forward pass
#         with torch.no_grad(), torch.cuda.amp.autocast():
#             img_feat = self.model.encode_image(img_tensor)
#             txt_feat = self.model.encode_text(text_tokens)
#             img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
#             txt_feat = txt_feat / txt_feat.norm(dim=-1, keepdim=True)
#             probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()

#         # Sorted output
#         return [
#             {"label": l, "score": float(p)}
#             for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
#         ]