from .timesformer_config import TimesformerScrollprizeConfig from transformers import PreTrainedModel import torch from timesformer_pytorch import TimeSformer # from .configuration_resnet import ResnetConfig class TimesformerScrollprizeModel(PreTrainedModel): config_class = TimesformerScrollprizeConfig def __init__(self, config): super().__init__(config) self.backbone=TimeSformer( dim = config.dim, image_size = config.window_size, patch_size = config.patch_size, num_frames = config.num_frames, num_classes = config.num_classes, channels=1, depth = config.depth, heads = config.n_heads, dim_head = 64, attn_dropout = 0.1, ff_dropout = 0.1 ) def forward(self, tensor): x = self.backbone(torch.permute(tensor, (0, 2, 1,3,4))) x=x.view(-1,1,4,4) return x