File size: 2,069 Bytes
35037e4
9188b68
 
35037e4
 
 
9188b68
 
35037e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9188b68
 
35037e4
9188b68
 
 
 
35037e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# handler.py  – place in repo root
import io, base64, torch
from PIL import Image

import open_clip
from mobileclip.modules.common.mobileone import reparameterize_model

class EndpointHandler:
    """
    Zero‑shot image classifier for MobileCLIP‑B using OpenCLIP.
    Expects JSON:
      {
        "image": "<base64‑encoded PNG/JPEG>",
        "candidate_labels": ["cat", "dog", ...]
      }
    """
    def __init__(self, path: str = ""):
        # Hugging Face Endpoints clones the repo into `path`.
        # The weights file is mobileclip_b.pt (already in the repo).
        weights = f"{path}/mobileclip_b.pt"
        self.model, _, self.preprocess = open_clip.create_model_and_transforms(
            "MobileCLIP-B", pretrained=weights
        )

        # Re‑parameterize once for faster inference (as per MobileCLIP docs)
        self.model = reparameterize_model(self.model)
        self.model.eval()

        # OpenCLIP tokenizer (same as CLIP)
        self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model.to(self.device)

    def __call__(self, data):
        # Decode input
        img_b64 = data["image"]
        labels  = data.get("candidate_labels", [])
        image   = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")

        # Preprocess
        image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
        text_tokens  = self.tokenizer(labels).to(self.device)

        with torch.no_grad(), torch.cuda.amp.autocast():
            img_feat = self.model.encode_image(image_tensor)
            txt_feat = self.model.encode_text(text_tokens)
            img_feat /= img_feat.norm(dim=-1, keepdim=True)
            txt_feat /= txt_feat.norm(dim=-1, keepdim=True)
            probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()

        return [
            {"label": l, "score": float(p)} for l, p in sorted(
                zip(labels, probs), key=lambda x: x[1], reverse=True
            )
        ]