| | |
| | |
| | |
| | |
| |
|
| | from typing import Dict, List, NamedTuple, Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch import Tensor |
| |
|
| |
|
| | EncoderOut = NamedTuple( |
| | "EncoderOut", |
| | [ |
| | ("encoder_out", Tensor), |
| | ("encoder_padding_mask", Optional[Tensor]), |
| | ("encoder_embedding", Optional[Tensor]), |
| | ("encoder_states", Optional[List[Tensor]]), |
| | ("src_tokens", Optional[Tensor]), |
| | ("src_lengths", Optional[Tensor]), |
| | ], |
| | ) |
| |
|
| |
|
| | class FairseqEncoder(nn.Module): |
| | """Base class for encoders.""" |
| |
|
| | def __init__(self, dictionary): |
| | super().__init__() |
| | self.dictionary = dictionary |
| |
|
| | def forward(self, src_tokens, src_lengths=None, **kwargs): |
| | """ |
| | Args: |
| | src_tokens (LongTensor): tokens in the source language of shape |
| | `(batch, src_len)` |
| | src_lengths (LongTensor): lengths of each source sentence of shape |
| | `(batch)` |
| | """ |
| | raise NotImplementedError |
| |
|
| | def forward_torchscript(self, net_input: Dict[str, Tensor]): |
| | """A TorchScript-compatible version of forward. |
| | |
| | Encoders which use additional arguments may want to override |
| | this method for TorchScript compatibility. |
| | """ |
| | if torch.jit.is_scripting(): |
| | return self.forward( |
| | src_tokens=net_input["src_tokens"], |
| | src_lengths=net_input["src_lengths"], |
| | ) |
| | else: |
| | return self.forward_non_torchscript(net_input) |
| |
|
| | @torch.jit.unused |
| | def forward_non_torchscript(self, net_input: Dict[str, Tensor]): |
| | encoder_input = { |
| | k: v for k, v in net_input.items() if k != "prev_output_tokens" |
| | } |
| | return self.forward(**encoder_input) |
| |
|
| | def reorder_encoder_out(self, encoder_out, new_order): |
| | """ |
| | Reorder encoder output according to `new_order`. |
| | |
| | Args: |
| | encoder_out: output from the ``forward()`` method |
| | new_order (LongTensor): desired order |
| | |
| | Returns: |
| | `encoder_out` rearranged according to `new_order` |
| | """ |
| | raise NotImplementedError |
| |
|
| | def max_positions(self): |
| | """Maximum input length supported by the encoder.""" |
| | return 1e6 |
| |
|
| | def upgrade_state_dict_named(self, state_dict, name): |
| | """Upgrade old state dicts to work with newer code.""" |
| | return state_dict |
| |
|
| | def set_num_updates(self, num_updates): |
| | """State from trainer to pass along to model at every update.""" |
| |
|
| | def _apply(m): |
| | if hasattr(m, "set_num_updates") and m != self: |
| | m.set_num_updates(num_updates) |
| |
|
| | self.apply(_apply) |
| |
|