|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass, field |
|
|
from typing import Optional |
|
|
|
|
|
import torch.nn |
|
|
from fairseq2.config_registry import ConfigRegistry |
|
|
from fairseq2.logging import get_log_writer |
|
|
from fairseq2.nn.incremental_state import IncrementalStateBag |
|
|
from fairseq2.nn.transformer import AttentionMaskFactory, CausalAttentionMaskFactory |
|
|
from fairseq2.typing import DataType, Device |
|
|
|
|
|
from lcm.datasets.batch import EmbeddingsBatch |
|
|
from lcm.models.abstract_lcm import ( |
|
|
AbstractLCModel, |
|
|
AbstractLCModelBuilder, |
|
|
AbstractLCModelConfig, |
|
|
) |
|
|
from lcm.models.base_lcm.frontend import LCMFrontend, LCMFrontendConfig |
|
|
from lcm.nn.initialization import parse_norm_order |
|
|
from lcm.nn.normalization import parse_layer_norm_factory |
|
|
from lcm.nn.projection import Projection, ProjectionConfig |
|
|
from lcm.nn.transformer import ( |
|
|
LCMTransformerDecoder, |
|
|
TransformerConfig, |
|
|
TransformerFactory, |
|
|
) |
|
|
|
|
|
logger = get_log_writer(__name__) |
|
|
|
|
|
BASE_LCM_MODEL_TYPE = "base_lcm" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class BaseLCModelConfig(AbstractLCModelConfig): |
|
|
model_type: str = BASE_LCM_MODEL_TYPE |
|
|
|
|
|
max_seq_len: int = 2048 |
|
|
|
|
|
model_dim: int = 1024 |
|
|
|
|
|
model_output_dim: Optional[int] = None |
|
|
"""If ``None`` use SONAR dimension as output_dim.""" |
|
|
|
|
|
frontend: LCMFrontendConfig = field(default_factory=lambda: LCMFrontendConfig()) |
|
|
"""The fronted config. This module maps from `sonar_embed_dim` to `model_dim` |
|
|
and potentially adds positional embeddings""" |
|
|
|
|
|
lcm: TransformerConfig = field(default_factory=lambda: TransformerConfig()) |
|
|
"""The core lcm config. This is causal Transformer decoder""" |
|
|
|
|
|
postnet: ProjectionConfig = field(default_factory=lambda: ProjectionConfig()) |
|
|
"""The postnet config. A module mapping the output of the core lcm |
|
|
back to `sonar_embed_dim`""" |
|
|
|
|
|
|
|
|
lcm_archs = ConfigRegistry[BaseLCModelConfig]() |
|
|
lcm_arch = lcm_archs.decorator |
|
|
|
|
|
|
|
|
class BaseLCModel(AbstractLCModel): |
|
|
"""Base class for LCM models""" |
|
|
|
|
|
config: BaseLCModelConfig |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: BaseLCModelConfig, |
|
|
lcm: LCMTransformerDecoder, |
|
|
frontend: LCMFrontend, |
|
|
postnet: Projection, |
|
|
) -> None: |
|
|
""" |
|
|
Basic LCM model with : |
|
|
- fronted |
|
|
- lcm |
|
|
- postnet |
|
|
""" |
|
|
super().__init__(config) |
|
|
|
|
|
self.frontend = frontend |
|
|
|
|
|
self.lcm = lcm |
|
|
|
|
|
self.postnet = postnet |
|
|
|
|
|
self.model_dim = lcm.model_dim |
|
|
|
|
|
self.sonar_embed_dim = config.sonar_embed_dim |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
batch: EmbeddingsBatch, |
|
|
state_bag: Optional[IncrementalStateBag] = None, |
|
|
**kwargs, |
|
|
) -> EmbeddingsBatch: |
|
|
""" |
|
|
Scaling + Positions |
|
|
If a normalizer is provided, the features will be normalized in the |
|
|
frontend's pre_forward (e.g. MSE LCM) or in the criterion (Diffusion LCM) |
|
|
""" |
|
|
seqs, padding_mask = self.frontend( |
|
|
batch.seqs, |
|
|
batch.padding_mask, |
|
|
diffusion_timesteps=batch.diffusion_timesteps, |
|
|
state_bag=state_bag, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
seqs, padding_mask = self.lcm( |
|
|
seqs, |
|
|
padding_mask, |
|
|
state_bag=state_bag, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
seqs = self.postnet(seqs) |
|
|
|
|
|
return EmbeddingsBatch(seqs=seqs, padding_mask=padding_mask) |
|
|
|
|
|
def predict_next_sentence( |
|
|
self, |
|
|
batch: EmbeddingsBatch, |
|
|
sample: bool = False, |
|
|
temperature: float = 1.0, |
|
|
state_bag: Optional[IncrementalStateBag] = None, |
|
|
**kwargs, |
|
|
) -> EmbeddingsBatch: |
|
|
""" |
|
|
The method for predicting the next sentence embeddings. |
|
|
In the basic LCM, this is equivalent to just the forward method, |
|
|
but the derived architectures may have a different implementation. |
|
|
E.g. in VAE LCM, we run the VAE decoder on top of the `forward` results. |
|
|
|
|
|
Args: |
|
|
batch (EmbeddingsBatch): the sequence of concepts which |
|
|
the model should continue. |
|
|
sample (bool): whether to predict the single most probable next sentence |
|
|
or to sample from the predicted distribution. |
|
|
temperature (float): a positive float indicating the degree of diversity |
|
|
for the sampling (active only if `sample is True`). |
|
|
Returns: |
|
|
EmbeddingsBatch: the batch with predicted SONAR sentences. |
|
|
""" |
|
|
|
|
|
|
|
|
if self.frontend.sonar_normalizer is not None: |
|
|
batch = batch.normalize_seqs(self.frontend.sonar_normalizer) |
|
|
|
|
|
|
|
|
predicted_means = self.forward(batch, state_bag=state_bag, **kwargs) |
|
|
|
|
|
if sample and temperature > 0: |
|
|
noise = torch.randn_like(predicted_means.seqs) * temperature |
|
|
predicted_means.seqs = predicted_means.seqs + noise |
|
|
|
|
|
if self.frontend.sonar_normalizer is not None: |
|
|
predicted_means = predicted_means.denormalize_seqs( |
|
|
self.frontend.sonar_normalizer |
|
|
) |
|
|
|
|
|
return predicted_means |
|
|
|
|
|
|
|
|
class BaseLCModelBuilder(AbstractLCModelBuilder): |
|
|
"""Builds modules of a base LCM model""" |
|
|
|
|
|
config: BaseLCModelConfig |
|
|
device: Optional[Device] |
|
|
dtype: Optional[DataType] |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: BaseLCModelConfig, |
|
|
*, |
|
|
device: Optional[Device] = None, |
|
|
dtype: Optional[DataType] = None, |
|
|
) -> None: |
|
|
super().__init__(config=config, device=device, dtype=dtype) |
|
|
self.lcm_factory = TransformerFactory( |
|
|
model_dim=self.config.model_dim, |
|
|
max_seq_len=self.config.max_seq_len, |
|
|
config=self.config.lcm, |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
if config.model_output_dim is None: |
|
|
self.model_output_dim = self.config.sonar_embed_dim |
|
|
else: |
|
|
self.model_output_dim = config.model_output_dim |
|
|
|
|
|
def build_model(self) -> BaseLCModel: |
|
|
"""Build a model.""" |
|
|
|
|
|
frontend = self.build_frontend() |
|
|
|
|
|
lcm = self.build_core_lcm() |
|
|
|
|
|
postnet = self.build_postnet() |
|
|
|
|
|
return BaseLCModel( |
|
|
config=self.config, |
|
|
frontend=frontend, |
|
|
lcm=lcm, |
|
|
postnet=postnet, |
|
|
) |
|
|
|
|
|
def build_frontend(self) -> LCMFrontend: |
|
|
"""Build the LCM front-end (i.e., prenet).""" |
|
|
|
|
|
return LCMFrontend( |
|
|
sonar_embed_dim=self.config.sonar_embed_dim, |
|
|
model_dim=self.config.model_dim, |
|
|
config=self.config.frontend, |
|
|
pos_encoder=self.lcm_factory.build_pos_encoder(), |
|
|
sonar_normalizer=self.build_sonar_normalizer(), |
|
|
device=self.device, |
|
|
dtype=self.dtype, |
|
|
) |
|
|
|
|
|
def build_postnet(self) -> Projection: |
|
|
return Projection( |
|
|
output_dim=self.model_output_dim, |
|
|
input_dim=self.config.model_dim, |
|
|
config=self.config.postnet, |
|
|
device=self.device, |
|
|
dtype=self.dtype, |
|
|
) |
|
|
|
|
|
def build_attention_mask_factory(self): |
|
|
self_attn_mask_factory: AttentionMaskFactory |
|
|
|
|
|
self_attn_mask_factory = CausalAttentionMaskFactory() |
|
|
|
|
|
return self_attn_mask_factory |
|
|
|
|
|
def build_core_lcm(self) -> LCMTransformerDecoder: |
|
|
"""Build the core LCM module.""" |
|
|
|
|
|
config = self.config.lcm |
|
|
|
|
|
layers = [self.lcm_factory.build_layer() for _ in range(config.num_layers)] |
|
|
|
|
|
self_attn_mask_factory = self.build_attention_mask_factory() |
|
|
|
|
|
if config.final_norm_order_style is None: |
|
|
|
|
|
final_norm_order = parse_norm_order(config.norm_order_style) |
|
|
else: |
|
|
final_norm_order = parse_norm_order(config.final_norm_order_style) |
|
|
|
|
|
layer_norm_factory = parse_layer_norm_factory(config.layer_normalization_style) |
|
|
|
|
|
return LCMTransformerDecoder( |
|
|
layers, |
|
|
self_attn_mask_factory=self_attn_mask_factory, |
|
|
norm_order=final_norm_order, |
|
|
layer_norm_factory=layer_norm_factory, |
|
|
dropout_p=config.final_dropout_p, |
|
|
device=self.device, |
|
|
dtype=self.dtype, |
|
|
) |
|
|
|
|
|
|
|
|
def create_base_lcm_model( |
|
|
config: BaseLCModelConfig, |
|
|
*, |
|
|
device: Optional[Device] = None, |
|
|
dtype: Optional[DataType] = None, |
|
|
) -> BaseLCModel: |
|
|
"""Create an LCM model. |
|
|
:param config: |
|
|
The configuration. |
|
|
:param device: |
|
|
The device on which to initialize modules. |
|
|
:param dtype: |
|
|
The data type of module parameters and buffers. |
|
|
""" |
|
|
return BaseLCModelBuilder(config, device=device, dtype=dtype).build_model() |
|
|
|