| |
| import pytest |
| import torch |
|
|
| from mmcv.cnn.bricks import HSigmoid |
|
|
|
|
| def test_hsigmoid(): |
| |
| with pytest.raises(AssertionError): |
| HSigmoid(divisor=0) |
|
|
| |
| act = HSigmoid() |
| input_shape = torch.Size([1, 3, 64, 64]) |
| input = torch.randn(input_shape) |
| output = act(input) |
| expected_output = torch.min( |
| torch.max((input + 3) / 6, torch.zeros(input_shape)), |
| torch.ones(input_shape)) |
| |
| assert output.shape == expected_output.shape |
| |
| assert torch.equal(output, expected_output) |
|
|
| |
| act = HSigmoid(1, 2, 0, 1) |
| input_shape = torch.Size([1, 3, 64, 64]) |
| input = torch.randn(input_shape) |
| output = act(input) |
| expected_output = torch.min( |
| torch.max((input + 1) / 2, torch.zeros(input_shape)), |
| torch.ones(input_shape)) |
| |
| assert output.shape == expected_output.shape |
| |
| assert torch.equal(output, expected_output) |
|
|