File size: 2,006 Bytes
d39b279 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import torch
from transformers import VideoMAEImageProcessor, VideoMAEForVideoClassification
import torch.nn as nn
import torchvision
import time
from .mamba_base import MambaConfig, ResidualBlock
def create_reorder_index(N, device):
new_order = []
for col in range(N):
if col % 2 == 0:
new_order.extend(range(col, N*N, N))
else:
new_order.extend(range(col + N*(N-1), col-1, -N))
return torch.tensor(new_order, device=device)
def reorder_data(data, N):
assert isinstance(data, torch.Tensor), "data should be a torch.Tensor"
device = data.device
new_order = create_reorder_index(N, device)
B, t, _, _ = data.shape
index = new_order.repeat(B, t, 1).unsqueeze(-1)
reordered_data = torch.gather(data, 2, index.expand_as(data))
return reordered_data
class Videomae_Net(nn.Module):
def __init__(
self, channel_size=512, dropout=0.2, class_num=1
):
super(Videomae_Net, self).__init__()
self.model = VideoMAEForVideoClassification.from_pretrained("/ossfs/workspace/GenVideo/pretrained_weights/videomae")
self.fc1 = nn.Linear(768, class_num)
self.bn1 = nn.BatchNorm1d(768)
self._init_params()
def _init_params(self):
nn.init.xavier_normal_(self.fc1.weight)
nn.init.constant_(self.fc1.bias, 0)
def forward(self, x):
x = self.model.videomae(x)
sequence_output = x[0]
print(sequence_output.shape)
if self.model.fc_norm is not None:
sequence_output = self.model.fc_norm(sequence_output.mean(1))
else:
sequence_output = sequence_output[:, 0]
x = self.bn1(sequence_output)
x = self.fc1(x)
return x
if __name__ == '__main__':
model = Videomae_Net()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
input_data = torch.randn(1, 16, 3, 224, 224).to(device)
model(input_data)
|