Prompt48 commited on
Commit
904b4f7
·
verified ·
1 Parent(s): 0c52c15

Upload edit\Qwen3-TTS-test\.venv\Lib\site-packages\accelerate\data_loader.py with huggingface_hub

Browse files
edit//Qwen3-TTS-test//.venv//Lib//site-packages//accelerate//data_loader.py ADDED
@@ -0,0 +1,1451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import importlib
16
+ import math
17
+ from contextlib import suppress
18
+ from typing import Callable, Optional, Union
19
+
20
+ import torch
21
+ from packaging import version
22
+ from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler
23
+
24
+ from .logging import get_logger
25
+ from .state import DistributedType, GradientState, PartialState, is_torch_xla_available
26
+ from .utils import (
27
+ RNGType,
28
+ broadcast,
29
+ broadcast_object_list,
30
+ compare_versions,
31
+ concatenate,
32
+ find_batch_size,
33
+ get_data_structure,
34
+ initialize_tensors,
35
+ is_datasets_available,
36
+ is_torch_version,
37
+ is_torchdata_stateful_dataloader_available,
38
+ send_to_device,
39
+ slice_tensors,
40
+ synchronize_rng_states,
41
+ )
42
+
43
+
44
+ logger = get_logger(__name__)
45
+
46
+ # kwargs of the DataLoader in min version 2.0
47
+ _PYTORCH_DATALOADER_KWARGS = {
48
+ "batch_size": 1,
49
+ "shuffle": False,
50
+ "sampler": None,
51
+ "batch_sampler": None,
52
+ "num_workers": 0,
53
+ "collate_fn": None,
54
+ "pin_memory": False,
55
+ "drop_last": False,
56
+ "timeout": 0,
57
+ "worker_init_fn": None,
58
+ "multiprocessing_context": None,
59
+ "generator": None,
60
+ "prefetch_factor": 2,
61
+ "persistent_workers": False,
62
+ "pin_memory_device": "",
63
+ }
64
+
65
+ # kwargs added after by version
66
+ _PYTORCH_DATALOADER_ADDITIONAL_KWARGS = {"2.6.0": {"in_order": True}}
67
+
68
+ for v, additional_kwargs in _PYTORCH_DATALOADER_ADDITIONAL_KWARGS.items():
69
+ if is_torch_version(">=", v):
70
+ _PYTORCH_DATALOADER_KWARGS.update(additional_kwargs)
71
+
72
+
73
+ class SeedableRandomSampler(RandomSampler):
74
+ """
75
+ Same as a random sampler, except that in `__iter__` a seed can be used.
76
+
77
+ Needed specifically in distributed cases, when the random generator for each GPU needs to start from the same seed
78
+ and be fully reproducible on multiple iterations.
79
+
80
+ If a custom `generator` is passed, it will rely on its initial seed as well as the current iteration it is on
81
+ (stored in `self.epoch`).
82
+ """
83
+
84
+ def __init__(self, *args, **kwargs):
85
+ data_seed = kwargs.pop("data_seed", None)
86
+ super().__init__(*args, **kwargs)
87
+
88
+ self.initial_seed = data_seed if data_seed is not None else torch.random.initial_seed()
89
+ self.epoch = 0
90
+
91
+ def __iter__(self):
92
+ if self.generator is None:
93
+ self.generator = torch.Generator(
94
+ device=torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu"
95
+ )
96
+ self.generator.manual_seed(self.initial_seed)
97
+
98
+ # Allow `self.epoch` to modify the seed of the generator
99
+ seed = self.epoch + self.initial_seed
100
+ # print("Setting seed at epoch", self.epoch, seed)
101
+ self.generator.manual_seed(seed)
102
+ yield from super().__iter__()
103
+ self.set_epoch(self.epoch + 1)
104
+
105
+ def set_epoch(self, epoch: int):
106
+ "Sets the current iteration of the sampler."
107
+ self.epoch = epoch
108
+
109
+
110
+ class BatchSamplerShard(BatchSampler):
111
+ """
112
+ Wraps a PyTorch `BatchSampler` to generate batches for one of the processes only. Instances of this class will
113
+ always yield a number of batches that is a round multiple of `num_processes` and that all have the same size.
114
+ Depending on the value of the `drop_last` attribute of the batch sampler passed, it will either stop the iteration
115
+ at the first batch that would be too small / not present on all processes or loop with indices from the beginning.
116
+
117
+ Args:
118
+ batch_sampler (`torch.utils.data.sampler.BatchSampler`):
119
+ The batch sampler to split in several shards.
120
+ num_processes (`int`, *optional*, defaults to 1):
121
+ The number of processes running concurrently.
122
+ process_index (`int`, *optional*, defaults to 0):
123
+ The index of the current process.
124
+ split_batches (`bool`, *optional*, defaults to `False`):
125
+ Whether the shards should be created by splitting a batch to give a piece of it on each process, or by
126
+ yielding different full batches on each process.
127
+
128
+ On two processes with a sampler of `[[0, 1, 2, 3], [4, 5, 6, 7]]`, this will result in:
129
+
130
+ - the sampler on process 0 to yield `[0, 1, 2, 3]` and the sampler on process 1 to yield `[4, 5, 6, 7]` if
131
+ this argument is set to `False`.
132
+ - the sampler on process 0 to yield `[0, 1]` then `[4, 5]` and the sampler on process 1 to yield `[2, 3]`
133
+ then `[6, 7]` if this argument is set to `True`.
134
+ even_batches (`bool`, *optional*, defaults to `True`):
135
+ Whether or not to loop back at the beginning of the sampler when the number of samples is not a round
136
+ multiple of (original batch size / number of processes).
137
+
138
+ <Tip warning={true}>
139
+
140
+ `BatchSampler`s with varying batch sizes are not enabled by default. To enable this behaviour, set `even_batches`
141
+ equal to `False`
142
+
143
+ </Tip>"""
144
+
145
+ def __init__(
146
+ self,
147
+ batch_sampler: BatchSampler,
148
+ num_processes: int = 1,
149
+ process_index: int = 0,
150
+ split_batches: bool = False,
151
+ even_batches: bool = True,
152
+ ):
153
+ if split_batches and batch_sampler.batch_size % num_processes != 0:
154
+ raise ValueError(
155
+ f"To use `BatchSamplerShard` in `split_batches` mode, the batch size ({batch_sampler.batch_size}) "
156
+ f"needs to be a round multiple of the number of processes ({num_processes})."
157
+ )
158
+ self.batch_sampler = batch_sampler
159
+ self.num_processes = num_processes
160
+ self.process_index = process_index
161
+ self.split_batches = split_batches
162
+ self.even_batches = even_batches
163
+ self.batch_size = getattr(batch_sampler, "batch_size", None)
164
+ self.drop_last = getattr(batch_sampler, "drop_last", False)
165
+ if self.batch_size is None and self.even_batches:
166
+ raise ValueError(
167
+ "You need to use `even_batches=False` when the batch sampler has no batch size. If you "
168
+ "are not calling this method directly, set `accelerator.even_batches=False` instead."
169
+ )
170
+
171
+ @property
172
+ def total_length(self):
173
+ return len(self.batch_sampler)
174
+
175
+ def __len__(self):
176
+ if self.split_batches:
177
+ # Split batches does not change the length of the batch sampler
178
+ return len(self.batch_sampler)
179
+ if len(self.batch_sampler) % self.num_processes == 0:
180
+ # If the length is a round multiple of the number of processes, it's easy.
181
+ return len(self.batch_sampler) // self.num_processes
182
+ length = len(self.batch_sampler) // self.num_processes
183
+ if self.drop_last:
184
+ # Same if we drop the remainder.
185
+ return length
186
+ elif self.even_batches:
187
+ # When we even batches we always get +1
188
+ return length + 1
189
+ else:
190
+ # Otherwise it depends on the process index.
191
+ return length + 1 if self.process_index < len(self.batch_sampler) % self.num_processes else length
192
+
193
+ def __iter__(self):
194
+ return self._iter_with_split() if self.split_batches else self._iter_with_no_split()
195
+
196
+ def _iter_with_split(self):
197
+ initial_data = []
198
+ batch_length = self.batch_sampler.batch_size // self.num_processes
199
+ for idx, batch in enumerate(self.batch_sampler):
200
+ if idx == 0:
201
+ initial_data = batch
202
+ if len(batch) == self.batch_size:
203
+ # If the batch is full, we yield the part of it this process is responsible of.
204
+ yield batch[batch_length * self.process_index : batch_length * (self.process_index + 1)]
205
+
206
+ # If drop_last is True of the last batch was full, iteration is over, otherwise...
207
+ if not self.drop_last and len(initial_data) > 0 and len(batch) < self.batch_size:
208
+ if not self.even_batches:
209
+ if len(batch) > batch_length * self.process_index:
210
+ yield batch[batch_length * self.process_index : batch_length * (self.process_index + 1)]
211
+ else:
212
+ # For degenerate cases where the dataset has less than num_process * batch_size samples
213
+ while len(initial_data) < self.batch_size:
214
+ initial_data += initial_data
215
+ batch = batch + initial_data
216
+ yield batch[batch_length * self.process_index : batch_length * (self.process_index + 1)]
217
+
218
+ def _iter_with_no_split(self):
219
+ initial_data = []
220
+ batch_to_yield = []
221
+ for idx, batch in enumerate(self.batch_sampler):
222
+ # We gather the initial indices in case we need to circle back at the end.
223
+ if not self.drop_last and idx < self.num_processes:
224
+ initial_data += batch
225
+ # We identify the batch to yield but wait until we ar sure every process gets a full batch before actually
226
+ # yielding it.
227
+ if idx % self.num_processes == self.process_index:
228
+ batch_to_yield = batch
229
+ if idx % self.num_processes == self.num_processes - 1 and (
230
+ self.batch_size is None or len(batch) == self.batch_size
231
+ ):
232
+ yield batch_to_yield
233
+ batch_to_yield = []
234
+
235
+ # If drop_last is True, iteration is over, otherwise...
236
+ if not self.drop_last and len(initial_data) > 0:
237
+ if not self.even_batches:
238
+ if len(batch_to_yield) > 0:
239
+ yield batch_to_yield
240
+ else:
241
+ # ... we yield the complete batch we had saved before if it has the proper length
242
+ if len(batch_to_yield) == self.batch_size:
243
+ yield batch_to_yield
244
+
245
+ # For degenerate cases where the dataset has less than num_process * batch_size samples
246
+ while len(initial_data) < self.num_processes * self.batch_size:
247
+ initial_data += initial_data
248
+
249
+ # If the last batch seen was of the proper size, it has been yielded by its process so we move to the next
250
+ if len(batch) == self.batch_size:
251
+ batch = []
252
+ idx += 1
253
+
254
+ # Make sure we yield a multiple of self.num_processes batches
255
+ cycle_index = 0
256
+ while idx % self.num_processes != 0 or len(batch) > 0:
257
+ end_index = cycle_index + self.batch_size - len(batch)
258
+ batch += initial_data[cycle_index:end_index]
259
+ if idx % self.num_processes == self.process_index:
260
+ yield batch
261
+ cycle_index = end_index
262
+ batch = []
263
+ idx += 1
264
+
265
+
266
+ class IterableDatasetShard(IterableDataset):
267
+ """
268
+ Wraps a PyTorch `IterableDataset` to generate samples for one of the processes only. Instances of this class will
269
+ always yield a number of samples that is a round multiple of the actual batch size (depending of the value of
270
+ `split_batches`, this is either `batch_size` or `batch_size x num_processes`). Depending on the value of the
271
+ `drop_last` attribute of the batch sampler passed, it will either stop the iteration at the first batch that would
272
+ be too small or loop with indices from the beginning.
273
+
274
+ Args:
275
+ dataset (`torch.utils.data.dataset.IterableDataset`):
276
+ The batch sampler to split in several shards.
277
+ batch_size (`int`, *optional*, defaults to 1):
278
+ The size of the batches per shard (if `split_batches=False`) or the size of the batches (if
279
+ `split_batches=True`).
280
+ drop_last (`bool`, *optional*, defaults to `False`):
281
+ Whether or not to drop the last incomplete batch or complete the last batches by using the samples from the
282
+ beginning.
283
+ num_processes (`int`, *optional*, defaults to 1):
284
+ The number of processes running concurrently.
285
+ process_index (`int`, *optional*, defaults to 0):
286
+ The index of the current process.
287
+ split_batches (`bool`, *optional*, defaults to `False`):
288
+ Whether the shards should be created by splitting a batch to give a piece of it on each process, or by
289
+ yielding different full batches on each process.
290
+
291
+ On two processes with an iterable dataset yielding of `[0, 1, 2, 3, 4, 5, 6, 7]`, this will result in:
292
+
293
+ - the shard on process 0 to yield `[0, 1, 2, 3]` and the shard on process 1 to yield `[4, 5, 6, 7]` if this
294
+ argument is set to `False`.
295
+ - the shard on process 0 to yield `[0, 1, 4, 5]` and the sampler on process 1 to yield `[2, 3, 6, 7]` if
296
+ this argument is set to `True`.
297
+ """
298
+
299
+ def __init__(
300
+ self,
301
+ dataset: IterableDataset,
302
+ batch_size: int = 1,
303
+ drop_last: bool = False,
304
+ num_processes: int = 1,
305
+ process_index: int = 0,
306
+ split_batches: bool = False,
307
+ ):
308
+ if split_batches and batch_size > 1 and batch_size % num_processes != 0:
309
+ raise ValueError(
310
+ f"To use `IterableDatasetShard` in `split_batches` mode, the batch size ({batch_size}) "
311
+ f"needs to be a round multiple of the number of processes ({num_processes})."
312
+ )
313
+ self.dataset: IterableDataset = dataset
314
+ self.batch_size = batch_size
315
+ self.drop_last = drop_last
316
+ self.num_processes = num_processes
317
+ self.process_index = process_index
318
+ self.split_batches = split_batches
319
+
320
+ def set_epoch(self, epoch):
321
+ self.epoch = epoch
322
+ if hasattr(self.dataset, "set_epoch"):
323
+ self.dataset.set_epoch(epoch)
324
+
325
+ def __len__(self):
326
+ # We will just raise the downstream error if the underlying dataset is not sized
327
+ if self.drop_last:
328
+ return (len(self.dataset) // (self.batch_size * self.num_processes)) * self.batch_size
329
+ else:
330
+ return math.ceil(len(self.dataset) / (self.batch_size * self.num_processes)) * self.batch_size
331
+
332
+ def __iter__(self):
333
+ if (
334
+ not hasattr(self.dataset, "set_epoch")
335
+ and hasattr(self.dataset, "generator")
336
+ and isinstance(self.dataset.generator, torch.Generator)
337
+ ):
338
+ self.dataset.generator.manual_seed(self.epoch)
339
+ real_batch_size = self.batch_size if self.split_batches else (self.batch_size * self.num_processes)
340
+ process_batch_size = (self.batch_size // self.num_processes) if self.split_batches else self.batch_size
341
+ process_slice = range(self.process_index * process_batch_size, (self.process_index + 1) * process_batch_size)
342
+
343
+ first_batch = None
344
+ current_batch = []
345
+ for element in self.dataset:
346
+ current_batch.append(element)
347
+ # Wait to have a full batch before yielding elements.
348
+ if len(current_batch) == real_batch_size:
349
+ for i in process_slice:
350
+ yield current_batch[i]
351
+ if first_batch is None:
352
+ first_batch = current_batch.copy()
353
+ current_batch = []
354
+
355
+ # Finished if drop_last is True, otherwise complete the last batch with elements from the beginning.
356
+ if not self.drop_last and len(current_batch) > 0:
357
+ if first_batch is None:
358
+ first_batch = current_batch.copy()
359
+ while len(current_batch) < real_batch_size:
360
+ current_batch += first_batch
361
+ for i in process_slice:
362
+ yield current_batch[i]
363
+
364
+
365
+ class DataLoaderStateMixin:
366
+ """
367
+ Mixin class that adds a state to a `DataLoader` to keep track of the status inside the dataloader such as at the
368
+ end of the iteration, the number of items in the dataset in the last batch relative to the batch size, and other
369
+ useful information that might be needed.
370
+
371
+ **Available attributes:**
372
+
373
+ - **end_of_dataloader** (`bool`) -- Whether at the last iteration or batch
374
+ - **remainder** (`int`) -- The number of items that are remaining in the last batch, relative to the total
375
+ batch size
376
+
377
+ <Tip warning={true}>
378
+
379
+ Inheriters of this class should ensure that the class creates a `GradientState()` instance, stored in
380
+ `self.gradient_state`.
381
+
382
+ </Tip>
383
+
384
+ """
385
+
386
+ def __init_subclass__(cls, **kwargs):
387
+ cls.end_of_dataloader = False
388
+ cls.remainder = -1
389
+
390
+ def reset(self):
391
+ self.end_of_dataloader = False
392
+ self.remainder = -1
393
+
394
+ def begin(self):
395
+ "Prepares the gradient state for the current dataloader"
396
+ self.reset()
397
+ with suppress(Exception):
398
+ if not self._drop_last:
399
+ length = getattr(self.dataset, "total_dataset_length", len(self.dataset))
400
+ self.remainder = length % self.total_batch_size
401
+ self.gradient_state._add_dataloader(self)
402
+
403
+ def end(self):
404
+ "Cleans up the gradient state after exiting the dataloader"
405
+ self.gradient_state._remove_dataloader(self)
406
+
407
+
408
+ class DataLoaderAdapter:
409
+ """
410
+ A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. For
411
+ compatibility reasons, this class inherits from the class it wraps around, so it can be used as a drop-in.
412
+ """
413
+
414
+ def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, **kwargs):
415
+ self.use_stateful_dataloader = use_stateful_dataloader
416
+ if is_torchdata_stateful_dataloader_available():
417
+ from torchdata.stateful_dataloader import StatefulDataLoader
418
+
419
+ if use_stateful_dataloader and not is_torchdata_stateful_dataloader_available():
420
+ raise ImportError(
421
+ "StatefulDataLoader is not available. Please install torchdata version 0.8.0 or higher to use it."
422
+ )
423
+ if use_stateful_dataloader:
424
+ torchdata_version = version.parse(importlib.metadata.version("torchdata"))
425
+ if (
426
+ "in_order" in kwargs
427
+ and compare_versions(torchdata_version, "<", "0.11")
428
+ and is_torch_version(">=", "2.6.0")
429
+ ):
430
+ kwargs.pop("in_order")
431
+ self.base_dataloader = StatefulDataLoader(dataset, batch_sampler=batch_sampler, **kwargs)
432
+ else:
433
+ self.base_dataloader = DataLoader(dataset, batch_sampler=batch_sampler, **kwargs)
434
+
435
+ if hasattr(self.base_dataloader, "state_dict"):
436
+ self.dl_state_dict = self.base_dataloader.state_dict()
437
+
438
+ def __getattr__(self, name):
439
+ # Avoid infinite recursion if we try to access a nonexistent base_dataloader attribute.
440
+ if name == "base_dataloader":
441
+ raise AttributeError()
442
+ # Delegate attribute access to the internal dataloader
443
+ return getattr(self.base_dataloader, name)
444
+
445
+ def state_dict(self):
446
+ return self.dl_state_dict
447
+
448
+ def load_state_dict(self, state_dict):
449
+ self.base_dataloader.load_state_dict(state_dict)
450
+
451
+ @property
452
+ def __class__(self):
453
+ """
454
+ In order to maintain backwards compatibility with other code, we need to ensure `isinstance(obj, DataLoader)`
455
+ returns true. This is because some downstream code assumes that the `DataLoader` is the base class of the
456
+ object.
457
+ """
458
+ return self.base_dataloader.__class__
459
+
460
+ def __len__(self):
461
+ return len(self.base_dataloader)
462
+
463
+ def adjust_state_dict_for_prefetch(self):
464
+ """
465
+ Adjusts the state dict for prefetching. Natively, this will adjust all of the iters yielded keys in
466
+ `self.dl_state_dict` by a factor of `num_processes - 1`, however if a custom correction is needed, this can be
467
+ overridden.
468
+
469
+ This should modify `self.dl_state_dict` directly
470
+ """
471
+ # The state dict will be off by a factor of `n-1` batch too many during DDP,
472
+ # so we need to adjust it here
473
+ if PartialState().distributed_type != DistributedType.NO:
474
+ factor = PartialState().num_processes - 1
475
+ if self.dl_state_dict["_sampler_iter_yielded"] > 0:
476
+ self.dl_state_dict["_sampler_iter_yielded"] -= factor
477
+ if self.dl_state_dict["_num_yielded"] > 0:
478
+ self.dl_state_dict["_num_yielded"] -= factor
479
+ if self.dl_state_dict["_index_sampler_state"] is not None:
480
+ if (
481
+ "samples_yielded" in self.dl_state_dict["_index_sampler_state"]
482
+ and self.dl_state_dict["_index_sampler_state"]["samples_yielded"] > 0
483
+ ):
484
+ self.dl_state_dict["_index_sampler_state"]["samples_yielded"] -= self.batch_size * factor
485
+
486
+ def _update_state_dict(self):
487
+ # The state_dict of the underlying base_dataloader may be ahead of what is currently being yielded.
488
+ # E.g. the implementation of DataLoaderShard involves having an underlying iterator 1 element ahead of
489
+ # what it wants to yield.
490
+ #
491
+ # _update_state_dict is called to snapshot the state_dict that would properly recover the DataLoaderAdapter.
492
+ if hasattr(self.base_dataloader, "state_dict"):
493
+ self.dl_state_dict = self.base_dataloader.state_dict()
494
+ # Potentially modify the state_dict to adjust for prefetching
495
+ self.adjust_state_dict_for_prefetch()
496
+ # Then tag if we are at the end of the dataloader
497
+ self.dl_state_dict["_iterator_finished"] = self.end_of_dataloader
498
+
499
+
500
+ class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
501
+ """
502
+ Subclass of `DataLoaderAdapter` that will deal with device placement and current distributed setup.
503
+
504
+ Args:
505
+ dataset (`torch.utils.data.dataset.Dataset`):
506
+ The dataset to use to build this dataloader.
507
+ device (`torch.device`, *optional*):
508
+ If passed, the device to put all batches on.
509
+ rng_types (list of `str` or [`~utils.RNGType`]):
510
+ The list of random number generators to synchronize at the beginning of each iteration. Should be one or
511
+ several of:
512
+
513
+ - `"torch"`: the base torch random number generator
514
+ - `"cuda"`: the CUDA random number generator (GPU only)
515
+ - `"xla"`: the XLA random number generator (TPU only)
516
+ - `"generator"`: an optional `torch.Generator`
517
+ synchronized_generator (`torch.Generator`, *optional*):
518
+ A random number generator to keep synchronized across processes.
519
+ skip_batches (`int`, *optional*, defaults to 0):
520
+ The number of batches to skip at the beginning.
521
+ use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
522
+ Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
523
+ **kwargs (additional keyword arguments, *optional*):
524
+ All other keyword arguments to pass to the regular `DataLoader` initialization.
525
+
526
+ **Available attributes:**
527
+
528
+ - **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes.
529
+ Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total
530
+ number of processes
531
+
532
+ - **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes.
533
+ """
534
+
535
+ def __init__(
536
+ self,
537
+ dataset,
538
+ device=None,
539
+ rng_types=None,
540
+ synchronized_generator=None,
541
+ skip_batches=0,
542
+ use_stateful_dataloader=False,
543
+ _drop_last: bool = False,
544
+ _non_blocking: bool = False,
545
+ torch_device_mesh=None,
546
+ **kwargs,
547
+ ):
548
+ super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
549
+ self.device = device
550
+ self.rng_types = rng_types
551
+ self.synchronized_generator = synchronized_generator
552
+ self.skip_batches = skip_batches
553
+ self.gradient_state = GradientState()
554
+ self._drop_last = _drop_last
555
+ self._non_blocking = _non_blocking
556
+ self.iteration = 0
557
+
558
+ def __iter__(self):
559
+ if self.rng_types is not None:
560
+ synchronize_rng_states(self.rng_types, self.synchronized_generator)
561
+ self.begin()
562
+
563
+ self.set_epoch(self.iteration)
564
+ dataloader_iter = self.base_dataloader.__iter__()
565
+ # We iterate one batch ahead to check when we are at the end
566
+ try:
567
+ current_batch = next(dataloader_iter)
568
+ except StopIteration:
569
+ self.end()
570
+ return
571
+
572
+ batch_index = 0
573
+ while True:
574
+ try:
575
+ # But we still move it to the device so it is done before `StopIteration` is reached
576
+ if self.device is not None:
577
+ current_batch = send_to_device(current_batch, self.device, non_blocking=self._non_blocking)
578
+ self._update_state_dict()
579
+ next_batch = next(dataloader_iter)
580
+ if batch_index >= self.skip_batches:
581
+ yield current_batch
582
+ batch_index += 1
583
+ current_batch = next_batch
584
+ except StopIteration:
585
+ self.end_of_dataloader = True
586
+ self._update_state_dict()
587
+ if batch_index >= self.skip_batches:
588
+ yield current_batch
589
+ break
590
+
591
+ self.iteration += 1
592
+ self.end()
593
+
594
+ def __reduce__(self):
595
+ """
596
+ Define the `__reduce__` method to ensure a `DataLoaderShard` can be pickled and unpickled. This needs to be
597
+ explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
598
+ `__class__` member.
599
+ """
600
+ args = super().__reduce__()
601
+ return (DataLoaderShard, *args[1:])
602
+
603
+ def set_epoch(self, epoch: int):
604
+ # In case it is manually passed in, the user can set it to what they like
605
+ if self.iteration != epoch:
606
+ self.iteration = epoch
607
+ if hasattr(self.batch_sampler, "set_epoch"):
608
+ self.batch_sampler.set_epoch(epoch)
609
+ if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"):
610
+ self.batch_sampler.sampler.set_epoch(epoch)
611
+ if (
612
+ hasattr(self.batch_sampler, "batch_sampler")
613
+ and hasattr(self.batch_sampler.batch_sampler, "sampler")
614
+ and hasattr(self.batch_sampler.batch_sampler.sampler, "set_epoch")
615
+ ):
616
+ self.batch_sampler.batch_sampler.sampler.set_epoch(epoch)
617
+ # We support if a custom `Dataset` implementation has `set_epoch`
618
+ # or in general HF datasets `Datasets`
619
+ elif hasattr(self.dataset, "set_epoch"):
620
+ self.dataset.set_epoch(epoch)
621
+
622
+ @property
623
+ def total_batch_size(self):
624
+ batch_sampler = self.sampler if isinstance(self.sampler, BatchSampler) else self.batch_sampler
625
+ return (
626
+ batch_sampler.batch_size
627
+ if getattr(batch_sampler, "split_batches", False)
628
+ else (batch_sampler.batch_size * getattr(batch_sampler, "num_processes", 1))
629
+ )
630
+
631
+ @property
632
+ def total_dataset_length(self):
633
+ if hasattr(self.dataset, "total_length"):
634
+ return self.dataset.total_length
635
+ else:
636
+ return len(self.dataset)
637
+
638
+ def get_sampler(self):
639
+ return get_sampler(self)
640
+
641
+ def set_sampler(self, sampler):
642
+ sampler_is_batch_sampler = isinstance(self.sampler, BatchSampler)
643
+ if sampler_is_batch_sampler:
644
+ self.sampler.sampler = sampler
645
+ else:
646
+ self.batch_sampler.sampler = sampler
647
+ if hasattr(self.batch_sampler, "batch_sampler"):
648
+ self.batch_sampler.batch_sampler.sampler = sampler
649
+
650
+
651
+ if is_torch_xla_available():
652
+ import torch_xla.distributed.parallel_loader as xpl
653
+
654
+ class MpDeviceLoaderWrapper(xpl.MpDeviceLoader):
655
+ """
656
+ Wrapper for the xpl.MpDeviceLoader class that knows the total batch size.
657
+
658
+ XLA preloading threads will all call DataLoaderShard's __iter__(). Remove rng_types from DataLoaderShard to
659
+ prevent it from using the XLA device in the preloading threads, and synchronize the RNG once from the main
660
+ thread only.
661
+
662
+ **Available attributes:**
663
+
664
+ - **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes.
665
+ Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total
666
+ number of processes
667
+
668
+ - **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes.
669
+ """
670
+
671
+ def __init__(self, dataloader: DataLoaderShard, device: torch.device):
672
+ super().__init__(dataloader, device)
673
+ self._rng_types = self._loader.rng_types
674
+ self._loader.rng_types = None
675
+ self.device = device
676
+
677
+ def __iter__(self):
678
+ if self._rng_types is not None:
679
+ synchronize_rng_states(self._rng_types, self._loader.synchronized_generator)
680
+
681
+ return super().__iter__()
682
+
683
+ def set_epoch(self, epoch: int):
684
+ if hasattr(self.dataloader, "set_epoch"):
685
+ self.dataloader.set_epoch(epoch)
686
+
687
+ @property
688
+ def total_batch_size(self):
689
+ return self._loader.total_batch_size
690
+
691
+ @property
692
+ def total_dataset_length(self):
693
+ return self._loader.total_dataset_length
694
+
695
+ @property
696
+ def batch_sampler(self):
697
+ return self._loader.batch_sampler
698
+
699
+ @property
700
+ def dataloader(self):
701
+ return self._loader
702
+
703
+
704
+ class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
705
+ """
706
+ Subclass of `DataLoaderAdapter` that will iterate and preprocess on process 0 only, then dispatch on each process
707
+ their part of the batch.
708
+
709
+ Args:
710
+ split_batches (`bool`, *optional*, defaults to `False`):
711
+ Whether the resulting `DataLoader` should split the batches of the original data loader across devices or
712
+ yield full batches (in which case it will yield batches starting at the `process_index`-th and advancing of
713
+ `num_processes` batches at each iteration). Another way to see this is that the observed batch size will be
714
+ the same as the initial `dataloader` if this option is set to `True`, the batch size of the initial
715
+ `dataloader` multiplied by `num_processes` otherwise. Setting this option to `True` requires that the batch
716
+ size of the `dataloader` is a round multiple of `batch_size`.
717
+ skip_batches (`int`, *optional*, defaults to 0):
718
+ The number of batches to skip at the beginning of an iteration.
719
+ use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
720
+ Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
721
+
722
+ **Available attributes:**
723
+
724
+ - **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes.
725
+ Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total
726
+ number of processes
727
+
728
+ - **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes.
729
+ """
730
+
731
+ def __init__(
732
+ self,
733
+ dataset,
734
+ split_batches: bool = False,
735
+ skip_batches=0,
736
+ use_stateful_dataloader=False,
737
+ _drop_last: bool = False,
738
+ _non_blocking: bool = False,
739
+ slice_fn=None,
740
+ torch_device_mesh=None,
741
+ **kwargs,
742
+ ):
743
+ shuffle = False
744
+ from torch.utils.data.datapipes.iter.combinatorics import ShufflerIterDataPipe
745
+
746
+ # We need to save the shuffling state of the DataPipe
747
+ if isinstance(dataset, ShufflerIterDataPipe):
748
+ shuffle = dataset._shuffle_enabled
749
+ super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
750
+ self.split_batches = split_batches
751
+ if shuffle:
752
+ torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle)
753
+
754
+ self.gradient_state = GradientState()
755
+ self.state = PartialState()
756
+ self._drop_last = _drop_last
757
+ self._non_blocking = _non_blocking
758
+ self.skip_batches = skip_batches
759
+ self.torch_device_mesh = torch_device_mesh
760
+
761
+ self.slice_fn = slice_tensors if slice_fn is None else slice_fn
762
+ self.iteration = 0
763
+
764
+ # if a device mesh is provided extract each dimension (dp, fsdp, tp)
765
+ # device mesh may hold any number of dimensions, however,
766
+ # below code is for targeted support for dp, fsdp and tp
767
+
768
+ # device mesh will be used only if there is tp involved
769
+ # or any multi-dimensional parallelism involving tp
770
+ # (dp, tp) (fsdp, tp) (dp, fsdp, tp)
771
+ # otherwise the default behaviour not using device mesh should be sufficient
772
+ # since multi dimensional parallelism devoid of tp would anyway need
773
+ # different batches for each process irrespective of dp or fsdp
774
+ self.submesh_tp = None
775
+ self.submesh_dp = None
776
+ self.submesh_fsdp = None
777
+ if self.torch_device_mesh and "tp" in self.torch_device_mesh.mesh_dim_names:
778
+ self.submesh_tp = self.torch_device_mesh["tp"]
779
+ if "dp" in self.torch_device_mesh.mesh_dim_names:
780
+ self.submesh_dp = self.torch_device_mesh["dp"]
781
+ if "fsdp" in self.torch_device_mesh.mesh_dim_names:
782
+ self.submesh_fsdp = self.torch_device_mesh["fsdp"]
783
+ if self.submesh_tp and (self.submesh_dp or self.submesh_fsdp):
784
+ raise ValueError("TP + (DP/FSDP) is not yet supported in dispatch mode")
785
+
786
+ def _fetch_batches(self, iterator):
787
+ batches, batch = None, None
788
+ # On process 0, we gather the batch to dispatch.
789
+ if self.state.process_index == 0:
790
+ # Procedure to support TP only is simpler
791
+ # since we want to dispatch the same batch of samples across all ranks
792
+ # this removes complexity of handling multiple tp rank groups when TP + DP
793
+ # combination is involved.
794
+
795
+ try:
796
+ # for TP case avoid using split_batches
797
+ # since it would mean that the dataloader should be spilling out
798
+ # duplicates of batches.
799
+ if self.split_batches:
800
+ # One batch of the main iterator is dispatched and split.
801
+ if self.submesh_tp:
802
+ logger.warning(
803
+ "Use of split_batches for TP would need the dataloader to produce duplicate batches,"
804
+ "otherwise, use dispatch_batches=True instead."
805
+ )
806
+ self._update_state_dict()
807
+ batch = next(iterator)
808
+ else:
809
+ # num_processes batches of the main iterator are concatenated then dispatched and split.
810
+ # We add the batches one by one so we have the remainder available when drop_last=False.
811
+ batches = []
812
+ if self.submesh_tp:
813
+ # when tp, extract single batch and then replicate
814
+ self._update_state_dict()
815
+ batch = next(iterator)
816
+ batches = [batch] * self.state.num_processes
817
+ else:
818
+ for _ in range(self.state.num_processes):
819
+ self._update_state_dict()
820
+ batches.append(next(iterator))
821
+ try:
822
+ batch = concatenate(batches, dim=0)
823
+ except RuntimeError as e:
824
+ raise RuntimeError(
825
+ "You can't use batches of different size with `dispatch_batches=True` or when using an `IterableDataset`."
826
+ "either pass `dispatch_batches=False` and have each process fetch its own batch "
827
+ " or pass `split_batches=True`. By doing so, the main process will fetch a full batch and "
828
+ "slice it into `num_processes` batches for each process."
829
+ ) from e
830
+ # In both cases, we need to get the structure of the batch that we will broadcast on other
831
+ # processes to initialize the tensors with the right shape.
832
+ # data_structure, stop_iteration
833
+ batch_info = [get_data_structure(batch), False]
834
+ except StopIteration:
835
+ batch_info = [None, True]
836
+ else:
837
+ batch_info = [None, self._stop_iteration]
838
+ # This is inplace, so after this instruction, every process has the same `batch_info` as process 0.
839
+ broadcast_object_list(batch_info)
840
+ self._stop_iteration = batch_info[1]
841
+ if self._stop_iteration:
842
+ # If drop_last is False and split_batches is False, we may have a remainder to take care of.
843
+ if not self.split_batches and not self._drop_last:
844
+ if self.state.process_index == 0 and len(batches) > 0:
845
+ batch = concatenate(batches, dim=0)
846
+ batch_info = [get_data_structure(batch), False]
847
+ else:
848
+ batch_info = [None, True]
849
+ broadcast_object_list(batch_info)
850
+ return batch, batch_info
851
+
852
+ def __iter__(self):
853
+ self.begin()
854
+ self.set_epoch(self.iteration)
855
+ main_iterator = None
856
+ if is_torch_version(">=", "2.0.1"):
857
+ # NOTE PyTorch DataLoader adds forward compatibilities for DataPipes, which broadcasts
858
+ # shared seed to all dist processes. Thus, we need to create iterator for all dist processes.
859
+ # But, we only iterate through the DataLoader on process 0.
860
+ main_iterator = self.base_dataloader.__iter__()
861
+ elif self.state.process_index == 0:
862
+ main_iterator = self.base_dataloader.__iter__()
863
+ stop_iteration = False
864
+ self._stop_iteration = False
865
+ first_batch = None
866
+ next_batch, next_batch_info = self._fetch_batches(main_iterator)
867
+ batch_index = 0
868
+ while not stop_iteration:
869
+ batch, batch_info = next_batch, next_batch_info
870
+
871
+ if self.state.process_index != 0:
872
+ # Initialize tensors on other processes than process 0.
873
+ batch = initialize_tensors(batch_info[0])
874
+ batch = send_to_device(batch, self.state.device, non_blocking=self._non_blocking)
875
+ # Broadcast the batch before splitting it.
876
+ batch = broadcast(batch, from_process=0)
877
+
878
+ if not self._drop_last and first_batch is None:
879
+ # We keep at least num processes elements of the first batch to be able to complete the last batch
880
+ first_batch = self.slice_fn(
881
+ batch,
882
+ slice(0, self.state.num_processes),
883
+ process_index=self.state.process_index,
884
+ num_processes=self.state.num_processes,
885
+ )
886
+
887
+ if batch is None:
888
+ raise ValueError(
889
+ f"Batch does not contain any data (`{batch}`). At the end of all iterable data available before expected stop iteration."
890
+ )
891
+
892
+ observed_batch_size = find_batch_size(batch)
893
+ batch_size = observed_batch_size // self.state.num_processes
894
+
895
+ stop_iteration = self._stop_iteration
896
+ if not stop_iteration:
897
+ # We may still be at the end of the dataloader without knowing it yet: if there is nothing left in
898
+ # the dataloader since the number of batches is a round multiple of the number of processes.
899
+ next_batch, next_batch_info = self._fetch_batches(main_iterator)
900
+ # next_batch_info[0] is None when there are no more batches, otherwise we still need to process them.
901
+ if self._stop_iteration and next_batch_info[0] is None:
902
+ stop_iteration = True
903
+
904
+ if not self._drop_last and stop_iteration and observed_batch_size % self.state.num_processes != 0:
905
+ # If the last batch is not complete, let's add the first batch to it.
906
+ batch = concatenate([batch, first_batch], dim=0)
907
+ # Batch size computation above is wrong, it's off by 1 so we fix it.
908
+ batch_size += 1
909
+
910
+ data_slice = slice(self.state.process_index * batch_size, (self.state.process_index + 1) * batch_size)
911
+ batch = self.slice_fn(
912
+ batch,
913
+ data_slice,
914
+ process_index=self.state.process_index,
915
+ num_processes=self.state.num_processes,
916
+ )
917
+
918
+ if stop_iteration:
919
+ self.end_of_dataloader = True
920
+ self._update_state_dict()
921
+ self.remainder = observed_batch_size
922
+ if batch_index >= self.skip_batches:
923
+ yield batch
924
+ batch_index += 1
925
+ self.iteration += 1
926
+ self.end()
927
+
928
+ def set_epoch(self, epoch: int):
929
+ # In case it is manually passed in, the user can set it to what they like
930
+ if self.iteration != epoch:
931
+ self.iteration = epoch
932
+ if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"):
933
+ self.batch_sampler.sampler.set_epoch(epoch)
934
+ elif hasattr(self.dataset, "set_epoch"):
935
+ self.dataset.set_epoch(epoch)
936
+
937
+ def __len__(self):
938
+ whole_length = len(self.base_dataloader)
939
+ if self.split_batches:
940
+ return whole_length
941
+ elif self._drop_last:
942
+ return whole_length // self.state.num_processes
943
+ else:
944
+ return math.ceil(whole_length / self.state.num_processes)
945
+
946
+ def __reduce__(self):
947
+ """
948
+ Define the `__reduce__` method to ensure a `DataLoaderDispatcher` can be pickled and unpickled. This needs to
949
+ be explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
950
+ `__class__` member.
951
+ """
952
+ args = super().__reduce__()
953
+ return (DataLoaderDispatcher, *args[1:])
954
+
955
+ @property
956
+ def total_batch_size(self):
957
+ return (
958
+ self.dataset.batch_size if self.split_batches else (self.dataset.batch_size * self.dataset.num_processes)
959
+ )
960
+
961
+ @property
962
+ def total_dataset_length(self):
963
+ return len(self.dataset)
964
+
965
+ def get_sampler(self):
966
+ return get_sampler(self)
967
+
968
+ def set_sampler(self, sampler):
969
+ sampler_is_batch_sampler = isinstance(self.sampler, BatchSampler)
970
+ if sampler_is_batch_sampler:
971
+ self.sampler.sampler = sampler
972
+ else:
973
+ self.batch_sampler.sampler = sampler
974
+ if hasattr(self.batch_sampler, "batch_sampler"):
975
+ self.batch_sampler.batch_sampler.sampler = sampler
976
+
977
+
978
+ def get_sampler(dataloader):
979
+ """
980
+ Get the sampler associated to the dataloader
981
+
982
+ Args:
983
+ dataloader (`torch.utils.data.dataloader.DataLoader`):
984
+ The data loader to split across several devices.
985
+ Returns:
986
+ `torch.utils.data.Sampler`: The sampler associated to the dataloader
987
+ """
988
+ sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
989
+ if sampler_is_batch_sampler:
990
+ sampler = getattr(dataloader.sampler, "sampler", None)
991
+ else:
992
+ sampler = getattr(dataloader.batch_sampler, "sampler", None)
993
+ return sampler
994
+
995
+
996
+ def prepare_data_loader(
997
+ dataloader: DataLoader,
998
+ device: Optional[torch.device] = None,
999
+ num_processes: Optional[int] = None,
1000
+ process_index: Optional[int] = None,
1001
+ split_batches: bool = False,
1002
+ put_on_device: bool = False,
1003
+ rng_types: Optional[list[Union[str, RNGType]]] = None,
1004
+ dispatch_batches: Optional[bool] = None,
1005
+ even_batches: bool = True,
1006
+ slice_fn_for_dispatch: Optional[Callable] = None,
1007
+ use_seedable_sampler: bool = False,
1008
+ data_seed: Optional[int] = None,
1009
+ non_blocking: bool = False,
1010
+ use_stateful_dataloader: bool = False,
1011
+ torch_device_mesh=None,
1012
+ ) -> DataLoader:
1013
+ """
1014
+ Wraps a PyTorch `DataLoader` to generate batches for one of the processes only.
1015
+
1016
+ Depending on the value of the `drop_last` attribute of the `dataloader` passed, it will either stop the iteration
1017
+ at the first batch that would be too small / not present on all processes or loop with indices from the beginning.
1018
+
1019
+ Args:
1020
+ dataloader (`torch.utils.data.dataloader.DataLoader`):
1021
+ The data loader to split across several devices.
1022
+ device (`torch.device`):
1023
+ The target device for the returned `DataLoader`.
1024
+ num_processes (`int`, *optional*):
1025
+ The number of processes running concurrently. Will default to the value given by [`~state.PartialState`].
1026
+ process_index (`int`, *optional*):
1027
+ The index of the current process. Will default to the value given by [`~state.PartialState`].
1028
+ split_batches (`bool`, *optional*, defaults to `False`):
1029
+ Whether the resulting `DataLoader` should split the batches of the original data loader across devices or
1030
+ yield full batches (in which case it will yield batches starting at the `process_index`-th and advancing of
1031
+ `num_processes` batches at each iteration).
1032
+
1033
+ Another way to see this is that the observed batch size will be the same as the initial `dataloader` if
1034
+ this option is set to `True`, the batch size of the initial `dataloader` multiplied by `num_processes`
1035
+ otherwise.
1036
+
1037
+ Setting this option to `True` requires that the batch size of the `dataloader` is a round multiple of
1038
+ `batch_size`.
1039
+ put_on_device (`bool`, *optional*, defaults to `False`):
1040
+ Whether or not to put the batches on `device` (only works if the batches are nested list, tuples or
1041
+ dictionaries of tensors).
1042
+ rng_types (list of `str` or [`~utils.RNGType`]):
1043
+ The list of random number generators to synchronize at the beginning of each iteration. Should be one or
1044
+ several of:
1045
+
1046
+ - `"torch"`: the base torch random number generator
1047
+ - `"cuda"`: the CUDA random number generator (GPU only)
1048
+ - `"xla"`: the XLA random number generator (TPU only)
1049
+ - `"generator"`: the `torch.Generator` of the sampler (or batch sampler if there is no sampler in your
1050
+ dataloader) or of the iterable dataset (if it exists) if the underlying dataset is of that type.
1051
+
1052
+ dispatch_batches (`bool`, *optional*):
1053
+ If set to `True`, the dataloader prepared is only iterated through on the main process and then the batches
1054
+ are split and broadcast to each process. Will default to `True` when the underlying dataset is an
1055
+ `IterableDataset`, `False` otherwise.
1056
+ even_batches (`bool`, *optional*, defaults to `True`):
1057
+ If set to `True`, in cases where the total batch size across all processes does not exactly divide the
1058
+ dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among
1059
+ all workers.
1060
+ slice_fn_for_dispatch (`Callable`, *optional*`):
1061
+ If passed, this function will be used to slice tensors across `num_processes`. Will default to
1062
+ [`~utils.slice_tensors`]. This argument is used only when `dispatch_batches` is set to `True` and will be
1063
+ ignored otherwise.
1064
+ use_seedable_sampler (`bool`, *optional*, defaults to `False`):
1065
+ Whether to use the [`~data_loader.SeedableRandomSampler`] instead of a `RandomSampler` for better
1066
+ reproducibility. Comes at a cost of potentially different performances due to different shuffling
1067
+ algorithms but ensures results will be the *exact* same. Should be paired with `set_seed()` at every
1068
+ `self.set_epoch`
1069
+ data_seed (`int`, *optional*, defaults to `None`):
1070
+ The seed to use for the underlying generator when using `use_seedable_sampler`. If `None`, the generator
1071
+ will use the current default seed from torch.
1072
+ non_blocking (`bool`, *optional*, defaults to `False`):
1073
+ If set to `True`, dataloader will utilize non-blocking host-to-device transfers. If the dataloader has
1074
+ `pin_memory` set to `True`, this will help to increase overlap between data transfer and computations.
1075
+ use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
1076
+ "If set to true, the dataloader prepared by the Accelerator will be backed by "
1077
+ "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader).
1078
+ This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed."
1079
+ torch_device_mesh (`torch.distributed.DeviceMesh`, *optional*, defaults to `None`):
1080
+ PyTorch device mesh.
1081
+
1082
+
1083
+ Returns:
1084
+ `torch.utils.data.dataloader.DataLoader`: A new data loader that will yield the portion of the batches
1085
+
1086
+ <Tip warning={true}>
1087
+
1088
+ `BatchSampler`s with varying batch sizes are not enabled by default. To enable this behaviour, set `even_batches`
1089
+ equal to `False`
1090
+
1091
+ </Tip>
1092
+ """
1093
+ if dispatch_batches is None:
1094
+ if not put_on_device:
1095
+ dispatch_batches = False
1096
+ else:
1097
+ dispatch_batches = isinstance(dataloader.dataset, IterableDataset)
1098
+
1099
+ if dispatch_batches and not put_on_device:
1100
+ raise ValueError("Using `dispatch_batches=True` requires `put_on_device=True`.")
1101
+ # Grab defaults from PartialState
1102
+ state = PartialState()
1103
+ if num_processes is None:
1104
+ num_processes = state.num_processes
1105
+
1106
+ if process_index is None:
1107
+ process_index = state.process_index
1108
+
1109
+ if torch_device_mesh:
1110
+ if state.distributed_type == DistributedType.DEEPSPEED:
1111
+ # In DeepSpeed, the optimizer sharing level in DP is determined by the config file.
1112
+ # Only considers "dp" and "tp".
1113
+ # Given a device mesh (dp, tp) = (2, 3):
1114
+ # - From the data parallel perspective, ranks should be structured as: 0 0 0 1 1 1
1115
+ # - Processes with the same DP rank will receive the same batch.
1116
+ submesh_tp_size = 1
1117
+ if "tp" in torch_device_mesh.mesh_dim_names:
1118
+ submesh_tp_size = torch_device_mesh["tp"].size()
1119
+ process_index = process_index // submesh_tp_size
1120
+ num_processes = num_processes // submesh_tp_size
1121
+ else:
1122
+ # when device mesh is used, specifically with TP
1123
+ # then there is need to update process_index and num_processes
1124
+ # to bring in the effect of generating same batch across TP ranks
1125
+ # and different batch across FSDP and DP ranks.
1126
+ # Example:
1127
+ # if device mesh is (dp,fsdp,tp) = (2, 2, 3)
1128
+ # ranks would range from 0...11
1129
+ # from data angle ranks should look like 0 0 0 1 1 1 2 2 2 3 3 3
1130
+ # processes with same ranks/ids would receive the same batch
1131
+ # for CP the same as TP applies
1132
+ submesh_fsdp_size = 1
1133
+ submesh_dp_size = 1
1134
+ submesh_tp_size = 1
1135
+ submesh_cp_size = 1
1136
+ if "tp" in torch_device_mesh.mesh_dim_names:
1137
+ submesh_tp_size = torch_device_mesh["tp"].size()
1138
+ if "cp" in torch_device_mesh.mesh_dim_names:
1139
+ submesh_cp_size = torch_device_mesh["cp"].size()
1140
+ if "dp_replicate" in torch_device_mesh.mesh_dim_names:
1141
+ submesh_dp_size = torch_device_mesh["dp_replicate"].size()
1142
+ if "dp_shard" in torch_device_mesh.mesh_dim_names:
1143
+ submesh_fsdp_size = torch_device_mesh["dp_shard"].size()
1144
+ process_index = process_index // (submesh_tp_size * submesh_cp_size)
1145
+ num_processes = submesh_fsdp_size * submesh_dp_size
1146
+
1147
+ # Sanity check
1148
+ if split_batches:
1149
+ if dataloader.batch_size is not None:
1150
+ batch_size_for_check = dataloader.batch_size
1151
+ else:
1152
+ # For custom batch_sampler
1153
+ if hasattr(dataloader.batch_sampler, "batch_size"):
1154
+ batch_size_for_check = dataloader.batch_sampler.batch_size
1155
+ else:
1156
+ raise ValueError(
1157
+ "In order to use `split_batches==True` you must have a `batch_size` attribute either in the passed "
1158
+ "`dataloader` or `dataloader.batch_sampler` objects, and it has to return a natural number. "
1159
+ "Your `dataloader.batch_size` is None and `dataloader.batch_sampler` "
1160
+ f"(`{type(dataloader.batch_sampler)}`) does not have the `batch_size` attribute set."
1161
+ )
1162
+
1163
+ if batch_size_for_check > 1 and batch_size_for_check % num_processes != 0:
1164
+ raise ValueError(
1165
+ f"To use a `DataLoader` in `split_batches` mode, the batch size ({dataloader.batch_size}) "
1166
+ f"needs to be a round multiple of the number of processes ({num_processes})."
1167
+ )
1168
+
1169
+ new_dataset = dataloader.dataset
1170
+ # Iterable dataset doesn't like batch_sampler, but data_loader creates a default one for it
1171
+ new_batch_sampler = dataloader.batch_sampler if not isinstance(new_dataset, IterableDataset) else None
1172
+ sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
1173
+ synchronized_generator = None
1174
+
1175
+ sampler = get_sampler(dataloader)
1176
+ if isinstance(sampler, RandomSampler) and use_seedable_sampler:
1177
+ # When iterating through the dataloader during distributed processes
1178
+ # we want to ensure that on each process we are iterating through the same
1179
+ # samples in the same order if a seed is set. This requires a tweak
1180
+ # to the `torch.utils.data.RandomSampler` class (if used).
1181
+ sampler = SeedableRandomSampler(
1182
+ data_source=sampler.data_source,
1183
+ replacement=sampler.replacement,
1184
+ num_samples=sampler._num_samples,
1185
+ generator=getattr(
1186
+ sampler,
1187
+ "generator",
1188
+ torch.Generator(device=torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu"),
1189
+ ),
1190
+ data_seed=data_seed,
1191
+ )
1192
+
1193
+ if isinstance(dataloader.sampler, RandomSampler) and state.distributed_type == DistributedType.XLA:
1194
+ # isinstance(dataloader.sampler, RandomSampler) indicates the original dataloader has `shuffle` enabled.
1195
+ generator = torch.Generator(
1196
+ device=torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu"
1197
+ )
1198
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
1199
+ generator.manual_seed(seed)
1200
+ dataloader.generator = generator
1201
+ dataloader.sampler.generator = generator
1202
+ # No change if no multiprocess
1203
+ if (num_processes != 1 or state.distributed_type == DistributedType.MEGATRON_LM) and not dispatch_batches:
1204
+ if is_datasets_available():
1205
+ from datasets import IterableDataset as DatasetsIterableDataset
1206
+ if (
1207
+ is_datasets_available()
1208
+ and isinstance(new_dataset, DatasetsIterableDataset)
1209
+ and not split_batches
1210
+ and new_dataset.n_shards > num_processes
1211
+ ):
1212
+ new_dataset = new_dataset.shard(num_shards=num_processes, index=process_index)
1213
+ elif isinstance(new_dataset, IterableDataset):
1214
+ if getattr(dataloader.dataset, "generator", None) is not None:
1215
+ synchronized_generator = dataloader.dataset.generator
1216
+ new_dataset = IterableDatasetShard(
1217
+ new_dataset,
1218
+ batch_size=dataloader.batch_size,
1219
+ drop_last=dataloader.drop_last,
1220
+ num_processes=num_processes,
1221
+ process_index=process_index,
1222
+ split_batches=split_batches,
1223
+ )
1224
+ else:
1225
+ if not use_seedable_sampler and hasattr(sampler, "generator"):
1226
+ if sampler.generator is None:
1227
+ sampler.generator = torch.Generator(
1228
+ device=torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu"
1229
+ )
1230
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
1231
+ sampler.generator.manual_seed(seed)
1232
+ synchronized_generator = sampler.generator
1233
+ batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
1234
+ new_batch_sampler = BatchSamplerShard(
1235
+ batch_sampler,
1236
+ num_processes=num_processes,
1237
+ process_index=process_index,
1238
+ split_batches=split_batches,
1239
+ even_batches=even_batches,
1240
+ )
1241
+
1242
+ # We ignore all of those since they are all dealt with by our new_batch_sampler
1243
+ ignore_kwargs = [
1244
+ "batch_size",
1245
+ "shuffle",
1246
+ "sampler",
1247
+ "batch_sampler",
1248
+ "drop_last",
1249
+ ]
1250
+
1251
+ if rng_types is not None and synchronized_generator is None and "generator" in rng_types:
1252
+ rng_types.remove("generator")
1253
+
1254
+ kwargs = {
1255
+ k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k])
1256
+ for k in _PYTORCH_DATALOADER_KWARGS
1257
+ if k not in ignore_kwargs
1258
+ }
1259
+
1260
+ # Need to provide batch_size as batch_sampler is None for Iterable dataset
1261
+ if new_batch_sampler is None:
1262
+ kwargs["drop_last"] = dataloader.drop_last
1263
+ kwargs["batch_size"] = (
1264
+ dataloader.batch_size // num_processes if split_batches and not dispatch_batches else dataloader.batch_size
1265
+ )
1266
+ if dispatch_batches:
1267
+ kwargs.pop("generator")
1268
+ dataloader = DataLoaderDispatcher(
1269
+ new_dataset,
1270
+ split_batches=split_batches,
1271
+ batch_sampler=new_batch_sampler,
1272
+ _drop_last=dataloader.drop_last,
1273
+ _non_blocking=non_blocking,
1274
+ slice_fn=slice_fn_for_dispatch,
1275
+ use_stateful_dataloader=use_stateful_dataloader,
1276
+ torch_device_mesh=torch_device_mesh,
1277
+ **kwargs,
1278
+ )
1279
+ elif sampler_is_batch_sampler:
1280
+ dataloader = DataLoaderShard(
1281
+ new_dataset,
1282
+ device=device if put_on_device and state.distributed_type != DistributedType.XLA else None,
1283
+ sampler=new_batch_sampler,
1284
+ batch_size=dataloader.batch_size,
1285
+ rng_types=rng_types,
1286
+ _drop_last=dataloader.drop_last,
1287
+ _non_blocking=non_blocking,
1288
+ synchronized_generator=synchronized_generator,
1289
+ use_stateful_dataloader=use_stateful_dataloader,
1290
+ **kwargs,
1291
+ )
1292
+ else:
1293
+ dataloader = DataLoaderShard(
1294
+ new_dataset,
1295
+ device=device if put_on_device and state.distributed_type != DistributedType.XLA else None,
1296
+ batch_sampler=new_batch_sampler,
1297
+ rng_types=rng_types,
1298
+ synchronized_generator=synchronized_generator,
1299
+ _drop_last=dataloader.drop_last,
1300
+ _non_blocking=non_blocking,
1301
+ use_stateful_dataloader=use_stateful_dataloader,
1302
+ **kwargs,
1303
+ )
1304
+
1305
+ if isinstance(sampler, SeedableRandomSampler) and use_seedable_sampler:
1306
+ dataloader.set_sampler(sampler)
1307
+ if state.distributed_type == DistributedType.XLA:
1308
+ return MpDeviceLoaderWrapper(dataloader, device)
1309
+ return dataloader
1310
+
1311
+
1312
+ class SkipBatchSampler(BatchSampler):
1313
+ """
1314
+ A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`.
1315
+ Should not be used if the original dataloader is a `StatefulDataLoader`.
1316
+ """
1317
+
1318
+ def __init__(self, batch_sampler, skip_batches=0):
1319
+ self.batch_sampler = batch_sampler
1320
+ self.skip_batches = skip_batches
1321
+
1322
+ def __iter__(self):
1323
+ for index, samples in enumerate(self.batch_sampler):
1324
+ if index >= self.skip_batches:
1325
+ yield samples
1326
+
1327
+ @property
1328
+ def total_length(self):
1329
+ return len(self.batch_sampler)
1330
+
1331
+ def __len__(self):
1332
+ return len(self.batch_sampler) - self.skip_batches
1333
+
1334
+
1335
+ class SkipDataLoader(DataLoaderAdapter, DataLoaderStateMixin):
1336
+ """
1337
+ Subclass of a PyTorch `DataLoader` that will skip the first batches. Generally it's preferable to use
1338
+ `skip_first_batches`/`torchdata.StatefulDataLoader` instead of this class.
1339
+
1340
+ Args:
1341
+ dataset (`torch.utils.data.dataset.Dataset`):
1342
+ The dataset to use to build this dataloader.
1343
+ skip_batches (`int`, *optional*, defaults to 0):
1344
+ The number of batches to skip at the beginning.
1345
+ kwargs:
1346
+ All other keyword arguments to pass to the regular `DataLoader` initialization.
1347
+ """
1348
+
1349
+ def __init__(self, dataset, skip_batches=0, use_stateful_dataloader=False, **kwargs):
1350
+ super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
1351
+ self.skip_batches = skip_batches
1352
+ self.gradient_state = GradientState()
1353
+
1354
+ def __iter__(self):
1355
+ self.begin()
1356
+ for index, batch in enumerate(self.base_dataloader.__iter__()):
1357
+ if index >= self.skip_batches:
1358
+ self._update_state_dict()
1359
+ yield batch
1360
+ self.end()
1361
+
1362
+ def __len__(self):
1363
+ return len(self.base_dataloader) - self.skip_batches
1364
+
1365
+ def __reduce__(self):
1366
+ """
1367
+ Define the `__reduce__` method to ensure a `SkipDataLoader` can be pickled and unpickled. This needs to be
1368
+ explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
1369
+ `__class__` member.
1370
+ """
1371
+ args = super().__reduce__()
1372
+ return (SkipDataLoader, *args[1:])
1373
+
1374
+
1375
+ def skip_first_batches(dataloader, num_batches=0):
1376
+ """
1377
+ Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. Should not be used if
1378
+ the original dataloader is a `StatefulDataLoader`.
1379
+ """
1380
+ state = PartialState()
1381
+ if state.distributed_type == DistributedType.XLA:
1382
+ device = dataloader.device
1383
+ dataloader = dataloader.dataloader
1384
+
1385
+ dataset = dataloader.dataset
1386
+ sampler_is_batch_sampler = False
1387
+ if isinstance(dataset, IterableDataset):
1388
+ new_batch_sampler = None
1389
+ else:
1390
+ sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
1391
+ batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
1392
+ new_batch_sampler = SkipBatchSampler(batch_sampler, skip_batches=num_batches)
1393
+
1394
+ # We ignore all of those since they are all dealt with by our new_batch_sampler
1395
+ ignore_kwargs = [
1396
+ "batch_size",
1397
+ "shuffle",
1398
+ "sampler",
1399
+ "batch_sampler",
1400
+ "drop_last",
1401
+ ]
1402
+
1403
+ kwargs = {
1404
+ k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k])
1405
+ for k in _PYTORCH_DATALOADER_KWARGS
1406
+ if k not in ignore_kwargs
1407
+ }
1408
+
1409
+ # Need to provide batch_size as batch_sampler is None for Iterable dataset
1410
+ if new_batch_sampler is None:
1411
+ kwargs["drop_last"] = dataloader.drop_last
1412
+ kwargs["batch_size"] = dataloader.batch_size
1413
+
1414
+ if isinstance(dataloader, DataLoaderDispatcher):
1415
+ if new_batch_sampler is None:
1416
+ # Need to manually skip batches in the dataloader
1417
+ kwargs["skip_batches"] = num_batches
1418
+ dataloader = DataLoaderDispatcher(
1419
+ dataset,
1420
+ split_batches=dataloader.split_batches,
1421
+ batch_sampler=new_batch_sampler,
1422
+ _drop_last=dataloader._drop_last,
1423
+ **kwargs,
1424
+ )
1425
+ elif isinstance(dataloader, DataLoaderShard):
1426
+ if new_batch_sampler is None:
1427
+ # Need to manually skip batches in the dataloader
1428
+ kwargs["skip_batches"] = num_batches
1429
+ elif sampler_is_batch_sampler:
1430
+ kwargs["sampler"] = new_batch_sampler
1431
+ kwargs["batch_size"] = dataloader.batch_size
1432
+ else:
1433
+ kwargs["batch_sampler"] = new_batch_sampler
1434
+ dataloader = DataLoaderShard(
1435
+ dataset,
1436
+ device=dataloader.device,
1437
+ rng_types=dataloader.rng_types,
1438
+ synchronized_generator=dataloader.synchronized_generator,
1439
+ **kwargs,
1440
+ )
1441
+ else:
1442
+ if new_batch_sampler is None:
1443
+ # Need to manually skip batches in the dataloader
1444
+ dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs)
1445
+ else:
1446
+ dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs)
1447
+
1448
+ if state.distributed_type == DistributedType.XLA:
1449
+ dataloader = MpDeviceLoaderWrapper(dataloader, device)
1450
+
1451
+ return dataloader