File size: 4,081 Bytes
147df04
652b877
9188b68
147df04
dc1caec
9188b68
08ce2dc
652b877
147df04
652b877
147df04
35037e4
147df04
35037e4
2fb4fd2
652b877
35037e4
825b375
35037e4
652b877
9188b68
 
825b375
e1369ab
 
08ce2dc
 
9188b68
652b877
147df04
652b877
147df04
 
652b877
 
 
 
 
147df04
 
652b877
 
147df04
35037e4
 
08ce2dc
 
35037e4
652b877
 
aa10251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

import contextlib, io, base64, torch
from PIL import Image
import open_clip
from mobileclip.modules.common.mobileone import reparameterize_model


class EndpointHandler:
    def __init__(self, path: str = ""):
        # You can also pass pretrained='datacompdr' to let OpenCLIP download
        weights = f"{path}/mobileclip_b.pt"
        self.model, _, self.preprocess = open_clip.create_model_and_transforms(
            "MobileCLIP-B", pretrained=weights
        )
        self.model.eval()
        self.model = reparameterize_model(self.model)   # *** fuse branches ***

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model.to(self.device)
        self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")

    def __call__(self, data):
        payload = data.get("inputs", data)
        img_b64 = payload["image"]
        labels  = payload.get("candidate_labels", [])
        if not labels:
            return {"error": "candidate_labels list is empty"}

        # ---------------- decode inputs ----------------
        image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
        img_tensor  = self.preprocess(image).unsqueeze(0).to(self.device)
        text_tokens = self.tokenizer(labels).to(self.device)

        # ---------------- forward pass -----------------
        autocast_ctx = (
            torch.cuda.amp.autocast if self.device.startswith("cuda") else contextlib.nullcontext
        )
        with torch.no_grad(), autocast_ctx():
            img_feat = self.model.encode_image(img_tensor)
            txt_feat = self.model.encode_text(text_tokens)
            img_feat /= img_feat.norm(dim=-1, keepdim=True)
            txt_feat /= txt_feat.norm(dim=-1, keepdim=True)
            probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()

        return [
            {"label": l, "score": float(p)}
            for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
        ]

        
# # handler.py  (repo root)
# import io, base64, torch
# from PIL import Image
# import open_clip

# class EndpointHandler:
#     """
#     Zero‑shot classifier for MobileCLIP‑B (OpenCLIP).

#     Expected client JSON *to the endpoint*:
#     {
#       "inputs": {
#         "image": "<base64 PNG/JPEG>",
#         "candidate_labels": ["cat", "dog", ...]
#       }
#     }
#     """

#     def __init__(self, path: str = ""):
#         weights = f"{path}/mobileclip_b.pt"
#         self.model, _, self.preprocess = open_clip.create_model_and_transforms(
#             "MobileCLIP-B", pretrained=weights
#         )
#         self.model.eval()

#         self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
#         self.device = "cuda" if torch.cuda.is_available() else "cpu"
#         self.model.to(self.device)

#     def __call__(self, data):
#         # ── unwrap Hugging Face's `inputs` envelope ───────────
#         payload = data.get("inputs", data)

#         img_b64 = payload["image"]
#         labels  = payload.get("candidate_labels", [])
#         if not labels:
#             return {"error": "candidate_labels list is empty"}

#         # Decode & preprocess image
#         image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
#         img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)

#         # Tokenise labels
#         text_tokens = self.tokenizer(labels).to(self.device)

#         # Forward pass
#         with torch.no_grad(), torch.cuda.amp.autocast():
#             img_feat = self.model.encode_image(img_tensor)
#             txt_feat = self.model.encode_text(text_tokens)
#             img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
#             txt_feat = txt_feat / txt_feat.norm(dim=-1, keepdim=True)
#             probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()

#         # Sorted output
#         return [
#             {"label": l, "score": float(p)}
#             for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
#         ]