File size: 5,135 Bytes
233acb0 aa10251 233acb0 9188b68 35037e4 aa10251 9188b68 35037e4 233acb0 825b375 aa10251 825b375 35037e4 825b375 35037e4 08ce2dc aa10251 35037e4 825b375 aa10251 35037e4 aa10251 35037e4 aa10251 825b375 35037e4 9188b68 aa10251 9188b68 aa10251 825b375 aa10251 08ce2dc 9188b68 aa10251 08ce2dc 233acb0 35037e4 aa10251 08ce2dc aa10251 35037e4 233acb0 08ce2dc 35037e4 aa10251 35037e4 08ce2dc 35037e4 aa10251 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | # handler.py (repo root)
import io, base64, torch
from PIL import Image
import open_clip
from open_clip import fuse_conv_bn_sequential
class EndpointHandler:
"""
Zero‑shot classifier for MobileCLIP‑B (OpenCLIP).
Client JSON format:
{
"inputs": {
"image": "<base64 PNG/JPEG>",
"candidate_labels": ["cat", "dog", ...]
}
}
"""
# ----------------------------------------------------- #
# INITIALISATION (once) #
# ----------------------------------------------------- #
def __init__(self, path: str = ""):
weights = f"{path}/mobileclip_b.pt"
# Load model + transforms
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
"MobileCLIP-B", pretrained=weights
)
# Fuse Conv+BN for faster inference
self.model = fuse_conv_bn_sequential(self.model).eval()
# Tokeniser
self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
# Device
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
# -------- text‑embedding cache --------
# key: prompt string • value: torch.Tensor [512] on correct device
self.label_cache: dict[str, torch.Tensor] = {}
# ----------------------------------------------------- #
# INFERENCE (per request) #
# ----------------------------------------------------- #
def __call__(self, data):
# 1. Unwrap the HF "inputs" envelope
payload = data.get("inputs", data)
img_b64 = payload["image"]
labels = payload.get("candidate_labels", [])
if not labels:
return {"error": "candidate_labels list is empty"}
# 2. Decode & preprocess image
image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
# 3. Text embeddings with cache
missing = [l for l in labels if l not in self.label_cache]
if missing:
tokens = self.tokenizer(missing).to(self.device)
with torch.no_grad():
emb = self.model.encode_text(tokens)
emb = emb / emb.norm(dim=-1, keepdim=True)
for lbl, vec in zip(missing, emb):
self.label_cache[lbl] = vec # store on device
txt_feat = torch.stack([self.label_cache[l] for l in labels])
# 4. Forward pass for image
with torch.no_grad(), torch.cuda.amp.autocast():
img_feat = self.model.encode_image(img_tensor)
img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
# 5. Similarity & softmax
probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
# 6. Return sorted list
return [
{"label": l, "score": float(p)}
for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
]
# # handler.py (repo root)
# import io, base64, torch
# from PIL import Image
# import open_clip
# class EndpointHandler:
# """
# Zero‑shot classifier for MobileCLIP‑B (OpenCLIP).
# Expected client JSON *to the endpoint*:
# {
# "inputs": {
# "image": "<base64 PNG/JPEG>",
# "candidate_labels": ["cat", "dog", ...]
# }
# }
# """
# def __init__(self, path: str = ""):
# weights = f"{path}/mobileclip_b.pt"
# self.model, _, self.preprocess = open_clip.create_model_and_transforms(
# "MobileCLIP-B", pretrained=weights
# )
# self.model.eval()
# self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
# self.device = "cuda" if torch.cuda.is_available() else "cpu"
# self.model.to(self.device)
# def __call__(self, data):
# # ── unwrap Hugging Face's `inputs` envelope ───────────
# payload = data.get("inputs", data)
# img_b64 = payload["image"]
# labels = payload.get("candidate_labels", [])
# if not labels:
# return {"error": "candidate_labels list is empty"}
# # Decode & preprocess image
# image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
# img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
# # Tokenise labels
# text_tokens = self.tokenizer(labels).to(self.device)
# # Forward pass
# with torch.no_grad(), torch.cuda.amp.autocast():
# img_feat = self.model.encode_image(img_tensor)
# txt_feat = self.model.encode_text(text_tokens)
# img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
# txt_feat = txt_feat / txt_feat.norm(dim=-1, keepdim=True)
# probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
# # Sorted output
# return [
# {"label": l, "score": float(p)}
# for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
# ]
|