| | |
| | |
| | |
| | |
| |
|
| | from fairseq.model_parallel.modules import ModelParallelMultiheadAttention |
| | from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer |
| |
|
| |
|
| | try: |
| | from fairseq.model_parallel.megatron.mpu import ( |
| | ColumnParallelLinear, |
| | RowParallelLinear, |
| | ) |
| |
|
| | has_megatron_submodule = True |
| | except (ImportError, ModuleNotFoundError): |
| | has_megatron_submodule = False |
| |
|
| |
|
| | class ModelParallelTransformerEncoderLayer(TransformerEncoderLayer): |
| | """Encoder layer block over multiple gpus. |
| | |
| | See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details. |
| | """ |
| |
|
| | def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): |
| | if q_noise > 0: |
| | raise NotImplementedError |
| | return ColumnParallelLinear(input_dim, output_dim, gather_output=False) |
| |
|
| | def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): |
| | if q_noise > 0: |
| | raise NotImplementedError |
| | return RowParallelLinear(input_dim, output_dim, input_is_parallel=True) |
| |
|
| | def build_self_attention(self, embed_dim, args, **unused_kwargs): |
| | return ModelParallelMultiheadAttention( |
| | embed_dim, |
| | args.encoder_attention_heads, |
| | dropout=args.attention_dropout, |
| | self_attention=True, |
| | ) |
| |
|
| |
|
| | class ModelParallelTransformerDecoderLayer(TransformerDecoderLayer): |
| | """Decoder layer block. |
| | |
| | See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details. |
| | """ |
| |
|
| | def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): |
| | if q_noise > 0: |
| | raise NotImplementedError |
| | return ColumnParallelLinear(input_dim, output_dim, gather_output=False) |
| |
|
| | def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): |
| | if q_noise > 0: |
| | raise NotImplementedError |
| | return RowParallelLinear(input_dim, output_dim, input_is_parallel=True) |
| |
|
| | def build_self_attention(self, embed_dim, args, **unused_kwargs): |
| | return ModelParallelMultiheadAttention( |
| | embed_dim=embed_dim, |
| | num_heads=args.decoder_attention_heads, |
| | dropout=args.attention_dropout, |
| | self_attention=not getattr(args, "cross_self_attention", False), |
| | ) |
| |
|
| | def build_encoder_attention(self, embed_dim, args, **unused_kwargs): |
| | return ModelParallelMultiheadAttention( |
| | embed_dim=embed_dim, |
| | num_heads=args.decoder_attention_heads, |
| | kdim=getattr(args, "encoder_embed_dim", None), |
| | vdim=getattr(args, "encoder_embed_dim", None), |
| | dropout=args.attention_dropout, |
| | encoder_decoder_attention=True, |
| | ) |
| |
|