| import pytest |
| import torch |
| from torch.nn.modules.batchnorm import _BatchNorm |
|
|
| from mmdet.models.necks import FPN, ChannelMapper, DilatedEncoder |
|
|
|
|
| def test_fpn(): |
| """Tests fpn.""" |
| s = 64 |
| in_channels = [8, 16, 32, 64] |
| feat_sizes = [s // 2**i for i in range(4)] |
| out_channels = 8 |
| |
| with pytest.raises(AssertionError): |
| FPN(in_channels=in_channels, |
| out_channels=out_channels, |
| start_level=1, |
| num_outs=2) |
|
|
| |
| with pytest.raises(AssertionError): |
| FPN(in_channels=in_channels, |
| out_channels=out_channels, |
| start_level=1, |
| end_level=4, |
| num_outs=2) |
|
|
| |
| with pytest.raises(AssertionError): |
| FPN(in_channels=in_channels, |
| out_channels=out_channels, |
| start_level=1, |
| end_level=3, |
| num_outs=1) |
|
|
| |
| with pytest.raises(AssertionError): |
| FPN(in_channels=in_channels, |
| out_channels=out_channels, |
| start_level=1, |
| add_extra_convs='on_xxx', |
| num_outs=5) |
|
|
| fpn_model = FPN( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| start_level=1, |
| add_extra_convs=True, |
| num_outs=5) |
|
|
| |
| feats = [ |
| torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i]) |
| for i in range(len(in_channels)) |
| ] |
| outs = fpn_model(feats) |
| assert fpn_model.add_extra_convs == 'on_input' |
| assert len(outs) == fpn_model.num_outs |
| for i in range(fpn_model.num_outs): |
| outs[i].shape[1] == out_channels |
| outs[i].shape[2] == outs[i].shape[3] == s // (2**i) |
|
|
| |
| fpn_model = FPN( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| start_level=1, |
| add_extra_convs=False, |
| num_outs=5) |
| outs = fpn_model(feats) |
| assert len(outs) == fpn_model.num_outs |
| assert not fpn_model.add_extra_convs |
| for i in range(fpn_model.num_outs): |
| outs[i].shape[1] == out_channels |
| outs[i].shape[2] == outs[i].shape[3] == s // (2**i) |
|
|
| |
| fpn_model = FPN( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| start_level=1, |
| add_extra_convs=True, |
| no_norm_on_lateral=False, |
| norm_cfg=dict(type='BN', requires_grad=True), |
| num_outs=5) |
| outs = fpn_model(feats) |
| assert len(outs) == fpn_model.num_outs |
| assert fpn_model.add_extra_convs == 'on_input' |
| for i in range(fpn_model.num_outs): |
| outs[i].shape[1] == out_channels |
| outs[i].shape[2] == outs[i].shape[3] == s // (2**i) |
| bn_exist = False |
| for m in fpn_model.modules(): |
| if isinstance(m, _BatchNorm): |
| bn_exist = True |
| assert bn_exist |
|
|
| |
| fpn_model = FPN( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| start_level=1, |
| add_extra_convs=True, |
| upsample_cfg=dict(mode='bilinear', align_corners=True), |
| num_outs=5) |
| fpn_model(feats) |
| outs = fpn_model(feats) |
| assert len(outs) == fpn_model.num_outs |
| assert fpn_model.add_extra_convs == 'on_input' |
| for i in range(fpn_model.num_outs): |
| outs[i].shape[1] == out_channels |
| outs[i].shape[2] == outs[i].shape[3] == s // (2**i) |
|
|
| |
| fpn_model = FPN( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| start_level=1, |
| add_extra_convs=True, |
| upsample_cfg=dict(scale_factor=2), |
| num_outs=5) |
| outs = fpn_model(feats) |
| assert len(outs) == fpn_model.num_outs |
| for i in range(fpn_model.num_outs): |
| outs[i].shape[1] == out_channels |
| outs[i].shape[2] == outs[i].shape[3] == s // (2**i) |
|
|
| |
| fpn_model = FPN( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| add_extra_convs='on_input', |
| start_level=1, |
| num_outs=5) |
| assert fpn_model.add_extra_convs == 'on_input' |
| outs = fpn_model(feats) |
| assert len(outs) == fpn_model.num_outs |
| for i in range(fpn_model.num_outs): |
| outs[i].shape[1] == out_channels |
| outs[i].shape[2] == outs[i].shape[3] == s // (2**i) |
|
|
| |
| fpn_model = FPN( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| add_extra_convs='on_lateral', |
| start_level=1, |
| num_outs=5) |
| assert fpn_model.add_extra_convs == 'on_lateral' |
| outs = fpn_model(feats) |
| assert len(outs) == fpn_model.num_outs |
| for i in range(fpn_model.num_outs): |
| outs[i].shape[1] == out_channels |
| outs[i].shape[2] == outs[i].shape[3] == s // (2**i) |
|
|
| |
| fpn_model = FPN( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| add_extra_convs='on_output', |
| start_level=1, |
| num_outs=5) |
| assert fpn_model.add_extra_convs == 'on_output' |
| outs = fpn_model(feats) |
| assert len(outs) == fpn_model.num_outs |
| for i in range(fpn_model.num_outs): |
| outs[i].shape[1] == out_channels |
| outs[i].shape[2] == outs[i].shape[3] == s // (2**i) |
|
|
| |
| fpn_model = FPN( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| add_extra_convs=True, |
| extra_convs_on_inputs=False, |
| start_level=1, |
| num_outs=5, |
| ) |
| assert fpn_model.add_extra_convs == 'on_output' |
| outs = fpn_model(feats) |
| assert len(outs) == fpn_model.num_outs |
| for i in range(fpn_model.num_outs): |
| outs[i].shape[1] == out_channels |
| outs[i].shape[2] == outs[i].shape[3] == s // (2**i) |
|
|
| |
| fpn_model = FPN( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| add_extra_convs=True, |
| extra_convs_on_inputs=True, |
| start_level=1, |
| num_outs=5, |
| ) |
| assert fpn_model.add_extra_convs == 'on_input' |
| outs = fpn_model(feats) |
| assert len(outs) == fpn_model.num_outs |
| for i in range(fpn_model.num_outs): |
| outs[i].shape[1] == out_channels |
| outs[i].shape[2] == outs[i].shape[3] == s // (2**i) |
|
|
|
|
| def test_channel_mapper(): |
| """Tests ChannelMapper.""" |
| s = 64 |
| in_channels = [8, 16, 32, 64] |
| feat_sizes = [s // 2**i for i in range(4)] |
| out_channels = 8 |
| kernel_size = 3 |
| feats = [ |
| torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i]) |
| for i in range(len(in_channels)) |
| ] |
|
|
| |
| with pytest.raises(AssertionError): |
| channel_mapper = ChannelMapper( |
| in_channels=10, out_channels=out_channels, kernel_size=kernel_size) |
| |
| |
| with pytest.raises(AssertionError): |
| channel_mapper = ChannelMapper( |
| in_channels=in_channels[:-1], |
| out_channels=out_channels, |
| kernel_size=kernel_size) |
| channel_mapper(feats) |
|
|
| channel_mapper = ChannelMapper( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| kernel_size=kernel_size) |
|
|
| outs = channel_mapper(feats) |
| assert len(outs) == len(feats) |
| for i in range(len(feats)): |
| outs[i].shape[1] == out_channels |
| outs[i].shape[2] == outs[i].shape[3] == s // (2**i) |
|
|
|
|
| def test_dilated_encoder(): |
| in_channels = 16 |
| out_channels = 32 |
| out_shape = 34 |
| dilated_encoder = DilatedEncoder(in_channels, out_channels, 16, 2) |
| feat = [torch.rand(1, in_channels, 34, 34)] |
| out_feat = dilated_encoder(feat)[0] |
| assert out_feat.shape == (1, out_channels, out_shape, out_shape) |
|
|