STAR / fairseq /models /speech_to_text /multi_modality_model.py
Yixuan Li
add fairseq folder
85ba398
raw
history blame
1.88 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.models import FairseqDecoder, FairseqEncoder
# a container for different encoders with training samples from different modality
# each time, only one encoder is selected
class MultiModalityEncoder(FairseqEncoder):
def __init__(self, dictionary):
super().__init__(dictionary)
def select_encoder(self, mode, **kwargs):
raise NotImplementedError("Model must implement the select_encoder method")
return None, kwargs
# def post_encoder(self, encoder_out, src_tokens, src_lengths, mode, **kwargs):
# # Default do nothing
# return encoder_out
# get sample data from JointSpeechTextDataset
def forward(self, src_tokens, src_lengths=None, mode="", **kwargs):
encoder, kwargs = self.select_encoder(mode, **kwargs)
# return self.post_encoder(encoder(src_tokens, src_lengths, **kwargs), src_tokens, src_lengths, mode, **kwargs)
return encoder(src_tokens, src_lengths, **kwargs)
# a container for different decoders with training samples from different modality
# each time, only one decoder is selected
class MultiInputDecoder(FairseqDecoder):
def __init__(self, dictionary):
super().__init__(dictionary)
def select_decoder(self, mode, **kwargs):
raise NotImplementedError("Model must implement the select_decoder method")
return None, kwargs
def forward(
self, prev_output_tokens, encoder_out, incremental_state=None, mode="", **kwargs
):
decoder, kwargs = self.select_decoder(mode, **kwargs)
return decoder(
prev_output_tokens,
encoder_out,
incremental_state=incremental_state,
**kwargs
)