| |
| from unittest import TestCase |
|
|
| import torch |
| from torch.nn.modules.batchnorm import _BatchNorm |
|
|
| from mmpose.models.backbones import LiteHRNet |
| from mmpose.models.backbones.litehrnet import LiteHRModule |
| from mmpose.models.backbones.resnet import Bottleneck |
|
|
|
|
| class TestLiteHrnet(TestCase): |
|
|
| @staticmethod |
| def is_norm(modules): |
| """Check if is one of the norms.""" |
| if isinstance(modules, (_BatchNorm, )): |
| return True |
| return False |
|
|
| @staticmethod |
| def all_zeros(modules): |
| """Check if the weight(and bias) is all zero.""" |
| weight_zero = torch.equal(modules.weight.data, |
| torch.zeros_like(modules.weight.data)) |
| if hasattr(modules, 'bias'): |
| bias_zero = torch.equal(modules.bias.data, |
| torch.zeros_like(modules.bias.data)) |
| else: |
| bias_zero = True |
|
|
| return weight_zero and bias_zero |
|
|
| def test_litehrmodule(self): |
| |
| block = LiteHRModule( |
| num_branches=1, |
| num_blocks=1, |
| in_channels=[ |
| 40, |
| ], |
| reduce_ratio=8, |
| module_type='LITE') |
|
|
| x = torch.randn(2, 40, 56, 56) |
| x_out = block([[x]]) |
| self.assertEqual(x_out[0][0].shape, torch.Size([2, 40, 56, 56])) |
|
|
| block = LiteHRModule( |
| num_branches=1, |
| num_blocks=1, |
| in_channels=[ |
| 40, |
| ], |
| reduce_ratio=8, |
| module_type='NAIVE') |
|
|
| x = torch.randn(2, 40, 56, 56) |
| x_out = block([x]) |
| self.assertEqual(x_out[0].shape, torch.Size([2, 40, 56, 56])) |
|
|
| with self.assertRaises(ValueError): |
| block = LiteHRModule( |
| num_branches=1, |
| num_blocks=1, |
| in_channels=[ |
| 40, |
| ], |
| reduce_ratio=8, |
| module_type='none') |
|
|
| def test_litehrnet_backbone(self): |
| extra = dict( |
| stem=dict(stem_channels=32, out_channels=32, expand_ratio=1), |
| num_stages=3, |
| stages_spec=dict( |
| num_modules=(2, 4, 2), |
| num_branches=(2, 3, 4), |
| num_blocks=(2, 2, 2), |
| module_type=('LITE', 'LITE', 'LITE'), |
| with_fuse=(True, True, True), |
| reduce_ratios=(8, 8, 8), |
| num_channels=( |
| (40, 80), |
| (40, 80, 160), |
| (40, 80, 160, 320), |
| )), |
| with_head=True) |
|
|
| model = LiteHRNet(extra, in_channels=3) |
|
|
| imgs = torch.randn(2, 3, 224, 224) |
| feat = model(imgs) |
| self.assertIsInstance(feat, tuple) |
| self.assertEqual(feat[-1].shape, torch.Size([2, 40, 56, 56])) |
|
|
| |
| model = LiteHRNet(extra, in_channels=3) |
| model.init_weights() |
| for m in model.modules(): |
| if isinstance(m, Bottleneck): |
| self.assertTrue(self.all_zeros(m.norm3)) |
| model.train() |
|
|
| imgs = torch.randn(2, 3, 224, 224) |
| feat = model(imgs) |
| self.assertIsInstance(feat, tuple) |
| self.assertEqual(feat[-1].shape, torch.Size([2, 40, 56, 56])) |
|
|
| extra = dict( |
| stem=dict(stem_channels=32, out_channels=32, expand_ratio=1), |
| num_stages=3, |
| stages_spec=dict( |
| num_modules=(2, 4, 2), |
| num_branches=(2, 3, 4), |
| num_blocks=(2, 2, 2), |
| module_type=('NAIVE', 'NAIVE', 'NAIVE'), |
| with_fuse=(True, True, True), |
| reduce_ratios=(8, 8, 8), |
| num_channels=( |
| (40, 80), |
| (40, 80, 160), |
| (40, 80, 160, 320), |
| )), |
| with_head=True) |
|
|
| model = LiteHRNet(extra, in_channels=3) |
|
|
| imgs = torch.randn(2, 3, 224, 224) |
| feat = model(imgs) |
| self.assertIsInstance(feat, tuple) |
| self.assertEqual(feat[-1].shape, torch.Size([2, 40, 56, 56])) |
|
|
| |
| model = LiteHRNet(extra, in_channels=3) |
| model.init_weights() |
| for m in model.modules(): |
| if isinstance(m, Bottleneck): |
| self.assertTrue(self.all_zeros(m.norm3)) |
| model.train() |
|
|
| imgs = torch.randn(2, 3, 224, 224) |
| feat = model(imgs) |
| self.assertIsInstance(feat, tuple) |
| self.assertEqual(feat[-1].shape, torch.Size([2, 40, 56, 56])) |
|
|