|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
|
from typing import Literal, Optional |
|
|
|
|
|
import torch |
|
|
from fairseq2.logging import get_log_writer |
|
|
from fairseq2.nn import PositionEncoder |
|
|
from fairseq2.nn.position_encoder import ( |
|
|
LearnedPositionEncoder, |
|
|
RotaryEncoder, |
|
|
SinusoidalPositionEncoder, |
|
|
) |
|
|
from fairseq2.nn.projection import Linear |
|
|
from fairseq2.nn.transformer import ( |
|
|
FeedForwardNetwork, |
|
|
GLUFeedForwardNetwork, |
|
|
MultiheadAttention, |
|
|
StandardFeedForwardNetwork, |
|
|
TransformerDecoderLayer, |
|
|
create_default_sdpa, |
|
|
) |
|
|
from fairseq2.typing import DataType, Device |
|
|
|
|
|
from lcm.nn.initialization import ( |
|
|
SUPPORTED_INIT_TYPES, |
|
|
get_init_fn, |
|
|
parse_activation_fn, |
|
|
parse_norm_order, |
|
|
) |
|
|
from lcm.nn.normalization import SUPPORTED_LN_TYPES, parse_layer_norm_factory |
|
|
from lcm.nn.transformer import LCMStandardTransformerDecoderLayer |
|
|
from lcm.nn.transformer.attention import ( |
|
|
FullAttentionState, |
|
|
QKNormMultiheadAttention, |
|
|
) |
|
|
|
|
|
SUPPORTED_NORM_ORDERS = Literal["pre", "post", "normformer"] |
|
|
|
|
|
|
|
|
logger = get_log_writer(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TransformerConfig: |
|
|
"""A config object to group all config |
|
|
hyper-parameters of a LCMTransformerDecoder""" |
|
|
|
|
|
num_layers: int = 2 |
|
|
|
|
|
num_attn_heads: int = 8 |
|
|
|
|
|
|
|
|
dropout_p: float = 0.1 |
|
|
""" The dropout probability outputs of the attention layers and the |
|
|
feed-forward network (before joining the residual stream)""" |
|
|
|
|
|
final_dropout_p: float = 0.1 |
|
|
""" The dropout probability on decoder outputs""" |
|
|
|
|
|
attention_dropout_p: float = 0.0 |
|
|
"""the dropout rate on attention weights in SDPA""" |
|
|
|
|
|
|
|
|
ffn_inner_dim: int = 1024 * 4 |
|
|
|
|
|
use_swiglu: bool = False |
|
|
"""Use GLUFeedForwardNetwork instead of regular FFN blocks""" |
|
|
|
|
|
ffn_inner_activation_name: str = "relu" |
|
|
|
|
|
"""The activation to apply to outputs of the FFN inner projection layer. |
|
|
Default is `relu `i.e., `torch.nn.ReLU`. This is only relevant when `use_swiglu= False`""" |
|
|
|
|
|
|
|
|
pos_embedding_style: Literal["rope", "sine", "learned", "none"] = "learned" |
|
|
|
|
|
"""If `rope`: a rotary positional encoder in used in the attention layers. |
|
|
If `sine`: Sinusoidal positional embeddings will be added in |
|
|
the frontend before heading into the decoder |
|
|
If `learned`: Learned positional embeddings will be added in |
|
|
the frontend before heading into the decoder. |
|
|
If `None`: no positional embeddings will be used (e.g. in the case |
|
|
of unconditional diffusion of a single vector).""" |
|
|
|
|
|
rope_theta: float = 10_000.0 |
|
|
""" The coefficient of the long-term decay of RoPE embeddings.""" |
|
|
|
|
|
|
|
|
layer_normalization_style: SUPPORTED_LN_TYPES = "standard" |
|
|
|
|
|
norm_order_style: SUPPORTED_NORM_ORDERS = "pre" |
|
|
"""LayerNorm order in the transformer decoder, |
|
|
default is pre-normalization (`pre`). Other options are post-normalization (`post`) |
|
|
and normformer-style normalization (`normformer`)""" |
|
|
|
|
|
final_norm_order_style: Optional[SUPPORTED_NORM_ORDERS] = None |
|
|
"""Controls lcm-level norm-order, using ``post`` here with a ``pre`` layer-level norm-order |
|
|
means that we will skip the last layernorm in the stack""" |
|
|
|
|
|
enable_qk_layernorm: bool = False |
|
|
"""If ``True``, LayerNorms will be applied to queries and keys in self-attention layers |
|
|
QK-LayerNorm described in https://arxiv.org/pdf/2302.05442 and subsequent work |
|
|
is recommended to alleviate Transformer training instabilities |
|
|
""" |
|
|
mha_qkv_weight_normalization: bool = False |
|
|
"""if ``True`` wrap the K/Q/V linears of MHA in weight normalization""" |
|
|
|
|
|
mha_output_weight_normalization: bool = False |
|
|
"""if ``True`` wrap the output projection of MHA with weight normalization. |
|
|
This is a temporary fix to resume training some models and will be removed""" |
|
|
|
|
|
|
|
|
mha_output_proj_bias: bool = False |
|
|
"""If ``True`` add a bias term to the MHA output projection""" |
|
|
|
|
|
scale_residual: Optional[float] = None |
|
|
"""scale to multiply the residual in the Transformer decoder""" |
|
|
|
|
|
attention_output_init_fn: SUPPORTED_INIT_TYPES = "xavier" |
|
|
|
|
|
|
|
|
class TransformerFactory: |
|
|
def __init__( |
|
|
self, |
|
|
model_dim: int, |
|
|
max_seq_len: int, |
|
|
config: TransformerConfig, |
|
|
device: Optional[Device] = None, |
|
|
dtype: Optional[DataType] = None, |
|
|
) -> None: |
|
|
""" |
|
|
:param model_dim: |
|
|
The hidden model dimension of the Transformer |
|
|
:params max_seq_len: |
|
|
Maximum supported sequence length by the model |
|
|
:param config: |
|
|
The configuration. |
|
|
:param device: |
|
|
The device on which to initialize modules. |
|
|
:param dtype: |
|
|
The data type of module parameters and buffers. |
|
|
""" |
|
|
self.model_dim = model_dim |
|
|
self.max_seq_len = max_seq_len |
|
|
self.config = config |
|
|
self.device, self.dtype = device, dtype |
|
|
|
|
|
def build_layer(self) -> TransformerDecoderLayer: |
|
|
"""Build a Transformer decoder layer based on the provided config.""" |
|
|
|
|
|
self_attn = self.build_attention() |
|
|
|
|
|
ffn = self.build_ffn() |
|
|
|
|
|
norm_order = parse_norm_order(self.config.norm_order_style) |
|
|
|
|
|
layer_norm_factory = parse_layer_norm_factory( |
|
|
self.config.layer_normalization_style |
|
|
) |
|
|
|
|
|
layer = LCMStandardTransformerDecoderLayer( |
|
|
self_attn=self_attn, |
|
|
encoder_decoder_attn=None, |
|
|
ffn=ffn, |
|
|
dropout_p=self.config.dropout_p, |
|
|
norm_order=norm_order, |
|
|
layer_norm_factory=layer_norm_factory, |
|
|
scale_residual=self.config.scale_residual is not None, |
|
|
device=self.device, |
|
|
dtype=self.dtype, |
|
|
) |
|
|
|
|
|
if layer.residual_scale is not None: |
|
|
assert self.config.scale_residual is not None, ( |
|
|
f"Layer has a resiudal scale but scale={self.config.scale_residual}" |
|
|
) |
|
|
torch.nn.init.constant_(layer.residual_scale, self.config.scale_residual) |
|
|
logger.info( |
|
|
f"Initializing the residual scale at {self.config.scale_residual}" |
|
|
) |
|
|
return layer |
|
|
|
|
|
def build_pos_encoder(self) -> Optional[PositionEncoder]: |
|
|
"""Build the positional encoder (learned or sinusoidal, if any) |
|
|
that will be used in the frontend""" |
|
|
pos_encoder: Optional[PositionEncoder] |
|
|
|
|
|
if self.config.pos_embedding_style == "learned": |
|
|
pos_encoder = LearnedPositionEncoder( |
|
|
self.model_dim, |
|
|
self.max_seq_len, |
|
|
device=self.device, |
|
|
dtype=self.dtype, |
|
|
) |
|
|
elif self.config.pos_embedding_style == "sine": |
|
|
pos_encoder = SinusoidalPositionEncoder( |
|
|
self.model_dim, |
|
|
self.max_seq_len, |
|
|
device=self.device, |
|
|
) |
|
|
|
|
|
else: |
|
|
pos_encoder = None |
|
|
|
|
|
return pos_encoder |
|
|
|
|
|
def build_attention_pos_encoder(self) -> Optional[PositionEncoder]: |
|
|
"""Build the position encoder that can |
|
|
potentially be used in the MHA module""" |
|
|
|
|
|
pos_encoder: Optional[PositionEncoder] |
|
|
|
|
|
if self.config.pos_embedding_style == "rope": |
|
|
pos_encoder = RotaryEncoder( |
|
|
encoding_dim=self.model_dim // self.config.num_attn_heads, |
|
|
max_seq_len=self.max_seq_len, |
|
|
theta=self.config.rope_theta, |
|
|
device=self.device, |
|
|
) |
|
|
else: |
|
|
pos_encoder = None |
|
|
return pos_encoder |
|
|
|
|
|
def build_attention(self) -> MultiheadAttention: |
|
|
"""Build a Transformer multi-head attention layer.""" |
|
|
|
|
|
|
|
|
kv_dim = self.model_dim |
|
|
|
|
|
|
|
|
sdpa = create_default_sdpa(attn_dropout_p=self.config.attention_dropout_p) |
|
|
|
|
|
init_fn = get_init_fn(self.config.attention_output_init_fn) |
|
|
|
|
|
|
|
|
pos_encoder = self.build_attention_pos_encoder() |
|
|
|
|
|
layer_norm_factory = parse_layer_norm_factory( |
|
|
self.config.layer_normalization_style |
|
|
) |
|
|
|
|
|
|
|
|
output_proj = Linear( |
|
|
self.model_dim, |
|
|
self.model_dim, |
|
|
bias=self.config.mha_output_proj_bias, |
|
|
init_fn=init_fn, |
|
|
device=self.device, |
|
|
dtype=self.dtype, |
|
|
) |
|
|
if self.config.mha_output_weight_normalization: |
|
|
output_proj = torch.nn.utils.parametrizations.weight_norm(output_proj) |
|
|
|
|
|
return QKNormMultiheadAttention( |
|
|
self.model_dim, |
|
|
self.config.num_attn_heads, |
|
|
kv_dim=kv_dim, |
|
|
pos_encoder=pos_encoder, |
|
|
sdpa=sdpa, |
|
|
output_proj=output_proj, |
|
|
enable_qk_layernorm=self.config.enable_qk_layernorm, |
|
|
weight_normalization=self.config.mha_qkv_weight_normalization, |
|
|
layer_norm_factory=layer_norm_factory, |
|
|
state_factory=FullAttentionState, |
|
|
device=self.device, |
|
|
dtype=self.dtype, |
|
|
) |
|
|
|
|
|
def build_ffn(self) -> FeedForwardNetwork: |
|
|
"""Build a Transformer feed-forward network.""" |
|
|
if self.config.use_swiglu: |
|
|
|
|
|
return GLUFeedForwardNetwork( |
|
|
self.model_dim, |
|
|
self.config.ffn_inner_dim, |
|
|
bias=True, |
|
|
inner_dim_scale=2 / 3, |
|
|
inner_dim_to_multiple=256, |
|
|
device=self.device, |
|
|
dtype=self.dtype, |
|
|
) |
|
|
|
|
|
ffn_inner_activation = parse_activation_fn( |
|
|
self.config.ffn_inner_activation_name |
|
|
) |
|
|
norm_order = parse_norm_order(self.config.norm_order_style) |
|
|
|
|
|
return StandardFeedForwardNetwork( |
|
|
self.model_dim, |
|
|
self.config.ffn_inner_dim, |
|
|
inner_activation=ffn_inner_activation, |
|
|
bias=True, |
|
|
norm_order=norm_order, |
|
|
device=self.device, |
|
|
dtype=self.dtype, |
|
|
) |
|
|
|