| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Wrappers changing the layouts of the tensors that transforms operate on.""" |
|
|
| from jax import tree_util as jtu |
| import jax.numpy as jnp |
| import numpy as np |
|
|
| from optax._src import base |
|
|
|
|
| def flatten( |
| inner: base.GradientTransformation |
| ) -> base.GradientTransformationExtraArgs: |
| """Flattens parameters and gradients for init and update of inner transform. |
| |
| This can reduce the overhead of performing many calculations on lots of small |
| variables, at the cost of slightly increased memory usage. |
| |
| Args: |
| inner: Inner transformation to flatten inputs for. |
| |
| Returns: |
| New ``GradientTransformationExtraArgs`` |
| """ |
|
|
| inner = base.with_extra_args_support(inner) |
|
|
| def _flatten(params): |
| """Flattens and concatenates all tensors in params to a single vector.""" |
| params, _ = jtu.tree_flatten(params) |
| return jnp.concatenate([jnp.reshape(param, [-1]) for param in params]) |
|
|
| def _unflatten(updates, flat): |
| """Extracts tensors from flat, using the structure and shapes of params.""" |
| updates_flat, treedef = jtu.tree_flatten(updates) |
| offsets = [] |
| for update in updates_flat: |
| size = np.size(update) |
| if offsets: |
| offsets.append(size + offsets[-1]) |
| else: |
| offsets.append(size) |
| del offsets[-1] |
| flat_split = jnp.split(flat, offsets) |
| reshaped = [ |
| jnp.reshape(flat_update, update.shape) |
| for flat_update, update in zip(flat_split, updates_flat) |
| ] |
| return jtu.tree_unflatten(treedef, reshaped) |
|
|
| def init_fn(params): |
| flat = _flatten(params) |
| return inner.init(flat) |
|
|
| def update_fn(updates, state, params=None, **extra_args): |
| if params is not None: |
| params = _flatten(params) |
| updates_flat, state = inner.update( |
| _flatten(updates), state, params, **extra_args |
| ) |
| updates = _unflatten(updates, updates_flat) |
| return updates, state |
|
|
| return base.GradientTransformationExtraArgs(init_fn, update_fn) |
|
|