| # Copyright (c) OpenMMLab. All rights reserved. | |
| import pytest | |
| import torch | |
| from mmaction.models.common import TAM | |
| def test_TAM(): | |
| """test TAM.""" | |
| with pytest.raises(AssertionError): | |
| # alpha must be a positive integer | |
| TAM(16, 8, alpha=0, beta=4) | |
| with pytest.raises(AssertionError): | |
| # beta must be a positive integer | |
| TAM(16, 8, alpha=2, beta=0) | |
| with pytest.raises(AssertionError): | |
| # the channels number of x should be equal to self.in_channels of TAM | |
| tam = TAM(16, 8) | |
| x = torch.rand(64, 8, 112, 112) | |
| tam(x) | |
| tam = TAM(16, 8) | |
| x = torch.rand(32, 16, 112, 112) | |
| output = tam(x) | |
| assert output.shape == torch.Size([32, 16, 112, 112]) | |