Prompt48 commited on
Commit
15045f9
·
verified ·
1 Parent(s): fb750ba

Upload edit\Qwen3-TTS-test\.venv\Lib\site-packages\accelerate\parallelism_config.py with huggingface_hub

Browse files
edit//Qwen3-TTS-test//.venv//Lib//site-packages//accelerate//parallelism_config.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import warnings
17
+ from dataclasses import dataclass
18
+ from typing import TYPE_CHECKING, Literal, Optional, Union
19
+
20
+ from accelerate.utils.dataclasses import (
21
+ DeepSpeedSequenceParallelConfig,
22
+ DistributedType,
23
+ TorchContextParallelConfig,
24
+ TorchTensorParallelConfig,
25
+ )
26
+ from accelerate.utils.versions import is_torch_version
27
+
28
+
29
+ if TYPE_CHECKING:
30
+ from accelerate import Accelerator
31
+
32
+
33
+ @dataclass
34
+ class ParallelismConfig:
35
+ """
36
+ A dataclass to configure parallelisms applied to the model. Inspired by torchtitan's `ParallelDims`
37
+ https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/parallel_dims.py
38
+
39
+ Args:
40
+ dp_replicate_size (`int`, defaults to `1`):
41
+ The size of the data parallel group. If `dp_replicate_size` is set to 1, the data parallel replication
42
+ group will not be used.
43
+ dp_shard_size (`int`, defaults to `1`):
44
+ The size of the model shard group. If `dp_replicate_size > 1` and `tp_size > 1`, `dp_shard_size` must also
45
+ be greater than 1, as composing DDP + TP is currently not supported.
46
+ tp_size (`int`, defaults to `1`):
47
+ The size of the tensor parallel group. If `tp_size` is set to `1`, the tensor parallel group will not be
48
+ used.
49
+ tp_handler (`~utils.TorchTensorParallelConfig`, defaults to `None`):
50
+ The handler for the tensor parallel group.
51
+ cp_size (`int`, defaults to `1`):
52
+ The size of the context parallel group. Currently not supported, but reserved for future use and enabled
53
+ for downstream libraries.
54
+ cp_backend (`str`, defaults to `torch`):
55
+ Which CP backend to use: `torch` (FSDP2)
56
+ sp_size (`int`, defaults to `1`):
57
+ The size of the sequence parallel group.
58
+ sp_backend (`str`, defaults to `deepspeed`):
59
+ Which SP backend to use:`deepspeed` (ALST/Ulysses)
60
+
61
+ You may obtain different distributed data parallel paradigms by configuring `dp_replicate_size` and `dp_shard_size`
62
+ together:
63
+ - `dp_replicate_size == 1` and `dp_shard_size > 1`, we obtain Fully Sharded Data Parallel (FSDP).
64
+ - `dp_replicate_size > 1` and `dp_shard_size > 1`, we obtain Hybrid Sharded Data Parallel (HSDP).
65
+ - `dp_replicate_size > 1` and `dp_shard_size == 1` is an invalid configuration, to use pure DP, use
66
+ `DistributedDataParallelKwargs` instead.
67
+
68
+ """
69
+
70
+ dp_replicate_size: Optional[int] = None
71
+ dp_shard_size: Optional[int] = None
72
+ tp_size: Optional[int] = None
73
+ cp_size: Optional[int] = None
74
+ cp_backend: Literal["torch"] = None
75
+ sp_size: Optional[int] = None
76
+ sp_backend: Literal["deepspeed"] = None
77
+
78
+ # we use Union because we might support other x parallel plugins (i.e. deepspeed, etc)
79
+ tp_handler: Union[None, TorchTensorParallelConfig] = None
80
+ cp_handler: Union[None, TorchContextParallelConfig] = None
81
+ sp_handler: Union[None, DeepSpeedSequenceParallelConfig] = None
82
+
83
+ device_mesh = None
84
+
85
+ def __repr__(self):
86
+ return (
87
+ "ParallelismConfig(\n "
88
+ f"\tdp_replicate_size={self.dp_replicate_size},\n"
89
+ f"\tdp_shard_size={self.dp_shard_size},\n"
90
+ f"\ttp_size={self.tp_size},\n"
91
+ f"\tcp_size={self.cp_size},\n"
92
+ f"\tcp_backend={self.cp_backend},\n"
93
+ f"\tsp_size={self.sp_size},\n"
94
+ f"\tsp_backend={self.sp_backend},\n"
95
+ f"\ttotal_size={self.total_size}\n"
96
+ f"\ttp_handler={self.tp_handler},\n"
97
+ f"\tcp_handler={self.cp_handler})\n"
98
+ )
99
+
100
+ def to_json(self):
101
+ import copy
102
+
103
+ _non_serializable_fields = ["device_mesh"]
104
+
105
+ copy.deepcopy(
106
+ {
107
+ k: copy.deepcopy(v.__dict__) if hasattr(v, "__dict__") else v
108
+ for k, v in self.__dict__.items()
109
+ if k not in _non_serializable_fields
110
+ }
111
+ )
112
+
113
+ @property
114
+ def dp_dim_names(self):
115
+ """Names of enabled dimensions across which data parallelism is applied."""
116
+ dims = []
117
+ if self.dp_replicate_enabled:
118
+ dims += ["dp_replicate"]
119
+ if self.dp_shard_enabled:
120
+ dims += ["dp_shard"]
121
+ return dims
122
+
123
+ @property
124
+ def non_dp_dim_names(self):
125
+ """Names of enabled dimensions which will receive the same batch (non-data parallel dimensions)."""
126
+ dims = []
127
+ if self.tp_enabled:
128
+ dims += ["tp"]
129
+ if self.cp_enabled:
130
+ dims += ["cp"]
131
+ if self.sp_enabled:
132
+ dims += ["sp"]
133
+ return dims
134
+
135
+ @property
136
+ def dp_shard_cp_dim_names(self):
137
+ """Names of enabled dimensions which will be flattened into a joint mesh across which is model sharded in FSDP."""
138
+ dims = []
139
+ if self.dp_shard_enabled:
140
+ dims += ["dp_shard"]
141
+ if self.cp_enabled:
142
+ dims += ["cp"]
143
+ return dims
144
+
145
+ @property
146
+ def dp_cp_dim_names(self):
147
+ """Names of enabled dimensions across which loss should be averaged"""
148
+ dims = []
149
+ if self.dp_replicate_enabled:
150
+ dims += ["dp_replicate"]
151
+ if self.dp_shard_enabled:
152
+ dims += ["dp_shard"]
153
+ if self.cp_enabled:
154
+ dims += ["cp"]
155
+ return dims
156
+
157
+ @property
158
+ def fsdp_dim_names(self):
159
+ """Names of enabled dimensions across which FSDP is applied, including data parallel replication."""
160
+ dims = []
161
+ if self.dp_replicate_enabled:
162
+ dims += ["dp_replicate"]
163
+ dims += ["dp_shard_cp"]
164
+ return dims
165
+
166
+ @property
167
+ def total_size(self):
168
+ """The total size of the parallelism configuration, which is the product of all sizes."""
169
+ return self.dp_replicate_size * self.dp_shard_size * self.tp_size * self.cp_size * self.sp_size
170
+
171
+ @property
172
+ def non_data_parallel_size(self):
173
+ """The size of the non-data parallel dimensions, which is the product of tensor and context parallel sizes."""
174
+ return self.tp_size * self.cp_size * self.sp_size
175
+
176
+ @property
177
+ def data_parallel_size(self):
178
+ """The size of the data parallel dimensions, which is the product of data parallel replication and"""
179
+ return self.dp_replicate_size * self.dp_shard_size
180
+
181
+ @property
182
+ def dp_replicate_enabled(self):
183
+ """True if data parallel replication is enabled, i.e. `dp_replicate_size > 1`."""
184
+ return self.dp_replicate_size > 1
185
+
186
+ @property
187
+ def dp_shard_enabled(self):
188
+ """True if data parallel sharding is enabled, i.e. `dp_shard_size > 1`."""
189
+ return self.dp_shard_size > 1
190
+
191
+ @property
192
+ def tp_enabled(self):
193
+ """True if tensor parallelism is enabled, i.e. `tp_size > 1`."""
194
+ return self.tp_size > 1
195
+
196
+ @property
197
+ def cp_enabled(self):
198
+ """True if context parallelism is enabled, i.e. `cp_size > 1`."""
199
+ return self.cp_size > 1
200
+
201
+ @property
202
+ def sp_enabled(self):
203
+ """True if context parallelism is enabled, i.e. `sp_size > 1`."""
204
+ return self.sp_size > 1
205
+
206
+ @property
207
+ def active_mesh_dims(self):
208
+ """Names of all active mesh dimensions."""
209
+ return self.dp_dim_names + self.non_dp_dim_names
210
+
211
+ def build_device_mesh(self, device_type: str):
212
+ """Builds a device mesh for the given device type based on the parallelism configuration.
213
+ This method will also create required joint meshes (e.g. `dp_shard_cp`, `dp_cp`, `dp`).
214
+
215
+ Args:
216
+ device_type (`str`): The type of device for which to build the mesh, e
217
+ """
218
+ if is_torch_version(">=", "2.2.0"):
219
+ from torch.distributed.device_mesh import init_device_mesh
220
+ else:
221
+ raise RuntimeError("Building a device_mesh requires to have torch>=2.2.0")
222
+
223
+ mesh = self._get_mesh()
224
+ if len(mesh) == 0:
225
+ return None
226
+ mesh_dim_names, mesh_shape = mesh
227
+ device_mesh = init_device_mesh(
228
+ device_type,
229
+ mesh_shape,
230
+ mesh_dim_names=mesh_dim_names,
231
+ )
232
+ if self.dp_dim_names:
233
+ device_mesh[self.dp_dim_names]._flatten("dp")
234
+ if self.dp_shard_cp_dim_names:
235
+ device_mesh[self.dp_shard_cp_dim_names]._flatten("dp_shard_cp")
236
+ if self.dp_cp_dim_names:
237
+ device_mesh[self.dp_cp_dim_names]._flatten("dp_cp")
238
+
239
+ return device_mesh
240
+
241
+ def get_device_mesh(self, device_type: Optional[str] = None):
242
+ if self.device_mesh is None:
243
+ if device_type is not None:
244
+ self.device_mesh = self.build_device_mesh(device_type)
245
+ else:
246
+ raise ("You need to pass a device_type e.g cuda to build the device mesh")
247
+ else:
248
+ if device_type is not None:
249
+ if self.device_mesh.device_type != device_type:
250
+ raise ValueError(
251
+ f"The device_mesh is already created with device type {self.device_mesh.device_type}. However, you are trying to get a device mesh with device_type {device_type}. Please check if you correctly initialized your device_mesh"
252
+ )
253
+ return self.device_mesh
254
+
255
+ def _get_mesh(self) -> tuple[tuple[int, ...], tuple[str, ...]]:
256
+ """Generate mesh shape and dimension names for torch.distributed.init_device_mesh()."""
257
+
258
+ # Build mesh dimensions dictionary
259
+ mesh_dims = {parallelism: self._sizes[parallelism] for parallelism in self.active_mesh_dims}
260
+
261
+ # Apply canonical ordering
262
+ mesh_order = ["dp_replicate", "dp_shard", "cp", "sp", "tp"]
263
+ sorted_items = sorted(
264
+ mesh_dims.items(),
265
+ key=lambda x: (mesh_order.index(x[0])),
266
+ )
267
+ return tuple(zip(*sorted_items))
268
+
269
+ def __post_init__(self):
270
+ # Basic size validation
271
+ if self.dp_replicate_size is None:
272
+ self.dp_replicate_size = int(os.environ.get("PARALLELISM_CONFIG_DP_REPLICATE_SIZE", "1"))
273
+ if self.dp_shard_size is None:
274
+ self.dp_shard_size = int(os.environ.get("PARALLELISM_CONFIG_DP_SHARD_SIZE", "1"))
275
+ if self.tp_size is None:
276
+ self.tp_size = int(os.environ.get("PARALLELISM_CONFIG_TP_SIZE", "1"))
277
+ if self.cp_size is None:
278
+ self.cp_size = int(os.environ.get("PARALLELISM_CONFIG_CP_SIZE", "1"))
279
+ if self.cp_backend is None:
280
+ self.cp_backend = os.environ.get("PARALLELISM_CONFIG_CP_BACKEND", "torch")
281
+ if self.sp_size is None:
282
+ self.sp_size = int(os.environ.get("PARALLELISM_CONFIG_SP_SIZE", "1"))
283
+ if self.sp_backend is None:
284
+ self.sp_backend = os.environ.get("PARALLELISM_CONFIG_SP_BACKEND", "deepspeed")
285
+
286
+ if self.tp_size > 1:
287
+ if self.tp_handler is None:
288
+ self.tp_handler = TorchTensorParallelConfig()
289
+
290
+ if self.cp_size > 1:
291
+ if self.cp_handler is None:
292
+ self.cp_handler = TorchContextParallelConfig()
293
+ else:
294
+ cp_backends_config_map = dict(
295
+ torch=TorchContextParallelConfig,
296
+ )
297
+ if not isinstance(self.cp_handler, cp_backends_config_map[self.cp_backend]):
298
+ raise ValueError(
299
+ f"ParallelismConfig's cp_backend={self.cp_backend} requires {cp_backends_config_map[self.cp_backend]}, but cp_handler was set to {type(self.cp_handler)}"
300
+ )
301
+
302
+ if self.sp_size > 1:
303
+ if self.sp_handler is None:
304
+ self.sp_handler = DeepSpeedSequenceParallelConfig()
305
+ if self.dp_replicate_size < 1:
306
+ raise ValueError(f"dp_replicate_size must be at least 1, but got {self.dp_replicate_size}")
307
+ if self.dp_shard_size < 1:
308
+ raise ValueError(f"dp_shard_size must be at least 1, but got {self.dp_shard_size}")
309
+ if self.tp_size < 1:
310
+ raise ValueError(f"tp_size must be at least 1, but got {self.tp_size}")
311
+ if self.cp_size < 1:
312
+ raise ValueError(f"cp_size must be at least 1, but got {self.cp_size}")
313
+ valid_cp_backends = ["torch"]
314
+ if self.cp_backend not in valid_cp_backends:
315
+ raise ValueError(f"cp_backend must be one of {valid_cp_backends}, but got {self.cp_backend}")
316
+
317
+ if self.sp_size < 1:
318
+ raise ValueError(f"sp_size must be at least 1, but got {self.sp_size}")
319
+ valid_sp_backends = ["deepspeed"]
320
+ if self.sp_backend not in valid_sp_backends:
321
+ raise ValueError(f"sp_backend must be one of {valid_sp_backends}, but got {self.sp_backend}")
322
+
323
+ if (self.tp_size > 1 or self.cp_size > 1) and self.dp_replicate_size > 1 and self.dp_shard_size == 1:
324
+ raise ValueError(
325
+ "Tensor/Context parallelism (tp/cp_size > 1) cannot be used with pure data parallelism (dp_replicate_size > 1 and dp_shard_size == 1). "
326
+ "Please set dp_shard_size > 1 and dp_replicate_size == 1 to compose FSDP + TP/CP for 2D parallel, "
327
+ "or set dp_replicate_size == 1 and dp_shard_size > 1 to compose HSDP + TP/CP for 3D parallel."
328
+ )
329
+ self._sizes = {
330
+ "dp_replicate": self.dp_replicate_size,
331
+ "dp_shard": self.dp_shard_size,
332
+ "tp": self.tp_size,
333
+ "cp": self.cp_size,
334
+ "sp": self.sp_size,
335
+ }
336
+
337
+ def _set_size(self, parallelism: str, size: int):
338
+ assert parallelism in self._sizes.keys(), f"Parallelism must be one of {self._sizes.keys()}"
339
+ self._sizes[parallelism] = size
340
+ setattr(self, f"{parallelism}_size", size)
341
+
342
+ def _validate_accelerator(self, accelerator: "Accelerator"):
343
+ _warnings = set()
344
+ if not accelerator.multi_device and self.total_size == 1:
345
+ # No distributed setup, valid parallelism config
346
+ return
347
+
348
+ # We need this to ensure DDP works
349
+ if self.total_size == 1:
350
+ self._set_size("dp_replicate", accelerator.num_processes)
351
+
352
+ if self.total_size != accelerator.num_processes:
353
+ raise ValueError(
354
+ f"ParallelismConfig total_size ({self.total_size}) does not match "
355
+ f"num_processes ({accelerator.num_processes}). Please adjust dp_replicate_size/ "
356
+ f"dp_shard_size/tp_size/cp_size/sp_size."
357
+ )
358
+
359
+ if self.total_size > 1 and not (
360
+ accelerator.is_fsdp2
361
+ or accelerator.multi_device
362
+ or accelerator.distributed_type == DistributedType.DEEPSPEED
363
+ ):
364
+ raise ValueError(
365
+ f"ParallelismConfig is only compatible DistributedType.FSDP (version 2) or DistributedType.Multi{{Device}} or DistributedType.DEEPSPEED, but got {accelerator.distributed_type}."
366
+ )
367
+
368
+ for parallelism, size in self._sizes.items():
369
+ if size == 1 and getattr(self, f"{parallelism}_handler", None) is not None:
370
+ _warnings.add(
371
+ f"ParallelismConfig.{parallelism}_handler is set, but {parallelism}_size is set to 1. This handler will be ignored."
372
+ )
373
+
374
+ if _warnings and accelerator.is_main_process:
375
+ warnings.warn(
376
+ "ParallelismConfig has the following warnings:\n" + "\n".join(_warnings),
377
+ UserWarning,
378
+ )