import os import pickle import orbax.checkpoint from flax.training import orbax_utils import jax import jax.numpy as jnp import numpy as np from orbax.checkpoint import ArrayRestoreArgs, RestoreArgs # suppress logging from orbax import logging logger = logging.getLogger("absl") logger.setLevel(logging.ERROR) # compute path to repo root by using this file's path REPO_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) def save_train_run(out, savedir, savename): '''Save train run as orbax checkpoint. Orbax requires absolute paths, so we compute the absolute path to the repo root.''' # determine whether savedir is relative or absolute if not os.path.isabs(savedir): savedir = os.path.join(REPO_PATH, savedir) if not os.path.exists(savedir): os.makedirs(savedir, exist_ok=True) savepath = os.path.join(savedir, savename) checkpointer = orbax.checkpoint.PyTreeCheckpointer() save_args = orbax_utils.save_args_from_target(out) # Save the checkpoint checkpointer.save(savepath, out, save_args=save_args) return savepath def load_checkpoints(path, ckpt_key="checkpoints", custom_loader_cfg: dict=None): '''Load checkpoints from orbax checkpoint. Orbax requires absolute paths, so we compute the absolute path to the repo root.''' if custom_loader_cfg is None: restored = load_train_run(path) return restored[ckpt_key] elif custom_loader_cfg["name"] == "open_ended": # Open-ended loader needs the full checkpoint restored = load_train_run(path) partner_out, ego_out = restored out = ego_out if custom_loader_cfg["type"] == "ego" else partner_out if ckpt_key == "final_buffer": return out["final_buffer"]["params"] else: return out[ckpt_key] elif custom_loader_cfg["name"] == "partial_load": return _load_partial(path, ckpt_key) elif custom_loader_cfg["name"] == "fcp": # FCP saves checkpoints with shape # (NUM_SEEDS, PARTNER_POP_SIZE, NUM_CHECKPOINTS, ...) # Reshape to the (NUM_SEEDS, FCP_POP_SIZE, ...) layout that # AgentPopulation expects, where FCP_POP_SIZE = PARTNER_POP_SIZE * # NUM_CHECKPOINTS. Mirrors get_fcp_population in # teammate_generation/fcp.py so a saved FCP run can be reused as a # partner pool by ego_agent_training/run.py. restored = load_train_run(path) ckpts = restored[ckpt_key] return jax.tree.map( lambda x: x.reshape(x.shape[0], x.shape[1] * x.shape[2], *x.shape[3:]), ckpts, ) else: raise ValueError(f"Invalid custom loader name: {custom_loader_cfg['name']}") def _load_partial(path, ckpt_key): '''Load only a single top-level key from an orbax checkpoint, avoiding OOM from loading the entire pytree (e.g. skipping metrics).''' if not os.path.isabs(path): path = os.path.join(REPO_PATH, path) checkpointer = orbax.checkpoint.PyTreeCheckpointer() cpu_sharding = jax.sharding.SingleDeviceSharding(jax.devices('cpu')[0]) meta = checkpointer.metadata(path) if ckpt_key not in meta: raise KeyError(f"Key '{ckpt_key}' not found in checkpoint. Available keys: {list(meta.keys())}") subtree_meta = meta[ckpt_key] item = {ckpt_key: jax.tree.map( lambda m: np.empty(m.shape, dtype=m.dtype) if hasattr(m, 'shape') else m, subtree_meta, )} transforms = {ckpt_key: orbax.checkpoint.Transform()} restore_args = {ckpt_key: jax.tree.map( lambda _: orbax.checkpoint.ArrayRestoreArgs(sharding=cpu_sharding), subtree_meta, )} restored = checkpointer.restore(path, item=item, transforms=transforms, restore_args=restore_args) return restored[ckpt_key] def load_train_run(path): '''Load checkpoints from orbax checkpoint. Orbax requires absolute paths, so we compute the absolute path to the repo root.''' # determine whether path is relative or absolute if not os.path.isabs(path): path = os.path.join(REPO_PATH, path) # load the checkpoint checkpointer = orbax.checkpoint.PyTreeCheckpointer() def _restore_with_numpy_args(): metadata = checkpointer.metadata(path) def _mk_restore_args(leaf): if hasattr(leaf, "shape") and hasattr(leaf, "dtype"): return ArrayRestoreArgs(restore_type=np.ndarray) return RestoreArgs() restore_args = jax.tree_util.tree_map(_mk_restore_args, metadata.tree) return checkpointer.restore(path, restore_args=restore_args) force_cpu_restore = ( os.environ.get("JAX_AHT_FORCE_CPU_RESTORE", "0") == "1" or os.environ.get("JAX_PLATFORMS", "").lower() == "cpu" ) if force_cpu_restore: restored = _restore_with_numpy_args() else: try: restored = checkpointer.restore(path) except Exception as exc: msg = str(exc) recoverable = ( "sharding passed to deserialization" in msg or "Device cuda:0 was not found in jax.local_devices()" in msg ) if not recoverable: raise restored = _restore_with_numpy_args() # convert pytree leaves from np arrays to jax arrays restored = jax.tree_util.tree_map( lambda x: jnp.array(x) if isinstance(x, np.ndarray) else x, restored ) return restored def save_train_run_as_pickle(out, savedir, savename): if not os.path.exists(savedir): os.makedirs(savedir, exist_ok=True) savepath = f"{savedir}/{savename}.pkl" with open(savepath, "wb") as f: pickle.dump(out, f) return savepath def load_checkpoints_from_pickle(path, ckpt_key="checkpoints"): out = load_train_run_from_pickle(path) return out[ckpt_key] def load_train_run_from_pickle(path): with open(path, "rb") as f: out = pickle.load(f) return out