| | |
| | import io, base64, torch |
| | from PIL import Image |
| |
|
| | import open_clip |
| | from mobileclip.modules.common.mobileone import reparameterize_model |
| |
|
| | class EndpointHandler: |
| | """ |
| | Zero‑shot image classifier for MobileCLIP‑B using OpenCLIP. |
| | Expects JSON: |
| | { |
| | "image": "<base64‑encoded PNG/JPEG>", |
| | "candidate_labels": ["cat", "dog", ...] |
| | } |
| | """ |
| | def __init__(self, path: str = ""): |
| | |
| | |
| | weights = f"{path}/mobileclip_b.pt" |
| | self.model, _, self.preprocess = open_clip.create_model_and_transforms( |
| | "MobileCLIP-B", pretrained=weights |
| | ) |
| |
|
| | |
| | self.model = reparameterize_model(self.model) |
| | self.model.eval() |
| |
|
| | |
| | self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B") |
| |
|
| | self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| | self.model.to(self.device) |
| |
|
| | def __call__(self, data): |
| | |
| | img_b64 = data["image"] |
| | labels = data.get("candidate_labels", []) |
| | image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB") |
| |
|
| | |
| | image_tensor = self.preprocess(image).unsqueeze(0).to(self.device) |
| | text_tokens = self.tokenizer(labels).to(self.device) |
| |
|
| | with torch.no_grad(), torch.cuda.amp.autocast(): |
| | img_feat = self.model.encode_image(image_tensor) |
| | txt_feat = self.model.encode_text(text_tokens) |
| | img_feat /= img_feat.norm(dim=-1, keepdim=True) |
| | txt_feat /= txt_feat.norm(dim=-1, keepdim=True) |
| | probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist() |
| |
|
| | return [ |
| | {"label": l, "score": float(p)} for l, p in sorted( |
| | zip(labels, probs), key=lambda x: x[1], reverse=True |
| | ) |
| | ] |
| |
|