File size: 4,421 Bytes
d670799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm

from mmaction.models import X3D
from mmaction.testing import check_norm_state, generate_backbone_demo_inputs


def test_x3d_backbone():
    """Test x3d backbone."""
    with pytest.raises(AssertionError):
        # In X3D: 1 <= num_stages <= 4
        X3D(gamma_w=1.0, gamma_b=2.25, gamma_d=2.2, num_stages=0)

    with pytest.raises(AssertionError):
        # In X3D: 1 <= num_stages <= 4
        X3D(gamma_w=1.0, gamma_b=2.25, gamma_d=2.2, num_stages=5)

    with pytest.raises(AssertionError):
        # len(spatial_strides) == num_stages
        X3D(gamma_w=1.0,
            gamma_b=2.25,
            gamma_d=2.2,
            spatial_strides=(1, 2),
            num_stages=4)

    with pytest.raises(AssertionError):
        # se_style in ['half', 'all']
        X3D(gamma_w=1.0, gamma_b=2.25, gamma_d=2.2, se_style=None)

    with pytest.raises(AssertionError):
        # se_ratio should be None or > 0
        X3D(gamma_w=1.0,
            gamma_b=2.25,
            gamma_d=2.2,
            se_style='half',
            se_ratio=0)

    # x3d_s, no pretrained, norm_eval True
    x3d_s = X3D(gamma_w=1.0, gamma_b=2.25, gamma_d=2.2, norm_eval=True)
    x3d_s.init_weights()
    x3d_s.train()
    assert check_norm_state(x3d_s.modules(), False)

    # x3d_l, no pretrained, norm_eval True
    x3d_l = X3D(gamma_w=1.0, gamma_b=2.25, gamma_d=5.0, norm_eval=True)
    x3d_l.init_weights()
    x3d_l.train()
    assert check_norm_state(x3d_l.modules(), False)

    # x3d_s, no pretrained, norm_eval False
    x3d_s = X3D(gamma_w=1.0, gamma_b=2.25, gamma_d=2.2, norm_eval=False)
    x3d_s.init_weights()
    x3d_s.train()
    assert check_norm_state(x3d_s.modules(), True)

    # x3d_l, no pretrained, norm_eval False
    x3d_l = X3D(gamma_w=1.0, gamma_b=2.25, gamma_d=5.0, norm_eval=False)
    x3d_l.init_weights()
    x3d_l.train()
    assert check_norm_state(x3d_l.modules(), True)

    # x3d_s, no pretrained, frozen_stages, norm_eval False
    frozen_stages = 1
    x3d_s_frozen = X3D(
        gamma_w=1.0,
        gamma_b=2.25,
        gamma_d=2.2,
        norm_eval=False,
        frozen_stages=frozen_stages)

    x3d_s_frozen.init_weights()
    x3d_s_frozen.train()
    assert x3d_s_frozen.conv1_t.bn.training is False
    for param in x3d_s_frozen.conv1_s.parameters():
        assert param.requires_grad is False
    for param in x3d_s_frozen.conv1_t.parameters():
        assert param.requires_grad is False

    for i in range(1, frozen_stages + 1):
        layer = getattr(x3d_s_frozen, f'layer{i}')
        for mod in layer.modules():
            if isinstance(mod, _BatchNorm):
                assert mod.training is False
        for param in layer.parameters():
            assert param.requires_grad is False

    # test zero_init_residual, zero_init_residual is True by default
    for m in x3d_s_frozen.modules():
        if hasattr(m, 'conv3'):
            assert torch.equal(m.conv3.bn.weight,
                               torch.zeros_like(m.conv3.bn.weight))
            assert torch.equal(m.conv3.bn.bias,
                               torch.zeros_like(m.conv3.bn.bias))

    # x3d_s inference
    input_shape = (1, 3, 13, 64, 64)
    imgs = generate_backbone_demo_inputs(input_shape)
    # parrots 3dconv is only implemented on gpu
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            x3d_s_frozen = x3d_s_frozen.cuda()
            imgs_gpu = imgs.cuda()
            feat = x3d_s_frozen(imgs_gpu)
            assert feat.shape == torch.Size([1, 432, 13, 2, 2])
    else:
        feat = x3d_s_frozen(imgs)
        assert feat.shape == torch.Size([1, 432, 13, 2, 2])

    # x3d_m inference
    input_shape = (1, 3, 16, 96, 96)
    imgs = generate_backbone_demo_inputs(input_shape)
    # parrots 3dconv is only implemented on gpu
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            x3d_s_frozen = x3d_s_frozen.cuda()
            imgs_gpu = imgs.cuda()
            feat = x3d_s_frozen(imgs_gpu)
            assert feat.shape == torch.Size([1, 432, 16, 3, 3])
    else:
        feat = x3d_s_frozen(imgs)
        assert feat.shape == torch.Size([1, 432, 16, 3, 3])