|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from fairseq import utils |
|
|
from fairseq.modules import ( |
|
|
TransformerSentenceEncoderLayer |
|
|
) |
|
|
from fairseq.model_parallel.modules import ModelParallelMultiheadAttention |
|
|
try: |
|
|
from fairseq.model_parallel.megatron.mpu import ( |
|
|
ColumnParallelLinear, |
|
|
RowParallelLinear, |
|
|
) |
|
|
has_megatron_submodule = True |
|
|
except (ImportError, ModuleNotFoundError): |
|
|
has_megatron_submodule = False |
|
|
|
|
|
|
|
|
class ModelParallelTransformerSentenceEncoderLayer(TransformerSentenceEncoderLayer): |
|
|
""" |
|
|
Implements a Model Parallel Transformer Encoder Layer used in |
|
|
BERT/XLM style pre-trained models. |
|
|
""" |
|
|
def build_fc1(self, input_dim, output_dim, **unused): |
|
|
return ColumnParallelLinear(input_dim, output_dim, gather_output=False) |
|
|
|
|
|
def build_fc2(self, input_dim, output_dim, **unused): |
|
|
return RowParallelLinear(input_dim, output_dim, input_is_parallel=True) |
|
|
|
|
|
def build_self_attention( |
|
|
self, |
|
|
embed_dim, |
|
|
num_attention_heads, |
|
|
dropout, |
|
|
**kwargs, |
|
|
): |
|
|
return ModelParallelMultiheadAttention( |
|
|
embed_dim, |
|
|
num_attention_heads, |
|
|
dropout=dropout, |
|
|
self_attention=True |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
self_attn_mask: torch.Tensor = None, |
|
|
self_attn_padding_mask: torch.Tensor = None, |
|
|
): |
|
|
""" |
|
|
LayerNorm is applied either before or after the self-attention/ffn |
|
|
modules similar to the original Transformer imlementation. |
|
|
""" |
|
|
residual = x |
|
|
x = self.self_attn_layer_norm(x) |
|
|
x, attn = self.self_attn( |
|
|
query=x, |
|
|
key=x, |
|
|
value=x, |
|
|
key_padding_mask=self_attn_padding_mask, |
|
|
need_weights=False, |
|
|
attn_mask=self_attn_mask, |
|
|
) |
|
|
x = F.dropout(x, p=self.dropout, training=self.training) |
|
|
x = residual + x |
|
|
|
|
|
residual = x |
|
|
x = self.final_layer_norm(x) |
|
|
x = self.activation_fn(self.fc1(x)) |
|
|
x = F.dropout(x, p=self.activation_dropout, training=self.training) |
|
|
x = self.fc2(x) |
|
|
x = F.dropout(x, p=self.dropout, training=self.training) |
|
|
x = residual + x |
|
|
return x, None |
|
|
|