| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | from omegaconf import DictConfig |
| |
|
| | from nemo.collections.nlp.modules.common.megatron.alibi_relative_position_embedding import ( |
| | ALiBiRelativePositionEmbedding, |
| | ) |
| | from nemo.collections.nlp.modules.common.megatron.language_model import Embedding |
| | from nemo.collections.nlp.modules.common.megatron.layer_type import LayerType |
| | from nemo.collections.nlp.modules.common.megatron.megatron_decoders import get_decoder_model |
| | from nemo.collections.nlp.modules.common.megatron.megatron_encoder_decoder import ( |
| | MegatronTransformerEncoderDecoderModule, |
| | ) |
| | from nemo.collections.nlp.modules.common.megatron.megatron_encoders import get_encoder_model |
| | from nemo.collections.nlp.modules.common.megatron.module import MegatronModule |
| | from nemo.collections.nlp.modules.common.megatron.t5_relative_position_embedding import T5RelativePositionEmbedding |
| | from nemo.collections.nlp.modules.common.megatron.utils import ( |
| | ApexGuardDefaults, |
| | build_position_ids, |
| | init_method_normal, |
| | parallel_lm_logits, |
| | scaled_init_method_normal, |
| | ) |
| | from nemo.collections.nlp.modules.common.megatron.vocab_parallel_cross_entropy import vocab_parallel_cross_entropy |
| |
|
| | try: |
| | from apex.transformer import parallel_state, tensor_parallel |
| | from apex.transformer.enums import AttnMaskType, ModelType |
| |
|
| | HAVE_APEX = True |
| | except (ImportError, ModuleNotFoundError): |
| | HAVE_APEX = False |
| | |
| | AttnMaskType = ApexGuardDefaults() |
| | ModelType = ApexGuardDefaults() |
| |
|
| | __all__ = ["MegatronTokenLevelHead", "MegatronTokenLevelEncoderDecoderModule"] |
| |
|
| |
|
| | class MegatronTokenLevelHead(MegatronModule): |
| | """Masked LM head for token-based encoder-decoder models (e.g., T5) |
| | |
| | Arguments: |
| | mpu_vocab_size: model parallel size of vocabulary. |
| | parallel_output: wether output logits being distributed or not. |
| | """ |
| |
|
| | def __init__(self, mpu_vocab_size, parallel_output, bias=True): |
| | super(MegatronTokenLevelHead, self).__init__() |
| |
|
| | if bias: |
| | self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) |
| | self.bias.model_parallel = True |
| | self.bias.partition_dim = 0 |
| | self.bias.stride = 1 |
| | else: |
| | self.bias = None |
| | self.parallel_output = parallel_output |
| |
|
| | def forward(self, hidden_states, word_embeddings_weight): |
| |
|
| | async_tensor_model_parallel_allreduce = parallel_state.get_tensor_model_parallel_world_size() > 1 |
| | output = parallel_lm_logits( |
| | hidden_states, |
| | word_embeddings_weight, |
| | self.parallel_output, |
| | bias=self.bias, |
| | async_tensor_model_parallel_allreduce=async_tensor_model_parallel_allreduce, |
| | ) |
| | return output |
| |
|
| |
|
| | |
| |
|
| |
|
| | class MegatronTokenLevelEncoderDecoderModule(MegatronModule): |
| | """Token-based (input/output is tokens) encoder-decoder model (e.g. T5 Language model.)""" |
| |
|
| | def __init__( |
| | self, |
| | encoder_cfg: DictConfig, |
| | decoder_cfg: DictConfig, |
| | vocab_size: int, |
| | max_position_embeddings, |
| | num_tokentypes=0, |
| | parallel_output=True, |
| | pre_process=True, |
| | post_process=True, |
| | fp16_cross_entropy=False, |
| | use_cpu_initialization=False, |
| | precision=16, |
| | embedding_init_method_std=0.02, |
| | embedding_dropout=0.1, |
| | label_smoothing=0.0, |
| | add_encoder=True, |
| | add_decoder=True, |
| | share_token_embeddings=True, |
| | share_decoder_tokens_head_embeddings=True, |
| | tokens_head_bias=True, |
| | ): |
| | super(MegatronTokenLevelEncoderDecoderModule, self).__init__() |
| |
|
| | self.encoder_cfg = encoder_cfg |
| | self.decoder_cfg = decoder_cfg |
| | self.parallel_output = parallel_output |
| | self.pre_process = pre_process |
| | self.post_process = post_process |
| | self.fp16_cross_entropy = fp16_cross_entropy |
| | self.precision = precision |
| | self.add_encoder = add_encoder |
| | self.add_decoder = add_decoder |
| | self.label_smoothing = label_smoothing |
| | self.share_token_embeddings = share_token_embeddings |
| | self.share_decoder_tokens_head_embeddings = share_decoder_tokens_head_embeddings |
| | self.tokens_head_bias = tokens_head_bias |
| |
|
| | encoder_kv_channels, decoder_kv_channels = self._validate_config() |
| |
|
| | encoder, decoder = None, None |
| | if add_encoder: |
| | if pre_process: |
| | self.encoder_embedding = Embedding( |
| | hidden_size=encoder_cfg.hidden_size, |
| | vocab_size=vocab_size, |
| | max_sequence_length=max_position_embeddings, |
| | init_method=init_method_normal(embedding_init_method_std), |
| | num_tokentypes=num_tokentypes, |
| | use_cpu_initialization=use_cpu_initialization, |
| | embedding_dropout_prob=embedding_dropout, |
| | position_embedding_type=encoder_cfg.get('position_embedding_type', 'learned_absolute'), |
| | ) |
| | self._encoder_embedding_key = "encoder_embedding" |
| | if self.encoder_cfg.get('position_embedding_type', 'learned_absolute') == 'relative': |
| | self.encoder_relative_position_embedding = T5RelativePositionEmbedding( |
| | init_method=init_method_normal(embedding_init_method_std), |
| | num_attention_heads=encoder_cfg.num_attention_heads, |
| | relative_position_num_buckets=encoder_cfg.relative_attention_num_buckets, |
| | relative_position_max_distance=encoder_cfg.relative_attention_max_distance, |
| | bidirectional=True, |
| | layer_type=LayerType.encoder, |
| | ) |
| | self._encoder_relative_position_embedding_key = "encoder_relative_position_embedding" |
| | |
| | if parallel_state.get_pipeline_model_parallel_rank() != 0: |
| | self.encoder_relative_position_embeddings_weight().data.fill_(0) |
| | self.encoder_relative_position_embeddings_weight().shared = True |
| | elif self.encoder_cfg.get('position_embedding_type', 'learned_absolute') == 'alibi': |
| | self.encoder_relative_position_embedding = ALiBiRelativePositionEmbedding( |
| | bidirectional=True, |
| | num_attention_heads=encoder_cfg.num_attention_heads, |
| | layer_type=LayerType.encoder, |
| | num_attention_heads_alibi=None, |
| | max_seq_len=max_position_embeddings, |
| | ) |
| | self._encoder_relative_position_embedding_key = "encoder_relative_position_embedding" |
| | else: |
| | self.encoder_relative_position_embedding = None |
| |
|
| | encoder = get_encoder_model( |
| | arch=encoder_cfg.arch, |
| | hidden_size=encoder_cfg.hidden_size, |
| | ffn_hidden_size=encoder_cfg.ffn_hidden_size, |
| | num_layers=encoder_cfg.num_layers, |
| | num_attention_heads=encoder_cfg.num_attention_heads, |
| | apply_query_key_layer_scaling=encoder_cfg.get('apply_query_key_layer_scaling', True), |
| | kv_channels=encoder_kv_channels, |
| | init_method=init_method_normal(encoder_cfg.get('init_method_std', 0.02)), |
| | scaled_init_method=scaled_init_method_normal( |
| | encoder_cfg.get('init_method_std', 0.02), encoder_cfg.num_layers |
| | ), |
| | encoder_attn_mask_type=AttnMaskType.padding, |
| | pre_process=pre_process, |
| | post_process=post_process, |
| | init_method_std=encoder_cfg.get('init_method_std', 0.02), |
| | use_cpu_initialization=use_cpu_initialization, |
| | hidden_dropout=encoder_cfg.get('hidden_dropout', 0.1), |
| | attention_dropout=encoder_cfg.get('attention_dropout', 0.1), |
| | ffn_dropout=encoder_cfg.get('ffn_dropout', 0.0), |
| | precision=precision, |
| | fp32_residual_connection=encoder_cfg.get('fp32_residual_connection', False), |
| | activations_checkpoint_method=encoder_cfg.get('activations_checkpoint_method', None), |
| | activations_checkpoint_num_layers=encoder_cfg.get('activations_checkpoint_num_layers', 1), |
| | activations_checkpoint_granularity=encoder_cfg.get('activations_checkpoint_granularity', None), |
| | layernorm_epsilon=encoder_cfg.get('layernorm_epsilon', 1e-5), |
| | bias_activation_fusion=encoder_cfg.get('bias_activation_fusion', True), |
| | bias_dropout_add_fusion=encoder_cfg.get('bias_dropout_add_fusion', True), |
| | masked_softmax_fusion=encoder_cfg.get('masked_softmax_fusion', True), |
| | persist_layer_norm=encoder_cfg.get('persist_layer_norm', True), |
| | openai_gelu=encoder_cfg.get('openai_gelu', False), |
| | onnx_safe=encoder_cfg.get('onnx_safe', False), |
| | hidden_steps=encoder_cfg.get('hidden_steps', -1), |
| | activation=encoder_cfg.get('activation', 'gelu'), |
| | bias=encoder_cfg.get('bias', True), |
| | normalization=encoder_cfg.get('normalization', 'layernorm'), |
| | transformer_block_type=encoder_cfg.get('transformer_block_type', 'pre_ln'), |
| | headscale=encoder_cfg.get('headscale', False), |
| | parent_model_type=ModelType.encoder_and_decoder, |
| | num_self_attention_per_cross_attention=encoder_cfg.get('num_self_attention_per_cross_attention', 1), |
| | megatron_legacy=encoder_cfg.get('megatron_legacy', False), |
| | normalize_attention_scores=encoder_cfg.get('normalize_attention_scores', True), |
| | num_moe_experts=encoder_cfg.get('num_moe_experts', 1), |
| | moe_frequency=encoder_cfg.get('moe_frequency', 1), |
| | moe_dropout=encoder_cfg.get('moe_dropout', 0.0), |
| | ) |
| |
|
| | if add_decoder: |
| | |
| | if pre_process: |
| | |
| | if hasattr(self, 'encoder_embedding') and share_token_embeddings: |
| | self.decoder_embedding = self.encoder_embedding |
| | else: |
| | |
| | self.decoder_embedding = Embedding( |
| | hidden_size=decoder_cfg.hidden_size, |
| | vocab_size=vocab_size, |
| | max_sequence_length=max_position_embeddings, |
| | init_method=init_method_normal(embedding_init_method_std), |
| | num_tokentypes=num_tokentypes, |
| | use_cpu_initialization=use_cpu_initialization, |
| | embedding_dropout_prob=embedding_dropout, |
| | position_embedding_type=decoder_cfg.get('position_embedding_type', 'learned_absolute'), |
| | ) |
| | |
| | |
| | if share_token_embeddings: |
| | self.decoder_embedding.zero_parameters() |
| |
|
| | self._decoder_embedding_key = "decoder_embedding" |
| |
|
| | if self.decoder_cfg.get('position_embedding_type', 'learned_absolute') == 'relative': |
| | self.decoder_relative_position_embedding = T5RelativePositionEmbedding( |
| | init_method=init_method_normal(embedding_init_method_std), |
| | num_attention_heads=decoder_cfg.num_attention_heads, |
| | relative_position_num_buckets=decoder_cfg.relative_attention_num_buckets, |
| | relative_position_max_distance=decoder_cfg.relative_attention_max_distance, |
| | bidirectional=False, |
| | layer_type=LayerType.decoder, |
| | ) |
| | self._decoder_relative_position_embedding_key = "decoder_relative_position_embedding" |
| | |
| | if ( |
| | parallel_state.get_pipeline_model_parallel_rank() |
| | != parallel_state.get_pipeline_model_parallel_split_rank() |
| | ): |
| | self.decoder_relative_position_embeddings_weight().data.fill_(0) |
| | self.decoder_relative_position_embeddings_weight().shared = True |
| |
|
| | if not self.decoder_cfg.relative_position_bias_self_attention_only: |
| | self.decoder_cross_attention_relative_position_embedding = T5RelativePositionEmbedding( |
| | init_method=init_method_normal(embedding_init_method_std), |
| | num_attention_heads=decoder_cfg.num_attention_heads, |
| | relative_position_num_buckets=decoder_cfg.relative_attention_num_buckets, |
| | relative_position_max_distance=decoder_cfg.relative_attention_max_distance, |
| | bidirectional=True, |
| | layer_type=LayerType.decoder, |
| | ) |
| | self._decoder_cross_attention_relative_position_embedding_key = ( |
| | "decoder_cross_attention_relative_position_embedding" |
| | ) |
| | if ( |
| | parallel_state.get_pipeline_model_parallel_rank() |
| | != parallel_state.get_pipeline_model_parallel_split_rank() |
| | ): |
| | self.decoder_cross_attention_relative_position_embeddings_weight().data.fill_(0) |
| | self.decoder_cross_attention_relative_position_embeddings_weight().shared = True |
| | elif self.decoder_cfg.get('position_embedding_type', 'learned_absolute') == 'alibi': |
| | self.decoder_relative_position_embedding = ALiBiRelativePositionEmbedding( |
| | bidirectional=False, |
| | num_attention_heads=decoder_cfg.num_attention_heads, |
| | layer_type=LayerType.decoder, |
| | num_attention_heads_alibi=None, |
| | max_seq_len=max_position_embeddings, |
| | ) |
| | self._decoder_relative_position_embedding_key = "decoder_relative_position_embedding" |
| | else: |
| | self.decoder_relative_position_embedding = None |
| |
|
| | decoder = get_decoder_model( |
| | arch=decoder_cfg.arch, |
| | hidden_size=decoder_cfg.hidden_size, |
| | ffn_hidden_size=decoder_cfg.ffn_hidden_size, |
| | num_layers=decoder_cfg.num_layers, |
| | num_attention_heads=decoder_cfg.num_attention_heads, |
| | apply_query_key_layer_scaling=decoder_cfg.get('apply_query_key_layer_scaling', True), |
| | kv_channels=decoder_kv_channels, |
| | init_method=init_method_normal(decoder_cfg.get('init_method_std', 0.02)), |
| | scaled_init_method=scaled_init_method_normal( |
| | decoder_cfg.get('init_method_std', 0.02), decoder_cfg.num_layers |
| | ), |
| | decoder_attn_mask_type=AttnMaskType.causal, |
| | pre_process=pre_process, |
| | post_process=post_process, |
| | init_method_std=decoder_cfg.get('init_method_std', 0.02), |
| | use_cpu_initialization=use_cpu_initialization, |
| | hidden_dropout=decoder_cfg.get('hidden_dropout', 0.1), |
| | attention_dropout=decoder_cfg.get('attention_dropout', 0.1), |
| | ffn_dropout=decoder_cfg.get('ffn_dropout', 0.0), |
| | precision=precision, |
| | fp32_residual_connection=decoder_cfg.get('fp32_residual_connection', False), |
| | activations_checkpoint_method=decoder_cfg.get('activations_checkpoint_method', None), |
| | activations_checkpoint_num_layers=decoder_cfg.get('activations_checkpoint_num_layers', 1), |
| | activations_checkpoint_granularity=decoder_cfg.get('activations_checkpoint_granularity', None), |
| | layernorm_epsilon=decoder_cfg.get('layernorm_epsilon', 1e-5), |
| | bias_activation_fusion=decoder_cfg.get('bias_activation_fusion', True), |
| | bias_dropout_add_fusion=decoder_cfg.get('bias_dropout_add_fusion', True), |
| | masked_softmax_fusion=decoder_cfg.get('masked_softmax_fusion', True), |
| | persist_layer_norm=decoder_cfg.get('persist_layer_norm', True), |
| | openai_gelu=decoder_cfg.get('openai_gelu', False), |
| | onnx_safe=decoder_cfg.get('onnx_safe', False), |
| | hidden_steps=decoder_cfg.get('hidden_steps', -1), |
| | activation=decoder_cfg.get('activation', 'gelu'), |
| | bias=decoder_cfg.get('bias', True), |
| | normalization=decoder_cfg.get('normalization', 'layernorm'), |
| | transformer_block_type=decoder_cfg.get('transformer_block_type', 'pre_ln'), |
| | headscale=decoder_cfg.get('headscale', False), |
| | parent_model_type=ModelType.encoder_and_decoder, |
| | megatron_legacy=decoder_cfg.get('megatron_legacy', False), |
| | normalize_attention_scores=decoder_cfg.get('normalize_attention_scores', True), |
| | num_moe_experts=decoder_cfg.get('num_moe_experts', 1), |
| | moe_frequency=decoder_cfg.get('moe_frequency', 1), |
| | moe_dropout=decoder_cfg.get('moe_dropout', 0.0), |
| | ) |
| |
|
| | self.enc_dec_model = MegatronTransformerEncoderDecoderModule( |
| | encoder=encoder, decoder=decoder, hidden_steps=encoder_cfg.get('hidden_steps', -1), |
| | ) |
| | self._enc_dec_model_key = "enc_dec_model" |
| |
|
| | if self.share_token_embeddings: |
| | |
| | self.initialize_word_embeddings( |
| | init_method=init_method_normal(embedding_init_method_std), |
| | vocab_size=vocab_size, |
| | hidden_size=encoder_cfg.hidden_size, |
| | ) |
| |
|
| | if add_decoder and post_process: |
| | if share_decoder_tokens_head_embeddings: |
| | self.tokens_head = MegatronTokenLevelHead( |
| | self.word_embeddings_weight().size(0), parallel_output, bias=tokens_head_bias |
| | ) |
| | else: |
| | self.tokens_head = tensor_parallel.ColumnParallelLinear( |
| | input_size=decoder_cfg.hidden_size, |
| | output_size=vocab_size, |
| | bias=tokens_head_bias, |
| | gather_output=not self.parallel_output, |
| | init_method=init_method_normal(decoder_cfg.init_method_std), |
| | use_cpu_initialization=use_cpu_initialization, |
| | ) |
| |
|
| | self._tokens_head_key = 'tokens_head' |
| |
|
| | def _validate_kv_channels(self, cfg): |
| | kv_channels = cfg.kv_channels |
| | if cfg.kv_channels is None: |
| | assert ( |
| | cfg.hidden_size % cfg.num_attention_heads == 0 |
| | ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None' |
| | kv_channels = cfg.hidden_size // cfg.num_attention_heads |
| |
|
| | return kv_channels |
| |
|
| | def _validate_enc_dec_hidden_size(self, encoder_cfg, decoder_cfg): |
| | if encoder_cfg.hidden_size != decoder_cfg.hidden_size: |
| | raise ValueError( |
| | f"Encoder and decoder hidden_size must be equal, but got encoder: {encoder_cfg.hidden_size} and decoder: {decoder_cfg.hidden_size}" |
| | ) |
| |
|
| | def _validate_perceiver_config(self, cfg): |
| | if ( |
| | cfg.get("position_embedding_type", "learned_absolute") == "relative" |
| | and cfg.get("arch", "transformer") == "perceiver" |
| | ): |
| | raise ValueError(f"Perceivers with relative position embeddings are not supported") |
| |
|
| | def _validate_config(self): |
| | encoder_kv_channels = self._validate_kv_channels(self.encoder_cfg) |
| | decoder_kv_channels = self._validate_kv_channels(self.decoder_cfg) |
| | self._validate_enc_dec_hidden_size(self.encoder_cfg, self.decoder_cfg) |
| | self._validate_perceiver_config(self.encoder_cfg) |
| | self._validate_perceiver_config(self.decoder_cfg) |
| | if parallel_state.get_pipeline_model_parallel_world_size() > 1: |
| | assert ( |
| | self.share_token_embeddings |
| | ), "Token embeddings must be shared when using pipeline model parallel size > 1" |
| | assert ( |
| | self.share_decoder_tokens_head_embeddings |
| | ), "Decoder token embeddings and the outputlayer must be shared when using pipeline model parallel size > 1" |
| | return encoder_kv_channels, decoder_kv_channels |
| |
|
| | def set_input_tensor(self, input_tensor): |
| | """ See megatron.model.transformer.set_input_tensor()""" |
| | |
| | |
| |
|
| | if not isinstance(input_tensor, list): |
| | input_tensor = [input_tensor] |
| |
|
| | if self.add_encoder and self.add_decoder: |
| | assert ( |
| | len(input_tensor) == 1 |
| | ), 'input_tensor should only be length 1 for stage with both encoder and decoder' |
| | self.enc_dec_model.encoder.set_input_tensor(input_tensor[0]) |
| | elif self.add_encoder: |
| | assert len(input_tensor) == 1, 'input_tensor should only be length 1 for stage with only encoder' |
| | self.enc_dec_model.encoder.set_input_tensor(input_tensor[0]) |
| | elif self.add_decoder: |
| | if len(input_tensor) == 2: |
| | self.enc_dec_model.decoder.set_input_tensor(input_tensor[0]) |
| | self.enc_dec_model.encoder_hidden_state = input_tensor[1] |
| | elif len(input_tensor) == 1: |
| | self.enc_dec_model.decoder.set_input_tensor(None) |
| | self.enc_dec_model.encoder_hidden_state = input_tensor[0] |
| | else: |
| | raise Exception('input_tensor must have either length 1 or 2') |
| | else: |
| | raise Exception('Stage must have at least either encoder or decoder') |
| |
|
| | def forward( |
| | self, |
| | enc_input_ids=None, |
| | enc_attn_mask=None, |
| | dec_input_ids=None, |
| | dec_attn_mask=None, |
| | token_type_ids=None, |
| | labels=None, |
| | enc_output=None, |
| | enc_output_attn_mask=None, |
| | enc_input=None, |
| | output_enc_hidden_only=False, |
| | ): |
| | """ |
| | Return value is per token / per dimension (i.e., non collapsed loss value) |
| | """ |
| | ( |
| | encoder_self_attention_relative_position_bias, |
| | decoder_self_attention_relative_position_bias, |
| | decoder_cross_attention_relative_position_bias, |
| | ) = (None, None, None) |
| |
|
| | if enc_input is not None and enc_output is not None: |
| | raise ValueError( |
| | """Both enc_input and enc_output are not None. |
| | You should only be passing one of them. |
| | enc_input is the result of the encoder embedding layer |
| | enc_output is the result of running the entire transformer encoder.""" |
| | ) |
| |
|
| | |
| | if enc_output is not None: |
| | |
| | enc_output = enc_output.transpose(0, 1) |
| | enc_seq_length = enc_output.size(0) |
| | elif enc_input is not None: |
| | |
| | enc_input = enc_input.transpose(0, 1) |
| | enc_seq_length = enc_input.size(0) |
| | |
| | elif enc_input_ids is not None: |
| | enc_seq_length = enc_input_ids.size(1) |
| | if self.pre_process and self.add_encoder: |
| | |
| | if self.encoder_relative_position_embedding is None: |
| | enc_position_ids = build_position_ids(enc_input_ids) |
| | else: |
| | enc_position_ids = None |
| | enc_input = self.encoder_embedding(enc_input_ids, enc_position_ids, token_type_ids=token_type_ids) |
| | else: |
| | enc_input = None |
| | else: |
| | |
| | enc_seq_length = enc_attn_mask.size(1) |
| |
|
| | if self.add_encoder and self.encoder_relative_position_embedding is not None: |
| | encoder_self_attention_relative_position_bias = self.encoder_relative_position_embedding( |
| | query_seq_length=enc_seq_length, key_seq_length=enc_seq_length, |
| | ) |
| |
|
| | if output_enc_hidden_only: |
| | |
| | if enc_output is None and self.enc_dec_model.encoder is not None: |
| | enc_output = self.enc_dec_model.encode( |
| | enc_input=enc_input, |
| | enc_attn_mask=enc_attn_mask, |
| | enc_layer_past=None, |
| | enc_get_key_value=False, |
| | enc_self_attention_relative_position_bias=encoder_self_attention_relative_position_bias, |
| | ) |
| | else: |
| | enc_output = self.enc_dec_model.encoder_hidden_state |
| | return enc_output |
| | else: |
| | if enc_output_attn_mask is None: |
| | enc_output_attn_mask = enc_attn_mask |
| |
|
| | if self.pre_process and self.add_decoder: |
| | |
| | if self.decoder_relative_position_embedding is None: |
| | dec_position_ids = build_position_ids(dec_input_ids) |
| | else: |
| | dec_position_ids = None |
| | dec_input = self.decoder_embedding(dec_input_ids, dec_position_ids, token_type_ids=token_type_ids) |
| | else: |
| | |
| | dec_input = None |
| |
|
| | if self.add_decoder and self.decoder_relative_position_embedding is not None: |
| | decoder_self_attention_relative_position_bias = self.decoder_relative_position_embedding( |
| | query_seq_length=dec_input_ids.size(1), key_seq_length=dec_input_ids.size(1) |
| | ) |
| | if not self.decoder_cfg.relative_position_bias_self_attention_only: |
| | decoder_cross_attention_relative_position_bias = self.decoder_cross_attention_relative_position_embedding( |
| | query_seq_length=dec_input_ids.size(1), key_seq_length=enc_seq_length, |
| | ) |
| | else: |
| | decoder_cross_attention_relative_position_bias = None |
| |
|
| | output = self.enc_dec_model( |
| | enc_input=enc_input, |
| | enc_attn_mask=enc_attn_mask, |
| | dec_input=dec_input, |
| | dec_attn_mask=dec_attn_mask, |
| | enc_layer_past=None, |
| | enc_get_key_value=False, |
| | enc_output=enc_output, |
| | enc_output_attn_mask=enc_output_attn_mask, |
| | dec_layer_past=None, |
| | dec_get_key_value=False, |
| | enc_self_attention_relative_position_bias=encoder_self_attention_relative_position_bias, |
| | dec_self_attention_relative_position_bias=decoder_self_attention_relative_position_bias, |
| | dec_cross_attention_relative_position_bias=decoder_cross_attention_relative_position_bias, |
| | ) |
| |
|
| | if self.post_process and self.add_decoder: |
| | dec_output, enc_output = output |
| | |
| | if self.share_decoder_tokens_head_embeddings: |
| | token_logits = self.tokens_head(dec_output, self.word_embeddings_weight()) |
| | else: |
| | token_logits = self.tokens_head(dec_output)[0] |
| |
|
| | if labels is not None: |
| | |
| | labels = labels.transpose(0, 1).contiguous() |
| |
|
| | |
| | label_smoothing = self.label_smoothing if self.training else 0.0 |
| |
|
| | |
| | if self.fp16_cross_entropy: |
| | assert token_logits.dtype == torch.half |
| | tokens_loss = vocab_parallel_cross_entropy(token_logits, labels, label_smoothing) |
| | else: |
| | tokens_loss = vocab_parallel_cross_entropy(token_logits.float(), labels, label_smoothing) |
| |
|
| | |
| | tokens_loss = tokens_loss.transpose(0, 1).contiguous() |
| |
|
| | return tokens_loss |
| | else: |
| | |
| | token_logits = token_logits.transpose(0, 1).contiguous() |
| | return token_logits |
| |
|
| | elif self.add_decoder and not self.add_encoder: |
| | decoder_output, _ = output |
| | return decoder_output |
| | else: |
| | encoder_output = output |
| | return encoder_output |
| |
|
| | def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): |
| | """For easy load when model is combined with other heads, |
| | add an extra key.""" |
| |
|
| | state_dict_ = {} |
| |
|
| | state_dict_[self._encoder_embedding_key] = self.encoder_embedding.state_dict_for_save_checkpoint( |
| | destination, prefix, keep_vars |
| | ) |
| | state_dict_[self._decoder_embedding_key] = self.decoder_embedding.state_dict_for_save_checkpoint( |
| | destination, prefix, keep_vars |
| | ) |
| | state_dict_[self._enc_dec_model_key] = self.enc_dec_model.state_dict_for_save_checkpoint( |
| | destination, prefix, keep_vars |
| | ) |
| | state_dict_[self._tokens_head_key] = self.tokens_head.state_dict_for_save_checkpoint( |
| | destination, prefix, keep_vars |
| | ) |
| |
|
| | return state_dict_ |
| |
|
| | def load_state_dict(self, state_dict, strict=True): |
| | """Customized load.""" |
| |
|
| | self.encoder_embedding.encoder_embeddingload_state_dict(state_dict[self._encoder_embedding_key], strict=strict) |
| | self.decoder_embedding.load_state_dict(state_dict[self._decoder_embedding_key], strict=strict) |
| | self.enc_dec_model.load_state_dict(state_dict[self._enc_dec_model_key], strict=strict) |
| | self.tokens_head.load_state_dict(state_dict[self._tokens_head_key], strict=strict) |
| |
|