File size: 4,297 Bytes
652b877
9188b68
147df04
9106e2d
9188b68
652b877
147df04
c5d457b
 
 
 
 
 
35037e4
c5d457b
 
35037e4
652b877
c5d457b
 
 
 
 
 
9188b68
 
825b375
e1369ab
 
08ce2dc
 
9188b68
652b877
147df04
c5d457b
 
 
 
 
 
147df04
 
652b877
c5d457b
 
147df04
 
652b877
 
c5d457b
35037e4
 
08ce2dc
 
35037e4
c5d457b
aa10251
 
c5d457b
aa10251
 
c5d457b
aa10251
c5d457b
aa10251
 
 
 
 
c5d457b
aa10251
 
 
c5d457b
aa10251
 
 
 
 
 
 
 
c5d457b
aa10251
c5d457b
aa10251
 
c5d457b
 
 
 
 
aa10251
 
c5d457b
 
aa10251
 
 
 
 
 
c5d457b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import contextlib, io, base64, torch
from PIL import Image
import open_clip
from reparam import reparameterize_model

class EndpointHandler:
    def __init__(self, path: str = ""):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        # Fix 1: Load weights directly from the web, just like local script
        # This guarantees the weights are identical.
        model, _, self.preprocess = open_clip.create_model_and_transforms(
            "MobileCLIP-B", pretrained='datacompdr'
        )
        model.eval()
        self.model = reparameterize_model(model)  # fuse branches

        self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
        self.model.to(self.device)

        # Fix 2: Explicitly set model to half-precision if on CUDA
        # This matches the behavior of torch.set_default_dtype(torch.float16)
        if self.device == "cuda":
            self.model.to(torch.float16)

    def __call__(self, data):
        payload = data.get("inputs", data)
        img_b64 = payload["image"]
        labels  = payload.get("candidate_labels", [])
        if not labels:
            return {"error": "candidate_labels list is empty"}

        # ---------------- decode inputs ----------------
        image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
        img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)

        # The preprocessor might output float32, so ensure tensor matches model dtype
        if self.device == "cuda":
            img_tensor = img_tensor.to(torch.float16)

        text_tokens = self.tokenizer(labels).to(self.device)

        # ---------------- forward pass -----------------
        # No need for autocast if everything is already float16
        with torch.no_grad():
            img_feat = self.model.encode_image(img_tensor)
            txt_feat = self.model.encode_text(text_tokens)
            img_feat /= img_feat.norm(dim=-1, keepdim=True)
            txt_feat /= txt_feat.norm(dim=-1, keepdim=True)
            probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].cpu().tolist()

        return [
            {"label": l, "score": float(p)}
            for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
        ]
# import contextlib, io, base64, torch
# from PIL import Image
# import open_clip
# from reparam import reparameterize_model


# class EndpointHandler:
#     def __init__(self, path: str = ""):
#         # You can also pass pretrained='datacompdr' to let OpenCLIP download
#         weights = f"{path}/mobileclip_b.pt"
#         self.model, _, self.preprocess = open_clip.create_model_and_transforms(
#             "MobileCLIP-B", pretrained=weights
#         )
#         self.model.eval()
#         self.model = reparameterize_model(self.model)   # *** fuse branches ***

#         self.device = "cuda" if torch.cuda.is_available() else "cpu"
#         self.model.to(self.device)
#         self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")

#     def __call__(self, data):
#         payload = data.get("inputs", data)
#         img_b64 = payload["image"]
#         labels  = payload.get("candidate_labels", [])
#         if not labels:
#             return {"error": "candidate_labels list is empty"}

#         # ---------------- decode inputs ----------------
#         image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
#         img_tensor  = self.preprocess(image).unsqueeze(0).to(self.device)
#         text_tokens = self.tokenizer(labels).to(self.device)

#         # ---------------- forward pass -----------------
#         autocast_ctx = (
#             torch.cuda.amp.autocast if self.device.startswith("cuda") else contextlib.nullcontext
#         )
#         with torch.no_grad(), autocast_ctx():
#             img_feat = self.model.encode_image(img_tensor)
#             txt_feat = self.model.encode_text(text_tokens)
#             img_feat /= img_feat.norm(dim=-1, keepdim=True)
#             txt_feat /= txt_feat.norm(dim=-1, keepdim=True)
#             probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()

#         return [
#             {"label": l, "score": float(p)}
#             for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
#         ]