| | from typing import Tuple, List |
| |
|
| | import torch |
| | import unittest |
| |
|
| | from apex.transformer import parallel_state |
| | from apex.transformer.pipeline_parallel.utils import get_num_microbatches |
| | from apex.transformer.pipeline_parallel.schedules.common import ( |
| | _get_params_for_weight_decay_optimization, build_model |
| | ) |
| | from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import ( |
| | _forward_backward_pipelining_with_interleaving, |
| | ) |
| | from apex.transformer.pipeline_parallel.utils import ( |
| | setup_microbatch_calculator, _reconfigure_microbatch_calculator, update_num_microbatches |
| | ) |
| | from apex.transformer.testing import global_vars |
| | from apex.transformer.testing.commons import ( |
| | print_separator, fwd_step_func, model_provider_func |
| | ) |
| | from apex.transformer.log_util import get_transformer_logger |
| | from apex.transformer._data import MegatronPretrainingRandomSampler, MegatronPretrainingSampler |
| | from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase |
| |
|
| | from torch.testing._internal import common_utils |
| |
|
| | |
| | |
| | _logger = get_transformer_logger("pipeline_parallel_test") |
| | |
| | |
| |
|
| |
|
| | NUM_ITERATIONS = 20 |
| | NUM_SAMPLES = 16384 // 2 |
| | HIDDEN_SIZE = 16 |
| |
|
| |
|
| | def Dataset(num_samples: int) -> List[Tuple[torch.Tensor, torch.Tensor]]: |
| | return [ |
| | ( |
| | torch.randn(HIDDEN_SIZE, HIDDEN_SIZE), |
| | torch.randn(HIDDEN_SIZE // 2, HIDDEN_SIZE // 2), |
| | ) |
| | for _ in range(num_samples) |
| | ] |
| |
|
| |
|
| | |
| | def run_interleaved_with_dynamic_batch_size( |
| | pipeline_model_parallel_size: int, forward_only: bool, BatchSamplerCls, |
| | ) -> None: |
| | args = global_vars.get_args() |
| | _reconfigure_microbatch_calculator( |
| | args.rank, |
| | args.rampup_batch_size, |
| | args.global_batch_size, |
| | args.micro_batch_size, |
| | 1, |
| | ) |
| | virtual_pipeline_model_parallel_size = 2 |
| | |
| | |
| | |
| | parallel_state.initialize_model_parallel( |
| | 1, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size |
| | ) |
| | pipeline_model_parallel_size = ( |
| | parallel_state.get_pipeline_model_parallel_world_size() |
| | ) |
| |
|
| | print_separator( |
| | f"BatchSamplerCls: {BatchSamplerCls.__name__}, forward_only: {forward_only}" |
| | ) |
| |
|
| | model = build_model( |
| | model_provider_func, |
| | wrap_with_ddp=True, |
| | virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size, |
| | hidden_size=HIDDEN_SIZE, |
| | ) |
| | assert isinstance(model, list) |
| | assert len(model) == virtual_pipeline_model_parallel_size |
| | optimizer = torch.optim.Adam( |
| | _get_params_for_weight_decay_optimization(model)) |
| |
|
| | initial_local_minibatch_size = get_num_microbatches() * args.micro_batch_size |
| | dataset = Dataset(NUM_SAMPLES) |
| | data_loader = torch.utils.data.DataLoader( |
| | dataset, |
| | batch_sampler=BatchSamplerCls( |
| | NUM_SAMPLES, |
| | 0, |
| | initial_local_minibatch_size, |
| | parallel_state.get_data_parallel_rank(), |
| | parallel_state.get_data_parallel_world_size(), |
| | ), |
| | ) |
| | data_iter = iter(data_loader) |
| |
|
| | def get_num_samples(batch): |
| | if isinstance(batch, torch.Tensor): |
| | return len(batch) |
| | assert isinstance(batch, (list, tuple)) |
| | return [get_num_samples(b) for b in batch] |
| |
|
| | tensor_shape = [args.micro_batch_size, HIDDEN_SIZE, HIDDEN_SIZE] |
| | consumed_samples = 0 |
| | for i in range(NUM_ITERATIONS): |
| | update_num_microbatches(consumed_samples, consistency_check=False) |
| | local_batch_size = get_num_microbatches() * args.micro_batch_size |
| | data_iter._index_sampler.local_minibatch_size = local_batch_size |
| | local_mini_batch = next(data_iter) |
| |
|
| | _logger.info( |
| | f"iter: {i} / {NUM_ITERATIONS} " |
| | f"local batchsize: {get_num_samples(local_mini_batch)} " |
| | f"consumed_samples: {consumed_samples} / {NUM_SAMPLES}" |
| | ) |
| | _forward_backward_pipelining_with_interleaving( |
| | fwd_step_func, |
| | local_mini_batch, |
| | model, |
| | forward_only=forward_only, |
| | tensor_shape=tensor_shape, |
| | ) |
| |
|
| | consumed_samples += ( |
| | parallel_state.get_data_parallel_world_size() |
| | * get_num_microbatches() |
| | * args.micro_batch_size |
| | ) |
| |
|
| | if not forward_only: |
| | for m in model: |
| | for p in m.parameters(): |
| | if p.grad is None: |
| | raise RuntimeError("grad not found") |
| | else: |
| | optimizer.zero_grad(set_to_none=True) |
| |
|
| | torch.cuda.synchronize() |
| |
|
| |
|
| | class DynamicBatchsizeTestBase: |
| | @unittest.skipUnless(torch.cuda.device_count() > 2, "requires at least 3 gpus") |
| | def test_dynamic_batchsize(self): |
| |
|
| | n_tests = 0 |
| | failures = [] |
| |
|
| | override_args = { |
| | "micro_batch_size": 2, |
| | "num_layers": 16, |
| | "hidden_size": 256, |
| | "num_attention_heads": 8, |
| | "max_position_embeddings": 512, |
| | "seq_length": 512, |
| | "global_batch_size": 128, |
| | "use_cpu_initialization": True, |
| | "world_size": self.world_size, |
| | "rank": self.rank, |
| | } |
| |
|
| | global_vars.set_global_variables( |
| | args_defaults={"global_batch_size": 512, |
| | "rampup_batch_size": [64, 64, 1000], }, |
| | ignore_unknown_args=True, |
| | override_args=override_args, |
| | ) |
| |
|
| | args = global_vars.get_args() |
| |
|
| | setup_microbatch_calculator( |
| | args.rank, |
| | args.rampup_batch_size, |
| | args.global_batch_size, |
| | args.micro_batch_size, |
| | 1, |
| | ) |
| | for BatchSamplerCls in ( |
| | MegatronPretrainingSampler, |
| | MegatronPretrainingRandomSampler, |
| | ): |
| | for forward_only in (False, True): |
| | n_tests += 1 |
| | pipeline_model_parallel_size = self.world_size |
| | try: |
| | run_interleaved_with_dynamic_batch_size( |
| | pipeline_model_parallel_size, forward_only, BatchSamplerCls, |
| | ) |
| | except Exception as e: |
| | msg = ( |
| | f"\tforward_only: {forward_only}\n" |
| | f"pipeline rank: {parallel_state.get_pipeline_model_parallel_rank()}, " |
| | f"virtual pipeline rank: {parallel_state.get_virtual_pipeline_model_parallel_rank()}\n" |
| | f"{str(e)}" |
| | ) |
| | raise RuntimeError(msg) |
| | finally: |
| | parallel_state.destroy_model_parallel() |
| | if failures: |
| | print_separator("TEST FAILED:") |
| | print("\n".join(failures)) |
| | msg = f"{len(failures)} / {n_tests} cases failed" |
| | raise RuntimeError(msg) |
| | else: |
| | if torch.distributed.get_rank() == 0: |
| | print_separator("TEST RESULT: ### PASS!") |
| |
|
| |
|
| | class NcclDynamicBatchsizeTest(DynamicBatchsizeTestBase, NcclDistributedTestBase): |
| | pass |
| |
|
| | |
| |
|
| |
|
| | if __name__ == "__main__": |
| | torch.backends.cuda.matmul.allow_tf32 = False |
| | common_utils.run_tests() |
| |
|