File size: 5,457 Bytes
cd0f6ce
270502a
88c8e02
 
 
 
cd0f6ce
270502a
 
 
88c8e02
270502a
88c8e02
270502a
88c8e02
270502a
 
 
 
 
 
88c8e02
270502a
 
88c8e02
 
cd0f6ce
 
88c8e02
cd0f6ce
 
 
88c8e02
 
270502a
88c8e02
270502a
 
88c8e02
 
270502a
 
 
 
88c8e02
 
cd0f6ce
270502a
 
 
 
 
88c8e02
270502a
88c8e02
270502a
 
 
 
 
01f2a47
270502a
88c8e02
cd0f6ce
 
88c8e02
270502a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# handler.py
import io, base64, time, torch
from PIL import Image
from transformers import CLIPModel, CLIPProcessor

class EndpointHandler:
    def __init__(self, path=""):
        self.model      = CLIPModel.from_pretrained(path)
        self.processor  = CLIPProcessor.from_pretrained(path)
        self.device     = "cuda" if torch.cuda.is_available() else "cpu"
        self.model.to(self.device).eval()
        self.cache: dict[str, torch.Tensor] = {}

    # -------------------------------------------------------
    def __call__(self, data):
        T = {}                             # timing dict
        t0 = time.perf_counter()

        payload  = data.get("inputs", data)
        img_b64  = payload["image"]
        prompts  = payload["candidate_labels"]

        # β€”β€” text embeddings (cache) β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
        t = time.perf_counter()
        missing = [p for p in prompts if p not in self.cache]
        if missing:
            tok = self.processor(text=missing, return_tensors="pt",
                                 padding=True).to(self.device)
            with torch.no_grad():
                emb = self.model.get_text_features(**tok)
                emb = emb / emb.norm(dim=-1, keepdim=True)
            for p, e in zip(missing, emb):
                self.cache[p] = e
        txt_feat = torch.stack([self.cache[p] for p in prompts])
        T["encode_text"] = (time.perf_counter() - t) * 1000  # ms

        # β€”β€” image preprocessing β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
        t = time.perf_counter()
        img = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
        img_in = self.processor(images=img, return_tensors="pt").to(self.device)
        T["decode_resize"] = (time.perf_counter() - t) * 1000

        # β€”β€” image embedding β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
        t = time.perf_counter()
        with torch.no_grad(), torch.cuda.amp.autocast():
            img_feat = self.model.get_image_features(**img_in)
        img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
        img_feat = img_feat.float();  txt_feat = txt_feat.float()
        T["encode_image"] = (time.perf_counter() - t) * 1000

        # β€”β€” similarity & softmax β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
        t = time.perf_counter()
        probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
        T["similarity_softmax"] = (time.perf_counter() - t) * 1000

        # β€”β€” log timings β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
        total = (time.perf_counter() - t0) * 1000
        print(f"[CLIP timings] total={total:.1f}β€―ms | " +
              " | ".join(f"{k}={v:.1f}" for k, v in T.items()),
              flush=True)

        # β€”β€” build response β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
        return [
            {"label": p, "score": float(s)}
            for p, s in sorted(zip(prompts, probs), key=lambda x: x[1], reverse=True)
        ]

# import io, base64, torch
# from PIL import Image
# from transformers import CLIPModel, CLIPProcessor

# class EndpointHandler:
#     """
#     CLIP ViT‑L/14 zero‑shot classifier.
#     Expects JSON: {
#       "inputs": {
#         "image": "<base64>",
#         "candidate_labels": ["prompt‑1", "prompt‑2", ...]
#       }
#     }
#     """

#     def __init__(self, path=""):
#         self.model = CLIPModel.from_pretrained(path)
#         self.processor = CLIPProcessor.from_pretrained(path)
#         self.device = "cuda" if torch.cuda.is_available() else "cpu"
#         self.model.to(self.device).eval()
#         self.cache: dict[str, torch.Tensor] = {}          # prompt -> emb

#     def __call__(self, data):
#         payload = data.get("inputs", data)
#         img_b64 = payload["image"]
#         prompts = payload.get("candidate_labels", [])
#         if not prompts:
#             return {"error": "candidate_labels list is empty"}

#         # --- text embeddings with per‑process cache ----------
#         missing = [p for p in prompts if p not in self.cache]
#         if missing:
#             tok = self.processor(text=missing, return_tensors="pt",
#                                  padding=True).to(self.device)
#             with torch.no_grad():
#                 emb = self.model.get_text_features(**tok)
#                 emb = emb / emb.norm(dim=-1, keepdim=True)
#             for p, e in zip(missing, emb):
#                 self.cache[p] = e
#         txt_feat = torch.stack([self.cache[p] for p in prompts])

#         # --- image embedding ---------------------------------
#         img = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
#         img_in = self.processor(images=img, return_tensors="pt").to(self.device)
       
#         with torch.no_grad(), torch.cuda.amp.autocast():
#             img_feat = self.model.get_image_features(**img_in)
        
#         img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
#         # txt_feat = txt_feat / txt_feat.norm(dim=-1, keepdim=True)
        
#         img_feat = img_feat.float()   #  ← add these two lines
#         txt_feat = txt_feat.float()   #  ←
        
#         probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()

       

#         return [
#             {"label": p, "score": float(s)}
#             for p, s in sorted(zip(prompts, probs), key=lambda x: x[1], reverse=True)
#         ]