import warnings warnings.filterwarnings("ignore") import torch import numpy as np from PIL import Image MODELS = { "CLIP ViT-B/32": "clip", "SigLIP 2 Base": "siglip2", "X-CLIP Base": "xclip", } class ZeroShotVideoClassifier: def __init__(self, model_key: str = "CLIP ViT-B/32", device: str = None): self.model_key = model_key self.backend = MODELS[model_key] self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self._load_model() def _load_model(self): if self.backend == "clip": import open_clip self.model, _, self.preprocess = open_clip.create_model_and_transforms( "ViT-B-32", pretrained="openai" ) self.tokenizer = open_clip.get_tokenizer("ViT-B-32") self.model.to(self.device).eval() elif self.backend == "siglip2": from transformers import AutoModel, AutoProcessor self.model = AutoModel.from_pretrained("google/siglip2-base-patch16-224").eval() self.processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224") self.model.to(self.device) elif self.backend == "xclip": from transformers import XCLIPModel, XCLIPProcessor self.model = XCLIPModel.from_pretrained("microsoft/xclip-base-patch16").eval() self.processor = XCLIPProcessor.from_pretrained("microsoft/xclip-base-patch16") self.model.to(self.device) def _encode_text_clip(self, labels: list[str]) -> torch.Tensor: texts = self.tokenizer(["a video of {}".format(l) for l in labels]).to(self.device) with torch.no_grad(): embs = self.model.encode_text(texts) return embs / embs.norm(dim=-1, keepdim=True) def _encode_frames_clip(self, frames: list[Image.Image]) -> torch.Tensor: tensors = torch.stack([self.preprocess(f) for f in frames]).to(self.device) with torch.no_grad(): embs = self.model.encode_image(tensors) embs = embs / embs.norm(dim=-1, keepdim=True) return embs.mean(dim=0, keepdim=True) def _encode_text_siglip2(self, labels: list[str]) -> torch.Tensor: inputs = self.processor( text=["a video of {}".format(l) for l in labels], padding="max_length", truncation=True, max_length=64, return_tensors="pt" ) with torch.no_grad(): out = self.model.get_text_features(**inputs) embs = out.pooler_output if hasattr(out, "pooler_output") else out[0][:, 0, :] return embs / embs.norm(dim=-1, keepdim=True) def _encode_frames_siglip2(self, frames: list[Image.Image]) -> torch.Tensor: inputs = self.processor(images=frames, return_tensors="pt") with torch.no_grad(): out = self.model.get_image_features(**inputs) embs = out.pooler_output if hasattr(out, "pooler_output") else out[0][:, 0, :] embs = embs / embs.norm(dim=-1, keepdim=True) video_emb = embs.mean(dim=0, keepdim=True) return video_emb / video_emb.norm(dim=-1, keepdim=True) def _encode_text_xclip(self, labels: list[str]) -> torch.Tensor: inputs = self.processor.tokenizer( ["a video of {}".format(l) for l in labels], return_tensors="pt", padding=True, truncation=True ) with torch.no_grad(): out = self.model.get_text_features(**inputs) embs = out.pooler_output return embs / embs.norm(dim=-1, keepdim=True) def _encode_frames_xclip(self, frames: list[Image.Image]) -> torch.Tensor: frames_np = [np.array(f) for f in frames] inputs = self.processor.image_processor(images=frames_np, return_tensors="pt") with torch.no_grad(): out = self.model.get_video_features(pixel_values=inputs["pixel_values"]) embs = out.pooler_output return embs / embs.norm(dim=-1, keepdim=True) def classify(self, frames: list[Image.Image], labels: list[str], top_k: int = 5) -> list[dict]: if self.backend == "clip": text_embs = self._encode_text_clip(labels) video_emb = self._encode_frames_clip(frames) scores = (video_emb @ text_embs.T).squeeze().softmax(dim=-1).cpu().numpy() elif self.backend == "siglip2": text_embs = self._encode_text_siglip2(labels) video_emb = self._encode_frames_siglip2(frames) logit_scale = self.model.logit_scale.exp() logit_bias = self.model.logit_bias logits = video_emb @ text_embs.T * logit_scale + logit_bias scores = torch.sigmoid(logits).detach().squeeze().cpu().numpy() elif self.backend == "xclip": text_embs = self._encode_text_xclip(labels) video_emb = self._encode_frames_xclip(frames) scores = (video_emb @ text_embs.T).detach().squeeze().cpu().numpy() top_indices = np.argsort(scores)[::-1][:top_k] return [{"label": labels[i], "score": float(scores[i])} for i in top_indices]