Prompt48 commited on
Commit
0c52c15
·
verified ·
1 Parent(s): 7680122

Upload edit\Qwen3-TTS-test\.venv\Lib\site-packages\accelerate\checkpointing.py with huggingface_hub

Browse files
edit//Qwen3-TTS-test//.venv//Lib//site-packages//accelerate//checkpointing.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import random
16
+ from pathlib import Path
17
+ from typing import Optional
18
+
19
+ import numpy as np
20
+ import torch
21
+ from safetensors.torch import load_model
22
+
23
+ from .utils import (
24
+ MODEL_NAME,
25
+ OPTIMIZER_NAME,
26
+ RNG_STATE_NAME,
27
+ SAFE_MODEL_NAME,
28
+ SAFE_WEIGHTS_NAME,
29
+ SAMPLER_NAME,
30
+ SCALER_NAME,
31
+ SCHEDULER_NAME,
32
+ WEIGHTS_NAME,
33
+ get_pretty_name,
34
+ is_cuda_available,
35
+ is_hpu_available,
36
+ is_mlu_available,
37
+ is_musa_available,
38
+ is_sdaa_available,
39
+ is_torch_version,
40
+ is_torch_xla_available,
41
+ is_xpu_available,
42
+ load,
43
+ save,
44
+ )
45
+
46
+
47
+ if is_torch_version(">=", "2.4.0"):
48
+ from torch.amp import GradScaler
49
+ else:
50
+ from torch.cuda.amp import GradScaler
51
+
52
+ if is_torch_xla_available():
53
+ import torch_xla.core.xla_model as xm
54
+
55
+ from .logging import get_logger
56
+ from .state import PartialState
57
+
58
+
59
+ logger = get_logger(__name__)
60
+
61
+
62
+ def save_accelerator_state(
63
+ output_dir: str,
64
+ model_states: list[dict],
65
+ optimizers: list,
66
+ schedulers: list,
67
+ dataloaders: list,
68
+ process_index: int,
69
+ step: int,
70
+ scaler: Optional[GradScaler] = None,
71
+ save_on_each_node: bool = False,
72
+ safe_serialization: bool = True,
73
+ ):
74
+ """
75
+ Saves the current states of the models, optimizers, scaler, and RNG generators to a given directory.
76
+
77
+ <Tip>
78
+
79
+ If `safe_serialization` is `True`, models will be saved with `safetensors` while the rest are saved using native
80
+ `pickle`.
81
+
82
+ </Tip>
83
+
84
+ Args:
85
+ output_dir (`str` or `os.PathLike`):
86
+ The name of the folder to save all relevant weights and states.
87
+ model_states (`List[torch.nn.Module]`):
88
+ A list of model states
89
+ optimizers (`List[torch.optim.Optimizer]`):
90
+ A list of optimizer instances
91
+ schedulers (`List[torch.optim.lr_scheduler._LRScheduler]`):
92
+ A list of learning rate schedulers
93
+ dataloaders (`List[torch.utils.data.DataLoader]`):
94
+ A list of dataloader instances to save their sampler states
95
+ process_index (`int`):
96
+ The current process index in the Accelerator state
97
+ step (`int`):
98
+ The current step in the internal step tracker
99
+ scaler (`torch.amp.GradScaler`, *optional*):
100
+ An optional gradient scaler instance to save;
101
+ save_on_each_node (`bool`, *optional*):
102
+ Whether to save on every node, or only the main node.
103
+ safe_serialization (`bool`, *optional*, defaults to `True`):
104
+ Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
105
+ """
106
+ output_dir = Path(output_dir)
107
+ # Model states
108
+ for i, state in enumerate(model_states):
109
+ weights_name = WEIGHTS_NAME if not safe_serialization else SAFE_WEIGHTS_NAME
110
+ if i > 0:
111
+ weights_name = weights_name.replace(".", f"_{i}.")
112
+ output_model_file = output_dir.joinpath(weights_name)
113
+ save(state, output_model_file, save_on_each_node=save_on_each_node, safe_serialization=safe_serialization)
114
+ logger.info(f"Model weights saved in {output_model_file}")
115
+ # Optimizer states
116
+ for i, opt in enumerate(optimizers):
117
+ state = opt.state_dict()
118
+ optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin"
119
+ output_optimizer_file = output_dir.joinpath(optimizer_name)
120
+ save(state, output_optimizer_file, save_on_each_node=save_on_each_node, safe_serialization=False)
121
+ logger.info(f"Optimizer state saved in {output_optimizer_file}")
122
+ # Scheduler states
123
+ for i, scheduler in enumerate(schedulers):
124
+ state = scheduler.state_dict()
125
+ scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin"
126
+ output_scheduler_file = output_dir.joinpath(scheduler_name)
127
+ save(state, output_scheduler_file, save_on_each_node=save_on_each_node, safe_serialization=False)
128
+ logger.info(f"Scheduler state saved in {output_scheduler_file}")
129
+ # DataLoader states
130
+ for i, dataloader in enumerate(dataloaders):
131
+ sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin"
132
+ output_sampler_file = output_dir.joinpath(sampler_name)
133
+ # Only save if we have our custom sampler
134
+ from .data_loader import IterableDatasetShard, SeedableRandomSampler
135
+
136
+ if isinstance(dataloader.dataset, IterableDatasetShard):
137
+ sampler = dataloader.get_sampler()
138
+ if isinstance(sampler, SeedableRandomSampler):
139
+ save(sampler, output_sampler_file, save_on_each_node=save_on_each_node, safe_serialization=False)
140
+ if getattr(dataloader, "use_stateful_dataloader", False):
141
+ dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
142
+ output_dataloader_state_dict_file = output_dir.joinpath(dataloader_state_dict_name)
143
+ state_dict = dataloader.state_dict()
144
+ torch.save(state_dict, output_dataloader_state_dict_file)
145
+ logger.info(f"Sampler state for dataloader {i} saved in {output_sampler_file}")
146
+
147
+ # GradScaler state
148
+ if scaler is not None:
149
+ state = scaler.state_dict()
150
+ output_scaler_file = output_dir.joinpath(SCALER_NAME)
151
+ torch.save(state, output_scaler_file)
152
+ logger.info(f"Gradient scaler state saved in {output_scaler_file}")
153
+ # Random number generator states
154
+ states = {}
155
+ states_name = f"{RNG_STATE_NAME}_{process_index}.pkl"
156
+ states["step"] = step
157
+ states["random_state"] = random.getstate()
158
+ states["numpy_random_seed"] = np.random.get_state()
159
+ states["torch_manual_seed"] = torch.get_rng_state()
160
+ if is_xpu_available():
161
+ states["torch_xpu_manual_seed"] = torch.xpu.get_rng_state_all()
162
+ if is_mlu_available():
163
+ states["torch_mlu_manual_seed"] = torch.mlu.get_rng_state_all()
164
+ elif is_sdaa_available():
165
+ states["torch_sdaa_manual_seed"] = torch.sdaa.get_rng_state_all()
166
+ elif is_musa_available():
167
+ states["torch_musa_manual_seed"] = torch.musa.get_rng_state_all()
168
+ if is_hpu_available():
169
+ states["torch_hpu_manual_seed"] = torch.hpu.get_rng_state_all()
170
+ if is_cuda_available():
171
+ states["torch_cuda_manual_seed"] = torch.cuda.get_rng_state_all()
172
+ if is_torch_xla_available():
173
+ states["xm_seed"] = xm.get_rng_state()
174
+ output_states_file = output_dir.joinpath(states_name)
175
+ torch.save(states, output_states_file)
176
+ logger.info(f"Random states saved in {output_states_file}")
177
+ return output_dir
178
+
179
+
180
+ def load_accelerator_state(
181
+ input_dir,
182
+ models,
183
+ optimizers,
184
+ schedulers,
185
+ dataloaders,
186
+ process_index,
187
+ scaler=None,
188
+ map_location=None,
189
+ load_kwargs=None,
190
+ **load_model_func_kwargs,
191
+ ):
192
+ """
193
+ Loads states of the models, optimizers, scaler, and RNG generators from a given directory.
194
+
195
+ Args:
196
+ input_dir (`str` or `os.PathLike`):
197
+ The name of the folder to load all relevant weights and states.
198
+ models (`List[torch.nn.Module]`):
199
+ A list of model instances
200
+ optimizers (`List[torch.optim.Optimizer]`):
201
+ A list of optimizer instances
202
+ schedulers (`List[torch.optim.lr_scheduler._LRScheduler]`):
203
+ A list of learning rate schedulers
204
+ process_index (`int`):
205
+ The current process index in the Accelerator state
206
+ scaler (`torch.amp.GradScaler`, *optional*):
207
+ An optional *GradScaler* instance to load
208
+ map_location (`str`, *optional*):
209
+ What device to load the optimizer state onto. Should be one of either "cpu" or "on_device".
210
+ load_kwargs (`dict`, *optional*):
211
+ Additional arguments that can be passed to the `load` function.
212
+ load_model_func_kwargs (`dict`, *optional*):
213
+ Additional arguments that can be passed to the model's `load_state_dict` method.
214
+
215
+ Returns:
216
+ `dict`: Contains the `Accelerator` attributes to override while loading the state.
217
+ """
218
+ # stores the `Accelerator` attributes to override
219
+ override_attributes = dict()
220
+ if map_location not in [None, "cpu", "on_device"]:
221
+ raise TypeError(
222
+ "Unsupported optimizer map location passed, please choose one of `None`, `'cpu'`, or `'on_device'`"
223
+ )
224
+ if map_location is None:
225
+ map_location = "cpu"
226
+ elif map_location == "on_device":
227
+ map_location = PartialState().device
228
+
229
+ if load_kwargs is None:
230
+ load_kwargs = {}
231
+
232
+ input_dir = Path(input_dir)
233
+ # Model states
234
+ for i, model in enumerate(models):
235
+ ending = f"_{i}" if i > 0 else ""
236
+ input_model_file = input_dir.joinpath(f"{SAFE_MODEL_NAME}{ending}.safetensors")
237
+ if input_model_file.exists():
238
+ load_model(model, input_model_file, device=str(map_location), **load_model_func_kwargs)
239
+ else:
240
+ # Load with torch
241
+ input_model_file = input_dir.joinpath(f"{MODEL_NAME}{ending}.bin")
242
+ state_dict = load(input_model_file, map_location=map_location)
243
+ model.load_state_dict(state_dict, **load_model_func_kwargs)
244
+ logger.info("All model weights loaded successfully")
245
+
246
+ # Optimizer states
247
+ for i, opt in enumerate(optimizers):
248
+ optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin"
249
+ input_optimizer_file = input_dir.joinpath(optimizer_name)
250
+ optimizer_state = load(input_optimizer_file, map_location=map_location, **load_kwargs)
251
+ optimizers[i].load_state_dict(optimizer_state)
252
+ logger.info("All optimizer states loaded successfully")
253
+
254
+ # Scheduler states
255
+ for i, scheduler in enumerate(schedulers):
256
+ scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin"
257
+ input_scheduler_file = input_dir.joinpath(scheduler_name)
258
+ scheduler_state = load(input_scheduler_file, **load_kwargs)
259
+ scheduler.load_state_dict(scheduler_state)
260
+ logger.info("All scheduler states loaded successfully")
261
+
262
+ for i, dataloader in enumerate(dataloaders):
263
+ sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin"
264
+ input_sampler_file = input_dir.joinpath(sampler_name)
265
+ # Only load if we have our custom sampler
266
+ from .data_loader import IterableDatasetShard, SeedableRandomSampler
267
+
268
+ if isinstance(dataloader.dataset, IterableDatasetShard):
269
+ sampler = dataloader.get_sampler()
270
+ if isinstance(sampler, SeedableRandomSampler):
271
+ sampler = dataloader.set_sampler(load(input_sampler_file))
272
+ if getattr(dataloader, "use_stateful_dataloader", False):
273
+ dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
274
+ input_dataloader_state_dict_file = input_dir.joinpath(dataloader_state_dict_name)
275
+ if input_dataloader_state_dict_file.exists():
276
+ state_dict = load(input_dataloader_state_dict_file, **load_kwargs)
277
+ dataloader.load_state_dict(state_dict)
278
+ logger.info("All dataloader sampler states loaded successfully")
279
+
280
+ # GradScaler state
281
+ if scaler is not None:
282
+ input_scaler_file = input_dir.joinpath(SCALER_NAME)
283
+ scaler_state = load(input_scaler_file)
284
+ scaler.load_state_dict(scaler_state)
285
+ logger.info("GradScaler state loaded successfully")
286
+
287
+ # Random states
288
+ try:
289
+ states = load(input_dir.joinpath(f"{RNG_STATE_NAME}_{process_index}.pkl"))
290
+ if "step" in states:
291
+ override_attributes["step"] = states["step"]
292
+ random.setstate(states["random_state"])
293
+ np.random.set_state(states["numpy_random_seed"])
294
+ torch.set_rng_state(states["torch_manual_seed"])
295
+ if is_xpu_available():
296
+ torch.xpu.set_rng_state_all(states["torch_xpu_manual_seed"])
297
+ if is_mlu_available():
298
+ torch.mlu.set_rng_state_all(states["torch_mlu_manual_seed"])
299
+ elif is_sdaa_available():
300
+ torch.sdaa.set_rng_state_all(states["torch_sdaa_manual_seed"])
301
+ elif is_musa_available():
302
+ torch.musa.set_rng_state_all(states["torch_musa_manual_seed"])
303
+ else:
304
+ torch.cuda.set_rng_state_all(states["torch_cuda_manual_seed"])
305
+ if is_torch_xla_available():
306
+ xm.set_rng_state(states["xm_seed"])
307
+ logger.info("All random states loaded successfully")
308
+ except Exception:
309
+ logger.info("Could not load random states")
310
+
311
+ return override_attributes
312
+
313
+
314
+ def save_custom_state(obj, path, index: int = 0, save_on_each_node: bool = False):
315
+ """
316
+ Saves the state of `obj` to `{path}/custom_checkpoint_{index}.pkl`
317
+ """
318
+ # Should this be the right way to get a qual_name type value from `obj`?
319
+ save_location = Path(path) / f"custom_checkpoint_{index}.pkl"
320
+ logger.info(f"Saving the state of {get_pretty_name(obj)} to {save_location}")
321
+ save(obj.state_dict(), save_location, save_on_each_node=save_on_each_node)
322
+
323
+
324
+ def load_custom_state(obj, path, index: int = 0):
325
+ """
326
+ Loads the state of `obj` at `{path}/custom_checkpoint_{index}.pkl`. Will always set `weights_only=False` when
327
+ loading the state.
328
+ """
329
+ load_location = f"{path}/custom_checkpoint_{index}.pkl"
330
+ logger.info(f"Loading the state of {get_pretty_name(obj)} from {load_location}")
331
+ obj.load_state_dict(load(load_location, map_location="cpu", weights_only=False))