| | |
| | import torch |
| |
|
| | from mmpose.core.post_processing.group import HeatmapParser |
| |
|
| |
|
| | def test_group(): |
| | cfg = {} |
| | cfg['num_joints'] = 17 |
| | cfg['detection_threshold'] = 0.1 |
| | cfg['tag_threshold'] = 1 |
| | cfg['use_detection_val'] = True |
| | cfg['ignore_too_much'] = False |
| | cfg['nms_kernel'] = 5 |
| | cfg['nms_padding'] = 2 |
| | cfg['tag_per_joint'] = True |
| | cfg['max_num_people'] = 1 |
| | parser = HeatmapParser(cfg) |
| | fake_heatmap = torch.zeros(1, 1, 5, 5) |
| | fake_heatmap[0, 0, 3, 3] = 1 |
| | fake_heatmap[0, 0, 3, 2] = 0.8 |
| | assert parser.nms(fake_heatmap)[0, 0, 3, 2] == 0 |
| | fake_heatmap = torch.zeros(1, 17, 32, 32) |
| | fake_tag = torch.zeros(1, 17, 32, 32, 1) |
| | fake_heatmap[0, 0, 10, 10] = 0.8 |
| | fake_heatmap[0, 1, 12, 12] = 0.9 |
| | fake_heatmap[0, 4, 8, 8] = 0.8 |
| | fake_heatmap[0, 8, 6, 6] = 0.9 |
| | fake_tag[0, 0, 10, 10] = 0.8 |
| | fake_tag[0, 1, 12, 12] = 0.9 |
| | fake_tag[0, 4, 8, 8] = 0.8 |
| | fake_tag[0, 8, 6, 6] = 0.9 |
| | grouped, scores = parser.parse(fake_heatmap, fake_tag, True, True) |
| | assert grouped[0][0, 0, 0] == 10.25 |
| | assert abs(scores[0] - 0.2) < 0.001 |
| | cfg['tag_per_joint'] = False |
| | parser = HeatmapParser(cfg) |
| | grouped, scores = parser.parse(fake_heatmap, fake_tag, False, False) |
| | assert grouped[0][0, 0, 0] == 10. |
| | grouped, scores = parser.parse(fake_heatmap, fake_tag, False, True) |
| | assert grouped[0][0, 0, 0] == 10. |
| |
|
| |
|
| | def test_group_score_per_joint(): |
| | cfg = {} |
| | cfg['num_joints'] = 17 |
| | cfg['detection_threshold'] = 0.1 |
| | cfg['tag_threshold'] = 1 |
| | cfg['use_detection_val'] = True |
| | cfg['ignore_too_much'] = False |
| | cfg['nms_kernel'] = 5 |
| | cfg['nms_padding'] = 2 |
| | cfg['tag_per_joint'] = True |
| | cfg['max_num_people'] = 1 |
| | cfg['score_per_joint'] = True |
| | parser = HeatmapParser(cfg) |
| | fake_heatmap = torch.zeros(1, 1, 5, 5) |
| | fake_heatmap[0, 0, 3, 3] = 1 |
| | fake_heatmap[0, 0, 3, 2] = 0.8 |
| | assert parser.nms(fake_heatmap)[0, 0, 3, 2] == 0 |
| | fake_heatmap = torch.zeros(1, 17, 32, 32) |
| | fake_tag = torch.zeros(1, 17, 32, 32, 1) |
| | fake_heatmap[0, 0, 10, 10] = 0.8 |
| | fake_heatmap[0, 1, 12, 12] = 0.9 |
| | fake_heatmap[0, 4, 8, 8] = 0.8 |
| | fake_heatmap[0, 8, 6, 6] = 0.9 |
| | fake_tag[0, 0, 10, 10] = 0.8 |
| | fake_tag[0, 1, 12, 12] = 0.9 |
| | fake_tag[0, 4, 8, 8] = 0.8 |
| | fake_tag[0, 8, 6, 6] = 0.9 |
| | grouped, scores = parser.parse(fake_heatmap, fake_tag, True, True) |
| | assert len(scores[0]) == 17 |
| |
|