|
|
|
|
|
import copy
|
|
|
import unittest
|
|
|
|
|
|
import numpy as np
|
|
|
import pytest
|
|
|
import torch
|
|
|
from mmengine.structures import InstanceData
|
|
|
from mmengine.testing import assert_dict_has_keys
|
|
|
from numpy.testing import assert_array_equal
|
|
|
|
|
|
from mmaction.datasets.transforms import (FormatAudioShape, FormatGCNInput,
|
|
|
FormatShape, PackActionInputs,
|
|
|
Transpose)
|
|
|
from mmaction.registry import TRANSFORMS
|
|
|
from mmaction.structures import ActionDataSample
|
|
|
from mmaction.utils import register_all_modules
|
|
|
|
|
|
register_all_modules()
|
|
|
|
|
|
|
|
|
class TestPackActionInputs(unittest.TestCase):
|
|
|
|
|
|
def test_transform(self):
|
|
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
results = PackActionInputs()(dict())
|
|
|
|
|
|
|
|
|
results = dict(keypoint=np.random.randn(2, 300, 17, 3), label=1)
|
|
|
transform = PackActionInputs()
|
|
|
results = transform(results)
|
|
|
self.assertIn('inputs', results)
|
|
|
self.assertIn('data_samples', results)
|
|
|
self.assertIsInstance(results['inputs'], torch.Tensor)
|
|
|
self.assertEqual(results['inputs'].shape, (2, 300, 17, 3))
|
|
|
self.assertEqual(results['data_samples'].gt_label,
|
|
|
torch.LongTensor([1]))
|
|
|
|
|
|
|
|
|
results = dict(heatmap_imgs=np.random.randn(2, 17, 56, 56), label=1)
|
|
|
transform = PackActionInputs()
|
|
|
results = transform(results)
|
|
|
self.assertIn('inputs', results)
|
|
|
self.assertIn('data_samples', results)
|
|
|
self.assertIsInstance(results['inputs'], torch.Tensor)
|
|
|
self.assertEqual(results['inputs'].shape, (2, 17, 56, 56))
|
|
|
self.assertEqual(results['data_samples'].gt_label,
|
|
|
torch.LongTensor([1]))
|
|
|
|
|
|
|
|
|
results = dict(audios=np.random.randn(3, 1, 128, 80), label=[1])
|
|
|
transform = PackActionInputs()
|
|
|
results = transform(results)
|
|
|
self.assertIn('inputs', results)
|
|
|
self.assertIn('data_samples', results)
|
|
|
self.assertEqual(results['inputs'].shape, (3, 1, 128, 80))
|
|
|
self.assertIsInstance(results['inputs'], torch.Tensor)
|
|
|
|
|
|
|
|
|
results = dict(text=np.random.randn(77))
|
|
|
transform = PackActionInputs()
|
|
|
results = transform(results)
|
|
|
self.assertIn('inputs', results)
|
|
|
self.assertIn('data_samples', results)
|
|
|
self.assertEqual(results['inputs'].shape, (77, ))
|
|
|
self.assertIsInstance(results['inputs'], torch.Tensor)
|
|
|
|
|
|
|
|
|
data = dict(
|
|
|
imgs=np.random.randn(2, 256, 256, 3),
|
|
|
label=[1],
|
|
|
filename='test.txt',
|
|
|
original_shape=(256, 256, 3),
|
|
|
img_shape=(256, 256, 3),
|
|
|
flip_direction='vertical')
|
|
|
|
|
|
transform = PackActionInputs()
|
|
|
results = transform(copy.deepcopy(data))
|
|
|
self.assertIn('inputs', results)
|
|
|
self.assertIn('data_samples', results)
|
|
|
self.assertIsInstance(results['inputs'], torch.Tensor)
|
|
|
self.assertIsInstance(results['data_samples'], ActionDataSample)
|
|
|
self.assertEqual(results['data_samples'].img_shape, (256, 256, 3))
|
|
|
self.assertEqual(results['data_samples'].gt_label,
|
|
|
torch.LongTensor([1]))
|
|
|
|
|
|
|
|
|
data['imgs'] = data['imgs'].mean(-1)
|
|
|
results = transform(copy.deepcopy(data))
|
|
|
self.assertIn('inputs', results)
|
|
|
self.assertIsInstance(results['inputs'], torch.Tensor)
|
|
|
self.assertEqual(results['inputs'].shape, (2, 256, 256))
|
|
|
|
|
|
|
|
|
data = dict(
|
|
|
imgs=np.random.randn(256, 256, 3),
|
|
|
gt_bboxes=np.array([[0, 0, 340, 224]]),
|
|
|
gt_labels=[1],
|
|
|
proposals=np.array([[0, 0, 340, 224]]),
|
|
|
filename='test.txt')
|
|
|
|
|
|
transform = PackActionInputs()
|
|
|
results = transform(copy.deepcopy(data))
|
|
|
self.assertIn('inputs', results)
|
|
|
self.assertIsInstance(results['inputs'], torch.Tensor)
|
|
|
self.assertIn('data_samples', results)
|
|
|
self.assertIsInstance(results['data_samples'], ActionDataSample)
|
|
|
self.assertIsInstance(results['data_samples'].gt_instances,
|
|
|
InstanceData)
|
|
|
self.assertIsInstance(results['data_samples'].proposals, InstanceData)
|
|
|
|
|
|
|
|
|
data = dict(
|
|
|
imgs=np.random.randn(2, 256, 256, 3), text=np.random.randn(77))
|
|
|
|
|
|
transform = PackActionInputs(collect_keys=('imgs', 'text'))
|
|
|
results = transform(copy.deepcopy(data))
|
|
|
self.assertIn('inputs', results)
|
|
|
self.assertIn('data_samples', results)
|
|
|
self.assertIsInstance(results['inputs'], dict)
|
|
|
self.assertEqual(results['inputs']['imgs'].shape, (2, 256, 256, 3))
|
|
|
self.assertEqual(results['inputs']['text'].shape, (77, ))
|
|
|
|
|
|
def test_repr(self):
|
|
|
cfg = dict(
|
|
|
type='PackActionInputs', meta_keys=['flip_direction', 'img_shape'])
|
|
|
transform = TRANSFORMS.build(cfg)
|
|
|
self.assertEqual(
|
|
|
repr(transform), 'PackActionInputs(collect_keys=None, '
|
|
|
"meta_keys=['flip_direction', 'img_shape'])")
|
|
|
|
|
|
|
|
|
class TestPackLocalizationInputs(unittest.TestCase):
|
|
|
|
|
|
def test_transform(self):
|
|
|
|
|
|
data = dict(
|
|
|
raw_feature=np.random.randn(400, 5),
|
|
|
gt_bbox=np.array([[0.1, 0.3], [0.375, 0.625]]),
|
|
|
filename='test.txt')
|
|
|
|
|
|
cfg = dict(type='PackLocalizationInputs', keys=('gt_bbox', ))
|
|
|
transform = TRANSFORMS.build(cfg)
|
|
|
results = transform(copy.deepcopy(data))
|
|
|
self.assertIn('inputs', results)
|
|
|
self.assertIsInstance(results['inputs'], torch.Tensor)
|
|
|
self.assertIn('data_samples', results)
|
|
|
self.assertIsInstance(results['data_samples'], ActionDataSample)
|
|
|
self.assertIsInstance(results['data_samples'].gt_instances,
|
|
|
InstanceData)
|
|
|
|
|
|
del data['raw_feature']
|
|
|
with self.assertRaises(ValueError):
|
|
|
transform(copy.deepcopy(data))
|
|
|
|
|
|
|
|
|
data['bsp_feature'] = np.random.randn(100, 32)
|
|
|
results = transform(copy.deepcopy(data))
|
|
|
self.assertIn('inputs', results)
|
|
|
self.assertIsInstance(results['inputs'], torch.Tensor)
|
|
|
self.assertIn('data_samples', results)
|
|
|
self.assertIsInstance(results['data_samples'], ActionDataSample)
|
|
|
self.assertIsInstance(results['data_samples'].gt_instances,
|
|
|
InstanceData)
|
|
|
|
|
|
def test_repr(self):
|
|
|
cfg = dict(
|
|
|
type='PackLocalizationInputs',
|
|
|
meta_keys=['video_name', 'feature_frame'])
|
|
|
transform = TRANSFORMS.build(cfg)
|
|
|
self.assertEqual(
|
|
|
repr(transform),
|
|
|
"PackLocalizationInputs(meta_keys=['video_name', 'feature_frame'])"
|
|
|
)
|
|
|
|
|
|
|
|
|
def test_transpose():
|
|
|
results = dict(imgs=np.random.randn(256, 256, 3))
|
|
|
keys = ['imgs']
|
|
|
order = [2, 0, 1]
|
|
|
transpose = Transpose(keys, order)
|
|
|
results = transpose(results)
|
|
|
assert results['imgs'].shape == (3, 256, 256)
|
|
|
assert repr(transpose) == transpose.__class__.__name__ + \
|
|
|
f'(keys={keys}, order={order})'
|
|
|
|
|
|
|
|
|
def test_format_shape():
|
|
|
with pytest.raises(ValueError):
|
|
|
|
|
|
FormatShape('NHWC')
|
|
|
|
|
|
|
|
|
results = dict(
|
|
|
imgs=np.random.randn(3, 224, 224, 3), num_clips=1, clip_len=3)
|
|
|
format_shape = FormatShape('NCHW')
|
|
|
assert format_shape(results)['input_shape'] == (3, 3, 224, 224)
|
|
|
|
|
|
|
|
|
results = dict(
|
|
|
imgs=np.random.randn(3, 224, 224, 2),
|
|
|
num_clips=1,
|
|
|
clip_len=3,
|
|
|
modality='Flow')
|
|
|
format_shape = FormatShape('NCHW')
|
|
|
assert format_shape(results)['input_shape'] == (1, 6, 224, 224)
|
|
|
|
|
|
|
|
|
results = dict(
|
|
|
imgs=np.random.randn(3, 224, 224, 3), num_clips=1, clip_len=3)
|
|
|
format_shape = FormatShape('NCTHW')
|
|
|
assert format_shape(results)['input_shape'] == (1, 3, 3, 224, 224)
|
|
|
|
|
|
|
|
|
results = dict(
|
|
|
imgs=np.random.randn(18, 224, 224, 3), num_clips=2, clip_len=3)
|
|
|
assert format_shape(results)['input_shape'] == (6, 3, 3, 224, 224)
|
|
|
target_keys = ['imgs', 'input_shape']
|
|
|
assert assert_dict_has_keys(results, target_keys)
|
|
|
|
|
|
|
|
|
results = dict(
|
|
|
imgs=np.random.randn(6, 224, 224, 3),
|
|
|
heatmap_imgs=np.random.randn(12, 17, 56, 56),
|
|
|
num_clips=2,
|
|
|
clip_len=dict(RGB=3, Pose=6))
|
|
|
|
|
|
results = format_shape(results)
|
|
|
assert results['input_shape'] == (2, 3, 3, 224, 224)
|
|
|
assert results['heatmap_input_shape'] == (2, 17, 6, 56, 56)
|
|
|
|
|
|
assert repr(format_shape) == "FormatShape(input_format='NCTHW')"
|
|
|
|
|
|
|
|
|
results = dict(
|
|
|
imgs=np.random.randn(12, 17, 56, 56), num_clips=2, clip_len=6)
|
|
|
format_shape = FormatShape('NCTHW_Heatmap')
|
|
|
assert format_shape(results)['input_shape'] == (2, 17, 6, 56, 56)
|
|
|
|
|
|
|
|
|
results = dict(
|
|
|
imgs=np.random.randn(72, 224, 224, 3),
|
|
|
num_clips=9,
|
|
|
clip_len=1,
|
|
|
num_proposals=8)
|
|
|
format_shape = FormatShape('NPTCHW')
|
|
|
assert format_shape(results)['input_shape'] == (8, 9, 3, 224, 224)
|
|
|
|
|
|
|
|
|
def test_format_audio_shape():
|
|
|
with pytest.raises(ValueError):
|
|
|
|
|
|
FormatAudioShape('XXXX')
|
|
|
|
|
|
|
|
|
results = dict(audios=np.random.randn(3, 128, 8))
|
|
|
format_shape = FormatAudioShape('NCTF')
|
|
|
assert format_shape(results)['input_shape'] == (3, 1, 128, 8)
|
|
|
assert repr(format_shape) == format_shape.__class__.__name__ + \
|
|
|
"(input_format='NCTF')"
|
|
|
|
|
|
|
|
|
def test_format_gcn_input():
|
|
|
with pytest.raises(AssertionError):
|
|
|
FormatGCNInput(mode='invalid')
|
|
|
|
|
|
results = dict(
|
|
|
keypoint=np.random.randn(2, 10, 17, 2),
|
|
|
keypoint_score=np.random.randn(2, 10, 17))
|
|
|
format_shape = FormatGCNInput(num_person=2, mode='zero')
|
|
|
results = format_shape(results)
|
|
|
assert results['keypoint'].shape == (1, 2, 10, 17, 3)
|
|
|
assert repr(format_shape) == 'FormatGCNInput(num_person=2, mode=zero)'
|
|
|
|
|
|
results = dict(keypoint=np.random.randn(2, 40, 25, 3), num_clips=4)
|
|
|
format_shape = FormatGCNInput(num_person=2, mode='zero')
|
|
|
results = format_shape(results)
|
|
|
assert results['keypoint'].shape == (4, 2, 10, 25, 3)
|
|
|
|
|
|
results = dict(keypoint=np.random.randn(1, 10, 25, 3))
|
|
|
format_shape = FormatGCNInput(num_person=2, mode='zero')
|
|
|
results = format_shape(results)
|
|
|
assert results['keypoint'].shape == (1, 2, 10, 25, 3)
|
|
|
assert_array_equal(results['keypoint'][:, 1], np.zeros((1, 10, 25, 3)))
|
|
|
|
|
|
results = dict(keypoint=np.random.randn(1, 10, 25, 3))
|
|
|
format_shape = FormatGCNInput(num_person=2, mode='loop')
|
|
|
results = format_shape(results)
|
|
|
assert results['keypoint'].shape == (1, 2, 10, 25, 3)
|
|
|
assert_array_equal(results['keypoint'][:, 1], results['keypoint'][:, 0])
|
|
|
|
|
|
results = dict(keypoint=np.random.randn(3, 10, 25, 3))
|
|
|
format_shape = FormatGCNInput(num_person=2, mode='zero')
|
|
|
results = format_shape(results)
|
|
|
assert results['keypoint'].shape == (1, 2, 10, 25, 3)
|
|
|
|