| |
|
| |
|
| |
|
| |
|
| |
|
| | import numpy as np |
| | import hypothesis.strategies as st |
| | import unittest |
| | import caffe2.python.hypothesis_test_util as hu |
| | from caffe2.python import core, workspace |
| | from hypothesis import given, settings |
| | import caffe2.python.ideep_test_util as mu |
| |
|
| | @st.composite |
| | def _tensor_splits(draw, add_axis=False): |
| | """Generates (axis, split_info, tensor_splits) tuples.""" |
| | tensor = draw(hu.tensor(min_dim=2, min_value=4)) |
| | axis = draw(st.integers(-len(tensor.shape), len(tensor.shape) - 1)) |
| | if add_axis: |
| | |
| | |
| | return ( |
| | axis, |
| | np.ones(tensor.shape[axis], dtype=np.int32), |
| | [ |
| | np.array(tensor.take(i, axis=axis)) |
| | for i in range(tensor.shape[axis]) |
| | ] |
| | ) |
| | else: |
| | |
| | |
| | splits = sorted(draw( |
| | st.lists(elements=st.integers(0, tensor.shape[axis]), max_size=4) |
| | ) + [0, tensor.shape[axis]]) |
| | |
| | splits = list(set(splits)) |
| | return ( |
| | axis, |
| | np.array(np.diff(splits), dtype=np.int32), |
| | [ |
| | tensor.take(range(splits[i], splits[i + 1]), axis=axis) |
| | for i in range(len(splits) - 1) |
| | ], |
| | ) |
| |
|
| |
|
| | @unittest.skipIf(not workspace.C.use_mkldnn, "No MKLDNN support.") |
| | class TestConcatSplitOps(hu.HypothesisTestCase): |
| | @given(tensor_splits=_tensor_splits(), |
| | **mu.gcs) |
| | @settings(deadline=10000) |
| | def test_concat(self, tensor_splits, gc, dc): |
| | axis, _, splits = tensor_splits |
| |
|
| | op = core.CreateOperator( |
| | "Concat", |
| | ['X_{}'.format(i) for i in range(len(splits))], |
| | ['concat_result', 'split_info'], |
| | axis=axis |
| | ) |
| |
|
| | self.assertDeviceChecks(dc, op, splits, [0, 1]) |
| | self.assertGradientChecks(gc, op, splits, 0, [0]) |
| |
|
| | @given(tensor_splits=_tensor_splits(), |
| | split_as_arg=st.booleans(), |
| | **mu.gcs) |
| | @settings(deadline=10000) |
| | def test_split(self, tensor_splits, split_as_arg, gc, dc): |
| | axis, split_info, splits = tensor_splits |
| |
|
| | split_as_arg = True |
| |
|
| | if split_as_arg: |
| | input_names = ['input'] |
| | input_tensors = [np.concatenate(splits, axis=axis)] |
| | kwargs = dict(axis=axis, split=split_info) |
| | else: |
| | input_names = ['input', 'split'] |
| | input_tensors = [np.concatenate(splits, axis=axis), split_info] |
| | kwargs = dict(axis=axis) |
| |
|
| | op = core.CreateOperator( |
| | "Split", |
| | input_names, |
| | ['X_{}'.format(i) for i in range(len(split_info))], |
| | **kwargs |
| | ) |
| |
|
| | def split_ref(input, split=split_info): |
| | s = np.cumsum([0] + list(split)) |
| | return [ |
| | np.array(input.take(np.arange(s[i], s[i + 1]), axis=axis)) |
| | for i in range(len(split)) |
| | ] |
| | outputs_with_grad = range(len(split_info)) |
| | self.assertDeviceChecks(dc, op, input_tensors, outputs_with_grad) |
| | self.assertGradientChecks(gc, op, input_tensors, 0, outputs_with_grad) |
| |
|
| | @given(tensor_splits=_tensor_splits(add_axis=True), **mu.gcs) |
| | @settings(deadline=10000) |
| | def test_concat_add_axis(self, tensor_splits, gc, dc): |
| | axis, _, splits = tensor_splits |
| | op = core.CreateOperator( |
| | "Concat", |
| | ['X_{}'.format(i) for i in range(len(splits))], |
| | ['concat_result', 'split_info'], |
| | axis=axis, |
| | add_axis=1 |
| | ) |
| |
|
| | self.assertDeviceChecks(dc, op, splits, [0, 1]) |
| |
|
| | for i in range(len(splits)): |
| | self.assertGradientChecks(gc, op, splits, i, [0]) |
| |
|
| |
|
| | @given(tensor_splits=_tensor_splits(add_axis=True), **mu.gcs) |
| | def test_concat_with_TensorCPU(self, tensor_splits, gc, dc): |
| | axis, _, splits = tensor_splits |
| | op0 = core.CreateOperator( |
| | "Concat", |
| | ['X_{}'.format(i) for i in range(len(splits))], |
| | ['concat_result0', 'split_info0'], |
| | axis=axis, |
| | add_axis=1, |
| | device_option=dc[0] |
| | ) |
| | op1 = core.CreateOperator( |
| | "Concat", |
| | ['X_{}'.format(i) for i in range(len(splits))], |
| | ['concat_result1', 'split_info1'], |
| | axis=axis, |
| | add_axis=1, |
| | device_option=dc[1] |
| | ) |
| |
|
| | for i, X in enumerate(splits): |
| | workspace.FeedBlob('X_{}'.format(i), X, dc[0]) |
| |
|
| | workspace.RunOperatorOnce(op0) |
| | res0 = workspace.FetchBlob('concat_result0') |
| | inf0 = workspace.FetchBlob('split_info0') |
| |
|
| | workspace.RunOperatorOnce(op1) |
| | res1 = workspace.FetchBlob('concat_result1') |
| | inf1 = workspace.FetchBlob('split_info1') |
| |
|
| | if not np.allclose(res0, res1, atol=0.0, rtol=0.0): |
| | print(res1.flatten()) |
| | print(res0.flatten()) |
| | print(np.max(np.abs(res1 - res0))) |
| | self.assertTrue(False) |
| |
|
| | if not np.allclose(inf0, inf1, atol=0.0, rtol=0.0): |
| | print(inf1.flatten()) |
| | print(inf0.flatten()) |
| | print(np.max(np.abs(inf1 - inf0))) |
| | self.assertTrue(False) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | unittest.main() |
| |
|