|
|
|
|
|
import pytest
|
|
|
import torch
|
|
|
|
|
|
from mmaction.models.common import Conv2plus1d
|
|
|
|
|
|
|
|
|
def test_conv2plus1d():
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
|
|
Conv2plus1d(3, 8, (2, 2))
|
|
|
|
|
|
conv_2plus1d = Conv2plus1d(3, 8, 2)
|
|
|
conv_2plus1d.init_weights()
|
|
|
|
|
|
assert torch.equal(conv_2plus1d.bn_s.weight,
|
|
|
torch.ones_like(conv_2plus1d.bn_s.weight))
|
|
|
assert torch.equal(conv_2plus1d.bn_s.bias,
|
|
|
torch.zeros_like(conv_2plus1d.bn_s.bias))
|
|
|
|
|
|
x = torch.rand(1, 3, 8, 256, 256)
|
|
|
output = conv_2plus1d(x)
|
|
|
assert output.shape == torch.Size([1, 8, 7, 255, 255])
|
|
|
|