testmobileclip / handler.py
finhdev's picture
Update handler.py
08ce2dc verified
raw
history blame
2.35 kB
# handler.py
import io
import base64
import torch
from PIL import Image
import open_clip
from open_clip import fuse_conv_bn_sequential
class EndpointHandler:
"""
Zero‑shot image classifier for MobileCLIP‑B (OpenCLIP).
Expects JSON payload:
{
"image": "<base64‑encoded PNG/JPEG>",
"candidate_labels": ["cat", "dog", ...]
}
Returns:
[
{"label": "cat", "score": 0.91},
{"label": "dog", "score": 0.05},
...
]
"""
def __init__(self, path: str = ""):
# Path points to the repo root inside the container
weights = f"{path}/mobileclip_b.pt"
# Load model + transforms from OpenCLIP
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
"MobileCLIP-B", pretrained=weights
)
# Fuse conv + BN for faster inference (same idea as MobileCLIP re‑param)
self.model = fuse_conv_bn_sequential(self.model).eval()
# Tokenizer for label prompts
self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
# Device selection
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
def __call__(self, data):
# 1. Parse request
img_b64 = data["image"]
labels = data.get("candidate_labels", [])
if not labels:
return {"error": "candidate_labels list is empty"}
# 2. Decode & preprocess image
image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
# 3. Tokenize labels
text_tokens = self.tokenizer(labels).to(self.device)
# 4. Forward pass
with torch.no_grad(), torch.cuda.amp.autocast():
img_feat = self.model.encode_image(image_tensor)
txt_feat = self.model.encode_text(text_tokens)
img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
txt_feat = txt_feat / txt_feat.norm(dim=-1, keepdim=True)
probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
# 5. Return sorted list
return [
{"label": l, "score": float(p)}
for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
]