Lexa
Initial commit
3d79eb3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
#
from dataclasses import dataclass, field
from typing import Optional
import torch.nn
from fairseq2.config_registry import ConfigRegistry
from fairseq2.logging import get_log_writer
from fairseq2.nn.incremental_state import IncrementalStateBag
from fairseq2.nn.transformer import AttentionMaskFactory, CausalAttentionMaskFactory
from fairseq2.typing import DataType, Device
from lcm.datasets.batch import EmbeddingsBatch
from lcm.models.abstract_lcm import (
AbstractLCModel,
AbstractLCModelBuilder,
AbstractLCModelConfig,
)
from lcm.models.base_lcm.frontend import LCMFrontend, LCMFrontendConfig
from lcm.nn.initialization import parse_norm_order
from lcm.nn.normalization import parse_layer_norm_factory
from lcm.nn.projection import Projection, ProjectionConfig
from lcm.nn.transformer import (
LCMTransformerDecoder,
TransformerConfig,
TransformerFactory,
)
logger = get_log_writer(__name__)
BASE_LCM_MODEL_TYPE = "base_lcm"
@dataclass
class BaseLCModelConfig(AbstractLCModelConfig):
model_type: str = BASE_LCM_MODEL_TYPE
max_seq_len: int = 2048
model_dim: int = 1024
model_output_dim: Optional[int] = None
"""If ``None`` use SONAR dimension as output_dim."""
frontend: LCMFrontendConfig = field(default_factory=lambda: LCMFrontendConfig())
"""The fronted config. This module maps from `sonar_embed_dim` to `model_dim`
and potentially adds positional embeddings"""
lcm: TransformerConfig = field(default_factory=lambda: TransformerConfig())
"""The core lcm config. This is causal Transformer decoder"""
postnet: ProjectionConfig = field(default_factory=lambda: ProjectionConfig())
"""The postnet config. A module mapping the output of the core lcm
back to `sonar_embed_dim`"""
lcm_archs = ConfigRegistry[BaseLCModelConfig]()
lcm_arch = lcm_archs.decorator
class BaseLCModel(AbstractLCModel):
"""Base class for LCM models"""
config: BaseLCModelConfig
def __init__(
self,
config: BaseLCModelConfig,
lcm: LCMTransformerDecoder,
frontend: LCMFrontend,
postnet: Projection,
) -> None:
"""
Basic LCM model with :
- fronted
- lcm
- postnet
"""
super().__init__(config)
self.frontend = frontend
self.lcm = lcm
self.postnet = postnet
self.model_dim = lcm.model_dim
self.sonar_embed_dim = config.sonar_embed_dim
def forward(
self,
batch: EmbeddingsBatch,
state_bag: Optional[IncrementalStateBag] = None,
**kwargs,
) -> EmbeddingsBatch:
"""
Scaling + Positions
If a normalizer is provided, the features will be normalized in the
frontend's pre_forward (e.g. MSE LCM) or in the criterion (Diffusion LCM)
"""
seqs, padding_mask = self.frontend(
batch.seqs,
batch.padding_mask,
diffusion_timesteps=batch.diffusion_timesteps,
state_bag=state_bag,
**kwargs,
)
# Core LCM
seqs, padding_mask = self.lcm(
seqs,
padding_mask,
state_bag=state_bag,
**kwargs,
)
# Postnet:
seqs = self.postnet(seqs) # type: ignore
return EmbeddingsBatch(seqs=seqs, padding_mask=padding_mask)
def predict_next_sentence(
self,
batch: EmbeddingsBatch,
sample: bool = False,
temperature: float = 1.0,
state_bag: Optional[IncrementalStateBag] = None,
**kwargs,
) -> EmbeddingsBatch:
"""
The method for predicting the next sentence embeddings.
In the basic LCM, this is equivalent to just the forward method,
but the derived architectures may have a different implementation.
E.g. in VAE LCM, we run the VAE decoder on top of the `forward` results.
Args:
batch (EmbeddingsBatch): the sequence of concepts which
the model should continue.
sample (bool): whether to predict the single most probable next sentence
or to sample from the predicted distribution.
temperature (float): a positive float indicating the degree of diversity
for the sampling (active only if `sample is True`).
Returns:
EmbeddingsBatch: the batch with predicted SONAR sentences.
"""
# Normalize the input embeddings if we're expected to
# normalize outside of the model's forward pass
if self.frontend.sonar_normalizer is not None:
batch = batch.normalize_seqs(self.frontend.sonar_normalizer)
# TODO: implement efficient sampling of multiple candidates
predicted_means = self.forward(batch, state_bag=state_bag, **kwargs)
if sample and temperature > 0:
noise = torch.randn_like(predicted_means.seqs) * temperature
predicted_means.seqs = predicted_means.seqs + noise
if self.frontend.sonar_normalizer is not None:
predicted_means = predicted_means.denormalize_seqs(
self.frontend.sonar_normalizer
)
return predicted_means
class BaseLCModelBuilder(AbstractLCModelBuilder):
"""Builds modules of a base LCM model"""
config: BaseLCModelConfig
device: Optional[Device]
dtype: Optional[DataType]
def __init__(
self,
config: BaseLCModelConfig,
*,
device: Optional[Device] = None,
dtype: Optional[DataType] = None,
) -> None:
super().__init__(config=config, device=device, dtype=dtype)
self.lcm_factory = TransformerFactory(
model_dim=self.config.model_dim,
max_seq_len=self.config.max_seq_len,
config=self.config.lcm,
device=device,
dtype=dtype,
)
if config.model_output_dim is None:
self.model_output_dim = self.config.sonar_embed_dim
else:
self.model_output_dim = config.model_output_dim
def build_model(self) -> BaseLCModel:
"""Build a model."""
frontend = self.build_frontend()
lcm = self.build_core_lcm()
postnet = self.build_postnet()
return BaseLCModel(
config=self.config,
frontend=frontend,
lcm=lcm,
postnet=postnet,
)
def build_frontend(self) -> LCMFrontend:
"""Build the LCM front-end (i.e., prenet)."""
return LCMFrontend(
sonar_embed_dim=self.config.sonar_embed_dim,
model_dim=self.config.model_dim,
config=self.config.frontend,
pos_encoder=self.lcm_factory.build_pos_encoder(),
sonar_normalizer=self.build_sonar_normalizer(),
device=self.device,
dtype=self.dtype,
)
def build_postnet(self) -> Projection:
return Projection(
output_dim=self.model_output_dim,
input_dim=self.config.model_dim,
config=self.config.postnet,
device=self.device,
dtype=self.dtype,
)
def build_attention_mask_factory(self):
self_attn_mask_factory: AttentionMaskFactory
self_attn_mask_factory = CausalAttentionMaskFactory()
return self_attn_mask_factory
def build_core_lcm(self) -> LCMTransformerDecoder:
"""Build the core LCM module."""
config = self.config.lcm
layers = [self.lcm_factory.build_layer() for _ in range(config.num_layers)]
self_attn_mask_factory = self.build_attention_mask_factory()
if config.final_norm_order_style is None:
# The final norm order style will be that of the layer-level norm order
final_norm_order = parse_norm_order(config.norm_order_style)
else:
final_norm_order = parse_norm_order(config.final_norm_order_style)
layer_norm_factory = parse_layer_norm_factory(config.layer_normalization_style)
return LCMTransformerDecoder(
layers, # type: ignore
self_attn_mask_factory=self_attn_mask_factory,
norm_order=final_norm_order,
layer_norm_factory=layer_norm_factory,
dropout_p=config.final_dropout_p,
device=self.device,
dtype=self.dtype,
)
def create_base_lcm_model(
config: BaseLCModelConfig,
*,
device: Optional[Device] = None,
dtype: Optional[DataType] = None,
) -> BaseLCModel:
"""Create an LCM model.
:param config:
The configuration.
:param device:
The device on which to initialize modules.
:param dtype:
The data type of module parameters and buffers.
"""
return BaseLCModelBuilder(config, device=device, dtype=dtype).build_model()