| import concurrent.futures as futures |
| import dataclasses |
| import logging |
| from typing import Protocol |
|
|
| from etils import epath |
| import jax |
| import orbax.checkpoint as ocp |
|
|
| from openpi.shared import array_typing as at |
| import openpi.shared.normalize as _normalize |
| import openpi.training.data_loader as _data_loader |
| import openpi.training.utils as training_utils |
|
|
|
|
| def initialize_checkpoint_dir( |
| checkpoint_dir: epath.Path | str, *, keep_period: int | None, overwrite: bool, resume: bool |
| ) -> tuple[ocp.CheckpointManager, bool]: |
| checkpoint_dir = epath.Path(checkpoint_dir).resolve() |
| resuming = False |
| if checkpoint_dir.exists(): |
| if overwrite: |
| checkpoint_dir.rmtree() |
| checkpoint_dir.mkdir(parents=True, exist_ok=True) |
| logging.info(f"Wiped checkpoint directory {checkpoint_dir}") |
| elif resume: |
| resuming = True |
| else: |
| raise FileExistsError( |
| f"Checkpoint directory {checkpoint_dir} already exists. Use --overwrite or --resume " |
| "to indicate how to handle it." |
| ) |
|
|
| checkpoint_dir.mkdir(parents=True, exist_ok=True) |
|
|
| mngr = ocp.CheckpointManager( |
| checkpoint_dir, |
| item_handlers={ |
| "assets": CallbackHandler(), |
| "train_state": ocp.PyTreeCheckpointHandler(), |
| "params": ocp.PyTreeCheckpointHandler(), |
| }, |
| options=ocp.CheckpointManagerOptions( |
| max_to_keep=1, |
| keep_period=keep_period, |
| create=False, |
| async_options=ocp.AsyncOptions(timeout_secs=7200), |
| ), |
| ) |
|
|
| |
| |
| |
| if resuming and tuple(mngr.all_steps()) in [(), (0,)]: |
| logging.info("Checkpoint directory exists, but does not contain any checkpoints. Aborting resume.") |
| resuming = False |
|
|
| return mngr, resuming |
|
|
|
|
| def save_state( |
| checkpoint_manager: ocp.CheckpointManager, |
| state: training_utils.TrainState, |
| data_loader: _data_loader.DataLoader, |
| step: int, |
| ): |
| def save_assets(directory: epath.Path): |
| |
| data_config = data_loader.data_config() |
| norm_stats = data_config.norm_stats |
| if norm_stats is not None and data_config.asset_id is not None: |
| _normalize.save(directory / data_config.asset_id, norm_stats) |
|
|
| |
| with at.disable_typechecking(): |
| train_state, params = _split_params(state) |
| items = { |
| "assets": save_assets, |
| |
| "params": {"params": params}, |
| } |
| checkpoint_manager.save(step, items) |
|
|
|
|
| def restore_state( |
| checkpoint_manager: ocp.CheckpointManager, |
| state: training_utils.TrainState, |
| data_loader: _data_loader.DataLoader, |
| step: int | None = None, |
| ) -> training_utils.TrainState: |
| del data_loader |
|
|
| with at.disable_typechecking(): |
| |
| train_state, params = _split_params(state) |
| restored = checkpoint_manager.restore( |
| step, |
| items={ |
| "train_state": train_state, |
| "params": {"params": params}, |
| }, |
| ) |
| return _merge_params(restored["train_state"], restored["params"]) |
|
|
|
|
| def load_norm_stats(assets_dir: epath.Path | str, asset_id: str) -> dict[str, _normalize.NormStats] | None: |
| norm_stats_dir = epath.Path(assets_dir) / asset_id |
| norm_stats = _normalize.load(norm_stats_dir) |
| logging.info(f"Loaded norm stats from {norm_stats_dir}") |
| return norm_stats |
|
|
|
|
| class Callback(Protocol): |
| def __call__(self, directory: epath.Path) -> None: ... |
|
|
|
|
| class CallbackHandler(ocp.AsyncCheckpointHandler): |
| """A CheckpointHandler for calling an arbitrary function asynchronously. Only for saving, not for restoring.""" |
|
|
| def __init__(self): |
| self._executor = futures.ThreadPoolExecutor(max_workers=1) |
|
|
| def close(self): |
| self._executor.shutdown() |
|
|
| def save(self, directory: epath.Path, args: "CallbackSave"): |
| if jax.process_index() == 0: |
| args.callback(directory) |
|
|
| async def async_save(self, directory: epath.Path, args: "CallbackSave") -> list[futures.Future]: |
| return [self._executor.submit(self.save, directory, args)] |
|
|
| def restore(self, *args, **kwargs): |
| raise NotImplementedError("CallbackHandler does not support restore") |
|
|
|
|
| @ocp.args.register_with_handler(CallbackHandler, for_save=True) |
| @dataclasses.dataclass |
| class CallbackSave(ocp.args.CheckpointArgs): |
| callback: Callback |
|
|
|
|
| @ocp.args.register_with_handler(CallbackHandler, for_restore=True) |
| class CallbackRestore(ocp.args.CheckpointArgs): ... |
|
|
|
|
| def _split_params(state: training_utils.TrainState) -> tuple[training_utils.TrainState, at.Params]: |
| if state.ema_params is not None: |
| params = state.ema_params |
| train_state = dataclasses.replace(state, ema_params=None) |
| else: |
| params = state.params |
| train_state = dataclasses.replace(state, params={}) |
| return train_state, params |
|
|
|
|
| def _merge_params(train_state: training_utils.TrainState, params: dict[str, at.Params]) -> training_utils.TrainState: |
| |
| if train_state.params: |
| return dataclasses.replace(train_state, ema_params=params["params"]) |
| return dataclasses.replace(train_state, params=params["params"]) |
|
|