| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Tests for ClusterPreserveQuantizeRegistry.""" |
|
|
| import tensorflow as tf |
|
|
| from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry |
| from tensorflow_model_optimization.python.core.keras.compat import keras |
| from tensorflow_model_optimization.python.core.quantization.keras import quantize_config |
| from tensorflow_model_optimization.python.core.quantization.keras.collab_opts.cluster_preserve import cluster_preserve_quantize_registry |
| from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry |
|
|
|
|
| QuantizeConfig = quantize_config.QuantizeConfig |
| layers = keras.layers |
|
|
|
|
| class ClusterPreserveQuantizeRegistryTest(tf.test.TestCase): |
|
|
| def setUp(self): |
| super(ClusterPreserveQuantizeRegistryTest, self).setUp() |
| |
| self.cluster_preserve_quantize_registry = ( |
| cluster_preserve_quantize_registry.ClusterPreserveQuantizeRegistry( |
| False) |
| ) |
| |
| |
| self.layer_conv2d = layers.Conv2D(10, (2, 2)) |
| self.layer_conv2d.build((2, 2)) |
| |
| self.layer_dense = layers.Dense(10) |
| self.layer_dense.build((2, 2)) |
| |
| self.layer_relu = layers.ReLU() |
| self.layer_relu.build((2, 2)) |
|
|
| |
| |
| self.layer_custom = self.CustomLayer() |
| self.layer_custom.build() |
|
|
| class CustomLayer(layers.Layer): |
| """A simple custom layer with training weights.""" |
|
|
| def build(self, input_shape=(2, 2)): |
| self.add_weight(shape=input_shape, |
| initializer='random_normal', |
| trainable=True) |
|
|
| class CustomQuantizeConfig(QuantizeConfig): |
| """A dummy concrete class for testing unregistered configs.""" |
|
|
| def get_weights_and_quantizers(self, layer): |
| return [] |
|
|
| def get_activations_and_quantizers(self, layer): |
| return [] |
|
|
| def set_quantize_weights(self, layer, quantize_weights): |
| pass |
|
|
| def set_quantize_activations(self, layer, quantize_activations): |
| pass |
|
|
| def get_output_quantizers(self, layer): |
| return [] |
|
|
| def get_config(self): |
| return {} |
|
|
| def testSupportsKerasLayer(self): |
| |
| self.assertTrue( |
| self.cluster_preserve_quantize_registry.supports(self.layer_dense)) |
| self.assertTrue( |
| self.cluster_preserve_quantize_registry.supports(self.layer_conv2d)) |
| |
| self.assertTrue( |
| self.cluster_preserve_quantize_registry.supports(self.layer_relu)) |
|
|
| def testDoesNotSupportCustomLayer(self): |
| self.assertFalse( |
| self.cluster_preserve_quantize_registry.supports(self.layer_custom)) |
|
|
| def testApplyClusterPreserveWithQuantizeConfig(self): |
| (self.cluster_preserve_quantize_registry |
| .apply_cluster_preserve_quantize_config( |
| self.layer_conv2d, |
| default_8bit_quantize_registry.Default8BitConvQuantizeConfig( |
| ['kernel'], ['activation'], False))) |
|
|
| def testRaisesErrorUnsupportedQuantizeConfigWithLayer(self): |
| with self.assertRaises( |
| ValueError, msg='Unregistered QuantizeConfigs should raise error.'): |
| (self.cluster_preserve_quantize_registry. |
| apply_cluster_preserve_quantize_config( |
| self.layer_conv2d, self.CustomQuantizeConfig)) |
|
|
| with self.assertRaises(ValueError, |
| msg='Unregistered layers should raise error.'): |
| (self.cluster_preserve_quantize_registry. |
| apply_cluster_preserve_quantize_config( |
| self.layer_custom, self.CustomQuantizeConfig)) |
|
|
|
|
| class ClusterPreserveDefault8bitQuantizeRegistryTest(tf.test.TestCase): |
|
|
| def setUp(self): |
| super(ClusterPreserveDefault8bitQuantizeRegistryTest, self).setUp() |
| self.default_8bit_quantize_registry = ( |
| default_8bit_quantize_registry.Default8BitQuantizeRegistry()) |
| self.cluster_registry = clustering_registry.ClusteringRegistry() |
| |
| self.cluster_preserve_quantize_registry = ( |
| cluster_preserve_quantize_registry.ClusterPreserveQuantizeRegistry( |
| False)) |
|
|
| def testSupportsClusterDefault8bitQuantizeKerasLayers(self): |
| |
| |
| cqat_layers_config_map = ( |
| self.cluster_preserve_quantize_registry._LAYERS_CONFIG_MAP) |
| for cqat_support_layer in cqat_layers_config_map: |
| if cqat_layers_config_map[cqat_support_layer].weight_attrs and ( |
| cqat_layers_config_map[cqat_support_layer].quantize_config_attrs): |
| self.assertIn( |
| cqat_support_layer, self.cluster_registry._LAYERS_WEIGHTS_MAP, |
| msg='Clusteirng doesn\'t support {}'.format(cqat_support_layer)) |
| self.assertIn( |
| cqat_support_layer, |
| self.default_8bit_quantize_registry._layer_quantize_map, |
| msg='Default 8bit QAT doesn\'t support {}'.format( |
| cqat_support_layer)) |
|
|
|
|
| if __name__ == '__main__': |
| tf.test.main() |
|
|