| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """Megatron Module""" |
| |
|
| | import torch |
| | from torch.autograd import Variable |
| | from torch.nn.parameter import Parameter |
| |
|
| | from nemo.utils import logging |
| |
|
| | try: |
| | from apex.transformer import parallel_state, tensor_parallel |
| |
|
| | HAVE_APEX = True |
| |
|
| | except (ImportError, ModuleNotFoundError): |
| |
|
| | HAVE_APEX = False |
| |
|
| |
|
| | _FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) |
| | _HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) |
| | _BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor) |
| |
|
| |
|
| | def param_is_not_shared(param): |
| | return not hasattr(param, 'shared') or not param.shared |
| |
|
| |
|
| | class MegatronModule(torch.nn.Module): |
| | """Megatron specific extensions of torch Module with support |
| | for pipelining.""" |
| |
|
| | def __init__(self, share_token_embeddings=True): |
| | if not HAVE_APEX: |
| | raise ImportError( |
| | "Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." |
| | ) |
| | super(MegatronModule, self).__init__() |
| | self.share_token_embeddings = share_token_embeddings |
| |
|
| | def word_embeddings_weight(self): |
| | if self.pre_process: |
| | if hasattr(self, 'language_model'): |
| | return self.language_model.embedding.word_embeddings.weight |
| | elif hasattr(self, 'encoder_embedding'): |
| | return self.encoder_embedding.word_embeddings.weight |
| | elif hasattr(self, 'decoder_embedding'): |
| | return self.decoder_embedding.word_embeddings.weight |
| | else: |
| | raise ValueError( |
| | f"Pre_process is True, but no embedding is found on this rank. Looked for language_model.embedding, encoder_embedding, and decoder_embedding" |
| | ) |
| | else: |
| | |
| | if not self.share_token_embeddings: |
| | raise Exception( |
| | 'word_embeddings_weight() called for last ' 'stage, but share_token_embeddings is false' |
| | ) |
| | return self.word_embeddings.weight |
| |
|
| | def position_embeddings_weight(self): |
| | if self.pre_process: |
| | if hasattr(self, 'language_model'): |
| | return self.language_model.embedding.position_embeddings.weight |
| | elif hasattr(self, 'encoder_embedding'): |
| | return self.encoder_embedding.position_embeddings.weight |
| | elif hasattr(self, 'decoder_embedding'): |
| | return self.decoder_embedding.position_embeddings.weight |
| | else: |
| | raise ValueError( |
| | f"Pre_process is True, but no embedding is found on this rank. Looked for language_model.embedding, encoder_embedding, and decoder_embedding" |
| | ) |
| | else: |
| | |
| | raise ValueError(f"Pre_process is False, there is no position embedding on this rank.") |
| |
|
| | def encoder_relative_position_embeddings_weight(self): |
| | if hasattr(self, 'encoder_relative_position_embedding'): |
| | return self.encoder_relative_position_embedding.relative_position_embedding.weight |
| | else: |
| | raise ValueError( |
| | f"No encoder_relative_position_embedding found on this rank. Looking for encoder_relative_position_embedding.relative_position_embedding.weight" |
| | ) |
| |
|
| | def decoder_relative_position_embeddings_weight(self): |
| | if hasattr(self, 'decoder_relative_position_embedding'): |
| | return self.decoder_relative_position_embedding.relative_position_embedding.weight |
| | else: |
| | raise ValueError( |
| | f"No decoder_relative_position_embedding found on this rank. Looking for decoder_relative_position_embedding.relative_position_embedding.weight" |
| | ) |
| |
|
| | def decoder_cross_attention_relative_position_embeddings_weight(self): |
| | if hasattr(self, 'decoder_cross_attention_relative_position_embedding'): |
| | return self.decoder_cross_attention_relative_position_embedding.relative_position_embedding.weight |
| | else: |
| | raise ValueError( |
| | f"No decoder_cross_attention_relative_position_embedding found on this rank. Looking for decoder_cross_attention_relative_position_embedding.relative_position_embedding.weight" |
| | ) |
| |
|
| | def initialize_word_embeddings(self, init_method, vocab_size, hidden_size): |
| | if not self.share_token_embeddings: |
| | raise Exception('initialize_word_embeddings() was called but ' 'share_token_embeddings is false') |
| |
|
| | |
| | |
| | |
| | if parallel_state.get_pipeline_model_parallel_world_size() == 1: |
| | return |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if parallel_state.is_pipeline_last_stage() and not self.pre_process: |
| | |
| | assert not parallel_state.is_pipeline_first_stage() |
| | self._word_embeddings_for_head_key = 'word_embeddings_for_head' |
| | |
| | |
| | self.word_embeddings = tensor_parallel.VocabParallelEmbedding( |
| | vocab_size, hidden_size, init_method=init_method |
| | ) |
| | self.word_embeddings.weight.data.fill_(0) |
| | self.word_embeddings.weight.shared = True |
| |
|
| | |
| | |
| | |
| | if not parallel_state.is_pipeline_first_stage(ignore_virtual=True) and self.pre_process: |
| | if hasattr(self, 'language_model'): |
| | |
| | self.language_model.embedding.zero_parameters() |
| | else: |
| | |
| | assert hasattr(self, 'decoder_embedding') |
| | self.decoder_embedding.zero_parameters() |
| |
|
| | def sync_initial_word_embeddings(self): |
| |
|
| | if torch.distributed.is_initialized(): |
| | if parallel_state.is_rank_in_embedding_group() and self.share_token_embeddings: |
| | torch.distributed.all_reduce( |
| | self.word_embeddings_weight().data, group=parallel_state.get_embedding_group() |
| | ) |
| | else: |
| | logging.warning( |
| | "WARNING! Distributed processes aren't initialized, so " |
| | "word embeddings in the last layer are not synchronized. " |
| | "If you are just manipulating a model this is fine, but " |
| | "this needs to be handled manually. If you are training " |
| | "something is definitely wrong." |
| | ) |
| |
|
| | def sync_initial_position_embeddings(self): |
| | |
| | |
| | |
| | if ( |
| | parallel_state.is_rank_in_position_embedding_group() |
| | and parallel_state.get_pipeline_model_parallel_split_rank() is not None |
| | ): |
| | |
| | |
| | position_embeddings = self.position_embeddings_weight() |
| | torch.distributed.all_reduce(position_embeddings.data, group=parallel_state.get_position_embedding_group()) |
| |
|
| | def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): |
| | """Use this function to override the state dict for |
| | saving checkpoints.""" |
| | return self.state_dict(destination, prefix, keep_vars) |
| |
|
| | def sync_initial_encoder_relative_position_embeddings(self): |
| | |
| | if parallel_state.is_rank_in_encoder_relative_position_embedding_group(): |
| | position_embeddings = self.encoder_relative_position_embeddings_weight() |
| | torch.distributed.all_reduce( |
| | position_embeddings.data, group=parallel_state.get_encoder_relative_position_embedding_group() |
| | ) |
| |
|
| | def sync_initial_decoder_relative_position_embeddings(self): |
| | if parallel_state.is_rank_in_decoder_relative_position_embedding_group(): |
| | position_embeddings = self.decoder_relative_position_embeddings_weight() |
| | torch.distributed.all_reduce( |
| | position_embeddings.data, group=parallel_state.get_decoder_relative_position_embedding_group() |
| | ) |
| |
|
| | def sync_initial_decoder_cross_attention_relative_position_embeddings(self): |
| | if parallel_state.is_rank_in_decoder_relative_position_embedding_group(): |
| | position_embeddings = self.decoder_cross_attention_relative_position_embeddings_weight() |
| | torch.distributed.all_reduce( |
| | position_embeddings.data, group=parallel_state.get_decoder_relative_position_embedding_group() |
| | ) |
| |
|
| |
|
| | def conversion_helper(val, conversion): |
| | """Apply conversion to val. Recursively apply conversion if `val` |
| | #is a nested tuple/list structure.""" |
| | if not isinstance(val, (tuple, list)): |
| | return conversion(val) |
| | rtn = [conversion_helper(v, conversion) for v in val] |
| | if isinstance(val, tuple): |
| | rtn = tuple(rtn) |
| | return rtn |
| |
|
| |
|
| | def fp32_to_float16(val, float16_converter): |
| | """Convert fp32 `val` to fp16/bf16""" |
| |
|
| | def half_conversion(val): |
| | val_typecheck = val |
| | if isinstance(val_typecheck, (Parameter, Variable)): |
| | val_typecheck = val.data |
| | if isinstance(val_typecheck, _FLOAT_TYPES): |
| | val = float16_converter(val) |
| | return val |
| |
|
| | return conversion_helper(val, half_conversion) |
| |
|
| |
|
| | def float16_to_fp32(val): |
| | """Convert fp16/bf16 `val` to fp32""" |
| |
|
| | def float_conversion(val): |
| | val_typecheck = val |
| | if isinstance(val_typecheck, (Parameter, Variable)): |
| | val_typecheck = val.data |
| | if isinstance(val_typecheck, (_BF16_TYPES, _HALF_TYPES)): |
| | val = val.float() |
| | return val |
| |
|
| | return conversion_helper(val, float_conversion) |
| |
|
| |
|
| | class Float16Module(MegatronModule): |
| | def __init__(self, module, precision): |
| | if not HAVE_APEX: |
| | raise ImportError( |
| | "Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." |
| | ) |
| | super().__init__() |
| | self.precision = precision |
| |
|
| | if precision == 16: |
| | self.add_module('module', module.half()) |
| |
|
| | def float16_converter(val): |
| | return val.half() |
| |
|
| | elif precision == 'bf16': |
| | self.add_module('module', module.bfloat16()) |
| |
|
| | def float16_converter(val): |
| | return val.bfloat16() |
| |
|
| | else: |
| | raise Exception( |
| | f'precision {precision} is not supported. Float16Module (megatron_amp_O2) supports ' |
| | 'only fp16 and bf16.' |
| | ) |
| |
|
| | self.float16_converter = float16_converter |
| |
|
| | def set_input_tensor(self, input_tensor): |
| | return self.module.set_input_tensor(input_tensor) |
| |
|
| | def forward(self, *inputs, **kwargs): |
| | |
| | if getattr(self.module, 'pre_process', True): |
| | inputs = fp32_to_float16(inputs, self.float16_converter) |
| | outputs = self.module(*inputs, **kwargs) |
| | if parallel_state.is_pipeline_last_stage(): |
| | outputs = float16_to_fp32(outputs) |
| | return outputs |
| |
|
| | def state_dict(self, destination=None, prefix='', keep_vars=False): |
| | return self.module.state_dict(destination, prefix, keep_vars) |
| |
|
| | def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): |
| | return self.module.state_dict_for_save_checkpoint(destination, prefix, keep_vars) |
| |
|
| | def word_embeddings_weight(self): |
| | if self.module.pre_process: |
| | if hasattr(self.module, 'language_model'): |
| | return self.module.language_model.embedding.word_embeddings.weight |
| | elif hasattr(self.module, 'encoder_embedding'): |
| | return self.module.encoder_embedding.word_embeddings.weight |
| | elif hasattr(self.module, 'decoder_embedding'): |
| | return self.module.decoder_embedding.word_embeddings.weight |
| | else: |
| | raise ValueError( |
| | f"Pre_process is True, but no embedding is found on this rank. Looked for language_model.embedding, encoder_embedding, and decoder_embedding" |
| | ) |
| | else: |
| | |
| | if not self.share_token_embeddings: |
| | raise Exception( |
| | 'word_embeddings_weight() called for last ' 'stage, but share_token_embeddings is false' |
| | ) |
| | return self.module.word_embeddings.weight |
| |
|
| | def position_embeddings_weight(self): |
| | if self.module.pre_process: |
| | if hasattr(self.module, 'language_model'): |
| | return self.module.language_model.embedding.position_embeddings.weight |
| | elif hasattr(self.module, 'encoder_embedding'): |
| | return self.module.encoder_embedding.position_embeddings.weight |
| | elif hasattr(self.module, 'decoder_embedding'): |
| | return self.module.decoder_embedding.position_embeddings.weight |
| | else: |
| | raise ValueError( |
| | f"Pre_process is True, but no embedding is found on this rank. Looked for language_model.position_embeddings, encoder_embedding.position_embedding_weight, and decoder_embedding.position_embedding_weight" |
| | ) |
| | else: |
| | |
| | raise ValueError(f"Pre_process is False, there is no position embedding on this rank.") |
| |
|
| | def encoder_relative_position_embeddings_weight(self): |
| | if hasattr(self.module, 'encoder_relative_position_embedding'): |
| | return self.module.encoder_relative_position_embedding.relative_position_embedding.weight |
| | else: |
| | raise ValueError( |
| | f"No encoder_relative_position_embedding found on this rank. Looking for encoder_relative_position_embedding.relative_position_embedding.weight" |
| | ) |
| |
|
| | def decoder_relative_position_embeddings_weight(self): |
| | if hasattr(self.module, 'decoder_relative_position_embedding'): |
| | return self.module.decoder_relative_position_embedding.relative_position_embedding.weight |
| | else: |
| | raise ValueError( |
| | f"No decoder_relative_position_embedding found on this rank. Looking for decoder_relative_position_embedding.relative_position_embedding.weight" |
| | ) |
| |
|
| | def decoder_cross_attention_relative_position_embeddings_weight(self): |
| | if hasattr(self.module, 'decoder_cross_attention_relative_position_embedding'): |
| | return self.module.decoder_cross_attention_relative_position_embedding.relative_position_embedding.weight |
| | else: |
| | raise ValueError( |
| | f"No decoder_cross_attention_relative_position_embedding found on this rank. Looking for decoder_cross_attention_relative_position_embedding.relative_position_embedding.weight" |
| | ) |
| |
|