jaxaht-benchmark / common /save_load_utils.py
lainwired's picture
Initial jaxaht-benchmark deployment
5146e76
import os
import pickle
import orbax.checkpoint
from flax.training import orbax_utils
import jax
import jax.numpy as jnp
import numpy as np
from orbax.checkpoint import ArrayRestoreArgs, RestoreArgs
# suppress logging from orbax
import logging
logger = logging.getLogger("absl")
logger.setLevel(logging.ERROR)
# compute path to repo root by using this file's path
REPO_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
def save_train_run(out, savedir, savename):
'''Save train run as orbax checkpoint.
Orbax requires absolute paths, so we compute the absolute path to the repo root.'''
# determine whether savedir is relative or absolute
if not os.path.isabs(savedir):
savedir = os.path.join(REPO_PATH, savedir)
if not os.path.exists(savedir):
os.makedirs(savedir, exist_ok=True)
savepath = os.path.join(savedir, savename)
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(out)
# Save the checkpoint
checkpointer.save(savepath, out, save_args=save_args)
return savepath
def load_checkpoints(path, ckpt_key="checkpoints", custom_loader_cfg: dict=None):
'''Load checkpoints from orbax checkpoint.
Orbax requires absolute paths, so we compute the absolute path to the repo root.'''
if custom_loader_cfg is None:
restored = load_train_run(path)
return restored[ckpt_key]
elif custom_loader_cfg["name"] == "open_ended":
# Open-ended loader needs the full checkpoint
restored = load_train_run(path)
partner_out, ego_out = restored
out = ego_out if custom_loader_cfg["type"] == "ego" else partner_out
if ckpt_key == "final_buffer":
return out["final_buffer"]["params"]
else:
return out[ckpt_key]
elif custom_loader_cfg["name"] == "partial_load":
return _load_partial(path, ckpt_key)
elif custom_loader_cfg["name"] == "fcp":
# FCP saves checkpoints with shape
# (NUM_SEEDS, PARTNER_POP_SIZE, NUM_CHECKPOINTS, ...)
# Reshape to the (NUM_SEEDS, FCP_POP_SIZE, ...) layout that
# AgentPopulation expects, where FCP_POP_SIZE = PARTNER_POP_SIZE *
# NUM_CHECKPOINTS. Mirrors get_fcp_population in
# teammate_generation/fcp.py so a saved FCP run can be reused as a
# partner pool by ego_agent_training/run.py.
restored = load_train_run(path)
ckpts = restored[ckpt_key]
return jax.tree.map(
lambda x: x.reshape(x.shape[0], x.shape[1] * x.shape[2], *x.shape[3:]),
ckpts,
)
else:
raise ValueError(f"Invalid custom loader name: {custom_loader_cfg['name']}")
def _load_partial(path, ckpt_key):
'''Load only a single top-level key from an orbax checkpoint, avoiding OOM
from loading the entire pytree (e.g. skipping metrics).'''
if not os.path.isabs(path):
path = os.path.join(REPO_PATH, path)
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
cpu_sharding = jax.sharding.SingleDeviceSharding(jax.devices('cpu')[0])
meta = checkpointer.metadata(path)
if ckpt_key not in meta:
raise KeyError(f"Key '{ckpt_key}' not found in checkpoint. Available keys: {list(meta.keys())}")
subtree_meta = meta[ckpt_key]
item = {ckpt_key: jax.tree.map(
lambda m: np.empty(m.shape, dtype=m.dtype) if hasattr(m, 'shape') else m,
subtree_meta,
)}
transforms = {ckpt_key: orbax.checkpoint.Transform()}
restore_args = {ckpt_key: jax.tree.map(
lambda _: orbax.checkpoint.ArrayRestoreArgs(sharding=cpu_sharding),
subtree_meta,
)}
restored = checkpointer.restore(path, item=item, transforms=transforms, restore_args=restore_args)
return restored[ckpt_key]
def load_train_run(path):
'''Load checkpoints from orbax checkpoint.
Orbax requires absolute paths, so we compute the absolute path to the repo root.'''
# determine whether path is relative or absolute
if not os.path.isabs(path):
path = os.path.join(REPO_PATH, path)
# load the checkpoint
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
def _restore_with_numpy_args():
metadata = checkpointer.metadata(path)
def _mk_restore_args(leaf):
if hasattr(leaf, "shape") and hasattr(leaf, "dtype"):
return ArrayRestoreArgs(restore_type=np.ndarray)
return RestoreArgs()
restore_args = jax.tree_util.tree_map(_mk_restore_args, metadata.tree)
return checkpointer.restore(path, restore_args=restore_args)
force_cpu_restore = (
os.environ.get("JAX_AHT_FORCE_CPU_RESTORE", "0") == "1"
or os.environ.get("JAX_PLATFORMS", "").lower() == "cpu"
)
if force_cpu_restore:
restored = _restore_with_numpy_args()
else:
try:
restored = checkpointer.restore(path)
except Exception as exc:
msg = str(exc)
recoverable = (
"sharding passed to deserialization" in msg
or "Device cuda:0 was not found in jax.local_devices()" in msg
)
if not recoverable:
raise
restored = _restore_with_numpy_args()
# convert pytree leaves from np arrays to jax arrays
restored = jax.tree_util.tree_map(
lambda x: jnp.array(x) if isinstance(x, np.ndarray) else x,
restored
)
return restored
def save_train_run_as_pickle(out, savedir, savename):
if not os.path.exists(savedir):
os.makedirs(savedir, exist_ok=True)
savepath = f"{savedir}/{savename}.pkl"
with open(savepath, "wb") as f:
pickle.dump(out, f)
return savepath
def load_checkpoints_from_pickle(path, ckpt_key="checkpoints"):
out = load_train_run_from_pickle(path)
return out[ckpt_key]
def load_train_run_from_pickle(path):
with open(path, "rb") as f:
out = pickle.load(f)
return out