Spaces:
Sleeping
Sleeping
| # Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Exponential moving average optimizer.""" | |
| from typing import List, Optional | |
| import tensorflow as tf, tf_keras | |
| # pylint: disable=protected-access | |
| def maybe_merge_call(fn, strategy, *args, **kwargs): | |
| """Maybe invoke `fn` via `merge_call` which may or may not be fulfilled. | |
| The caller of this utility function requests to invoke `fn` via `merge_call` | |
| at `tf.distribute.Strategy`'s best efforts. It is `tf.distribute`'s internal | |
| whether the request is honored, depending on the `Strategy`. See | |
| `tf.distribute.ReplicaContext.merge_call()` for more information. | |
| This is adapted from tensorflow/python/distribute/merge_call_interim.py. | |
| Args: | |
| fn: the function to be invoked. | |
| strategy: the `tf.distribute.Strategy` to call `fn` with. | |
| *args: the positional arguments to be passed in to `fn`. | |
| **kwargs: the keyword arguments to be passed in to `fn`. | |
| Returns: | |
| The return value of the `fn` call. | |
| """ | |
| if strategy.extended._use_merge_call(): | |
| return tf.distribute.get_replica_context().merge_call( | |
| fn, args=args, kwargs=kwargs | |
| ) | |
| else: | |
| return fn(strategy, *args, **kwargs) | |
| class ExponentialMovingAverage(tf_keras.optimizers.legacy.Optimizer): | |
| """Optimizer that computes an exponential moving average of the variables. | |
| Empirically it has been found that using the moving average of the trained | |
| parameters of a deep network is better than using its trained parameters | |
| directly. This optimizer allows you to compute this moving average and swap | |
| the variables at save time so that any code outside of the training loop | |
| will use by default the average values instead of the original ones. | |
| Example of usage for training: | |
| ```python | |
| opt = tf_keras.optimizers.SGD(learning_rate) | |
| opt = ExponentialMovingAverage(opt) | |
| opt.shadow_copy(model) | |
| ``` | |
| At test time, swap the shadow variables to evaluate on the averaged weights: | |
| ```python | |
| opt.swap_weights() | |
| # Test eval the model here | |
| opt.swap_weights() | |
| ``` | |
| """ | |
| def __init__(self, | |
| optimizer: tf_keras.optimizers.Optimizer, | |
| trainable_weights_only: bool = True, | |
| average_decay: float = 0.99, | |
| start_step: int = 0, | |
| dynamic_decay: bool = True, | |
| name: str = 'ExponentialMovingAverage', | |
| **kwargs): | |
| """Construct a new ExponentialMovingAverage optimizer. | |
| Args: | |
| optimizer: `tf_keras.optimizers.Optimizer` that will be | |
| used to compute and apply gradients. | |
| trainable_weights_only: 'bool', if True, only model trainable weights will | |
| be updated. Otherwise, all model weights will be updated. This mainly | |
| affects batch normalization parameters. | |
| average_decay: float. Decay to use to maintain the moving averages | |
| of trained variables. | |
| start_step: int. What step to start the moving average. | |
| dynamic_decay: bool. Whether to change the decay based on the number | |
| of optimizer updates. Decay will start at 0.1 and gradually increase | |
| up to `average_decay` after each optimizer update. This behavior is | |
| similar to `tf.train.ExponentialMovingAverage` in TF 1.x. | |
| name: Optional name for the operations created when applying | |
| gradients. Defaults to "moving_average". | |
| **kwargs: keyword arguments. Allowed to be {`clipnorm`, | |
| `clipvalue`, `lr`, `decay`}. | |
| """ | |
| super().__init__(name, **kwargs) | |
| self._average_decay = average_decay | |
| self._trainable_weights_only = trainable_weights_only | |
| self._start_step = tf.constant(start_step, tf.float32) | |
| self._dynamic_decay = dynamic_decay | |
| self._optimizer = optimizer | |
| self._track_trackable(self._optimizer, 'ema_base_optimizer') | |
| self._average_weights = None | |
| self._model_weights = None | |
| def shadow_copy(self, model: tf_keras.Model): | |
| """Creates shadow variables for the given model weights.""" | |
| if self._trainable_weights_only: | |
| self._model_weights = model.trainable_variables | |
| else: | |
| self._model_weights = model.variables | |
| for var in self._model_weights: | |
| self.add_slot(var, 'average', initializer='zeros') | |
| self._average_weights = [ | |
| self.get_slot(var, 'average') for var in self._model_weights | |
| ] | |
| def has_shadow_copy(self): | |
| """Whether this optimizer has created shadow variables.""" | |
| return self._model_weights is not None and self._average_weights is not None | |
| def _create_slots(self, var_list): | |
| self._optimizer._create_slots(var_list=var_list) # pylint: disable=protected-access | |
| def apply_gradients(self, grads_and_vars, name: Optional[str] = None): | |
| result = self._optimizer.apply_gradients(grads_and_vars, name) | |
| maybe_merge_call(self.update_average, tf.distribute.get_strategy()) | |
| return result | |
| def update_average(self, strategy): | |
| # Compute current decay value. | |
| step = tf.cast(self.iterations, tf.float32) | |
| if step < self._start_step: | |
| decay = tf.constant(0., tf.float32) | |
| elif self._dynamic_decay: | |
| decay = step - self._start_step | |
| decay = tf.minimum(self._average_decay, (1. + decay) / (10. + decay)) | |
| else: | |
| decay = self._average_decay | |
| def _apply_moving(average, normal): | |
| diff = average - normal | |
| average.assign_sub(tf.cast(1.0 - decay, average.dtype) * diff) | |
| return average | |
| # Update moving average with the latest value. | |
| for average, normal in zip(self._average_weights, self._model_weights): | |
| strategy.extended.update( | |
| average, _apply_moving, args=(normal,), group=False | |
| ) | |
| def swap_weights(self): | |
| """Swap the average and moving weights. | |
| This is a convenience method to allow one to evaluate the averaged weights | |
| at test time. Loads the weights stored in `self._average` into the model, | |
| keeping a copy of the original model weights. Swapping twice will return | |
| the original weights. | |
| """ | |
| if tf.distribute.in_cross_replica_context(): | |
| strategy = tf.distribute.get_strategy() | |
| strategy.run(self._swap_weights, args=()) | |
| else: | |
| raise ValueError( | |
| 'Swapping weights must occur under a tf.distribute.Strategy.' | |
| ) | |
| def _swap_weights(self): | |
| def fn_0(a, b): | |
| a.assign_add(b) | |
| return a | |
| def fn_1(b, a): | |
| b.assign(a - b) | |
| return b | |
| def fn_2(a, b): | |
| a.assign_sub(b) | |
| return a | |
| def _swap(strategy, a_and_b): | |
| """Swap `a` and `b` and mirror to all devices.""" | |
| for a, b in a_and_b: | |
| strategy.extended.update(a, fn_0, args=(b,)) # a = a + b | |
| strategy.extended.update(b, fn_1, args=(a,)) # b = a - b | |
| strategy.extended.update(a, fn_2, args=(b,)) # a = a - b | |
| # Use merge_call if requested by strategy and always for TPUStrategy as | |
| # the use of merge_call is not recommended and deprecated for other | |
| # strategies such as mirrored strategy (MS) and multi-worker mirrored | |
| # strategy (MWMS) if nccl/collective_ops are used, which can operate in | |
| # pure replica context. | |
| strategy = tf.distribute.get_strategy() | |
| if isinstance(strategy, tf.distribute.TPUStrategy): | |
| maybe_merge_call( | |
| _swap, | |
| strategy, | |
| zip(self._average_weights, self._model_weights), | |
| ) | |
| else: | |
| _swap( | |
| strategy, | |
| zip(self._average_weights, self._model_weights), | |
| ) | |
| def assign_average_vars(self, var_list: List[tf.Variable]): | |
| """Assign variables in var_list with their respective averages. | |
| Args: | |
| var_list: List of model variables to be assigned to their average. | |
| Returns: | |
| assign_op: The op corresponding to the assignment operation of | |
| variables to their average. | |
| """ | |
| assign_op = tf.group([ | |
| var.assign(self.get_slot(var, 'average')) for var in var_list | |
| if var.trainable | |
| ]) | |
| return assign_op | |
| def _create_hypers(self): | |
| self._optimizer._create_hypers() # pylint: disable=protected-access | |
| def _prepare(self, var_list): | |
| return self._optimizer._prepare(var_list=var_list) # pylint: disable=protected-access | |
| def iterations(self): | |
| return self._optimizer.iterations | |
| def iterations(self, variable): | |
| self._optimizer.iterations = variable | |
| def weights(self): | |
| # return self._weights + self._optimizer.weights | |
| return self._optimizer.weights | |
| def variables(self): | |
| return self._weights + [self.iterations] | |
| def lr(self): | |
| return self._optimizer._get_hyper('learning_rate') | |
| def lr(self, lr): | |
| self._optimizer._set_hyper('learning_rate', lr) | |
| def learning_rate(self): | |
| return self._optimizer._get_hyper('learning_rate') | |
| def learning_rate(self, learning_rate): # pylint: disable=redefined-outer-name | |
| self._optimizer._set_hyper('learning_rate', learning_rate) | |
| def _resource_apply_dense(self, grad, var): | |
| return self._optimizer._resource_apply_dense(grad, var) | |
| def _resource_apply_sparse(self, grad, var, indices): | |
| return self._optimizer._resource_apply_sparse(grad, var, indices) | |
| def _resource_apply_sparse_duplicate_indices(self, grad, var, indices): | |
| return self._optimizer._resource_apply_sparse_duplicate_indices( | |
| grad, var, indices) | |
| def get_config(self): | |
| config = { | |
| 'optimizer': tf_keras.optimizers.serialize(self._optimizer), | |
| 'average_decay': self._average_decay, | |
| 'start_step': self._start_step, | |
| 'dynamic_decay': self._dynamic_decay, | |
| } | |
| base_config = super(ExponentialMovingAverage, self).get_config() | |
| return dict(list(base_config.items()) + list(config.items())) | |
| def from_config(cls, config, custom_objects=None): | |
| optimizer = tf_keras.optimizers.deserialize( | |
| config.pop('optimizer'), | |
| custom_objects=custom_objects, | |
| ) | |
| return cls(optimizer, **config) | |