NeMo / nemo /collections /nlp /modules /common /megatron /token_level_encoder_decoder.py
camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from omegaconf import DictConfig
from nemo.collections.nlp.modules.common.megatron.alibi_relative_position_embedding import (
ALiBiRelativePositionEmbedding,
)
from nemo.collections.nlp.modules.common.megatron.language_model import Embedding
from nemo.collections.nlp.modules.common.megatron.layer_type import LayerType
from nemo.collections.nlp.modules.common.megatron.megatron_decoders import get_decoder_model
from nemo.collections.nlp.modules.common.megatron.megatron_encoder_decoder import (
MegatronTransformerEncoderDecoderModule,
)
from nemo.collections.nlp.modules.common.megatron.megatron_encoders import get_encoder_model
from nemo.collections.nlp.modules.common.megatron.module import MegatronModule
from nemo.collections.nlp.modules.common.megatron.t5_relative_position_embedding import T5RelativePositionEmbedding
from nemo.collections.nlp.modules.common.megatron.utils import (
ApexGuardDefaults,
build_position_ids,
init_method_normal,
parallel_lm_logits,
scaled_init_method_normal,
)
from nemo.collections.nlp.modules.common.megatron.vocab_parallel_cross_entropy import vocab_parallel_cross_entropy
try:
from apex.transformer import parallel_state, tensor_parallel
from apex.transformer.enums import AttnMaskType, ModelType
HAVE_APEX = True
except (ImportError, ModuleNotFoundError):
HAVE_APEX = False
# fake missing classes with None attributes
AttnMaskType = ApexGuardDefaults()
ModelType = ApexGuardDefaults()
__all__ = ["MegatronTokenLevelHead", "MegatronTokenLevelEncoderDecoderModule"]
class MegatronTokenLevelHead(MegatronModule):
"""Masked LM head for token-based encoder-decoder models (e.g., T5)
Arguments:
mpu_vocab_size: model parallel size of vocabulary.
parallel_output: wether output logits being distributed or not.
"""
def __init__(self, mpu_vocab_size, parallel_output, bias=True):
super(MegatronTokenLevelHead, self).__init__()
if bias:
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
self.bias.model_parallel = True
self.bias.partition_dim = 0
self.bias.stride = 1
else:
self.bias = None
self.parallel_output = parallel_output
def forward(self, hidden_states, word_embeddings_weight):
async_tensor_model_parallel_allreduce = parallel_state.get_tensor_model_parallel_world_size() > 1
output = parallel_lm_logits(
hidden_states,
word_embeddings_weight,
self.parallel_output,
bias=self.bias,
async_tensor_model_parallel_allreduce=async_tensor_model_parallel_allreduce,
)
return output
# TODO: add soft prompts as an Embedding sub-class
class MegatronTokenLevelEncoderDecoderModule(MegatronModule):
"""Token-based (input/output is tokens) encoder-decoder model (e.g. T5 Language model.)"""
def __init__(
self,
encoder_cfg: DictConfig,
decoder_cfg: DictConfig,
vocab_size: int, # TODO: This should eventually go inside encoder_cfg and decoder_cfg when separate enc/dec tokenizers are supported.
max_position_embeddings,
num_tokentypes=0,
parallel_output=True,
pre_process=True,
post_process=True,
fp16_cross_entropy=False,
use_cpu_initialization=False,
precision=16,
embedding_init_method_std=0.02,
embedding_dropout=0.1,
label_smoothing=0.0,
add_encoder=True,
add_decoder=True,
share_token_embeddings=True,
share_decoder_tokens_head_embeddings=True,
tokens_head_bias=True,
):
super(MegatronTokenLevelEncoderDecoderModule, self).__init__()
self.encoder_cfg = encoder_cfg
self.decoder_cfg = decoder_cfg
self.parallel_output = parallel_output
self.pre_process = pre_process
self.post_process = post_process
self.fp16_cross_entropy = fp16_cross_entropy
self.precision = precision
self.add_encoder = add_encoder
self.add_decoder = add_decoder
self.label_smoothing = label_smoothing
self.share_token_embeddings = share_token_embeddings
self.share_decoder_tokens_head_embeddings = share_decoder_tokens_head_embeddings
self.tokens_head_bias = tokens_head_bias
encoder_kv_channels, decoder_kv_channels = self._validate_config()
encoder, decoder = None, None
if add_encoder:
if pre_process:
self.encoder_embedding = Embedding(
hidden_size=encoder_cfg.hidden_size,
vocab_size=vocab_size,
max_sequence_length=max_position_embeddings,
init_method=init_method_normal(embedding_init_method_std),
num_tokentypes=num_tokentypes,
use_cpu_initialization=use_cpu_initialization,
embedding_dropout_prob=embedding_dropout,
position_embedding_type=encoder_cfg.get('position_embedding_type', 'learned_absolute'),
)
self._encoder_embedding_key = "encoder_embedding"
if self.encoder_cfg.get('position_embedding_type', 'learned_absolute') == 'relative':
self.encoder_relative_position_embedding = T5RelativePositionEmbedding(
init_method=init_method_normal(embedding_init_method_std),
num_attention_heads=encoder_cfg.num_attention_heads,
relative_position_num_buckets=encoder_cfg.relative_attention_num_buckets,
relative_position_max_distance=encoder_cfg.relative_attention_max_distance,
bidirectional=True,
layer_type=LayerType.encoder,
)
self._encoder_relative_position_embedding_key = "encoder_relative_position_embedding"
# Pipeline model parallel rank 0 will have the actual RPE weights. We zero it out on all other ranks and then sync them on setup.
if parallel_state.get_pipeline_model_parallel_rank() != 0:
self.encoder_relative_position_embeddings_weight().data.fill_(0)
self.encoder_relative_position_embeddings_weight().shared = True
elif self.encoder_cfg.get('position_embedding_type', 'learned_absolute') == 'alibi':
self.encoder_relative_position_embedding = ALiBiRelativePositionEmbedding(
bidirectional=True,
num_attention_heads=encoder_cfg.num_attention_heads,
layer_type=LayerType.encoder,
num_attention_heads_alibi=None,
max_seq_len=max_position_embeddings,
)
self._encoder_relative_position_embedding_key = "encoder_relative_position_embedding"
else:
self.encoder_relative_position_embedding = None
encoder = get_encoder_model(
arch=encoder_cfg.arch,
hidden_size=encoder_cfg.hidden_size,
ffn_hidden_size=encoder_cfg.ffn_hidden_size,
num_layers=encoder_cfg.num_layers,
num_attention_heads=encoder_cfg.num_attention_heads,
apply_query_key_layer_scaling=encoder_cfg.get('apply_query_key_layer_scaling', True),
kv_channels=encoder_kv_channels,
init_method=init_method_normal(encoder_cfg.get('init_method_std', 0.02)),
scaled_init_method=scaled_init_method_normal(
encoder_cfg.get('init_method_std', 0.02), encoder_cfg.num_layers
),
encoder_attn_mask_type=AttnMaskType.padding,
pre_process=pre_process,
post_process=post_process,
init_method_std=encoder_cfg.get('init_method_std', 0.02),
use_cpu_initialization=use_cpu_initialization,
hidden_dropout=encoder_cfg.get('hidden_dropout', 0.1),
attention_dropout=encoder_cfg.get('attention_dropout', 0.1),
ffn_dropout=encoder_cfg.get('ffn_dropout', 0.0),
precision=precision,
fp32_residual_connection=encoder_cfg.get('fp32_residual_connection', False),
activations_checkpoint_method=encoder_cfg.get('activations_checkpoint_method', None),
activations_checkpoint_num_layers=encoder_cfg.get('activations_checkpoint_num_layers', 1),
activations_checkpoint_granularity=encoder_cfg.get('activations_checkpoint_granularity', None),
layernorm_epsilon=encoder_cfg.get('layernorm_epsilon', 1e-5),
bias_activation_fusion=encoder_cfg.get('bias_activation_fusion', True),
bias_dropout_add_fusion=encoder_cfg.get('bias_dropout_add_fusion', True),
masked_softmax_fusion=encoder_cfg.get('masked_softmax_fusion', True),
persist_layer_norm=encoder_cfg.get('persist_layer_norm', True),
openai_gelu=encoder_cfg.get('openai_gelu', False),
onnx_safe=encoder_cfg.get('onnx_safe', False),
hidden_steps=encoder_cfg.get('hidden_steps', -1),
activation=encoder_cfg.get('activation', 'gelu'),
bias=encoder_cfg.get('bias', True),
normalization=encoder_cfg.get('normalization', 'layernorm'),
transformer_block_type=encoder_cfg.get('transformer_block_type', 'pre_ln'),
headscale=encoder_cfg.get('headscale', False),
parent_model_type=ModelType.encoder_and_decoder,
num_self_attention_per_cross_attention=encoder_cfg.get('num_self_attention_per_cross_attention', 1),
megatron_legacy=encoder_cfg.get('megatron_legacy', False),
normalize_attention_scores=encoder_cfg.get('normalize_attention_scores', True),
num_moe_experts=encoder_cfg.get('num_moe_experts', 1),
moe_frequency=encoder_cfg.get('moe_frequency', 1),
moe_dropout=encoder_cfg.get('moe_dropout', 0.0),
)
if add_decoder:
# If this is the decoder first stage
if pre_process:
# If the encoder also lies on this rank (PP = 1), then just assign embeddings directly.
if hasattr(self, 'encoder_embedding') and share_token_embeddings:
self.decoder_embedding = self.encoder_embedding
else:
# This is the case where PP > 1 and first decoder first stage, or when not sharing embeddings with encoder
self.decoder_embedding = Embedding(
hidden_size=decoder_cfg.hidden_size,
vocab_size=vocab_size,
max_sequence_length=max_position_embeddings,
init_method=init_method_normal(embedding_init_method_std),
num_tokentypes=num_tokentypes,
use_cpu_initialization=use_cpu_initialization,
embedding_dropout_prob=embedding_dropout,
position_embedding_type=decoder_cfg.get('position_embedding_type', 'learned_absolute'),
)
# We initialize decoder embeddings, but set them to zero since we they're tied with the encoder embeddings.
# A later initialize_embedding call will synchronize the embeddings.
if share_token_embeddings:
self.decoder_embedding.zero_parameters()
self._decoder_embedding_key = "decoder_embedding"
if self.decoder_cfg.get('position_embedding_type', 'learned_absolute') == 'relative':
self.decoder_relative_position_embedding = T5RelativePositionEmbedding(
init_method=init_method_normal(embedding_init_method_std),
num_attention_heads=decoder_cfg.num_attention_heads,
relative_position_num_buckets=decoder_cfg.relative_attention_num_buckets,
relative_position_max_distance=decoder_cfg.relative_attention_max_distance,
bidirectional=False,
layer_type=LayerType.decoder,
)
self._decoder_relative_position_embedding_key = "decoder_relative_position_embedding"
# Pipeline model parallel rank == split_rank will have the actual RPE weights. We zero it out on all other ranks and then sync them on setup.
if (
parallel_state.get_pipeline_model_parallel_rank()
!= parallel_state.get_pipeline_model_parallel_split_rank()
):
self.decoder_relative_position_embeddings_weight().data.fill_(0)
self.decoder_relative_position_embeddings_weight().shared = True
if not self.decoder_cfg.relative_position_bias_self_attention_only:
self.decoder_cross_attention_relative_position_embedding = T5RelativePositionEmbedding(
init_method=init_method_normal(embedding_init_method_std),
num_attention_heads=decoder_cfg.num_attention_heads,
relative_position_num_buckets=decoder_cfg.relative_attention_num_buckets,
relative_position_max_distance=decoder_cfg.relative_attention_max_distance,
bidirectional=True,
layer_type=LayerType.decoder,
)
self._decoder_cross_attention_relative_position_embedding_key = (
"decoder_cross_attention_relative_position_embedding"
)
if (
parallel_state.get_pipeline_model_parallel_rank()
!= parallel_state.get_pipeline_model_parallel_split_rank()
):
self.decoder_cross_attention_relative_position_embeddings_weight().data.fill_(0)
self.decoder_cross_attention_relative_position_embeddings_weight().shared = True
elif self.decoder_cfg.get('position_embedding_type', 'learned_absolute') == 'alibi':
self.decoder_relative_position_embedding = ALiBiRelativePositionEmbedding(
bidirectional=False,
num_attention_heads=decoder_cfg.num_attention_heads,
layer_type=LayerType.decoder,
num_attention_heads_alibi=None,
max_seq_len=max_position_embeddings,
)
self._decoder_relative_position_embedding_key = "decoder_relative_position_embedding"
else:
self.decoder_relative_position_embedding = None
decoder = get_decoder_model(
arch=decoder_cfg.arch,
hidden_size=decoder_cfg.hidden_size,
ffn_hidden_size=decoder_cfg.ffn_hidden_size,
num_layers=decoder_cfg.num_layers,
num_attention_heads=decoder_cfg.num_attention_heads,
apply_query_key_layer_scaling=decoder_cfg.get('apply_query_key_layer_scaling', True),
kv_channels=decoder_kv_channels,
init_method=init_method_normal(decoder_cfg.get('init_method_std', 0.02)),
scaled_init_method=scaled_init_method_normal(
decoder_cfg.get('init_method_std', 0.02), decoder_cfg.num_layers
),
decoder_attn_mask_type=AttnMaskType.causal,
pre_process=pre_process,
post_process=post_process,
init_method_std=decoder_cfg.get('init_method_std', 0.02),
use_cpu_initialization=use_cpu_initialization,
hidden_dropout=decoder_cfg.get('hidden_dropout', 0.1),
attention_dropout=decoder_cfg.get('attention_dropout', 0.1),
ffn_dropout=decoder_cfg.get('ffn_dropout', 0.0),
precision=precision,
fp32_residual_connection=decoder_cfg.get('fp32_residual_connection', False),
activations_checkpoint_method=decoder_cfg.get('activations_checkpoint_method', None),
activations_checkpoint_num_layers=decoder_cfg.get('activations_checkpoint_num_layers', 1),
activations_checkpoint_granularity=decoder_cfg.get('activations_checkpoint_granularity', None),
layernorm_epsilon=decoder_cfg.get('layernorm_epsilon', 1e-5),
bias_activation_fusion=decoder_cfg.get('bias_activation_fusion', True),
bias_dropout_add_fusion=decoder_cfg.get('bias_dropout_add_fusion', True),
masked_softmax_fusion=decoder_cfg.get('masked_softmax_fusion', True),
persist_layer_norm=decoder_cfg.get('persist_layer_norm', True),
openai_gelu=decoder_cfg.get('openai_gelu', False),
onnx_safe=decoder_cfg.get('onnx_safe', False),
hidden_steps=decoder_cfg.get('hidden_steps', -1),
activation=decoder_cfg.get('activation', 'gelu'),
bias=decoder_cfg.get('bias', True),
normalization=decoder_cfg.get('normalization', 'layernorm'),
transformer_block_type=decoder_cfg.get('transformer_block_type', 'pre_ln'),
headscale=decoder_cfg.get('headscale', False),
parent_model_type=ModelType.encoder_and_decoder,
megatron_legacy=decoder_cfg.get('megatron_legacy', False),
normalize_attention_scores=decoder_cfg.get('normalize_attention_scores', True),
num_moe_experts=decoder_cfg.get('num_moe_experts', 1),
moe_frequency=decoder_cfg.get('moe_frequency', 1),
moe_dropout=decoder_cfg.get('moe_dropout', 0.0),
)
self.enc_dec_model = MegatronTransformerEncoderDecoderModule(
encoder=encoder, decoder=decoder, hidden_steps=encoder_cfg.get('hidden_steps', -1),
)
self._enc_dec_model_key = "enc_dec_model"
if self.share_token_embeddings:
# This is only relevant for PP > 1.
self.initialize_word_embeddings(
init_method=init_method_normal(embedding_init_method_std),
vocab_size=vocab_size,
hidden_size=encoder_cfg.hidden_size,
)
if add_decoder and post_process:
if share_decoder_tokens_head_embeddings:
self.tokens_head = MegatronTokenLevelHead(
self.word_embeddings_weight().size(0), parallel_output, bias=tokens_head_bias
)
else:
self.tokens_head = tensor_parallel.ColumnParallelLinear(
input_size=decoder_cfg.hidden_size,
output_size=vocab_size,
bias=tokens_head_bias,
gather_output=not self.parallel_output,
init_method=init_method_normal(decoder_cfg.init_method_std),
use_cpu_initialization=use_cpu_initialization,
)
self._tokens_head_key = 'tokens_head'
def _validate_kv_channels(self, cfg):
kv_channels = cfg.kv_channels
if cfg.kv_channels is None:
assert (
cfg.hidden_size % cfg.num_attention_heads == 0
), 'hidden_size must be divisible by num_attention_heads if kv_channels is None'
kv_channels = cfg.hidden_size // cfg.num_attention_heads
return kv_channels
def _validate_enc_dec_hidden_size(self, encoder_cfg, decoder_cfg):
if encoder_cfg.hidden_size != decoder_cfg.hidden_size:
raise ValueError(
f"Encoder and decoder hidden_size must be equal, but got encoder: {encoder_cfg.hidden_size} and decoder: {decoder_cfg.hidden_size}"
)
def _validate_perceiver_config(self, cfg):
if (
cfg.get("position_embedding_type", "learned_absolute") == "relative"
and cfg.get("arch", "transformer") == "perceiver"
):
raise ValueError(f"Perceivers with relative position embeddings are not supported")
def _validate_config(self):
encoder_kv_channels = self._validate_kv_channels(self.encoder_cfg)
decoder_kv_channels = self._validate_kv_channels(self.decoder_cfg)
self._validate_enc_dec_hidden_size(self.encoder_cfg, self.decoder_cfg)
self._validate_perceiver_config(self.encoder_cfg)
self._validate_perceiver_config(self.decoder_cfg)
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
assert (
self.share_token_embeddings
), "Token embeddings must be shared when using pipeline model parallel size > 1"
assert (
self.share_decoder_tokens_head_embeddings
), "Decoder token embeddings and the outputlayer must be shared when using pipeline model parallel size > 1"
return encoder_kv_channels, decoder_kv_channels
def set_input_tensor(self, input_tensor):
""" See megatron.model.transformer.set_input_tensor()"""
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
if self.add_encoder and self.add_decoder:
assert (
len(input_tensor) == 1
), 'input_tensor should only be length 1 for stage with both encoder and decoder'
self.enc_dec_model.encoder.set_input_tensor(input_tensor[0])
elif self.add_encoder:
assert len(input_tensor) == 1, 'input_tensor should only be length 1 for stage with only encoder'
self.enc_dec_model.encoder.set_input_tensor(input_tensor[0])
elif self.add_decoder:
if len(input_tensor) == 2:
self.enc_dec_model.decoder.set_input_tensor(input_tensor[0])
self.enc_dec_model.encoder_hidden_state = input_tensor[1]
elif len(input_tensor) == 1:
self.enc_dec_model.decoder.set_input_tensor(None)
self.enc_dec_model.encoder_hidden_state = input_tensor[0]
else:
raise Exception('input_tensor must have either length 1 or 2')
else:
raise Exception('Stage must have at least either encoder or decoder')
def forward(
self,
enc_input_ids=None,
enc_attn_mask=None,
dec_input_ids=None,
dec_attn_mask=None,
token_type_ids=None,
labels=None,
enc_output=None, # Result of running the entire encoder
enc_output_attn_mask=None,
enc_input=None, # Result of running encoder embedding only
output_enc_hidden_only=False,
):
"""
Return value is per token / per dimension (i.e., non collapsed loss value)
"""
(
encoder_self_attention_relative_position_bias,
decoder_self_attention_relative_position_bias,
decoder_cross_attention_relative_position_bias,
) = (None, None, None)
if enc_input is not None and enc_output is not None:
raise ValueError(
"""Both enc_input and enc_output are not None.
You should only be passing one of them.
enc_input is the result of the encoder embedding layer
enc_output is the result of running the entire transformer encoder."""
)
# In order of precedence, we use enc_output, enc_input, and then enc_input_ids to determine the encoder sequence length.
if enc_output is not None:
# If enc_output is provided in `batch_for_pipeline`, we need to transpose it from [B x S x H] -> [S x B x H].
enc_output = enc_output.transpose(0, 1)
enc_seq_length = enc_output.size(0)
elif enc_input is not None:
# If enc_input is provided, we need to transpose it from [B x S x H] -> [S x B x H].
enc_input = enc_input.transpose(0, 1)
enc_seq_length = enc_input.size(0)
# Only need to run encoder embedding and position ids if enc_input or enc_output is not provided.
elif enc_input_ids is not None:
enc_seq_length = enc_input_ids.size(1)
if self.pre_process and self.add_encoder:
# We don't need position ids for RPE, because the embedding layer does not have position embeddings.
if self.encoder_relative_position_embedding is None:
enc_position_ids = build_position_ids(enc_input_ids)
else:
enc_position_ids = None
enc_input = self.encoder_embedding(enc_input_ids, enc_position_ids, token_type_ids=token_type_ids)
else:
enc_input = None
else:
# This should only happen with PP > 1 for enc-dec prompt learning models
enc_seq_length = enc_attn_mask.size(1)
if self.add_encoder and self.encoder_relative_position_embedding is not None:
encoder_self_attention_relative_position_bias = self.encoder_relative_position_embedding(
query_seq_length=enc_seq_length, key_seq_length=enc_seq_length,
)
if output_enc_hidden_only:
# When pipeline parallel > 1 we need to make sure encoder exist (will be missing in decoder)
if enc_output is None and self.enc_dec_model.encoder is not None:
enc_output = self.enc_dec_model.encode(
enc_input=enc_input,
enc_attn_mask=enc_attn_mask,
enc_layer_past=None,
enc_get_key_value=False,
enc_self_attention_relative_position_bias=encoder_self_attention_relative_position_bias,
)
else:
enc_output = self.enc_dec_model.encoder_hidden_state
return enc_output
else:
if enc_output_attn_mask is None:
enc_output_attn_mask = enc_attn_mask
if self.pre_process and self.add_decoder:
# We don't need position ids for RPE, because the embedding layer does not have position embeddings.
if self.decoder_relative_position_embedding is None:
dec_position_ids = build_position_ids(dec_input_ids)
else:
dec_position_ids = None
dec_input = self.decoder_embedding(dec_input_ids, dec_position_ids, token_type_ids=token_type_ids)
else:
# Note: This is when the decoder itself is split across PP ranks.
dec_input = None
if self.add_decoder and self.decoder_relative_position_embedding is not None:
decoder_self_attention_relative_position_bias = self.decoder_relative_position_embedding(
query_seq_length=dec_input_ids.size(1), key_seq_length=dec_input_ids.size(1)
)
if not self.decoder_cfg.relative_position_bias_self_attention_only:
decoder_cross_attention_relative_position_bias = self.decoder_cross_attention_relative_position_embedding(
query_seq_length=dec_input_ids.size(1), key_seq_length=enc_seq_length,
)
else:
decoder_cross_attention_relative_position_bias = None
output = self.enc_dec_model(
enc_input=enc_input,
enc_attn_mask=enc_attn_mask,
dec_input=dec_input,
dec_attn_mask=dec_attn_mask,
enc_layer_past=None,
enc_get_key_value=False,
enc_output=enc_output,
enc_output_attn_mask=enc_output_attn_mask,
dec_layer_past=None,
dec_get_key_value=False,
enc_self_attention_relative_position_bias=encoder_self_attention_relative_position_bias,
dec_self_attention_relative_position_bias=decoder_self_attention_relative_position_bias,
dec_cross_attention_relative_position_bias=decoder_cross_attention_relative_position_bias,
)
if self.post_process and self.add_decoder:
dec_output, enc_output = output # [s, b, h]
# project decoder output to vocabulary-size dimensions
if self.share_decoder_tokens_head_embeddings:
token_logits = self.tokens_head(dec_output, self.word_embeddings_weight())
else:
token_logits = self.tokens_head(dec_output)[0]
if labels is not None:
# [b, s] -> [s, b]
labels = labels.transpose(0, 1).contiguous()
# Set label smoothing to 0 if in eval mode.
label_smoothing = self.label_smoothing if self.training else 0.0
# tensor_parallel.vocab_parallel_cross_entropy performs log_softmax and return log p(x_i|z) per token i
if self.fp16_cross_entropy:
assert token_logits.dtype == torch.half
tokens_loss = vocab_parallel_cross_entropy(token_logits, labels, label_smoothing)
else:
tokens_loss = vocab_parallel_cross_entropy(token_logits.float(), labels, label_smoothing)
# [s, b] -> [b, s]
tokens_loss = tokens_loss.transpose(0, 1).contiguous()
return tokens_loss
else:
# [s, b, h] -> [b, s, h]
token_logits = token_logits.transpose(0, 1).contiguous()
return token_logits
elif self.add_decoder and not self.add_encoder:
decoder_output, _ = output
return decoder_output
else:
encoder_output = output
return encoder_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._encoder_embedding_key] = self.encoder_embedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars
)
state_dict_[self._decoder_embedding_key] = self.decoder_embedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars
)
state_dict_[self._enc_dec_model_key] = self.enc_dec_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars
)
state_dict_[self._tokens_head_key] = self.tokens_head.state_dict_for_save_checkpoint(
destination, prefix, keep_vars
)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
self.encoder_embedding.encoder_embeddingload_state_dict(state_dict[self._encoder_embedding_key], strict=strict)
self.decoder_embedding.load_state_dict(state_dict[self._decoder_embedding_key], strict=strict)
self.enc_dec_model.load_state_dict(state_dict[self._enc_dec_model_key], strict=strict)
self.tokens_head.load_state_dict(state_dict[self._tokens_head_key], strict=strict)