|
|
|
|
|
import io, base64, torch |
|
|
from PIL import Image |
|
|
from transformers import CLIPModel, CLIPProcessor |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
""" |
|
|
Custom zero‑shot classifier replicating local OpenAI‑CLIP logic. |
|
|
|
|
|
Client JSON must look like: |
|
|
{ |
|
|
"inputs": { |
|
|
"image": "<base64 PNG/JPEG>", |
|
|
"candidate_labels": ["car", "teddy bear", ...] |
|
|
} |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, path: str = ""): |
|
|
|
|
|
self.model = CLIPModel.from_pretrained(path) |
|
|
self.processor = CLIPProcessor.from_pretrained(path) |
|
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.model.to(self.device).eval() |
|
|
|
|
|
|
|
|
self.cache: dict[str, torch.Tensor] = {} |
|
|
|
|
|
|
|
|
def __call__(self, data): |
|
|
payload = data.get("inputs", data) |
|
|
|
|
|
img_b64 = payload["image"] |
|
|
names = payload.get("candidate_labels", []) |
|
|
if not names: |
|
|
return {"error": "candidate_labels list is empty"} |
|
|
|
|
|
|
|
|
prompts = [f"a photo of a {p}" for p in names] |
|
|
|
|
|
|
|
|
missing = [p for p in prompts if p not in self.cache] |
|
|
if missing: |
|
|
txt_in = self.processor(text=missing, return_tensors="pt", |
|
|
padding=True).to(self.device) |
|
|
with torch.no_grad(): |
|
|
txt_emb = self.model.get_text_features(**txt_in) |
|
|
txt_emb = txt_emb / txt_emb.norm(dim=-1, keepdim=True) |
|
|
for p, e in zip(missing, txt_emb): |
|
|
self.cache[p] = e |
|
|
txt_feat = torch.stack([self.cache[p] for p in prompts]) |
|
|
|
|
|
|
|
|
img = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB") |
|
|
img_in = self.processor(images=img, return_tensors="pt").to(self.device) |
|
|
|
|
|
with torch.no_grad(), torch.cuda.amp.autocast(): |
|
|
img_feat = self.model.get_image_features(**img_in) |
|
|
img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True) |
|
|
|
|
|
|
|
|
probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist() |
|
|
|
|
|
|
|
|
return [ |
|
|
{"label": n, "score": float(p)} |
|
|
for n, p in sorted(zip(names, probs), key=lambda x: x[1], reverse=True) |
|
|
] |
|
|
|