| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from __future__ import absolute_import |
| | from __future__ import division |
| | from __future__ import print_function |
| |
|
| | import itertools |
| |
|
| | from absl.testing import parameterized |
| | import numpy as np |
| | import tensorflow as tf |
| |
|
| | from tensorflow_model_optimization.python.core.internal.tensor_encoding.stages.research import misc |
| | from tensorflow_model_optimization.python.core.internal.tensor_encoding.testing import test_utils |
| |
|
| |
|
| | if tf.executing_eagerly(): |
| | tf.compat.v1.disable_eager_execution() |
| |
|
| |
|
| | class SplitBySmallValueEncodingStageTest(test_utils.BaseEncodingStageTest): |
| |
|
| | def default_encoding_stage(self): |
| | """See base class.""" |
| | return misc.SplitBySmallValueEncodingStage() |
| |
|
| | def default_input(self): |
| | """See base class.""" |
| | return tf.random.uniform([50], minval=-1.0, maxval=1.0) |
| |
|
| | @property |
| | def is_lossless(self): |
| | """See base class.""" |
| | return False |
| |
|
| | def common_asserts_for_test_data(self, data): |
| | """See base class.""" |
| | self._assert_is_integer( |
| | data.encoded_x[misc.SplitBySmallValueEncodingStage.ENCODED_INDICES_KEY]) |
| |
|
| | def _assert_is_integer(self, indices): |
| | """Asserts that indices values are integers.""" |
| | assert indices.dtype == np.int32 |
| |
|
| | @parameterized.parameters([tf.float32, tf.float64]) |
| | def test_input_types(self, x_dtype): |
| | |
| | x = tf.constant([1.0, 0.1, 0.01, 0.001, 0.0001], dtype=x_dtype) |
| | threshold = 0.05 |
| | stage = misc.SplitBySmallValueEncodingStage(threshold=threshold) |
| | encode_params, decode_params = stage.get_params() |
| | encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params, |
| | decode_params) |
| | test_data = test_utils.TestData(x, encoded_x, decoded_x) |
| | test_data = self.evaluate_test_data(test_data) |
| |
|
| | self._assert_is_integer(test_data.encoded_x[ |
| | misc.SplitBySmallValueEncodingStage.ENCODED_INDICES_KEY]) |
| |
|
| | |
| | expected_encoded_values = np.array([1.0, 0.1], dtype=x.dtype.as_numpy_dtype) |
| | expected_encoded_indices = np.array([0, 1], dtype=np.int32) |
| | expected_decoded_x = np.array([1.0, 0.1, 0., 0., 0.], |
| | dtype=x_dtype.as_numpy_dtype) |
| | self.assertAllEqual(test_data.encoded_x[stage.ENCODED_VALUES_KEY], |
| | expected_encoded_values) |
| | self.assertAllEqual(test_data.encoded_x[stage.ENCODED_INDICES_KEY], |
| | expected_encoded_indices) |
| | self.assertAllEqual(test_data.decoded_x, expected_decoded_x) |
| |
|
| | def test_all_zero_input_works(self): |
| | |
| | |
| | stage = misc.SplitBySmallValueEncodingStage() |
| | test_data = self.run_one_to_many_encode_decode(stage, |
| | lambda: tf.zeros([50])) |
| |
|
| | self.assertAllEqual(np.zeros((50)).astype(np.float32), test_data.decoded_x) |
| |
|
| | def test_all_below_threshold_works(self): |
| | |
| | |
| | stage = misc.SplitBySmallValueEncodingStage(threshold=0.1) |
| | x = tf.random.uniform([50], minval=-0.01, maxval=0.01) |
| | encode_params, decode_params = stage.get_params() |
| | encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params, |
| | decode_params) |
| | test_data = test_utils.TestData(x, encoded_x, decoded_x) |
| | test_data = self.evaluate_test_data(test_data) |
| |
|
| | expected_encoded_indices = np.array([], dtype=np.int32).reshape([0]) |
| | self.assertAllEqual(test_data.encoded_x[stage.ENCODED_VALUES_KEY], []) |
| | self.assertAllEqual(test_data.encoded_x[stage.ENCODED_INDICES_KEY], |
| | expected_encoded_indices) |
| | self.assertAllEqual(test_data.decoded_x, |
| | np.zeros([50], dtype=x.dtype.as_numpy_dtype)) |
| |
|
| |
|
| | class DifferenceBetweenIntegersEncodingStageTest( |
| | test_utils.BaseEncodingStageTest): |
| |
|
| | def default_encoding_stage(self): |
| | """See base class.""" |
| | return misc.DifferenceBetweenIntegersEncodingStage() |
| |
|
| | def default_input(self): |
| | """See base class.""" |
| | return tf.random.uniform([10], minval=0, maxval=10, dtype=tf.int64) |
| |
|
| | @property |
| | def is_lossless(self): |
| | """See base class.""" |
| | return True |
| |
|
| | def common_asserts_for_test_data(self, data): |
| | """See base class.""" |
| | self.assertAllEqual(data.x, data.decoded_x) |
| |
|
| | @parameterized.parameters( |
| | itertools.product([[1,], [2,], [10,]], [tf.int32, tf.int64])) |
| | def test_with_multiple_input_shapes(self, input_dims, dtype): |
| |
|
| | def x_fn(): |
| | return tf.random.uniform(input_dims, minval=0, maxval=10, dtype=dtype) |
| |
|
| | test_data = self.run_one_to_many_encode_decode( |
| | self.default_encoding_stage(), x_fn) |
| | self.common_asserts_for_test_data(test_data) |
| |
|
| | def test_empty_input_static(self): |
| | |
| | x = [] |
| | x = tf.convert_to_tensor(x, dtype=tf.int32) |
| | assert x.shape.as_list() == [0] |
| |
|
| | stage = self.default_encoding_stage() |
| | encode_params, decode_params = stage.get_params() |
| | encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params, |
| | decode_params) |
| |
|
| | test_data = self.evaluate_test_data( |
| | test_utils.TestData(x, encoded_x, decoded_x)) |
| | self.common_asserts_for_test_data(test_data) |
| |
|
| | def test_empty_input_dynamic(self): |
| | |
| | |
| | y = tf.zeros((10,)) |
| | indices = tf.compat.v2.where(tf.abs(y) > 1e-8) |
| | x = tf.gather_nd(y, indices) |
| | x = tf.cast(x, tf.int32) |
| | assert x.shape.as_list() == [None] |
| | stage = self.default_encoding_stage() |
| | encode_params, decode_params = stage.get_params() |
| | encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params, |
| | decode_params) |
| |
|
| | test_data = self.evaluate_test_data( |
| | test_utils.TestData(x, encoded_x, decoded_x)) |
| | assert test_data.x.shape == (0,) |
| | assert test_data.encoded_x[stage.ENCODED_VALUES_KEY].shape == (0,) |
| | assert test_data.decoded_x.shape == (0,) |
| |
|
| | @parameterized.parameters([tf.bool, tf.float32]) |
| | def test_encode_unsupported_type_raises(self, dtype): |
| | stage = self.default_encoding_stage() |
| | with self.assertRaisesRegexp(TypeError, 'Unsupported input type'): |
| | self.run_one_to_many_encode_decode( |
| | stage, lambda: tf.cast(self.default_input(), dtype)) |
| |
|
| | def test_encode_unsupported_input_shape_raises(self): |
| | x = tf.random.uniform((3, 4), maxval=10, dtype=tf.int32) |
| | stage = self.default_encoding_stage() |
| | params, _ = stage.get_params() |
| | with self.assertRaisesRegexp(ValueError, 'Number of dimensions must be 1'): |
| | stage.encode(x, params) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | tf.test.main() |
| |
|