|
|
| import os
|
| from collections import OrderedDict
|
| from typing import Dict, List, Optional, Union
|
|
|
| import torch
|
| from mmcv.cnn.bricks import DropPath
|
| from mmengine.logging import MMLogger
|
| from mmengine.model import BaseModule, ModuleList
|
| from mmengine.runner.checkpoint import _load_checkpoint
|
| from torch import nn
|
|
|
| from mmaction.registry import MODELS
|
|
|
| logger = MMLogger.get_current_instance()
|
|
|
| MODEL_PATH = 'https://download.openmmlab.com/mmaction/v1.0/recognition'
|
| _MODELS = {
|
| 'ViT-B/16':
|
| os.path.join(MODEL_PATH, 'uniformerv2/clipVisualEncoder',
|
| 'vit-base-p16-res224_clip-rgb_20221219-b8a5da86.pth'),
|
| 'ViT-L/14':
|
| os.path.join(MODEL_PATH, 'uniformerv2/clipVisualEncoder',
|
| 'vit-large-p14-res224_clip-rgb_20221219-9de7543e.pth'),
|
| 'ViT-L/14_336':
|
| os.path.join(MODEL_PATH, 'uniformerv2/clipVisualEncoder',
|
| 'vit-large-p14-res336_clip-rgb_20221219-d370f9e5.pth'),
|
| }
|
|
|
|
|
| class QuickGELU(BaseModule):
|
| """Quick GELU function. Forked from https://github.com/openai/CLIP/blob/d50
|
| d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py.
|
|
|
| Args:
|
| x (torch.Tensor): The input features of shape :math:`(B, N, C)`.
|
| """
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| return x * torch.sigmoid(1.702 * x)
|
|
|
|
|
| class Local_MHRA(BaseModule):
|
| """Local MHRA.
|
|
|
| Args:
|
| d_model (int): Number of input channels.
|
| dw_reduction (float): Downsample ratio of input channels.
|
| Defaults to 1.5.
|
| pos_kernel_size (int): Kernel size of local MHRA.
|
| Defaults to 3.
|
| init_cfg (dict, optional): The config of weight initialization.
|
| Defaults to None.
|
| """
|
|
|
| def __init__(
|
| self,
|
| d_model: int,
|
| dw_reduction: float = 1.5,
|
| pos_kernel_size: int = 3,
|
| init_cfg: Optional[dict] = None,
|
| ) -> None:
|
| super().__init__(init_cfg=init_cfg)
|
|
|
| padding = pos_kernel_size // 2
|
| re_d_model = int(d_model // dw_reduction)
|
| self.pos_embed = nn.Sequential(
|
| nn.BatchNorm3d(d_model),
|
| nn.Conv3d(d_model, re_d_model, kernel_size=1, stride=1, padding=0),
|
| nn.Conv3d(
|
| re_d_model,
|
| re_d_model,
|
| kernel_size=(pos_kernel_size, 1, 1),
|
| stride=(1, 1, 1),
|
| padding=(padding, 0, 0),
|
| groups=re_d_model),
|
| nn.Conv3d(re_d_model, d_model, kernel_size=1, stride=1, padding=0),
|
| )
|
|
|
|
|
| logger.info('Init zero for Conv in pos_emb')
|
| nn.init.constant_(self.pos_embed[3].weight, 0)
|
| nn.init.constant_(self.pos_embed[3].bias, 0)
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| return self.pos_embed(x)
|
|
|
|
|
| class ResidualAttentionBlock(BaseModule):
|
| """Local UniBlock.
|
|
|
| Args:
|
| d_model (int): Number of input channels.
|
| n_head (int): Number of attention head.
|
| drop_path (float): Stochastic depth rate.
|
| Defaults to 0.0.
|
| dw_reduction (float): Downsample ratio of input channels.
|
| Defaults to 1.5.
|
| no_lmhra (bool): Whether removing local MHRA.
|
| Defaults to False.
|
| double_lmhra (bool): Whether using double local MHRA.
|
| Defaults to True.
|
| init_cfg (dict, optional): The config of weight initialization.
|
| Defaults to None.
|
| """
|
|
|
| def __init__(
|
| self,
|
| d_model: int,
|
| n_head: int,
|
| drop_path: float = 0.0,
|
| dw_reduction: float = 1.5,
|
| no_lmhra: bool = False,
|
| double_lmhra: bool = True,
|
| init_cfg: Optional[dict] = None,
|
| ) -> None:
|
| super().__init__(init_cfg=init_cfg)
|
|
|
| self.n_head = n_head
|
| self.drop_path = DropPath(
|
| drop_path) if drop_path > 0. else nn.Identity()
|
| logger.info(f'Drop path rate: {drop_path}')
|
|
|
| self.no_lmhra = no_lmhra
|
| self.double_lmhra = double_lmhra
|
| logger.info(f'No L_MHRA: {no_lmhra}')
|
| logger.info(f'Double L_MHRA: {double_lmhra}')
|
| if not no_lmhra:
|
| self.lmhra1 = Local_MHRA(d_model, dw_reduction=dw_reduction)
|
| if double_lmhra:
|
| self.lmhra2 = Local_MHRA(d_model, dw_reduction=dw_reduction)
|
|
|
|
|
| self.attn = nn.MultiheadAttention(d_model, n_head)
|
| self.ln_1 = nn.LayerNorm(d_model)
|
| self.mlp = nn.Sequential(
|
| OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
|
| ('gelu', QuickGELU()),
|
| ('c_proj', nn.Linear(d_model * 4, d_model))]))
|
| self.ln_2 = nn.LayerNorm(d_model)
|
|
|
| def attention(self, x: torch.Tensor) -> torch.Tensor:
|
| return self.attn(x, x, x, need_weights=False, attn_mask=None)[0]
|
|
|
| def forward(self, x: torch.Tensor, T: int = 8) -> torch.Tensor:
|
|
|
| if not self.no_lmhra:
|
|
|
| tmp_x = x[1:, :, :]
|
| L, NT, C = tmp_x.shape
|
| N = NT // T
|
| H = W = int(L**0.5)
|
| tmp_x = tmp_x.view(H, W, N, T, C).permute(2, 4, 3, 0,
|
| 1).contiguous()
|
| tmp_x = tmp_x + self.drop_path(self.lmhra1(tmp_x))
|
| tmp_x = tmp_x.view(N, C, T,
|
| L).permute(3, 0, 2,
|
| 1).contiguous().view(L, NT, C)
|
| x = torch.cat([x[:1, :, :], tmp_x], dim=0)
|
|
|
| x = x + self.drop_path(self.attention(self.ln_1(x)))
|
|
|
| if not self.no_lmhra and self.double_lmhra:
|
| tmp_x = x[1:, :, :]
|
| tmp_x = tmp_x.view(H, W, N, T, C).permute(2, 4, 3, 0,
|
| 1).contiguous()
|
| tmp_x = tmp_x + self.drop_path(self.lmhra2(tmp_x))
|
| tmp_x = tmp_x.view(N, C, T,
|
| L).permute(3, 0, 2,
|
| 1).contiguous().view(L, NT, C)
|
| x = torch.cat([x[:1, :, :], tmp_x], dim=0)
|
|
|
| x = x + self.drop_path(self.mlp(self.ln_2(x)))
|
| return x
|
|
|
|
|
| class Extractor(BaseModule):
|
| """Global UniBlock.
|
|
|
| Args:
|
| d_model (int): Number of input channels.
|
| n_head (int): Number of attention head.
|
| mlp_factor (float): Ratio of hidden dimensions in MLP layers.
|
| Defaults to 4.0.
|
| drop_out (float): Stochastic dropout rate.
|
| Defaults to 0.0.
|
| drop_path (float): Stochastic depth rate.
|
| Defaults to 0.0.
|
| init_cfg (dict, optional): The config of weight initialization.
|
| Defaults to None.
|
| """
|
|
|
| def __init__(
|
| self,
|
| d_model: int,
|
| n_head: int,
|
| mlp_factor: float = 4.0,
|
| dropout: float = 0.0,
|
| drop_path: float = 0.0,
|
| init_cfg: Optional[dict] = None,
|
| ) -> None:
|
| super().__init__(init_cfg=init_cfg)
|
|
|
| self.drop_path = DropPath(
|
| drop_path) if drop_path > 0. else nn.Identity()
|
| logger.info(f'Drop path rate: {drop_path}')
|
| self.attn = nn.MultiheadAttention(d_model, n_head)
|
| self.ln_1 = nn.LayerNorm(d_model)
|
| d_mlp = round(mlp_factor * d_model)
|
| self.mlp = nn.Sequential(
|
| OrderedDict([('c_fc', nn.Linear(d_model, d_mlp)),
|
| ('gelu', QuickGELU()),
|
| ('dropout', nn.Dropout(dropout)),
|
| ('c_proj', nn.Linear(d_mlp, d_model))]))
|
| self.ln_2 = nn.LayerNorm(d_model)
|
| self.ln_3 = nn.LayerNorm(d_model)
|
|
|
|
|
| nn.init.xavier_uniform_(self.attn.in_proj_weight)
|
| nn.init.constant_(self.attn.out_proj.weight, 0.)
|
| nn.init.constant_(self.attn.out_proj.bias, 0.)
|
| nn.init.xavier_uniform_(self.mlp[0].weight)
|
| nn.init.constant_(self.mlp[-1].weight, 0.)
|
| nn.init.constant_(self.mlp[-1].bias, 0.)
|
|
|
| def attention(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
| d_model = self.ln_1.weight.size(0)
|
| q = (x @ self.attn.in_proj_weight[:d_model].T
|
| ) + self.attn.in_proj_bias[:d_model]
|
|
|
| k = (y @ self.attn.in_proj_weight[d_model:-d_model].T
|
| ) + self.attn.in_proj_bias[d_model:-d_model]
|
| v = (y @ self.attn.in_proj_weight[-d_model:].T
|
| ) + self.attn.in_proj_bias[-d_model:]
|
| Tx, Ty, N = q.size(0), k.size(0), q.size(1)
|
| q = q.view(Tx, N, self.attn.num_heads,
|
| self.attn.head_dim).permute(1, 2, 0, 3)
|
| k = k.view(Ty, N, self.attn.num_heads,
|
| self.attn.head_dim).permute(1, 2, 0, 3)
|
| v = v.view(Ty, N, self.attn.num_heads,
|
| self.attn.head_dim).permute(1, 2, 0, 3)
|
| aff = (q @ k.transpose(-2, -1) / (self.attn.head_dim**0.5))
|
|
|
| aff = aff.softmax(dim=-1)
|
| out = aff @ v
|
| out = out.permute(2, 0, 1, 3).flatten(2)
|
| out = self.attn.out_proj(out)
|
| return out
|
|
|
| def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
| x = x + self.drop_path(self.attention(self.ln_1(x), self.ln_3(y)))
|
| x = x + self.drop_path(self.mlp(self.ln_2(x)))
|
| return x
|
|
|
|
|
| class Transformer(BaseModule):
|
| """Backbone:
|
|
|
| Args:
|
| width (int): Number of input channels in local UniBlock.
|
| layers (int): Number of layers of local UniBlock.
|
| heads (int): Number of attention head in local UniBlock.
|
| backbone_drop_path_rate (float): Stochastic depth rate
|
| in local UniBlock. Defaults to 0.0.
|
| t_size (int): Number of temporal dimension after patch embedding.
|
| Defaults to 8.
|
| dw_reduction (float): Downsample ratio of input channels in local MHRA.
|
| Defaults to 1.5.
|
| no_lmhra (bool): Whether removing local MHRA in local UniBlock.
|
| Defaults to False.
|
| double_lmhra (bool): Whether using double local MHRA
|
| in local UniBlock. Defaults to True.
|
| return_list (List[int]): Layer index of input features
|
| for global UniBlock. Defaults to [8, 9, 10, 11].
|
| n_dim (int): Number of layers of global UniBlock.
|
| Defaults to 4.
|
| n_dim (int): Number of layers of global UniBlock.
|
| Defaults to 4.
|
| n_dim (int): Number of input channels in global UniBlock.
|
| Defaults to 768.
|
| n_head (int): Number of attention head in global UniBlock.
|
| Defaults to 12.
|
| mlp_factor (float): Ratio of hidden dimensions in MLP layers
|
| in global UniBlock. Defaults to 4.0.
|
| drop_path_rate (float): Stochastic depth rate in global UniBlock.
|
| Defaults to 0.0.
|
| mlp_dropout (List[float]): Stochastic dropout rate in each MLP layer
|
| in global UniBlock. Defaults to [0.5, 0.5, 0.5, 0.5].
|
| init_cfg (dict, optional): The config of weight initialization.
|
| Defaults to None.
|
| """
|
|
|
| def __init__(
|
| self,
|
| width: int,
|
| layers: int,
|
| heads: int,
|
| backbone_drop_path_rate: float = 0.,
|
| t_size: int = 8,
|
| dw_reduction: float = 1.5,
|
| no_lmhra: bool = True,
|
| double_lmhra: bool = False,
|
| return_list: List[int] = [8, 9, 10, 11],
|
| n_layers: int = 4,
|
| n_dim: int = 768,
|
| n_head: int = 12,
|
| mlp_factor: float = 4.0,
|
| drop_path_rate: float = 0.,
|
| mlp_dropout: List[float] = [0.5, 0.5, 0.5, 0.5],
|
| init_cfg: Optional[dict] = None,
|
| ) -> None:
|
| super().__init__(init_cfg=init_cfg)
|
|
|
| self.T = t_size
|
| self.return_list = return_list
|
|
|
| b_dpr = [
|
| x.item()
|
| for x in torch.linspace(0, backbone_drop_path_rate, layers)
|
| ]
|
| self.resblocks = ModuleList([
|
| ResidualAttentionBlock(
|
| width,
|
| heads,
|
| drop_path=b_dpr[i],
|
| dw_reduction=dw_reduction,
|
| no_lmhra=no_lmhra,
|
| double_lmhra=double_lmhra,
|
| ) for i in range(layers)
|
| ])
|
|
|
|
|
| assert n_layers == len(return_list)
|
| self.temporal_cls_token = nn.Parameter(torch.zeros(1, 1, n_dim))
|
| self.dpe = ModuleList([
|
| nn.Conv3d(
|
| n_dim,
|
| n_dim,
|
| kernel_size=3,
|
| stride=1,
|
| padding=1,
|
| bias=True,
|
| groups=n_dim) for _ in range(n_layers)
|
| ])
|
| for m in self.dpe:
|
| nn.init.constant_(m.bias, 0.)
|
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)]
|
| self.dec = ModuleList([
|
| Extractor(
|
| n_dim,
|
| n_head,
|
| mlp_factor=mlp_factor,
|
| dropout=mlp_dropout[i],
|
| drop_path=dpr[i],
|
| ) for i in range(n_layers)
|
| ])
|
|
|
| self.norm = nn.LayerNorm(n_dim)
|
| self.balance = nn.Parameter(torch.zeros((n_dim)))
|
| self.sigmoid = nn.Sigmoid()
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| T_down = self.T
|
| L, NT, C = x.shape
|
| N = NT // T_down
|
| H = W = int((L - 1)**0.5)
|
| cls_token = self.temporal_cls_token.repeat(1, N, 1)
|
|
|
| j = -1
|
| for i, resblock in enumerate(self.resblocks):
|
| x = resblock(x, T_down)
|
| if i in self.return_list:
|
| j += 1
|
| tmp_x = x.clone()
|
| tmp_x = tmp_x.view(L, N, T_down, C)
|
|
|
| _, tmp_feats = tmp_x[:1], tmp_x[1:]
|
| tmp_feats = tmp_feats.permute(1, 3, 2,
|
| 0).reshape(N, C, T_down, H, W)
|
| tmp_feats = self.dpe[j](tmp_feats.clone()).view(
|
| N, C, T_down, L - 1).permute(3, 0, 2, 1).contiguous()
|
| tmp_x[1:] = tmp_x[1:] + tmp_feats
|
|
|
| tmp_x = tmp_x.permute(2, 0, 1, 3).flatten(0, 1)
|
| cls_token = self.dec[j](cls_token, tmp_x)
|
|
|
| weight = self.sigmoid(self.balance)
|
| residual = x.view(L, N, T_down, C)[0].mean(1)
|
| out = self.norm((1 - weight) * cls_token[0, :, :] + weight * residual)
|
| return out
|
|
|
|
|
| @MODELS.register_module()
|
| class UniFormerV2(BaseModule):
|
| """UniFormerV2:
|
|
|
| A pytorch implement of: `UniFormerV2: Spatiotemporal
|
| Learning by Arming Image ViTs with Video UniFormer
|
| <https://arxiv.org/abs/2211.09552>`
|
|
|
| Args:
|
| input_resolution (int): Number of input resolution.
|
| Defaults to 224.
|
| patch_size (int): Number of patch size.
|
| Defaults to 16.
|
| width (int): Number of input channels in local UniBlock.
|
| Defaults to 768.
|
| layers (int): Number of layers of local UniBlock.
|
| Defaults to 12.
|
| heads (int): Number of attention head in local UniBlock.
|
| Defaults to 12.
|
| backbone_drop_path_rate (float): Stochastic depth rate
|
| in local UniBlock. Defaults to 0.0.
|
| t_size (int): Number of temporal dimension after patch embedding.
|
| Defaults to 8.
|
| temporal_downsample (bool): Whether downsampling temporal dimentison.
|
| Defaults to False.
|
| dw_reduction (float): Downsample ratio of input channels in local MHRA.
|
| Defaults to 1.5.
|
| no_lmhra (bool): Whether removing local MHRA in local UniBlock.
|
| Defaults to False.
|
| double_lmhra (bool): Whether using double local MHRA in local UniBlock.
|
| Defaults to True.
|
| return_list (List[int]): Layer index of input features
|
| for global UniBlock. Defaults to [8, 9, 10, 11].
|
| n_dim (int): Number of layers of global UniBlock.
|
| Defaults to 4.
|
| n_dim (int): Number of layers of global UniBlock.
|
| Defaults to 4.
|
| n_dim (int): Number of input channels in global UniBlock.
|
| Defaults to 768.
|
| n_head (int): Number of attention head in global UniBlock.
|
| Defaults to 12.
|
| mlp_factor (float): Ratio of hidden dimensions in MLP layers
|
| in global UniBlock. Defaults to 4.0.
|
| drop_path_rate (float): Stochastic depth rate in global UniBlock.
|
| Defaults to 0.0.
|
| mlp_dropout (List[float]): Stochastic dropout rate in each MLP layer
|
| in global UniBlock. Defaults to [0.5, 0.5, 0.5, 0.5].
|
| clip_pretrained (bool): Whether to load pretrained CLIP visual encoder.
|
| Defaults to True.
|
| pretrained (str): Name of pretrained model.
|
| Defaults to None.
|
| init_cfg (dict or list[dict]): Initialization config dict. Defaults to
|
| ``[
|
| dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
|
| dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
| ]``.
|
| """
|
|
|
| def __init__(
|
| self,
|
|
|
| input_resolution: int = 224,
|
| patch_size: int = 16,
|
| width: int = 768,
|
| layers: int = 12,
|
| heads: int = 12,
|
| backbone_drop_path_rate: float = 0.,
|
| t_size: int = 8,
|
| kernel_size: int = 3,
|
| dw_reduction: float = 1.5,
|
| temporal_downsample: bool = False,
|
| no_lmhra: bool = True,
|
| double_lmhra: bool = False,
|
|
|
| return_list: List[int] = [8, 9, 10, 11],
|
| n_layers: int = 4,
|
| n_dim: int = 768,
|
| n_head: int = 12,
|
| mlp_factor: float = 4.0,
|
| drop_path_rate: float = 0.,
|
| mlp_dropout: List[float] = [0.5, 0.5, 0.5, 0.5],
|
|
|
| clip_pretrained: bool = True,
|
| pretrained: Optional[str] = None,
|
| init_cfg: Optional[Union[Dict, List[Dict]]] = [
|
| dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
|
| dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
| ]
|
| ) -> None:
|
| super().__init__(init_cfg=init_cfg)
|
|
|
| self.pretrained = pretrained
|
| self.clip_pretrained = clip_pretrained
|
| self.input_resolution = input_resolution
|
| padding = (kernel_size - 1) // 2
|
| if temporal_downsample:
|
| self.conv1 = nn.Conv3d(
|
| 3,
|
| width, (kernel_size, patch_size, patch_size),
|
| (2, patch_size, patch_size), (padding, 0, 0),
|
| bias=False)
|
| t_size = t_size // 2
|
| else:
|
| self.conv1 = nn.Conv3d(
|
| 3,
|
| width, (1, patch_size, patch_size),
|
| (1, patch_size, patch_size), (0, 0, 0),
|
| bias=False)
|
|
|
| scale = width**-0.5
|
| self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
| self.positional_embedding = nn.Parameter(scale * torch.randn(
|
| (input_resolution // patch_size)**2 + 1, width))
|
| self.ln_pre = nn.LayerNorm(width)
|
|
|
| self.transformer = Transformer(
|
| width,
|
| layers,
|
| heads,
|
| dw_reduction=dw_reduction,
|
| backbone_drop_path_rate=backbone_drop_path_rate,
|
| t_size=t_size,
|
| no_lmhra=no_lmhra,
|
| double_lmhra=double_lmhra,
|
| return_list=return_list,
|
| n_layers=n_layers,
|
| n_dim=n_dim,
|
| n_head=n_head,
|
| mlp_factor=mlp_factor,
|
| drop_path_rate=drop_path_rate,
|
| mlp_dropout=mlp_dropout,
|
| )
|
|
|
| def _inflate_weight(self,
|
| weight_2d: torch.Tensor,
|
| time_dim: int,
|
| center: bool = True) -> torch.Tensor:
|
| logger.info(f'Init center: {center}')
|
| if center:
|
| weight_3d = torch.zeros(*weight_2d.shape)
|
| weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
|
| middle_idx = time_dim // 2
|
| weight_3d[:, :, middle_idx, :, :] = weight_2d
|
| else:
|
| weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
|
| weight_3d = weight_3d / time_dim
|
| return weight_3d
|
|
|
| def _load_pretrained(self, pretrained: str = None) -> None:
|
| """Load CLIP pretrained visual encoder.
|
|
|
| The visual encoder is extracted from CLIP.
|
| https://github.com/openai/CLIP
|
|
|
| Args:
|
| pretrained (str): Model name of pretrained CLIP visual encoder.
|
| Defaults to None.
|
| """
|
| assert pretrained is not None, \
|
| 'please specify clip pretraied checkpoint'
|
|
|
| model_path = _MODELS[pretrained]
|
| logger.info(f'Load CLIP pretrained model from {model_path}')
|
| state_dict = _load_checkpoint(model_path, map_location='cpu')
|
| state_dict_3d = self.state_dict()
|
| for k in state_dict.keys():
|
| if k in state_dict_3d.keys(
|
| ) and state_dict[k].shape != state_dict_3d[k].shape:
|
| if len(state_dict_3d[k].shape) <= 2:
|
| logger.info(f'Ignore: {k}')
|
| continue
|
| logger.info(f'Inflate: {k}, {state_dict[k].shape}' +
|
| f' => {state_dict_3d[k].shape}')
|
| time_dim = state_dict_3d[k].shape[2]
|
| state_dict[k] = self._inflate_weight(state_dict[k], time_dim)
|
| self.load_state_dict(state_dict, strict=False)
|
|
|
| def init_weights(self):
|
| """Initialize the weights in backbone."""
|
| if self.clip_pretrained:
|
| logger = MMLogger.get_current_instance()
|
| logger.info(f'load model from: {self.pretrained}')
|
| self._load_pretrained(self.pretrained)
|
| else:
|
| if self.pretrained:
|
| self.init_cfg = dict(
|
| type='Pretrained', checkpoint=self.pretrained)
|
| super().init_weights()
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| x = self.conv1(x)
|
| N, C, T, H, W = x.shape
|
| x = x.permute(0, 2, 3, 4, 1).reshape(N * T, H * W, C)
|
|
|
| x = torch.cat([
|
| self.class_embedding.to(x.dtype) + torch.zeros(
|
| x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x
|
| ],
|
| dim=1)
|
| x = x + self.positional_embedding.to(x.dtype)
|
| x = self.ln_pre(x)
|
|
|
| x = x.permute(1, 0, 2)
|
| out = self.transformer(x)
|
| return out
|
|
|