|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
|
import logging |
|
|
|
|
|
from fairseq.models import ( |
|
|
FairseqEncoderModel, |
|
|
FairseqLanguageModel, |
|
|
register_model, |
|
|
register_model_architecture, |
|
|
) |
|
|
from fairseq.models.speech_to_speech.modules.ctc_decoder import CTCDecoder |
|
|
from fairseq.models.speech_to_speech.modules.transformer_encoder import ( |
|
|
TransformerEncoderNoEmb, |
|
|
) |
|
|
from fairseq.models.speech_to_speech.s2s_conformer import S2SpecTConformerModel |
|
|
from fairseq.models.speech_to_speech.s2s_conformer_unity import ( |
|
|
multitask_text_transformer_decoder_arch, |
|
|
) |
|
|
from fairseq.models.speech_to_speech.s2s_transformer import ( |
|
|
base_multitask_text_transformer_decoder_arch, |
|
|
s2spect_architecture_base, |
|
|
) |
|
|
from fairseq.models.text_to_speech import TTSTransformerDecoder |
|
|
from fairseq.models.transformer import TransformerDecoder, TransformerModelBase |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@register_model("s2spect2_conformer") |
|
|
class S2SpecT2ConformerModel(S2SpecTConformerModel): |
|
|
""" |
|
|
Direct speech-to-speech translation model with Conformer encoder + MT Transformer decoder + TTS Transformer decoder |
|
|
""" |
|
|
|
|
|
@staticmethod |
|
|
def add_args(parser): |
|
|
S2SpecTConformerModel.add_args(parser) |
|
|
parser.add_argument( |
|
|
"--translation-decoder-layers", |
|
|
type=int, |
|
|
default=4, |
|
|
metavar="N", |
|
|
help="num decoder layers in the first-pass translation module", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--synthesizer", |
|
|
default="transformer", |
|
|
choices=["transformer"], |
|
|
help="", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--synthesizer-encoder-layers", |
|
|
type=int, |
|
|
default=0, |
|
|
metavar="N", |
|
|
help="num encoder layers in the second-pass synthesizer module", |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def build_multitask_decoder( |
|
|
cls, |
|
|
args, |
|
|
tgt_dict, |
|
|
in_dim, |
|
|
is_mt_decoder, |
|
|
decoder_layers, |
|
|
decoder_embed_dim, |
|
|
decoder_attention_heads, |
|
|
): |
|
|
decoder_args = args.decoder_args |
|
|
decoder_args.encoder_embed_dim = in_dim |
|
|
if args.decoder_type == "transformer": |
|
|
if is_mt_decoder: |
|
|
multitask_text_transformer_decoder_arch( |
|
|
decoder_args, |
|
|
decoder_layers, |
|
|
decoder_embed_dim, |
|
|
decoder_attention_heads, |
|
|
) |
|
|
else: |
|
|
base_multitask_text_transformer_decoder_arch(decoder_args) |
|
|
task_decoder = TransformerDecoder( |
|
|
decoder_args, |
|
|
tgt_dict, |
|
|
embed_tokens=TransformerModelBase.build_embedding( |
|
|
decoder_args, |
|
|
tgt_dict, |
|
|
decoder_args.decoder_embed_dim, |
|
|
), |
|
|
) |
|
|
elif args.decoder_type == "ctc": |
|
|
task_decoder = CTCDecoder( |
|
|
dictionary=tgt_dict, |
|
|
in_dim=in_dim, |
|
|
) |
|
|
else: |
|
|
raise NotImplementedError( |
|
|
"currently only support multitask decoder_type 'transformer', 'ctc'" |
|
|
) |
|
|
|
|
|
return task_decoder |
|
|
|
|
|
@classmethod |
|
|
def build_decoder(cls, args): |
|
|
_args = copy.deepcopy(args) |
|
|
_args.encoder_embed_dim = args.decoder_embed_dim |
|
|
|
|
|
if args.synthesizer == "transformer": |
|
|
return TTSTransformerDecoder(_args, None, padding_idx=1) |
|
|
else: |
|
|
raise NotImplementedError(args.synthesizer) |
|
|
|
|
|
@classmethod |
|
|
def build_model(cls, args, task): |
|
|
encoder = cls.build_encoder(args) |
|
|
decoder = cls.build_decoder(args) |
|
|
base_model = cls(encoder, decoder) |
|
|
|
|
|
|
|
|
base_model.mt_task_name = None |
|
|
base_model.multitask_decoders = {} |
|
|
has_first_pass_decoder = False |
|
|
for task_name, task_obj in task.multitask_tasks.items(): |
|
|
if task_obj.is_first_pass_decoder: |
|
|
has_first_pass_decoder = True |
|
|
base_model.mt_task_name = task_name |
|
|
|
|
|
in_dim = ( |
|
|
args.encoder_embed_dim |
|
|
if task_obj.args.input_from == "encoder" |
|
|
else args.decoder_embed_dim |
|
|
) |
|
|
task_decoder = cls.build_multitask_decoder( |
|
|
task_obj.args, |
|
|
task_obj.target_dictionary, |
|
|
in_dim, |
|
|
task_obj.is_first_pass_decoder, |
|
|
getattr(args, "translation_decoder_layers", 4), |
|
|
getattr(args, "decoder_embed_dim", 256), |
|
|
getattr(args, "decoder_attention_heads", 4), |
|
|
) |
|
|
|
|
|
setattr(base_model, f"{task_name}_decoder", task_decoder) |
|
|
decoder_model_cls = ( |
|
|
FairseqEncoderModel |
|
|
if task_obj.args.decoder_type == "ctc" |
|
|
else FairseqLanguageModel |
|
|
) |
|
|
base_model.multitask_decoders[task_name] = decoder_model_cls( |
|
|
getattr(base_model, f"{task_name}_decoder") |
|
|
) |
|
|
|
|
|
assert has_first_pass_decoder, "set at least one intermediate non-CTC decoder" |
|
|
|
|
|
|
|
|
if getattr(args, "synthesizer_encoder_layers", 0) > 0: |
|
|
base_model.synthesizer_encoder = cls.build_text_encoder(args) |
|
|
else: |
|
|
base_model.synthesizer_encoder = None |
|
|
|
|
|
return base_model |
|
|
|
|
|
@classmethod |
|
|
def build_text_encoder(cls, args): |
|
|
_args = copy.deepcopy(args) |
|
|
_args.encoder_layers = args.synthesizer_encoder_layers |
|
|
_args.encoder_embed_dim = args.decoder_embed_dim |
|
|
_args.encoder_ffn_embed_dim = args.decoder_ffn_embed_dim |
|
|
_args.encoder_attention_heads = args.decoder_attention_heads |
|
|
_args.encoder_normalize_before = True |
|
|
return TransformerEncoderNoEmb(_args) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
src_tokens, |
|
|
src_lengths, |
|
|
prev_output_tokens, |
|
|
prev_output_tokens_mt, |
|
|
tgt_speaker=None, |
|
|
incremental_state=None, |
|
|
target_lengths=None, |
|
|
speaker=None, |
|
|
return_all_hiddens=False, |
|
|
): |
|
|
encoder_out = self.encoder( |
|
|
src_tokens, |
|
|
src_lengths=src_lengths, |
|
|
tgt_speaker=tgt_speaker, |
|
|
return_all_hiddens=return_all_hiddens, |
|
|
) |
|
|
|
|
|
|
|
|
mt_decoder = getattr(self, f"{self.mt_task_name}_decoder") |
|
|
mt_decoder_out = mt_decoder( |
|
|
prev_output_tokens_mt, |
|
|
encoder_out=encoder_out, |
|
|
) |
|
|
x = mt_decoder_out[1]["inner_states"][-1] |
|
|
if mt_decoder.layer_norm is not None: |
|
|
x = mt_decoder.layer_norm(x) |
|
|
|
|
|
mt_decoder_padding_mask = None |
|
|
if prev_output_tokens_mt.eq(mt_decoder.padding_idx).any(): |
|
|
mt_decoder_padding_mask = prev_output_tokens_mt.eq(mt_decoder.padding_idx) |
|
|
|
|
|
|
|
|
if self.synthesizer_encoder is not None: |
|
|
tts_encoder_out = self.synthesizer_encoder( |
|
|
x, |
|
|
mt_decoder_padding_mask, |
|
|
return_all_hiddens=return_all_hiddens, |
|
|
) |
|
|
else: |
|
|
tts_encoder_out = { |
|
|
"encoder_out": [x], |
|
|
"encoder_padding_mask": [mt_decoder_padding_mask], |
|
|
} |
|
|
|
|
|
|
|
|
decoder_out = self.decoder( |
|
|
prev_output_tokens, |
|
|
encoder_out=tts_encoder_out, |
|
|
incremental_state=incremental_state, |
|
|
target_lengths=target_lengths, |
|
|
speaker=speaker, |
|
|
) |
|
|
if return_all_hiddens: |
|
|
decoder_out[-1]["encoder_states"] = encoder_out["encoder_states"] |
|
|
decoder_out[-1]["encoder_padding_mask"] = encoder_out[ |
|
|
"encoder_padding_mask" |
|
|
] |
|
|
decoder_out[-1]["mt_decoder_out"] = mt_decoder_out |
|
|
return decoder_out |
|
|
|
|
|
|
|
|
@register_model_architecture( |
|
|
model_name="s2spect2_conformer", arch_name="s2spect2_conformer" |
|
|
) |
|
|
def s2spect2_conformer_architecture_base(args): |
|
|
args.conv_version = getattr(args, "conv_version", "convtransformer") |
|
|
args.attn_type = getattr(args, "attn_type", None) |
|
|
args.pos_enc_type = getattr(args, "pos_enc_type", "abs") |
|
|
args.max_source_positions = getattr(args, "max_source_positions", 6000) |
|
|
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) |
|
|
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) |
|
|
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) |
|
|
args.dropout = getattr(args, "dropout", 0.1) |
|
|
args.encoder_layers = getattr(args, "encoder_layers", 16) |
|
|
args.depthwise_conv_kernel_size = getattr(args, "depthwise_conv_kernel_size", 31) |
|
|
s2spect_architecture_base(args) |
|
|
|
|
|
|
|
|
|
|
|
@register_model_architecture( |
|
|
model_name="s2spect2_conformer", arch_name="s2spect_conformer_translatotron2" |
|
|
) |
|
|
def s2spect2_conformer_architecture_base_legacy(args): |
|
|
s2spect2_conformer_architecture_base(args) |
|
|
|