| from transformers.modeling_utils import PreTrainedModel | |
| import torch | |
| from timesformer_pytorch import TimeSformer | |
| from .timesformer_config import TimesformerConfig | |
| class TimesformerModel(PreTrainedModel): | |
| config_class = TimesformerConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.backbone = TimeSformer( | |
| dim=config.dim, | |
| image_size=config.window_size, | |
| patch_size=config.patch_size, | |
| num_frames=config.num_frames, | |
| num_classes=config.num_classes, | |
| channels=1, | |
| depth=config.depth, | |
| heads=config.n_heads, | |
| dim_head=64, | |
| attn_dropout=0.1, | |
| ff_dropout=0.1, | |
| ) | |
| self.post_init() | |
| def forward(self, tensor): | |
| x = self.backbone(torch.permute(tensor, (0, 2, 1, 3, 4))) | |
| x = x.view(-1, 1, 4, 4) | |
| return x | |