File size: 2,006 Bytes
d39b279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
from transformers import VideoMAEImageProcessor, VideoMAEForVideoClassification
import torch.nn as nn
import torchvision
import time
from .mamba_base import MambaConfig, ResidualBlock

def create_reorder_index(N, device):
    new_order = []
    for col in range(N):
        if col % 2 == 0:
            new_order.extend(range(col, N*N, N))
        else:
            new_order.extend(range(col + N*(N-1), col-1, -N))
    return torch.tensor(new_order, device=device)

def reorder_data(data, N):
    assert isinstance(data, torch.Tensor), "data should be a torch.Tensor"
    device = data.device
    new_order = create_reorder_index(N, device)
    B, t, _, _ = data.shape
    index = new_order.repeat(B, t, 1).unsqueeze(-1)
    reordered_data = torch.gather(data, 2, index.expand_as(data))
    return reordered_data

class Videomae_Net(nn.Module):
    def __init__(
        self, channel_size=512, dropout=0.2, class_num=1
    ):
        super(Videomae_Net, self).__init__()
        self.model = VideoMAEForVideoClassification.from_pretrained("/ossfs/workspace/GenVideo/pretrained_weights/videomae")
        self.fc1 = nn.Linear(768, class_num)
        self.bn1 = nn.BatchNorm1d(768)

        self._init_params()
    
    def _init_params(self):
        nn.init.xavier_normal_(self.fc1.weight)
        nn.init.constant_(self.fc1.bias, 0)

    def forward(self, x):
        x = self.model.videomae(x)
        sequence_output = x[0]
        print(sequence_output.shape)
        if self.model.fc_norm is not None:
            sequence_output = self.model.fc_norm(sequence_output.mean(1))
        else:
            sequence_output = sequence_output[:, 0]
        x = self.bn1(sequence_output)
        x = self.fc1(x)
        return x



if __name__ == '__main__':

    model = Videomae_Net()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()


    input_data = torch.randn(1, 16, 3, 224, 224).to(device)

    model(input_data)