Kalpit
feat: Add model files with LFS
d39b279
raw
history blame
6.55 kB
from transformers import XCLIPVisionModel
import os
import sys
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from .mamba_base import MambaConfig, ResidualBlock
import torch.nn.init as init
from .clip import clip
import math
from transformers import XCLIPVisionConfig, XCLIPVisionModel
def create_reorder_index(N, device):
new_order = []
for col in range(N):
if col % 2 == 0:
new_order.extend(range(col, N*N, N))
else:
new_order.extend(range(col + N*(N-1), col-1, -N))
return torch.tensor(new_order, device=device)
def reorder_data(data, N):
assert isinstance(data, torch.Tensor), "data should be a torch.Tensor"
device = data.device
new_order = create_reorder_index(N, device)
B, t, _, _ = data.shape
index = new_order.repeat(B, t, 1).unsqueeze(-1)
reordered_data = torch.gather(data, 2, index.expand_as(data))
return reordered_data
class XCLIP_DeMamba(nn.Module):
def __init__(
self, channel_size=768, class_num=1
):
super(XCLIP_DeMamba, self).__init__()
# self.encoder = XCLIPVisionModel.from_pretrained("GenVideo/pretrained_weights/xclip")
# my code for training from scratch
config = XCLIPVisionConfig()
self.encoder = XCLIPVisionModel(config)
blocks = []
channel = 768
self.fusing_ratios = 1
self.patch_nums = (14//self.fusing_ratios)**2
self.mamba_configs = MambaConfig(d_model=channel)
self.mamba = ResidualBlock(config = self.mamba_configs)
# self.fc1 = nn.Linear((self.patch_nums+1)*channel, class_num)
self.fc1 = nn.Linear(38400, class_num) # my code
# self.fc_norm = nn.LayerNorm(self.patch_nums*channel)
self.fc_norm = None # my code
self.fc_norm2 = nn.LayerNorm(768)
self.initialize_weights(self.fc1)
self.dropout = nn.Dropout(p=0.0)
def initialize_weights(self, module):
for m in module.modules():
if isinstance(m, nn.Linear):
init.xavier_uniform_(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
init.kaiming_uniform_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
def forward(self, x):
b, t, _, h, w = x.shape
images = x.view(b * t, 3, h, w)
outputs = self.encoder(images, output_hidden_states=True)
sequence_output = outputs['last_hidden_state'][:,1:,:]
_, _, c = sequence_output.shape
global_feat = outputs['pooler_output'].reshape(b, t, -1)
global_feat = global_feat.mean(1)
global_feat = self.fc_norm2(global_feat)
sequence_output = sequence_output.view(b, t, -1, c)
_, _, f_w, _ = sequence_output.shape
f_h, f_w = int(math.sqrt(f_w)), int(math.sqrt(f_w))
s = f_h//self.fusing_ratios
sequence_output = sequence_output.view(b, t, self.fusing_ratios, s, self.fusing_ratios, s, c)
x = sequence_output.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(b*s*s, t, -1, c)
b_l = b*s*s
x = reorder_data(x, self.fusing_ratios)
x = x.permute(0, 2, 1, 3).contiguous().view(b_l, -1, c)
res = self.mamba(x)
video_level_features = res.mean(1)
video_level_features = video_level_features.view(b, -1)
# my code
if self.fc_norm is None:
self.fc_norm = nn.LayerNorm(video_level_features.size(-1)).to(video_level_features.device)
video_level_features = self.fc_norm(video_level_features)
video_level_features = torch.cat((global_feat, video_level_features), dim=1)
pred = self.fc1(video_level_features)
pred = self.dropout(pred)
return pred
class CLIP_DeMamba(nn.Module):
def __init__(
self, channel_size=512, class_num=1
):
super(CLIP_DeMamba, self).__init__()
self.clip_model, preprocess = clip.load('ViT-B-14')
self.clip_model = self.clip_model.float()
blocks = []
channel = 512
self.fusing_ratios = 2
self.patch_nums = (14//self.fusing_ratios)**2
self.mamba_configs = MambaConfig(d_model=channel)
self.mamba = ResidualBlock(config = self.mamba_configs)
self.fc1 = nn.Linear(channel*(self.patch_nums+1), class_num)
self.bn1 = nn.BatchNorm1d(channel)
self.initialize_weights(self.fc1)
def initialize_weights(self, module):
for m in module.modules():
if isinstance(m, nn.Linear):
init.xavier_uniform_(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
init.kaiming_uniform_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
def forward(self, x):
b, t, _, h, w = x.shape
images = x.view(b * t, 3, h, w)
sequence_output = self.clip_model.encode_image(images)
_, _, c = sequence_output.shape
sequence_output = sequence_output.view(b, t, -1, c)
global_feat = sequence_output.reshape(b, -1, c)
global_feat = global_feat.mean(1)
_, _, f_w, _ = sequence_output.shape
f_h, f_w = int(math.sqrt(f_w)), int(math.sqrt(f_w))
s = f_h//self.fusing_ratios
sequence_output = sequence_output.view(b, t, self.fusing_ratios, s, self.fusing_ratios, s, c)
x = sequence_output.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(b*s*s, t, -1, c)
b_l = b*s*s
x = reorder_data(x, self.fusing_ratios)
x = x.permute(0, 2, 1, 3).contiguous().view(b_l, -1, c)
res = self.mamba(x)
video_level_features = res.mean(1)
video_level_features = video_level_features.view(b, -1)
video_level_features = torch.cat((global_feat, video_level_features), dim=1)
x = self.fc1(video_level_features)
return x
if __name__ == '__main__':
model = CLIP_DeMamba()
print(model)