Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import pytest | |
| import torch | |
| from mmocr.models.textdet.necks import FPNC, FPN_UNet | |
| def test_fpnc(): | |
| in_channels = [64, 128, 256, 512] | |
| size = [112, 56, 28, 14] | |
| for flag in [False, True]: | |
| fpnc = FPNC( | |
| in_channels=in_channels, | |
| bias_on_lateral=flag, | |
| bn_re_on_lateral=flag, | |
| bias_on_smooth=flag, | |
| bn_re_on_smooth=flag, | |
| conv_after_concat=flag) | |
| fpnc.init_weights() | |
| inputs = [] | |
| for i in range(4): | |
| inputs.append(torch.rand(1, in_channels[i], size[i], size[i])) | |
| outputs = fpnc.forward(inputs) | |
| assert list(outputs.size()) == [1, 256, 112, 112] | |
| def test_fpn_unet_neck(): | |
| s = 64 | |
| feat_sizes = [s // 2**i for i in range(4)] # [64, 32, 16, 8] | |
| in_channels = [8, 16, 32, 64] | |
| out_channels = 4 | |
| # len(in_channcels) is not equal to 4 | |
| with pytest.raises(AssertionError): | |
| FPN_UNet(in_channels + [128], out_channels) | |
| # `out_channels` is not int type | |
| with pytest.raises(AssertionError): | |
| FPN_UNet(in_channels, [2, 4]) | |
| feats = [ | |
| torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i]) | |
| for i in range(len(in_channels)) | |
| ] | |
| fpn_unet_neck = FPN_UNet(in_channels, out_channels) | |
| fpn_unet_neck.init_weights() | |
| out_neck = fpn_unet_neck(feats) | |
| assert out_neck.shape == torch.Size([1, out_channels, s * 4, s * 4]) | |