File size: 8,093 Bytes
7934b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer based language model."""
import torch

from nemo.collections.nlp.modules.common.megatron.megatron_perceiver_encoders import MegatronPerceiverEncoderModule
from nemo.collections.nlp.modules.common.megatron.module import MegatronModule
from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults

try:
    from apex.transformer.enums import AttnMaskType

    HAVE_APEX = True
except (ImportError, ModuleNotFoundError):
    HAVE_APEX = False
    # fake missing classes with None attributes
    AttnMaskType = ApexGuardDefaults()


__all__ = ["MegatronTransformerEncoderDecoderModule"]


class MegatronTransformerEncoderDecoderModule(MegatronModule):
    """Transformer encoder-decoder model.
    """

    def __init__(
        self,
        encoder,
        decoder,
        # AttnMaskType enum mask type (e.g., padding, casual)
        encoder_attn_mask_type: AttnMaskType = None,
        decoder_attn_mask_type: AttnMaskType = None,
        hidden_steps: int = None,
    ):
        super(MegatronTransformerEncoderDecoderModule, self).__init__()

        self.encoder = encoder
        self.decoder = decoder
        self.hidden_steps = hidden_steps
        if isinstance(encoder, MegatronPerceiverEncoderModule) and hidden_steps is None:
            raise ValueError(
                f"hidden_steps cannot be None for perceiver encoders. It is needed to compute the encoder-decoder cross attention mask."
            )

        # try to infer mask_type if not given
        if encoder_attn_mask_type is None:
            if encoder is None:
                encoder_attn_mask_type = None
            # Perceiver does not have a `.model` attribute, assume it always uses padding mask.
            elif isinstance(encoder, MegatronPerceiverEncoderModule):
                encoder_attn_mask_type = AttnMaskType.padding
            elif hasattr(encoder.model, 'self_attn_mask_type'):
                encoder_attn_mask_type = encoder.model.self_attn_mask_type
            else:
                raise AttributeError(
                    "Could not find an attribute for encoder self_attn_mask_type, make sure it is set when instatiating the encoder or pass it to the constructor of this class."
                )
        if decoder_attn_mask_type is None:
            if decoder is None:
                decoder_attn_mask_type = None
            elif hasattr(decoder.model, 'self_attn_mask_type'):
                decoder_attn_mask_type = decoder.model.self_attn_mask_type
            else:
                raise AttributeError(
                    "Could not find an attribute for decoder self_attn_mask_type, make sure it is set when instatiating the decoder or pass it to the constructor of this class."
                )

        self.encoder_attn_mask_type = encoder_attn_mask_type
        self.decoder_attn_mask_type = decoder_attn_mask_type

        self._encoder_key = "encoder"
        self._decoder_key = "decoder"

    def encode(
        self,
        enc_input,
        enc_attn_mask,
        enc_layer_past=None,
        enc_get_key_value=False,
        enc_self_attention_relative_position_bias=None,
    ):
        if self.encoder is None:
            raise ValueError(f"Cannot call .encode(...) when self.encoder is None.")
        """Encodes embedder input using encoder"""
        enc_output = self.encoder(
            enc_input=enc_input,
            enc_attn_mask=enc_attn_mask,
            layer_past=enc_layer_past,
            get_key_value=enc_get_key_value,
            enc_self_attention_relative_position_bias=enc_self_attention_relative_position_bias,
        )

        return enc_output

    def decode(
        self,
        dec_input,
        dec_attn_mask,
        enc_output,
        enc_attn_mask,
        dec_layer_past=None,
        dec_get_key_value=False,
        dec_self_attention_relative_position_bias=None,
        dec_cross_attention_relative_position_bias=None,
    ):
        if self.decoder is None:
            raise ValueError(f"Cannot call .decode(...) when self.decoder is None.")
        """Decodes embedder input using decoder and encoder input"""
        dec_output = self.decoder(
            dec_input=dec_input,
            dec_attn_mask=dec_attn_mask,
            layer_past=dec_layer_past,
            get_key_value=dec_get_key_value,
            enc_output=enc_output,
            enc_attn_mask=enc_attn_mask,
            dec_self_attention_relative_position_bias=dec_self_attention_relative_position_bias,
            dec_cross_attention_relative_position_bias=dec_cross_attention_relative_position_bias,
        )

        return dec_output

    def forward(
        self,
        enc_input,
        enc_attn_mask,
        dec_input,
        dec_attn_mask,
        enc_layer_past=None,
        enc_get_key_value=False,
        enc_output=None,
        enc_output_attn_mask=None,
        dec_layer_past=None,
        dec_get_key_value=False,
        output_enc_hidden_only=False,
        enc_self_attention_relative_position_bias=None,
        dec_self_attention_relative_position_bias=None,
        dec_cross_attention_relative_position_bias=None,
    ):
        # encoder
        if enc_output is None:
            if self.encoder is not None:
                enc_output = self.encode(
                    enc_input=enc_input,
                    enc_attn_mask=enc_attn_mask,
                    enc_layer_past=enc_layer_past,
                    enc_get_key_value=enc_get_key_value,
                    enc_self_attention_relative_position_bias=enc_self_attention_relative_position_bias,
                )
            else:
                assert self.encoder_hidden_state is not None
                enc_output = self.encoder_hidden_state
        else:
            enc_attn_mask = enc_output_attn_mask.to(enc_attn_mask)

        if self.decoder is None or output_enc_hidden_only:
            return enc_output

        # decoder
        # Adjust encoder attention mask if encoder is a perceiver.
        if self.encoder is not None and isinstance(self.encoder, MegatronPerceiverEncoderModule):
            # Attention mask is expected to be of shape [B x S] and enc_output is of size [S x B x H].
            enc_attn_mask = torch.ones(enc_output.size(1), self.hidden_steps).to(enc_output.device)

        dec_output = self.decode(
            dec_input=dec_input,
            dec_attn_mask=dec_attn_mask,
            enc_output=enc_output,
            enc_attn_mask=enc_attn_mask,
            dec_layer_past=dec_layer_past,
            dec_get_key_value=dec_get_key_value,
            dec_self_attention_relative_position_bias=dec_self_attention_relative_position_bias,
            dec_cross_attention_relative_position_bias=dec_cross_attention_relative_position_bias,
        )

        return dec_output, enc_output

    def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
        """For easy load."""

        state_dict_ = {}

        state_dict_[self._encoder_key] = self.encoder.state_dict_for_save_checkpoint(destination, prefix, keep_vars)
        state_dict_[self._decoder_key] = self.decoder.state_dict_for_save_checkpoint(destination, prefix, keep_vars)

        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        self.encoder.load_state_dict(state_dict[self._encoder_key], strict=strict)
        self.decoder.load_state_dict(state_dict[self._decoder_key], strict=strict)