| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Encoding stages implementing various clipping strategies. |
| |
| The base classes, `ClipByNormEncodingStage` and `ClipByValueEncodingStage`, are |
| expected to be subclassed as implementations of |
| `AdaptiveEncodingStageInterface`, to realize a variety of clipping strategies |
| that are adaptive to the data being processed in an iterative execution. |
| """ |
|
|
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
|
|
| import collections |
| import tensorflow as tf |
|
|
| from tensorflow_model_optimization.python.core.internal.tensor_encoding.core import encoding_stage |
|
|
|
|
| @encoding_stage.tf_style_encoding_stage |
| class ClipByNormEncodingStage(encoding_stage.EncodingStageInterface): |
| """Encoding stage applying clipping by norm (L-2 ball projection). |
| |
| See `tf.clip_by_norm` for more information. |
| """ |
|
|
| ENCODED_VALUES_KEY = 'clipped_values' |
| NORM_PARAMS_KEY = 'norm_param' |
|
|
| def __init__(self, clip_norm): |
| """Initializer for the `ClipByNormEncodingStage`. |
| |
| Args: |
| clip_norm: A scalar, norm of the ball onto which to project. |
| """ |
| self._clip_norm = clip_norm |
|
|
| @property |
| def name(self): |
| """See base class.""" |
| return 'clip_by_norm' |
|
|
| @property |
| def compressible_tensors_keys(self): |
| """See base class.""" |
| return [self.ENCODED_VALUES_KEY] |
|
|
| @property |
| def commutes_with_sum(self): |
| """See base class.""" |
| return True |
|
|
| @property |
| def decode_needs_input_shape(self): |
| """See base class.""" |
| return False |
|
|
| def get_params(self): |
| """See base class.""" |
| encode_params = collections.OrderedDict([(self.NORM_PARAMS_KEY, |
| self._clip_norm)]) |
| decode_params = collections.OrderedDict() |
| return encode_params, decode_params |
|
|
| def encode(self, x, encode_params): |
| """See base class.""" |
| clipped_x = tf.clip_by_norm( |
| x, tf.cast(encode_params[self.NORM_PARAMS_KEY], x.dtype)) |
| return collections.OrderedDict([(self.ENCODED_VALUES_KEY, clipped_x)]) |
|
|
| def decode(self, |
| encoded_tensors, |
| decode_params, |
| num_summands=None, |
| shape=None): |
| """See base class.""" |
| del decode_params, num_summands, shape |
| return tf.identity(encoded_tensors[self.ENCODED_VALUES_KEY]) |
|
|
|
|
| @encoding_stage.tf_style_encoding_stage |
| class ClipByValueEncodingStage(encoding_stage.EncodingStageInterface): |
| """Encoding stage applying clipping by value (L-infinity ball projection). |
| |
| See `tf.clip_by_value` for more information. |
| """ |
|
|
| ENCODED_VALUES_KEY = 'clipped_values' |
| MIN_PARAMS_KEY = 'min_param' |
| MAX_PARAMS_KEY = 'max_param' |
|
|
| def __init__(self, clip_value_min, clip_value_max): |
| """Initializer for the `ClipByValueEncodingStage`. |
| |
| Args: |
| clip_value_min: A scalar, the minimum value to which to clip. |
| clip_value_max: A scalar, the maximum value to which to clip. |
| """ |
| self._clip_value_min = clip_value_min |
| self._clip_value_max = clip_value_max |
|
|
| @property |
| def name(self): |
| """See base class.""" |
| return 'clip_by_value' |
|
|
| @property |
| def compressible_tensors_keys(self): |
| """See base class.""" |
| return [self.ENCODED_VALUES_KEY] |
|
|
| @property |
| def commutes_with_sum(self): |
| """See base class.""" |
| return True |
|
|
| @property |
| def decode_needs_input_shape(self): |
| """See base class.""" |
| return False |
|
|
| def get_params(self): |
| """See base class.""" |
| params = collections.OrderedDict([ |
| (self.MIN_PARAMS_KEY, self._clip_value_min), |
| (self.MAX_PARAMS_KEY, self._clip_value_max) |
| ]) |
| return params, collections.OrderedDict() |
|
|
| def encode(self, x, encode_params): |
| """See base class.""" |
| clipped_x = tf.clip_by_value( |
| x, |
| tf.cast(encode_params[self.MIN_PARAMS_KEY], x.dtype), |
| tf.cast(encode_params[self.MAX_PARAMS_KEY], x.dtype)) |
| return collections.OrderedDict([(self.ENCODED_VALUES_KEY, clipped_x)]) |
|
|
| def decode(self, |
| encoded_tensors, |
| decode_params, |
| num_summands=None, |
| shape=None): |
| """See base class.""" |
| del decode_params, num_summands, shape |
| return tf.identity(encoded_tensors[self.ENCODED_VALUES_KEY]) |
|
|