Lexa
Initial commit
3d79eb3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
#
from dataclasses import dataclass
from typing import Literal, Optional
import torch
from fairseq2.logging import get_log_writer
from fairseq2.nn import PositionEncoder
from fairseq2.nn.position_encoder import (
LearnedPositionEncoder,
RotaryEncoder,
SinusoidalPositionEncoder,
)
from fairseq2.nn.projection import Linear
from fairseq2.nn.transformer import (
FeedForwardNetwork,
GLUFeedForwardNetwork,
MultiheadAttention,
StandardFeedForwardNetwork,
TransformerDecoderLayer,
create_default_sdpa,
)
from fairseq2.typing import DataType, Device
from lcm.nn.initialization import (
SUPPORTED_INIT_TYPES,
get_init_fn,
parse_activation_fn,
parse_norm_order,
)
from lcm.nn.normalization import SUPPORTED_LN_TYPES, parse_layer_norm_factory
from lcm.nn.transformer import LCMStandardTransformerDecoderLayer
from lcm.nn.transformer.attention import (
FullAttentionState,
QKNormMultiheadAttention,
)
SUPPORTED_NORM_ORDERS = Literal["pre", "post", "normformer"]
logger = get_log_writer(__name__)
@dataclass
class TransformerConfig:
"""A config object to group all config
hyper-parameters of a LCMTransformerDecoder"""
num_layers: int = 2
num_attn_heads: int = 8
# Dropout rates
dropout_p: float = 0.1
""" The dropout probability outputs of the attention layers and the
feed-forward network (before joining the residual stream)"""
final_dropout_p: float = 0.1
""" The dropout probability on decoder outputs"""
attention_dropout_p: float = 0.0
"""the dropout rate on attention weights in SDPA"""
# FFN
ffn_inner_dim: int = 1024 * 4
use_swiglu: bool = False
"""Use GLUFeedForwardNetwork instead of regular FFN blocks"""
ffn_inner_activation_name: str = "relu"
"""The activation to apply to outputs of the FFN inner projection layer.
Default is `relu `i.e., `torch.nn.ReLU`. This is only relevant when `use_swiglu= False`"""
# positional embedding
pos_embedding_style: Literal["rope", "sine", "learned", "none"] = "learned"
"""If `rope`: a rotary positional encoder in used in the attention layers.
If `sine`: Sinusoidal positional embeddings will be added in
the frontend before heading into the decoder
If `learned`: Learned positional embeddings will be added in
the frontend before heading into the decoder.
If `None`: no positional embeddings will be used (e.g. in the case
of unconditional diffusion of a single vector)."""
rope_theta: float = 10_000.0
""" The coefficient of the long-term decay of RoPE embeddings."""
# Normalization
layer_normalization_style: SUPPORTED_LN_TYPES = "standard"
norm_order_style: SUPPORTED_NORM_ORDERS = "pre"
"""LayerNorm order in the transformer decoder,
default is pre-normalization (`pre`). Other options are post-normalization (`post`)
and normformer-style normalization (`normformer`)"""
final_norm_order_style: Optional[SUPPORTED_NORM_ORDERS] = None
"""Controls lcm-level norm-order, using ``post`` here with a ``pre`` layer-level norm-order
means that we will skip the last layernorm in the stack"""
enable_qk_layernorm: bool = False
"""If ``True``, LayerNorms will be applied to queries and keys in self-attention layers
QK-LayerNorm described in https://arxiv.org/pdf/2302.05442 and subsequent work
is recommended to alleviate Transformer training instabilities
"""
mha_qkv_weight_normalization: bool = False
"""if ``True`` wrap the K/Q/V linears of MHA in weight normalization"""
mha_output_weight_normalization: bool = False
"""if ``True`` wrap the output projection of MHA with weight normalization.
This is a temporary fix to resume training some models and will be removed"""
# Miscellaneous
mha_output_proj_bias: bool = False
"""If ``True`` add a bias term to the MHA output projection"""
scale_residual: Optional[float] = None
"""scale to multiply the residual in the Transformer decoder"""
attention_output_init_fn: SUPPORTED_INIT_TYPES = "xavier"
class TransformerFactory:
def __init__(
self,
model_dim: int,
max_seq_len: int,
config: TransformerConfig,
device: Optional[Device] = None,
dtype: Optional[DataType] = None,
) -> None:
"""
:param model_dim:
The hidden model dimension of the Transformer
:params max_seq_len:
Maximum supported sequence length by the model
:param config:
The configuration.
:param device:
The device on which to initialize modules.
:param dtype:
The data type of module parameters and buffers.
"""
self.model_dim = model_dim
self.max_seq_len = max_seq_len
self.config = config
self.device, self.dtype = device, dtype
def build_layer(self) -> TransformerDecoderLayer:
"""Build a Transformer decoder layer based on the provided config."""
self_attn = self.build_attention()
ffn = self.build_ffn()
norm_order = parse_norm_order(self.config.norm_order_style)
layer_norm_factory = parse_layer_norm_factory(
self.config.layer_normalization_style
)
layer = LCMStandardTransformerDecoderLayer(
self_attn=self_attn,
encoder_decoder_attn=None,
ffn=ffn,
dropout_p=self.config.dropout_p,
norm_order=norm_order,
layer_norm_factory=layer_norm_factory,
scale_residual=self.config.scale_residual is not None,
device=self.device,
dtype=self.dtype,
)
# reset residual_scale
if layer.residual_scale is not None:
assert self.config.scale_residual is not None, (
f"Layer has a resiudal scale but scale={self.config.scale_residual}"
)
torch.nn.init.constant_(layer.residual_scale, self.config.scale_residual)
logger.info(
f"Initializing the residual scale at {self.config.scale_residual}"
)
return layer
def build_pos_encoder(self) -> Optional[PositionEncoder]:
"""Build the positional encoder (learned or sinusoidal, if any)
that will be used in the frontend"""
pos_encoder: Optional[PositionEncoder]
if self.config.pos_embedding_style == "learned":
pos_encoder = LearnedPositionEncoder(
self.model_dim,
self.max_seq_len,
device=self.device,
dtype=self.dtype,
)
elif self.config.pos_embedding_style == "sine":
pos_encoder = SinusoidalPositionEncoder(
self.model_dim,
self.max_seq_len,
device=self.device,
)
else:
pos_encoder = None
return pos_encoder
def build_attention_pos_encoder(self) -> Optional[PositionEncoder]:
"""Build the position encoder that can
potentially be used in the MHA module"""
pos_encoder: Optional[PositionEncoder]
if self.config.pos_embedding_style == "rope":
pos_encoder = RotaryEncoder(
encoding_dim=self.model_dim // self.config.num_attn_heads,
max_seq_len=self.max_seq_len,
theta=self.config.rope_theta,
device=self.device,
)
else:
pos_encoder = None
return pos_encoder
def build_attention(self) -> MultiheadAttention:
"""Build a Transformer multi-head attention layer."""
# allow for a different kv_dim
kv_dim = self.model_dim
# fairseq2.nn.transformer.attention.TorchSDPA
sdpa = create_default_sdpa(attn_dropout_p=self.config.attention_dropout_p)
init_fn = get_init_fn(self.config.attention_output_init_fn)
# How does Rope play with encoder-decoder attention?
pos_encoder = self.build_attention_pos_encoder()
layer_norm_factory = parse_layer_norm_factory(
self.config.layer_normalization_style
)
# build output_proj:
output_proj = Linear(
self.model_dim,
self.model_dim,
bias=self.config.mha_output_proj_bias,
init_fn=init_fn,
device=self.device,
dtype=self.dtype,
)
if self.config.mha_output_weight_normalization:
output_proj = torch.nn.utils.parametrizations.weight_norm(output_proj)
return QKNormMultiheadAttention(
self.model_dim,
self.config.num_attn_heads,
kv_dim=kv_dim,
pos_encoder=pos_encoder,
sdpa=sdpa,
output_proj=output_proj,
enable_qk_layernorm=self.config.enable_qk_layernorm,
weight_normalization=self.config.mha_qkv_weight_normalization,
layer_norm_factory=layer_norm_factory,
state_factory=FullAttentionState,
device=self.device,
dtype=self.dtype,
)
def build_ffn(self) -> FeedForwardNetwork:
"""Build a Transformer feed-forward network."""
if self.config.use_swiglu:
# Default gate_activation is torch.nn.SiLU
return GLUFeedForwardNetwork(
self.model_dim,
self.config.ffn_inner_dim,
bias=True,
inner_dim_scale=2 / 3,
inner_dim_to_multiple=256,
device=self.device,
dtype=self.dtype,
)
ffn_inner_activation = parse_activation_fn(
self.config.ffn_inner_activation_name
)
norm_order = parse_norm_order(self.config.norm_order_style)
return StandardFeedForwardNetwork(
self.model_dim,
self.config.ffn_inner_dim,
inner_activation=ffn_inner_activation,
bias=True,
norm_order=norm_order,
device=self.device,
dtype=self.dtype,
)