Rohit Mugalya
updated the interface with all three models
91ea5a9
import warnings
warnings.filterwarnings("ignore")
import torch
import numpy as np
from PIL import Image
MODELS = {
"CLIP ViT-B/32": "clip",
"SigLIP 2 Base": "siglip2",
"X-CLIP Base": "xclip",
}
class ZeroShotVideoClassifier:
def __init__(self, model_key: str = "CLIP ViT-B/32", device: str = None):
self.model_key = model_key
self.backend = MODELS[model_key]
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self._load_model()
def _load_model(self):
if self.backend == "clip":
import open_clip
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
"ViT-B-32", pretrained="openai"
)
self.tokenizer = open_clip.get_tokenizer("ViT-B-32")
self.model.to(self.device).eval()
elif self.backend == "siglip2":
from transformers import AutoModel, AutoProcessor
self.model = AutoModel.from_pretrained("google/siglip2-base-patch16-224").eval()
self.processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")
self.model.to(self.device)
elif self.backend == "xclip":
from transformers import XCLIPModel, XCLIPProcessor
self.model = XCLIPModel.from_pretrained("microsoft/xclip-base-patch16").eval()
self.processor = XCLIPProcessor.from_pretrained("microsoft/xclip-base-patch16")
self.model.to(self.device)
def _encode_text_clip(self, labels: list[str]) -> torch.Tensor:
texts = self.tokenizer(["a video of {}".format(l) for l in labels]).to(self.device)
with torch.no_grad():
embs = self.model.encode_text(texts)
return embs / embs.norm(dim=-1, keepdim=True)
def _encode_frames_clip(self, frames: list[Image.Image]) -> torch.Tensor:
tensors = torch.stack([self.preprocess(f) for f in frames]).to(self.device)
with torch.no_grad():
embs = self.model.encode_image(tensors)
embs = embs / embs.norm(dim=-1, keepdim=True)
return embs.mean(dim=0, keepdim=True)
def _encode_text_siglip2(self, labels: list[str]) -> torch.Tensor:
inputs = self.processor(
text=["a video of {}".format(l) for l in labels],
padding="max_length",
truncation=True,
max_length=64,
return_tensors="pt"
)
with torch.no_grad():
out = self.model.get_text_features(**inputs)
embs = out.pooler_output if hasattr(out, "pooler_output") else out[0][:, 0, :]
return embs / embs.norm(dim=-1, keepdim=True)
def _encode_frames_siglip2(self, frames: list[Image.Image]) -> torch.Tensor:
inputs = self.processor(images=frames, return_tensors="pt")
with torch.no_grad():
out = self.model.get_image_features(**inputs)
embs = out.pooler_output if hasattr(out, "pooler_output") else out[0][:, 0, :]
embs = embs / embs.norm(dim=-1, keepdim=True)
video_emb = embs.mean(dim=0, keepdim=True)
return video_emb / video_emb.norm(dim=-1, keepdim=True)
def _encode_text_xclip(self, labels: list[str]) -> torch.Tensor:
inputs = self.processor.tokenizer(
["a video of {}".format(l) for l in labels],
return_tensors="pt",
padding=True,
truncation=True
)
with torch.no_grad():
out = self.model.get_text_features(**inputs)
embs = out.pooler_output
return embs / embs.norm(dim=-1, keepdim=True)
def _encode_frames_xclip(self, frames: list[Image.Image]) -> torch.Tensor:
frames_np = [np.array(f) for f in frames]
inputs = self.processor.image_processor(images=frames_np, return_tensors="pt")
with torch.no_grad():
out = self.model.get_video_features(pixel_values=inputs["pixel_values"])
embs = out.pooler_output
return embs / embs.norm(dim=-1, keepdim=True)
def classify(self, frames: list[Image.Image], labels: list[str], top_k: int = 5) -> list[dict]:
if self.backend == "clip":
text_embs = self._encode_text_clip(labels)
video_emb = self._encode_frames_clip(frames)
scores = (video_emb @ text_embs.T).squeeze().softmax(dim=-1).cpu().numpy()
elif self.backend == "siglip2":
text_embs = self._encode_text_siglip2(labels)
video_emb = self._encode_frames_siglip2(frames)
logit_scale = self.model.logit_scale.exp()
logit_bias = self.model.logit_bias
logits = video_emb @ text_embs.T * logit_scale + logit_bias
scores = torch.sigmoid(logits).detach().squeeze().cpu().numpy()
elif self.backend == "xclip":
text_embs = self._encode_text_xclip(labels)
video_emb = self._encode_frames_xclip(frames)
scores = (video_emb @ text_embs.T).detach().squeeze().cpu().numpy()
top_indices = np.argsort(scores)[::-1][:top_k]
return [{"label": labels[i], "score": float(scores[i])} for i in top_indices]