NeMo / nemo /collections /nlp /modules /common /megatron /megatron_encoder_decoder.py
camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer based language model."""
import torch
from nemo.collections.nlp.modules.common.megatron.megatron_perceiver_encoders import MegatronPerceiverEncoderModule
from nemo.collections.nlp.modules.common.megatron.module import MegatronModule
from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults
try:
from apex.transformer.enums import AttnMaskType
HAVE_APEX = True
except (ImportError, ModuleNotFoundError):
HAVE_APEX = False
# fake missing classes with None attributes
AttnMaskType = ApexGuardDefaults()
__all__ = ["MegatronTransformerEncoderDecoderModule"]
class MegatronTransformerEncoderDecoderModule(MegatronModule):
"""Transformer encoder-decoder model.
"""
def __init__(
self,
encoder,
decoder,
# AttnMaskType enum mask type (e.g., padding, casual)
encoder_attn_mask_type: AttnMaskType = None,
decoder_attn_mask_type: AttnMaskType = None,
hidden_steps: int = None,
):
super(MegatronTransformerEncoderDecoderModule, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.hidden_steps = hidden_steps
if isinstance(encoder, MegatronPerceiverEncoderModule) and hidden_steps is None:
raise ValueError(
f"hidden_steps cannot be None for perceiver encoders. It is needed to compute the encoder-decoder cross attention mask."
)
# try to infer mask_type if not given
if encoder_attn_mask_type is None:
if encoder is None:
encoder_attn_mask_type = None
# Perceiver does not have a `.model` attribute, assume it always uses padding mask.
elif isinstance(encoder, MegatronPerceiverEncoderModule):
encoder_attn_mask_type = AttnMaskType.padding
elif hasattr(encoder.model, 'self_attn_mask_type'):
encoder_attn_mask_type = encoder.model.self_attn_mask_type
else:
raise AttributeError(
"Could not find an attribute for encoder self_attn_mask_type, make sure it is set when instatiating the encoder or pass it to the constructor of this class."
)
if decoder_attn_mask_type is None:
if decoder is None:
decoder_attn_mask_type = None
elif hasattr(decoder.model, 'self_attn_mask_type'):
decoder_attn_mask_type = decoder.model.self_attn_mask_type
else:
raise AttributeError(
"Could not find an attribute for decoder self_attn_mask_type, make sure it is set when instatiating the decoder or pass it to the constructor of this class."
)
self.encoder_attn_mask_type = encoder_attn_mask_type
self.decoder_attn_mask_type = decoder_attn_mask_type
self._encoder_key = "encoder"
self._decoder_key = "decoder"
def encode(
self,
enc_input,
enc_attn_mask,
enc_layer_past=None,
enc_get_key_value=False,
enc_self_attention_relative_position_bias=None,
):
if self.encoder is None:
raise ValueError(f"Cannot call .encode(...) when self.encoder is None.")
"""Encodes embedder input using encoder"""
enc_output = self.encoder(
enc_input=enc_input,
enc_attn_mask=enc_attn_mask,
layer_past=enc_layer_past,
get_key_value=enc_get_key_value,
enc_self_attention_relative_position_bias=enc_self_attention_relative_position_bias,
)
return enc_output
def decode(
self,
dec_input,
dec_attn_mask,
enc_output,
enc_attn_mask,
dec_layer_past=None,
dec_get_key_value=False,
dec_self_attention_relative_position_bias=None,
dec_cross_attention_relative_position_bias=None,
):
if self.decoder is None:
raise ValueError(f"Cannot call .decode(...) when self.decoder is None.")
"""Decodes embedder input using decoder and encoder input"""
dec_output = self.decoder(
dec_input=dec_input,
dec_attn_mask=dec_attn_mask,
layer_past=dec_layer_past,
get_key_value=dec_get_key_value,
enc_output=enc_output,
enc_attn_mask=enc_attn_mask,
dec_self_attention_relative_position_bias=dec_self_attention_relative_position_bias,
dec_cross_attention_relative_position_bias=dec_cross_attention_relative_position_bias,
)
return dec_output
def forward(
self,
enc_input,
enc_attn_mask,
dec_input,
dec_attn_mask,
enc_layer_past=None,
enc_get_key_value=False,
enc_output=None,
enc_output_attn_mask=None,
dec_layer_past=None,
dec_get_key_value=False,
output_enc_hidden_only=False,
enc_self_attention_relative_position_bias=None,
dec_self_attention_relative_position_bias=None,
dec_cross_attention_relative_position_bias=None,
):
# encoder
if enc_output is None:
if self.encoder is not None:
enc_output = self.encode(
enc_input=enc_input,
enc_attn_mask=enc_attn_mask,
enc_layer_past=enc_layer_past,
enc_get_key_value=enc_get_key_value,
enc_self_attention_relative_position_bias=enc_self_attention_relative_position_bias,
)
else:
assert self.encoder_hidden_state is not None
enc_output = self.encoder_hidden_state
else:
enc_attn_mask = enc_output_attn_mask.to(enc_attn_mask)
if self.decoder is None or output_enc_hidden_only:
return enc_output
# decoder
# Adjust encoder attention mask if encoder is a perceiver.
if self.encoder is not None and isinstance(self.encoder, MegatronPerceiverEncoderModule):
# Attention mask is expected to be of shape [B x S] and enc_output is of size [S x B x H].
enc_attn_mask = torch.ones(enc_output.size(1), self.hidden_steps).to(enc_output.device)
dec_output = self.decode(
dec_input=dec_input,
dec_attn_mask=dec_attn_mask,
enc_output=enc_output,
enc_attn_mask=enc_attn_mask,
dec_layer_past=dec_layer_past,
dec_get_key_value=dec_get_key_value,
dec_self_attention_relative_position_bias=dec_self_attention_relative_position_bias,
dec_cross_attention_relative_position_bias=dec_cross_attention_relative_position_bias,
)
return dec_output, enc_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
"""For easy load."""
state_dict_ = {}
state_dict_[self._encoder_key] = self.encoder.state_dict_for_save_checkpoint(destination, prefix, keep_vars)
state_dict_[self._decoder_key] = self.decoder.state_dict_for_save_checkpoint(destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
self.encoder.load_state_dict(state_dict[self._encoder_key], strict=strict)
self.decoder.load_state_dict(state_dict[self._decoder_key], strict=strict)