File size: 835 Bytes
d670799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from mmaction.models import VisionTransformer


def test_vit_backbone():
    """Test vit backbone."""
    x = torch.randn(1, 3, 8, 64, 64)
    model = VisionTransformer(
        img_size=64,
        num_frames=8,
        qkv_bias=True,
        drop_path_rate=0.2,
        init_values=0.1)
    model.init_weights()

    assert model(x).shape == torch.Size([1, 768])
    model.eval()
    assert model(x).shape == torch.Size([1, 768])

    model = VisionTransformer(
        img_size=64,
        num_frames=8,
        use_learnable_pos_emb=True,
        drop_rate=0.1,
        use_mean_pooling=False)
    model.init_weights()

    assert model(x).shape == torch.Size([1, 768])
    model.eval()
    assert model(x).shape == torch.Size([1, 768])