File size: 4,764 Bytes
233acb0 aa10251 e1369ab 233acb0 9188b68 35037e4 aa10251 9188b68 35037e4 e1369ab 825b375 e1369ab 825b375 35037e4 825b375 35037e4 08ce2dc e1369ab 35037e4 825b375 aa10251 35037e4 e1369ab 35037e4 825b375 35037e4 9188b68 e1369ab aa10251 e1369ab 9188b68 825b375 e1369ab 08ce2dc 9188b68 e1369ab 08ce2dc 233acb0 35037e4 e1369ab aa10251 e1369ab aa10251 08ce2dc e1369ab 35037e4 233acb0 08ce2dc e1369ab 35037e4 e1369ab 35037e4 08ce2dc 35037e4 aa10251 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
# handler.py (repo root)
# handler.py (repo root)
import io, base64, torch
from PIL import Image
import open_clip
class EndpointHandler:
"""
Zero‑shot classifier for MobileCLIP‑B (OpenCLIP) with a text‑embedding cache.
Client JSON:
{
"inputs": {
"image": "<base64 PNG/JPEG>",
"candidate_labels": ["cat", "dog", ...]
}
}
"""
# ------------------------------------------------- #
# INITIALISATION #
# ------------------------------------------------- #
def __init__(self, path: str = ""):
weights = f"{path}/mobileclip_b.pt"
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
"MobileCLIP-B", pretrained=weights
)
self.model.eval()
self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
# cache: {prompt -> 1×512 tensor on device}
self.label_cache: dict[str, torch.Tensor] = {}
# ------------------------------------------------- #
# INFERENCE #
# ------------------------------------------------- #
def __call__(self, data):
payload = data.get("inputs", data)
img_b64 = payload["image"]
labels = payload.get("candidate_labels", [])
if not labels:
return {"error": "candidate_labels list is empty"}
# --- image ----
image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
# --- text (with cache) ----
missing = [l for l in labels if l not in self.label_cache]
if missing:
tokens = self.tokenizer(missing).to(self.device)
with torch.no_grad():
emb = self.model.encode_text(tokens)
emb = emb / emb.norm(dim=-1, keepdim=True)
for l, e in zip(missing, emb):
self.label_cache[l] = e
txt_feat = torch.stack([self.label_cache[l] for l in labels])
# --- forward & softmax ----
with torch.no_grad(), torch.cuda.amp.autocast():
img_feat = self.model.encode_image(img_tensor)
img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
# --- sorted output ----
return [
{"label": l, "score": float(p)}
for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
]
# # handler.py (repo root)
# import io, base64, torch
# from PIL import Image
# import open_clip
# class EndpointHandler:
# """
# Zero‑shot classifier for MobileCLIP‑B (OpenCLIP).
# Expected client JSON *to the endpoint*:
# {
# "inputs": {
# "image": "<base64 PNG/JPEG>",
# "candidate_labels": ["cat", "dog", ...]
# }
# }
# """
# def __init__(self, path: str = ""):
# weights = f"{path}/mobileclip_b.pt"
# self.model, _, self.preprocess = open_clip.create_model_and_transforms(
# "MobileCLIP-B", pretrained=weights
# )
# self.model.eval()
# self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
# self.device = "cuda" if torch.cuda.is_available() else "cpu"
# self.model.to(self.device)
# def __call__(self, data):
# # ── unwrap Hugging Face's `inputs` envelope ───────────
# payload = data.get("inputs", data)
# img_b64 = payload["image"]
# labels = payload.get("candidate_labels", [])
# if not labels:
# return {"error": "candidate_labels list is empty"}
# # Decode & preprocess image
# image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
# img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
# # Tokenise labels
# text_tokens = self.tokenizer(labels).to(self.device)
# # Forward pass
# with torch.no_grad(), torch.cuda.amp.autocast():
# img_feat = self.model.encode_image(img_tensor)
# txt_feat = self.model.encode_text(text_tokens)
# img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
# txt_feat = txt_feat / txt_feat.norm(dim=-1, keepdim=True)
# probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
# # Sorted output
# return [
# {"label": l, "score": float(p)}
# for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
# ]
|