File size: 2,352 Bytes
6ed4a9c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 | # Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmcv.cnn.bricks import LayerScale, Scale
def test_scale():
# test default scale
scale = Scale()
assert scale.scale.data == 1.
assert scale.scale.dtype == torch.float
x = torch.rand(1, 3, 64, 64)
output = scale(x)
assert output.shape == (1, 3, 64, 64)
# test given scale
scale = Scale(10.)
assert scale.scale.data == 10.
assert scale.scale.dtype == torch.float
x = torch.rand(1, 3, 64, 64)
output = scale(x)
assert output.shape == (1, 3, 64, 64)
def test_layer_scale():
with pytest.raises(AssertionError):
cfg = dict(
dim=10,
data_format='BNC',
)
LayerScale(**cfg)
# test init
cfg = dict(dim=10)
ls = LayerScale(**cfg)
assert torch.equal(ls.weight, torch.ones(10, requires_grad=True) * 1e-5)
# test forward
# test channels_last
cfg = dict(dim=256, inplace=False, data_format='channels_last')
ls_channels_last = LayerScale(**cfg)
x = torch.randn((4, 49, 256))
out = ls_channels_last(x)
assert tuple(out.size()) == (4, 49, 256)
assert torch.equal(x * 1e-5, out)
# test channels_last 2d
cfg = dict(dim=256, inplace=False, data_format='channels_last')
ls_channels_last = LayerScale(**cfg)
x = torch.randn((4, 7, 49, 256))
out = ls_channels_last(x)
assert tuple(out.size()) == (4, 7, 49, 256)
assert torch.equal(x * 1e-5, out)
# test channels_first
cfg = dict(dim=256, inplace=False, data_format='channels_first')
ls_channels_first = LayerScale(**cfg)
x = torch.randn((4, 256, 7, 7))
out = ls_channels_first(x)
assert tuple(out.size()) == (4, 256, 7, 7)
assert torch.equal(x * 1e-5, out)
# test channels_first 3D
cfg = dict(dim=256, inplace=False, data_format='channels_first')
ls_channels_first = LayerScale(**cfg)
x = torch.randn((4, 256, 7, 7, 7))
out = ls_channels_first(x)
assert tuple(out.size()) == (4, 256, 7, 7, 7)
assert torch.equal(x * 1e-5, out)
# test inplace True
cfg = dict(dim=256, inplace=True, data_format='channels_first')
ls_channels_first = LayerScale(**cfg)
x = torch.randn((4, 256, 7, 7))
out = ls_channels_first(x)
assert tuple(out.size()) == (4, 256, 7, 7)
assert x is out
|