File size: 950 Bytes
d39b279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from transformers import XCLIPVisionModel
import os
import sys
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import math

from transformers import XCLIPVisionModel
class XCLIP(nn.Module):
    def __init__(
        self, channel_size=512, dropout=0.2, class_num=1
    ):
        super(XCLIP, self).__init__()
      
        self.backbone = XCLIPVisionModel.from_pretrained("GenVideo/pretrained_weights/xclip")
        self.fc_norm = nn.LayerNorm(768)
        self.head = nn.Linear(768, 1)

    def forward(self, x):
        b, t, _, h, w = x.shape
        images = x.view(b * t, 3, h, w)
        outputs = self.backbone(images, output_hidden_states=True)
        sequence_output = outputs['pooler_output'].reshape(b, t, -1)
        video_level_features = self.fc_norm(sequence_output.mean(1))
        pred = self.head(video_level_features)

        return pred