mmaction2 / tests /datasets /transforms /test_formating.py
niobures's picture
mmaction2
d3dbf03 verified
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import unittest
import numpy as np
import pytest
import torch
from mmengine.structures import InstanceData
from mmengine.testing import assert_dict_has_keys
from numpy.testing import assert_array_equal
from mmaction.datasets.transforms import (FormatAudioShape, FormatGCNInput,
FormatShape, PackActionInputs,
Transpose)
from mmaction.registry import TRANSFORMS
from mmaction.structures import ActionDataSample
from mmaction.utils import register_all_modules
register_all_modules()
class TestPackActionInputs(unittest.TestCase):
def test_transform(self):
# none input
with self.assertRaises(ValueError):
results = PackActionInputs()(dict())
# keypoint input
results = dict(keypoint=np.random.randn(2, 300, 17, 3), label=1)
transform = PackActionInputs()
results = transform(results)
self.assertIn('inputs', results)
self.assertIn('data_samples', results)
self.assertIsInstance(results['inputs'], torch.Tensor)
self.assertEqual(results['inputs'].shape, (2, 300, 17, 3))
self.assertEqual(results['data_samples'].gt_label,
torch.LongTensor([1]))
# heatmap_imgs input
results = dict(heatmap_imgs=np.random.randn(2, 17, 56, 56), label=1)
transform = PackActionInputs()
results = transform(results)
self.assertIn('inputs', results)
self.assertIn('data_samples', results)
self.assertIsInstance(results['inputs'], torch.Tensor)
self.assertEqual(results['inputs'].shape, (2, 17, 56, 56))
self.assertEqual(results['data_samples'].gt_label,
torch.LongTensor([1]))
# audios input
results = dict(audios=np.random.randn(3, 1, 128, 80), label=[1])
transform = PackActionInputs()
results = transform(results)
self.assertIn('inputs', results)
self.assertIn('data_samples', results)
self.assertEqual(results['inputs'].shape, (3, 1, 128, 80))
self.assertIsInstance(results['inputs'], torch.Tensor)
# text input
results = dict(text=np.random.randn(77))
transform = PackActionInputs()
results = transform(results)
self.assertIn('inputs', results)
self.assertIn('data_samples', results)
self.assertEqual(results['inputs'].shape, (77, ))
self.assertIsInstance(results['inputs'], torch.Tensor)
# imgs input with label
data = dict(
imgs=np.random.randn(2, 256, 256, 3),
label=[1],
filename='test.txt',
original_shape=(256, 256, 3),
img_shape=(256, 256, 3),
flip_direction='vertical')
transform = PackActionInputs()
results = transform(copy.deepcopy(data))
self.assertIn('inputs', results)
self.assertIn('data_samples', results)
self.assertIsInstance(results['inputs'], torch.Tensor)
self.assertIsInstance(results['data_samples'], ActionDataSample)
self.assertEqual(results['data_samples'].img_shape, (256, 256, 3))
self.assertEqual(results['data_samples'].gt_label,
torch.LongTensor([1]))
# Test grayscale image
data['imgs'] = data['imgs'].mean(-1)
results = transform(copy.deepcopy(data))
self.assertIn('inputs', results)
self.assertIsInstance(results['inputs'], torch.Tensor)
self.assertEqual(results['inputs'].shape, (2, 256, 256))
# imgs input with gt_bboxes
data = dict(
imgs=np.random.randn(256, 256, 3),
gt_bboxes=np.array([[0, 0, 340, 224]]),
gt_labels=[1],
proposals=np.array([[0, 0, 340, 224]]),
filename='test.txt')
transform = PackActionInputs()
results = transform(copy.deepcopy(data))
self.assertIn('inputs', results)
self.assertIsInstance(results['inputs'], torch.Tensor)
self.assertIn('data_samples', results)
self.assertIsInstance(results['data_samples'], ActionDataSample)
self.assertIsInstance(results['data_samples'].gt_instances,
InstanceData)
self.assertIsInstance(results['data_samples'].proposals, InstanceData)
# imgs and text input
data = dict(
imgs=np.random.randn(2, 256, 256, 3), text=np.random.randn(77))
transform = PackActionInputs(collect_keys=('imgs', 'text'))
results = transform(copy.deepcopy(data))
self.assertIn('inputs', results)
self.assertIn('data_samples', results)
self.assertIsInstance(results['inputs'], dict)
self.assertEqual(results['inputs']['imgs'].shape, (2, 256, 256, 3))
self.assertEqual(results['inputs']['text'].shape, (77, ))
def test_repr(self):
cfg = dict(
type='PackActionInputs', meta_keys=['flip_direction', 'img_shape'])
transform = TRANSFORMS.build(cfg)
self.assertEqual(
repr(transform), 'PackActionInputs(collect_keys=None, '
"meta_keys=['flip_direction', 'img_shape'])")
class TestPackLocalizationInputs(unittest.TestCase):
def test_transform(self):
# raw_feature input
data = dict(
raw_feature=np.random.randn(400, 5),
gt_bbox=np.array([[0.1, 0.3], [0.375, 0.625]]),
filename='test.txt')
cfg = dict(type='PackLocalizationInputs', keys=('gt_bbox', ))
transform = TRANSFORMS.build(cfg)
results = transform(copy.deepcopy(data))
self.assertIn('inputs', results)
self.assertIsInstance(results['inputs'], torch.Tensor)
self.assertIn('data_samples', results)
self.assertIsInstance(results['data_samples'], ActionDataSample)
self.assertIsInstance(results['data_samples'].gt_instances,
InstanceData)
del data['raw_feature']
with self.assertRaises(ValueError):
transform(copy.deepcopy(data))
# bsp_feature input
data['bsp_feature'] = np.random.randn(100, 32)
results = transform(copy.deepcopy(data))
self.assertIn('inputs', results)
self.assertIsInstance(results['inputs'], torch.Tensor)
self.assertIn('data_samples', results)
self.assertIsInstance(results['data_samples'], ActionDataSample)
self.assertIsInstance(results['data_samples'].gt_instances,
InstanceData)
def test_repr(self):
cfg = dict(
type='PackLocalizationInputs',
meta_keys=['video_name', 'feature_frame'])
transform = TRANSFORMS.build(cfg)
self.assertEqual(
repr(transform),
"PackLocalizationInputs(meta_keys=['video_name', 'feature_frame'])"
)
def test_transpose():
results = dict(imgs=np.random.randn(256, 256, 3))
keys = ['imgs']
order = [2, 0, 1]
transpose = Transpose(keys, order)
results = transpose(results)
assert results['imgs'].shape == (3, 256, 256)
assert repr(transpose) == transpose.__class__.__name__ + \
f'(keys={keys}, order={order})'
def test_format_shape():
with pytest.raises(ValueError):
# invalid input format
FormatShape('NHWC')
# 'NCHW' input format (RGB Modality)
results = dict(
imgs=np.random.randn(3, 224, 224, 3), num_clips=1, clip_len=3)
format_shape = FormatShape('NCHW')
assert format_shape(results)['input_shape'] == (3, 3, 224, 224)
# `NCHW` input format (Flow Modality)
results = dict(
imgs=np.random.randn(3, 224, 224, 2),
num_clips=1,
clip_len=3,
modality='Flow')
format_shape = FormatShape('NCHW')
assert format_shape(results)['input_shape'] == (1, 6, 224, 224)
# `NCTHW` input format with num_clips=1, clip_len=3
results = dict(
imgs=np.random.randn(3, 224, 224, 3), num_clips=1, clip_len=3)
format_shape = FormatShape('NCTHW')
assert format_shape(results)['input_shape'] == (1, 3, 3, 224, 224)
# `NCTHW` input format with num_clips=2, clip_len=3
results = dict(
imgs=np.random.randn(18, 224, 224, 3), num_clips=2, clip_len=3)
assert format_shape(results)['input_shape'] == (6, 3, 3, 224, 224)
target_keys = ['imgs', 'input_shape']
assert assert_dict_has_keys(results, target_keys)
# `NCTHW` input format with imgs and heatmap_imgs
results = dict(
imgs=np.random.randn(6, 224, 224, 3),
heatmap_imgs=np.random.randn(12, 17, 56, 56),
num_clips=2,
clip_len=dict(RGB=3, Pose=6))
results = format_shape(results)
assert results['input_shape'] == (2, 3, 3, 224, 224)
assert results['heatmap_input_shape'] == (2, 17, 6, 56, 56)
assert repr(format_shape) == "FormatShape(input_format='NCTHW')"
# `NCTHW_Heatmap` input format
results = dict(
imgs=np.random.randn(12, 17, 56, 56), num_clips=2, clip_len=6)
format_shape = FormatShape('NCTHW_Heatmap')
assert format_shape(results)['input_shape'] == (2, 17, 6, 56, 56)
# `NPTCHW` input format
results = dict(
imgs=np.random.randn(72, 224, 224, 3),
num_clips=9,
clip_len=1,
num_proposals=8)
format_shape = FormatShape('NPTCHW')
assert format_shape(results)['input_shape'] == (8, 9, 3, 224, 224)
def test_format_audio_shape():
with pytest.raises(ValueError):
# invalid input format
FormatAudioShape('XXXX')
# `NCTF` input format
results = dict(audios=np.random.randn(3, 128, 8))
format_shape = FormatAudioShape('NCTF')
assert format_shape(results)['input_shape'] == (3, 1, 128, 8)
assert repr(format_shape) == format_shape.__class__.__name__ + \
"(input_format='NCTF')"
def test_format_gcn_input():
with pytest.raises(AssertionError):
FormatGCNInput(mode='invalid')
results = dict(
keypoint=np.random.randn(2, 10, 17, 2),
keypoint_score=np.random.randn(2, 10, 17))
format_shape = FormatGCNInput(num_person=2, mode='zero')
results = format_shape(results)
assert results['keypoint'].shape == (1, 2, 10, 17, 3)
assert repr(format_shape) == 'FormatGCNInput(num_person=2, mode=zero)'
results = dict(keypoint=np.random.randn(2, 40, 25, 3), num_clips=4)
format_shape = FormatGCNInput(num_person=2, mode='zero')
results = format_shape(results)
assert results['keypoint'].shape == (4, 2, 10, 25, 3)
results = dict(keypoint=np.random.randn(1, 10, 25, 3))
format_shape = FormatGCNInput(num_person=2, mode='zero')
results = format_shape(results)
assert results['keypoint'].shape == (1, 2, 10, 25, 3)
assert_array_equal(results['keypoint'][:, 1], np.zeros((1, 10, 25, 3)))
results = dict(keypoint=np.random.randn(1, 10, 25, 3))
format_shape = FormatGCNInput(num_person=2, mode='loop')
results = format_shape(results)
assert results['keypoint'].shape == (1, 2, 10, 25, 3)
assert_array_equal(results['keypoint'][:, 1], results['keypoint'][:, 0])
results = dict(keypoint=np.random.randn(3, 10, 25, 3))
format_shape = FormatGCNInput(num_person=2, mode='zero')
results = format_shape(results)
assert results['keypoint'].shape == (1, 2, 10, 25, 3)