|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import List, Optional, Tuple |
|
|
|
|
|
from fairseq2.nn.padding import PaddingMask |
|
|
from fairseq2.nn.transformer import ( |
|
|
AttentionMask, |
|
|
AttentionMaskFactory, |
|
|
LayerNormFactory, |
|
|
StandardTransformerDecoderLayer, |
|
|
TransformerDecoder, |
|
|
TransformerDecoderLayer, |
|
|
TransformerNormOrder, |
|
|
) |
|
|
from fairseq2.typing import DataType, Device, override |
|
|
from torch import Generator, Tensor |
|
|
from torch.nn import Dropout, ModuleList |
|
|
|
|
|
from lcm.nn.incremental_state import LCMIncrementalStateBag |
|
|
|
|
|
|
|
|
class LCMStandardTransformerDecoderLayer(StandardTransformerDecoderLayer): |
|
|
"""Pass on `source_lengths` to StandardTransformerDecoderLayer's forward_pass.""" |
|
|
|
|
|
@override |
|
|
def forward( |
|
|
self, |
|
|
seqs: Tensor, |
|
|
padding_mask: Optional[PaddingMask], |
|
|
self_attn_mask: Optional[AttentionMask] = None, |
|
|
encoder_output: Optional[Tensor] = None, |
|
|
encoder_padding_mask: Optional[PaddingMask] = None, |
|
|
*, |
|
|
state_bag: Optional[LCMIncrementalStateBag] = None, |
|
|
) -> Tuple[Tensor, Optional[PaddingMask]]: |
|
|
seqs = self._forward_self_attn( |
|
|
seqs, |
|
|
padding_mask, |
|
|
self_attn_mask, |
|
|
state_bag, |
|
|
) |
|
|
|
|
|
seqs = self._forward_encoder_decoder_attn( |
|
|
seqs, padding_mask, encoder_output, encoder_padding_mask, state_bag |
|
|
) |
|
|
|
|
|
seqs = self._forward_ffn(seqs) |
|
|
|
|
|
return seqs, padding_mask |
|
|
|
|
|
@override |
|
|
def _forward_self_attn( |
|
|
self, |
|
|
seqs: Tensor, |
|
|
padding_mask: Optional[PaddingMask], |
|
|
self_attn_mask: Optional[AttentionMask], |
|
|
state_bag: Optional[LCMIncrementalStateBag], |
|
|
) -> Tensor: |
|
|
residual = seqs |
|
|
|
|
|
if self.norm_order != TransformerNormOrder.POST: |
|
|
seqs = self.self_attn_layer_norm(seqs) |
|
|
|
|
|
seqs = self.self_attn( |
|
|
seqs, |
|
|
padding_mask, |
|
|
keys=seqs, |
|
|
key_padding_mask=padding_mask, |
|
|
values=seqs, |
|
|
attn_mask=self_attn_mask, |
|
|
state_bag=state_bag, |
|
|
) |
|
|
|
|
|
if self.self_attn_norm is not None: |
|
|
seqs = self.self_attn_norm(seqs) |
|
|
|
|
|
if self.self_attn_dropout is not None: |
|
|
seqs = self.self_attn_dropout(seqs) |
|
|
|
|
|
seqs = seqs + residual |
|
|
|
|
|
if self.norm_order == TransformerNormOrder.POST: |
|
|
seqs = self.self_attn_layer_norm(seqs) |
|
|
|
|
|
return seqs |
|
|
|
|
|
|
|
|
class LCMTransformerDecoder(TransformerDecoder): |
|
|
def __init__( |
|
|
self, |
|
|
layers: List[TransformerDecoderLayer], |
|
|
layer_norm_factory: LayerNormFactory, |
|
|
self_attn_mask_factory: AttentionMaskFactory, |
|
|
use_causal_attn_mask: bool = True, |
|
|
generator: Optional[Generator] = None, |
|
|
dropout_p: float = 0.0, |
|
|
norm_order: TransformerNormOrder = TransformerNormOrder.POST, |
|
|
device: Optional[Device] = None, |
|
|
dtype: Optional[DataType] = None, |
|
|
) -> None: |
|
|
layer_list = ModuleList(layers) |
|
|
|
|
|
if not layer_list: |
|
|
raise ValueError("`layers` must be non-empty.") |
|
|
|
|
|
model_dim = layer_list[0].model_dim |
|
|
|
|
|
super().__init__(model_dim) |
|
|
|
|
|
self.self_attn_mask_factory = self_attn_mask_factory |
|
|
|
|
|
self.layers = layer_list |
|
|
|
|
|
self.generator = generator |
|
|
|
|
|
if norm_order != TransformerNormOrder.POST: |
|
|
self.layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype) |
|
|
else: |
|
|
self.register_module("layer_norm", None) |
|
|
|
|
|
if dropout_p > 0.0: |
|
|
self.dropout = Dropout(dropout_p) |
|
|
else: |
|
|
self.register_module("dropout", None) |
|
|
|
|
|
self.norm_order = norm_order |
|
|
|
|
|
@override |
|
|
def forward( |
|
|
self, |
|
|
seqs: Tensor, |
|
|
padding_mask: Optional[PaddingMask], |
|
|
encoder_output: Optional[Tensor] = None, |
|
|
encoder_padding_mask: Optional[PaddingMask] = None, |
|
|
*, |
|
|
state_bag: Optional[LCMIncrementalStateBag] = None, |
|
|
**kwargs, |
|
|
) -> Tuple[Tensor, Optional[PaddingMask]]: |
|
|
"""Pass on two additional arguments to StandardTransformerDecoder's forward_pass:""" |
|
|
num_layers = len(self.layers) |
|
|
|
|
|
self_attn_mask: Optional[AttentionMask] = None |
|
|
if self.self_attn_mask_factory is not None: |
|
|
self_attn_mask = self.self_attn_mask_factory( |
|
|
seqs, |
|
|
keys=seqs, |
|
|
training=self.training, |
|
|
state_bag=state_bag, |
|
|
) |
|
|
|
|
|
for layer_idx, layer in enumerate(self.layers): |
|
|
layer_output, layer_padding_mask = layer( |
|
|
seqs, |
|
|
padding_mask, |
|
|
self_attn_mask, |
|
|
encoder_output, |
|
|
encoder_padding_mask, |
|
|
state_bag=state_bag, |
|
|
) |
|
|
|
|
|
seqs, padding_mask = layer_output, layer_padding_mask |
|
|
|
|
|
for hook in self._layer_output_hooks.values(): |
|
|
if not hook(layer_idx, seqs, padding_mask, num_layers): |
|
|
break |
|
|
|
|
|
if self.layer_norm is not None: |
|
|
seqs = self.layer_norm(seqs) |
|
|
|
|
|
if self.dropout is not None: |
|
|
seqs = self.dropout(seqs) |
|
|
|
|
|
return seqs, padding_mask |
|
|
|