Lexa
Initial commit
3d79eb3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
#
from typing import List, Optional, Tuple
from fairseq2.nn.padding import PaddingMask
from fairseq2.nn.transformer import (
AttentionMask,
AttentionMaskFactory,
LayerNormFactory,
StandardTransformerDecoderLayer,
TransformerDecoder,
TransformerDecoderLayer,
TransformerNormOrder,
)
from fairseq2.typing import DataType, Device, override
from torch import Generator, Tensor
from torch.nn import Dropout, ModuleList
from lcm.nn.incremental_state import LCMIncrementalStateBag
class LCMStandardTransformerDecoderLayer(StandardTransformerDecoderLayer): # type: ignore
"""Pass on `source_lengths` to StandardTransformerDecoderLayer's forward_pass."""
@override
def forward( # type: ignore
self,
seqs: Tensor,
padding_mask: Optional[PaddingMask],
self_attn_mask: Optional[AttentionMask] = None,
encoder_output: Optional[Tensor] = None,
encoder_padding_mask: Optional[PaddingMask] = None,
*,
state_bag: Optional[LCMIncrementalStateBag] = None,
) -> Tuple[Tensor, Optional[PaddingMask]]:
seqs = self._forward_self_attn(
seqs,
padding_mask,
self_attn_mask,
state_bag,
)
seqs = self._forward_encoder_decoder_attn(
seqs, padding_mask, encoder_output, encoder_padding_mask, state_bag
)
seqs = self._forward_ffn(seqs)
return seqs, padding_mask
@override
def _forward_self_attn( # type: ignore
self,
seqs: Tensor,
padding_mask: Optional[PaddingMask],
self_attn_mask: Optional[AttentionMask],
state_bag: Optional[LCMIncrementalStateBag],
) -> Tensor:
residual = seqs
if self.norm_order != TransformerNormOrder.POST:
seqs = self.self_attn_layer_norm(seqs)
seqs = self.self_attn(
seqs,
padding_mask,
keys=seqs,
key_padding_mask=padding_mask,
values=seqs,
attn_mask=self_attn_mask,
state_bag=state_bag,
)
if self.self_attn_norm is not None:
seqs = self.self_attn_norm(seqs)
if self.self_attn_dropout is not None:
seqs = self.self_attn_dropout(seqs)
seqs = seqs + residual
if self.norm_order == TransformerNormOrder.POST:
seqs = self.self_attn_layer_norm(seqs)
return seqs
class LCMTransformerDecoder(TransformerDecoder):
def __init__(
self,
layers: List[TransformerDecoderLayer],
layer_norm_factory: LayerNormFactory,
self_attn_mask_factory: AttentionMaskFactory,
use_causal_attn_mask: bool = True,
generator: Optional[Generator] = None,
dropout_p: float = 0.0,
norm_order: TransformerNormOrder = TransformerNormOrder.POST,
device: Optional[Device] = None,
dtype: Optional[DataType] = None,
) -> None:
layer_list = ModuleList(layers)
if not layer_list:
raise ValueError("`layers` must be non-empty.")
model_dim = layer_list[0].model_dim
super().__init__(model_dim)
self.self_attn_mask_factory = self_attn_mask_factory
self.layers = layer_list
self.generator = generator
if norm_order != TransformerNormOrder.POST:
self.layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype)
else:
self.register_module("layer_norm", None)
if dropout_p > 0.0:
self.dropout = Dropout(dropout_p)
else:
self.register_module("dropout", None)
self.norm_order = norm_order
@override
def forward( # type: ignore
self,
seqs: Tensor,
padding_mask: Optional[PaddingMask],
encoder_output: Optional[Tensor] = None,
encoder_padding_mask: Optional[PaddingMask] = None,
*,
state_bag: Optional[LCMIncrementalStateBag] = None,
**kwargs,
) -> Tuple[Tensor, Optional[PaddingMask]]:
"""Pass on two additional arguments to StandardTransformerDecoder's forward_pass:"""
num_layers = len(self.layers)
self_attn_mask: Optional[AttentionMask] = None
if self.self_attn_mask_factory is not None:
self_attn_mask = self.self_attn_mask_factory(
seqs,
keys=seqs,
training=self.training,
state_bag=state_bag,
)
for layer_idx, layer in enumerate(self.layers):
layer_output, layer_padding_mask = layer(
seqs,
padding_mask,
self_attn_mask,
encoder_output,
encoder_padding_mask,
state_bag=state_bag,
)
seqs, padding_mask = layer_output, layer_padding_mask
for hook in self._layer_output_hooks.values():
if not hook(layer_idx, seqs, padding_mask, num_layers):
break
if self.layer_norm is not None:
seqs = self.layer_norm(seqs)
if self.dropout is not None:
seqs = self.dropout(seqs)
return seqs, padding_mask