File size: 2,350 Bytes
08ce2dc 9188b68 35037e4 08ce2dc 9188b68 35037e4 08ce2dc 35037e4 08ce2dc 35037e4 08ce2dc 35037e4 08ce2dc 35037e4 08ce2dc 35037e4 08ce2dc 35037e4 08ce2dc 35037e4 08ce2dc 35037e4 9188b68 08ce2dc 9188b68 08ce2dc 35037e4 08ce2dc 35037e4 08ce2dc 35037e4 08ce2dc 35037e4 08ce2dc 35037e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
# handler.py
import io
import base64
import torch
from PIL import Image
import open_clip
from open_clip import fuse_conv_bn_sequential
class EndpointHandler:
"""
Zero‑shot image classifier for MobileCLIP‑B (OpenCLIP).
Expects JSON payload:
{
"image": "<base64‑encoded PNG/JPEG>",
"candidate_labels": ["cat", "dog", ...]
}
Returns:
[
{"label": "cat", "score": 0.91},
{"label": "dog", "score": 0.05},
...
]
"""
def __init__(self, path: str = ""):
# Path points to the repo root inside the container
weights = f"{path}/mobileclip_b.pt"
# Load model + transforms from OpenCLIP
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
"MobileCLIP-B", pretrained=weights
)
# Fuse conv + BN for faster inference (same idea as MobileCLIP re‑param)
self.model = fuse_conv_bn_sequential(self.model).eval()
# Tokenizer for label prompts
self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
# Device selection
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
def __call__(self, data):
# 1. Parse request
img_b64 = data["image"]
labels = data.get("candidate_labels", [])
if not labels:
return {"error": "candidate_labels list is empty"}
# 2. Decode & preprocess image
image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
# 3. Tokenize labels
text_tokens = self.tokenizer(labels).to(self.device)
# 4. Forward pass
with torch.no_grad(), torch.cuda.amp.autocast():
img_feat = self.model.encode_image(image_tensor)
txt_feat = self.model.encode_text(text_tokens)
img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
txt_feat = txt_feat / txt_feat.norm(dim=-1, keepdim=True)
probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
# 5. Return sorted list
return [
{"label": l, "score": float(p)}
for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
]
|