testmobileclip / handler.py
finhdev's picture
Update handler.py
2fb4fd2 verified
raw
history blame
7.2 kB
# handler.py (repo root)
import io, base64, torch, open_clip
from PIL import Image
# optional: from open_clip import fuse_conv_bn_sequential # if you want re‑param
class EndpointHandler:
"""
MobileCLIP‑B ('datacompdr') zero‑shot classifier with per‑process
text‑embedding cache.
Expected client JSON:
{
"inputs": {
"image": "<base64 PNG/JPEG>",
"candidate_labels": ["a photo of a cat", ...]
}
}
"""
def __init__(self, path=""):
# Load the exact weights your local run uses
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
"mobileclip_b", pretrained="datacompdr"
)
# Optional: fuse conv+bn for speed
# self.model = fuse_conv_bn_sequential(self.model).eval()
self.model.eval()
self.tokenizer = open_clip.get_tokenizer("mobileclip_b")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
self.cache: dict[str, torch.Tensor] = {} # prompt → embedding
def __call__(self, data):
payload = data.get("inputs", data)
img_b64 = payload["image"]
labels = payload.get("candidate_labels", [])
if not labels:
return {"error": "candidate_labels list is empty"}
# Image → tensor
img = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
img_t = self.preprocess(img).unsqueeze(0).to(self.device)
# Text embeddings with cache
new = [l for l in labels if l not in self.cache]
if new:
tok = self.tokenizer(new).to(self.device)
with torch.no_grad():
emb = self.model.encode_text(tok)
emb = emb / emb.norm(dim=-1, keepdim=True)
for l, e in zip(new, emb):
self.cache[l] = e
txt_t = torch.stack([self.cache[l] for l in labels])
# Forward
with torch.no_grad(), torch.cuda.amp.autocast():
img_f = self.model.encode_image(img_t)
img_f = img_f / img_f.norm(dim=-1, keepdim=True)
probs = (100 * img_f @ txt_t.T).softmax(dim=-1)[0].tolist()
return [
{"label": l, "score": float(p)}
for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
]
# import io, base64, torch
# from PIL import Image
# import open_clip
# class EndpointHandler:
# """
# Zero‑shot classifier for MobileCLIP‑B (OpenCLIP) with a text‑embedding cache.
# Client JSON:
# {
# "inputs": {
# "image": "<base64 PNG/JPEG>",
# "candidate_labels": ["cat", "dog", ...]
# }
# }
# """
# # ------------------------------------------------- #
# # INITIALISATION #
# # ------------------------------------------------- #
# def __init__(self, path: str = ""):
# weights = f"{path}/mobileclip_b.pt"
# self.model, _, self.preprocess = open_clip.create_model_and_transforms(
# "MobileCLIP-B", pretrained=weights
# )
# self.model.eval()
# self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
# self.device = "cuda" if torch.cuda.is_available() else "cpu"
# self.model.to(self.device)
# # cache: {prompt -> 1×512 tensor on device}
# self.label_cache: dict[str, torch.Tensor] = {}
# # ------------------------------------------------- #
# # INFERENCE #
# # ------------------------------------------------- #
# def __call__(self, data):
# payload = data.get("inputs", data)
# img_b64 = payload["image"]
# labels = payload.get("candidate_labels", [])
# if not labels:
# return {"error": "candidate_labels list is empty"}
# # --- image ----
# image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
# img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
# # --- text (with cache) ----
# missing = [l for l in labels if l not in self.label_cache]
# if missing:
# tokens = self.tokenizer(missing).to(self.device)
# with torch.no_grad():
# emb = self.model.encode_text(tokens)
# emb = emb / emb.norm(dim=-1, keepdim=True)
# for l, e in zip(missing, emb):
# self.label_cache[l] = e
# txt_feat = torch.stack([self.label_cache[l] for l in labels])
# # --- forward & softmax ----
# with torch.no_grad(), torch.cuda.amp.autocast():
# img_feat = self.model.encode_image(img_tensor)
# img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
# probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
# # --- sorted output ----
# return [
# {"label": l, "score": float(p)}
# for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
# ]
# # handler.py (repo root)
# import io, base64, torch
# from PIL import Image
# import open_clip
# class EndpointHandler:
# """
# Zero‑shot classifier for MobileCLIP‑B (OpenCLIP).
# Expected client JSON *to the endpoint*:
# {
# "inputs": {
# "image": "<base64 PNG/JPEG>",
# "candidate_labels": ["cat", "dog", ...]
# }
# }
# """
# def __init__(self, path: str = ""):
# weights = f"{path}/mobileclip_b.pt"
# self.model, _, self.preprocess = open_clip.create_model_and_transforms(
# "MobileCLIP-B", pretrained=weights
# )
# self.model.eval()
# self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
# self.device = "cuda" if torch.cuda.is_available() else "cpu"
# self.model.to(self.device)
# def __call__(self, data):
# # ── unwrap Hugging Face's `inputs` envelope ───────────
# payload = data.get("inputs", data)
# img_b64 = payload["image"]
# labels = payload.get("candidate_labels", [])
# if not labels:
# return {"error": "candidate_labels list is empty"}
# # Decode & preprocess image
# image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
# img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
# # Tokenise labels
# text_tokens = self.tokenizer(labels).to(self.device)
# # Forward pass
# with torch.no_grad(), torch.cuda.amp.autocast():
# img_feat = self.model.encode_image(img_tensor)
# txt_feat = self.model.encode_text(text_tokens)
# img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
# txt_feat = txt_feat / txt_feat.norm(dim=-1, keepdim=True)
# probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
# # Sorted output
# return [
# {"label": l, "score": float(p)}
# for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
# ]