| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """Transformer based language model.""" |
| | import torch |
| |
|
| | from nemo.collections.nlp.modules.common.megatron.megatron_perceiver_encoders import MegatronPerceiverEncoderModule |
| | from nemo.collections.nlp.modules.common.megatron.module import MegatronModule |
| | from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults |
| |
|
| | try: |
| | from apex.transformer.enums import AttnMaskType |
| |
|
| | HAVE_APEX = True |
| | except (ImportError, ModuleNotFoundError): |
| | HAVE_APEX = False |
| | |
| | AttnMaskType = ApexGuardDefaults() |
| |
|
| |
|
| | __all__ = ["MegatronTransformerEncoderDecoderModule"] |
| |
|
| |
|
| | class MegatronTransformerEncoderDecoderModule(MegatronModule): |
| | """Transformer encoder-decoder model. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | encoder, |
| | decoder, |
| | |
| | encoder_attn_mask_type: AttnMaskType = None, |
| | decoder_attn_mask_type: AttnMaskType = None, |
| | hidden_steps: int = None, |
| | ): |
| | super(MegatronTransformerEncoderDecoderModule, self).__init__() |
| |
|
| | self.encoder = encoder |
| | self.decoder = decoder |
| | self.hidden_steps = hidden_steps |
| | if isinstance(encoder, MegatronPerceiverEncoderModule) and hidden_steps is None: |
| | raise ValueError( |
| | f"hidden_steps cannot be None for perceiver encoders. It is needed to compute the encoder-decoder cross attention mask." |
| | ) |
| |
|
| | |
| | if encoder_attn_mask_type is None: |
| | if encoder is None: |
| | encoder_attn_mask_type = None |
| | |
| | elif isinstance(encoder, MegatronPerceiverEncoderModule): |
| | encoder_attn_mask_type = AttnMaskType.padding |
| | elif hasattr(encoder.model, 'self_attn_mask_type'): |
| | encoder_attn_mask_type = encoder.model.self_attn_mask_type |
| | else: |
| | raise AttributeError( |
| | "Could not find an attribute for encoder self_attn_mask_type, make sure it is set when instatiating the encoder or pass it to the constructor of this class." |
| | ) |
| | if decoder_attn_mask_type is None: |
| | if decoder is None: |
| | decoder_attn_mask_type = None |
| | elif hasattr(decoder.model, 'self_attn_mask_type'): |
| | decoder_attn_mask_type = decoder.model.self_attn_mask_type |
| | else: |
| | raise AttributeError( |
| | "Could not find an attribute for decoder self_attn_mask_type, make sure it is set when instatiating the decoder or pass it to the constructor of this class." |
| | ) |
| |
|
| | self.encoder_attn_mask_type = encoder_attn_mask_type |
| | self.decoder_attn_mask_type = decoder_attn_mask_type |
| |
|
| | self._encoder_key = "encoder" |
| | self._decoder_key = "decoder" |
| |
|
| | def encode( |
| | self, |
| | enc_input, |
| | enc_attn_mask, |
| | enc_layer_past=None, |
| | enc_get_key_value=False, |
| | enc_self_attention_relative_position_bias=None, |
| | ): |
| | if self.encoder is None: |
| | raise ValueError(f"Cannot call .encode(...) when self.encoder is None.") |
| | """Encodes embedder input using encoder""" |
| | enc_output = self.encoder( |
| | enc_input=enc_input, |
| | enc_attn_mask=enc_attn_mask, |
| | layer_past=enc_layer_past, |
| | get_key_value=enc_get_key_value, |
| | enc_self_attention_relative_position_bias=enc_self_attention_relative_position_bias, |
| | ) |
| |
|
| | return enc_output |
| |
|
| | def decode( |
| | self, |
| | dec_input, |
| | dec_attn_mask, |
| | enc_output, |
| | enc_attn_mask, |
| | dec_layer_past=None, |
| | dec_get_key_value=False, |
| | dec_self_attention_relative_position_bias=None, |
| | dec_cross_attention_relative_position_bias=None, |
| | ): |
| | if self.decoder is None: |
| | raise ValueError(f"Cannot call .decode(...) when self.decoder is None.") |
| | """Decodes embedder input using decoder and encoder input""" |
| | dec_output = self.decoder( |
| | dec_input=dec_input, |
| | dec_attn_mask=dec_attn_mask, |
| | layer_past=dec_layer_past, |
| | get_key_value=dec_get_key_value, |
| | enc_output=enc_output, |
| | enc_attn_mask=enc_attn_mask, |
| | dec_self_attention_relative_position_bias=dec_self_attention_relative_position_bias, |
| | dec_cross_attention_relative_position_bias=dec_cross_attention_relative_position_bias, |
| | ) |
| |
|
| | return dec_output |
| |
|
| | def forward( |
| | self, |
| | enc_input, |
| | enc_attn_mask, |
| | dec_input, |
| | dec_attn_mask, |
| | enc_layer_past=None, |
| | enc_get_key_value=False, |
| | enc_output=None, |
| | enc_output_attn_mask=None, |
| | dec_layer_past=None, |
| | dec_get_key_value=False, |
| | output_enc_hidden_only=False, |
| | enc_self_attention_relative_position_bias=None, |
| | dec_self_attention_relative_position_bias=None, |
| | dec_cross_attention_relative_position_bias=None, |
| | ): |
| | |
| | if enc_output is None: |
| | if self.encoder is not None: |
| | enc_output = self.encode( |
| | enc_input=enc_input, |
| | enc_attn_mask=enc_attn_mask, |
| | enc_layer_past=enc_layer_past, |
| | enc_get_key_value=enc_get_key_value, |
| | enc_self_attention_relative_position_bias=enc_self_attention_relative_position_bias, |
| | ) |
| | else: |
| | assert self.encoder_hidden_state is not None |
| | enc_output = self.encoder_hidden_state |
| | else: |
| | enc_attn_mask = enc_output_attn_mask.to(enc_attn_mask) |
| |
|
| | if self.decoder is None or output_enc_hidden_only: |
| | return enc_output |
| |
|
| | |
| | |
| | if self.encoder is not None and isinstance(self.encoder, MegatronPerceiverEncoderModule): |
| | |
| | enc_attn_mask = torch.ones(enc_output.size(1), self.hidden_steps).to(enc_output.device) |
| |
|
| | dec_output = self.decode( |
| | dec_input=dec_input, |
| | dec_attn_mask=dec_attn_mask, |
| | enc_output=enc_output, |
| | enc_attn_mask=enc_attn_mask, |
| | dec_layer_past=dec_layer_past, |
| | dec_get_key_value=dec_get_key_value, |
| | dec_self_attention_relative_position_bias=dec_self_attention_relative_position_bias, |
| | dec_cross_attention_relative_position_bias=dec_cross_attention_relative_position_bias, |
| | ) |
| |
|
| | return dec_output, enc_output |
| |
|
| | def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): |
| | """For easy load.""" |
| |
|
| | state_dict_ = {} |
| |
|
| | state_dict_[self._encoder_key] = self.encoder.state_dict_for_save_checkpoint(destination, prefix, keep_vars) |
| | state_dict_[self._decoder_key] = self.decoder.state_dict_for_save_checkpoint(destination, prefix, keep_vars) |
| |
|
| | return state_dict_ |
| |
|
| | def load_state_dict(self, state_dict, strict=True): |
| | """Customized load.""" |
| |
|
| | self.encoder.load_state_dict(state_dict[self._encoder_key], strict=strict) |
| | self.decoder.load_state_dict(state_dict[self._decoder_key], strict=strict) |
| |
|