File size: 15,137 Bytes
e062359 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 | #!/usr/bin/env python
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
import tempfile
import warnings
from unittest.mock import Mock
import torch
from torch.utils.data import (
BatchSampler,
DataLoader,
Dataset,
IterableDataset,
RandomSampler,
TensorDataset,
default_collate,
)
from accelerate.accelerator import Accelerator, DataLoaderConfiguration
from accelerate.utils.dataclasses import DistributedType
NUM_ELEMENTS = 22
NUM_WORKERS = 4
BATCH_SIZE = 4
class DummyDataset(Dataset):
def __len__(self):
return NUM_ELEMENTS
def __getitem__(self, index):
squeeze = False
if isinstance(index, int):
index = [index]
squeeze = True
elif isinstance(index, slice):
index = list(range(*index.indices(self.size)))
else:
index = list(index)
batch = [{"index": i, "label": i % 2, "random_augmentation": torch.rand(1).item()} for i in index]
if squeeze:
batch = batch[0]
return batch
class DummyIterableDataset(IterableDataset):
def __init__(self, data):
self.data = data
def __iter__(self):
yield from self.data
def create_accelerator(even_batches=True):
dataloader_config = DataLoaderConfiguration(even_batches=even_batches)
accelerator = Accelerator(dataloader_config=dataloader_config)
assert accelerator.num_processes == 2, "this script expects that two GPUs are available"
return accelerator
def create_dataloader(
accelerator: Accelerator, dataset_size: int, batch_size: int, iterable: bool = False, shuffle: bool = False
):
"""
Create a simple DataLoader to use during the test cases
"""
values = torch.as_tensor(range(dataset_size))
if shuffle:
values = values[torch.randperm(values.size(0))]
if iterable:
dataset = DummyIterableDataset(values)
else:
dataset = TensorDataset(torch.as_tensor(range(dataset_size)))
dl = DataLoader(dataset, batch_size=batch_size)
dl = accelerator.prepare(dl)
return dl
def verify_dataloader_batch_sizes(
accelerator: Accelerator,
dataset_size: int,
batch_size: int,
process_0_expected_batch_sizes: list[int],
process_1_expected_batch_sizes: list[int],
):
"""
A helper function for verifying the batch sizes coming from a prepared dataloader in each process
"""
dl = create_dataloader(accelerator=accelerator, dataset_size=dataset_size, batch_size=batch_size)
batch_sizes = [len(batch[0]) for batch in dl]
if accelerator.process_index == 0:
assert batch_sizes == process_0_expected_batch_sizes
elif accelerator.process_index == 1:
assert batch_sizes == process_1_expected_batch_sizes
def test_default_ensures_even_batch_sizes():
accelerator = create_accelerator()
# without padding, we would expect a different number of batches
verify_dataloader_batch_sizes(
accelerator,
dataset_size=3,
batch_size=1,
process_0_expected_batch_sizes=[1, 1],
process_1_expected_batch_sizes=[1, 1],
)
# without padding, we would expect the same number of batches, but different sizes
verify_dataloader_batch_sizes(
accelerator,
dataset_size=7,
batch_size=2,
process_0_expected_batch_sizes=[2, 2],
process_1_expected_batch_sizes=[2, 2],
)
def test_can_disable_even_batches():
accelerator = create_accelerator(even_batches=False)
verify_dataloader_batch_sizes(
accelerator,
dataset_size=3,
batch_size=1,
process_0_expected_batch_sizes=[1, 1],
process_1_expected_batch_sizes=[1],
)
verify_dataloader_batch_sizes(
accelerator,
dataset_size=7,
batch_size=2,
process_0_expected_batch_sizes=[2, 2],
process_1_expected_batch_sizes=[2, 1],
)
def test_can_join_uneven_inputs():
accelerator = create_accelerator(even_batches=False)
model = torch.nn.Linear(1, 1)
ddp_model = accelerator.prepare(model)
dl = create_dataloader(accelerator, dataset_size=3, batch_size=1)
batch_idxs = []
with accelerator.join_uneven_inputs([ddp_model]):
for batch_idx, batch in enumerate(dl):
output = ddp_model(batch[0].float())
loss = output.sum()
loss.backward()
batch_idxs.append(batch_idx)
accelerator.wait_for_everyone()
if accelerator.process_index == 0:
assert batch_idxs == [0, 1]
elif accelerator.process_index == 1:
assert batch_idxs == [0]
def test_join_raises_warning_for_non_ddp_distributed(accelerator):
with warnings.catch_warnings(record=True) as w:
with accelerator.join_uneven_inputs([Mock()]):
pass
assert issubclass(w[-1].category, UserWarning)
assert "only supported for multi-GPU" in str(w[-1].message)
def test_join_can_override_even_batches():
default_even_batches = True
overridden_even_batches = False
accelerator = create_accelerator(even_batches=default_even_batches)
model = torch.nn.Linear(1, 1)
ddp_model = accelerator.prepare(model)
train_dl = create_dataloader(accelerator, dataset_size=3, batch_size=1)
valid_dl = create_dataloader(accelerator, dataset_size=3, batch_size=1)
with accelerator.join_uneven_inputs([ddp_model], even_batches=overridden_even_batches):
train_dl_overridden_value = train_dl.batch_sampler.even_batches
valid_dl_overridden_value = valid_dl.batch_sampler.even_batches
assert train_dl_overridden_value == overridden_even_batches
assert valid_dl_overridden_value == overridden_even_batches
assert train_dl.batch_sampler.even_batches == default_even_batches
assert valid_dl.batch_sampler.even_batches == default_even_batches
def test_join_can_override_for_mixed_type_dataloaders():
default_even_batches = True
overridden_even_batches = False
accelerator = create_accelerator(even_batches=default_even_batches)
model = torch.nn.Linear(1, 1)
ddp_model = accelerator.prepare(model)
create_dataloader(accelerator, dataset_size=3, batch_size=1, iterable=True)
batch_dl = create_dataloader(accelerator, dataset_size=3, batch_size=1)
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
try:
with accelerator.join_uneven_inputs([ddp_model], even_batches=overridden_even_batches):
batch_dl_overridden_value = batch_dl.batch_sampler.even_batches
except AttributeError:
# ensure attribute error is not raised when processing iterable dl
raise AssertionError
assert batch_dl_overridden_value == overridden_even_batches
assert batch_dl.batch_sampler.even_batches == default_even_batches
def test_join_raises_warning_for_iterable_when_overriding_even_batches():
accelerator = create_accelerator()
model = torch.nn.Linear(1, 1)
ddp_model = accelerator.prepare(model)
create_dataloader(accelerator, dataset_size=3, batch_size=1, iterable=True)
with warnings.catch_warnings(record=True) as w:
with accelerator.join_uneven_inputs([ddp_model], even_batches=False):
pass
assert issubclass(w[-1].category, UserWarning)
assert "only supported for map-style datasets" in str(w[-1].message)
def test_pickle_accelerator():
accelerator = create_accelerator()
data_loader = create_dataloader(accelerator, dataset_size=32, batch_size=4)
_ = accelerator.prepare(data_loader)
pickled_accelerator = pickle.dumps(accelerator)
unpickled_accelerator = pickle.loads(pickled_accelerator)
# TODO: Maybe this should be implemented as __eq__ for AcceleratorState?
assert accelerator.state.__dict__ == unpickled_accelerator.state.__dict__
def test_data_loader(data_loader, accelerator):
# Prepare the DataLoader
data_loader = accelerator.prepare(data_loader)
all_examples = []
for i, batch in enumerate(data_loader):
index, _ = accelerator.gather_for_metrics((batch["index"], batch["label"]))
all_examples.extend(index.detach().cpu().numpy().tolist())
# Sort the examples
sorted_all_examples = sorted(all_examples)
# Check if all elements are present in the sorted list of iterated samples
assert len(set(sorted_all_examples)) == NUM_ELEMENTS, (
"Not all the dataset elements have been iterated in an epoch due to duplication of samples across processes."
)
def test_stateful_dataloader(accelerator):
"""
Tests that a stateful dataloader can be iterated over, saved after a few batches using `load_state_dict`, and then
resumed from the saved state.
The result should be the same as the rest of the data that iterated over after saving.
"""
old_dataloader_config = accelerator.dataloader_config
try:
accelerator.dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True)
prepared_dl = create_dataloader(
accelerator, dataset_size=32 * accelerator.num_processes, batch_size=4, iterable=True, shuffle=True
)
untrained_batches = []
# Calculate what step that will be
total_batches = 32 * accelerator.num_processes // (4 * accelerator.num_processes)
last_batch_num = total_batches - 1
for step, batch in enumerate(prepared_dl):
# Step just before
if step == last_batch_num - 1:
state_dict = prepared_dl.state_dict()
if step >= last_batch_num:
# Otherwise grab the "unseen" batches
untrained_batches.append(batch)
not_skipped_batches = accelerator.gather(untrained_batches)
prepared_dl.load_state_dict(state_dict)
resumed_batches = []
for batch in prepared_dl:
resumed_batches.append(batch)
resumed_batches = accelerator.gather(resumed_batches)
for b1, b2 in zip(not_skipped_batches, resumed_batches):
for v1, v2 in zip(b1, b2):
assert torch.equal(v1, v2), f"Batch {b1} and {b2} are not equal"
finally:
accelerator.dataloader_config = old_dataloader_config
def test_stateful_dataloader_save_state(accelerator):
"""
Tests that a stateful dataloader can be iterated over, saved after a few batches using `Accelerator.save_state`,
and then resumed from the saved state.
The result should be the same as the rest of the data that iterated over after saving.
"""
old_dataloader_config = accelerator.dataloader_config
try:
with tempfile.TemporaryDirectory() as tmpdir:
accelerator.dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True)
prepared_dl = create_dataloader(
accelerator, dataset_size=32 * accelerator.num_processes, batch_size=4, iterable=True, shuffle=True
)
untrained_batches = []
# Calculate what step that will be
total_batches = 32 * accelerator.num_processes // (4 * accelerator.num_processes)
last_batch_num = total_batches - 1
for step, batch in enumerate(prepared_dl):
# Step just before
if step == last_batch_num - 1:
accelerator.save_state(tmpdir)
if step >= last_batch_num:
# Otherwise grab the "unseen" batches
untrained_batches.append(batch)
not_skipped_batches = accelerator.gather(untrained_batches)
accelerator.load_state(tmpdir)
resumed_batches = []
for batch in prepared_dl:
resumed_batches.append(batch)
resumed_batches = accelerator.gather(resumed_batches)
for b1, b2 in zip(not_skipped_batches, resumed_batches):
for v1, v2 in zip(b1, b2):
assert torch.equal(v1, v2), f"Batch {b1} and {b2} are not equal"
finally:
accelerator.dataloader_config = old_dataloader_config
def main():
accelerator = create_accelerator()
torch.manual_seed(accelerator.process_index)
accelerator.print("Test that even_batches variable ensures uniform batches across processes")
test_default_ensures_even_batch_sizes()
accelerator.print("Run tests with even_batches disabled")
test_can_disable_even_batches()
accelerator.print("Test joining uneven inputs")
test_can_join_uneven_inputs()
accelerator.print("Test overriding even_batches when joining uneven inputs")
test_join_can_override_even_batches()
accelerator.print("Test overriding even_batches for mixed dataloader types")
test_join_can_override_for_mixed_type_dataloaders()
accelerator.print("Test overriding even_batches raises a warning for iterable dataloaders")
test_join_raises_warning_for_iterable_when_overriding_even_batches()
accelerator.print("Test join with non DDP distributed raises warning")
original_state = accelerator.state.distributed_type
accelerator.state.distributed_type = DistributedType.FSDP
test_join_raises_warning_for_non_ddp_distributed(accelerator)
accelerator.state.distributed_type = original_state
accelerator.print("Test pickling an accelerator")
test_pickle_accelerator()
dataset = DummyDataset()
accelerator.print("Test DataLoader with shuffle=False")
loader = DataLoader(dataset, shuffle=False, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
test_data_loader(loader, accelerator)
accelerator.print("Test DataLoader with shuffle=True")
loader = DataLoader(dataset, shuffle=True, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
test_data_loader(loader, accelerator)
accelerator.print("Test DataLoader with batch_sampler")
sampler = BatchSampler(RandomSampler(dataset), batch_size=BATCH_SIZE, drop_last=False)
loader = DataLoader(dataset, batch_sampler=sampler, num_workers=NUM_WORKERS)
test_data_loader(loader, accelerator)
accelerator.print("Test DataLoader with sampler as an instance of `BatchSampler`")
sampler = BatchSampler(RandomSampler(dataset), batch_size=BATCH_SIZE, drop_last=False)
loader = DataLoader(dataset, sampler=sampler, batch_size=None, collate_fn=default_collate, num_workers=NUM_WORKERS)
test_data_loader(loader, accelerator)
test_stateful_dataloader(accelerator)
test_stateful_dataloader_save_state(accelerator)
accelerator.end_training()
if __name__ == "__main__":
main()
|