testmobileclip / handler.py
finhdev's picture
Update handler.py
35037e4 verified
raw
history blame
2.07 kB
# handler.py – place in repo root
import io, base64, torch
from PIL import Image
import open_clip
from mobileclip.modules.common.mobileone import reparameterize_model
class EndpointHandler:
"""
Zero‑shot image classifier for MobileCLIP‑B using OpenCLIP.
Expects JSON:
{
"image": "<base64‑encoded PNG/JPEG>",
"candidate_labels": ["cat", "dog", ...]
}
"""
def __init__(self, path: str = ""):
# Hugging Face Endpoints clones the repo into `path`.
# The weights file is mobileclip_b.pt (already in the repo).
weights = f"{path}/mobileclip_b.pt"
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
"MobileCLIP-B", pretrained=weights
)
# Re‑parameterize once for faster inference (as per MobileCLIP docs)
self.model = reparameterize_model(self.model)
self.model.eval()
# OpenCLIP tokenizer (same as CLIP)
self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
def __call__(self, data):
# Decode input
img_b64 = data["image"]
labels = data.get("candidate_labels", [])
image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
# Preprocess
image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
text_tokens = self.tokenizer(labels).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast():
img_feat = self.model.encode_image(image_tensor)
txt_feat = self.model.encode_text(text_tokens)
img_feat /= img_feat.norm(dim=-1, keepdim=True)
txt_feat /= txt_feat.norm(dim=-1, keepdim=True)
probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
return [
{"label": l, "score": float(p)} for l, p in sorted(
zip(labels, probs), key=lambda x: x[1], reverse=True
)
]