File size: 5,998 Bytes
5146e76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os
import pickle
import orbax.checkpoint
from flax.training import orbax_utils
import jax
import jax.numpy as jnp
import numpy as np
from orbax.checkpoint import ArrayRestoreArgs, RestoreArgs

# suppress logging from orbax 
import logging
logger = logging.getLogger("absl")
logger.setLevel(logging.ERROR)

# compute path to repo root by using this file's path
REPO_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

def save_train_run(out, savedir, savename):
    '''Save train run as orbax checkpoint. 
    Orbax requires absolute paths, so we compute the absolute path to the repo root.'''
    # determine whether savedir is relative or absolute
    if not os.path.isabs(savedir):
        savedir = os.path.join(REPO_PATH, savedir)
    if not os.path.exists(savedir):
        os.makedirs(savedir, exist_ok=True)
    savepath = os.path.join(savedir, savename)
    
    checkpointer = orbax.checkpoint.PyTreeCheckpointer()
    save_args = orbax_utils.save_args_from_target(out)
    
    # Save the checkpoint
    checkpointer.save(savepath, out, save_args=save_args)
    return savepath

def load_checkpoints(path, ckpt_key="checkpoints", custom_loader_cfg: dict=None):
    '''Load checkpoints from orbax checkpoint.
    Orbax requires absolute paths, so we compute the absolute path to the repo root.'''
    if custom_loader_cfg is None:
        restored = load_train_run(path)
        return restored[ckpt_key]
    elif custom_loader_cfg["name"] == "open_ended":
        # Open-ended loader needs the full checkpoint
        restored = load_train_run(path)
        partner_out, ego_out = restored
        out = ego_out if custom_loader_cfg["type"] == "ego" else partner_out
        if ckpt_key == "final_buffer":
            return out["final_buffer"]["params"]
        else:
            return out[ckpt_key]
    elif custom_loader_cfg["name"] == "partial_load":
        return _load_partial(path, ckpt_key)
    elif custom_loader_cfg["name"] == "fcp":
        # FCP saves checkpoints with shape
        #   (NUM_SEEDS, PARTNER_POP_SIZE, NUM_CHECKPOINTS, ...)
        # Reshape to the (NUM_SEEDS, FCP_POP_SIZE, ...) layout that
        # AgentPopulation expects, where FCP_POP_SIZE = PARTNER_POP_SIZE *
        # NUM_CHECKPOINTS. Mirrors get_fcp_population in
        # teammate_generation/fcp.py so a saved FCP run can be reused as a
        # partner pool by ego_agent_training/run.py.
        restored = load_train_run(path)
        ckpts = restored[ckpt_key]
        return jax.tree.map(
            lambda x: x.reshape(x.shape[0], x.shape[1] * x.shape[2], *x.shape[3:]),
            ckpts,
        )
    else:
        raise ValueError(f"Invalid custom loader name: {custom_loader_cfg['name']}")

def _load_partial(path, ckpt_key):
    '''Load only a single top-level key from an orbax checkpoint, avoiding OOM
    from loading the entire pytree (e.g. skipping metrics).'''
    if not os.path.isabs(path):
        path = os.path.join(REPO_PATH, path)

    checkpointer = orbax.checkpoint.PyTreeCheckpointer()
    cpu_sharding = jax.sharding.SingleDeviceSharding(jax.devices('cpu')[0])
    meta = checkpointer.metadata(path)

    if ckpt_key not in meta:
        raise KeyError(f"Key '{ckpt_key}' not found in checkpoint. Available keys: {list(meta.keys())}")

    subtree_meta = meta[ckpt_key]
    item = {ckpt_key: jax.tree.map(
        lambda m: np.empty(m.shape, dtype=m.dtype) if hasattr(m, 'shape') else m,
        subtree_meta,
    )}
    transforms = {ckpt_key: orbax.checkpoint.Transform()}
    restore_args = {ckpt_key: jax.tree.map(
        lambda _: orbax.checkpoint.ArrayRestoreArgs(sharding=cpu_sharding),
        subtree_meta,
    )}

    restored = checkpointer.restore(path, item=item, transforms=transforms, restore_args=restore_args)
    return restored[ckpt_key]

def load_train_run(path):
    '''Load checkpoints from orbax checkpoint. 
    Orbax requires absolute paths, so we compute the absolute path to the repo root.'''
    # determine whether path is relative or absolute
    if not os.path.isabs(path):
        path = os.path.join(REPO_PATH, path)
    # load the checkpoint
    checkpointer = orbax.checkpoint.PyTreeCheckpointer()
    def _restore_with_numpy_args():
        metadata = checkpointer.metadata(path)

        def _mk_restore_args(leaf):
            if hasattr(leaf, "shape") and hasattr(leaf, "dtype"):
                return ArrayRestoreArgs(restore_type=np.ndarray)
            return RestoreArgs()

        restore_args = jax.tree_util.tree_map(_mk_restore_args, metadata.tree)
        return checkpointer.restore(path, restore_args=restore_args)

    force_cpu_restore = (
        os.environ.get("JAX_AHT_FORCE_CPU_RESTORE", "0") == "1"
        or os.environ.get("JAX_PLATFORMS", "").lower() == "cpu"
    )

    if force_cpu_restore:
        restored = _restore_with_numpy_args()
    else:
        try:
            restored = checkpointer.restore(path)
        except Exception as exc:
            msg = str(exc)
            recoverable = (
                "sharding passed to deserialization" in msg
                or "Device cuda:0 was not found in jax.local_devices()" in msg
            )
            if not recoverable:
                raise
            restored = _restore_with_numpy_args()
    # convert pytree leaves from np arrays to jax arrays
    restored = jax.tree_util.tree_map(
        lambda x: jnp.array(x) if isinstance(x, np.ndarray) else x,
        restored
    )
    return restored

def save_train_run_as_pickle(out, savedir, savename):
    if not os.path.exists(savedir):
        os.makedirs(savedir, exist_ok=True)
        
    savepath = f"{savedir}/{savename}.pkl"
    with open(savepath, "wb") as f:
        pickle.dump(out, f)
    return savepath

def load_checkpoints_from_pickle(path, ckpt_key="checkpoints"):
    out = load_train_run_from_pickle(path)
    return out[ckpt_key]

def load_train_run_from_pickle(path):
    with open(path, "rb") as f:
        out = pickle.load(f)
    return out