| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Integration tests for CQAT, PCQAT cases.""" |
| from absl.testing import parameterized |
| import numpy as np |
| import tensorflow as tf |
|
|
| from tensorflow_model_optimization.python.core.clustering.keras import cluster |
| from tensorflow_model_optimization.python.core.clustering.keras import cluster_config |
| from tensorflow_model_optimization.python.core.clustering.keras.experimental import cluster as experimental_cluster |
| from tensorflow_model_optimization.python.core.keras.compat import keras |
| from tensorflow_model_optimization.python.core.quantization.keras import quantize |
| from tensorflow_model_optimization.python.core.quantization.keras.collab_opts.cluster_preserve import ( |
| default_8bit_cluster_preserve_quantize_scheme,) |
| from tensorflow_model_optimization.python.core.quantization.keras.collab_opts.cluster_preserve.cluster_utils import ( |
| strip_clustering_cqat,) |
|
|
|
|
| layers = keras.layers |
|
|
|
|
| class ClusterPreserveIntegrationTest(tf.test.TestCase, parameterized.TestCase): |
|
|
| def setUp(self): |
| super(ClusterPreserveIntegrationTest, self).setUp() |
| self.cluster_params = { |
| 'number_of_clusters': 4, |
| 'cluster_centroids_init': cluster_config.CentroidInitialization.LINEAR |
| } |
|
|
| def compile_and_fit(self, model): |
| """Here we compile and fit the model.""" |
| model.compile( |
| loss=keras.losses.categorical_crossentropy, |
| optimizer='adam', |
| metrics=['accuracy'], |
| ) |
| model.fit( |
| np.random.rand(20, 10), |
| keras.utils.to_categorical(np.random.randint(5, size=(20, 1)), 5), |
| batch_size=20, |
| ) |
|
|
| def _get_number_of_unique_weights(self, stripped_model, layer_nr, |
| weight_name): |
| layer = stripped_model.layers[layer_nr] |
| if isinstance(layer, quantize.quantize_wrapper.QuantizeWrapper): |
| for weight_item in layer.trainable_weights: |
| if weight_name in weight_item.name: |
| weight = weight_item |
| else: |
| weight = getattr(layer, weight_name) |
| weights_as_list = weight.numpy().flatten() |
| nr_of_unique_weights = len(set(weights_as_list)) |
| return nr_of_unique_weights |
|
|
| def _get_sparsity(self, model): |
| sparsity_list = [] |
| for layer in model.layers: |
| for weights in layer.trainable_weights: |
| if 'kernel' in weights.name: |
| np_weights = keras.backend.get_value(weights) |
| sparsity = 1.0 - np.count_nonzero(np_weights) / float( |
| np_weights.size) |
| sparsity_list.append(sparsity) |
|
|
| return sparsity_list |
|
|
| def _get_clustered_model(self, preserve_sparsity): |
| """Cluster the (sparse) model and return clustered_model.""" |
| tf.random.set_seed(1) |
| original_model = keras.Sequential([ |
| layers.Dense(5, activation='softmax', input_shape=(10,)), |
| layers.Flatten(), |
| ]) |
|
|
| |
| if preserve_sparsity: |
| first_layer_weights = original_model.layers[0].get_weights() |
| first_layer_weights[0][:][0:2] = 0.0 |
| original_model.layers[0].set_weights(first_layer_weights) |
|
|
| |
| clustering_params = { |
| 'number_of_clusters': 4, |
| 'cluster_centroids_init': cluster_config.CentroidInitialization.LINEAR, |
| 'preserve_sparsity': True |
| } |
|
|
| clustered_model = experimental_cluster.cluster_weights( |
| original_model, **clustering_params) |
|
|
| return clustered_model |
|
|
| def _get_conv_model(self, |
| nr_of_channels, |
| data_format=None, |
| kernel_size=(3, 3)): |
| """Returns functional model with Conv2D layer.""" |
| inp = keras.layers.Input(shape=(32, 32), batch_size=100) |
| shape = (1, 32, 32) if data_format == 'channels_first' else (32, 32, 1) |
| x = keras.layers.Reshape(shape)(inp) |
| x = keras.layers.Conv2D( |
| filters=nr_of_channels, |
| kernel_size=kernel_size, |
| data_format=data_format, |
| activation='relu', |
| )(x) |
| x = keras.layers.MaxPool2D(2, 2)(x) |
| out = keras.layers.Flatten()(x) |
| model = keras.Model(inputs=inp, outputs=out) |
| return model |
|
|
| def _compile_and_fit_conv_model(self, model, nr_epochs=1): |
| """Compile and fit conv model from _get_conv_model.""" |
| x_train = np.random.uniform(size=(500, 32, 32)) |
| y_train = np.random.randint(low=0, high=1024, size=(500,)) |
| model.compile( |
| optimizer=keras.optimizers.Adam(learning_rate=1e-4), |
| loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), |
| metrics=[keras.metrics.SparseCategoricalAccuracy(name='accuracy')], |
| ) |
|
|
| model.fit(x_train, y_train, epochs=nr_epochs, batch_size=100, verbose=1) |
|
|
| return model |
|
|
| def _get_conv_clustered_model(self, |
| nr_of_channels, |
| nr_of_clusters, |
| data_format, |
| preserve_sparsity, |
| kernel_size=(3, 3)): |
| """Returns clustered per channel model with Conv2D layer.""" |
| tf.random.set_seed(42) |
| model = self._get_conv_model(nr_of_channels, data_format, kernel_size) |
|
|
| if preserve_sparsity: |
| |
| assert model.layers[2].name == 'conv2d' |
|
|
| conv_layer_weights = model.layers[2].get_weights() |
| shape = conv_layer_weights[0].shape |
| conv_layer_weights_flatten = conv_layer_weights[0].flatten() |
|
|
| nr_elems = len(conv_layer_weights_flatten) |
| conv_layer_weights_flatten[0:1 + nr_elems // 2] = 0.0 |
| pruned_conv_layer_weights = tf.reshape(conv_layer_weights_flatten, shape) |
| conv_layer_weights[0] = pruned_conv_layer_weights |
| model.layers[2].set_weights(conv_layer_weights) |
|
|
| clustering_params = { |
| 'number_of_clusters': |
| nr_of_clusters, |
| 'cluster_centroids_init': |
| cluster_config.CentroidInitialization.KMEANS_PLUS_PLUS, |
| 'cluster_per_channel': |
| True, |
| 'preserve_sparsity': |
| preserve_sparsity |
| } |
|
|
| clustered_model = experimental_cluster.cluster_weights(model, |
| **clustering_params) |
| clustered_model = self._compile_and_fit_conv_model(clustered_model) |
|
|
| |
| return clustered_model |
|
|
| def _pcqat_training(self, preserve_sparsity, quant_aware_annotate_model): |
| """PCQAT training on the input model.""" |
| quant_aware_model = quantize.quantize_apply( |
| quant_aware_annotate_model, |
| scheme=default_8bit_cluster_preserve_quantize_scheme |
| .Default8BitClusterPreserveQuantizeScheme(preserve_sparsity)) |
|
|
| self.compile_and_fit(quant_aware_model) |
|
|
| stripped_pcqat_model = strip_clustering_cqat(quant_aware_model) |
|
|
| |
| |
| num_of_unique_weights_pcqat = self._get_number_of_unique_weights( |
| stripped_pcqat_model, 1, 'kernel') |
|
|
| sparsity_pcqat = self._get_sparsity(stripped_pcqat_model) |
|
|
| return sparsity_pcqat, num_of_unique_weights_pcqat |
|
|
| def testEndToEndClusterPreserve(self): |
| """Runs CQAT end to end and whole model is quantized.""" |
| original_model = keras.Sequential( |
| [layers.Dense(5, activation='softmax', input_shape=(10,))] |
| ) |
| clustered_model = cluster.cluster_weights( |
| original_model, |
| **self.cluster_params) |
| self.compile_and_fit(clustered_model) |
| clustered_model = cluster.strip_clustering(clustered_model) |
| num_of_unique_weights_clustering = self._get_number_of_unique_weights( |
| clustered_model, 0, 'kernel') |
|
|
| quant_aware_annotate_model = ( |
| quantize.quantize_annotate_model(clustered_model)) |
|
|
| quant_aware_model = quantize.quantize_apply( |
| quant_aware_annotate_model, |
| scheme=default_8bit_cluster_preserve_quantize_scheme |
| .Default8BitClusterPreserveQuantizeScheme()) |
|
|
| self.compile_and_fit(quant_aware_model) |
| stripped_cqat_model = strip_clustering_cqat(quant_aware_model) |
|
|
| |
| |
| num_of_unique_weights_cqat = self._get_number_of_unique_weights( |
| stripped_cqat_model, 1, 'kernel') |
| self.assertAllEqual(num_of_unique_weights_clustering, |
| num_of_unique_weights_cqat) |
|
|
| def testEndToEndClusterPreservePerLayer(self): |
| """Runs CQAT end to end and model is quantized per layers.""" |
| original_model = keras.Sequential([ |
| layers.Dense(5, activation='relu', input_shape=(10,)), |
| layers.Dense(5, activation='softmax', input_shape=(10,)), |
| ]) |
| clustered_model = cluster.cluster_weights( |
| original_model, |
| **self.cluster_params) |
| self.compile_and_fit(clustered_model) |
| clustered_model = cluster.strip_clustering(clustered_model) |
| num_of_unique_weights_clustering = self._get_number_of_unique_weights( |
| clustered_model, 1, 'kernel') |
|
|
| def apply_quantization_to_dense(layer): |
| if isinstance(layer, keras.layers.Dense): |
| return quantize.quantize_annotate_layer(layer) |
| return layer |
|
|
| quant_aware_annotate_model = keras.models.clone_model( |
| clustered_model, |
| clone_function=apply_quantization_to_dense, |
| ) |
|
|
| quant_aware_model = quantize.quantize_apply( |
| quant_aware_annotate_model, |
| scheme=default_8bit_cluster_preserve_quantize_scheme |
| .Default8BitClusterPreserveQuantizeScheme()) |
|
|
| self.compile_and_fit(quant_aware_model) |
| stripped_cqat_model = strip_clustering_cqat( |
| quant_aware_model) |
|
|
| |
| |
| num_of_unique_weights_cqat = self._get_number_of_unique_weights( |
| stripped_cqat_model, 2, 'kernel') |
| self.assertAllEqual(num_of_unique_weights_clustering, |
| num_of_unique_weights_cqat) |
|
|
| def testEndToEndClusterPreserveOneLayer(self): |
| """Runs CQAT end to end and model is quantized only for a single layer.""" |
| original_model = keras.Sequential([ |
| layers.Dense(5, activation='relu', input_shape=(10,)), |
| layers.Dense(5, activation='softmax', input_shape=(10,), name='qat'), |
| ]) |
| clustered_model = cluster.cluster_weights( |
| original_model, |
| **self.cluster_params) |
| self.compile_and_fit(clustered_model) |
| clustered_model = cluster.strip_clustering(clustered_model) |
| num_of_unique_weights_clustering = self._get_number_of_unique_weights( |
| clustered_model, 1, 'kernel') |
|
|
| def apply_quantization_to_dense(layer): |
| if isinstance(layer, keras.layers.Dense): |
| if layer.name == 'qat': |
| return quantize.quantize_annotate_layer(layer) |
| return layer |
|
|
| quant_aware_annotate_model = keras.models.clone_model( |
| clustered_model, |
| clone_function=apply_quantization_to_dense, |
| ) |
|
|
| quant_aware_model = quantize.quantize_apply( |
| quant_aware_annotate_model, |
| scheme=default_8bit_cluster_preserve_quantize_scheme |
| .Default8BitClusterPreserveQuantizeScheme()) |
|
|
| self.compile_and_fit(quant_aware_model) |
|
|
| stripped_cqat_model = strip_clustering_cqat( |
| quant_aware_model) |
|
|
| |
| |
| num_of_unique_weights_cqat = self._get_number_of_unique_weights( |
| stripped_cqat_model, 1, 'kernel') |
| self.assertAllEqual(num_of_unique_weights_clustering, |
| num_of_unique_weights_cqat) |
|
|
| def testEndToEndPruneClusterPreserveQAT(self): |
| """Runs PCQAT end to end when we quantize the whole model.""" |
| preserve_sparsity = True |
| clustered_model = self._get_clustered_model(preserve_sparsity) |
| |
| first_layer_weights = clustered_model.layers[0].weights[1] |
| stripped_model_before_tuning = cluster.strip_clustering( |
| clustered_model) |
| nr_of_unique_weights_before = self._get_number_of_unique_weights( |
| stripped_model_before_tuning, 0, 'kernel') |
|
|
| self.compile_and_fit(clustered_model) |
|
|
| stripped_model_clustered = cluster.strip_clustering(clustered_model) |
| weights_after_tuning = stripped_model_clustered.layers[0].kernel |
| nr_of_unique_weights_after = self._get_number_of_unique_weights( |
| stripped_model_clustered, 0, 'kernel') |
|
|
| |
| |
| self.assertEqual(nr_of_unique_weights_before, nr_of_unique_weights_after) |
|
|
| |
| |
| |
| |
| self.assertTrue( |
| np.array_equal(first_layer_weights[:][0:2], |
| weights_after_tuning[:][0:2])) |
|
|
| |
| sparsity_pruning = self._get_sparsity(stripped_model_clustered) |
|
|
| |
| quant_aware_annotate_model = ( |
| quantize.quantize_annotate_model(stripped_model_clustered) |
| ) |
|
|
| |
| |
| preserve_sparsity = True |
| sparsity_pcqat, unique_weights_pcqat = self._pcqat_training( |
| preserve_sparsity, quant_aware_annotate_model) |
| self.assertAllGreaterEqual(np.array(sparsity_pcqat), |
| sparsity_pruning[0]) |
| self.assertAllEqual(nr_of_unique_weights_after, unique_weights_pcqat) |
|
|
| def testEndToEndClusterPreserveQATClusteredPerChannel( |
| self, data_format='channels_last'): |
| """Runs CQAT end to end for the model that is clustered per channel.""" |
|
|
| nr_of_channels = 12 |
| nr_of_clusters = 4 |
|
|
| clustered_model = self._get_conv_clustered_model( |
| nr_of_channels, nr_of_clusters, data_format, preserve_sparsity=False) |
| stripped_model = cluster.strip_clustering(clustered_model) |
|
|
| |
| conv2d_layer = stripped_model.layers[2] |
| self.assertEqual(conv2d_layer.name, 'conv2d') |
|
|
| |
| nr_unique_weights = -1 |
|
|
| for weight in conv2d_layer.weights: |
| if 'kernel' in weight.name: |
| nr_unique_weights = len(np.unique(weight.numpy())) |
| self.assertLessEqual(nr_unique_weights, nr_of_clusters*nr_of_channels) |
|
|
| quant_aware_annotate_model = ( |
| quantize.quantize_annotate_model(stripped_model) |
| ) |
|
|
| quant_aware_model = quantize.quantize_apply( |
| quant_aware_annotate_model, |
| scheme=default_8bit_cluster_preserve_quantize_scheme |
| .Default8BitClusterPreserveQuantizeScheme()) |
|
|
| |
| model = self._compile_and_fit_conv_model(quant_aware_model, 3) |
|
|
| stripped_cqat_model = strip_clustering_cqat(model) |
|
|
| |
| |
| layer_nr = 3 |
| num_of_unique_weights_cqat = self._get_number_of_unique_weights( |
| stripped_cqat_model, layer_nr, 'kernel') |
| self.assertLessEqual(num_of_unique_weights_cqat, nr_unique_weights) |
|
|
| |
| |
| layer = stripped_cqat_model.layers[layer_nr] |
| weight_to_check = None |
| if isinstance(layer, quantize.quantize_wrapper.QuantizeWrapper): |
| for weight_item in layer.trainable_weights: |
| if 'kernel' in weight_item.name: |
| weight_to_check = weight_item |
|
|
| assert weight_to_check is not None |
|
|
| for i in range(nr_of_channels): |
| nr_unique_weights_per_channel = len( |
| np.unique(weight_to_check[:, :, :, i])) |
| assert nr_unique_weights_per_channel == nr_of_clusters |
|
|
| def testEndToEndPCQATClusteredPerChannel(self, data_format='channels_last'): |
| """Runs PCQAT end to end for the model that is clustered per channel.""" |
|
|
| nr_of_channels = 12 |
| nr_of_clusters = 4 |
|
|
| clustered_model = self._get_conv_clustered_model( |
| nr_of_channels, nr_of_clusters, data_format, preserve_sparsity=True) |
| stripped_model = cluster.strip_clustering(clustered_model) |
|
|
| |
| conv2d_layer = stripped_model.layers[2] |
| self.assertEqual(conv2d_layer.name, 'conv2d') |
|
|
| |
| nr_unique_weights = -1 |
|
|
| for weight in conv2d_layer.weights: |
| if 'kernel' in weight.name: |
| nr_unique_weights = len(np.unique(weight.numpy())) |
| self.assertLessEqual(nr_unique_weights, nr_of_clusters*nr_of_channels) |
|
|
| |
| |
| control_sparsity = self._get_sparsity(stripped_model) |
| self.assertGreater(control_sparsity[0], 0.5) |
|
|
| quant_aware_annotate_model = ( |
| quantize.quantize_annotate_model(stripped_model) |
| ) |
|
|
| quant_aware_model = quantize.quantize_apply( |
| quant_aware_annotate_model, |
| scheme=default_8bit_cluster_preserve_quantize_scheme |
| .Default8BitClusterPreserveQuantizeScheme()) |
|
|
| |
| model = self._compile_and_fit_conv_model(quant_aware_model, 3) |
|
|
| stripped_cqat_model = strip_clustering_cqat(model) |
|
|
| |
| |
| layer_nr = 3 |
| num_of_unique_weights_cqat = self._get_number_of_unique_weights( |
| stripped_cqat_model, layer_nr, 'kernel') |
| self.assertLessEqual(num_of_unique_weights_cqat, nr_unique_weights) |
|
|
| |
| |
| layer = stripped_cqat_model.layers[layer_nr] |
| weight_to_check = None |
| if isinstance(layer, quantize.quantize_wrapper.QuantizeWrapper): |
| for weight_item in layer.trainable_weights: |
| if 'kernel' in weight_item.name: |
| weight_to_check = weight_item |
|
|
| assert weight_to_check is not None |
|
|
| for i in range(nr_of_channels): |
| nr_unique_weights_per_channel = len( |
| np.unique(weight_to_check[:, :, :, i])) |
| assert nr_unique_weights_per_channel == nr_of_clusters |
|
|
| cqat_sparsity = self._get_sparsity(stripped_cqat_model) |
| self.assertLessEqual(cqat_sparsity[0], control_sparsity[0]) |
|
|
| def testEndToEndPCQATClusteredPerChannelConv2d1x1(self, |
| data_format='channels_last' |
| ): |
| """Runs PCQAT for model containing a 1x1 Conv2D. |
| |
| (with insufficient number of weights per channel). |
| |
| Args: |
| data_format: Format of input data. |
| """ |
| nr_of_channels = 12 |
| nr_of_clusters = 4 |
|
|
| |
| |
| with self.assertWarnsRegex(Warning, |
| r'Layer conv2d does not have enough weights'): |
| clustered_model = self._get_conv_clustered_model( |
| nr_of_channels, |
| nr_of_clusters, |
| data_format, |
| preserve_sparsity=True, |
| kernel_size=(1, 1)) |
| stripped_model = cluster.strip_clustering(clustered_model) |
|
|
| |
| conv2d_layer = stripped_model.layers[2] |
| self.assertEqual(conv2d_layer.name, 'conv2d') |
|
|
| for weight in conv2d_layer.weights: |
| if 'kernel' in weight.name: |
| |
| nr_original_weights = len(np.unique(weight.numpy())) |
| self.assertLess(nr_original_weights, nr_of_channels * nr_of_clusters) |
|
|
| |
| |
| for channel in range(nr_of_channels): |
| channel_weights = ( |
| weight[:, channel, :, :] |
| if data_format == 'channels_first' else weight[:, :, :, channel]) |
| nr_channel_weights = len(channel_weights) |
| self.assertGreater(nr_channel_weights, 0) |
| self.assertLessEqual(nr_channel_weights, nr_of_clusters) |
|
|
| |
| |
| control_sparsity = self._get_sparsity(stripped_model) |
| self.assertGreater(control_sparsity[0], 0.5) |
|
|
| quant_aware_annotate_model = ( |
| quantize.quantize_annotate_model(stripped_model)) |
|
|
| with self.assertWarnsRegex( |
| Warning, r'No clustering performed on layer quant_conv2d'): |
| quant_aware_model = quantize.quantize_apply( |
| quant_aware_annotate_model, |
| scheme=default_8bit_cluster_preserve_quantize_scheme |
| .Default8BitClusterPreserveQuantizeScheme(preserve_sparsity=True)) |
|
|
| |
| model = self._compile_and_fit_conv_model(quant_aware_model, 3) |
|
|
| stripped_cqat_model = strip_clustering_cqat(model) |
|
|
| |
| |
| layer_nr = 3 |
| num_of_unique_weights_cqat = self._get_number_of_unique_weights( |
| stripped_cqat_model, layer_nr, 'kernel') |
| self.assertEqual(num_of_unique_weights_cqat, nr_original_weights) |
|
|
| cqat_sparsity = self._get_sparsity(stripped_cqat_model) |
| self.assertLessEqual(cqat_sparsity[0], control_sparsity[0]) |
|
|
| def testPassingNonPrunedModelToPCQAT(self): |
| """Runs PCQAT as CQAT if the input model is not pruned.""" |
| preserve_sparsity = False |
| clustered_model = self._get_clustered_model(preserve_sparsity) |
|
|
| clustered_model = cluster.strip_clustering(clustered_model) |
| nr_of_unique_weights_after = self._get_number_of_unique_weights( |
| clustered_model, 0, 'kernel') |
|
|
| |
| |
| quant_aware_annotate_model = ( |
| quantize.quantize_annotate_model(clustered_model) |
| ) |
|
|
| quant_aware_model = quantize.quantize_apply( |
| quant_aware_annotate_model, |
| scheme=default_8bit_cluster_preserve_quantize_scheme |
| .Default8BitClusterPreserveQuantizeScheme(True)) |
|
|
| self.compile_and_fit(quant_aware_model) |
| stripped_pcqat_model = strip_clustering_cqat( |
| quant_aware_model) |
|
|
| |
| num_of_unique_weights_pcqat = self._get_number_of_unique_weights( |
| stripped_pcqat_model, 1, 'kernel') |
| self.assertAllEqual(nr_of_unique_weights_after, |
| num_of_unique_weights_pcqat) |
|
|
| @parameterized.parameters((0.), (2.)) |
| def testPassingModelWithUniformWeightsToPCQAT(self, uniform_weights): |
| """If pruned_clustered_model has uniform weights, it won't break PCQAT.""" |
| preserve_sparsity = True |
| original_model = keras.Sequential([ |
| layers.Dense(5, activation='softmax', input_shape=(10,)), |
| layers.Flatten(), |
| ]) |
|
|
| |
| first_layer_weights = original_model.layers[0].get_weights() |
| first_layer_weights[0][:] = uniform_weights |
| original_model.layers[0].set_weights(first_layer_weights) |
|
|
| |
| clustering_params = { |
| 'number_of_clusters': 4, |
| 'cluster_centroids_init': cluster_config.CentroidInitialization.LINEAR, |
| 'preserve_sparsity': True |
| } |
|
|
| clustered_model = experimental_cluster.cluster_weights( |
| original_model, **clustering_params) |
| clustered_model = cluster.strip_clustering(clustered_model) |
|
|
| nr_of_unique_weights_after = self._get_number_of_unique_weights( |
| clustered_model, 0, 'kernel') |
| sparsity_pruning = self._get_sparsity(clustered_model) |
|
|
| quant_aware_annotate_model = ( |
| quantize.quantize_annotate_model(clustered_model) |
| ) |
|
|
| sparsity_pcqat, unique_weights_pcqat = self._pcqat_training( |
| preserve_sparsity, quant_aware_annotate_model) |
| self.assertAllGreaterEqual(np.array(sparsity_pcqat), |
| sparsity_pruning[0]) |
| self.assertAllEqual(nr_of_unique_weights_after, unique_weights_pcqat) |
|
|
| def testTrainableWeightsBehaveCorrectlyDuringPCQAT(self): |
| """PCQAT zero centroid masks stay the same and trainable variables are updating between epochs.""" |
| preserve_sparsity = True |
| clustered_model = self._get_clustered_model(preserve_sparsity) |
| clustered_model = cluster.strip_clustering(clustered_model) |
|
|
| |
| quant_aware_annotate_model = ( |
| quantize.quantize_annotate_model(clustered_model) |
| ) |
|
|
| quant_aware_model = quantize.quantize_apply( |
| quant_aware_annotate_model, |
| scheme=default_8bit_cluster_preserve_quantize_scheme |
| .Default8BitClusterPreserveQuantizeScheme(True)) |
|
|
| quant_aware_model.compile( |
| loss=keras.losses.categorical_crossentropy, |
| optimizer='adam', |
| metrics=['accuracy'], |
| ) |
|
|
| class CheckCentroidsAndTrainableVarsCallback(keras.callbacks.Callback): |
| """Check the updates of trainable variables and centroid masks.""" |
|
|
| def on_epoch_begin(self, batch, logs=None): |
| |
| vars_dictionary = self.model.layers[1]._weight_vars[0][2] |
| self.centroid_mask = vars_dictionary['centroids_mask'] |
| self.zero_centroid_index_begin = np.where( |
| self.centroid_mask == 0)[0] |
|
|
| |
| self.layer_kernel = ( |
| self.model.layers[1].weights[3].numpy() |
| ) |
| self.original_weight = vars_dictionary['ori_weights_vars_tf'].numpy() |
| self.centroids = vars_dictionary['cluster_centroids_tf'].numpy() |
|
|
| def on_epoch_end(self, batch, logs=None): |
| |
| vars_dictionary = self.model.layers[1]._weight_vars[0][2] |
| self.zero_centroid_index_end = np.where( |
| vars_dictionary['centroids_mask'] == 0)[0] |
| assert np.array_equal( |
| self.zero_centroid_index_begin, |
| self.zero_centroid_index_end |
| ) |
|
|
| |
| assert not np.array_equal( |
| self.layer_kernel, |
| self.model.layers[1].weights[3].numpy() |
| ) |
| assert not np.array_equal( |
| self.original_weight, |
| vars_dictionary['ori_weights_vars_tf'].numpy() |
| ) |
| assert not np.array_equal( |
| self.centroids, |
| vars_dictionary['cluster_centroids_tf'].numpy() |
| ) |
|
|
| |
| |
| |
| quant_aware_model.fit( |
| np.random.rand(20, 10), |
| keras.utils.to_categorical(np.random.randint(5, size=(20, 1)), 5), |
| steps_per_epoch=5, |
| epochs=3, |
| callbacks=[CheckCentroidsAndTrainableVarsCallback()], |
| ) |
|
|
|
|
| if __name__ == '__main__': |
| tf.test.main() |
|
|