| |
| |
| |
| |
| |
| |
|
|
| from fairseq import checkpoint_utils |
| from fairseq.models import ( |
| register_model, |
| register_model_architecture, |
| ) |
| from fairseq.models.speech_to_text import ( |
| ConvTransformerModel, |
| convtransformer_espnet, |
| ConvTransformerEncoder, |
| ) |
| from fairseq.models.speech_to_text.modules.augmented_memory_attention import ( |
| augmented_memory, |
| SequenceEncoder, |
| AugmentedMemoryConvTransformerEncoder, |
| ) |
|
|
| from torch import nn, Tensor |
| from typing import Dict, List |
| from fairseq.models.speech_to_text.modules.emformer import NoSegAugmentedMemoryTransformerEncoderLayer |
|
|
| @register_model("convtransformer_simul_trans") |
| class SimulConvTransformerModel(ConvTransformerModel): |
| """ |
| Implementation of the paper: |
| |
| SimulMT to SimulST: Adapting Simultaneous Text Translation to |
| End-to-End Simultaneous Speech Translation |
| |
| https://www.aclweb.org/anthology/2020.aacl-main.58.pdf |
| """ |
|
|
| @staticmethod |
| def add_args(parser): |
| super(SimulConvTransformerModel, SimulConvTransformerModel).add_args(parser) |
| parser.add_argument( |
| "--train-monotonic-only", |
| action="store_true", |
| default=False, |
| help="Only train monotonic attention", |
| ) |
|
|
| @classmethod |
| def build_decoder(cls, args, task, embed_tokens): |
| tgt_dict = task.tgt_dict |
|
|
| from examples.simultaneous_translation.models.transformer_monotonic_attention import ( |
| TransformerMonotonicDecoder, |
| ) |
|
|
| decoder = TransformerMonotonicDecoder(args, tgt_dict, embed_tokens) |
|
|
| if getattr(args, "load_pretrained_decoder_from", None): |
| decoder = checkpoint_utils.load_pretrained_component_from_model( |
| component=decoder, checkpoint=args.load_pretrained_decoder_from |
| ) |
| return decoder |
|
|
|
|
| @register_model_architecture( |
| "convtransformer_simul_trans", "convtransformer_simul_trans_espnet" |
| ) |
| def convtransformer_simul_trans_espnet(args): |
| convtransformer_espnet(args) |
|
|
|
|
| @register_model("convtransformer_augmented_memory") |
| @augmented_memory |
| class AugmentedMemoryConvTransformerModel(SimulConvTransformerModel): |
| @classmethod |
| def build_encoder(cls, args): |
| encoder = SequenceEncoder(args, AugmentedMemoryConvTransformerEncoder(args)) |
|
|
| if getattr(args, "load_pretrained_encoder_from", None) is not None: |
| encoder = checkpoint_utils.load_pretrained_component_from_model( |
| component=encoder, checkpoint=args.load_pretrained_encoder_from |
| ) |
|
|
| return encoder |
|
|
|
|
| @register_model_architecture( |
| "convtransformer_augmented_memory", "convtransformer_augmented_memory" |
| ) |
| def augmented_memory_convtransformer_espnet(args): |
| convtransformer_espnet(args) |
|
|
|
|
| |
| |
| |
| |
| |
|
|
|
|
| class ConvTransformerEmformerEncoder(ConvTransformerEncoder): |
| def __init__(self, args): |
| super().__init__(args) |
| stride = self.conv_layer_stride(args) |
| trf_left_context = args.segment_left_context // stride |
| trf_right_context = args.segment_right_context // stride |
| context_config = [trf_left_context, trf_right_context] |
| self.transformer_layers = nn.ModuleList( |
| [ |
| NoSegAugmentedMemoryTransformerEncoderLayer( |
| input_dim=args.encoder_embed_dim, |
| num_heads=args.encoder_attention_heads, |
| ffn_dim=args.encoder_ffn_embed_dim, |
| num_layers=args.encoder_layers, |
| dropout_in_attn=args.dropout, |
| dropout_on_attn=args.dropout, |
| dropout_on_fc1=args.dropout, |
| dropout_on_fc2=args.dropout, |
| activation_fn=args.activation_fn, |
| context_config=context_config, |
| segment_size=args.segment_length, |
| max_memory_size=args.max_memory_size, |
| scaled_init=True, |
| tanh_on_mem=args.amtrf_tanh_on_mem, |
| ) |
| ] |
| ) |
| self.conv_transformer_encoder = ConvTransformerEncoder(args) |
|
|
| def forward(self, src_tokens, src_lengths): |
| encoder_out: Dict[str, List[Tensor]] = self.conv_transformer_encoder(src_tokens, src_lengths.to(src_tokens.device)) |
| output = encoder_out["encoder_out"][0] |
| encoder_padding_masks = encoder_out["encoder_padding_mask"] |
|
|
| return { |
| "encoder_out": [output], |
| |
| |
| "encoder_padding_mask": [encoder_padding_masks[0][:, : output.size(0)]] if len(encoder_padding_masks) > 0 |
| else [], |
| "encoder_embedding": [], |
| "encoder_states": [], |
| "src_tokens": [], |
| "src_lengths": [], |
| } |
|
|
| @staticmethod |
| def conv_layer_stride(args): |
| |
| return 4 |
|
|
|
|
| @register_model("convtransformer_emformer") |
| class ConvtransformerEmformer(SimulConvTransformerModel): |
| @staticmethod |
| def add_args(parser): |
| super(ConvtransformerEmformer, ConvtransformerEmformer).add_args(parser) |
|
|
| parser.add_argument( |
| "--segment-length", |
| type=int, |
| metavar="N", |
| help="length of each segment (not including left context / right context)", |
| ) |
| parser.add_argument( |
| "--segment-left-context", |
| type=int, |
| help="length of left context in a segment", |
| ) |
| parser.add_argument( |
| "--segment-right-context", |
| type=int, |
| help="length of right context in a segment", |
| ) |
| parser.add_argument( |
| "--max-memory-size", |
| type=int, |
| default=-1, |
| help="Right context for the segment.", |
| ) |
| parser.add_argument( |
| "--amtrf-tanh-on-mem", |
| default=False, |
| action="store_true", |
| help="whether to use tanh on memory vector", |
| ) |
|
|
| @classmethod |
| def build_encoder(cls, args): |
| encoder = ConvTransformerEmformerEncoder(args) |
| if getattr(args, "load_pretrained_encoder_from", None): |
| encoder = checkpoint_utils.load_pretrained_component_from_model( |
| component=encoder, checkpoint=args.load_pretrained_encoder_from |
| ) |
| return encoder |
|
|
|
|
| @register_model_architecture( |
| "convtransformer_emformer", |
| "convtransformer_emformer", |
| ) |
| def convtransformer_emformer_base(args): |
| convtransformer_espnet(args) |
|
|