| |
| from unittest import TestCase |
|
|
| import torch |
|
|
| from mmpose.models.backbones import HourglassAENet, HourglassNet |
|
|
|
|
| class TestHourglass(TestCase): |
|
|
| def test_hourglass_backbone(self): |
| with self.assertRaises(AssertionError): |
| |
| HourglassNet(num_stacks=0) |
|
|
| with self.assertRaises(AssertionError): |
| |
| HourglassNet( |
| stage_channels=[256, 256, 384, 384, 384], |
| stage_blocks=[2, 2, 2, 2, 2, 4]) |
|
|
| with self.assertRaises(AssertionError): |
| |
| HourglassNet( |
| downsample_times=5, |
| stage_channels=[256, 256, 384, 384, 384], |
| stage_blocks=[2, 2, 2, 2, 2]) |
|
|
| |
| model = HourglassNet(num_stacks=1) |
| model.init_weights() |
| model.train() |
|
|
| imgs = torch.randn(1, 3, 256, 256) |
| feat = model(imgs) |
| self.assertEqual(len(feat), 1) |
| self.assertEqual(feat[0].shape, torch.Size([1, 256, 64, 64])) |
|
|
| |
| model = HourglassNet(num_stacks=2) |
| model.init_weights() |
| model.train() |
|
|
| imgs = torch.randn(1, 3, 256, 256) |
| feat = model(imgs) |
| self.assertEqual(len(feat), 2) |
| self.assertEqual(feat[0].shape, torch.Size([1, 256, 64, 64])) |
| self.assertEqual(feat[1].shape, torch.Size([1, 256, 64, 64])) |
|
|
| def test_hourglass_ae_backbone(self): |
| with self.assertRaises(AssertionError): |
| |
| HourglassAENet(num_stacks=0) |
|
|
| with self.assertRaises(AssertionError): |
| |
| HourglassAENet( |
| downsample_times=5, stage_channels=[256, 256, 384, 384, 384]) |
|
|
| |
| model = HourglassAENet(num_stacks=1) |
| model.init_weights() |
| model.train() |
|
|
| imgs = torch.randn(1, 3, 256, 256) |
| feat = model(imgs) |
| self.assertEqual(len(feat), 1) |
| self.assertEqual(feat[0].shape, torch.Size([1, 34, 64, 64])) |
|
|
| |
| model = HourglassAENet(num_stacks=2) |
| model.init_weights() |
| model.train() |
|
|
| imgs = torch.randn(1, 3, 256, 256) |
| feat = model(imgs) |
| self.assertEqual(len(feat), 2) |
| self.assertEqual(feat[0].shape, torch.Size([1, 34, 64, 64])) |
| self.assertEqual(feat[1].shape, torch.Size([1, 34, 64, 64])) |
|
|