| from collections.abc import Callable | |
| from typing import Any | |
| from flax import nnx | |
| from flax import struct | |
| import jax | |
| import optax | |
| from openpi.models import model as _model | |
| from openpi.shared import array_typing as at | |
| class TrainState: | |
| step: at.Int[at.ArrayLike, ""] | |
| params: nnx.State | |
| model_def: nnx.GraphDef[_model.BaseModel] | |
| opt_state: optax.OptState | |
| tx: optax.GradientTransformation = struct.field(pytree_node=False) | |
| ema_decay: float | None = struct.field(pytree_node=False) | |
| ema_params: nnx.State | None = None | |
| def tree_to_info(tree: at.PyTree, interp_func: Callable[[Any], str] = str) -> str: | |
| """Converts a PyTree into a human-readable string for logging. Optionally, `interp_func` can be provided to convert | |
| the leaf values to more meaningful strings. | |
| """ | |
| tree, _ = jax.tree_util.tree_flatten_with_path(tree) | |
| return "\n".join(f"{jax.tree_util.keystr(path)}: {interp_func(value)}" for path, value in tree) | |
| def array_tree_to_info(tree: at.PyTree) -> str: | |
| """Converts a PyTree of arrays into a human-readable string for logging.""" | |
| return tree_to_info(tree, lambda x: f"{x.shape}@{x.dtype}") | |