| |
| |
|
|
| import torch.nn as nn |
|
|
| from vlmo.torchscale.architecture.decoder import Decoder |
| from vlmo.torchscale.architecture.encoder import Encoder |
|
|
|
|
| class EncoderDecoder(nn.Module): |
| def __init__( |
| self, |
| args, |
| encoder_embed_tokens=None, |
| encoder_embed_positions=None, |
| decoder_embed_tokens=None, |
| decoder_embed_positions=None, |
| output_projection=None, |
| **kwargs |
| ): |
| super().__init__() |
| self.args = args |
| if args.share_all_embeddings: |
| args.share_decoder_input_output_embed = True |
|
|
| self.encoder = Encoder(args, encoder_embed_tokens, encoder_embed_positions, is_encoder_decoder=True, **kwargs) |
|
|
| if args.share_all_embeddings and decoder_embed_tokens is None: |
| decoder_embed_tokens = self.encoder.embed_tokens |
|
|
| self.decoder = Decoder( |
| args, decoder_embed_tokens, decoder_embed_positions, output_projection, is_encoder_decoder=True, **kwargs |
| ) |
|
|
| def forward(self, src_tokens, prev_output_tokens, return_all_hiddens=False, features_only=False, **kwargs): |
| encoder_out = self.encoder(src_tokens, return_all_hiddens=return_all_hiddens) |
| decoder_out = self.decoder( |
| prev_output_tokens, |
| encoder_out=encoder_out, |
| features_only=features_only, |
| return_all_hiddens=return_all_hiddens, |
| ) |
| return decoder_out |
|
|