| | from collections.abc import Callable |
| | from typing import Any |
| |
|
| | from flax import nnx |
| | from flax import struct |
| | import jax |
| | import optax |
| |
|
| | from openpi.models import model as _model |
| | from openpi.shared import array_typing as at |
| |
|
| |
|
| | @at.typecheck |
| | @struct.dataclass |
| | class TrainState: |
| | step: at.Int[at.ArrayLike, ""] |
| | params: nnx.State |
| | model_def: nnx.GraphDef[_model.BaseModel] |
| | opt_state: optax.OptState |
| | tx: optax.GradientTransformation = struct.field(pytree_node=False) |
| |
|
| | ema_decay: float | None = struct.field(pytree_node=False) |
| | ema_params: nnx.State | None = None |
| |
|
| |
|
| | @at.typecheck |
| | def tree_to_info(tree: at.PyTree, interp_func: Callable[[Any], str] = str) -> str: |
| | """Converts a PyTree into a human-readable string for logging. Optionally, `interp_func` can be provided to convert |
| | the leaf values to more meaningful strings. |
| | """ |
| | tree, _ = jax.tree_util.tree_flatten_with_path(tree) |
| | return "\n".join(f"{jax.tree_util.keystr(path)}: {interp_func(value)}" for path, value in tree) |
| |
|
| |
|
| | @at.typecheck |
| | def array_tree_to_info(tree: at.PyTree) -> str: |
| | """Converts a PyTree of arrays into a human-readable string for logging.""" |
| | def format_value(x): |
| | if hasattr(x, 'shape') and hasattr(x, 'dtype'): |
| | return f"{x.shape}@{x.dtype}" |
| | else: |
| | return f"{type(x).__name__}: {x}" |
| | return tree_to_info(tree, format_value) |
| |
|