# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from fairseq.models import FairseqDecoder, FairseqEncoder # a container for different encoders with training samples from different modality # each time, only one encoder is selected class MultiModalityEncoder(FairseqEncoder): def __init__(self, dictionary): super().__init__(dictionary) def select_encoder(self, mode, **kwargs): raise NotImplementedError("Model must implement the select_encoder method") return None, kwargs # def post_encoder(self, encoder_out, src_tokens, src_lengths, mode, **kwargs): # # Default do nothing # return encoder_out # get sample data from JointSpeechTextDataset def forward(self, src_tokens, src_lengths=None, mode="", **kwargs): encoder, kwargs = self.select_encoder(mode, **kwargs) # return self.post_encoder(encoder(src_tokens, src_lengths, **kwargs), src_tokens, src_lengths, mode, **kwargs) return encoder(src_tokens, src_lengths, **kwargs) # a container for different decoders with training samples from different modality # each time, only one decoder is selected class MultiInputDecoder(FairseqDecoder): def __init__(self, dictionary): super().__init__(dictionary) def select_decoder(self, mode, **kwargs): raise NotImplementedError("Model must implement the select_decoder method") return None, kwargs def forward( self, prev_output_tokens, encoder_out, incremental_state=None, mode="", **kwargs ): decoder, kwargs = self.select_decoder(mode, **kwargs) return decoder( prev_output_tokens, encoder_out, incremental_state=incremental_state, **kwargs )