File size: 1,374 Bytes
fd3be3e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
from internvideo2_stage2 import InternVideo2_Stage2 as IV2S2
from transformers import PretrainedConfig, PreTrainedModel, AutoModel, AutoConfig
from config import InternVideo2Config as config
import warnings
import torch
warnings.filterwarnings("ignore")
# model_config = config()
# model = IV2S2(model_config)
# print(model)
class InternVideo2Stage2VideoEncoder(PreTrainedModel):
config_class = config
def __init__(self, config):
super().__init__(config)
self.config = config
self.model = IV2S2(config).half().to(config.device)
def forward(self, x: torch.tensor):
"""forward pass
Args:
x (torch.tensor): Shape (B, N, C, H, W) or (N, C, H, W)
Returns:
torch.tensor: Shape (B*N, hidden_size)
"""
# x: Shape(B, C, N, H, W)
# output: Shape(B, N*98, hidden_size)
if len(x.shape) == 4:
x = x.unsqueeze(0)
B, N, C, H, W = x.shape
x = x.permute(0, 2, 1, 3, 4) # Shape(B, C, N, H, W)
output = self.model.encode_vision(x)
pooled_vision_embeds = output[1]
return pooled_vision_embeds
if __name__ == "__main__":
model_config = config()
model = InternVideo2Stage2VideoEncoder(model_config)
x = torch.randn(2, 3, 8, 224, 224, dtype=torch.float16).to(model_config.device)
output = model(x) |