|
|
from .timesformer_config import TimesformerScrollprizeConfig |
|
|
from transformers import PreTrainedModel |
|
|
import torch |
|
|
from timesformer_pytorch import TimeSformer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TimesformerScrollprizeModel(PreTrainedModel): |
|
|
config_class = TimesformerScrollprizeConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.backbone=TimeSformer( |
|
|
dim = config.dim, |
|
|
image_size = config.window_size, |
|
|
patch_size = config.patch_size, |
|
|
num_frames = config.num_frames, |
|
|
num_classes = config.num_classes, |
|
|
channels=1, |
|
|
depth = config.depth, |
|
|
heads = config.n_heads, |
|
|
dim_head = 64, |
|
|
attn_dropout = 0.1, |
|
|
ff_dropout = 0.1 |
|
|
) |
|
|
|
|
|
def forward(self, tensor): |
|
|
x = self.backbone(torch.permute(tensor, (0, 2, 1,3,4))) |
|
|
x=x.view(-1,1,4,4) |
|
|
return x |