paranke's picture
Upload folder using huggingface_hub
475ab49 verified
raw
history blame contribute delete
900 Bytes
from transformers.modeling_utils import PreTrainedModel
import torch
from timesformer_pytorch import TimeSformer
from .timesformer_config import TimesformerConfig
class TimesformerModel(PreTrainedModel):
config_class = TimesformerConfig
def __init__(self, config):
super().__init__(config)
self.backbone = TimeSformer(
dim=config.dim,
image_size=config.window_size,
patch_size=config.patch_size,
num_frames=config.num_frames,
num_classes=config.num_classes,
channels=1,
depth=config.depth,
heads=config.n_heads,
dim_head=64,
attn_dropout=0.1,
ff_dropout=0.1,
)
self.post_init()
def forward(self, tensor):
x = self.backbone(torch.permute(tensor, (0, 2, 1, 3, 4)))
x = x.view(-1, 1, 4, 4)
return x