File size: 2,350 Bytes
08ce2dc
 
 
 
9188b68
35037e4
 
08ce2dc
9188b68
 
35037e4
08ce2dc
 
 
35037e4
 
 
 
08ce2dc
 
 
 
 
 
35037e4
08ce2dc
35037e4
08ce2dc
35037e4
08ce2dc
 
35037e4
 
 
 
08ce2dc
 
35037e4
08ce2dc
35037e4
 
08ce2dc
35037e4
 
9188b68
 
08ce2dc
 
 
 
 
9188b68
08ce2dc
 
35037e4
 
08ce2dc
 
 
 
35037e4
 
 
08ce2dc
 
35037e4
 
08ce2dc
35037e4
08ce2dc
 
35037e4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# handler.py
import io
import base64
import torch
from PIL import Image

import open_clip
from open_clip import fuse_conv_bn_sequential

class EndpointHandler:
    """
    Zero‑shot image classifier for MobileCLIP‑B (OpenCLIP).
    
    Expects JSON payload:
      {
        "image": "<base64‑encoded PNG/JPEG>",
        "candidate_labels": ["cat", "dog", ...]
      }
    Returns:
      [
        {"label": "cat", "score": 0.91},
        {"label": "dog", "score": 0.05},
        ...
      ]
    """

    def __init__(self, path: str = ""):
        # Path points to the repo root inside the container
        weights = f"{path}/mobileclip_b.pt"

        # Load model + transforms from OpenCLIP
        self.model, _, self.preprocess = open_clip.create_model_and_transforms(
            "MobileCLIP-B", pretrained=weights
        )

        # Fuse conv + BN for faster inference (same idea as MobileCLIP re‑param)
        self.model = fuse_conv_bn_sequential(self.model).eval()

        # Tokenizer for label prompts
        self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")

        # Device selection
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model.to(self.device)

    def __call__(self, data):
        # 1. Parse request
        img_b64   = data["image"]
        labels    = data.get("candidate_labels", [])
        if not labels:
            return {"error": "candidate_labels list is empty"}

        # 2. Decode & preprocess image
        image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
        image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)

        # 3. Tokenize labels
        text_tokens = self.tokenizer(labels).to(self.device)

        # 4. Forward pass
        with torch.no_grad(), torch.cuda.amp.autocast():
            img_feat = self.model.encode_image(image_tensor)
            txt_feat = self.model.encode_text(text_tokens)
            img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
            txt_feat = txt_feat / txt_feat.norm(dim=-1, keepdim=True)
            probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()

        # 5. Return sorted list
        return [
            {"label": l, "score": float(p)}
            for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
        ]