File size: 5,198 Bytes
91ea5a9
 
 
7a9e335
 
 
 
 
91ea5a9
 
 
 
 
7a9e335
 
 
91ea5a9
 
 
7a9e335
91ea5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a9e335
91ea5a9
7a9e335
 
91ea5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a9e335
91ea5a9
 
7a9e335
91ea5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a9e335
91ea5a9
 
 
 
 
 
 
7a9e335
 
91ea5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import warnings
warnings.filterwarnings("ignore")

import torch
import numpy as np
from PIL import Image


MODELS = {
    "CLIP ViT-B/32": "clip",
    "SigLIP 2 Base": "siglip2",
    "X-CLIP Base": "xclip",
}


class ZeroShotVideoClassifier:
    def __init__(self, model_key: str = "CLIP ViT-B/32", device: str = None):
        self.model_key = model_key
        self.backend = MODELS[model_key]
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self._load_model()

    def _load_model(self):
        if self.backend == "clip":
            import open_clip
            self.model, _, self.preprocess = open_clip.create_model_and_transforms(
                "ViT-B-32", pretrained="openai"
            )
            self.tokenizer = open_clip.get_tokenizer("ViT-B-32")
            self.model.to(self.device).eval()

        elif self.backend == "siglip2":
            from transformers import AutoModel, AutoProcessor
            self.model = AutoModel.from_pretrained("google/siglip2-base-patch16-224").eval()
            self.processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")
            self.model.to(self.device)

        elif self.backend == "xclip":
            from transformers import XCLIPModel, XCLIPProcessor
            self.model = XCLIPModel.from_pretrained("microsoft/xclip-base-patch16").eval()
            self.processor = XCLIPProcessor.from_pretrained("microsoft/xclip-base-patch16")
            self.model.to(self.device)

    def _encode_text_clip(self, labels: list[str]) -> torch.Tensor:
        texts = self.tokenizer(["a video of {}".format(l) for l in labels]).to(self.device)
        with torch.no_grad():
            embs = self.model.encode_text(texts)
            return embs / embs.norm(dim=-1, keepdim=True)

    def _encode_frames_clip(self, frames: list[Image.Image]) -> torch.Tensor:
        tensors = torch.stack([self.preprocess(f) for f in frames]).to(self.device)
        with torch.no_grad():
            embs = self.model.encode_image(tensors)
            embs = embs / embs.norm(dim=-1, keepdim=True)
            return embs.mean(dim=0, keepdim=True)

    def _encode_text_siglip2(self, labels: list[str]) -> torch.Tensor:
        inputs = self.processor(
            text=["a video of {}".format(l) for l in labels],
            padding="max_length",
            truncation=True,
            max_length=64,
            return_tensors="pt"
        )
        with torch.no_grad():
            out = self.model.get_text_features(**inputs)
            embs = out.pooler_output if hasattr(out, "pooler_output") else out[0][:, 0, :]
            return embs / embs.norm(dim=-1, keepdim=True)

    def _encode_frames_siglip2(self, frames: list[Image.Image]) -> torch.Tensor:
        inputs = self.processor(images=frames, return_tensors="pt")
        with torch.no_grad():
            out = self.model.get_image_features(**inputs)
            embs = out.pooler_output if hasattr(out, "pooler_output") else out[0][:, 0, :]
            embs = embs / embs.norm(dim=-1, keepdim=True)
            video_emb = embs.mean(dim=0, keepdim=True)
            return video_emb / video_emb.norm(dim=-1, keepdim=True)

    def _encode_text_xclip(self, labels: list[str]) -> torch.Tensor:
        inputs = self.processor.tokenizer(
            ["a video of {}".format(l) for l in labels],
            return_tensors="pt",
            padding=True,
            truncation=True
        )
        with torch.no_grad():
            out = self.model.get_text_features(**inputs)
            embs = out.pooler_output
            return embs / embs.norm(dim=-1, keepdim=True)

    def _encode_frames_xclip(self, frames: list[Image.Image]) -> torch.Tensor:
        frames_np = [np.array(f) for f in frames]
        inputs = self.processor.image_processor(images=frames_np, return_tensors="pt")
        with torch.no_grad():
            out = self.model.get_video_features(pixel_values=inputs["pixel_values"])
            embs = out.pooler_output
            return embs / embs.norm(dim=-1, keepdim=True)

    def classify(self, frames: list[Image.Image], labels: list[str], top_k: int = 5) -> list[dict]:
        if self.backend == "clip":
            text_embs = self._encode_text_clip(labels)
            video_emb = self._encode_frames_clip(frames)
            scores = (video_emb @ text_embs.T).squeeze().softmax(dim=-1).cpu().numpy()

        elif self.backend == "siglip2":
            text_embs = self._encode_text_siglip2(labels)
            video_emb = self._encode_frames_siglip2(frames)
            logit_scale = self.model.logit_scale.exp()
            logit_bias = self.model.logit_bias
            logits = video_emb @ text_embs.T * logit_scale + logit_bias
            scores = torch.sigmoid(logits).detach().squeeze().cpu().numpy()

        elif self.backend == "xclip":
            text_embs = self._encode_text_xclip(labels)
            video_emb = self._encode_frames_xclip(frames)
            scores = (video_emb @ text_embs.T).detach().squeeze().cpu().numpy()

        top_indices = np.argsort(scores)[::-1][:top_k]
        return [{"label": labels[i], "score": float(scores[i])} for i in top_indices]