| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Utilities for partitioning.""" |
|
|
| from typing import Any, Mapping, MutableMapping, Optional, Tuple |
|
|
| import flax.core |
| import flax.serialization |
| import flax.struct |
| import jax.numpy as jnp |
| from flax import traverse_util |
| from flax.core import scope as flax_scope |
| from flax.linen import partitioning as flax_partitioning |
|
|
|
|
| EMPTY_DICT = flax.core.freeze({}) |
| FrozenDict = flax_scope.FrozenDict |
| FrozenVariableDict = flax_scope.FrozenVariableDict |
| MutableVariableDict = flax_scope.MutableVariableDict |
| VariableDict = flax_scope.VariableDict |
|
|
|
|
| def _validate_params_axes(params_axes, params): |
| axis_names = flax_partitioning.get_axis_names(params_axes) |
| missing_params_axes = set(traverse_util.flatten_dict(params, sep="/")) - set( |
| traverse_util.flatten_dict(axis_names, sep="/") |
| ) |
| if missing_params_axes: |
| raise ValueError(f"Missing axis names for parameters: {missing_params_axes}") |
|
|
|
|
| def _split_variables_and_axes(variables_and_axes: FrozenVariableDict) -> Tuple[FrozenVariableDict, FrozenVariableDict]: |
| """Splits `variables_and_axes` into two separate dicts with the same keys.""" |
| |
| variables = {} |
| axes = {} |
| for k, v in variables_and_axes.items(): |
| if k.endswith("_axes"): |
| axes[k[:-5]] = v |
| _validate_params_axes(v, variables_and_axes[k[:-5]]) |
| else: |
| variables[k] = v |
| return flax.core.freeze(variables), flax.core.freeze(axes) |
|
|
|
|
| class InferenceState(flax.struct.PyTreeNode): |
| """State compatible with FlaxOptimTrainState without optimizer state.""" |
|
|
| step: jnp.ndarray |
| params: flax_scope.FrozenVariableDict |
| params_axes: Optional[flax_scope.FrozenVariableDict] = None |
| flax_mutables: flax_scope.FrozenDict = EMPTY_DICT |
| flax_mutables_axes: Optional[flax_scope.FrozenVariableDict] = None |
|
|
| @classmethod |
| def create(cls, model_variables: FrozenVariableDict) -> "InferenceState": |
| other_variables, params = model_variables.pop("params") |
| if "params_axes" in other_variables: |
| other_variables, params_axes = other_variables.pop("params_axes") |
| _validate_params_axes(params_axes, params) |
| else: |
| params_axes = None |
|
|
| |
| flax_mutables, flax_mutables_axes = _split_variables_and_axes(other_variables) |
| flax_mutables_axes = flax_mutables_axes or None |
| return InferenceState( |
| step=jnp.array(0), |
| params=params, |
| params_axes=params_axes, |
| flax_mutables=flax_mutables, |
| flax_mutables_axes=flax_mutables_axes, |
| ) |
|
|
| @property |
| def param_states(self) -> FrozenVariableDict: |
| """The optimizer states of the parameters as a PyTree.""" |
| raise NotImplementedError("InferenceState has no optimizer states.") |
|
|
| def apply_gradient(self, *args, **kwargs) -> "InferenceState": |
| raise NotImplementedError("InferenceState does not support `apply_gradient`.") |
|
|
| def state_dict(self) -> MutableMapping[str, Any]: |
| state_dict = {"target": flax.core.unfreeze(self.params), "state": {"step": self.step}} |
| if self.flax_mutables: |
| state_dict["flax_mutables"] = flax.core.unfreeze(self.flax_mutables) |
| return state_dict |
|
|
| def replace_step(self, step: jnp.ndarray) -> "InferenceState": |
| return self.replace(step=step) |
|
|
| def replace_params(self, params: FrozenVariableDict) -> "InferenceState": |
| return self.replace(params=params) |
|
|
| def replace_flax_mutables(self, flax_mutables: FrozenDict) -> "InferenceState": |
| return self.replace(flax_mutables=flax_mutables) |
|
|
| def restore_state(self, state_dict: Mapping[str, Any]) -> "InferenceState": |
| return self.replace( |
| params=flax.core.freeze(state_dict["target"]), |
| step=state_dict["state"]["step"], |
| flax_mutables=flax.core.freeze(state_dict["flax_mutables"]) |
| if "flax_mutables" in state_dict |
| else EMPTY_DICT, |
| ) |
|
|
| def as_logical_axes(self) -> "InferenceState": |
| |
| |
| |
| |
| flax_mutables_axes = self.flax_mutables_axes or EMPTY_DICT |
| return InferenceState( |
| step=None, |
| params=flax_partitioning.get_axis_names(self.params_axes), |
| flax_mutables=flax_partitioning.get_axis_names(flax_mutables_axes), |
| ) |
|
|