Spaces:
Running
Running
File size: 5,998 Bytes
5146e76 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | import os
import pickle
import orbax.checkpoint
from flax.training import orbax_utils
import jax
import jax.numpy as jnp
import numpy as np
from orbax.checkpoint import ArrayRestoreArgs, RestoreArgs
# suppress logging from orbax
import logging
logger = logging.getLogger("absl")
logger.setLevel(logging.ERROR)
# compute path to repo root by using this file's path
REPO_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
def save_train_run(out, savedir, savename):
'''Save train run as orbax checkpoint.
Orbax requires absolute paths, so we compute the absolute path to the repo root.'''
# determine whether savedir is relative or absolute
if not os.path.isabs(savedir):
savedir = os.path.join(REPO_PATH, savedir)
if not os.path.exists(savedir):
os.makedirs(savedir, exist_ok=True)
savepath = os.path.join(savedir, savename)
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(out)
# Save the checkpoint
checkpointer.save(savepath, out, save_args=save_args)
return savepath
def load_checkpoints(path, ckpt_key="checkpoints", custom_loader_cfg: dict=None):
'''Load checkpoints from orbax checkpoint.
Orbax requires absolute paths, so we compute the absolute path to the repo root.'''
if custom_loader_cfg is None:
restored = load_train_run(path)
return restored[ckpt_key]
elif custom_loader_cfg["name"] == "open_ended":
# Open-ended loader needs the full checkpoint
restored = load_train_run(path)
partner_out, ego_out = restored
out = ego_out if custom_loader_cfg["type"] == "ego" else partner_out
if ckpt_key == "final_buffer":
return out["final_buffer"]["params"]
else:
return out[ckpt_key]
elif custom_loader_cfg["name"] == "partial_load":
return _load_partial(path, ckpt_key)
elif custom_loader_cfg["name"] == "fcp":
# FCP saves checkpoints with shape
# (NUM_SEEDS, PARTNER_POP_SIZE, NUM_CHECKPOINTS, ...)
# Reshape to the (NUM_SEEDS, FCP_POP_SIZE, ...) layout that
# AgentPopulation expects, where FCP_POP_SIZE = PARTNER_POP_SIZE *
# NUM_CHECKPOINTS. Mirrors get_fcp_population in
# teammate_generation/fcp.py so a saved FCP run can be reused as a
# partner pool by ego_agent_training/run.py.
restored = load_train_run(path)
ckpts = restored[ckpt_key]
return jax.tree.map(
lambda x: x.reshape(x.shape[0], x.shape[1] * x.shape[2], *x.shape[3:]),
ckpts,
)
else:
raise ValueError(f"Invalid custom loader name: {custom_loader_cfg['name']}")
def _load_partial(path, ckpt_key):
'''Load only a single top-level key from an orbax checkpoint, avoiding OOM
from loading the entire pytree (e.g. skipping metrics).'''
if not os.path.isabs(path):
path = os.path.join(REPO_PATH, path)
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
cpu_sharding = jax.sharding.SingleDeviceSharding(jax.devices('cpu')[0])
meta = checkpointer.metadata(path)
if ckpt_key not in meta:
raise KeyError(f"Key '{ckpt_key}' not found in checkpoint. Available keys: {list(meta.keys())}")
subtree_meta = meta[ckpt_key]
item = {ckpt_key: jax.tree.map(
lambda m: np.empty(m.shape, dtype=m.dtype) if hasattr(m, 'shape') else m,
subtree_meta,
)}
transforms = {ckpt_key: orbax.checkpoint.Transform()}
restore_args = {ckpt_key: jax.tree.map(
lambda _: orbax.checkpoint.ArrayRestoreArgs(sharding=cpu_sharding),
subtree_meta,
)}
restored = checkpointer.restore(path, item=item, transforms=transforms, restore_args=restore_args)
return restored[ckpt_key]
def load_train_run(path):
'''Load checkpoints from orbax checkpoint.
Orbax requires absolute paths, so we compute the absolute path to the repo root.'''
# determine whether path is relative or absolute
if not os.path.isabs(path):
path = os.path.join(REPO_PATH, path)
# load the checkpoint
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
def _restore_with_numpy_args():
metadata = checkpointer.metadata(path)
def _mk_restore_args(leaf):
if hasattr(leaf, "shape") and hasattr(leaf, "dtype"):
return ArrayRestoreArgs(restore_type=np.ndarray)
return RestoreArgs()
restore_args = jax.tree_util.tree_map(_mk_restore_args, metadata.tree)
return checkpointer.restore(path, restore_args=restore_args)
force_cpu_restore = (
os.environ.get("JAX_AHT_FORCE_CPU_RESTORE", "0") == "1"
or os.environ.get("JAX_PLATFORMS", "").lower() == "cpu"
)
if force_cpu_restore:
restored = _restore_with_numpy_args()
else:
try:
restored = checkpointer.restore(path)
except Exception as exc:
msg = str(exc)
recoverable = (
"sharding passed to deserialization" in msg
or "Device cuda:0 was not found in jax.local_devices()" in msg
)
if not recoverable:
raise
restored = _restore_with_numpy_args()
# convert pytree leaves from np arrays to jax arrays
restored = jax.tree_util.tree_map(
lambda x: jnp.array(x) if isinstance(x, np.ndarray) else x,
restored
)
return restored
def save_train_run_as_pickle(out, savedir, savename):
if not os.path.exists(savedir):
os.makedirs(savedir, exist_ok=True)
savepath = f"{savedir}/{savename}.pkl"
with open(savepath, "wb") as f:
pickle.dump(out, f)
return savepath
def load_checkpoints_from_pickle(path, ckpt_key="checkpoints"):
out = load_train_run_from_pickle(path)
return out[ckpt_key]
def load_train_run_from_pickle(path):
with open(path, "rb") as f:
out = pickle.load(f)
return out |