Spaces:
Build error
Build error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from unittest import TestCase | |
| import torch | |
| from mmpose.models.backbones import HourglassAENet, HourglassNet | |
| class TestHourglass(TestCase): | |
| def test_hourglass_backbone(self): | |
| with self.assertRaises(AssertionError): | |
| # HourglassNet's num_stacks should larger than 0 | |
| HourglassNet(num_stacks=0) | |
| with self.assertRaises(AssertionError): | |
| # len(stage_channels) should equal len(stage_blocks) | |
| HourglassNet( | |
| stage_channels=[256, 256, 384, 384, 384], | |
| stage_blocks=[2, 2, 2, 2, 2, 4]) | |
| with self.assertRaises(AssertionError): | |
| # len(stage_channels) should larger than downsample_times | |
| HourglassNet( | |
| downsample_times=5, | |
| stage_channels=[256, 256, 384, 384, 384], | |
| stage_blocks=[2, 2, 2, 2, 2]) | |
| # Test HourglassNet-52 | |
| model = HourglassNet(num_stacks=1) | |
| model.init_weights() | |
| model.train() | |
| imgs = torch.randn(1, 3, 256, 256) | |
| feat = model(imgs) | |
| self.assertEqual(len(feat), 1) | |
| self.assertEqual(feat[0].shape, torch.Size([1, 256, 64, 64])) | |
| # Test HourglassNet-104 | |
| model = HourglassNet(num_stacks=2) | |
| model.init_weights() | |
| model.train() | |
| imgs = torch.randn(1, 3, 256, 256) | |
| feat = model(imgs) | |
| self.assertEqual(len(feat), 2) | |
| self.assertEqual(feat[0].shape, torch.Size([1, 256, 64, 64])) | |
| self.assertEqual(feat[1].shape, torch.Size([1, 256, 64, 64])) | |
| def test_hourglass_ae_backbone(self): | |
| with self.assertRaises(AssertionError): | |
| # HourglassAENet's num_stacks should larger than 0 | |
| HourglassAENet(num_stacks=0) | |
| with self.assertRaises(AssertionError): | |
| # len(stage_channels) should larger than downsample_times | |
| HourglassAENet( | |
| downsample_times=5, stage_channels=[256, 256, 384, 384, 384]) | |
| # num_stack=1 | |
| model = HourglassAENet(num_stacks=1) | |
| model.init_weights() | |
| model.train() | |
| imgs = torch.randn(1, 3, 256, 256) | |
| feat = model(imgs) | |
| self.assertEqual(len(feat), 1) | |
| self.assertEqual(feat[0].shape, torch.Size([1, 34, 64, 64])) | |
| # num_stack=2 | |
| model = HourglassAENet(num_stacks=2) | |
| model.init_weights() | |
| model.train() | |
| imgs = torch.randn(1, 3, 256, 256) | |
| feat = model(imgs) | |
| self.assertEqual(len(feat), 2) | |
| self.assertEqual(feat[0].shape, torch.Size([1, 34, 64, 64])) | |
| self.assertEqual(feat[1].shape, torch.Size([1, 34, 64, 64])) | |