File size: 743 Bytes
d670799 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmaction.models.common import Conv2plus1d
def test_conv2plus1d():
with pytest.raises(AssertionError):
# Length of kernel size, stride and padding must be the same
Conv2plus1d(3, 8, (2, 2))
conv_2plus1d = Conv2plus1d(3, 8, 2)
conv_2plus1d.init_weights()
assert torch.equal(conv_2plus1d.bn_s.weight,
torch.ones_like(conv_2plus1d.bn_s.weight))
assert torch.equal(conv_2plus1d.bn_s.bias,
torch.zeros_like(conv_2plus1d.bn_s.bias))
x = torch.rand(1, 3, 8, 256, 256)
output = conv_2plus1d(x)
assert output.shape == torch.Size([1, 8, 7, 255, 255])
|