| | |
| | import pytest |
| | import torch |
| |
|
| | from mmdet.models.backbones.hrnet import HRModule, HRNet |
| | from mmdet.models.backbones.resnet import BasicBlock, Bottleneck |
| |
|
| |
|
| | @pytest.mark.parametrize('block', [BasicBlock, Bottleneck]) |
| | def test_hrmodule(block): |
| | |
| | num_channles = (32, 64) |
| | in_channels = [c * block.expansion for c in num_channles] |
| | hrmodule = HRModule( |
| | num_branches=2, |
| | blocks=block, |
| | in_channels=in_channels, |
| | num_blocks=(4, 4), |
| | num_channels=num_channles, |
| | ) |
| |
|
| | feats = [ |
| | torch.randn(1, in_channels[0], 64, 64), |
| | torch.randn(1, in_channels[1], 32, 32) |
| | ] |
| | feats = hrmodule(feats) |
| |
|
| | assert len(feats) == 2 |
| | assert feats[0].shape == torch.Size([1, in_channels[0], 64, 64]) |
| | assert feats[1].shape == torch.Size([1, in_channels[1], 32, 32]) |
| |
|
| | |
| | num_channles = (32, 64) |
| | in_channels = [c * block.expansion for c in num_channles] |
| | hrmodule = HRModule( |
| | num_branches=2, |
| | blocks=block, |
| | in_channels=in_channels, |
| | num_blocks=(4, 4), |
| | num_channels=num_channles, |
| | multiscale_output=False, |
| | ) |
| |
|
| | feats = [ |
| | torch.randn(1, in_channels[0], 64, 64), |
| | torch.randn(1, in_channels[1], 32, 32) |
| | ] |
| | feats = hrmodule(feats) |
| |
|
| | assert len(feats) == 1 |
| | assert feats[0].shape == torch.Size([1, in_channels[0], 64, 64]) |
| |
|
| |
|
| | def test_hrnet_backbone(): |
| | |
| | extra = dict( |
| | stage1=dict( |
| | num_modules=1, |
| | num_branches=1, |
| | block='BOTTLENECK', |
| | num_blocks=(4, ), |
| | num_channels=(64, )), |
| | stage2=dict( |
| | num_modules=1, |
| | num_branches=2, |
| | block='BASIC', |
| | num_blocks=(4, 4), |
| | num_channels=(32, 64)), |
| | stage3=dict( |
| | num_modules=4, |
| | num_branches=3, |
| | block='BASIC', |
| | num_blocks=(4, 4, 4), |
| | num_channels=(32, 64, 128))) |
| |
|
| | with pytest.raises(AssertionError): |
| | |
| | HRNet(extra=extra) |
| | extra['stage4'] = dict( |
| | num_modules=3, |
| | num_branches=3, |
| | block='BASIC', |
| | num_blocks=(4, 4, 4, 4), |
| | num_channels=(32, 64, 128, 256)) |
| |
|
| | with pytest.raises(AssertionError): |
| | |
| | HRNet(extra=extra) |
| |
|
| | extra['stage4']['num_branches'] = 4 |
| |
|
| | |
| | model = HRNet(extra=extra) |
| | model.init_weights() |
| | model.train() |
| |
|
| | imgs = torch.randn(1, 3, 256, 256) |
| | feats = model(imgs) |
| | assert len(feats) == 4 |
| | assert feats[0].shape == torch.Size([1, 32, 64, 64]) |
| | assert feats[3].shape == torch.Size([1, 256, 8, 8]) |
| |
|
| | |
| | model = HRNet(extra=extra, multiscale_output=False) |
| | model.init_weights() |
| | model.train() |
| |
|
| | imgs = torch.randn(1, 3, 256, 256) |
| | feats = model(imgs) |
| | assert len(feats) == 1 |
| | assert feats[0].shape == torch.Size([1, 32, 64, 64]) |
| |
|