Alan0928's picture
Upload folder using huggingface_hub
08ff31f verified
Raw
History Blame Contribute Delete
1.22 kB
from collections.abc import Callable
from typing import Any
from flax import nnx
from flax import struct
import jax
import optax
from openpi.models import model as _model
from openpi.shared import array_typing as at
@at.typecheck
@struct.dataclass
class TrainState:
step: at.Int[at.ArrayLike, ""]
params: nnx.State
model_def: nnx.GraphDef[_model.BaseModel]
opt_state: optax.OptState
tx: optax.GradientTransformation = struct.field(pytree_node=False)
ema_decay: float | None = struct.field(pytree_node=False)
ema_params: nnx.State | None = None
@at.typecheck
def tree_to_info(tree: at.PyTree, interp_func: Callable[[Any], str] = str) -> str:
"""Converts a PyTree into a human-readable string for logging. Optionally, `interp_func` can be provided to convert
the leaf values to more meaningful strings.
"""
tree, _ = jax.tree_util.tree_flatten_with_path(tree)
return "\n".join(f"{jax.tree_util.keystr(path)}: {interp_func(value)}" for path, value in tree)
@at.typecheck
def array_tree_to_info(tree: at.PyTree) -> str:
"""Converts a PyTree of arrays into a human-readable string for logging."""
return tree_to_info(tree, lambda x: f"{x.shape}@{x.dtype}")