| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import itertools |
| from typing import Any, List, Optional, Union |
|
|
| import numpy as np |
| import torch |
| from omegaconf.dictconfig import DictConfig |
| from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin |
| from pytorch_lightning.trainer.trainer import Trainer |
|
|
| from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import ( |
| MegatronPretrainingRandomSampler, |
| MegatronPretrainingSampler, |
| ) |
| from nemo.collections.nlp.data.language_modeling.megatron.gpt_dataset import build_train_valid_test_datasets |
| from nemo.collections.nlp.models.language_modeling.megatron.gpt_model import GPTModel |
| from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel |
| from nemo.collections.nlp.modules.common.megatron.module import Float16Module |
| from nemo.collections.nlp.modules.common.megatron.utils import ( |
| average_losses_across_data_parallel_group, |
| get_all_params_for_weight_decay_optimization, |
| get_params_for_weight_decay_optimization, |
| ) |
| from nemo.collections.nlp.modules.common.text_generation_utils import ( |
| generate, |
| get_computeprob_response, |
| get_default_length_params, |
| get_default_sampling_params, |
| megatron_gpt_generate, |
| ) |
| from nemo.collections.nlp.modules.common.transformer.text_generation import ( |
| LengthParam, |
| OutputType, |
| SamplingParam, |
| TextGeneration, |
| ) |
| from nemo.collections.nlp.parts.nlp_overrides import GradScaler |
| from nemo.collections.nlp.parts.utils_funcs import get_last_rank |
| from nemo.core.classes.common import PretrainedModelInfo |
| from nemo.utils import logging |
|
|
| try: |
| from apex.transformer import parallel_state |
| from apex.transformer.pipeline_parallel.schedules.common import build_model |
| from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import forward_backward_no_pipelining |
| from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import ( |
| _forward_backward_pipelining_with_interleaving, |
| ) |
| from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import ( |
| forward_backward_pipelining_without_interleaving, |
| ) |
|
|
| HAVE_APEX = True |
| except (ImportError, ModuleNotFoundError): |
| HAVE_APEX = False |
|
|
| try: |
| import transformer_engine |
|
|
| HAVE_TE = True |
|
|
| except (ImportError, ModuleNotFoundError): |
| HAVE_TE = False |
|
|
|
|
| class MegatronGPTModel(MegatronBaseModel, TextGeneration): |
| """ |
| Megatron GPT pretraining |
| """ |
|
|
| def __init__(self, cfg: DictConfig, trainer: Trainer): |
| if not HAVE_APEX: |
| raise ImportError( |
| "Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." |
| ) |
| |
| self.tokenizer = None |
| super().__init__(cfg, trainer=trainer, no_lm_init=True) |
|
|
| self._validate_trainer() |
|
|
| self.megatron_amp_o2 = cfg.get('megatron_amp_O2', False) |
|
|
| if not self.megatron_amp_o2 and self.cfg.get('virtual_pipeline_model_parallel_size', None): |
| raise ValueError('Virtual pipeline model parallel is only supported when using megatron_amp_O2') |
|
|
| |
| self.model = build_model( |
| model_provider_func=self.model_provider_func, |
| wrap_with_ddp=False, |
| virtual_pipeline_model_parallel_size=self.cfg.get('virtual_pipeline_model_parallel_size', None), |
| ) |
|
|
| |
| if self.cfg.get('virtual_pipeline_model_parallel_size', None) is None: |
| self.model = self.model[0] |
|
|
| if self.megatron_amp_o2: |
|
|
| if not self.with_distributed_adam: |
| |
| if isinstance(self.model, list): |
| for module in self.model: |
| module.cuda(torch.cuda.current_device()) |
| else: |
| self.model.cuda(torch.cuda.current_device()) |
|
|
| |
| if isinstance(self.model, list): |
| converted_model = [] |
| for module in self.model: |
| converted_model.append(Float16Module(module=module, precision=cfg.precision)) |
| self.model = converted_model |
| else: |
| self.model = Float16Module(module=self.model, precision=cfg.precision) |
|
|
| if self.trainer.precision == 'bf16': |
| self.autocast_dtype = torch.bfloat16 |
| elif int(self.trainer.precision) == 32: |
| self.autocast_dtype = torch.float |
| elif int(self.trainer.precision) == 16: |
| self.autocast_dtype = torch.half |
| else: |
| raise ValueError('precision must be in [32, 16, "bf16"]') |
|
|
| self.transformer_engine = cfg.get('transformer_engine', False) |
|
|
| |
| self._inference_config = None |
|
|
| |
| if hasattr(self, '_nsys_profile_enabled'): |
| mp_size = cfg.get('tensor_model_parallel_size', 1) * cfg.get('pipeline_model_parallel_size', 1) |
| data_parallel_world_size = trainer.world_size // mp_size |
| grad_accum_steps = cfg.get('global_batch_size') // (cfg.get('micro_batch_size') * data_parallel_world_size) |
| self._nsys_profile_start_step *= grad_accum_steps |
| self._nsys_profile_end_step *= grad_accum_steps |
|
|
| self.get_attention_mask_from_fusion = self.cfg.get('get_attention_mask_from_fusion', True) |
|
|
| def set_inference_config(self, inference_config): |
| self._inference_config = inference_config |
|
|
| def get_inference_config(self): |
| return self._inference_config |
|
|
| def model_provider_func(self, pre_process, post_process): |
| """Model depends on pipeline paralellism.""" |
| model = GPTModel( |
| vocab_size=self.padded_vocab_size, |
| hidden_size=self.cfg.hidden_size, |
| max_position_embeddings=self.cfg.max_position_embeddings, |
| num_layers=self.cfg.num_layers, |
| num_attention_heads=self.cfg.num_attention_heads, |
| apply_query_key_layer_scaling=self.cfg.get('apply_query_key_layer_scaling', True), |
| kv_channels=self.cfg.get('kv_channels', None), |
| ffn_hidden_size=self.cfg.ffn_hidden_size, |
| num_tokentypes=0, |
| parallel_output=True, |
| pre_process=pre_process, |
| post_process=post_process, |
| init_method_std=self.cfg.get('init_method_std', 0.02), |
| use_scaled_init_method=self.cfg.get('use_scaled_init_method', True), |
| fp16_lm_cross_entropy=self.cfg.get('fp16_lm_cross_entropy', False), |
| use_cpu_initialization=self.cfg.get('use_cpu_initialization', False), |
| hidden_dropout=self.cfg.get('hidden_dropout', 0.1), |
| attention_dropout=self.cfg.get('attention_dropout', 0.1), |
| ffn_dropout=self.cfg.get('ffn_dropout', 0.0), |
| precision=self.cfg.get('precision', 16), |
| fp32_residual_connection=self.cfg.get('fp32_residual_connection', False), |
| activations_checkpoint_granularity=self.cfg.get('activations_checkpoint_granularity', None), |
| activations_checkpoint_method=self.cfg.get('activations_checkpoint_method', None), |
| activations_checkpoint_num_layers=self.cfg.get('activations_checkpoint_num_layers', 1), |
| activations_checkpoint_layers_per_pipeline=self.cfg.get( |
| 'activations_checkpoint_layers_per_pipeline', None |
| ), |
| normalization=self.cfg.get('normalization', 'layernorm'), |
| layernorm_epsilon=self.cfg.get('layernorm_epsilon', 1e-5), |
| onnx_safe=self.cfg.get('onnx_safe', False), |
| bias=self.cfg.get('bias', True), |
| bias_activation_fusion=self.cfg.get('bias_activation_fusion', True), |
| bias_dropout_add_fusion=self.cfg.get('bias_dropout_add_fusion', True), |
| activation=self.cfg.get('activation', 'gelu'), |
| headscale=self.cfg.get('headscale', False), |
| transformer_block_type=self.cfg.get('transformer_block_type', 'pre_ln'), |
| openai_gelu=self.cfg.get('openai_gelu', False), |
| normalize_attention_scores=self.cfg.get('normalize_attention_scores', True), |
| position_embedding_type=self.cfg.get('position_embedding_type', 'learned_absolute'), |
| rotary_percentage=self.cfg.get('rotary_percentage', 1.0), |
| share_embeddings_and_output_weights=self.cfg.get('share_embeddings_and_output_weights', True), |
| attention_type=self.cfg.get('attention_type', 'multihead'), |
| masked_softmax_fusion=self.cfg.get('masked_softmax_fusion', True), |
| gradient_accumulation_fusion=self.cfg.get('gradient_accumulation_fusion', False), |
| persist_layer_norm=self.cfg.get('persist_layer_norm', False), |
| sequence_parallel=self.cfg.get('sequence_parallel', False), |
| transformer_engine=self.cfg.get('transformer_engine', False), |
| fp8=self.cfg.get('fp8', False), |
| fp8_e4m3=self.cfg.get('fp8_e4m3', False), |
| fp8_hybrid=self.cfg.get('fp8_hybrid', False), |
| fp8_margin=self.cfg.get('fp8_margin', 0), |
| fp8_interval=self.cfg.get('fp8_interval', 1), |
| fp8_amax_history_len=self.cfg.get('fp8_amax_history_len', 1), |
| fp8_amax_compute_algo=self.cfg.get('fp8_amax_compute_algo', 'most_recent'), |
| reduce_amax=self.cfg.get('reduce_amax', True), |
| use_emha=self.cfg.get('use_emha', False), |
| ) |
|
|
| return model |
|
|
| def setup_optimizer_param_groups(self): |
| """ModelPT override. Optimizer will get self._optimizer_param_groups""" |
| if self.cfg.get('do_layer_norm_weight_decay', False): |
| if isinstance(self.model, list): |
| self._optimizer_param_groups = get_all_params_for_weight_decay_optimization(self.model) |
| else: |
| self._optimizer_param_groups = get_all_params_for_weight_decay_optimization([self.model]) |
|
|
| else: |
| self._optimizer_param_groups = get_params_for_weight_decay_optimization(self.model) |
|
|
| def configure_optimizers(self): |
|
|
| if self.with_distributed_adam: |
|
|
| |
| |
| if parallel_state.get_pipeline_model_parallel_world_size() > 1: |
| if parallel_state.is_pipeline_first_stage(ignore_virtual=True): |
| if isinstance(self.model, list): |
| module = self.model[0] |
| else: |
| module = self.model |
| if module.share_token_embeddings: |
| param = module.word_embeddings_weight() |
| param._disable_greedy_grad_copy = not self.megatron_amp_o2 |
| param._disable_overlap_grad_sync = True |
| if parallel_state.is_pipeline_last_stage(ignore_virtual=True): |
| if isinstance(self.model, list): |
| module = self.model[-1] |
| else: |
| module = self.model |
| if module.share_token_embeddings: |
| param = module.word_embeddings_weight() |
| param._disable_greedy_grad_copy = not self.megatron_amp_o2 |
| param._disable_overlap_grad_sync = True |
|
|
| |
| |
| for param in self.parameters(): |
| if getattr(param, 'sequence_parallel_enabled', False): |
| param._disable_greedy_grad_copy = not self.megatron_amp_o2 |
| param._disable_overlap_grad_sync = True |
|
|
| |
| |
| |
| buckets = [] |
| if self.cfg.get('virtual_pipeline_model_parallel_size', None) is not None: |
| |
| for module in self.model: |
| if isinstance(module, Float16Module): |
| module = module.module |
| stage_bucket = [] |
| for layer in module.language_model.encoder.layers: |
| stage_bucket.extend( |
| p for p in layer.parameters() if not getattr(p, '_disable_overlap_grad_sync', False) |
| ) |
| buckets.append(stage_bucket) |
| else: |
| |
| modules = self.model if isinstance(self.model, list) else [self.model] |
| for module in modules: |
| if isinstance(module, Float16Module): |
| module = module.module |
| for layer in module.language_model.encoder.layers: |
| buckets.append( |
| [p for p in layer.parameters() if not getattr(p, '_disable_overlap_grad_sync', False)] |
| ) |
| buckets.reverse() |
| used_params = set() |
| for bucket in buckets: |
| used_params.update(bucket) |
| buckets[-1].extend(p for p in self.parameters() if p not in used_params) |
| self.distributed_adam_buckets = buckets |
|
|
| return super().configure_optimizers() |
|
|
| def forward(self, tokens, text_position_ids, attention_mask, labels): |
| output_tensor = self.model(tokens, text_position_ids, attention_mask, labels=labels) |
| return output_tensor |
|
|
| def _get_fwd_bwd_function(self): |
| fwd_bwd_function = None |
| if self.cfg.get('pipeline_model_parallel_size', 1) > 1: |
| if self.cfg.get('virtual_pipeline_model_parallel_size', None) is not None: |
| fwd_bwd_function = _forward_backward_pipelining_with_interleaving |
| else: |
| fwd_bwd_function = forward_backward_pipelining_without_interleaving |
| else: |
| fwd_bwd_function = forward_backward_no_pipelining |
| return fwd_bwd_function |
|
|
| def training_step(self, dataloader_iter, batch_idx): |
| """ |
| We pass the dataloader iterator function to the micro-batch scheduler. |
| The input batch to each micro-batch is fetched using the dataloader function |
| in the micro-batch fwd function. |
| """ |
|
|
| |
| self._optimizer.zero_grad() |
|
|
| if self.with_distributed_adam: |
| |
| |
| |
| |
| |
| |
| |
| |
| modules = self.model if isinstance(self.model, list) else [self.model] |
| for module in modules: |
| if isinstance(module, Float16Module): |
| module = module.module |
| module = module.language_model |
| if hasattr(module, 'embedding'): |
| for param in module.embedding.parameters(): |
| param.data_ptr() |
|
|
| tensor_shape = [self.cfg.encoder_seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size] |
|
|
| |
| custom_sync_context_handler = None |
| custom_grad_sync_func = None |
| custom_param_sync_func = None |
| if self.with_distributed_adam: |
| if self.megatron_amp_o2: |
| |
| custom_sync_context_handler = lambda: self._optimizer.no_sync(greedy_grad_copy=True) |
| else: |
| |
| custom_sync_context_handler = lambda: self._optimizer.no_sync(greedy_grad_copy=False) |
| custom_grad_sync_func = self.reduce_overlap_gradients |
| custom_param_sync_func = self.sync_overlap_parameters |
| else: |
| if self.megatron_amp_o2 and not self.cfg.get('sequence_parallel', False): |
| custom_sync_context_handler = self._optimizer.no_sync |
| else: |
| |
| custom_sync_context_handler = None |
|
|
| |
| |
| fwd_bwd_function = self._get_fwd_bwd_function() |
|
|
| losses_reduced_per_micro_batch = fwd_bwd_function( |
| forward_step_func=self.get_forward_output_and_loss_func(), |
| batch=dataloader_iter, |
| model=self.model, |
| forward_only=False, |
| tensor_shape=tensor_shape, |
| dtype=self.autocast_dtype, |
| grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, |
| custom_sync_context_handler=custom_sync_context_handler, |
| custom_grad_sync_func=custom_grad_sync_func, |
| custom_param_sync_func=custom_param_sync_func, |
| sequence_parallel_enabled=self.cfg.get('sequence_parallel', False), |
| sync_batch_comm=self.cfg.get('sync_batch_comm', False), |
| num_micro_batches_with_partial_activation_checkpoints=self.cfg.get( |
| 'num_micro_batches_with_partial_activation_checkpoints', None |
| ), |
| ) |
|
|
| |
| if losses_reduced_per_micro_batch: |
| |
| loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch] |
| loss_tensor = torch.concat(loss_tensors_list) |
| loss_mean = loss_tensor.mean() |
| else: |
| loss_mean = torch.tensor(0.0).cuda() |
|
|
| |
| if self.cfg.get('tensor_model_parallel_size', 1) > 1 and self.cfg.get('sequence_parallel', False): |
| self.allreduce_sequence_parallel_gradients() |
|
|
| if self.with_distributed_adam: |
| |
| |
| |
| self._optimizer._finish_bucket_grad_sync() |
| elif self.megatron_amp_o2: |
| |
| if self.cfg.get('pipeline_model_parallel_size', 1) > 1 or self.cfg.get('sequence_parallel', False): |
| |
| self._optimizer.allreduce_main_grads() |
| else: |
| |
| |
| self.allreduce_gradients() |
|
|
| if self.cfg.get('pipeline_model_parallel_size', 1) > 1 and self.cfg.get( |
| 'share_embeddings_and_output_weights', True |
| ): |
| |
| self.allreduce_first_last_embeddings() |
|
|
| |
| |
| |
| torch.distributed.broadcast(loss_mean, get_last_rank()) |
|
|
| if self.cfg.precision == 16: |
| loss_scale = self.trainer.precision_plugin.scaler._scale |
| if loss_scale is not None: |
| self.log('loss_scale', loss_scale, batch_size=1) |
|
|
| self.log('reduced_train_loss', loss_mean, prog_bar=True, rank_zero_only=True, batch_size=1) |
| lr = self._optimizer.param_groups[0]['lr'] |
| self.log('lr', lr, rank_zero_only=True, batch_size=1) |
| self.log( |
| 'global_step', self.trainer.global_step, prog_bar=True, rank_zero_only=True, batch_size=1, |
| ) |
|
|
| |
| self.log( |
| 'consumed_samples', |
| self.compute_consumed_samples(self.trainer.global_step - self.init_global_step), |
| prog_bar=True, |
| rank_zero_only=True, |
| batch_size=1, |
| ) |
|
|
| return loss_mean |
|
|
| def backward(self, *args, **kwargs): |
| """ LightningModule hook to do backward. |
| We want this to do nothing since we run backward in the fwd/bwd functions from apex. |
| No need to call it here. |
| """ |
| return |
|
|
| def optimizer_zero_grad(self, *args, **kwargs): |
| """ LightningModule hook to zero grad. |
| We want this to do nothing as we are zeroing grads during the training_step. |
| """ |
| return |
|
|
| def _append_sequence_parallel_module_grads(self, module, grads): |
| """ Helper method for allreduce_sequence_parallel_gradients""" |
|
|
| for param in module.parameters(): |
| if getattr(self, 'transformer_engine', False): |
| sequence_parallel_param = getattr(param, 'sequence_parallel', False) |
| else: |
| sequence_parallel_param = getattr(param, 'sequence_parallel_enabled', False) |
| if sequence_parallel_param: |
| if self.megatron_amp_o2: |
| grad = param.main_grad |
| else: |
| grad = param.grad |
| grads.append(grad.data) |
|
|
| def allreduce_sequence_parallel_gradients(self): |
| """ All-reduce layernorm parameters across model parallel nodes when sequence parallelism is used. |
| Modified from megatron-lm: |
| https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/blob/3f91f09bb2ab32f9904b47f46f19d2fc3f518ed8/megatron/training.py#L425 |
| """ |
|
|
| grads = [] |
| if isinstance(self.model, list): |
| for module in self.model: |
| self._append_sequence_parallel_module_grads(module, grads) |
| else: |
| self._append_sequence_parallel_module_grads(self.model, grads) |
|
|
| coalesced = torch._utils._flatten_dense_tensors(grads) |
| torch.distributed.all_reduce(coalesced, group=parallel_state.get_tensor_model_parallel_group()) |
| for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)): |
| buf.copy_(synced) |
|
|
| def allreduce_first_last_embeddings(self): |
|
|
| |
| |
| |
| |
| |
| if parallel_state.get_pipeline_model_parallel_world_size() > 1 and ( |
| parallel_state.is_pipeline_first_stage(ignore_virtual=True) |
| or parallel_state.is_pipeline_last_stage(ignore_virtual=True) |
| ): |
| if parallel_state.is_pipeline_first_stage(ignore_virtual=True): |
| if isinstance(self.model, list): |
| module = self.model[0] |
| else: |
| module = self.model |
| if parallel_state.is_pipeline_last_stage(ignore_virtual=True): |
| if isinstance(self.model, list): |
| module = self.model[-1] |
| else: |
| module = self.model |
| if module.share_token_embeddings: |
| word_embeddings_weight = module.word_embeddings_weight() |
| if self.megatron_amp_o2: |
| |
| grad = word_embeddings_weight.main_grad |
| else: |
| grad = word_embeddings_weight.grad |
| torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group()) |
|
|
| def get_forward_output_and_loss_func(self, validation_step=False): |
| def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_layers=None): |
| if parallel_state.get_pipeline_model_parallel_world_size() == 1: |
| batch = next(dataloader_iter) |
| for k in batch.keys(): |
| if self.get_attention_mask_from_fusion: |
| batch[k] = batch[k].cuda(non_blocking=True) if k not in ['attention_mask'] else None |
| else: |
| batch[k] = batch[k].cuda(non_blocking=True) |
| else: |
| if parallel_state.is_pipeline_first_stage(): |
| batch = next(dataloader_iter) |
| |
| for k in batch.keys(): |
| if self.get_attention_mask_from_fusion: |
| batch[k] = batch[k].cuda(non_blocking=True) if k in ['tokens', 'position_ids'] else None |
| else: |
| batch[k] = ( |
| batch[k].cuda(non_blocking=True) |
| if k in ['tokens', 'position_ids', 'attention_mask'] |
| else None |
| ) |
| elif parallel_state.is_pipeline_last_stage(): |
| batch = next(dataloader_iter) |
| |
| for k in batch.keys(): |
| if self.get_attention_mask_from_fusion: |
| batch[k] = batch[k].cuda(non_blocking=True) if k in ['labels', 'loss_mask'] else None |
| else: |
| batch[k] = ( |
| batch[k].cuda(non_blocking=True) |
| if k in ['labels', 'loss_mask', 'attention_mask'] |
| else None |
| ) |
| else: |
| |
| batch = {k: None for k in ['tokens', 'position_ids', 'attention_mask', 'labels']} |
|
|
| output_tensor = model( |
| batch['tokens'], |
| batch['position_ids'], |
| batch['attention_mask'], |
| batch['labels'], |
| checkpoint_activations_all_layers=checkpoint_activations_all_layers, |
| ) |
|
|
| def loss_func(output_tensor): |
| |
| loss_for_ub = self.loss_func(batch['loss_mask'], output_tensor) |
| if validation_step and not self.cfg.data.get('validation_drop_last', True): |
| num_valid_tokens_in_ub = batch['loss_mask'].sum() |
| if loss_for_ub.isnan(): |
| assert batch['loss_mask'].count_nonzero() == 0, 'Got NaN loss with non-empty input' |
| loss_sum_for_ub = torch.zeros_like(num_valid_tokens_in_ub) |
| else: |
| loss_sum_for_ub = num_valid_tokens_in_ub * loss_for_ub |
|
|
| loss_sum_and_ub_size_all_gpu = torch.cat( |
| [ |
| loss_sum_for_ub.clone().detach().view(1), |
| torch.tensor([num_valid_tokens_in_ub]).cuda().clone().detach(), |
| ] |
| ) |
| |
| torch.distributed.all_reduce( |
| loss_sum_and_ub_size_all_gpu, group=parallel_state.get_data_parallel_group() |
| ) |
| return loss_for_ub, {'loss_sum_and_ub_size': loss_sum_and_ub_size_all_gpu} |
| else: |
| reduced_loss = average_losses_across_data_parallel_group([loss_for_ub]) |
| return loss_for_ub, {'avg': reduced_loss} |
|
|
| return output_tensor, loss_func |
|
|
| return fwd_output_and_loss_func |
|
|
| def get_forward_output_only_func(self): |
| def fwd_output_only_func(batch, model): |
| extra_arg = {} |
| if len(batch) == 3: |
| batch = [x.cuda() for x in batch] |
| tokens, attention_mask, position_ids = batch |
| attention_mask = attention_mask[0:1] |
| else: |
| ( |
| tokens, |
| attention_mask, |
| position_ids, |
| set_inference_key_value_memory, |
| inference_max_sequence_len, |
| ) = batch |
| tokens = tokens.cuda() |
| attention_mask = attention_mask.cuda() |
| position_ids = position_ids.cuda() |
| attention_mask = attention_mask[0:1] |
| extra_arg['set_inference_key_value_memory'] = set_inference_key_value_memory[0].item() |
| extra_arg['inference_max_sequence_len'] = inference_max_sequence_len[0].item() |
| output_tensor = model(tokens, position_ids, attention_mask, **extra_arg) |
|
|
| def id_func(output_tensor): |
| return output_tensor, {'logits': output_tensor} |
|
|
| return output_tensor, id_func |
|
|
| return fwd_output_only_func |
|
|
| def validation_step(self, dataloader_iter, batch_idx): |
| """ |
| Our dataloaders produce a micro-batch and then we fetch |
| a number of microbatches depending on the global batch size and model parallel size |
| from the dataloader to produce a list of microbatches. |
| The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. |
| """ |
|
|
| tensor_shape = [self.cfg.encoder_seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size] |
|
|
| |
| |
| fwd_bwd_function = self._get_fwd_bwd_function() |
|
|
| losses_reduced_per_micro_batch = fwd_bwd_function( |
| forward_step_func=self.get_forward_output_and_loss_func(validation_step=True), |
| batch=dataloader_iter, |
| model=self.model, |
| forward_only=True, |
| tensor_shape=tensor_shape, |
| dtype=self.autocast_dtype, |
| sequence_parallel_enabled=self.cfg.get('sequence_parallel', False), |
| sync_batch_comm=self.cfg.get('sync_batch_comm', False), |
| ) |
|
|
| |
| if losses_reduced_per_micro_batch: |
| if self.cfg.data.get('validation_drop_last', True): |
| |
| loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch] |
| return torch.concat(loss_tensors_list).mean() |
| else: |
| |
| loss_sum_tensors_list = [ |
| loss_sum['loss_sum_and_ub_size'] |
| for loss_sum in losses_reduced_per_micro_batch |
| if loss_sum['loss_sum_and_ub_size'][1] > 0 |
| ] |
| loss_sum = ( |
| torch.vstack(loss_sum_tensors_list).sum(axis=0) |
| if len(loss_sum_tensors_list) > 0 |
| else torch.tensor([0.0, 0.0]).cuda() |
| ) |
| return loss_sum |
| else: |
| |
| return [] |
|
|
| def validation_epoch_end(self, outputs): |
| if parallel_state.is_pipeline_last_stage(): |
| |
| if self.cfg.data.get('validation_drop_last', True): |
| averaged_loss = torch.stack(outputs).mean() |
| else: |
| |
| total_loss_and_total_samples = torch.vstack(outputs).sum(axis=0) |
| avg_loss = total_loss_and_total_samples[0] / total_loss_and_total_samples[1] |
| averaged_loss = avg_loss.type(torch.float32).cuda() |
| else: |
| averaged_loss = torch.tensor(0.0, dtype=torch.float32).cuda() |
|
|
| |
| torch.distributed.broadcast(averaged_loss, get_last_rank()) |
|
|
| self.log('val_loss', averaged_loss, prog_bar=True, rank_zero_only=True, batch_size=1) |
|
|
| def test_step(self, batch, batch_idx): |
| return self.validation_step(batch, batch_idx) |
|
|
| def test_epoch_end(self, outputs): |
| averaged_loss = average_losses_across_data_parallel_group(outputs) |
| logging.info(f'test_loss: {averaged_loss[0]}') |
|
|
| def loss_func(self, loss_mask, output_tensor): |
| losses = output_tensor.float() |
| loss_mask = loss_mask.view(-1).float() |
| |
| loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() |
| return loss |
|
|
| def build_train_valid_test_datasets(self): |
| logging.info('Building GPT datasets.') |
| if self.trainer.limit_val_batches > 1.0 and isinstance(self.trainer.limit_val_batches, float): |
| raise ValueError("limit_val_batches must be an integer or float less than or equal to 1.0.") |
| global_batch_size = self.cfg.global_batch_size |
| max_train_steps = self.trainer.max_steps |
| eval_iters = (max_train_steps // self.trainer.val_check_interval + 1) * self.trainer.limit_val_batches |
| test_iters = self.trainer.limit_test_batches |
|
|
| train_valid_test_num_samples = [ |
| max_train_steps * global_batch_size, |
| eval_iters * global_batch_size, |
| test_iters * global_batch_size, |
| ] |
|
|
| if self.trainer.limit_val_batches <= 1.0 and isinstance(self.trainer.limit_val_batches, float): |
| train_valid_test_num_samples[ |
| 1 |
| ] = 1 |
|
|
| self._train_ds, self._validation_ds, self._test_ds = build_train_valid_test_datasets( |
| cfg=self.cfg, |
| trainer=self.trainer, |
| data_prefix=self.cfg.data.data_prefix, |
| data_impl=self.cfg.data.data_impl, |
| splits_string=self.cfg.data.splits_string, |
| train_valid_test_num_samples=train_valid_test_num_samples, |
| seq_length=self.cfg.data.seq_length, |
| seed=self.cfg.seed, |
| skip_warmup=self.cfg.data.get('skip_warmup', True), |
| tokenizer=self.tokenizer, |
| ) |
| if self._train_ds is not None: |
| logging.info(f'Length of train dataset: {len(self._train_ds)}') |
| if self._validation_ds is not None: |
| logging.info(f'Length of val dataset: {len(self._validation_ds)}') |
| if self._test_ds is not None: |
| logging.info(f'Length of test dataset: {len(self._test_ds)}') |
| logging.info(f'Finished building GPT datasets.') |
|
|
| return self._train_ds, self._validation_ds, self._test_ds |
|
|
| def build_pretraining_data_loader( |
| self, dataset, consumed_samples, dataset_type=None, drop_last=True, pad_samples_to_global_batch_size=False |
| ): |
| """Buld dataloader given an input dataset.""" |
|
|
| logging.info(f'Building dataloader with consumed samples: {consumed_samples}') |
| |
| if hasattr(self.cfg.data, 'dataloader_type') and self.cfg.data.dataloader_type is not None: |
| if self.cfg.data.dataloader_type == 'single': |
| batch_sampler = MegatronPretrainingSampler( |
| total_samples=len(dataset), |
| consumed_samples=consumed_samples, |
| micro_batch_size=self.cfg.micro_batch_size, |
| data_parallel_rank=parallel_state.get_data_parallel_rank(), |
| data_parallel_size=parallel_state.get_data_parallel_world_size(), |
| drop_last=drop_last, |
| global_batch_size=self.cfg.global_batch_size, |
| pad_samples_to_global_batch_size=pad_samples_to_global_batch_size, |
| ) |
| elif self.cfg.data.dataloader_type == 'cyclic': |
| batch_sampler = MegatronPretrainingRandomSampler( |
| total_samples=len(dataset), |
| consumed_samples=consumed_samples, |
| micro_batch_size=self.cfg.micro_batch_size, |
| data_parallel_rank=parallel_state.get_data_parallel_rank(), |
| data_parallel_size=parallel_state.get_data_parallel_world_size(), |
| drop_last=self.cfg.get('drop_last', True), |
| ) |
| else: |
| raise ValueError('cfg.data.dataloader_type must be "single" or "cyclic"') |
| else: |
| raise ValueError('cfg.data.dataloader_type not found. Must be "single" or "cyclic"') |
|
|
| return torch.utils.data.DataLoader( |
| dataset, |
| batch_sampler=batch_sampler, |
| num_workers=self.cfg.data.num_workers, |
| pin_memory=True, |
| persistent_workers=True, |
| ) |
|
|
| def setup(self, stage=None): |
| """ PTL hook that is executed after DDP spawns. |
| We setup datasets here as megatron datasets require DDP to instantiate. |
| See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#setup for more information. |
| Args: |
| stage (str, optional): Can be 'fit', 'validate', 'test' or 'predict'. Defaults to None. |
| """ |
| num_parameters_on_device, total_num_parameters = self._get_total_params_across_model_parallel_groups_gpt_bert( |
| self.model |
| ) |
|
|
| logging.info( |
| f'Pipeline model parallel rank: {parallel_state.get_pipeline_model_parallel_rank()}, ' |
| f'Tensor model parallel rank: {parallel_state.get_tensor_model_parallel_rank()}, ' |
| f'Number of model parameters on device: {num_parameters_on_device:.2e}. ' |
| f'Total number of model parameters: {total_num_parameters:.2e}.' |
| ) |
|
|
| resume_checkpoint_path = self.trainer._checkpoint_connector.resume_from_checkpoint_fit_path |
| if resume_checkpoint_path: |
| init_consumed_samples = self._extract_consumed_samples_from_ckpt(resume_checkpoint_path) |
| else: |
| init_consumed_samples = 0 |
| self.init_consumed_samples = init_consumed_samples |
| self.init_global_step = self.trainer.global_step |
|
|
| if stage == 'predict': |
| return |
| else: |
| |
| |
| self.build_train_valid_test_datasets() |
| self.setup_training_data(self.cfg.data) |
| self.setup_validation_data(self.cfg.data) |
| self.setup_test_data(self.cfg.data) |
|
|
| |
| if parallel_state.get_pipeline_model_parallel_world_size() > 1: |
| if isinstance(self.model, list): |
| for i, module in enumerate(self.model): |
| parallel_state.set_virtual_pipeline_model_parallel_rank(i) |
| if self.cfg.get('share_embeddings_and_output_weights', True): |
| module.sync_initial_word_embeddings() |
| parallel_state.set_virtual_pipeline_model_parallel_rank(0) |
| else: |
| if self.cfg.get('share_embeddings_and_output_weights', True): |
| self.model.sync_initial_word_embeddings() |
|
|
| if self.cfg.get('transformer_engine', False): |
| self.setup_transformer_engine_tp_groups() |
|
|
| def setup_training_data(self, cfg): |
| if hasattr(self, '_train_ds'): |
| consumed_samples = self.compute_consumed_samples(0) |
| logging.info( |
| f'Setting up train dataloader with len(len(self._train_ds)): {len(self._train_ds)} and consumed samples: {consumed_samples}' |
| ) |
| self._train_dl = self.build_pretraining_data_loader(self._train_ds, consumed_samples) |
|
|
| def setup_validation_data(self, cfg): |
| if hasattr(self, '_validation_ds'): |
| consumed_samples = 0 |
| logging.info( |
| f'Setting up validation dataloader with len(len(self._validation_ds)): {len(self._validation_ds)} and consumed samples: {consumed_samples}' |
| ) |
|
|
| drop_last = True |
| if not self.cfg.data.get('validation_drop_last', True): |
| logging.info(f'Drop last in validation dataset is set to False') |
| drop_last = False |
| pad_samples_to_global_batch_size = False |
| if self.cfg.data.get('pad_samples_to_global_batch_size', False): |
| logging.info('pad_samples_to_global_batch_size set to True') |
| pad_samples_to_global_batch_size = True |
|
|
| self._validation_dl = self.build_pretraining_data_loader( |
| self._validation_ds, consumed_samples, "validation", drop_last, pad_samples_to_global_batch_size |
| ) |
|
|
| def setup_test_data(self, cfg): |
| if hasattr(self, '_test_ds'): |
| consumed_samples = 0 |
| logging.info( |
| f'Setting up test dataloader with len(len(self._test_ds)): {len(self._test_ds)} and consumed samples: {consumed_samples}' |
| ) |
| self._test_dl = self.build_pretraining_data_loader(self._test_ds, consumed_samples) |
|
|
| def generate( |
| self, |
| inputs: Union[List[str], torch.Tensor, List[dict]], |
| length_params: LengthParam, |
| sampling_params: SamplingParam = None, |
| ) -> OutputType: |
|
|
| |
| if parallel_state.is_unitialized(): |
|
|
| def dummy(): |
| return |
|
|
| if self.trainer.strategy.launcher is not None: |
| self.trainer.strategy.launcher.launch(dummy, trainer=self.trainer) |
| self.trainer.strategy.setup_environment() |
|
|
| if self.cfg.get('transformer_engine', False): |
| self.setup_transformer_engine_tp_groups() |
|
|
| |
| |
| if sampling_params is None: |
| sampling_params = get_default_sampling_params() |
|
|
| |
| |
| if length_params is None: |
| length_params = get_default_length_params() |
|
|
| return megatron_gpt_generate(self.cuda(), inputs, self.tokenizer, length_params, sampling_params) |
|
|
| def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None) -> Any: |
| inference_config = self.get_inference_config() |
| if inference_config is None: |
| return None |
| else: |
| |
| inference_config = inference_config.copy() |
| compute_logprob = inference_config['compute_logprob'] |
| if compute_logprob: |
| del inference_config['compute_logprob'] |
| inference_config['inputs'] = batch |
| inference_config['tokens_to_generate'] = 1 |
| inference_config['all_probs'] = True |
| inference_config["add_BOS"] = False |
| inference_config['greedy'] = True |
| response = generate(self, **inference_config) |
| compute_prob_response = get_computeprob_response(self.tokenizer, response, batch) |
| return compute_prob_response |
| else: |
| del inference_config['compute_logprob'] |
| inference_config['inputs'] = batch |
| return generate(self, **inference_config) |
|
|
| def list_available_models(self): |
| return None |
|
|
| def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any: |
| """ PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device |
| When using pipeline parallelism, we need the global batch to remain on the CPU, |
| since the memory overhead will be too high when using a large number of microbatches. |
| Microbatches are transferred from CPU to GPU inside the pipeline. |
| """ |
| return batch |
|
|
| def _validate_trainer(self): |
| """ Certain trainer configurations can break training. |
| Here we try to catch them and raise an error. |
| """ |
| if self.trainer.accumulate_grad_batches > 1: |
| raise ValueError( |
| f'Gradient accumulation is done within training_step. trainer.accumulate_grad_batches must equal 1' |
| ) |
|
|
| @classmethod |
| def list_available_models(cls) -> Optional[PretrainedModelInfo]: |
| """ |
| This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. |
| Returns: |
| List of available pre-trained models. |
| """ |
| result = [] |
| result.append( |
| PretrainedModelInfo( |
| pretrained_model_name="megatron_gpt_345m", |
| location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/megatron_gpt_345m/versions/1/files/megatron_gpt_345m.nemo", |
| description="345M parameter GPT generative Megatron model.", |
| ) |
| ) |
| return result |
|
|
| def _set_tp_groups(self, module): |
| """ Helper method to set tp groups for transformer engine""" |
|
|
| if self.cfg.get('transformer_engine', False): |
| logging.info(f'Setting up transformer engine modules for tensor parallelism.') |
| if self.cfg.get('megatron_amp_O2', 'False'): |
| |
| for layer in module.module.language_model.encoder.layers: |
| layer.set_tensor_parallel_group(parallel_state.get_tensor_model_parallel_group()) |
|
|
| else: |
| for layer in module.language_model.encoder.layers: |
| layer.set_tensor_parallel_group(parallel_state.get_tensor_model_parallel_group()) |
|
|
| def setup_transformer_engine_tp_groups(self): |
| """ This should be called after model parallel groups have been initialized |
| and only needs to be called when using Transformer Engine. |
| """ |
| if isinstance(self.model, list): |
| for module in self.model: |
| self._set_tp_groups(module) |
| else: |
| self._set_tp_groups(self.model) |
|
|
| def on_save_checkpoint(self, checkpoint) -> None: |
| """LightningModule hook: |
| https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-save-checkpoint |
| """ |
| if isinstance(self.model, list): |
| for i in range(len(self.model)): |
| parallel_state.set_virtual_pipeline_model_parallel_rank(i) |
| checkpoint[f'model{i}'] = self.model[i].module.state_dict_for_save_checkpoint() |
| parallel_state.set_virtual_pipeline_model_parallel_rank(0) |
|
|
| def on_load_checkpoint(self, checkpoint) -> None: |
| """LightningModule hook: |
| https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-load-checkpoint |
| """ |
| if isinstance(self.model, list): |
| for i in range(len(self.model)): |
| parallel_state.set_virtual_pipeline_model_parallel_rank(i) |
| self.model[i].module.load_state_dict(checkpoint[f'model{i}'], strict=True) |
| parallel_state.set_virtual_pipeline_model_parallel_rank(0) |
|
|
| def parameters(self): |
| if isinstance(self.model, list): |
| return itertools.chain.from_iterable(module.parameters() for module in self.model) |
| else: |
| return self.model.parameters() |
|
|
| def on_train_batch_end(self, outputs, dataloader_iter: Any, batch_idx: int, unused: Optional[int] = 0) -> None: |
| super().on_train_batch_end(outputs, dataloader_iter, batch_idx) |
|
|
| |
| |
| if self.trainer.precision_plugin is not None and isinstance( |
| self.trainer.precision_plugin, NativeMixedPrecisionPlugin |
| ): |
| precision_plugin = self.trainer.precision_plugin |
|
|
| if ( |
| hasattr(precision_plugin, 'scaler') |
| and precision_plugin.scaler is not None |
| and isinstance(precision_plugin.scaler, GradScaler) |
| ): |
| grad_scaler = precision_plugin.scaler |
|
|
| |
| |
| if grad_scaler.optimizer_update_skipped is not None and grad_scaler.optimizer_update_skipped is True: |
| scheduler_cfgs = self.trainer.lr_scheduler_configs |
|
|
| if not scheduler_cfgs or not self.trainer.lightning_module.automatic_optimization: |
| return |
|
|
| for scheduler_cfg in scheduler_cfgs: |
| |
| |
| scheduler_cfg.scheduler.last_epoch -= 2 |
| scheduler_cfg.scheduler.step() |
|
|
| |
| |
|
|
| |
| |
| grad_scaler.optimizer_update_skipped = None |
|
|