File size: 4,258 Bytes
147df04 9188b68 147df04 dc1caec 9188b68 35037e4 dc1caec 35037e4 08ce2dc 147df04 35037e4 147df04 35037e4 2fb4fd2 35037e4 dc1caec 147df04 825b375 35037e4 9188b68 dc1caec 147df04 825b375 147df04 e1369ab 08ce2dc 9188b68 147df04 35037e4 147df04 35037e4 147df04 35037e4 08ce2dc 35037e4 aa10251 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import io, base64, torch
from PIL import Image
import open_clip
# Make sure the mobileclip library is installed in your Hugging Face environment
# You might need to add it to your requirements.txt
from mobileclip.modules.common.mobileone import reparameterize_model
class EndpointHandler:
"""
Zero-shot classifier for MobileCLIP-B (OpenCLIP).
"""
def __init__(self, path: str = ""):
weights = f"{path}/mobileclip_b.pt"
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
"MobileCLIP-B", pretrained=weights
)
self.model.eval()
# *** THIS IS THE CRUCIAL ADDITION ***
self.model = reparameterize_model(self.model)
self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
def __call__(self, data):
# ... (the rest of your __call__ method remains the same)
# ── unwrap Hugging Face's `inputs` envelope ───────────
payload = data.get("inputs", data)
img_b64 = payload["image"]
labels = payload.get("candidate_labels", [])
if not labels:
return {"error": "candidate_labels list is empty"}
# Decode & preprocess image
image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
# Tokenise labels
text_tokens = self.tokenizer(labels).to(self.device)
# Forward pass
with torch.no_grad(), torch.cuda.amp.autocast():
img_feat = self.model.encode_image(img_tensor)
txt_feat = self.model.encode_text(text_tokens)
img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
txt_feat = txt_feat / txt_feat.norm(dim=-1, keepdim=True)
probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
# Sorted output
return [
{"label": l, "score": float(p)}
for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
]
# # handler.py (repo root)
# import io, base64, torch
# from PIL import Image
# import open_clip
# class EndpointHandler:
# """
# Zero‑shot classifier for MobileCLIP‑B (OpenCLIP).
# Expected client JSON *to the endpoint*:
# {
# "inputs": {
# "image": "<base64 PNG/JPEG>",
# "candidate_labels": ["cat", "dog", ...]
# }
# }
# """
# def __init__(self, path: str = ""):
# weights = f"{path}/mobileclip_b.pt"
# self.model, _, self.preprocess = open_clip.create_model_and_transforms(
# "MobileCLIP-B", pretrained=weights
# )
# self.model.eval()
# self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
# self.device = "cuda" if torch.cuda.is_available() else "cpu"
# self.model.to(self.device)
# def __call__(self, data):
# # ── unwrap Hugging Face's `inputs` envelope ───────────
# payload = data.get("inputs", data)
# img_b64 = payload["image"]
# labels = payload.get("candidate_labels", [])
# if not labels:
# return {"error": "candidate_labels list is empty"}
# # Decode & preprocess image
# image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
# img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
# # Tokenise labels
# text_tokens = self.tokenizer(labels).to(self.device)
# # Forward pass
# with torch.no_grad(), torch.cuda.amp.autocast():
# img_feat = self.model.encode_image(img_tensor)
# txt_feat = self.model.encode_text(text_tokens)
# img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
# txt_feat = txt_feat / txt_feat.norm(dim=-1, keepdim=True)
# probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
# # Sorted output
# return [
# {"label": l, "score": float(p)}
# for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
# ]
|