|
|
|
|
|
import io |
|
|
import base64 |
|
|
import torch |
|
|
from PIL import Image |
|
|
|
|
|
import open_clip |
|
|
from open_clip import fuse_conv_bn_sequential |
|
|
|
|
|
class EndpointHandler: |
|
|
""" |
|
|
Zero‑shot image classifier for MobileCLIP‑B (OpenCLIP). |
|
|
|
|
|
Expects JSON payload: |
|
|
{ |
|
|
"image": "<base64‑encoded PNG/JPEG>", |
|
|
"candidate_labels": ["cat", "dog", ...] |
|
|
} |
|
|
Returns: |
|
|
[ |
|
|
{"label": "cat", "score": 0.91}, |
|
|
{"label": "dog", "score": 0.05}, |
|
|
... |
|
|
] |
|
|
""" |
|
|
|
|
|
def __init__(self, path: str = ""): |
|
|
|
|
|
weights = f"{path}/mobileclip_b.pt" |
|
|
|
|
|
|
|
|
self.model, _, self.preprocess = open_clip.create_model_and_transforms( |
|
|
"MobileCLIP-B", pretrained=weights |
|
|
) |
|
|
|
|
|
|
|
|
self.model = fuse_conv_bn_sequential(self.model).eval() |
|
|
|
|
|
|
|
|
self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B") |
|
|
|
|
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.model.to(self.device) |
|
|
|
|
|
def __call__(self, data): |
|
|
|
|
|
img_b64 = data["image"] |
|
|
labels = data.get("candidate_labels", []) |
|
|
if not labels: |
|
|
return {"error": "candidate_labels list is empty"} |
|
|
|
|
|
|
|
|
image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB") |
|
|
image_tensor = self.preprocess(image).unsqueeze(0).to(self.device) |
|
|
|
|
|
|
|
|
text_tokens = self.tokenizer(labels).to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(), torch.cuda.amp.autocast(): |
|
|
img_feat = self.model.encode_image(image_tensor) |
|
|
txt_feat = self.model.encode_text(text_tokens) |
|
|
img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True) |
|
|
txt_feat = txt_feat / txt_feat.norm(dim=-1, keepdim=True) |
|
|
probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist() |
|
|
|
|
|
|
|
|
return [ |
|
|
{"label": l, "score": float(p)} |
|
|
for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True) |
|
|
] |
|
|
|