File size: 4,576 Bytes
2402804
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase

import numpy as np
import torch
from mmengine.structures import LabelData

from mmdet.structures import ReIDDataSample


def _equal(a, b):
    if isinstance(a, (torch.Tensor, np.ndarray)):
        return (a == b).all()
    else:
        return a == b


class TestReIDDataSample(TestCase):

    def test_init(self):
        img_shape = (256, 128)
        ori_shape = (64, 64)
        num_classes = 5
        meta_info = dict(
            img_shape=img_shape, ori_shape=ori_shape, num_classes=num_classes)
        data_sample = ReIDDataSample(metainfo=meta_info)
        self.assertIn('img_shape', data_sample)
        self.assertIn('ori_shape', data_sample)
        self.assertIn('num_classes', data_sample)
        self.assertTrue(_equal(data_sample.get('img_shape'), img_shape))
        self.assertTrue(_equal(data_sample.get('ori_shape'), ori_shape))
        self.assertTrue(_equal(data_sample.get('num_classes'), num_classes))

    def test_set_gt_label(self):
        data_sample = ReIDDataSample(metainfo=dict(num_classes=5))
        method = getattr(data_sample, 'set_' + 'gt_label')

        # Test number
        method(1)
        label = data_sample.get('gt_label')
        self.assertIsInstance(label, LabelData)
        self.assertIsInstance(label.label, torch.LongTensor)

        # Test tensor with single number
        method(torch.tensor(2))
        label = data_sample.get('gt_label')
        self.assertIsInstance(label, LabelData)
        self.assertIsInstance(label.label, torch.LongTensor)

        # Test array with single number
        method(np.array(3))
        label = data_sample.get('gt_label')
        self.assertIsInstance(label, LabelData)
        self.assertIsInstance(label.label, torch.LongTensor)

        # Test tensor
        _label = torch.tensor([1, 2, 3])
        method(_label)
        label = data_sample.get('gt_label')
        self.assertIsInstance(label, LabelData)
        self.assertIsInstance(label.label, torch.Tensor)
        self.assertTrue(_equal(label.label, _label))

        # Test array
        _label = np.array([1, 2, 3])
        method(_label)
        label = data_sample.get('gt_label')
        self.assertIsInstance(label, LabelData)
        self.assertIsInstance(label.label, torch.Tensor)
        self.assertTrue(_equal(label.label, torch.from_numpy(_label)))

        # Test Sequence
        _label = [1, 2, 3.]
        method(_label)
        label = data_sample.get('gt_label')
        self.assertIsInstance(label, LabelData)
        self.assertIsInstance(label.label, torch.Tensor)
        self.assertTrue(_equal(label.label, torch.tensor(_label)))

        # Test set num_classes
        self.assertEqual(label.num_classes, 5)

        # Test unavailable type
        with self.assertRaisesRegex(TypeError, "<class 'str'> is not"):
            method('hi')

    def test_set_gt_score(self):
        data_sample = ReIDDataSample(metainfo={'num_classes': 5})
        method = getattr(data_sample, 'set_' + 'gt_score')

        # Test set
        score = [0.1, 0.1, 0.6, 0.1, 0.1]
        method(torch.tensor(score))
        sample_gt_label = getattr(data_sample, 'gt_label')
        self.assertIn('score', sample_gt_label)
        torch.testing.assert_allclose(sample_gt_label.score, score)
        self.assertEqual(sample_gt_label.num_classes, 5)

        # Test set again
        score = [0.2, 0.1, 0.5, 0.1, 0.1]
        method(torch.tensor(score))
        torch.testing.assert_allclose(sample_gt_label.score, score)

        # Test invalid type
        with self.assertRaisesRegex(AssertionError, 'be a torch.Tensor'):
            method(score)

        # Test invalid dims
        with self.assertRaisesRegex(AssertionError, 'but got 2'):
            method(torch.tensor([score]))

        # Test invalid num_classes
        with self.assertRaisesRegex(AssertionError, r'length of value \(6\)'):
            method(torch.tensor(score + [0.1]))

        # Test auto inter num_classes
        data_sample = ReIDDataSample()
        method = getattr(data_sample, 'set_gt_score')
        method(torch.tensor(score))
        sample_gt_label = getattr(data_sample, 'gt_label')
        self.assertEqual(sample_gt_label.num_classes, len(score))

    def test_del_gt_label(self):
        data_sample = ReIDDataSample()
        self.assertNotIn('gt_label', data_sample)
        data_sample.set_gt_label(1)
        self.assertIn('gt_label', data_sample)
        del data_sample.gt_label
        self.assertNotIn('gt_label', data_sample)