File size: 2,446 Bytes
be611b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.modules import (
TransformerSentenceEncoderLayer
)
from fairseq.model_parallel.modules import ModelParallelMultiheadAttention
try:
from fairseq.model_parallel.megatron.mpu import (
ColumnParallelLinear,
RowParallelLinear,
)
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
class ModelParallelTransformerSentenceEncoderLayer(TransformerSentenceEncoderLayer):
"""
Implements a Model Parallel Transformer Encoder Layer used in
BERT/XLM style pre-trained models.
"""
def build_fc1(self, input_dim, output_dim, **unused):
return ColumnParallelLinear(input_dim, output_dim, gather_output=False)
def build_fc2(self, input_dim, output_dim, **unused):
return RowParallelLinear(input_dim, output_dim, input_is_parallel=True)
def build_self_attention(
self,
embed_dim,
num_attention_heads,
dropout,
**kwargs,
):
return ModelParallelMultiheadAttention(
embed_dim,
num_attention_heads,
dropout=dropout,
self_attention=True
)
def forward(
self,
x: torch.Tensor,
self_attn_mask: torch.Tensor = None,
self_attn_padding_mask: torch.Tensor = None,
):
"""
LayerNorm is applied either before or after the self-attention/ffn
modules similar to the original Transformer imlementation.
"""
residual = x
x = self.self_attn_layer_norm(x)
x, attn = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
need_weights=False,
attn_mask=self_attn_mask,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
residual = x
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
return x, None
|