Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| # Copyright 2021 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import contextlib | |
| import io | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from accelerate import Accelerator | |
| from accelerate.data_loader import prepare_data_loader | |
| from accelerate.state import AcceleratorState | |
| from accelerate.test_utils import RegressionDataset, RegressionModel, are_the_same_tensors | |
| from accelerate.utils import ( | |
| DistributedType, | |
| gather, | |
| is_bf16_available, | |
| is_torch_version, | |
| set_seed, | |
| synchronize_rng_states, | |
| ) | |
| def print_main(state): | |
| print(f"Printing from the main process {state.process_index}") | |
| def print_local_main(state): | |
| print(f"Printing from the local main process {state.local_process_index}") | |
| def print_last(state): | |
| print(f"Printing from the last process {state.process_index}") | |
| def print_on(state, process_idx): | |
| print(f"Printing from process {process_idx}: {state.process_index}") | |
| def process_execution_check(): | |
| accelerator = Accelerator() | |
| num_processes = accelerator.num_processes | |
| with accelerator.main_process_first(): | |
| idx = torch.tensor(accelerator.process_index).to(accelerator.device) | |
| idxs = accelerator.gather(idx) | |
| if num_processes > 1: | |
| assert idxs[0] == 0, "Main process was not first." | |
| # Test the decorators | |
| f = io.StringIO() | |
| with contextlib.redirect_stdout(f): | |
| accelerator.on_main_process(print_main)(accelerator.state) | |
| result = f.getvalue().rstrip() | |
| if accelerator.is_main_process: | |
| assert result == "Printing from the main process 0", f"{result} != Printing from the main process 0" | |
| else: | |
| assert f.getvalue().rstrip() == "", f'{result} != ""' | |
| f.truncate(0) | |
| f.seek(0) | |
| with contextlib.redirect_stdout(f): | |
| accelerator.on_local_main_process(print_local_main)(accelerator.state) | |
| if accelerator.is_local_main_process: | |
| assert f.getvalue().rstrip() == "Printing from the local main process 0" | |
| else: | |
| assert f.getvalue().rstrip() == "" | |
| f.truncate(0) | |
| f.seek(0) | |
| with contextlib.redirect_stdout(f): | |
| accelerator.on_last_process(print_last)(accelerator.state) | |
| if accelerator.is_last_process: | |
| assert f.getvalue().rstrip() == f"Printing from the last process {accelerator.state.num_processes - 1}" | |
| else: | |
| assert f.getvalue().rstrip() == "" | |
| f.truncate(0) | |
| f.seek(0) | |
| for process_idx in range(num_processes): | |
| with contextlib.redirect_stdout(f): | |
| accelerator.on_process(print_on, process_index=process_idx)(accelerator.state, process_idx) | |
| if accelerator.process_index == process_idx: | |
| assert f.getvalue().rstrip() == f"Printing from process {process_idx}: {accelerator.process_index}" | |
| else: | |
| assert f.getvalue().rstrip() == "" | |
| f.truncate(0) | |
| f.seek(0) | |
| def init_state_check(): | |
| # Test we can instantiate this twice in a row. | |
| state = AcceleratorState() | |
| if state.local_process_index == 0: | |
| print("Testing, testing. 1, 2, 3.") | |
| print(state) | |
| def rng_sync_check(): | |
| state = AcceleratorState() | |
| synchronize_rng_states(["torch"]) | |
| assert are_the_same_tensors(torch.get_rng_state()), "RNG states improperly synchronized on CPU." | |
| if state.distributed_type == DistributedType.MULTI_GPU: | |
| synchronize_rng_states(["cuda"]) | |
| assert are_the_same_tensors(torch.cuda.get_rng_state()), "RNG states improperly synchronized on GPU." | |
| generator = torch.Generator() | |
| synchronize_rng_states(["generator"], generator=generator) | |
| assert are_the_same_tensors(generator.get_state()), "RNG states improperly synchronized in generator." | |
| if state.local_process_index == 0: | |
| print("All rng are properly synched.") | |
| def dl_preparation_check(): | |
| state = AcceleratorState() | |
| length = 32 * state.num_processes | |
| dl = DataLoader(range(length), batch_size=8) | |
| dl = prepare_data_loader(dl, state.device, state.num_processes, state.process_index, put_on_device=True) | |
| result = [] | |
| for batch in dl: | |
| result.append(gather(batch)) | |
| result = torch.cat(result) | |
| print(state.process_index, result, type(dl)) | |
| assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result." | |
| dl = DataLoader(range(length), batch_size=8) | |
| dl = prepare_data_loader( | |
| dl, | |
| state.device, | |
| state.num_processes, | |
| state.process_index, | |
| put_on_device=True, | |
| split_batches=True, | |
| ) | |
| result = [] | |
| for batch in dl: | |
| result.append(gather(batch)) | |
| result = torch.cat(result) | |
| assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result." | |
| if state.process_index == 0: | |
| print("Non-shuffled dataloader passing.") | |
| dl = DataLoader(range(length), batch_size=8, shuffle=True) | |
| dl = prepare_data_loader(dl, state.device, state.num_processes, state.process_index, put_on_device=True) | |
| result = [] | |
| for batch in dl: | |
| result.append(gather(batch)) | |
| result = torch.cat(result).tolist() | |
| result.sort() | |
| assert result == list(range(length)), "Wrong shuffled dataloader result." | |
| dl = DataLoader(range(length), batch_size=8, shuffle=True) | |
| dl = prepare_data_loader( | |
| dl, | |
| state.device, | |
| state.num_processes, | |
| state.process_index, | |
| put_on_device=True, | |
| split_batches=True, | |
| ) | |
| result = [] | |
| for batch in dl: | |
| result.append(gather(batch)) | |
| result = torch.cat(result).tolist() | |
| result.sort() | |
| assert result == list(range(length)), "Wrong shuffled dataloader result." | |
| if state.local_process_index == 0: | |
| print("Shuffled dataloader passing.") | |
| def central_dl_preparation_check(): | |
| state = AcceleratorState() | |
| length = 32 * state.num_processes | |
| dl = DataLoader(range(length), batch_size=8) | |
| dl = prepare_data_loader( | |
| dl, state.device, state.num_processes, state.process_index, put_on_device=True, dispatch_batches=True | |
| ) | |
| result = [] | |
| for batch in dl: | |
| result.append(gather(batch)) | |
| result = torch.cat(result) | |
| assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result." | |
| dl = DataLoader(range(length), batch_size=8) | |
| dl = prepare_data_loader( | |
| dl, | |
| state.device, | |
| state.num_processes, | |
| state.process_index, | |
| put_on_device=True, | |
| split_batches=True, | |
| dispatch_batches=True, | |
| ) | |
| result = [] | |
| for batch in dl: | |
| result.append(gather(batch)) | |
| result = torch.cat(result) | |
| assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result." | |
| if state.process_index == 0: | |
| print("Non-shuffled central dataloader passing.") | |
| dl = DataLoader(range(length), batch_size=8, shuffle=True) | |
| dl = prepare_data_loader( | |
| dl, state.device, state.num_processes, state.process_index, put_on_device=True, dispatch_batches=True | |
| ) | |
| result = [] | |
| for batch in dl: | |
| result.append(gather(batch)) | |
| result = torch.cat(result).tolist() | |
| result.sort() | |
| assert result == list(range(length)), "Wrong shuffled dataloader result." | |
| dl = DataLoader(range(length), batch_size=8, shuffle=True) | |
| dl = prepare_data_loader( | |
| dl, | |
| state.device, | |
| state.num_processes, | |
| state.process_index, | |
| put_on_device=True, | |
| split_batches=True, | |
| dispatch_batches=True, | |
| ) | |
| result = [] | |
| for batch in dl: | |
| result.append(gather(batch)) | |
| result = torch.cat(result).tolist() | |
| result.sort() | |
| assert result == list(range(length)), "Wrong shuffled dataloader result." | |
| if state.local_process_index == 0: | |
| print("Shuffled central dataloader passing.") | |
| def mock_training(length, batch_size, generator): | |
| set_seed(42) | |
| generator.manual_seed(42) | |
| train_set = RegressionDataset(length=length) | |
| train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator) | |
| model = RegressionModel() | |
| optimizer = torch.optim.SGD(model.parameters(), lr=0.1) | |
| for epoch in range(3): | |
| for batch in train_dl: | |
| model.zero_grad() | |
| output = model(batch["x"]) | |
| loss = torch.nn.functional.mse_loss(output, batch["y"]) | |
| loss.backward() | |
| optimizer.step() | |
| return train_set, model | |
| def training_check(): | |
| state = AcceleratorState() | |
| generator = torch.Generator() | |
| batch_size = 8 | |
| length = batch_size * 4 * state.num_processes | |
| train_set, old_model = mock_training(length, batch_size * state.num_processes, generator) | |
| assert are_the_same_tensors(old_model.a), "Did not obtain the same model on both processes." | |
| assert are_the_same_tensors(old_model.b), "Did not obtain the same model on both processes." | |
| accelerator = Accelerator() | |
| train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator) | |
| model = RegressionModel() | |
| optimizer = torch.optim.SGD(model.parameters(), lr=0.1) | |
| train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer) | |
| set_seed(42) | |
| generator.manual_seed(42) | |
| for epoch in range(3): | |
| for batch in train_dl: | |
| model.zero_grad() | |
| output = model(batch["x"]) | |
| loss = torch.nn.functional.mse_loss(output, batch["y"]) | |
| accelerator.backward(loss) | |
| optimizer.step() | |
| model = accelerator.unwrap_model(model).cpu() | |
| assert torch.allclose(old_model.a, model.a), "Did not obtain the same model on CPU or distributed training." | |
| assert torch.allclose(old_model.b, model.b), "Did not obtain the same model on CPU or distributed training." | |
| accelerator.print("Training yielded the same results on one CPU or distributed setup with no batch split.") | |
| accelerator = Accelerator(split_batches=True) | |
| train_dl = DataLoader(train_set, batch_size=batch_size * state.num_processes, shuffle=True, generator=generator) | |
| model = RegressionModel() | |
| optimizer = torch.optim.SGD(model.parameters(), lr=0.1) | |
| train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer) | |
| set_seed(42) | |
| generator.manual_seed(42) | |
| for _ in range(3): | |
| for batch in train_dl: | |
| model.zero_grad() | |
| output = model(batch["x"]) | |
| loss = torch.nn.functional.mse_loss(output, batch["y"]) | |
| accelerator.backward(loss) | |
| optimizer.step() | |
| model = accelerator.unwrap_model(model).cpu() | |
| assert torch.allclose(old_model.a, model.a), "Did not obtain the same model on CPU or distributed training." | |
| assert torch.allclose(old_model.b, model.b), "Did not obtain the same model on CPU or distributed training." | |
| accelerator.print("Training yielded the same results on one CPU or distributes setup with batch split.") | |
| if torch.cuda.is_available(): | |
| # Mostly a test that FP16 doesn't crash as the operation inside the model is not converted to FP16 | |
| print("FP16 training check.") | |
| AcceleratorState._reset_state() | |
| accelerator = Accelerator(mixed_precision="fp16") | |
| train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator) | |
| model = RegressionModel() | |
| optimizer = torch.optim.SGD(model.parameters(), lr=0.1) | |
| train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer) | |
| set_seed(42) | |
| generator.manual_seed(42) | |
| for _ in range(3): | |
| for batch in train_dl: | |
| model.zero_grad() | |
| output = model(batch["x"]) | |
| loss = torch.nn.functional.mse_loss(output, batch["y"]) | |
| accelerator.backward(loss) | |
| optimizer.step() | |
| model = accelerator.unwrap_model(model).cpu() | |
| assert torch.allclose(old_model.a, model.a), "Did not obtain the same model on CPU or distributed training." | |
| assert torch.allclose(old_model.b, model.b), "Did not obtain the same model on CPU or distributed training." | |
| # BF16 support is only for CPU + TPU, and some GPU | |
| if is_bf16_available(): | |
| # Mostly a test that BF16 doesn't crash as the operation inside the model is not converted to BF16 | |
| print("BF16 training check.") | |
| AcceleratorState._reset_state() | |
| accelerator = Accelerator(mixed_precision="bf16") | |
| train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator) | |
| model = RegressionModel() | |
| optimizer = torch.optim.SGD(model.parameters(), lr=0.1) | |
| train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer) | |
| set_seed(42) | |
| generator.manual_seed(42) | |
| for _ in range(3): | |
| for batch in train_dl: | |
| model.zero_grad() | |
| output = model(batch["x"]) | |
| loss = torch.nn.functional.mse_loss(output, batch["y"]) | |
| accelerator.backward(loss) | |
| optimizer.step() | |
| model = accelerator.unwrap_model(model).cpu() | |
| assert torch.allclose(old_model.a, model.a), "Did not obtain the same model on CPU or distributed training." | |
| assert torch.allclose(old_model.b, model.b), "Did not obtain the same model on CPU or distributed training." | |
| def main(): | |
| accelerator = Accelerator() | |
| state = accelerator.state | |
| if state.local_process_index == 0: | |
| print("**Initialization**") | |
| init_state_check() | |
| if state.local_process_index == 0: | |
| print("\n**Test process execution**") | |
| process_execution_check() | |
| if state.local_process_index == 0: | |
| print("\n**Test random number generator synchronization**") | |
| rng_sync_check() | |
| if state.local_process_index == 0: | |
| print("\n**DataLoader integration test**") | |
| dl_preparation_check() | |
| if state.distributed_type != DistributedType.TPU and is_torch_version(">=", "1.8.0"): | |
| central_dl_preparation_check() | |
| # Trainings are not exactly the same in DeepSpeed and CPU mode | |
| if state.distributed_type == DistributedType.DEEPSPEED: | |
| return | |
| if state.local_process_index == 0: | |
| print("\n**Training integration test**") | |
| training_check() | |
| def _mp_fn(index): | |
| # For xla_spawn (TPUs) | |
| main() | |
| if __name__ == "__main__": | |
| main() | |