|
|
|
|
|
from unittest import TestCase |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from mmengine.structures import LabelData |
|
|
|
|
|
from mmdet.structures import ReIDDataSample |
|
|
|
|
|
|
|
|
def _equal(a, b): |
|
|
if isinstance(a, (torch.Tensor, np.ndarray)): |
|
|
return (a == b).all() |
|
|
else: |
|
|
return a == b |
|
|
|
|
|
|
|
|
class TestReIDDataSample(TestCase): |
|
|
|
|
|
def test_init(self): |
|
|
img_shape = (256, 128) |
|
|
ori_shape = (64, 64) |
|
|
num_classes = 5 |
|
|
meta_info = dict( |
|
|
img_shape=img_shape, ori_shape=ori_shape, num_classes=num_classes) |
|
|
data_sample = ReIDDataSample(metainfo=meta_info) |
|
|
self.assertIn('img_shape', data_sample) |
|
|
self.assertIn('ori_shape', data_sample) |
|
|
self.assertIn('num_classes', data_sample) |
|
|
self.assertTrue(_equal(data_sample.get('img_shape'), img_shape)) |
|
|
self.assertTrue(_equal(data_sample.get('ori_shape'), ori_shape)) |
|
|
self.assertTrue(_equal(data_sample.get('num_classes'), num_classes)) |
|
|
|
|
|
def test_set_gt_label(self): |
|
|
data_sample = ReIDDataSample(metainfo=dict(num_classes=5)) |
|
|
method = getattr(data_sample, 'set_' + 'gt_label') |
|
|
|
|
|
|
|
|
method(1) |
|
|
label = data_sample.get('gt_label') |
|
|
self.assertIsInstance(label, LabelData) |
|
|
self.assertIsInstance(label.label, torch.LongTensor) |
|
|
|
|
|
|
|
|
method(torch.tensor(2)) |
|
|
label = data_sample.get('gt_label') |
|
|
self.assertIsInstance(label, LabelData) |
|
|
self.assertIsInstance(label.label, torch.LongTensor) |
|
|
|
|
|
|
|
|
method(np.array(3)) |
|
|
label = data_sample.get('gt_label') |
|
|
self.assertIsInstance(label, LabelData) |
|
|
self.assertIsInstance(label.label, torch.LongTensor) |
|
|
|
|
|
|
|
|
_label = torch.tensor([1, 2, 3]) |
|
|
method(_label) |
|
|
label = data_sample.get('gt_label') |
|
|
self.assertIsInstance(label, LabelData) |
|
|
self.assertIsInstance(label.label, torch.Tensor) |
|
|
self.assertTrue(_equal(label.label, _label)) |
|
|
|
|
|
|
|
|
_label = np.array([1, 2, 3]) |
|
|
method(_label) |
|
|
label = data_sample.get('gt_label') |
|
|
self.assertIsInstance(label, LabelData) |
|
|
self.assertIsInstance(label.label, torch.Tensor) |
|
|
self.assertTrue(_equal(label.label, torch.from_numpy(_label))) |
|
|
|
|
|
|
|
|
_label = [1, 2, 3.] |
|
|
method(_label) |
|
|
label = data_sample.get('gt_label') |
|
|
self.assertIsInstance(label, LabelData) |
|
|
self.assertIsInstance(label.label, torch.Tensor) |
|
|
self.assertTrue(_equal(label.label, torch.tensor(_label))) |
|
|
|
|
|
|
|
|
self.assertEqual(label.num_classes, 5) |
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(TypeError, "<class 'str'> is not"): |
|
|
method('hi') |
|
|
|
|
|
def test_set_gt_score(self): |
|
|
data_sample = ReIDDataSample(metainfo={'num_classes': 5}) |
|
|
method = getattr(data_sample, 'set_' + 'gt_score') |
|
|
|
|
|
|
|
|
score = [0.1, 0.1, 0.6, 0.1, 0.1] |
|
|
method(torch.tensor(score)) |
|
|
sample_gt_label = getattr(data_sample, 'gt_label') |
|
|
self.assertIn('score', sample_gt_label) |
|
|
torch.testing.assert_allclose(sample_gt_label.score, score) |
|
|
self.assertEqual(sample_gt_label.num_classes, 5) |
|
|
|
|
|
|
|
|
score = [0.2, 0.1, 0.5, 0.1, 0.1] |
|
|
method(torch.tensor(score)) |
|
|
torch.testing.assert_allclose(sample_gt_label.score, score) |
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(AssertionError, 'be a torch.Tensor'): |
|
|
method(score) |
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(AssertionError, 'but got 2'): |
|
|
method(torch.tensor([score])) |
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(AssertionError, r'length of value \(6\)'): |
|
|
method(torch.tensor(score + [0.1])) |
|
|
|
|
|
|
|
|
data_sample = ReIDDataSample() |
|
|
method = getattr(data_sample, 'set_gt_score') |
|
|
method(torch.tensor(score)) |
|
|
sample_gt_label = getattr(data_sample, 'gt_label') |
|
|
self.assertEqual(sample_gt_label.num_classes, len(score)) |
|
|
|
|
|
def test_del_gt_label(self): |
|
|
data_sample = ReIDDataSample() |
|
|
self.assertNotIn('gt_label', data_sample) |
|
|
data_sample.set_gt_label(1) |
|
|
self.assertIn('gt_label', data_sample) |
|
|
del data_sample.gt_label |
|
|
self.assertNotIn('gt_label', data_sample) |
|
|
|