xclip-base-patch32 / handler.py
kas1293's picture
Update handler.py
0dd652c verified
Raw
History Blame Contribute Delete
1.97 kB
from typing import Dict, Any, List, Union
import base64
from io import BytesIO
import torch
import numpy as np
from PIL import Image
from transformers import XCLIPProcessor, XCLIPModel
class EndpointHandler:
def __init__(self, path=""):
self.model = XCLIPModel.from_pretrained(path)
self.processor = XCLIPProcessor.from_pretrained(path)
self.model.eval()
try:
self.num_frames = self.model.config.vision_config.num_frames
except AttributeError:
self.num_frames = 8
def _decode(self, b64: str) -> Image.Image:
return Image.open(BytesIO(base64.b64decode(b64))).convert("RGB")
def _sample(self, frames: List[Image.Image]) -> List[np.ndarray]:
idx = np.linspace(0, len(frames) - 1, self.num_frames).round().astype(int)
return [np.array(frames[i]) for i in idx]
def __call__(self, data: Dict[str, Any]) -> Union[Dict, List]:
inputs = data.get("inputs", data)
frames_b64 = inputs["frames"] if isinstance(inputs, dict) else inputs
labels = inputs.get("candidate_labels") if isinstance(inputs, dict) else None
frames = [self._decode(f) for f in frames_b64]
if not frames:
raise ValueError("No frames provided")
video = self._sample(frames)
if labels:
proc = self.processor(text=labels, videos=video,
return_tensors="pt", padding=True)
with torch.no_grad():
probs = self.model(**proc).logits_per_video.softmax(dim=1)[0]
return [{"label": l, "score": float(s)} for l, s in zip(labels, probs)]
proc = self.processor(videos=video, return_tensors="pt")
with torch.no_grad():
feats = self.model.get_video_features(pixel_values=proc["pixel_values"])
feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
return {"embedding": feats[0].tolist(), "dim": int(feats.shape[-1])}