|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from copy import deepcopy |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch.optim import AdamW |
|
|
from torch.optim.lr_scheduler import LambdaLR |
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
from accelerate.accelerator import Accelerator |
|
|
from accelerate.state import GradientState |
|
|
from accelerate.test_utils import RegressionDataset, RegressionModel |
|
|
from accelerate.utils import DistributedType, is_torch_version, set_seed |
|
|
|
|
|
|
|
|
def check_model_parameters(model_a, model_b, did_step, iteration): |
|
|
for param, grad_param in zip(model_a.parameters(), model_b.parameters()): |
|
|
if not param.requires_grad: |
|
|
continue |
|
|
if not did_step: |
|
|
|
|
|
assert ( |
|
|
torch.allclose(param.grad, grad_param.grad) is False |
|
|
), f"Gradients in sync when they should not be at iteration {iteration}:\nmodel_a grad ({param.grad}) == model_b grad ({grad_param.grad})" |
|
|
else: |
|
|
|
|
|
assert ( |
|
|
torch.allclose(param.grad, grad_param.grad) is True |
|
|
), f"Gradients not in sync when they should be at iteration {iteration}:\nmodel_a grad ({param.grad}) != model_b grad ({grad_param.grad})" |
|
|
|
|
|
|
|
|
def step_model(model, input, target, accelerator, do_backward=True): |
|
|
model.train() |
|
|
output = model(input) |
|
|
loss = F.mse_loss(output, target.to(output.device)) |
|
|
if not do_backward: |
|
|
loss /= accelerator.gradient_accumulation_steps |
|
|
loss.backward() |
|
|
else: |
|
|
accelerator.backward(loss) |
|
|
|
|
|
|
|
|
def get_training_setup(accelerator, sched=False): |
|
|
"Returns everything needed to perform basic training" |
|
|
set_seed(42) |
|
|
model = RegressionModel() |
|
|
ddp_model = deepcopy(model) |
|
|
dset = RegressionDataset(length=80) |
|
|
dataloader = DataLoader(dset, batch_size=16) |
|
|
model.to(accelerator.device) |
|
|
if sched: |
|
|
opt = AdamW(params=model.parameters(), lr=1e-3) |
|
|
ddp_opt = AdamW(params=ddp_model.parameters(), lr=1e-3) |
|
|
sched = LambdaLR(opt, lr_lambda=lambda epoch: epoch**0.65) |
|
|
ddp_sched = LambdaLR(ddp_opt, lr_lambda=lambda epoch: epoch**0.65) |
|
|
|
|
|
if sched: |
|
|
ddp_model, ddp_opt, ddp_sched, dataloader = accelerator.prepare(ddp_model, ddp_opt, ddp_sched, dataloader) |
|
|
else: |
|
|
ddp_model, dataloader = accelerator.prepare(ddp_model, dataloader) |
|
|
if sched: |
|
|
return (model, opt, sched, dataloader, ddp_model, ddp_opt, ddp_sched) |
|
|
return model, ddp_model, dataloader |
|
|
|
|
|
|
|
|
def test_noop_sync(accelerator): |
|
|
|
|
|
model, ddp_model, dataloader = get_training_setup(accelerator) |
|
|
|
|
|
ddp_input, ddp_target = next(iter(dataloader)).values() |
|
|
for iteration in range(3): |
|
|
|
|
|
input, target = accelerator.gather((ddp_input, ddp_target)) |
|
|
input, target = input.to(accelerator.device), target.to(accelerator.device) |
|
|
|
|
|
step_model(model, input, target, accelerator) |
|
|
|
|
|
if iteration % 2 == 0: |
|
|
|
|
|
with accelerator.no_sync(ddp_model): |
|
|
step_model(ddp_model, ddp_input, ddp_target, accelerator) |
|
|
else: |
|
|
|
|
|
step_model(ddp_model, ddp_input, ddp_target, accelerator) |
|
|
|
|
|
|
|
|
check_model_parameters(model, ddp_model, True, iteration) |
|
|
for param, ddp_param in zip(model.parameters(), ddp_model.parameters()): |
|
|
if not param.requires_grad: |
|
|
continue |
|
|
assert torch.allclose( |
|
|
param.grad, ddp_param.grad |
|
|
), f"Gradients not in sync when they should be:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})" |
|
|
|
|
|
|
|
|
torch.manual_seed(1337 + iteration) |
|
|
ddp_input = ddp_input[torch.randperm(len(ddp_input))] |
|
|
|
|
|
|
|
|
def test_distributed_sync(accelerator): |
|
|
|
|
|
model, ddp_model, dataloader = get_training_setup(accelerator) |
|
|
|
|
|
ddp_input, ddp_target = next(iter(dataloader)).values() |
|
|
for iteration in range(3): |
|
|
|
|
|
input, target = accelerator.gather((ddp_input, ddp_target)) |
|
|
input, target = input.to(accelerator.device), target.to(accelerator.device) |
|
|
|
|
|
step_model(model, input, target, accelerator) |
|
|
|
|
|
if iteration % 2 == 0: |
|
|
|
|
|
with accelerator.no_sync(ddp_model): |
|
|
step_model(ddp_model, ddp_input, ddp_target, accelerator) |
|
|
else: |
|
|
|
|
|
step_model(ddp_model, ddp_input, ddp_target, accelerator) |
|
|
|
|
|
|
|
|
for param, ddp_param in zip(model.parameters(), ddp_model.parameters()): |
|
|
if not param.requires_grad: |
|
|
continue |
|
|
if iteration % 2 == 0: |
|
|
|
|
|
assert ( |
|
|
torch.allclose(param.grad, ddp_param.grad) is False |
|
|
), f"Gradients in sync when they should not be:\nModel grad ({param.grad}) == DDP grad ({ddp_param.grad})" |
|
|
else: |
|
|
|
|
|
assert ( |
|
|
torch.allclose(param.grad, ddp_param.grad) is True |
|
|
), f"Gradients not in sync when they should be:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})" |
|
|
|
|
|
|
|
|
torch.manual_seed(1337 + iteration) |
|
|
ddp_input = ddp_input[torch.randperm(len(ddp_input))] |
|
|
|
|
|
|
|
|
def test_gradient_accumulation(split_batches=False, dispatch_batches=False): |
|
|
accelerator = Accelerator( |
|
|
split_batches=split_batches, dispatch_batches=dispatch_batches, gradient_accumulation_steps=2 |
|
|
) |
|
|
|
|
|
model, ddp_model, dataloader = get_training_setup(accelerator) |
|
|
for iteration, batch in enumerate(dataloader): |
|
|
ddp_input, ddp_target = batch.values() |
|
|
|
|
|
input, target = accelerator.gather((ddp_input, ddp_target)) |
|
|
input, target = input.to(accelerator.device), target.to(accelerator.device) |
|
|
|
|
|
step_model(model, input, target, accelerator, False) |
|
|
|
|
|
with accelerator.accumulate(ddp_model): |
|
|
step_model(ddp_model, ddp_input, ddp_target, accelerator) |
|
|
|
|
|
|
|
|
for param, ddp_param in zip(model.parameters(), ddp_model.parameters()): |
|
|
if not param.requires_grad: |
|
|
continue |
|
|
if ((iteration + 1) % 2 == 0) or (iteration == len(dataloader) - 1): |
|
|
|
|
|
assert ( |
|
|
torch.allclose(param.grad, ddp_param.grad) is True |
|
|
), f"Gradients not in sync when they should be at iteration {iteration}:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})" |
|
|
else: |
|
|
|
|
|
assert ( |
|
|
torch.allclose(param.grad, ddp_param.grad) is False |
|
|
), f"Gradients in sync when they should not be at iteration {iteration}:\nModel grad ({param.grad}) == DDP grad ({ddp_param.grad})" |
|
|
|
|
|
|
|
|
torch.manual_seed(1337 + iteration) |
|
|
ddp_input = ddp_input[torch.randperm(len(ddp_input))] |
|
|
GradientState._reset_state() |
|
|
|
|
|
|
|
|
def test_gradient_accumulation_with_opt_and_scheduler(split_batches=False, dispatch_batches=False): |
|
|
accelerator = Accelerator( |
|
|
split_batches=split_batches, dispatch_batches=dispatch_batches, gradient_accumulation_steps=2 |
|
|
) |
|
|
|
|
|
model, opt, sched, dataloader, ddp_model, ddp_opt, ddp_sched = get_training_setup(accelerator, True) |
|
|
for iteration, batch in enumerate(dataloader): |
|
|
ddp_input, ddp_target = batch.values() |
|
|
|
|
|
input, target = accelerator.gather((ddp_input, ddp_target)) |
|
|
input, target = input.to(accelerator.device), target.to(accelerator.device) |
|
|
|
|
|
model.train() |
|
|
ddp_model.train() |
|
|
step_model(model, input, target, accelerator, False) |
|
|
opt.step() |
|
|
|
|
|
if ((iteration + 1) % 2 == 0) or ((iteration + 1) == len(dataloader)): |
|
|
if split_batches: |
|
|
sched.step() |
|
|
else: |
|
|
for _ in range(accelerator.num_processes): |
|
|
sched.step() |
|
|
opt.zero_grad() |
|
|
|
|
|
with accelerator.accumulate(ddp_model): |
|
|
step_model(ddp_model, ddp_input, ddp_target, accelerator) |
|
|
ddp_opt.step() |
|
|
ddp_sched.step() |
|
|
ddp_opt.zero_grad() |
|
|
|
|
|
|
|
|
assert ( |
|
|
opt.param_groups[0]["lr"] == ddp_opt.param_groups[0]["lr"] |
|
|
), f'Learning rates found in each optimizer did not align\nopt: {opt.param_groups[0]["lr"]}\nDDP opt: {ddp_opt.param_groups[0]["lr"]}\n' |
|
|
did_step = (((iteration + 1) % 2) == 0) or ((iteration + 1) == len(dataloader)) |
|
|
if accelerator.num_processes > 1: |
|
|
check_model_parameters(model, ddp_model, did_step, iteration) |
|
|
|
|
|
torch.manual_seed(1337 + iteration) |
|
|
GradientState._reset_state() |
|
|
|
|
|
|
|
|
def test_dataloader_break(): |
|
|
accelerator = Accelerator() |
|
|
|
|
|
first_dset = RegressionDataset(length=80) |
|
|
first_dataloader = DataLoader(first_dset, batch_size=16) |
|
|
second_dset = RegressionDataset(length=96) |
|
|
second_dataloader = DataLoader(second_dset, batch_size=16) |
|
|
first_dataloader, second_dataloader = accelerator.prepare(first_dataloader, second_dataloader) |
|
|
for iteration, _ in enumerate(first_dataloader): |
|
|
|
|
|
if iteration < len(first_dataloader) - 1: |
|
|
assert id(accelerator.gradient_state.active_dataloader) == id(first_dataloader) |
|
|
if iteration == 1: |
|
|
for batch_num, _ in enumerate(second_dataloader): |
|
|
if batch_num < len(second_dataloader) - 1: |
|
|
assert id(accelerator.gradient_state.active_dataloader) == id( |
|
|
second_dataloader |
|
|
), f"First dataloader: {id(first_dataloader)}\nSecond dataloader: {id(second_dataloader)}\nActive dataloader: {id(accelerator.gradient_state.active_dataloader)}\n" |
|
|
else: |
|
|
assert id(accelerator.gradient_state.active_dataloader) == id( |
|
|
first_dataloader |
|
|
), f"First dataloader: {id(first_dataloader)}\nSecond dataloader: {id(second_dataloader)}\nActive dataloader: {id(accelerator.gradient_state.active_dataloader)}\n" |
|
|
else: |
|
|
assert accelerator.gradient_state.active_dataloader is None |
|
|
|
|
|
|
|
|
def main(): |
|
|
accelerator = Accelerator() |
|
|
state = accelerator.state |
|
|
if state.local_process_index == 0: |
|
|
print("**Test `accumulate` gradient accumulation with dataloader break**") |
|
|
test_dataloader_break() |
|
|
if state.distributed_type == DistributedType.NO: |
|
|
if state.local_process_index == 0: |
|
|
print("**Test NOOP `no_sync` context manager**") |
|
|
test_noop_sync(accelerator) |
|
|
if state.distributed_type in (DistributedType.MULTI_GPU, DistributedType.MULTI_CPU): |
|
|
if state.local_process_index == 0: |
|
|
print("**Test Distributed `no_sync` context manager**") |
|
|
test_distributed_sync(accelerator) |
|
|
if state.distributed_type == DistributedType.MULTI_GPU: |
|
|
for split_batch in [True, False]: |
|
|
for dispatch_batches in [True, False]: |
|
|
if state.local_process_index == 0: |
|
|
print( |
|
|
"**Test `accumulate` gradient accumulation, ", |
|
|
f"`split_batches={split_batch}` and `dispatch_batches={dispatch_batches}`**", |
|
|
) |
|
|
test_gradient_accumulation(split_batch, dispatch_batches) |
|
|
|
|
|
|
|
|
if is_torch_version("<", "2.0") or state.distributed_type == DistributedType.NO: |
|
|
if state.local_process_index == 0: |
|
|
print( |
|
|
"**Test `accumulate` gradient accumulation with optimizer and scheduler, ", |
|
|
"`split_batches=False`, `dispatch_batches=False`**", |
|
|
) |
|
|
test_gradient_accumulation_with_opt_and_scheduler() |
|
|
if state.distributed_type == DistributedType.MULTI_GPU: |
|
|
for split_batch in [True, False]: |
|
|
for dispatch_batches in [True, False]: |
|
|
if not split_batch and not dispatch_batches: |
|
|
continue |
|
|
if state.local_process_index == 0: |
|
|
print( |
|
|
"**Test `accumulate` gradient accumulation with optimizer and scheduler, ", |
|
|
f"`split_batches={split_batch}` and `dispatch_batches={dispatch_batches}`**", |
|
|
) |
|
|
test_gradient_accumulation_with_opt_and_scheduler(split_batch, dispatch_batches) |
|
|
|
|
|
|
|
|
def _mp_fn(index): |
|
|
|
|
|
main() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|