| | |
| | |
| | |
| | |
| | |
| |
|
| | from typing import Iterable, List, Optional, Tuple, final |
| |
|
| | import torch |
| | from fairseq2.nn.incremental_state import IncrementalStateBag |
| | from fairseq2.nn.module_list import ModuleList |
| | from fairseq2.nn.normalization import LayerNorm |
| | from fairseq2.nn.padding import PaddingMask |
| | from fairseq2.nn.transformer import ( |
| | AttentionMaskFactory, |
| | CausalAttentionMaskFactory, |
| | create_standard_layer_norm, |
| | ) |
| | from fairseq2.typing import DataType, Device, finaloverride |
| | from torch import Tensor |
| | from torch.nn import Module |
| |
|
| | from seamless_communication.models.monotonic_decoder.monotonic_decoder_layer import ( |
| | MonotonicTransformerDecoderLayer, |
| | ) |
| |
|
| |
|
| | @final |
| | class MonotonicTransformerDecoder(Module): |
| | """Represents a Monotonic Transformer decoder.""" |
| |
|
| | model_dim: int |
| | self_attn_mask_factory: AttentionMaskFactory |
| | layers: ModuleList |
| | layer_norm: LayerNorm |
| |
|
| | def __init__( |
| | self, |
| | layers: Iterable[MonotonicTransformerDecoderLayer], |
| | *, |
| | device: Optional[Device] = None, |
| | dtype: Optional[DataType] = None, |
| | ) -> None: |
| | """ |
| | :param layers: |
| | The decoder layers. |
| | """ |
| | super().__init__() |
| |
|
| | layer_list = ModuleList(layers) |
| |
|
| | if not layer_list: |
| | raise ValueError("`layers` must be non-empty.") |
| |
|
| | self.model_dim = layer_list[0].model_dim |
| |
|
| | self.self_attn_mask_factory = CausalAttentionMaskFactory() |
| |
|
| | self.layers = layer_list |
| |
|
| | self.layer_norm = create_standard_layer_norm( |
| | self.model_dim, device=device, dtype=dtype |
| | ) |
| |
|
| | @finaloverride |
| | def forward( |
| | self, |
| | seqs: Tensor, |
| | padding_mask: Optional[PaddingMask], |
| | encoder_output: Optional[Tensor] = None, |
| | encoder_padding_mask: Optional[PaddingMask] = None, |
| | *, |
| | state_bag: Optional[IncrementalStateBag] = None, |
| | ) -> Tuple[Tensor, Optional[PaddingMask], Tensor]: |
| | self_attn_mask = self.self_attn_mask_factory( |
| | seqs, keys=seqs, training=self.training, state_bag=state_bag |
| | ) |
| |
|
| | p_choose_list: List[Tensor] = [] |
| |
|
| | for layer in self.layers.drop_iter(): |
| | seqs, padding_mask, p_choose = layer( |
| | seqs, |
| | padding_mask, |
| | self_attn_mask, |
| | encoder_output, |
| | encoder_padding_mask, |
| | state_bag=state_bag, |
| | ) |
| | p_choose_list.append(p_choose) |
| |
|
| | seqs = self.layer_norm(seqs) |
| |
|
| | p_choose = torch.cat(p_choose_list, dim=0) |
| |
|
| | p_choose = p_choose.flatten(0, 1) |
| |
|
| | return seqs, padding_mask, p_choose |
| |
|