Spaces:
Running
Running
| import os | |
| import pickle | |
| import orbax.checkpoint | |
| from flax.training import orbax_utils | |
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| from orbax.checkpoint import ArrayRestoreArgs, RestoreArgs | |
| # suppress logging from orbax | |
| import logging | |
| logger = logging.getLogger("absl") | |
| logger.setLevel(logging.ERROR) | |
| # compute path to repo root by using this file's path | |
| REPO_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| def save_train_run(out, savedir, savename): | |
| '''Save train run as orbax checkpoint. | |
| Orbax requires absolute paths, so we compute the absolute path to the repo root.''' | |
| # determine whether savedir is relative or absolute | |
| if not os.path.isabs(savedir): | |
| savedir = os.path.join(REPO_PATH, savedir) | |
| if not os.path.exists(savedir): | |
| os.makedirs(savedir, exist_ok=True) | |
| savepath = os.path.join(savedir, savename) | |
| checkpointer = orbax.checkpoint.PyTreeCheckpointer() | |
| save_args = orbax_utils.save_args_from_target(out) | |
| # Save the checkpoint | |
| checkpointer.save(savepath, out, save_args=save_args) | |
| return savepath | |
| def load_checkpoints(path, ckpt_key="checkpoints", custom_loader_cfg: dict=None): | |
| '''Load checkpoints from orbax checkpoint. | |
| Orbax requires absolute paths, so we compute the absolute path to the repo root.''' | |
| if custom_loader_cfg is None: | |
| restored = load_train_run(path) | |
| return restored[ckpt_key] | |
| elif custom_loader_cfg["name"] == "open_ended": | |
| # Open-ended loader needs the full checkpoint | |
| restored = load_train_run(path) | |
| partner_out, ego_out = restored | |
| out = ego_out if custom_loader_cfg["type"] == "ego" else partner_out | |
| if ckpt_key == "final_buffer": | |
| return out["final_buffer"]["params"] | |
| else: | |
| return out[ckpt_key] | |
| elif custom_loader_cfg["name"] == "partial_load": | |
| return _load_partial(path, ckpt_key) | |
| elif custom_loader_cfg["name"] == "fcp": | |
| # FCP saves checkpoints with shape | |
| # (NUM_SEEDS, PARTNER_POP_SIZE, NUM_CHECKPOINTS, ...) | |
| # Reshape to the (NUM_SEEDS, FCP_POP_SIZE, ...) layout that | |
| # AgentPopulation expects, where FCP_POP_SIZE = PARTNER_POP_SIZE * | |
| # NUM_CHECKPOINTS. Mirrors get_fcp_population in | |
| # teammate_generation/fcp.py so a saved FCP run can be reused as a | |
| # partner pool by ego_agent_training/run.py. | |
| restored = load_train_run(path) | |
| ckpts = restored[ckpt_key] | |
| return jax.tree.map( | |
| lambda x: x.reshape(x.shape[0], x.shape[1] * x.shape[2], *x.shape[3:]), | |
| ckpts, | |
| ) | |
| else: | |
| raise ValueError(f"Invalid custom loader name: {custom_loader_cfg['name']}") | |
| def _load_partial(path, ckpt_key): | |
| '''Load only a single top-level key from an orbax checkpoint, avoiding OOM | |
| from loading the entire pytree (e.g. skipping metrics).''' | |
| if not os.path.isabs(path): | |
| path = os.path.join(REPO_PATH, path) | |
| checkpointer = orbax.checkpoint.PyTreeCheckpointer() | |
| cpu_sharding = jax.sharding.SingleDeviceSharding(jax.devices('cpu')[0]) | |
| meta = checkpointer.metadata(path) | |
| if ckpt_key not in meta: | |
| raise KeyError(f"Key '{ckpt_key}' not found in checkpoint. Available keys: {list(meta.keys())}") | |
| subtree_meta = meta[ckpt_key] | |
| item = {ckpt_key: jax.tree.map( | |
| lambda m: np.empty(m.shape, dtype=m.dtype) if hasattr(m, 'shape') else m, | |
| subtree_meta, | |
| )} | |
| transforms = {ckpt_key: orbax.checkpoint.Transform()} | |
| restore_args = {ckpt_key: jax.tree.map( | |
| lambda _: orbax.checkpoint.ArrayRestoreArgs(sharding=cpu_sharding), | |
| subtree_meta, | |
| )} | |
| restored = checkpointer.restore(path, item=item, transforms=transforms, restore_args=restore_args) | |
| return restored[ckpt_key] | |
| def load_train_run(path): | |
| '''Load checkpoints from orbax checkpoint. | |
| Orbax requires absolute paths, so we compute the absolute path to the repo root.''' | |
| # determine whether path is relative or absolute | |
| if not os.path.isabs(path): | |
| path = os.path.join(REPO_PATH, path) | |
| # load the checkpoint | |
| checkpointer = orbax.checkpoint.PyTreeCheckpointer() | |
| def _restore_with_numpy_args(): | |
| metadata = checkpointer.metadata(path) | |
| def _mk_restore_args(leaf): | |
| if hasattr(leaf, "shape") and hasattr(leaf, "dtype"): | |
| return ArrayRestoreArgs(restore_type=np.ndarray) | |
| return RestoreArgs() | |
| restore_args = jax.tree_util.tree_map(_mk_restore_args, metadata.tree) | |
| return checkpointer.restore(path, restore_args=restore_args) | |
| force_cpu_restore = ( | |
| os.environ.get("JAX_AHT_FORCE_CPU_RESTORE", "0") == "1" | |
| or os.environ.get("JAX_PLATFORMS", "").lower() == "cpu" | |
| ) | |
| if force_cpu_restore: | |
| restored = _restore_with_numpy_args() | |
| else: | |
| try: | |
| restored = checkpointer.restore(path) | |
| except Exception as exc: | |
| msg = str(exc) | |
| recoverable = ( | |
| "sharding passed to deserialization" in msg | |
| or "Device cuda:0 was not found in jax.local_devices()" in msg | |
| ) | |
| if not recoverable: | |
| raise | |
| restored = _restore_with_numpy_args() | |
| # convert pytree leaves from np arrays to jax arrays | |
| restored = jax.tree_util.tree_map( | |
| lambda x: jnp.array(x) if isinstance(x, np.ndarray) else x, | |
| restored | |
| ) | |
| return restored | |
| def save_train_run_as_pickle(out, savedir, savename): | |
| if not os.path.exists(savedir): | |
| os.makedirs(savedir, exist_ok=True) | |
| savepath = f"{savedir}/{savename}.pkl" | |
| with open(savepath, "wb") as f: | |
| pickle.dump(out, f) | |
| return savepath | |
| def load_checkpoints_from_pickle(path, ckpt_key="checkpoints"): | |
| out = load_train_run_from_pickle(path) | |
| return out[ckpt_key] | |
| def load_train_run_from_pickle(path): | |
| with open(path, "rb") as f: | |
| out = pickle.load(f) | |
| return out |