File size: 461 Bytes
d670799 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmaction.models.common import ConvAudio
def test_conv_audio():
conv_audio = ConvAudio(3, 8, 3)
conv_audio.init_weights()
x = torch.rand(1, 3, 8, 8)
output = conv_audio(x)
assert output.shape == torch.Size([1, 16, 8, 8])
conv_audio_sum = ConvAudio(3, 8, 3, op='sum')
output = conv_audio_sum(x)
assert output.shape == torch.Size([1, 8, 8, 8])
|