| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Apply transformed gradient updates to parameters.""" |
|
|
| import chex |
| import jax |
| import jax.numpy as jnp |
|
|
| from optax._src import base |
|
|
|
|
| def apply_updates(params: base.Params, updates: base.Updates) -> base.Params: |
| """Applies an update to the corresponding parameters. |
| |
| This is a utility functions that applies an update to a set of parameters, and |
| then returns the updated parameters to the caller. As an example, the update |
| may be a gradient transformed by a sequence of`GradientTransformations`. This |
| function is exposed for convenience, but it just adds updates and parameters; |
| you may also apply updates to parameters manually, using `tree_map` |
| (e.g. if you want to manipulate updates in custom ways before applying them). |
| |
| Args: |
| params: a tree of parameters. |
| updates: a tree of updates, the tree structure and the shape of the leaf |
| nodes must match that of `params`. |
| |
| Returns: |
| Updated parameters, with same structure, shape and type as `params`. |
| """ |
| return jax.tree_util.tree_map( |
| lambda p, u: jnp.asarray(p + u).astype(jnp.asarray(p).dtype), |
| params, updates) |
|
|
|
|
| def incremental_update( |
| new_tensors: base.Params, |
| old_tensors: base.Params, |
| step_size: chex.Numeric |
| ) -> base.Params: |
| """Incrementally update parameters via polyak averaging. |
| |
| Polyak averaging tracks an (exponential moving) average of the past |
| parameters of a model, for use at test/evaluation time. |
| |
| References: |
| [Polyak et al, 1991](https://epubs.siam.org/doi/10.1137/0330046) |
| |
| Args: |
| new_tensors: the latest value of the tensors. |
| old_tensors: a moving average of the values of the tensors. |
| step_size: the step_size used to update the polyak average on each step. |
| |
| Returns: |
| an updated moving average `step_size*new+(1-step_size)*old` of the params. |
| """ |
| return jax.tree_util.tree_map( |
| lambda new, old: step_size * new + (1.0 - step_size) * old, |
| new_tensors, old_tensors) |
|
|
|
|
| def periodic_update( |
| new_tensors: base.Params, |
| old_tensors: base.Params, |
| steps: chex.Array, |
| update_period: int |
| ) -> base.Params: |
| """Periodically update all parameters with new values. |
| |
| A slow copy of a model's parameters, updated every K actual updates, can be |
| used to implement forms of self-supervision (in supervised learning), or to |
| stabilise temporal difference learning updates (in reinforcement learning). |
| |
| References: |
| [Grill et al., 2020](https://arxiv.org/abs/2006.07733) |
| [Mnih et al., 2015](https://arxiv.org/abs/1312.5602) |
| |
| Args: |
| new_tensors: the latest value of the tensors. |
| old_tensors: a slow copy of the model's parameters. |
| steps: number of update steps on the "online" network. |
| update_period: every how many steps to update the "target" network. |
| |
| Returns: |
| a slow copy of the model's parameters, updated every `update_period` steps. |
| """ |
| return jax.lax.cond( |
| jnp.mod(steps, update_period) == 0, |
| lambda _: new_tensors, |
| lambda _: old_tensors, |
| None) |
|
|
|
|