| from typing import Dict, Any, List, Union |
| import base64 |
| from io import BytesIO |
|
|
| import torch |
| import numpy as np |
| from PIL import Image |
| from transformers import XCLIPProcessor, XCLIPModel |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| self.model = XCLIPModel.from_pretrained(path) |
| self.processor = XCLIPProcessor.from_pretrained(path) |
| self.model.eval() |
| try: |
| self.num_frames = self.model.config.vision_config.num_frames |
| except AttributeError: |
| self.num_frames = 8 |
|
|
| def _decode(self, b64: str) -> Image.Image: |
| return Image.open(BytesIO(base64.b64decode(b64))).convert("RGB") |
|
|
| def _sample(self, frames: List[Image.Image]) -> List[np.ndarray]: |
| idx = np.linspace(0, len(frames) - 1, self.num_frames).round().astype(int) |
| return [np.array(frames[i]) for i in idx] |
|
|
| def __call__(self, data: Dict[str, Any]) -> Union[Dict, List]: |
| inputs = data.get("inputs", data) |
| frames_b64 = inputs["frames"] if isinstance(inputs, dict) else inputs |
| labels = inputs.get("candidate_labels") if isinstance(inputs, dict) else None |
|
|
| frames = [self._decode(f) for f in frames_b64] |
| if not frames: |
| raise ValueError("No frames provided") |
| video = self._sample(frames) |
|
|
| if labels: |
| proc = self.processor(text=labels, videos=video, |
| return_tensors="pt", padding=True) |
| with torch.no_grad(): |
| probs = self.model(**proc).logits_per_video.softmax(dim=1)[0] |
| return [{"label": l, "score": float(s)} for l, s in zip(labels, probs)] |
|
|
| proc = self.processor(videos=video, return_tensors="pt") |
| with torch.no_grad(): |
| feats = self.model.get_video_features(pixel_values=proc["pixel_values"]) |
| feats = feats / feats.norm(p=2, dim=-1, keepdim=True) |
| return {"embedding": feats[0].tolist(), "dim": int(feats.shape[-1])} |