from transformers.modeling_utils import PreTrainedModel import torch from timesformer_pytorch import TimeSformer from .timesformer_config import TimesformerConfig class TimesformerModel(PreTrainedModel): config_class = TimesformerConfig def __init__(self, config): super().__init__(config) self.backbone = TimeSformer( dim=config.dim, image_size=config.window_size, patch_size=config.patch_size, num_frames=config.num_frames, num_classes=config.num_classes, channels=1, depth=config.depth, heads=config.n_heads, dim_head=64, attn_dropout=0.1, ff_dropout=0.1, ) self.post_init() def forward(self, tensor): x = self.backbone(torch.permute(tensor, (0, 2, 1, 3, 4))) x = x.view(-1, 1, 4, 4) return x