English
naveensp commited on
Commit
f9bac78
·
verified ·
1 Parent(s): a816a73

Upload checkpoint.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. checkpoint.py +1671 -0
checkpoint.py ADDED
@@ -0,0 +1,1671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import io
3
+ import logging
4
+ import pickle
5
+ import shutil
6
+ import traceback
7
+ from abc import ABCMeta, abstractmethod
8
+ from collections import defaultdict
9
+ from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
10
+ from contextlib import contextmanager
11
+ from copy import deepcopy
12
+ from dataclasses import dataclass, field, replace
13
+ from functools import reduce
14
+ from multiprocessing import shared_memory
15
+ from pathlib import Path
16
+ from typing import Any, Dict, Generator, List, Optional, Set, Tuple, cast
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.distributed.checkpoint as dist_cp
21
+ import torch.multiprocessing as mp
22
+ from packaging import version
23
+ from torch.distributed import _remote_device
24
+ from torch.distributed._shard._utils import narrow_tensor_by_index
25
+ from torch.distributed._shard.metadata import ShardMetadata
26
+ from torch.distributed._shard.sharded_tensor import ShardedTensor
27
+ from torch.distributed.checkpoint.filesystem import WriteResult, _StorageInfo
28
+ from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex
29
+ from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
30
+ from torch.distributed.checkpoint.planner import LoadItemType, ReadItem
31
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
32
+ from torch.distributed.fsdp import StateDictType
33
+ from torch.distributed.fsdp.api import (
34
+ FullOptimStateDictConfig,
35
+ FullStateDictConfig,
36
+ ShardedOptimStateDictConfig,
37
+ ShardedStateDictConfig,
38
+ )
39
+ from torch.futures import Future
40
+
41
+ try:
42
+ from torch.distributed.fsdp.flat_param import FlatParamHandle # type: ignore
43
+ except ModuleNotFoundError:
44
+ from torch.distributed.fsdp._flat_param import FlatParamHandle # type: ignore
45
+
46
+ from . import util
47
+
48
+ from .aliases import PathOrStr
49
+ from .config import BaseConfig, ShardedCheckpointerType, TrainConfig
50
+ from .exceptions import OLMoCheckpointError
51
+ from .optim import Optimizer, fix_optim_state_dict
52
+ from .safetensors_util import safetensors_file_to_state_dict
53
+ from .torch_util import (
54
+ barrier,
55
+ gc_cuda,
56
+ get_fs_local_rank,
57
+ get_global_rank,
58
+ get_world_size,
59
+ )
60
+ from .util import (
61
+ _get_s3_client,
62
+ default_thread_count,
63
+ dir_is_empty,
64
+ get_bytes_range,
65
+ get_progress_bar,
66
+ resource_path,
67
+ upload,
68
+ wait_for,
69
+ )
70
+
71
+ __all__ = [
72
+ "save_fsdp_model_and_optim_state",
73
+ "load_fsdp_model_and_optim_state",
74
+ "load_fsdp_optim_state",
75
+ "save_state_dict",
76
+ "load_state_dict",
77
+ "load_model_state",
78
+ "RemoteFileSystemWriter",
79
+ "RemoteFileSystemReader",
80
+ "Checkpointer",
81
+ "FullCheckpointer",
82
+ "TorchNewStyleShardedCheckpointer",
83
+ "TorchLegacyShardedCheckpointer",
84
+ "LocalShardedCheckpointer",
85
+ "build_sharded_checkpointer",
86
+ ]
87
+
88
+
89
+ log = logging.getLogger(__name__)
90
+
91
+ MODEL_AND_OPTIM_FOLDER = "model_and_optim"
92
+
93
+
94
+ def save_fsdp_model_and_optim_state(
95
+ checkpoint_dir: PathOrStr,
96
+ fsdp_model: FSDP,
97
+ optim: Optimizer,
98
+ *,
99
+ upload_to: Optional[str] = None,
100
+ save_overwrite: bool = False,
101
+ ):
102
+ """
103
+ Use this to save a state dict for an FSDP model and its optimizer via :module:`torch.distributed.checkpoint`
104
+ functions. This should be used during distributed training and should be called by all ranks.
105
+
106
+ :param checkpoint_dir: The directory to save to.
107
+ :param fsdp_model: The FSDP model.
108
+ :param optim: The FSDP model's optimizer.
109
+ :param upload_to: Optional, a remote "directory" to upload the checkpoint files to.
110
+ :param save_overwrite: Overwrite existing files.
111
+
112
+ :raises FileExistsError: If a model and optim checkpoint already exists in ``checkpoint_dir`` and ``save_overwrite=False``.
113
+ """
114
+ checkpoint_dir = Path(checkpoint_dir)
115
+ target_dir = checkpoint_dir / MODEL_AND_OPTIM_FOLDER
116
+ if save_overwrite:
117
+ if get_fs_local_rank() == 0:
118
+ shutil.rmtree(target_dir, ignore_errors=True)
119
+ elif not dir_is_empty(target_dir):
120
+ raise FileExistsError(target_dir)
121
+ barrier()
122
+ if get_fs_local_rank() == 0:
123
+ target_dir.mkdir(exist_ok=True, parents=True)
124
+ barrier()
125
+ with FSDP.state_dict_type(
126
+ fsdp_model,
127
+ state_dict_type=StateDictType.SHARDED_STATE_DICT,
128
+ state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
129
+ optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
130
+ ):
131
+ model_and_optim_state = {
132
+ "model": fsdp_model.state_dict(),
133
+ "optim": FSDP.optim_state_dict(fsdp_model, optim),
134
+ }
135
+ dist_cp.save_state_dict(
136
+ model_and_optim_state,
137
+ RemoteFileSystemWriter(
138
+ target_dir,
139
+ upload_to=None if upload_to is None else f"{upload_to.rstrip('/')}/{MODEL_AND_OPTIM_FOLDER}",
140
+ save_overwrite=save_overwrite,
141
+ ),
142
+ )
143
+
144
+
145
+ def load_fsdp_model_and_optim_state(
146
+ checkpoint_dir: PathOrStr,
147
+ fsdp_model: FSDP,
148
+ optim: Optimizer,
149
+ *,
150
+ local_cache: Optional[PathOrStr] = None,
151
+ load_optimizer_state: bool = True,
152
+ ):
153
+ """
154
+ Use this to load a state dict for an FSDP model and its optimizer via :module:`torch.distributed.checkpoint`
155
+ functions. This should be used during distributed training and should be called by all ranks.
156
+
157
+ :param checkpoint_dir: The checkpoint directory to load from. This can be a local or remote directory.
158
+ :param fsdp_model: The FSDP model.
159
+ :param optim: The FSDP model's optimizer.
160
+ :param local_cache: A local cache of the checkpoint directory. Use this when the ``checkpoint_dir`` is a
161
+ remote "directory" but there might be a cached version of the same artifacts.
162
+ :param load_optimizer_state: Set to ``False`` to skip loading the optimizer state.
163
+
164
+ :raises FileNotFoundError: If the ``checkpoint_dir`` doesn't contain a model and optimizer checkpoint.
165
+ """
166
+ load_path = str(checkpoint_dir).rstrip("/")
167
+ local_cache = None if local_cache is None else Path(local_cache)
168
+ with FSDP.state_dict_type(
169
+ fsdp_model,
170
+ state_dict_type=StateDictType.SHARDED_STATE_DICT,
171
+ state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
172
+ optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
173
+ ):
174
+ # Load the model state dict in place.
175
+ log.info("Loading model state...")
176
+ model_state = {"model": fsdp_model.state_dict()}
177
+ dist_cp.load_state_dict(
178
+ model_state,
179
+ RemoteFileSystemReader(
180
+ f"{load_path}/{MODEL_AND_OPTIM_FOLDER}",
181
+ local_cache=None if local_cache is None else local_cache / MODEL_AND_OPTIM_FOLDER,
182
+ ),
183
+ )
184
+ fsdp_model.load_state_dict(model_state["model"])
185
+
186
+ if not load_optimizer_state:
187
+ return
188
+
189
+ # Load optim state dict in place.
190
+ log.info("Loading sharded optimizer state...")
191
+ optim_state = load_sharded_optimizer_state_dict(
192
+ model_state_dict=model_state["model"],
193
+ optimizer_key="optim",
194
+ storage_reader=RemoteFileSystemReader(
195
+ f"{load_path}/{MODEL_AND_OPTIM_FOLDER}",
196
+ local_cache=None if local_cache is None else local_cache / MODEL_AND_OPTIM_FOLDER,
197
+ ),
198
+ )
199
+ del model_state
200
+ gc_cuda()
201
+ load_fsdp_optim_state(fsdp_model, optim, optim_state["optim"])
202
+
203
+
204
+ def load_fsdp_optim_state(fsdp_model: FSDP, optim: Optimizer, optim_state: Dict[str, Any]):
205
+ log.info("Flattening sharded optimizer state...")
206
+ # NOTE: Careful! The order of the these arguments has changed from 2.0 to 2.1... ¯\_(ツ)_/¯
207
+ if version.parse(torch.__version__) < version.parse("2.1.0"):
208
+ flattened_osd = FSDP.optim_state_dict_to_load(optim_state, fsdp_model, optim) # type: ignore
209
+ else:
210
+ flattened_osd = FSDP.optim_state_dict_to_load(fsdp_model, optim, optim_state) # type: ignore
211
+ del optim_state
212
+ gc.collect()
213
+ log.info("Loading flattened optimizer state...")
214
+ # Put optim state on CPU since `Optimizer.load_state_dict()` will create a deepcopy of the whole state dict,
215
+ # which takes up unnecessary GPU memory.
216
+ for state in flattened_osd["state"].values():
217
+ for k in state.keys():
218
+ v = state[k]
219
+ if isinstance(v, torch.Tensor):
220
+ state[k] = v.to(device="cpu")
221
+ gc_cuda()
222
+ optim.load_state_dict(fix_optim_state_dict(optim, flattened_osd))
223
+
224
+
225
+ def save_state_dict(
226
+ checkpoint_dir: PathOrStr,
227
+ fname: str,
228
+ state_dict: Dict[str, Any],
229
+ *,
230
+ upload_to: Optional[str] = None,
231
+ save_overwrite: bool = False,
232
+ synchronize: bool = True,
233
+ ):
234
+ """
235
+ Save a regular state dict to the file ``fname`` within ``checkpoint_dir`` using :func:`torch.save()`.
236
+ This can be used during distributed training or not. If during distributed training the ``fname`` should be unique
237
+ for each rank.
238
+
239
+ :param checkpoint_dir: The directory to save to.
240
+ :param fname: The target file within ``checkpoint_dir`` to save to. This should be a path relative to the ``checkpoint_dir``.
241
+ :param state_dict: The state dict to save.
242
+ :param upload_to: Optional, a remote "directory" to upload the file to.
243
+ :param save_overwrite: Overwrite existing files.
244
+ :param synchronize: If ``False``, don't do any distributed synchronization. Use this when only calling
245
+ this function from a single rank.
246
+
247
+ :raises FileExistsError: If the ``fname`` already exists within ``checkpoint_dir`` and ``save_overwrite=False``.
248
+ """
249
+ checkpoint_dir = Path(checkpoint_dir)
250
+ target_path = checkpoint_dir / fname
251
+ if save_overwrite:
252
+ target_path.unlink(missing_ok=True)
253
+ elif target_path.is_file():
254
+ raise FileExistsError(target_path)
255
+ if synchronize:
256
+ barrier()
257
+ target_path.parent.mkdir(exist_ok=True, parents=True)
258
+ if synchronize:
259
+ barrier()
260
+ torch.save(state_dict, target_path)
261
+ if upload_to is not None:
262
+ upload_target = f"{upload_to.rstrip('/')}/{fname}"
263
+ log.info(f"Uploading {target_path} to {upload_target}...")
264
+ upload(target_path, upload_target, save_overwrite=save_overwrite)
265
+
266
+
267
+ def load_state_dict(
268
+ checkpoint_dir: PathOrStr,
269
+ fname: str,
270
+ *,
271
+ local_cache: Optional[PathOrStr] = None,
272
+ map_location: Optional[str] = None,
273
+ ):
274
+ """
275
+ Load a regular state dict from the file ``fname`` within ``checkpoint_dir`` using :func:`torch.load()`.
276
+ This can be used during distributed training or not.
277
+
278
+ :param checkpoint_dir: A local or remote checkpoint directory.
279
+ :param fname: The target file within the ``checkpoint_dir``. This should be a path relative to the ``checkpoint_dir``.
280
+ :param local_cache: A local cache of the checkpoint directory. Use this when the ``checkpoint_dir`` is a
281
+ remote "directory" but there might be a cached version of the same artifacts.
282
+
283
+ :raises FileNotFoundError: If ``fname`` doesn't exist in the ``checkpoint_dir`` or the local cache.
284
+ """
285
+ if fname.endswith(".pt"):
286
+ # Try safetensors version first.
287
+ try:
288
+ path = resource_path(
289
+ str(checkpoint_dir).rstrip("/"), fname[:-2] + "safetensors", local_cache=local_cache
290
+ )
291
+ return safetensors_file_to_state_dict(path, map_location=map_location)
292
+ except FileNotFoundError:
293
+ pass
294
+
295
+ path = resource_path(str(checkpoint_dir).rstrip("/"), fname, local_cache=local_cache)
296
+ return torch.load(path, map_location=map_location)
297
+
298
+
299
+ def load_model_state(checkpoint_dir: PathOrStr, model: torch.nn.Module):
300
+ """
301
+ Load model state from a distributed FSDP model checkpoint created from :func:`save_fsdp_model_and_optim_state()`.
302
+ Note that ``model`` should not be wrapped with FSDP.
303
+ """
304
+ state_dict = {"model": model.state_dict()}
305
+ dist_cp.load_state_dict(
306
+ state_dict,
307
+ RemoteFileSystemReader(f"{str(checkpoint_dir).rstrip('/')}/{MODEL_AND_OPTIM_FOLDER}"),
308
+ no_dist=True,
309
+ )
310
+ model.load_state_dict(state_dict["model"])
311
+
312
+
313
+ class RemoteFileSystemWriter(dist_cp.FileSystemWriter):
314
+ """
315
+ A subclass of :class:`~torch.distributed.checkpoint.FileSystemWriter` that can upload files
316
+ directly to a cloud bucket when ``upload_to`` is specified.
317
+ """
318
+
319
+ def __init__(
320
+ self,
321
+ path: PathOrStr,
322
+ single_file_per_rank: bool = True,
323
+ sync_files: bool = True,
324
+ thread_count: Optional[int] = None,
325
+ per_thread_copy_ahead: int = 10_000_000,
326
+ upload_to: Optional[str] = None,
327
+ save_overwrite: bool = False,
328
+ ) -> None:
329
+ if thread_count is not None and thread_count <= 0:
330
+ raise ValueError("thread count must be at least 1")
331
+ super().__init__(
332
+ path,
333
+ single_file_per_rank=single_file_per_rank,
334
+ sync_files=sync_files,
335
+ # NOTE: we default to 1 thread here instead of whatever `default_thread_count()`
336
+ # returns because uploading big checkpoint files with multiple threads causes
337
+ # boto3 to fail in weird ways.
338
+ thread_count=thread_count or 1,
339
+ per_thread_copy_ahead=per_thread_copy_ahead,
340
+ )
341
+ self.upload_to = None if upload_to is None else upload_to.rstrip("/")
342
+ self.save_overwrite = save_overwrite
343
+
344
+ def write_data(
345
+ self,
346
+ plan: dist_cp.SavePlan,
347
+ planner: dist_cp.SavePlanner,
348
+ ) -> Future[List[WriteResult]]:
349
+ fut = super().write_data(plan, planner)
350
+ if self.upload_to is not None:
351
+ files_to_upload = set()
352
+ for write_result in fut.wait():
353
+ files_to_upload.add(write_result.storage_data.relative_path)
354
+
355
+ # Create the global S3 client up front to work around a threading issue in boto.
356
+ if self.upload_to.startswith("s3://"):
357
+ _get_s3_client("s3")
358
+ elif self.upload_to.startswith("r2://"):
359
+ _get_s3_client("r2")
360
+
361
+ with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
362
+ futures = []
363
+ for fname in files_to_upload:
364
+ source = self.path / fname
365
+ target = f"{self.upload_to}/{fname}"
366
+ log.info(f"Uploading {source} to {target}...")
367
+ futures.append(executor.submit(upload, source, target, save_overwrite=self.save_overwrite))
368
+ for f in as_completed(futures):
369
+ try:
370
+ f.result()
371
+ except BaseException:
372
+ # NOTE: we might get an error here that can't be pickled, which causes a different failure
373
+ # later when PyTorch tries to reduce that error across ranks. So here we just make
374
+ # sure we're raising a simple error type that can be pickled.
375
+ raise OLMoCheckpointError(f"Original error:\n{traceback.format_exc()}")
376
+ return fut
377
+
378
+ def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
379
+ super().finish(metadata, results)
380
+ if self.upload_to is not None:
381
+ source = self.path / ".metadata"
382
+ target = f"{self.upload_to}/.metadata"
383
+ log.info(f"Uploading {source} to {target}...")
384
+ upload(source, target, save_overwrite=self.save_overwrite)
385
+
386
+
387
+ class RemoteFileSystemReader(dist_cp.StorageReader):
388
+ """
389
+ A :class:`~torch.distributed.checkpoint.StorageReader` based on :class:`~torch.distributed.checkpoint.FileSystemReader`
390
+ that can read data directly from cloud storage as well as a local directory.
391
+ """
392
+
393
+ def __init__(
394
+ self, path: PathOrStr, *, local_cache: Optional[PathOrStr] = None, thread_count: Optional[int] = None
395
+ ):
396
+ super().__init__()
397
+ if thread_count is not None and thread_count <= 0:
398
+ raise ValueError("thread count must be at least 1")
399
+ self.path = str(path).rstrip("/")
400
+ self.cache = None if local_cache is None else Path(local_cache)
401
+ self.thread_count = thread_count or default_thread_count()
402
+ self.storage_data: Dict[MetadataIndex, _StorageInfo] = dict()
403
+ self._metadata: Optional[Metadata] = None
404
+
405
+ def _get_bytes(self, relative_path: str, offset: int, length: int) -> bytes:
406
+ if self.cache is not None and (path := self.cache / relative_path).is_file():
407
+ return get_bytes_range(path, offset, length)
408
+ else:
409
+ return get_bytes_range(f"{self.path}/{relative_path}", offset, length)
410
+
411
+ def _get_content_for_read(self, read_item: ReadItem) -> Tuple[ReadItem, bytes]:
412
+ sinfo = self.storage_data[read_item.storage_index]
413
+ content = self._get_bytes(sinfo.relative_path, sinfo.offset, sinfo.length)
414
+ return (read_item, content)
415
+
416
+ def read_data(self, plan: dist_cp.LoadPlan, planner: dist_cp.LoadPlanner) -> Future[None]:
417
+ # Create the global S3 client up front to work around a threading issue in boto.
418
+ if isinstance(self.path, str):
419
+ if self.path.startswith("s3://"):
420
+ _get_s3_client("s3")
421
+ elif self.path.startswith("r2://"):
422
+ _get_s3_client("r2")
423
+
424
+ with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
425
+ read_item_content_futures = []
426
+ for read_item in plan.items:
427
+ read_item_content_futures.append(executor.submit(self._get_content_for_read, read_item))
428
+ read_item_content_results = []
429
+ for f in as_completed(read_item_content_futures):
430
+ try:
431
+ read_item_content_results.append(f.result())
432
+ except BaseException:
433
+ # NOTE: we might get an error here that can't be pickled, which causes a different failure
434
+ # later when PyTorch tries to reduce that error across ranks. So here we just make
435
+ # sure we're raising a simple error type that can be pickled.
436
+ raise OLMoCheckpointError(f"Original error:\n{traceback.format_exc()}")
437
+
438
+ # Modified from `FileSystemReader.read_data()`
439
+ for read_item, content in read_item_content_results:
440
+ bytes = io.BytesIO(content)
441
+ bytes.seek(0)
442
+ if read_item.type == LoadItemType.BYTE_IO:
443
+ planner.load_bytes(read_item, bytes)
444
+ else:
445
+ tensor = cast(torch.Tensor, torch.load(bytes, map_location="cpu"))
446
+ tensor = narrow_tensor_by_index(tensor, read_item.storage_offsets, read_item.lengths)
447
+ target_tensor = planner.resolve_tensor(read_item).detach()
448
+
449
+ assert (
450
+ target_tensor.size() == tensor.size()
451
+ ), f"req {read_item.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
452
+ target_tensor.copy_(tensor)
453
+ planner.commit_tensor(read_item, target_tensor)
454
+
455
+ fut: Future = Future()
456
+ fut.set_result(None)
457
+ return fut
458
+
459
+ def read_metadata(self) -> Metadata:
460
+ if self._metadata is None:
461
+ with resource_path(self.path, ".metadata", local_cache=self.cache).open("rb") as metadata_file:
462
+ self._metadata = pickle.load(metadata_file)
463
+ return self._metadata
464
+
465
+ def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
466
+ del is_coordinator
467
+ self.storage_data = metadata.storage_data
468
+ assert self.storage_data is not None
469
+
470
+ def prepare_local_plan(self, plan: dist_cp.LoadPlan) -> dist_cp.LoadPlan:
471
+ return plan
472
+
473
+ def prepare_global_plan(self, global_plan: List[dist_cp.LoadPlan]) -> List[dist_cp.LoadPlan]:
474
+ return global_plan
475
+
476
+
477
+ class Checkpointer(metaclass=ABCMeta):
478
+ def __init__(self, cfg: TrainConfig, thread_count: Optional[int] = None):
479
+ self.cfg = cfg
480
+ self.thread_count = thread_count or default_thread_count()
481
+
482
+ @abstractmethod
483
+ def save_checkpoint(
484
+ self,
485
+ dir: PathOrStr,
486
+ fsdp_model: FSDP,
487
+ optim: Optimizer,
488
+ train_state: Dict[str, Any],
489
+ *,
490
+ upload_to: Optional[str] = None,
491
+ ) -> None:
492
+ raise NotImplementedError
493
+
494
+ @abstractmethod
495
+ def restore_checkpoint(
496
+ self,
497
+ load_path: PathOrStr,
498
+ fsdp_model: FSDP,
499
+ optim: Optimizer,
500
+ *,
501
+ local_cache: Optional[PathOrStr] = None,
502
+ load_optimizer_state: bool = True,
503
+ ) -> Dict[str, Any]:
504
+ """
505
+ Restores a checkpoint to the model and optimizer. Returns the remaining trainer state.
506
+ """
507
+ raise NotImplementedError
508
+
509
+ def unshard_checkpoint(
510
+ self,
511
+ load_path: PathOrStr,
512
+ *,
513
+ local_cache: Optional[PathOrStr] = None,
514
+ load_optimizer_state: bool = True,
515
+ load_trainer_state: bool = True,
516
+ device: Optional[torch.device] = None,
517
+ ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
518
+ """
519
+ Unshard a checkpoint.
520
+
521
+ Note this is not marked abstract because child classes are not required to implemented this.
522
+ """
523
+ del load_path, local_cache, load_optimizer_state, load_trainer_state, device
524
+ raise NotImplementedError
525
+
526
+ @contextmanager
527
+ def _temporary_wd(self, dir: PathOrStr) -> Generator[Path, None, None]:
528
+ # Make sure checkpoint directory doesn't exist unless it's okay to overwrite it.
529
+ checkpoint_dir = Path(dir)
530
+ if not dir_is_empty(checkpoint_dir):
531
+ if self.cfg.save_overwrite:
532
+ if get_fs_local_rank() == 0:
533
+ shutil.rmtree(checkpoint_dir, ignore_errors=True)
534
+ else:
535
+ raise FileExistsError(checkpoint_dir)
536
+ # No need to mkdir here since we'll directly replace the temporary directory with
537
+ # this directory below.
538
+ barrier()
539
+
540
+ # Prepare temporary directory. We don't have to be as careful here, we can
541
+ # just remove it if it already exists.
542
+ checkpoint_dir_tmp = checkpoint_dir.with_name(checkpoint_dir.name + "-tmp")
543
+ if get_fs_local_rank() == 0:
544
+ shutil.rmtree(checkpoint_dir_tmp, ignore_errors=True)
545
+ checkpoint_dir_tmp.mkdir(exist_ok=True, parents=True)
546
+
547
+ barrier()
548
+
549
+ # Yield temporary directory for `.save_checkpoint()` to use.
550
+ yield checkpoint_dir_tmp
551
+
552
+ barrier()
553
+
554
+ # Finally if all went well replace the temporary directory with the actual
555
+ # checkpoint directory.
556
+ if get_fs_local_rank() == 0:
557
+ # Replace temp directory with target checkpoint directory.
558
+ try:
559
+ checkpoint_dir_tmp.replace(checkpoint_dir)
560
+ except FileNotFoundError:
561
+ # Caught when another (file-system) local rank 0 has already replaced the tmp directory.
562
+ # This can happen when nodes are saving to a common NFS drive but otherwise have distinct
563
+ # file-systems.
564
+ if not checkpoint_dir.exists():
565
+ raise
566
+
567
+ # In the cases where we're using a shared NFS drive between ranks to save checkpoints,
568
+ # replacing the temp directory with the final directory from rank 0 might not be immediately
569
+ # realized in the file systems of the other ranks.
570
+ # So we wait here across all ranks until that final checkpoint directory is visible.
571
+ wait_for(lambda: checkpoint_dir.exists(), "Waiting for checkpoint directory", timeout=10.0)
572
+
573
+ barrier()
574
+
575
+ def _save_config(self, dir: PathOrStr, *, upload_to: Optional[str] = None) -> None:
576
+ if get_global_rank() == 0:
577
+ log.info("Saving config...")
578
+ self.cfg.save(config_path := Path(dir) / "config.yaml")
579
+ if upload_to is not None:
580
+ upload_target = f"{upload_to}/config.yaml"
581
+ log.info(f"Uploading {config_path} to {upload_target}")
582
+ upload(config_path, upload_target, save_overwrite=self.cfg.save_overwrite)
583
+
584
+
585
+ class FullCheckpointer(Checkpointer):
586
+ """
587
+ A :class:`Checkpointer` that saves a single full model and optimizer state dictionary.
588
+ """
589
+
590
+ def save_checkpoint(
591
+ self,
592
+ dir: PathOrStr,
593
+ fsdp_model: FSDP,
594
+ optim: Optimizer,
595
+ trainer_state: Dict[str, Any],
596
+ *,
597
+ upload_to: Optional[str] = None,
598
+ ) -> None:
599
+ with self._temporary_wd(dir) as checkpoint_dir:
600
+ with FSDP.state_dict_type(
601
+ fsdp_model,
602
+ state_dict_type=StateDictType.FULL_STATE_DICT,
603
+ state_dict_config=FullStateDictConfig(rank0_only=True, offload_to_cpu=True),
604
+ optim_state_dict_config=FullOptimStateDictConfig(rank0_only=True, offload_to_cpu=True),
605
+ ):
606
+ # We'll write the model and optimizer state dicts individually to reduce (CPU) memory consumption.
607
+ # First the model state.
608
+ model_state_dict = fsdp_model.state_dict()
609
+ if get_global_rank() == 0:
610
+ log.info("Saving model state...")
611
+ save_state_dict(
612
+ checkpoint_dir,
613
+ "model.pt",
614
+ model_state_dict,
615
+ upload_to=upload_to,
616
+ save_overwrite=self.cfg.save_overwrite,
617
+ synchronize=False,
618
+ )
619
+ del model_state_dict
620
+ barrier()
621
+
622
+ # Then the optimizer state.
623
+ optim_state_dict = FSDP.optim_state_dict(fsdp_model, optim)
624
+ if get_global_rank() == 0:
625
+ log.info("Saving optim state...")
626
+ save_state_dict(
627
+ checkpoint_dir,
628
+ "optim.pt",
629
+ optim_state_dict,
630
+ upload_to=upload_to,
631
+ save_overwrite=self.cfg.save_overwrite,
632
+ synchronize=False,
633
+ )
634
+ del optim_state_dict
635
+ barrier()
636
+
637
+ # Save trainer state.
638
+ if get_global_rank() == 0:
639
+ log.info("Saving trainer state...")
640
+ save_state_dict(
641
+ checkpoint_dir,
642
+ "train.pt",
643
+ trainer_state,
644
+ upload_to=upload_to,
645
+ save_overwrite=self.cfg.save_overwrite,
646
+ synchronize=False,
647
+ )
648
+ # Save config.
649
+ self._save_config(checkpoint_dir, upload_to=upload_to)
650
+
651
+ def restore_checkpoint(
652
+ self,
653
+ load_path: PathOrStr,
654
+ fsdp_model: FSDP,
655
+ optim: Optimizer,
656
+ *,
657
+ local_cache: Optional[PathOrStr] = None,
658
+ load_optimizer_state: bool = True,
659
+ ) -> Dict[str, Any]:
660
+ with FSDP.state_dict_type(
661
+ fsdp_model,
662
+ state_dict_type=StateDictType.FULL_STATE_DICT,
663
+ state_dict_config=FullStateDictConfig(rank0_only=False, offload_to_cpu=True),
664
+ optim_state_dict_config=FullOptimStateDictConfig(rank0_only=False, offload_to_cpu=True),
665
+ ):
666
+ with torch.no_grad():
667
+ # fill everything with NaN, so we can check afterwards that every parameter has been restored
668
+ for module_name, module in fsdp_model.named_modules():
669
+ if not isinstance(module, FSDP):
670
+ continue
671
+ for param in module.params:
672
+ param.fill_(torch.nan)
673
+
674
+ # restore params from checkpoint
675
+ state_dict_to_load = load_state_dict(
676
+ load_path, "model.pt", local_cache=local_cache, map_location="cpu"
677
+ )
678
+ (
679
+ state_dict_to_load,
680
+ og_keys_to_new,
681
+ ) = fsdp_model._fsdp_wrapped_module._make_state_dict_compatible(state_dict_to_load)
682
+
683
+ for module_name, module in fsdp_model.named_modules():
684
+ if not isinstance(module, FSDP):
685
+ continue
686
+ for param in module.params:
687
+ assert param._is_flat_param
688
+ for fqn, spi in zip(param._fqns, param._shard_param_infos):
689
+ if not spi.in_shard:
690
+ continue
691
+ key = f"{module_name}.{fqn}"
692
+ key = key.replace("_fsdp_wrapped_module.", "")
693
+ key = key.lstrip(".")
694
+ t = state_dict_to_load[key]
695
+ t = t.flatten()
696
+ param[spi.offset_in_shard : spi.offset_in_shard + spi.numel_in_shard].copy_(
697
+ t[spi.intra_param_start_idx : spi.intra_param_end_idx + 1]
698
+ )
699
+
700
+ # make sure that every parameter has been restored
701
+ for module_name, module in fsdp_model.named_modules():
702
+ if not isinstance(module, FSDP):
703
+ continue
704
+ for param in module.params:
705
+ if torch.isnan(param).any():
706
+ raise ValueError(
707
+ f"Module '{module_name}' contains NaNs, this is likely a bug restoring from full checkpoints"
708
+ )
709
+
710
+ # Load optimizer state.
711
+ if load_optimizer_state:
712
+ optim_state_dict_to_load = load_state_dict(
713
+ load_path, "optim.pt", local_cache=local_cache, map_location="cpu"
714
+ )
715
+ optim_state_dict_to_load = self._make_optim_state_dict_compatible(
716
+ optim_state_dict_to_load,
717
+ og_keys_to_new,
718
+ )
719
+ load_fsdp_optim_state(fsdp_model, optim, optim_state_dict_to_load)
720
+ del optim_state_dict_to_load
721
+
722
+ # Load other state.
723
+ try:
724
+ trainer_state = load_state_dict(load_path, "train.pt", local_cache=local_cache)
725
+ except FileNotFoundError:
726
+ # for backwards compatibility
727
+ trainer_state = load_state_dict(load_path, "other.pt", local_cache=local_cache)
728
+ barrier()
729
+ return trainer_state
730
+
731
+ def _make_optim_state_dict_compatible(
732
+ self, optim_state_dict: Dict[str, Any], og_keys_to_new: Dict[str, Set[str]]
733
+ ) -> Dict[str, Any]:
734
+ # This state dict comes in two forms: one where the state keys are integers and one where the
735
+ # keys are fully qualified parameter names. The latter case is easier to deal with here so we
736
+ # first transform the integer key form into the FQN key form.
737
+ if isinstance(optim_state_dict["param_groups"][0]["params"][0], int):
738
+ id_to_fqn: Dict[int, str] = {}
739
+ for group in optim_state_dict["param_groups"]:
740
+ new_param_names = []
741
+ for fqn, id in zip(group["param_names"], group["params"]):
742
+ fqn = fqn.replace("_fsdp_wrapped_module.", "")
743
+ id_to_fqn[id] = fqn
744
+ new_param_names.append(fqn)
745
+ group["param_names"] = new_param_names
746
+ group["params"] = new_param_names
747
+ for id in list(optim_state_dict["state"].keys()):
748
+ optim_state_dict["state"][id_to_fqn[id]] = optim_state_dict["state"].pop(id)
749
+ else:
750
+ # Otherwise we still want to clean up the param names to remove the "_fsdp_wrapped_module." prefix.
751
+ for group in optim_state_dict["param_groups"]:
752
+ group["param_names"] = [fqn.replace("_fsdp_wrapped_module.", "") for fqn in group["param_names"]]
753
+ group["params"] = [fqn.replace("_fsdp_wrapped_module.", "") for fqn in group["params"]]
754
+ assert group["param_names"] == group["params"]
755
+ for key in list(optim_state_dict["state"].keys()):
756
+ optim_state_dict["state"][key.replace("_fsdp_wrapped_module.", "")] = optim_state_dict[
757
+ "state"
758
+ ].pop(key)
759
+
760
+ # Now we can transform the state dict by renaming parameters according to `og_keys_to_new`.
761
+ # First fix param names in the state.
762
+ for og_key, new_keys in og_keys_to_new.items():
763
+ og_state = optim_state_dict["state"].pop(og_key, None)
764
+ if og_state is None:
765
+ continue
766
+ for i, new_key in enumerate(new_keys):
767
+ if i == len(new_keys) - 1:
768
+ optim_state_dict["state"][new_key] = og_state
769
+ else:
770
+ optim_state_dict["state"][new_key] = deepcopy(og_state)
771
+ # Now fix param names in the param groups.
772
+ for group in optim_state_dict["param_groups"]:
773
+ og_names = group["params"]
774
+ new_names = []
775
+ for og_key in og_names:
776
+ for new_key in og_keys_to_new[og_key]:
777
+ new_names.append(new_key)
778
+ group["params"] = new_names
779
+ group["param_names"] = new_names
780
+
781
+ return optim_state_dict
782
+
783
+ def load_checkpoint(
784
+ self,
785
+ load_path: PathOrStr,
786
+ *,
787
+ local_cache: Optional[PathOrStr] = None,
788
+ load_optimizer_state: bool = True,
789
+ device: Optional[torch.device] = None,
790
+ ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]]]:
791
+ device = device if device is not None else torch.device("cpu")
792
+ model_state = load_state_dict(load_path, "model.pt", local_cache=local_cache, map_location=device) # type: ignore
793
+ optim_state = None
794
+ if load_optimizer_state:
795
+ optim_state = load_state_dict(load_path, "optim.pt", local_cache=local_cache, map_location=device) # type: ignore
796
+ return model_state, optim_state
797
+
798
+
799
+ class TorchNewStyleShardedCheckpointer(Checkpointer):
800
+ """
801
+ A sharded :class:`Checkpointer` that uses PyTorch's new distributed checkpointing functionality.
802
+ """
803
+
804
+ def save_checkpoint(
805
+ self,
806
+ dir: PathOrStr,
807
+ fsdp_model: FSDP,
808
+ optim: Optimizer,
809
+ trainer_state: Dict[str, Any],
810
+ *,
811
+ upload_to: Optional[str] = None,
812
+ ) -> None:
813
+ with self._temporary_wd(dir) as checkpoint_dir:
814
+ # Save model and optim state.
815
+ save_fsdp_model_and_optim_state(
816
+ checkpoint_dir,
817
+ fsdp_model,
818
+ optim,
819
+ upload_to=upload_to,
820
+ save_overwrite=self.cfg.save_overwrite,
821
+ )
822
+
823
+ # Save trainer state.
824
+ log.info("Saving trainer state...")
825
+ save_state_dict(
826
+ checkpoint_dir,
827
+ f"train/rank{get_global_rank()}.pt",
828
+ trainer_state,
829
+ upload_to=upload_to,
830
+ save_overwrite=self.cfg.save_overwrite,
831
+ )
832
+
833
+ # Save config.
834
+ self._save_config(checkpoint_dir, upload_to=upload_to)
835
+
836
+ def restore_checkpoint(
837
+ self,
838
+ load_path: PathOrStr,
839
+ fsdp_model: FSDP,
840
+ optim: Optimizer,
841
+ *,
842
+ local_cache: Optional[PathOrStr] = None,
843
+ load_optimizer_state: bool = True,
844
+ ) -> Dict[str, Any]:
845
+ # Load model and optimizer state in place.
846
+ log.info("Loading model and optimizer state...")
847
+ load_fsdp_model_and_optim_state(
848
+ load_path,
849
+ fsdp_model,
850
+ optim,
851
+ local_cache=local_cache,
852
+ load_optimizer_state=load_optimizer_state,
853
+ )
854
+
855
+ # Load trainer state dict.
856
+ log.info("Loading trainer state...")
857
+ try:
858
+ trainer_state = load_state_dict(
859
+ load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache
860
+ )
861
+ except FileNotFoundError:
862
+ # Fall back to rank 0 train state.
863
+ # This can happen when we're restoring a checkpoint with a different world size.
864
+ trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache)
865
+ barrier()
866
+ return trainer_state
867
+
868
+
869
+ class TorchLegacyShardedCheckpointer(Checkpointer):
870
+ """
871
+ A sharded :class:`Checkpointer` that just uses `torch.save()` with extra logic for handling FSDP model
872
+ and optim state.
873
+
874
+ The world size must be kept consistent when using this checkpointer.
875
+ """
876
+
877
+ def save_checkpoint(
878
+ self,
879
+ dir: PathOrStr,
880
+ fsdp_model: FSDP,
881
+ optim: Optimizer,
882
+ trainer_state: Dict[str, Any],
883
+ *,
884
+ upload_to: Optional[str] = None,
885
+ ) -> None:
886
+ with self._temporary_wd(dir) as checkpoint_dir:
887
+ with FSDP.state_dict_type(
888
+ fsdp_model,
889
+ state_dict_type=StateDictType.SHARDED_STATE_DICT,
890
+ state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
891
+ optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
892
+ ):
893
+ state_dict = {
894
+ "model": fsdp_model.state_dict(),
895
+ "optim": FSDP.optim_state_dict(fsdp_model, optim),
896
+ **trainer_state,
897
+ }
898
+ save_state_dict(
899
+ checkpoint_dir,
900
+ f"rank{get_global_rank()}.pt",
901
+ state_dict,
902
+ upload_to=upload_to,
903
+ save_overwrite=self.cfg.save_overwrite,
904
+ )
905
+
906
+ # Save config.
907
+ self._save_config(checkpoint_dir, upload_to=upload_to)
908
+
909
+ def restore_checkpoint(
910
+ self,
911
+ load_path: PathOrStr,
912
+ fsdp_model: FSDP,
913
+ optim: Optimizer,
914
+ *,
915
+ local_cache: Optional[PathOrStr] = None,
916
+ load_optimizer_state: bool = True,
917
+ ) -> Dict[str, Any]:
918
+ with FSDP.state_dict_type(
919
+ fsdp_model,
920
+ state_dict_type=StateDictType.SHARDED_STATE_DICT,
921
+ state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
922
+ optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
923
+ ):
924
+ # Deserialize state dict.
925
+ state_dict = load_state_dict(
926
+ load_path, f"rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu"
927
+ )
928
+
929
+ # Load model and optimizer state.
930
+ log.info("Loading model state...")
931
+ fsdp_model.load_state_dict(state_dict["model"])
932
+ del state_dict["model"]
933
+ if load_optimizer_state:
934
+ log.info("Loading optimizer state...")
935
+ load_fsdp_optim_state(fsdp_model, optim, state_dict["optim"])
936
+ del state_dict["optim"]
937
+
938
+ barrier()
939
+ return state_dict
940
+
941
+ def unshard_checkpoint(
942
+ self,
943
+ load_path: PathOrStr,
944
+ *,
945
+ local_cache: Optional[PathOrStr] = None,
946
+ load_optimizer_state: bool = True,
947
+ load_trainer_state: bool = True,
948
+ device: Optional[torch.device] = None,
949
+ ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
950
+ assert local_cache is None, "this method currently only supports local files"
951
+ full_state_dict = self._unshard(load_path, device or torch.device("cpu"), skip_keys={"rng"})
952
+ model_state = full_state_dict.pop("model")
953
+ optim_state = full_state_dict.pop("optim")
954
+ return (
955
+ model_state,
956
+ optim_state if load_optimizer_state else None,
957
+ full_state_dict if load_trainer_state else None,
958
+ )
959
+
960
+ def _copy_sharded_tensors_to_shared_mem(self, state: Dict, world_size: int, rank: int, key: Tuple):
961
+ key = tuple() if key is None else key
962
+ if isinstance(state, (list, tuple, set)):
963
+ for i, sub_state in enumerate(state):
964
+ self._copy_sharded_tensors_to_shared_mem(sub_state, world_size, rank, key + (i,))
965
+ elif isinstance(state, dict):
966
+ for name in state.keys():
967
+ self._copy_sharded_tensors_to_shared_mem(state[name], world_size, rank, key + (name,))
968
+ elif isinstance(state, ShardedTensor):
969
+ self._copy_sharded_tensor_to_shared_mem(state, world_size, rank, key)
970
+ return
971
+ else:
972
+ return
973
+
974
+ def _get_shard_placement_and_rank_sizes(
975
+ self, shards_metadata: List[ShardMetadata], world_size: int
976
+ ) -> Tuple[Dict[ShardMetadata, Tuple[int, int]], List[int]]:
977
+ def shard_size(shard_md):
978
+ return reduce((lambda x, y: x * y), shard_md.shard_sizes) # type: ignore[attr-defined]
979
+
980
+ rank_sizes = [0 for _ in range(world_size)]
981
+ shard_placement: Dict[ShardMetadata, Tuple[int, int]] = {}
982
+ for shard_md in shards_metadata:
983
+ shard_rank = cast(_remote_device, shard_md.placement).rank()
984
+ assert shard_rank is not None
985
+ if shard_rank >= world_size:
986
+ raise RuntimeError(f"Shard rank {shard_rank} exceeds world size {world_size}")
987
+
988
+ shard_placement[shard_md] = (shard_rank, rank_sizes[shard_rank])
989
+ rank_sizes[shard_rank] += shard_size(shard_md)
990
+
991
+ return shard_placement, rank_sizes
992
+
993
+ def _copy_sharded_tensor_to_shared_mem(
994
+ self, sharded_tensor: ShardedTensor, world_size: int, rank: int, key: Tuple
995
+ ) -> Any:
996
+ shard0_md = sharded_tensor.metadata()
997
+ shard_placement, rank_sizes = self._get_shard_placement_and_rank_sizes(
998
+ shard0_md.shards_metadata, world_size
999
+ )
1000
+
1001
+ rank_size = rank_sizes[rank]
1002
+ assert rank_size >= 0
1003
+ if rank_size == 0:
1004
+ return
1005
+
1006
+ assert shard0_md.tensor_properties.dtype == torch.float32, "Expected sharded tensor to be fp32"
1007
+ numpy_type = np.float32
1008
+
1009
+ sharded_memory_name = "-".join(key + (str(rank),))
1010
+
1011
+ shm = shared_memory.SharedMemory(
1012
+ create=True, size=rank_size * np.dtype(numpy_type).itemsize, name=sharded_memory_name
1013
+ )
1014
+ np_arr = np.ndarray((rank_size,), dtype=numpy_type, buffer=shm.buf)
1015
+
1016
+ for local_shard in sharded_tensor.local_shards():
1017
+ shard_rank = cast(_remote_device, local_shard.metadata.placement).rank()
1018
+ assert shard_rank == rank
1019
+
1020
+ src = local_shard.tensor.flatten()
1021
+ shard_offset = shard_placement[local_shard.metadata][1]
1022
+
1023
+ np_arr[shard_offset : shard_offset + src.numel()] = src.numpy()
1024
+
1025
+ shm.close()
1026
+
1027
+ def _copy_sharded_data_to_shared_mem(self, world_size: int, shard_filepath: Path):
1028
+ shard_number = int(shard_filepath.name[4:-3])
1029
+ log.info("Starting unsharding shard number %d to shared memory", shard_number)
1030
+
1031
+ with self._patch_sharded_tensor_load():
1032
+ shard = torch.load(shard_filepath, map_location="cpu")
1033
+ log.debug("Done loading shard number %d", shard_number)
1034
+
1035
+ self._copy_sharded_tensors_to_shared_mem(
1036
+ shard, world_size, shard_number, (str(shard_filepath.parent).replace("/", "_"),)
1037
+ )
1038
+ log.info("Done unsharding shard number %d to shared memory", shard_number)
1039
+
1040
+ def _unshard_using_sharded_mem(
1041
+ self, state: Any, world_size: int, device: torch.device, shard_dir: PathOrStr
1042
+ ) -> Any:
1043
+ return self._unshard_state_using_shared_mem(state, world_size, device, (str(shard_dir).replace("/", "_"),))
1044
+
1045
+ def _unshard_state_using_shared_mem(
1046
+ self, state: Any, world_size: int, device: torch.device, key: Tuple
1047
+ ) -> Any:
1048
+ if isinstance(state, (list, tuple, set)):
1049
+ return state.__class__(
1050
+ self._unshard_state_using_shared_mem(sub_state, world_size, device, key + (i,))
1051
+ for i, sub_state in enumerate(state)
1052
+ )
1053
+ elif isinstance(state, dict):
1054
+ return {
1055
+ name: self._unshard_state_using_shared_mem(state[name], world_size, device, key + (name,))
1056
+ for name in state.keys()
1057
+ }
1058
+ elif isinstance(state, ShardedTensor):
1059
+ return self._unshard_tensor_using_shared_mem(state, world_size, device, key)
1060
+ elif isinstance(state, torch.Tensor):
1061
+ return state.to(device=device)
1062
+ else:
1063
+ return state
1064
+
1065
+ def _unshard_tensor_using_shared_mem(
1066
+ self, sharded_tensor: ShardedTensor, world_size: int, device: torch.device, key: Tuple
1067
+ ) -> torch.Tensor:
1068
+ shard0_md = sharded_tensor.metadata()
1069
+
1070
+ def shard_size(shard_md):
1071
+ return reduce((lambda x, y: x * y), shard_md.shard_sizes) # type: ignore[attr-defined]
1072
+
1073
+ shard_placement, rank_sizes = self._get_shard_placement_and_rank_sizes(
1074
+ shard0_md.shards_metadata, world_size
1075
+ )
1076
+
1077
+ assert shard0_md.tensor_properties.dtype == torch.float32, "Expected sharded tensor to be fp32"
1078
+ numpy_type = np.float32
1079
+
1080
+ out = torch.empty(
1081
+ *sharded_tensor.metadata().size, dtype=sharded_tensor.metadata().tensor_properties.dtype, device=device
1082
+ )
1083
+ dims = len(sharded_tensor.metadata().size)
1084
+ for shard_md, (rank, rank_offset) in shard_placement.items():
1085
+ if rank >= world_size:
1086
+ raise RuntimeError(f"Shard rank {rank} exceeds world size {world_size}")
1087
+
1088
+ sharded_memory_name = "-".join(key + (str(rank),))
1089
+ shm = shared_memory.SharedMemory(name=sharded_memory_name)
1090
+
1091
+ rank_size = rank_sizes[rank]
1092
+ assert rank_size >= 0
1093
+ if rank_size == 0:
1094
+ continue
1095
+
1096
+ np_arr = np.ndarray((rank_size,), dtype=numpy_type, buffer=shm.buf)
1097
+
1098
+ tensor = torch.from_numpy(np_arr)[rank_offset : rank_offset + shard_size(shard_md)]
1099
+ tensor = tensor.view(shard_md.shard_sizes)
1100
+
1101
+ out_narrow_view = out
1102
+ for dim in range(dims):
1103
+ out_narrow_view = out_narrow_view.narrow(
1104
+ dim,
1105
+ shard_md.shard_offsets[dim],
1106
+ shard_md.shard_sizes[dim],
1107
+ )
1108
+
1109
+ out_narrow_view.copy_(tensor)
1110
+
1111
+ shm.close()
1112
+ shm.unlink()
1113
+
1114
+ return out
1115
+
1116
+ @contextmanager
1117
+ def _patch_sharded_tensor_load(self):
1118
+ """
1119
+ Monkeypatch for torch's ShardedTensor, so we can unpickle without having torch.distributed set up.
1120
+ """
1121
+
1122
+ def _rebuild_from_type_v2_monkey(func, new_type, args, state):
1123
+ ret = func(*args)
1124
+ if type(ret) is not new_type:
1125
+ ret = ret.as_subclass(new_type)
1126
+
1127
+ # Shortcut the construction of ShardedTensor
1128
+ # This is in the top 5 of my worst hacks.
1129
+ if isinstance(ret, ShardedTensor):
1130
+ ret._local_shards, ret._metadata, _, ret._sharding_spec, ret._init_rrefs = state
1131
+ return ret
1132
+
1133
+ # The rest of this function ought to be in the top 5 of somebody else's worst hacks.
1134
+ # Tensor does define __setstate__ even though it doesn't define
1135
+ # __getstate__. So only use __setstate__ if it is NOT the one defined
1136
+ # on Tensor
1137
+ if getattr(ret.__class__, "__setstate__", torch.Tensor.__setstate__) is not torch.Tensor.__setstate__:
1138
+ ret.__setstate__(state)
1139
+ else:
1140
+ ret = torch._utils._set_obj_state(ret, state)
1141
+ return ret
1142
+
1143
+ original_rebuild_from_type_v2 = torch._tensor._rebuild_from_type_v2
1144
+ try:
1145
+ torch._tensor._rebuild_from_type_v2 = _rebuild_from_type_v2_monkey
1146
+ yield
1147
+ finally:
1148
+ torch._tensor._rebuild_from_type_v2 = original_rebuild_from_type_v2
1149
+
1150
+ def _unshard(self, input_dir: PathOrStr, device: torch.device, skip_keys: Optional[Set[str]] = None):
1151
+ """
1152
+ The current unsharding implementation consists of:
1153
+
1154
+ 1. Loading each shard on a separate process and copying their sharded tensors to shared memory.
1155
+ 2. Loading 1 shard on the main process as a base unsharded object.
1156
+ 3. Using the sharded tensors in shared memory to populate the base unsharded object.
1157
+
1158
+ This implementation replaced a prior implementation that instead loaded
1159
+ all shards using threads, because that implementation turned out to
1160
+ be extremely slow (e.g. 6+ hours) sometimes when the world size was 1024.
1161
+ The current implementation is slower than the old one in many scenarios,
1162
+ but is significantly faster in the above mentioned case (e.g. 30 minutes)
1163
+ if there are enough CPUs.
1164
+ """
1165
+
1166
+ input_dir = Path(input_dir)
1167
+ skip_keys = skip_keys or set()
1168
+
1169
+ shard_filepaths = list(input_dir.glob("rank*.pt"))
1170
+ world_size = len(shard_filepaths)
1171
+ if world_size == 0:
1172
+ raise RuntimeError("No shards found for unsharding")
1173
+
1174
+ log.info("Number of shards: %d", world_size)
1175
+ shard_size_gb = shard_filepaths[0].stat().st_size / (1024 * 1024 * 1024)
1176
+ min_ram_required_estimate_gb = shard_size_gb * world_size
1177
+ log.info(
1178
+ "Shards are %.2fGB each, at least %.2fGB RAM is required", shard_size_gb, min_ram_required_estimate_gb
1179
+ )
1180
+
1181
+ log.info("Copying sharded tensors to shared memory using multiple processes")
1182
+ # Copy sharded data to shared memory using multiple processes, so this process can load
1183
+ # from memory rather than disk. We spawn a new process instead of forking since shared memory
1184
+ # appears to get deleted when forked processes end for some reason.
1185
+ executor = ProcessPoolExecutor(
1186
+ mp_context=mp.get_context("spawn"), initializer=util.prepare_cli_environment
1187
+ )
1188
+ futures = []
1189
+ for shard_filepath in shard_filepaths:
1190
+ shard_rank = int(shard_filepath.name[4:-3])
1191
+
1192
+ if shard_rank >= world_size:
1193
+ raise RuntimeError(
1194
+ f"Shard rank {shard_rank} of file {shard_filepath} exceeds world size {world_size}"
1195
+ )
1196
+
1197
+ futures.append(executor.submit(self._copy_sharded_data_to_shared_mem, world_size, shard_filepath))
1198
+
1199
+ for f in as_completed(futures):
1200
+ f.result()
1201
+ executor.shutdown()
1202
+
1203
+ log.info("Loading a shard on the main process to be unsharded state")
1204
+ with self._patch_sharded_tensor_load():
1205
+ state = torch.load(shard_filepaths[0], map_location="cpu")
1206
+
1207
+ for key in skip_keys:
1208
+ if key in state:
1209
+ del state[key]
1210
+
1211
+ log.info("Unsharding from %d shards ...", world_size)
1212
+ return self._unshard_using_sharded_mem(state, world_size, device, input_dir)
1213
+
1214
+
1215
+ @dataclass
1216
+ class _LocalShardedCheckpointerMetadata(BaseConfig):
1217
+ world_size: int = field(default_factory=get_world_size)
1218
+
1219
+
1220
+ @dataclass
1221
+ class _FlatParamShard:
1222
+ full_shape: torch.Size
1223
+ shard_offsets: Tuple[int, int]
1224
+ shard_data: Optional[torch.Tensor]
1225
+
1226
+ def copy_into(self, full_tensor: torch.Tensor) -> None:
1227
+ assert self.shard_data is not None
1228
+ full_tensor_shard_view = full_tensor.view(-1)[self.shard_offsets[0] : self.shard_offsets[1] + 1]
1229
+ assert self.shard_data.shape == full_tensor_shard_view.shape
1230
+ full_tensor_shard_view.copy_(self.shard_data)
1231
+
1232
+
1233
+ class LocalShardedCheckpointer(Checkpointer):
1234
+ """
1235
+ A sharded :class:`Checkpointer` that directly saves the local FSDP flat params data.
1236
+ The optimizer state is saved directly with `torch.save()` without reformatting via FSDP methods.
1237
+
1238
+ The world size must be kept consistent when using this checkpointer. However, you can easily
1239
+ reconstruct a full unsharded model and/or optimizer state dictionary from a single Python process
1240
+ using :meth:`unshard_checkpoint()` (no distributed initialization required).
1241
+ """
1242
+
1243
+ # These correspond to metadata attributes on `torch.distributed.fsdp.flat_param.FlatParameter`.
1244
+ _FLAT_PARAM_METADATA_TO_SAVE = (
1245
+ "_fqns",
1246
+ "_shard_param_offsets",
1247
+ "_shard_indices",
1248
+ "_numels",
1249
+ "_numels_with_padding",
1250
+ "_shapes",
1251
+ "_shard_numel_padded",
1252
+ "_shard_param_infos",
1253
+ )
1254
+
1255
+ def _fsdp_modules(self, fsdp_model: FSDP) -> List[Tuple[str, FSDP]]:
1256
+ """
1257
+ Returns a list of FSDP modules with their FQN.
1258
+ """
1259
+ modules = []
1260
+ for name, module in fsdp_model.named_modules():
1261
+ if isinstance(module, FSDP):
1262
+ modules.append((name, module))
1263
+ return modules
1264
+
1265
+ def _prepare_fsdp_model(self, fsdp_model: FSDP) -> None:
1266
+ from torch.distributed.fsdp._runtime_utils import _lazy_init
1267
+
1268
+ # TODO (epwalsh): I'm not sure if this is necessary, but this is what PyTorch does before saving/loading
1269
+ # an FSDP state dict through the built-in methods.
1270
+ if torch.cuda.is_available():
1271
+ torch.cuda.synchronize()
1272
+ _lazy_init(fsdp_model, fsdp_model)
1273
+
1274
+ def _fsdp_handles(self, fsdp_model: FSDP) -> List[FlatParamHandle]:
1275
+ if version.parse(torch.__version__) < version.parse("2.1.0"):
1276
+ return fsdp_model._handles # type: ignore
1277
+ elif version.parse(torch.__version__) < version.parse("2.3.0"):
1278
+ # Handle could be None if the FSDP wrapper doesn't manage any parameters.
1279
+ if hasattr(fsdp_model, "_handle") and fsdp_model._handle is not None:
1280
+ return [fsdp_model._handle] # type: ignore
1281
+ else:
1282
+ return []
1283
+ else:
1284
+ # Need to verify FSDP internals with newer versions.
1285
+ raise NotImplementedError
1286
+
1287
+ @torch.no_grad()
1288
+ def _get_flat_param_state_to_save(self, fsdp_model: FSDP) -> Dict[str, Any]:
1289
+ self._prepare_fsdp_model(fsdp_model)
1290
+ module_data = []
1291
+ for module_fqn, fsdp_module in self._fsdp_modules(fsdp_model):
1292
+ handle_data = []
1293
+ for handle in self._fsdp_handles(fsdp_module):
1294
+ data: Dict[str, Any] = {}
1295
+ # This is a `FlatParameter` instance.
1296
+ # See `torch.distributed.fsdp.flat_param` for the API.
1297
+ flat_param = handle.flat_param
1298
+ data["flat_param.data"] = flat_param.detach()
1299
+ for key in self._FLAT_PARAM_METADATA_TO_SAVE:
1300
+ if hasattr(flat_param, key):
1301
+ data[f"flat_param.{key}"] = getattr(flat_param, key)
1302
+ handle_data.append(data)
1303
+ module_data.append({"handles": handle_data, "name": module_fqn})
1304
+ return {"modules": module_data}
1305
+
1306
+ @torch.no_grad()
1307
+ def _load_flat_param_state(self, fsdp_model: FSDP, model_state: Dict[str, Any]):
1308
+ """Load the state produced from `self._get_flat_param_state_to_save()`."""
1309
+ self._prepare_fsdp_model(fsdp_model)
1310
+ fsdp_modules = self._fsdp_modules(fsdp_model)
1311
+ assert len(model_state["modules"]) == len(fsdp_modules)
1312
+ for (_, fsdp_module), module_data in zip(fsdp_modules, model_state["modules"]):
1313
+ handles = self._fsdp_handles(fsdp_module)
1314
+ assert len(handles) == len(module_data["handles"])
1315
+ for handle, data in zip(handles, module_data["handles"]):
1316
+ flat_param = handle.flat_param
1317
+ # Make sure metadata matches.
1318
+ for key in self._FLAT_PARAM_METADATA_TO_SAVE:
1319
+ if hasattr(flat_param, key):
1320
+ assert getattr(flat_param, key) == data[f"flat_param.{key}"]
1321
+ # Load the flat sharded data.
1322
+ flat_param.copy_(data["flat_param.data"])
1323
+
1324
+ def _save_metadata(self, dir: PathOrStr, *, upload_to: Optional[str] = None) -> None:
1325
+ if get_fs_local_rank() == 0:
1326
+ log.info("Saving metadata...")
1327
+ metadata = _LocalShardedCheckpointerMetadata()
1328
+ metadata.save(metadata_path := Path(dir) / "metadata.yaml")
1329
+ if upload_to is not None and get_global_rank() == 0:
1330
+ upload_target = f"{upload_to}/metadata.yaml"
1331
+ log.info(f"Uploading {metadata_path} to {upload_target}")
1332
+ upload(metadata_path, upload_target, save_overwrite=self.cfg.save_overwrite)
1333
+
1334
+ def _load_metadata(
1335
+ self, load_path: PathOrStr, *, local_cache: Optional[PathOrStr] = None
1336
+ ) -> _LocalShardedCheckpointerMetadata:
1337
+ metadata_path = resource_path(load_path, "metadata.yaml", local_cache=local_cache)
1338
+ return _LocalShardedCheckpointerMetadata.load(metadata_path)
1339
+
1340
+ def save_checkpoint(
1341
+ self,
1342
+ dir: PathOrStr,
1343
+ fsdp_model: FSDP,
1344
+ optim: Optimizer,
1345
+ trainer_state: Dict[str, Any],
1346
+ *,
1347
+ upload_to: Optional[str] = None,
1348
+ ) -> None:
1349
+ with self._temporary_wd(dir) as checkpoint_dir:
1350
+ # Gather local FSDP flat params data to save.
1351
+ # We also save some flat param metadata like the corresponding fully qualified names (fqns)
1352
+ # of each original parameter so we can validate that the sharding is the same when loading
1353
+ # one of these checkpoints.
1354
+ log.info("Saving local FSDP flat params data...")
1355
+ save_state_dict(
1356
+ checkpoint_dir,
1357
+ f"model/rank{get_global_rank()}.pt",
1358
+ self._get_flat_param_state_to_save(fsdp_model),
1359
+ upload_to=upload_to,
1360
+ save_overwrite=self.cfg.save_overwrite,
1361
+ )
1362
+
1363
+ # Save optimizer state.
1364
+ log.info("Saving local optimizer state...")
1365
+ save_state_dict(
1366
+ checkpoint_dir,
1367
+ f"optim/rank{get_global_rank()}.pt",
1368
+ optim.state_dict(),
1369
+ upload_to=upload_to,
1370
+ save_overwrite=self.cfg.save_overwrite,
1371
+ )
1372
+
1373
+ # Save trainer state.
1374
+ log.info("Saving trainer state...")
1375
+ save_state_dict(
1376
+ checkpoint_dir,
1377
+ f"train/rank{get_global_rank()}.pt",
1378
+ trainer_state,
1379
+ upload_to=upload_to,
1380
+ save_overwrite=self.cfg.save_overwrite,
1381
+ )
1382
+
1383
+ # Save metadata.
1384
+ self._save_metadata(checkpoint_dir, upload_to=upload_to)
1385
+
1386
+ # Save config. We do this last b/c the presence of a config in a remote checkpoint
1387
+ # "directory" indicates that the folder is valid, as a opposed to a partially
1388
+ # uploaded checkpoint directory that failed before completing.
1389
+ self._save_config(checkpoint_dir, upload_to=upload_to)
1390
+
1391
+ def restore_checkpoint(
1392
+ self,
1393
+ load_path: PathOrStr,
1394
+ fsdp_model: FSDP,
1395
+ optim: Optimizer,
1396
+ *,
1397
+ local_cache: Optional[PathOrStr] = None,
1398
+ load_optimizer_state: bool = True,
1399
+ ) -> Dict[str, Any]:
1400
+ # Load metadata and make sure checkpoint is compatible.
1401
+ metadata = self._load_metadata(load_path, local_cache=local_cache)
1402
+ assert metadata.world_size == get_world_size()
1403
+
1404
+ # Load local FSDP flat param data.
1405
+ log.info("Loading local FSDP flat params data...")
1406
+ model_state = load_state_dict(
1407
+ load_path, f"model/rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu"
1408
+ )
1409
+ self._load_flat_param_state(fsdp_model, model_state)
1410
+ del model_state
1411
+
1412
+ # Load local optim state.
1413
+ if load_optimizer_state:
1414
+ log.info("Loading local optimizer state...")
1415
+ optim_state = load_state_dict(
1416
+ load_path, f"optim/rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu"
1417
+ )
1418
+ # HACK/TODO (epwalsh): When we use adaptive clipping we track the 'grad_norm_exp_avg' for every param
1419
+ # in every rank, and keep this in the optimizer state. But this causes issues when loading the
1420
+ # state since torch sees the state is non-empty for some params which would normally be empty,
1421
+ # and then assumes it should have all of the other state tensors for that param, which is doesn't.
1422
+ # So for now we just remove 'grad_norm_exp_avg' everywhere from the state, which resets that metric.
1423
+ # Not the end of the world but there's probably a better way around this without resetting
1424
+ # the metric.
1425
+ for param_id in list(optim_state["state"].keys()):
1426
+ state = optim_state["state"][param_id]
1427
+ if "grad_norm_exp_avg" in state:
1428
+ del state["grad_norm_exp_avg"]
1429
+ if len(state) == 0:
1430
+ del optim_state["state"][param_id]
1431
+ optim.load_state_dict(optim_state)
1432
+ del optim_state
1433
+
1434
+ # Load local trainer state.
1435
+ log.info("Loading local trainer state...")
1436
+ trainer_state = load_state_dict(load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache)
1437
+ barrier()
1438
+ return trainer_state
1439
+
1440
+ def _iter_flat_param_shards(
1441
+ self, model_state: Dict[str, Any]
1442
+ ) -> Generator[Tuple[str, _FlatParamShard], None, None]:
1443
+ for module_data in model_state["modules"]:
1444
+ module_prefix = module_data["name"].replace("_fsdp_wrapped_module.", "")
1445
+ for handle in module_data["handles"]:
1446
+ flat_data = handle["flat_param.data"]
1447
+ if (num_padding := handle["flat_param._shard_numel_padded"]) > 0:
1448
+ # If there's padding in the flat param it should be on the right.
1449
+ assert (flat_data[-num_padding:] == 0).all()
1450
+ # NOTE: this changes depending on the torch version, but we don't do a version
1451
+ # check since we might be trying to unshard an old checkpoint that was stored
1452
+ # with a different torch version than we're currently running with.
1453
+ if "flat_param._shard_indices" in handle:
1454
+ # torch <=2.0.1
1455
+ param_start = handle["flat_param._shard_indices"][0]
1456
+ current_flat_index = 0
1457
+ for relative_fqn, full_shape, (offset_start, offset_end) in zip(
1458
+ handle["flat_param._fqns"][param_start:],
1459
+ handle["flat_param._shapes"][param_start:],
1460
+ handle["flat_param._shard_param_offsets"],
1461
+ ):
1462
+ root_fqn = relative_fqn if not module_prefix else f"{module_prefix}.{relative_fqn}"
1463
+ numel_shard = offset_end - offset_start + 1
1464
+ flat_param_shard = _FlatParamShard(
1465
+ full_shape=full_shape,
1466
+ shard_offsets=(offset_start, offset_end),
1467
+ shard_data=flat_data[current_flat_index : current_flat_index + numel_shard],
1468
+ )
1469
+ current_flat_index += numel_shard
1470
+ yield root_fqn, flat_param_shard
1471
+ else:
1472
+ # torch >=2.1.0
1473
+ for relative_fqn, full_shape, shard_param_info in zip(
1474
+ handle["flat_param._fqns"],
1475
+ handle["flat_param._shapes"],
1476
+ handle["flat_param._shard_param_infos"],
1477
+ ):
1478
+ if not shard_param_info.in_shard:
1479
+ continue
1480
+ root_fqn = relative_fqn if not module_prefix else f"{module_prefix}.{relative_fqn}"
1481
+ flat_param_shard = _FlatParamShard(
1482
+ full_shape=full_shape,
1483
+ shard_offsets=(
1484
+ shard_param_info.intra_param_start_idx,
1485
+ shard_param_info.intra_param_end_idx,
1486
+ ),
1487
+ shard_data=flat_data[
1488
+ shard_param_info.offset_in_shard : shard_param_info.offset_in_shard
1489
+ + shard_param_info.numel_in_shard
1490
+ ],
1491
+ )
1492
+ yield root_fqn, flat_param_shard
1493
+
1494
+ def unshard_checkpoint(
1495
+ self,
1496
+ load_path: PathOrStr,
1497
+ *,
1498
+ local_cache: Optional[PathOrStr] = None,
1499
+ load_optimizer_state: bool = True,
1500
+ load_trainer_state: bool = True,
1501
+ device: Optional[torch.device] = None,
1502
+ ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
1503
+ device = device or torch.device("cpu")
1504
+ metadata = self._load_metadata(load_path, local_cache=local_cache)
1505
+
1506
+ # Gather paths model state, potentially downloading them.
1507
+ log.info("Gathering model state dicts...")
1508
+ model_state_paths = self._gather_state_dict_paths(
1509
+ load_path, "model", metadata.world_size, local_cache=local_cache
1510
+ )
1511
+
1512
+ # Load model state dicts one-by-one, materializing and populating the full parameters as we go.
1513
+ log.info("Materializing full parameters...")
1514
+ full_model_state: Dict[str, torch.Tensor] = {}
1515
+ # We keep a copy of the flat param metadata minus the actual tensors so we can reconstruct
1516
+ # the full optimizer state below without having to reload the model state dicts.
1517
+ flat_params_data: Dict[int, Dict[str, _FlatParamShard]] = defaultdict(dict)
1518
+ for rank, path in enumerate(model_state_paths):
1519
+ log.info(f"Loading shards from rank {rank}...")
1520
+ model_state = torch.load(path, map_location="cpu")
1521
+ for root_fqn, flat_param_shard in self._iter_flat_param_shards(model_state):
1522
+ if root_fqn not in full_model_state:
1523
+ log.info(
1524
+ f"Materializing full parameter '{root_fqn}' with shape {flat_param_shard.full_shape}..."
1525
+ )
1526
+ assert flat_param_shard.shard_data is not None
1527
+ full_model_state[root_fqn] = torch.empty(
1528
+ flat_param_shard.full_shape, dtype=flat_param_shard.shard_data.dtype, device=device
1529
+ )
1530
+ # Fill with NaNs so we can validate that the whole parameter has been populated
1531
+ # afterwards.
1532
+ full_model_state[root_fqn].fill_(torch.nan)
1533
+ # Copy over the local shard to the relevant part of the full parameter.
1534
+ full_param = full_model_state[root_fqn]
1535
+ log.info(f"Loading rank {rank} shard for '{root_fqn}'...")
1536
+ flat_param_shard.copy_into(full_param)
1537
+ flat_params_data[rank][root_fqn] = replace(flat_param_shard, shard_data=None)
1538
+
1539
+ log.info("Validating full parameters...")
1540
+ for key, tensor in full_model_state.items():
1541
+ if torch.isnan(tensor).any():
1542
+ raise ValueError(f"Parameter '{key}' contains NaNs, this is likely a bug with the unsharder")
1543
+
1544
+ trainer_state: Optional[Dict[str, Any]] = None
1545
+ if load_trainer_state:
1546
+ trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache)
1547
+
1548
+ if not load_optimizer_state:
1549
+ return full_model_state, None, trainer_state
1550
+
1551
+ log.info("Gathering optim state dicts...")
1552
+ optim_state_paths = self._gather_state_dict_paths(
1553
+ load_path, "optim", metadata.world_size, local_cache=local_cache
1554
+ )
1555
+
1556
+ log.info("Materializing full optim state...")
1557
+ full_optim_state: Dict[str, Any] = {"state": defaultdict(dict)}
1558
+ fqn_to_id: Dict[str, int] = {}
1559
+ id_to_fqn: Dict[int, str] = {}
1560
+ for rank, path in enumerate(optim_state_paths):
1561
+ log.info(f"Loading sharded optim state from rank {rank}...")
1562
+ optim_state = torch.load(path, map_location="cpu")
1563
+
1564
+ # Initialize param groups.
1565
+ # We assume parameter groups are the same across all ranks.
1566
+ # The only thing that differs across ranks is the state for each local sharded param.
1567
+ if "param_groups" not in full_optim_state:
1568
+ full_optim_state["param_groups"] = optim_state["param_groups"]
1569
+ else:
1570
+ assert full_optim_state["param_groups"] == optim_state["param_groups"]
1571
+
1572
+ # Generate mapping of parameter FQNs to optimizer param IDs and vice-versa.
1573
+ if not fqn_to_id or not id_to_fqn:
1574
+ for group in full_optim_state["param_groups"]:
1575
+ for fqn, id in zip(group["param_names"], group["params"]):
1576
+ fqn = fqn.replace("_fsdp_wrapped_module.", "")
1577
+ fqn_to_id[fqn] = id
1578
+ id_to_fqn[id] = fqn
1579
+
1580
+ # Iterate over local shard state and copy into the full state.
1581
+ for id, shard_state in optim_state["state"].items():
1582
+ fqn = id_to_fqn[id]
1583
+ flat_param_shard = flat_params_data[rank].get(fqn) # type: ignore[assignment]
1584
+ full_state = full_optim_state["state"][id]
1585
+ for key, shard_value in shard_state.items():
1586
+ assert isinstance(shard_value, torch.Tensor)
1587
+ if shard_value.shape == torch.Size([]):
1588
+ # Add singleton tensors directly to full state. These should be the same across
1589
+ # all ranks.
1590
+ assert key in ("step", "grad_norm_exp_avg") # sanity check
1591
+ if key not in full_state:
1592
+ full_state[key] = shard_value.to(device)
1593
+ else:
1594
+ assert full_state[key] == shard_value
1595
+ else:
1596
+ # Otherwise we have a sharded param state.
1597
+ # If the corresponding full param state hasn't been materialized yet, do so now.
1598
+ assert flat_param_shard is not None, f"missing flat_params_data for {fqn} from rank {rank}"
1599
+ if key not in full_state:
1600
+ log.info(
1601
+ f"Materializing full state '{key}' for '{fqn}' with shape {flat_param_shard.full_shape}..."
1602
+ )
1603
+ full_state[key] = torch.empty(
1604
+ flat_param_shard.full_shape, dtype=shard_value.dtype, device=device
1605
+ )
1606
+ full_state_value = full_state[key]
1607
+
1608
+ # Copy over the local shard state to the relevant part of the full parameter state.
1609
+ log.info(f"Loading rank {rank} shard state of '{key}' for '{fqn}'...")
1610
+ replace(flat_param_shard, shard_data=shard_value).copy_into(full_state_value)
1611
+
1612
+ # Lastly, clean up the parameter names in param groups.
1613
+ for group in full_optim_state["param_groups"]:
1614
+ group["param_names"] = [n.replace("_fsdp_wrapped_module.", "") for n in group["param_names"]]
1615
+
1616
+ return full_model_state, full_optim_state, trainer_state
1617
+
1618
+ def _get_state_dict_path(
1619
+ self,
1620
+ load_path: PathOrStr,
1621
+ state_dict_type: str,
1622
+ rank: int,
1623
+ *,
1624
+ local_cache: Optional[PathOrStr] = None,
1625
+ progress=None,
1626
+ ) -> Tuple[int, Path]:
1627
+ fname = f"{state_dict_type}/rank{rank}.pt"
1628
+ return rank, resource_path(str(load_path).rstrip("/"), fname, local_cache=local_cache, progress=progress)
1629
+
1630
+ def _gather_state_dict_paths(
1631
+ self,
1632
+ load_path: PathOrStr,
1633
+ state_dict_type: str,
1634
+ world_size: int,
1635
+ *,
1636
+ local_cache: Optional[PathOrStr] = None,
1637
+ ) -> List[Path]:
1638
+ progress = get_progress_bar()
1639
+ with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
1640
+ futures = []
1641
+ for rank in range(world_size):
1642
+ future = executor.submit(
1643
+ self._get_state_dict_path,
1644
+ load_path,
1645
+ state_dict_type,
1646
+ rank,
1647
+ local_cache=local_cache,
1648
+ progress=progress,
1649
+ )
1650
+ futures.append(future)
1651
+
1652
+ results: Dict[int, Path] = {}
1653
+ for future in as_completed(futures):
1654
+ rank, path = future.result()
1655
+ results[rank] = path
1656
+
1657
+ return [results[rank] for rank in range(world_size)]
1658
+
1659
+
1660
+ def build_sharded_checkpointer(
1661
+ cfg: TrainConfig, *, name: Optional[ShardedCheckpointerType] = None
1662
+ ) -> Checkpointer:
1663
+ name = name or cfg.sharded_checkpointer
1664
+ if name == ShardedCheckpointerType.torch_new:
1665
+ return TorchNewStyleShardedCheckpointer(cfg)
1666
+ elif name == ShardedCheckpointerType.torch_legacy:
1667
+ return TorchLegacyShardedCheckpointer(cfg)
1668
+ elif name == ShardedCheckpointerType.local:
1669
+ return LocalShardedCheckpointer(cfg)
1670
+ else:
1671
+ raise NotImplementedError(name)