| | |
| | |
| | |
| | |
| |
|
| | from .fairseq_encoder import FairseqEncoder |
| |
|
| |
|
| | class CompositeEncoder(FairseqEncoder): |
| | """ |
| | A wrapper around a dictionary of :class:`FairseqEncoder` objects. |
| | |
| | We run forward on each encoder and return a dictionary of outputs. The first |
| | encoder's dictionary is used for initialization. |
| | |
| | Args: |
| | encoders (dict): a dictionary of :class:`FairseqEncoder` objects. |
| | """ |
| |
|
| | def __init__(self, encoders): |
| | super().__init__(next(iter(encoders.values())).dictionary) |
| | self.encoders = encoders |
| | for key in self.encoders: |
| | self.add_module(key, self.encoders[key]) |
| |
|
| | def forward(self, src_tokens, src_lengths): |
| | """ |
| | Args: |
| | src_tokens (LongTensor): tokens in the source language of shape |
| | `(batch, src_len)` |
| | src_lengths (LongTensor): lengths of each source sentence of shape |
| | `(batch)` |
| | |
| | Returns: |
| | dict: |
| | the outputs from each Encoder |
| | """ |
| | encoder_out = {} |
| | for key in self.encoders: |
| | encoder_out[key] = self.encoders[key](src_tokens, src_lengths) |
| | return encoder_out |
| |
|
| | def reorder_encoder_out(self, encoder_out, new_order): |
| | """Reorder encoder output according to new_order.""" |
| | for key in self.encoders: |
| | encoder_out[key] = self.encoders[key].reorder_encoder_out( |
| | encoder_out[key], new_order |
| | ) |
| | return encoder_out |
| |
|
| | def max_positions(self): |
| | return min(self.encoders[key].max_positions() for key in self.encoders) |
| |
|
| | def upgrade_state_dict(self, state_dict): |
| | for key in self.encoders: |
| | self.encoders[key].upgrade_state_dict(state_dict) |
| | return state_dict |
| |
|