from typing import Dict, Any, List, Union import base64 from io import BytesIO import torch import numpy as np from PIL import Image from transformers import XCLIPProcessor, XCLIPModel class EndpointHandler: def __init__(self, path=""): self.model = XCLIPModel.from_pretrained(path) self.processor = XCLIPProcessor.from_pretrained(path) self.model.eval() try: self.num_frames = self.model.config.vision_config.num_frames except AttributeError: self.num_frames = 8 def _decode(self, b64: str) -> Image.Image: return Image.open(BytesIO(base64.b64decode(b64))).convert("RGB") def _sample(self, frames: List[Image.Image]) -> List[np.ndarray]: idx = np.linspace(0, len(frames) - 1, self.num_frames).round().astype(int) return [np.array(frames[i]) for i in idx] def __call__(self, data: Dict[str, Any]) -> Union[Dict, List]: inputs = data.get("inputs", data) frames_b64 = inputs["frames"] if isinstance(inputs, dict) else inputs labels = inputs.get("candidate_labels") if isinstance(inputs, dict) else None frames = [self._decode(f) for f in frames_b64] if not frames: raise ValueError("No frames provided") video = self._sample(frames) if labels: proc = self.processor(text=labels, videos=video, return_tensors="pt", padding=True) with torch.no_grad(): probs = self.model(**proc).logits_per_video.softmax(dim=1)[0] return [{"label": l, "score": float(s)} for l, s in zip(labels, probs)] proc = self.processor(videos=video, return_tensors="pt") with torch.no_grad(): feats = self.model.get_video_features(pixel_values=proc["pixel_values"]) feats = feats / feats.norm(p=2, dim=-1, keepdim=True) return {"embedding": feats[0].tolist(), "dim": int(feats.shape[-1])}