File size: 1,374 Bytes
fd3be3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from internvideo2_stage2 import InternVideo2_Stage2 as IV2S2
from transformers import PretrainedConfig, PreTrainedModel, AutoModel, AutoConfig
from config import InternVideo2Config as config
import warnings
import torch
warnings.filterwarnings("ignore")

# model_config = config()
# model = IV2S2(model_config)
# print(model)

class InternVideo2Stage2VideoEncoder(PreTrainedModel):
    config_class = config

    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.model = IV2S2(config).half().to(config.device)

    def forward(self, x: torch.tensor):
        """forward pass 
        Args:
            x (torch.tensor): Shape (B, N, C, H, W) or (N, C, H, W)
        Returns:
            torch.tensor: Shape (B*N, hidden_size)
        """
        # x: Shape(B, C, N, H, W)
        # output: Shape(B, N*98, hidden_size)
        if len(x.shape) == 4:
            x = x.unsqueeze(0)
        B, N, C, H, W = x.shape
        x = x.permute(0, 2, 1, 3, 4)    # Shape(B, C, N, H, W)
        output = self.model.encode_vision(x)   
        pooled_vision_embeds = output[1]
        return pooled_vision_embeds
    
if __name__ == "__main__":
    model_config = config()
    model = InternVideo2Stage2VideoEncoder(model_config)
    x = torch.randn(2, 3, 8, 224, 224, dtype=torch.float16).to(model_config.device)
    output = model(x)