| from .timesformer_config import TimesformerScrollprizeConfig |
| from transformers import PreTrainedModel |
| import torch |
| from timesformer_pytorch import TimeSformer |
|
|
| |
|
|
|
|
|
|
|
|
| class TimesformerScrollprizeModel(PreTrainedModel): |
| config_class = TimesformerScrollprizeConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.backbone=TimeSformer( |
| dim = config.dim, |
| image_size = config.window_size, |
| patch_size = config.patch_size, |
| num_frames = config.num_frames, |
| num_classes = config.num_classes, |
| channels=1, |
| depth = config.depth, |
| heads = config.n_heads, |
| dim_head = 64, |
| attn_dropout = 0.1, |
| ff_dropout = 0.1 |
| ) |
|
|
| def forward(self, tensor): |
| x = self.backbone(torch.permute(tensor, (0, 2, 1,3,4))) |
| x=x.view(-1,1,4,4) |
| return x |