File size: 4,576 Bytes
2402804 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import numpy as np
import torch
from mmengine.structures import LabelData
from mmdet.structures import ReIDDataSample
def _equal(a, b):
if isinstance(a, (torch.Tensor, np.ndarray)):
return (a == b).all()
else:
return a == b
class TestReIDDataSample(TestCase):
def test_init(self):
img_shape = (256, 128)
ori_shape = (64, 64)
num_classes = 5
meta_info = dict(
img_shape=img_shape, ori_shape=ori_shape, num_classes=num_classes)
data_sample = ReIDDataSample(metainfo=meta_info)
self.assertIn('img_shape', data_sample)
self.assertIn('ori_shape', data_sample)
self.assertIn('num_classes', data_sample)
self.assertTrue(_equal(data_sample.get('img_shape'), img_shape))
self.assertTrue(_equal(data_sample.get('ori_shape'), ori_shape))
self.assertTrue(_equal(data_sample.get('num_classes'), num_classes))
def test_set_gt_label(self):
data_sample = ReIDDataSample(metainfo=dict(num_classes=5))
method = getattr(data_sample, 'set_' + 'gt_label')
# Test number
method(1)
label = data_sample.get('gt_label')
self.assertIsInstance(label, LabelData)
self.assertIsInstance(label.label, torch.LongTensor)
# Test tensor with single number
method(torch.tensor(2))
label = data_sample.get('gt_label')
self.assertIsInstance(label, LabelData)
self.assertIsInstance(label.label, torch.LongTensor)
# Test array with single number
method(np.array(3))
label = data_sample.get('gt_label')
self.assertIsInstance(label, LabelData)
self.assertIsInstance(label.label, torch.LongTensor)
# Test tensor
_label = torch.tensor([1, 2, 3])
method(_label)
label = data_sample.get('gt_label')
self.assertIsInstance(label, LabelData)
self.assertIsInstance(label.label, torch.Tensor)
self.assertTrue(_equal(label.label, _label))
# Test array
_label = np.array([1, 2, 3])
method(_label)
label = data_sample.get('gt_label')
self.assertIsInstance(label, LabelData)
self.assertIsInstance(label.label, torch.Tensor)
self.assertTrue(_equal(label.label, torch.from_numpy(_label)))
# Test Sequence
_label = [1, 2, 3.]
method(_label)
label = data_sample.get('gt_label')
self.assertIsInstance(label, LabelData)
self.assertIsInstance(label.label, torch.Tensor)
self.assertTrue(_equal(label.label, torch.tensor(_label)))
# Test set num_classes
self.assertEqual(label.num_classes, 5)
# Test unavailable type
with self.assertRaisesRegex(TypeError, "<class 'str'> is not"):
method('hi')
def test_set_gt_score(self):
data_sample = ReIDDataSample(metainfo={'num_classes': 5})
method = getattr(data_sample, 'set_' + 'gt_score')
# Test set
score = [0.1, 0.1, 0.6, 0.1, 0.1]
method(torch.tensor(score))
sample_gt_label = getattr(data_sample, 'gt_label')
self.assertIn('score', sample_gt_label)
torch.testing.assert_allclose(sample_gt_label.score, score)
self.assertEqual(sample_gt_label.num_classes, 5)
# Test set again
score = [0.2, 0.1, 0.5, 0.1, 0.1]
method(torch.tensor(score))
torch.testing.assert_allclose(sample_gt_label.score, score)
# Test invalid type
with self.assertRaisesRegex(AssertionError, 'be a torch.Tensor'):
method(score)
# Test invalid dims
with self.assertRaisesRegex(AssertionError, 'but got 2'):
method(torch.tensor([score]))
# Test invalid num_classes
with self.assertRaisesRegex(AssertionError, r'length of value \(6\)'):
method(torch.tensor(score + [0.1]))
# Test auto inter num_classes
data_sample = ReIDDataSample()
method = getattr(data_sample, 'set_gt_score')
method(torch.tensor(score))
sample_gt_label = getattr(data_sample, 'gt_label')
self.assertEqual(sample_gt_label.num_classes, len(score))
def test_del_gt_label(self):
data_sample = ReIDDataSample()
self.assertNotIn('gt_label', data_sample)
data_sample.set_gt_label(1)
self.assertIn('gt_label', data_sample)
del data_sample.gt_label
self.assertNotIn('gt_label', data_sample)
|