diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..62c51761a31ddc31357f95f909b19270212d17ea --- /dev/null +++ b/.gitignore @@ -0,0 +1,118 @@ + +# JetBrains PyCharm IDE +.idea/ + +# Byte-compiled / optimized / DLL files +**/*/__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# macOS dir files +.DS_Store + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# dotenv +.env + +# virtualenv +.venv +venv/ +ENV/ + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +.pytest_cache +.ruff_cache + +# VSCODE +.vscode/ftp-sync.json +.vscode/settings.json +.vscode/launch.json + +# stopes logs +executor_logs/ +config_logs/ +outputs/ + +logs/ +**/dask_jobqueue_logs +core.* +mortimer_env.txt + +# datasets +_LexaLCM_Block0/Datasets/ + +# UV +uv.lock \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f287b169341770584b89ff0006108e641bc66804 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,16 @@ +repos: + - repo: https://github.com/astral-sh/uv-pre-commit + rev: 0.5.7 + hooks: + - id: uv-lock + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.8.2 + hooks: + # Lint + - id: ruff + args: [ --fix ] + # sort imports + - id: ruff + args: ["check", "--select", "I", "--fix"] + # format + - id: ruff-format \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100755 index 0000000000000000000000000000000000000000..480bf931e9dcf2dad7249fe44025d710ec0f99bd --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Alexandra 'Lexa' Baldwin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..04e9db8fc2dfa1816314c7860e6b81a49b989bde --- /dev/null +++ b/README.md @@ -0,0 +1,29 @@ +# LexaLCM Pre0 288M Pre-trained Large Concept Model +A pre-trained LCM model with 288M parameters based on Meta FAIR's LCM architecture. + +[[Paper]](https://ai.meta.com/research/publications/large-concept-models-language-modeling-in-a-sentence-representation-space/) + +Note: These instructions are for running the model on a single machine with a single GPU. If your system does not have a GPU that supports at least CUDA 12.1, or if you intend to execute this in the cloud, you'll need to modify the code per your requirements. + +## 1. Instal the Intel MKL runtime +```bash +sudo apt update +sudo apt install libmkl-rt +export LD_LIBRARY_PATH=/opt/intel/mkl/lib/intel64:$LD_LIBRARY_PATH +source ~/.bashrc +``` + +## 2. Install dependencies +```bash +uv sync --extra gpu --extra eval --extra data +``` + +## 3. Update the model cards' paths +These two model cards' paths must be updated to use the current paths based on where they exist in your local filesystem. +* '_LexaLCM_Pre0/Checkpoints/LCM_TwoTower_Pre0/model_card.yaml' +* '_LexaLCM_Pre0/Checkpoints/LCM_TwoTower_Pre0/checkpoints/step_250000/model_card.yaml' + +## 4. Test the model's inference +```bash +uv run --extra gpu scripts/run_inference.py +``` \ No newline at end of file diff --git a/lcm/__init__.py b/lcm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..20e5cd77f3bdb7b40146e2dc1df6648d9a1846d8 --- /dev/null +++ b/lcm/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# + +""" +LCM: Modular and Extensible Reasoning in an Embedding Space +Code base for training different LCM models. +""" + +from fairseq2 import setup_extensions +from fairseq2.assets import default_asset_store + +__version__ = "0.1.0.dev0" + + +def setup_fairseq2() -> None: + default_asset_store.add_package_metadata_provider("lcm.cards") + + +# This call activates setup_fairseq2 and potentially other extensions, +setup_extensions() diff --git a/lcm/cards/Normalizer_Wikipedia_En_1M.pt b/lcm/cards/Normalizer_Wikipedia_En_1M.pt new file mode 100644 index 0000000000000000000000000000000000000000..1fe79ef9b6a4729ced1426cd62fdbc6aec09aba1 Binary files /dev/null and b/lcm/cards/Normalizer_Wikipedia_En_1M.pt differ diff --git a/lcm/cards/sonar_normalizer.yaml b/lcm/cards/sonar_normalizer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..571170f54686702441f659b21b095aac8a81ecf5 --- /dev/null +++ b/lcm/cards/sonar_normalizer.yaml @@ -0,0 +1,4 @@ +name: sonar_normalizer_wikipedia_en_1m +model_family: sonar_normalizer +model_arch: base +checkpoint: Normalizer_Wikipedia_En_1M.pt diff --git a/lcm/datacards/datacards.yaml b/lcm/datacards/datacards.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d29ba510e76a817d3422bce97cb120f6b054d3e4 --- /dev/null +++ b/lcm/datacards/datacards.yaml @@ -0,0 +1,5 @@ +name: "Data_Wikipedia_En_1M" +parquet_path: + local: "./_LexaLCM_Pre0/Datasets/Wikipedia_En_1M" +source_column: "text_sentences_sonar_emb" +source_text_column: "text_sentences" diff --git a/lcm/datasets/__init__.py b/lcm/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..28faebf07e16031daaa4182323df5af116055ae2 --- /dev/null +++ b/lcm/datasets/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# diff --git a/lcm/datasets/batch.py b/lcm/datasets/batch.py new file mode 100644 index 0000000000000000000000000000000000000000..eca937f551b8fe77f5bbb1bf3e86e41bd3835b1a --- /dev/null +++ b/lcm/datasets/batch.py @@ -0,0 +1,425 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# All rights reserved. +# +# + +from copy import deepcopy +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from fairseq2.logging import get_log_writer +from fairseq2.models.sequence import SequenceBatch +from fairseq2.nn.padding import PaddingMask, pad_seqs +from fairseq2.typing import Device +from torch import Tensor +from torch.nn import Module + +from lcm.utils.common import Batched + +logger = get_log_writer(__name__) + + +DOC_LENGTHS = "__doc_lengths" + + +class LCMStyle(Enum): + """Specifies a style for preparing the LCM input.""" + + SUPERVISED = 1 + """For when the model is fed supervised data with source & target sentences.""" + + UNSUPERVISED = 2 + """For when the model is fed unsupervised data with source sentences only.""" + + PACKED_UNSUPERVISED = 3 + """For when the model is fed ``packed`` unsupervised data with source sentences only. + This means that we will look for document_lengths and propagate them to the + packed causal masked attention and the packed position encoders""" + + +@dataclass +class EmbeddingsBatch: + """Represents a sequence of embeddings batch. + Resembles Fairseq2's SequenceBatch with additional properties""" + + seqs: Tensor + """The sequences. *Shape:* :math:`(B,S,D)`, where :math:`B` is the batch + size, :math:`S` is the sequence length (in sentences per document), + and :math:`D` the embedding dimension + """ + + padding_mask: Optional[PaddingMask] = None + """The padding mask of ``seqs``. *Shape:* :math:`(B,S)`, where :math:`B` is + the batch size and :math:`S` is the sequence length.""" + + diffusion_timesteps: Optional[Tensor] = None + """Diffusion timesteps of noising process of ``seqs``. *Shape:* :math:`(B,S)`, where :math:`B` is + the batch size and :math:`S` is the sequence length.""" + + document_lengths: Optional[Tensor] = None + """Lengths of the documents (in sentences) present in the batch + Shape: (len_doc, ) + """ + + source_lengths: Optional[Tensor] = None + """Lengths of source part for each element in batch, so that `seqs[i, :source_lengths[i]]` corresponds to source for each i in [0, batch_size). + Shape: (batch_size, ) + """ + + def __post_init__(self): + if self.document_lengths is not None: + assert self.document_lengths.sum() == self.seqs.size( + 1 + ) or 2 * self.document_lengths.sum() == self.seqs.size(1), ( + "The legnths do no sum up to the sequence length " + "(nor half the length for doubled diffusion sequences). " + f"We have seqs.size={self.seqs.size()} and lengths={self.document_lengths} " + f"summing to {self.document_lengths.sum()}" + ) + + def __len__(self) -> int: + return self.batch_size + + @property + def batch_size(self) -> int: + """The size of the batch.""" + return self.seqs.size(0) + + @property + def shape(self) -> torch.Size: + """The shape of the batch.""" + return self.seqs.shape + + @property + def device(self) -> Device: + """The device of the batch.""" + return self.seqs.device + + def clone(self): + return deepcopy(self) + + def __getitem__(self, i: int) -> Any: + raise NotImplementedError( + "Access to each item in EmbeddingsBatch not allowed yet" + ) + + def unbatch(self) -> List[Tensor]: + if self.padding_mask is None: + return list(self.seqs) + else: + return [ + tt[:length] for tt, length in zip(self.seqs, self.padding_mask.seq_lens) + ] + + def get_last_element(self) -> Tensor: + if self.padding_mask: + return self.seqs[ + torch.arange(len(self.padding_mask.seq_lens), device=self.seqs.device), + (self.padding_mask.seq_lens - 1), + ] + else: + return self.seqs[:, -1] + + def set_last_element(self, element: Tensor) -> None: + element = element.to(self.seqs.device) + if self.padding_mask: + for i, slen in enumerate(self.padding_mask.seq_lens): + self.seqs[i, slen - 1] = element[i] + else: + self.seqs[:, -1] = element + + def normalize_seqs(self, normalizer: Optional[Module]) -> "EmbeddingsBatch": + if normalizer is None: + logger.warning( + "The normalizer is None, as such, the features will remain unchanged" + ) + return self + + return EmbeddingsBatch( + seqs=normalizer.normalize(self.seqs), + padding_mask=self.padding_mask, + diffusion_timesteps=self.diffusion_timesteps, + document_lengths=self.document_lengths, + source_lengths=self.source_lengths, + ) + + def denormalize_seqs(self, normalizer: Optional[Module]) -> "EmbeddingsBatch": + if normalizer is None: + logger.warning( + "The normalizer is None, as such, the features will remain unchanged" + ) + return self + + return EmbeddingsBatch( + seqs=normalizer.denormalize(self.seqs), + padding_mask=self.padding_mask, + diffusion_timesteps=self.diffusion_timesteps, + document_lengths=self.document_lengths, + source_lengths=self.source_lengths, + ) + + def double_seqs(self) -> "EmbeddingsBatch": + """ + performs sequence elements repeatition in sequence dim : + 1, 2, 3 -> 1, 1, 2, 2, 3, 3 + x, y -> x, x, y, y + """ + if self.padding_mask is not None: + doubled_padding_mask = PaddingMask( + seq_lens=2 * self.padding_mask._seq_lens, + batch_seq_len=2 * self.padding_mask._batch_seq_len, + ) + else: + doubled_padding_mask = None + + return EmbeddingsBatch( + seqs=torch.repeat_interleave(self.seqs, 2, dim=1), + padding_mask=doubled_padding_mask, + diffusion_timesteps=self.diffusion_timesteps, + document_lengths=self.document_lengths, + source_lengths=( + torch.repeat_interleave(self.source_lengths, 2, dim=0) + if self.source_lengths is not None + else None + ), + ) + + def flatten_to_sentences(self) -> Tensor: + """Flatten the sequence of embeddings + from B, S, D to B*~S, D after removing the padded positions + """ + + embed_dim = self.seqs.size(-1) + + if self.padding_mask is not None: + seq_lens = self.padding_mask.seq_lens + + embeds_mask = self.padding_mask.materialize().unsqueeze(-1) + + # Remove padded positions and reshape as B*~S, D + flat_embeds = torch.masked_select(self.seqs, embeds_mask).reshape( + (-1, embed_dim) + ) + + # split per document/paragraph + flat_embeds_per_doc = list(torch.split(flat_embeds, seq_lens.tolist())) + + # Concatenate back + flat_embeds = torch.concat(flat_embeds_per_doc) + + else: + embeds = self.seqs + + flat_embeds = embeds.reshape((-1, embed_dim)) + + return flat_embeds + + +@dataclass +class LCMInput(Batched): + """Dataclass for a pair of source/target sequences of SONAR embeddings""" + + source: List[Tensor] + """source: SONAR embeddings of the source text + i.e [X^1 in (N_1, D), ... X^M in (N_M, D)]""" + + target: Union[None, List[Tensor]] + """target: If supervised data: SONAR embeddings of the target text""" + + tokens: Union[None, SequenceBatch] = None + """tokens: Tokenized flattened sentences for the SONAR decoder + (see the dataloader `_prepare_subword_tokens`)""" + + target_tokens: Union[None, SequenceBatch] = None + """target_tokens: a sequence of the same shape as target_tokens, but shifted, to serve as the target. + (see the dataloader `_prepare_subword_tokens`)""" + + name: Optional[str] = None + """ + dataset name from which input is coming from + """ + batch: Optional[Dict[str, Any]] = None + """raw batch of dataloader used for tracking and debugging""" + + def __post_init__(self): + assert self.source is not None + + length = len(self.source) + + assert (self.target is None) or (len(self.target) == length), ( + f"all elements in LCMInput should be of the same length, got {len(self.target)} and {length}" + ) + + def __len__(self) -> int: + return len(self.source) + + def __getitem__(self, i: int) -> Union[Tensor, Tuple[Tensor, Tensor]]: + """ + Return the content of item in the batch + """ + if self.target is None: + return self.source[i] + else: + return self.source[i], self.target[i] + + def prepare_input( + self, + style: LCMStyle = LCMStyle.UNSUPERVISED, + ) -> EmbeddingsBatch: + """ + Adds special tokens to the source (& target) and prepares + the EmbeddingsBatch (tensor & its padding mask) that will be + forwarded in the LCM model. + + `style`: LCMStyle is either supervised or + unsupervised (requires target embeddings) + """ + + if style == LCMStyle.UNSUPERVISED: + return get_embeddings_sequence(src_seqs=self.source) + + elif style == LCMStyle.PACKED_UNSUPERVISED: + # If using PACKED_UNSUPERVISED, document_lengths will be added to `EmbeddingsBatch` + document_lengths = None + if self.batch is not None and self.batch.get(DOC_LENGTHS, None) is not None: + # document_lengths will only be consumed if the batch_size is 1 + assert len(self.batch[DOC_LENGTHS]) == 1, "Expecting batch size of 1" + + document_lengths = self.batch[DOC_LENGTHS][0].type(torch.int64) + + return get_embeddings_sequence( + src_seqs=self.source, + document_lengths=document_lengths, + ) + + elif style == LCMStyle.SUPERVISED: + assert self.target is not None, ( + "Missing target embeddings for a supervised batch" + ) + return get_embeddings_sequence( + src_seqs=self.source, + tgt_seqs=self.target, + ) + + raise ValueError(f"Unsupported style={style} - could not prepare input") + + def prepare_target_mask( + self, + embeddings: EmbeddingsBatch, + style: LCMStyle, + min_context_size: Optional[int] = None, + ) -> Tensor: + """Prepare a target mask signaling what positions + we should predict and optimize the model for + + Args: + - min_context_size: the minimum context used to predict the next + concept (only used for unuspervised training) + + """ + + batch_size, maxlen, _ = embeddings.seqs.size() + + device = embeddings.seqs.device + + if style == LCMStyle.UNSUPERVISED: + # A target mask for unsupervised next sentence prediction + # All positions are optimized/predicted starting from min_context_size + target_mask = torch.ones( + (batch_size, maxlen), + dtype=torch.bool, + device=device, + ) + if min_context_size is not None: + target_mask[:, : min(min_context_size, target_mask.size(1))] = False + + elif style == LCMStyle.PACKED_UNSUPERVISED: + # A target mask for unsupervised next sentence prediction when the data is packed + # All positions are optimized starting from min_context_size in each document + document_lengths = embeddings.document_lengths + if document_lengths is not None: # training + + def get_document_target_mask(doc_length): + mask = torch.ones(doc_length, dtype=torch.bool, device=device) + mask[: min(min_context_size, doc_length)] = False + return mask + + target_mask = torch.cat( + [get_document_target_mask(length) for length in document_lengths] + ).unsqueeze(0) + + else: # validation with unpacked data: + target_mask = torch.ones( + (batch_size, maxlen), + dtype=torch.bool, + device=device, + ) + if min_context_size is not None: + target_mask[:, : min(min_context_size, target_mask.size(1))] = False + + elif style == LCMStyle.SUPERVISED: + # A target mask for target prediction + indices = torch.arange(maxlen, device=device).expand(batch_size, -1) + + source_lengths = torch.tensor( + [seq.size(0) for seq in self.source], + device=device, + ) + + target_mask = indices >= source_lengths.unsqueeze(1).expand(-1, maxlen) + + # Factor in padded positions: + if embeddings.padding_mask is not None: + target_mask = target_mask * embeddings.padding_mask.materialize() + + return target_mask.detach() + + +def get_embeddings_sequence( + src_seqs: List[Tensor], + tgt_seqs: Optional[List[Tensor]] = None, + document_lengths: Optional[Tensor] = None, + double_target: bool = False, +) -> EmbeddingsBatch: + seqs_lst: List[Tensor] = [] + for src_seq, tgt_seq in zip(src_seqs, tgt_seqs or [None] * len(src_seqs)): # type: ignore + embeds: List[Tensor] = [] + device, dtype = src_seq.device, src_seq.dtype + + # mandatory src_sec + embeds.append(src_seq) + + # supervised tgt_seq + if tgt_seq is not None: + tgt_seq = tgt_seq.to(device).type(dtype) + + if double_target: + embeds.append(torch.repeat_interleave(tgt_seq, 2, dim=0)) + else: + embeds.append(tgt_seq) + + seqs_lst.append(torch.concat(embeds)) + + seqs, padding_mask = pad_seqs(seqs_lst) + + if document_lengths is not None: + document_lengths = document_lengths.to(seqs.device) + + if tgt_seqs is not None: + source_lengths = torch.tensor( + [seq.size(0) for seq in src_seqs], device=seqs.device + ) + else: + source_lengths = None + + output = EmbeddingsBatch( + seqs, + padding_mask=padding_mask, + document_lengths=document_lengths, + source_lengths=source_lengths, + ) + + return output diff --git a/lcm/inference/lcm/__init__.py b/lcm/inference/lcm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..716bcc06ed3c1d531d134e9684258103b1947198 --- /dev/null +++ b/lcm/inference/lcm/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# + +from lcm.inference.lcm.generator import LCMGenerator as LCMGenerator +from lcm.inference.lcm.generator import LCMGeneratorOptions as LCMGeneratorOptions + +__all__ = ["LCMGenerator", "LCMGeneratorOptions"] diff --git a/lcm/inference/lcm/generator.py b/lcm/inference/lcm/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..3e81b904bbaae76d13f6dd945503b7dc95797a90 --- /dev/null +++ b/lcm/inference/lcm/generator.py @@ -0,0 +1,448 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# + +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import torch +from fairseq2.generation.generator import ( + GenerationCounters, + Hypothesis, + SequenceGeneratorOutput, +) +from fairseq2.logging import get_log_writer + +from lcm.datasets.batch import EmbeddingsBatch, PaddingMask +from lcm.models.abstract_lcm import AbstractLCModel +from lcm.nn.incremental_state import LCMIncrementalStateBag + +logger = get_log_writer(__name__) + + +""" +This generator follows the style of existing generators in Fairseq2 +""" + + +@dataclass +class LCMGeneratorOptions: + """Holds the options to pass to a sequence generator.""" + + max_seq_len: int = 200 + """The hard limit on maximum length of generated sequences.""" + + min_seq_len: int = 1 + """The minimum length of generated sequences.""" + + eos_threshold: Optional[float] = 0.9 + """Threshold for cosine similarity to the EOS vector""" + + sample_latent_variable: bool = True + """When using VAE models, whether to return the mean or sample""" + + stop_on_repetition_cosine_threshold: Optional[float] = None + """Stop the generation when the similarity of two consecutive concepts is above the threshold.""" + + include_eos_token: bool = False + """Whether the eos token should be included in the hypotheses (matters only if they are trimmed).""" + + trim_hypotheses: bool = False + """Whether the tokens after the EOS token should be included in the hypotheses.""" + + seed: Optional[int] = None + """Seed to make generation deterministic""" + + lcm_temperature: float = 1.0 + """Temperature for decoding in the LCM""" + + +class LCMGenerator: + """Generates with an LCM model.""" + + def __init__( + self, + model: AbstractLCModel, + options: Optional[LCMGeneratorOptions] = None, + eos_vec: Optional[torch.Tensor] = None, + ) -> None: + """ + :param model: + The LC model to use for generation. + """ + model.eval() + self.model = model + + if options is None: + options = LCMGeneratorOptions() + + self.eos_vec = eos_vec + if self.eos_vec is None and options.eos_threshold: + logger.warning( + f"eos_threshold is set to {options.eos_threshold}, but eos_vec is not provided" + ) + if options.eos_threshold: + logger.debug(f"The eos_vec in generator has been set to {self.eos_vec}") + + self.options = options + + self.max_seq_len = options.max_seq_len + self.min_seq_len = options.min_seq_len + + assert self.min_seq_len >= 1, ( + f"min_seq_len must be greater than or equal to 1, min_seq_len={options.min_seq_len}" + ) + + self.eos_threshold = options.eos_threshold + + self.seqs: torch.Tensor + self.step_nr = 0 + self.min_prompt_len: int + self.max_prompt_len: int + self.sample_indices: torch.Tensor + self.state_bag: Optional[LCMIncrementalStateBag] = None + self.prompt_seq_lens: Optional[torch.Tensor] = None + self.prompt_padding_mask: Optional[torch.Tensor] = None + self.lengths: torch.Tensor + self.step_scores: torch.Tensor + + @torch.inference_mode() + def __call__( + self, + batch_input: EmbeddingsBatch, + max_gen_len: Optional[int] = None, + min_gen_len: Optional[int] = None, + temperature: float = 0.0, + disable_cache: bool = False, + **kwargs, + ) -> SequenceGeneratorOutput: + """ + :param input: + `bacth_input` embedded and padded tensor sequence of the inputs + `max_gen_len` max length to be generated for the given input + `min_gen_len` minimum length to be generated for the given input + `temperature` temperature to control the generation + `disable_cache` if True, do not use kv-caching + :returns: + The output of the LCM generator, consists of :math:`N` lists of + hypotheses for :math:`N` prompts. Each list has 1 Hypothesis + (beam size = 1), of which `seq` has the *Shape:* math:`(S+T, D)` + (:math:`S` is the prompt length, :math:`T` the length of the + generated sequence after the prompt and :math:`D` the model + dimension.) + + """ + if self.options.seed: + torch.manual_seed(self.options.seed) + + # Setup the variables + batch_size, self.max_prompt_len, embed_dim = batch_input.seqs.size() + prompt_padding_mask = batch_input.padding_mask + if prompt_padding_mask is None: + self.min_prompt_len = self.max_prompt_len + self.prompt_padding_mask = None + self.prompt_seq_lens = None + else: + self.prompt_seq_lens = prompt_padding_mask.seq_lens + assert self.prompt_seq_lens is not None, ( + "Expecting a valid `self.prompt_seq_lens` Tensor, found `None`" + ) + self.min_prompt_len = int(torch.min(self.prompt_seq_lens, dim=0)[0].item()) + + # Keep the materialized mask + self.prompt_padding_mask = prompt_padding_mask.materialize() + + if not max_gen_len: + max_gen_len = self.max_seq_len + + # Make sure we do not accidentally set a max_gen_len that exceeds + # the generator's model capability + assert max_gen_len <= self.max_seq_len, ( + f"Generator can generate up to {self.max_seq_len} sequences, max_gen_len={max_gen_len}" + ) + self.max_gen_len = max_gen_len + + if not min_gen_len: + min_gen_len = self.min_seq_len + + assert min_gen_len > 0, ( + f"min_gen_len must be greater than or equal to 1, min_gen_len={min_gen_len}" + ) + self.min_gen_len = min_gen_len + + if temperature == 0.0: + # If the call doesn't pass a specific temperature, + # use the default one from the decoding options + temperature = self.options.lcm_temperature + + self.temperature = temperature + + for k, v in kwargs.items(): + if hasattr(self.options, k) and v: + setattr(self.options, k, v) + + # Holds the generated sequences, scores and sample-dependent variables + dtype = self.model.dtype + device = batch_input.seqs.device + + if disable_cache: + self.state_bag = None + else: + self.state_bag = LCMIncrementalStateBag( + self.max_prompt_len + self.max_gen_len + ) + + # reserving full sequences capacity + self.seqs = torch.zeros( + (batch_size, self.max_prompt_len + self.max_gen_len, embed_dim), + device=device, + dtype=dtype, + ) + self.step_scores = torch.zeros( + (batch_size, self.max_prompt_len + self.max_gen_len), + device=device, + ) + self.lengths = torch.zeros(batch_size, dtype=torch.int, device=device) - 1 + + # Hold the samples indices to return in order + self.sample_indices = torch.arange(batch_size, device=device) + # Output buffer + self.hypotheses: List[List[Hypothesis]] = [[] for _ in range(batch_size)] + + # Bootstrap the sequences with the provided prompt. + self.seqs[:, : self.max_prompt_len] = batch_input.seqs[:, : self.max_prompt_len] + self.step_nr = self.min_prompt_len + self.prefill(**kwargs) + + for self.step_nr in range( + self.min_prompt_len, self.max_prompt_len + self.max_gen_len + ): + if not self._step(): + break + + return SequenceGeneratorOutput(self.hypotheses, counters=GenerationCounters()) + + @torch.inference_mode() + def prefill(self, **kwargs) -> None: + """The initial forward pass in the decoder with the prefix/prompt + to populate the KV-cache""" + + if self.state_bag is None: + return + + # Prefilling with -1 since the next call to step will use the last token in the prefix + prefill_len = self.step_nr - 1 + + if prefill_len > 0: + _ = self._decode( + self.seqs[:, :prefill_len], + padding_mask=None, + ) + self.state_bag.increment_step_nr(prefill_len) # type: ignore + else: + logger.warning( + f"Skipping prefill since only a context size of {self.step_nr} is provided in the prefix" + ) + + @torch.inference_mode() + def _decode( + self, + seqs: torch.Tensor, + padding_mask: Optional[PaddingMask], + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + output = self.model.predict_next_sentence( + EmbeddingsBatch(seqs, padding_mask), + sample=self.options.sample_latent_variable, + temperature=self.temperature, + state_bag=self.state_bag, + **kwargs, + ) + + # Dummy scores + scores = torch.zeros(seqs.shape[:-1]) + return output.seqs, scores + + def _step(self) -> bool: + # Generate the next step output. + + if self.state_bag is None: + # Without a state_bag, we're forwarding the full prefix + model_output, step_score = self._decode( + seqs=self.seqs[:, : self.step_nr], + padding_mask=None, + ) + else: + # Since we're using a state_bag, we're only forwarding the last embedding + model_output, step_score = self._decode( + seqs=self.seqs[:, self.step_nr - 1 : self.step_nr], + padding_mask=None, + ) + + self.state_bag.increment_step_nr() + + # model_output: EmbeddingBag + return self.finalize_step(model_output, step_score) + + def finalize_step( + self, model_output: torch.Tensor, step_score: torch.Tensor + ) -> bool: + """Post-processing and finalizing a step + by checking all stopping criteria + Takes the model's outputed embeddings (model_output) + and their associated scores (step_score) + If we're stepping, return True, else return False + """ + already_finished = self.lengths > -1 + should_finish_now = torch.zeros_like(already_finished) + + model_last_output = model_output[:, -1] + device = model_last_output.device + + # Ignore prompt positions between min-max prompt_len + must_keep_going = None + if self.step_nr < self.max_prompt_len: + assert self.prompt_padding_mask is not None, ( + f"If self.prompt_padding_mas is None, then self.step_nr should start from self.max_prompt_len={self.max_prompt_len} - currently self.step_nr = {self.step_nr}" + ) + mask = self.prompt_padding_mask[:, self.step_nr] + model_last_output[mask] = self.seqs[mask, self.step_nr] + must_keep_going = mask + + # Check stopping based on EOS similarity. + if self.eos_threshold is not None and self.eos_vec is not None: + sim2eos = torch.nn.functional.cosine_similarity( + self.eos_vec.to(device), model_last_output + ) + logger.debug(f"Similarity to eos vector: {sim2eos} vs {self.eos_threshold}") + should_finish_now = should_finish_now | sim2eos.ge(self.eos_threshold) + + # Check stopping based on repetition. + if ( + self.options.stop_on_repetition_cosine_threshold is not None + and self.step_nr > 0 + ): + sim2prev = torch.nn.functional.cosine_similarity( + self.seqs[:, self.step_nr - 1], model_last_output + ) + logger.debug( + f"Similarity to prev vector: {sim2prev} vs {self.options.stop_on_repetition_cosine_threshold}" + ) + should_finish_now = should_finish_now | sim2prev.ge( + self.options.stop_on_repetition_cosine_threshold + ) + + if must_keep_going is not None: + logger.debug( + f"Must keep going (to cover max_prompt_len={self.max_prompt_len}) is not None = {must_keep_going}" + ) + should_finish_now = should_finish_now & ~must_keep_going + + # Keep going if output is shorter than min_gen_len: + if self.prompt_seq_lens is not None: + longer_than_min_gen_len = (self.step_nr - self.prompt_seq_lens).ge( + self.min_gen_len + ) + else: + longer_than_min_gen_len = ( + self.step_nr - self.max_prompt_len + ) >= self.min_gen_len + + logger.debug( + f"Longer than min_gen_len ({self.min_gen_len}) = {longer_than_min_gen_len}" + ) + should_finish_now = should_finish_now & longer_than_min_gen_len + stopped_on_eos = should_finish_now + + # Stop hypotheses that reached max_gen_len + if self.prompt_seq_lens is not None: + exceeds_max_gen_len = (self.step_nr - self.prompt_seq_lens + 1).ge( + self.max_gen_len + ) + logger.debug( + f"step: {self.step_nr}; max_gen_len: {self.max_gen_len}; promt_lens: {self.prompt_seq_lens}; steps exceeded: {self.max_gen_len + self.prompt_seq_lens}" + ) + + else: + exceeds_max_gen_len = ( + self.step_nr - self.max_prompt_len + 1 + ) >= self.max_gen_len + logger.debug( + f"step: {self.step_nr}; max_gen_len: {self.max_gen_len}; promt_lens: None (unique length: {self.max_prompt_len}); steps exceeded: {self.max_prompt_len + self.max_gen_len}" + ) + + logger.debug( + f"Stopping criteria: {should_finish_now}; exceeds max len: {exceeds_max_gen_len}; already finished: {already_finished}" + ) + + should_finish_now = should_finish_now | exceeds_max_gen_len + + # Assign lengths to the sequences that have just finished. + should_finish_now = should_finish_now & ~already_finished + self.lengths[should_finish_now] = self.step_nr + 1 + + # Record the current step. + self.seqs[:, self.step_nr] = model_last_output.squeeze(1) + self.step_scores[:, self.step_nr - self.min_prompt_len] = step_score[:, -1] + + # Save completed hypsptheses + finished_mask = self.lengths.ne(-1) + finished_indices = finished_mask.nonzero() + + # Remove finished hypotheses and reorder variables/state_bag if any are left + if len(finished_indices) > 0: + for idx in finished_indices: + self.finish_sequence(int(idx), is_eos=bool(stopped_on_eos[int(idx)])) + + active_mask = ~finished_mask + active_indices = active_mask.nonzero().squeeze(-1) + + if len(active_indices) == 0: + return False + + self.reorder_state(active_indices) + + return True + + def finish_sequence(self, idx: int, is_eos: bool = False) -> None: + seq_len = int(self.lengths[idx].item()) + + if self.options.trim_hypotheses and self.lengths[idx].item() > -1 and is_eos: + seq_len = int(self.lengths[idx].item()) - int( + not self.options.include_eos_token + ) + + sample_idx = int(self.sample_indices[idx]) + self.hypotheses[sample_idx] = [ + Hypothesis( + seq=self.seqs[idx, :seq_len], + score=None, + step_scores=self.step_scores[idx], # Trim it as well? + ) + ] + + def state_bag_reorder(self, new_order: torch.Tensor) -> None: + if self.state_bag is not None: + self.state_bag.reorder(new_order) + + def reorder_state(self, new_order: torch.Tensor) -> None: + self.state_bag_reorder(new_order) + + self.seqs = self.seqs.index_select(dim=0, index=new_order) + + self.sample_indices = self.sample_indices.index_select(dim=0, index=new_order) + + self.step_scores = self.step_scores.index_select(dim=0, index=new_order) + + self.lengths = self.lengths.index_select(dim=0, index=new_order) + + if self.prompt_padding_mask is not None: + self.prompt_padding_mask = self.prompt_padding_mask.index_select( + dim=0, index=new_order + ) + + if self.prompt_seq_lens is not None: + self.prompt_seq_lens = self.prompt_seq_lens.index_select( + dim=0, index=new_order + ) diff --git a/lcm/inference/lcm/scorer.py b/lcm/inference/lcm/scorer.py new file mode 100644 index 0000000000000000000000000000000000000000..15cbd15c2d4ded819c56a332625cef4750a45565 --- /dev/null +++ b/lcm/inference/lcm/scorer.py @@ -0,0 +1,198 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# + +from typing import List, Optional + +import torch +from fairseq2.generation.generator import ( + GenerationCounters, + Hypothesis, + SequenceGeneratorOutput, +) + +from lcm.datasets.batch import EmbeddingsBatch +from lcm.inference.lcm.generator import LCMGenerator, LCMGeneratorOptions +from lcm.nn.incremental_state import LCMIncrementalStateBag + + +class LCMScorer(LCMGenerator): + """Generates with an LCM model in teacher-forcing mode.""" + + options: LCMGeneratorOptions + + @torch.inference_mode() + def __call__( # type: ignore + self, + batch_input: EmbeddingsBatch, + max_gen_len: Optional[int] = None, + min_gen_len: Optional[int] = None, + min_context_len: int = 1, + temperature: float = 0.0, + disable_cache: bool = False, + ) -> SequenceGeneratorOutput: + """ + :param input: + `bacth_input` embedded and padded tensor sequence of the inputs + `max_gen_len` max length to be generated for the given input + `min_gen_len` minimum length to be generated for the given input + `disable_cache` if True, do not use kv-caching + :returns: + The output of the LCM generator, consists of :math:`N` lists of + hypotheses for :math:`N` documents. Each list has 1 Hypothesis + (beam size = 1), of which `seq` has the *Shape:* math:`(T, D)` + (:math:`T` the length of the document and :math:`D` the model + dimension + + """ + if self.options.seed: + torch.manual_seed(self.options.seed) + + # Setup the variables + self.min_context_len = min_context_len + batch_size, self.max_text_len, embed_dim = batch_input.seqs.size() + text_padding_mask = batch_input.padding_mask + if text_padding_mask is None: + self.text_padding_mask = None + self.text_seq_lens = self.max_text_len * torch.ones( + batch_size, + dtype=torch.long, + device=batch_input.seqs.device, + ) + else: + self.text_seq_lens = text_padding_mask.seq_lens + assert self.text_seq_lens is not None, ( + "Expecting a valid `self.text_seq_lens` Tensor, found `None`" + ) + + # Keep the materialized mask + self.text_padding_mask = text_padding_mask.materialize() + + if not max_gen_len: + max_gen_len = self.max_seq_len + + max_gen_len = min(max_gen_len, self.max_text_len - self.min_context_len) + assert max_gen_len is not None, "max_gen_len is None" + + # Make sure we do not accidentally set a max_gen_len that exceeds + # the generator's model capability + assert max_gen_len <= self.max_seq_len, ( + f"Generator can generate up to {self.max_seq_len} sequences, max_gen_len={max_gen_len}" + ) + self.max_gen_len = max_gen_len + + if not min_gen_len: + min_gen_len = self.min_seq_len + + assert min_gen_len > 0, ( + f"min_gen_len must be greater than or equal to 1, min_gen_len={min_gen_len}" + ) + self.min_gen_len = min_gen_len + + if temperature == 0.0: + # If the call doesn't pass a specific temperature, + # use the default one from the decoding options + temperature = self.options.lcm_temperature + + # Holds the generated sequences, scores and sample-dependent variables + dtype = self.model.dtype + device = batch_input.seqs.device + self.temperature = temperature + + if disable_cache: + self.state_bag = None + else: + self.state_bag = LCMIncrementalStateBag(self.max_text_len) + + # reserving full sequences capacity + self.seqs = batch_input.seqs + self.preds = torch.zeros( + (batch_size, self.max_text_len - self.min_context_len, embed_dim), + device=device, + dtype=dtype, + ) + self.step_scores = torch.zeros( + (batch_size, self.max_text_len), + device=device, + ) + + # Hold the samples indices to return in order + self.sample_indices = torch.arange(batch_size, device=device) + # Output buffer + self.hypotheses: List[List[Hypothesis]] = [[] for _ in range(batch_size)] + + # the sequences with the provided prompt. + self.step_nr = self.min_context_len + self.prefill() + + for self.step_nr in range(self.min_context_len, self.max_text_len): + if not self._step(): + break + + return SequenceGeneratorOutput(self.hypotheses, counters=GenerationCounters()) + + def finalize_step( + self, model_output: torch.Tensor, step_score: torch.Tensor + ) -> bool: + """Post-processing and finalizing a step + by checking all stopping criteria + Takes the model's outputed embeddings (model_output) + and their associated scores (step_score) + If we're stepping, return True, else return False + """ + model_last_output = model_output[:, -1] + must_keep_going = self.text_seq_lens.gt(self.step_nr + 1) + should_finish_now = ~must_keep_going + + # Record the current step prediction. + self.preds[:, self.step_nr - self.min_context_len] = model_last_output.squeeze( + 1 + ) + self.step_scores[:, self.step_nr - self.min_context_len] = step_score[:, -1] + + # Save completed hypotheses + finished_indices = should_finish_now.nonzero() + + # Remove finished hypotheses and reorder variables/state_bag if any are left + if len(finished_indices) > 0: + for idx in finished_indices: + self.finish_sequence(int(idx)) + + active_mask = must_keep_going + active_indices = active_mask.nonzero().squeeze(-1) + + if len(active_indices) == 0: + return False + + self.reorder_state(active_indices) + + return True + + def finish_sequence(self, idx: int, is_eos: bool = False) -> None: + seq_len = int(self.text_seq_lens[idx].item()) + sample_idx = int(self.sample_indices[idx]) + self.hypotheses[sample_idx] = [ + Hypothesis( + seq=self.preds[idx, : seq_len - self.min_context_len], + score=None, + step_scores=self.step_scores[idx], # Trim it as well? + ) + ] + + def reorder_state(self, new_order: torch.Tensor) -> None: + self.state_bag_reorder(new_order) + + self.seqs = self.seqs.index_select(dim=0, index=new_order) + self.preds = self.preds.index_select(dim=0, index=new_order) + + self.sample_indices = self.sample_indices.index_select(dim=0, index=new_order) + + self.step_scores = self.step_scores.index_select(dim=0, index=new_order) + + if self.text_padding_mask is not None: + self.text_padding_mask = self.text_padding_mask.index_select( + dim=0, index=new_order + ) + + self.text_seq_lens = self.text_seq_lens.index_select(dim=0, index=new_order) diff --git a/lcm/inference/two_tower_diffusion_lcm/__init__.py b/lcm/inference/two_tower_diffusion_lcm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0cbe6984ba4b936691eeac6957383e0b1fdc4d7b --- /dev/null +++ b/lcm/inference/two_tower_diffusion_lcm/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# + +from lcm.inference.two_tower_diffusion_lcm.generator import ( + DiffusionLCMGeneratorOptions as DiffusionLCMGeneratorOptions, +) +from lcm.inference.two_tower_diffusion_lcm.generator import ( + TwoTowerDiffusionLCMGenerator as TwoTowerDiffusionLCMGenerator, +) + +__all__ = [ + "TwoTowerDiffusionLCMGenerator", + "DiffusionLCMGeneratorOptions", +] diff --git a/lcm/inference/two_tower_diffusion_lcm/generator.py b/lcm/inference/two_tower_diffusion_lcm/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..62487d8126204915769edbe807431809a2f33c2e --- /dev/null +++ b/lcm/inference/two_tower_diffusion_lcm/generator.py @@ -0,0 +1,466 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# + +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import torch +from fairseq2.generation.generator import ( + GenerationCounters, + Hypothesis, + SequenceGeneratorOutput, +) +from fairseq2.logging import get_log_writer + +from lcm.datasets.batch import EmbeddingsBatch, PaddingMask +from lcm.inference.lcm.generator import ( + LCMGenerator, + LCMGeneratorOptions, +) +from lcm.models.abstract_lcm import AbstractLCModel +from lcm.models.two_tower_diffusion_lcm.builder import TwoTowerDiffusionLCModel +from lcm.nn.incremental_state import LCMIncrementalStateBag + +logger = get_log_writer(__name__) + + +@dataclass +class DiffusionLCMGeneratorOptions(LCMGeneratorOptions): + """Holds the options to pass to a diffusion-based sequence generator.""" + + guidance_scale: float = 1.0 + """The weight of the regular classifier-free guidance. + Here `guidance_scale` is defined as the guidance weight `w` of + Equation (2) of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf. + `guidance_scale = 1` corresponds to doing no classifier free guidance. + A higher guidance scale value encourages the model to generate outputs + closely related to the `prompt` at the expense of lower quality.""" + + guidance_rescale: float = 0.0 + """The rescaling factor for Classifier-Free Guidance with Rescale + (Algorithm 2 - https://arxiv.org/pdf/2305.08891)""" + + ddim_eta: float = 0.0 + """The weight of noise for added noise in diffusion step. + It controls the level of interpolation between a deterministic + DDIM (at eta=0.0) and a stochastic DDPM (at eta = 1.0) + See section 5 of the DDIM paper https://arxiv.org/pdf/2010.02502 """ + + epsilon_scaling: Optional[float] = None + """epsilon_scaling: Optional[float] if not None, the predicted epsilon will + be scaled down by the provided factor as + introduced in https://arxiv.org/pdf/2308.15321""" "" + + initial_noise_scale: float = 1.0 + """For Diffusion models, scaling of initial noise""" + + inference_timesteps: int = 100 + """For Diffusion models, number of denoising timesteps""" + + clip_noise: int = 100 + """For Diffusion models, factor to clip noise of the sampling steps""" + + thresholding: bool = False + """Whether to use the "dynamic thresholding" method. + This is unsuitable for latent-space diffusion models such as Stable Diffusion.""" + + dynamic_thresholding_ratio: float = 0.995 + """The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.""" + + sample_max_value: float = 6.0 + """The threshold value for dynamic thresholding. Valid only when `thresholding=True`.""" + + +class TwoTowerDiffusionLCMGenerator(LCMGenerator): + """Generates with a Two-tower Diffusion LCM model.""" + + options: DiffusionLCMGeneratorOptions + + def __init__( + self, + model: AbstractLCModel, + options: Optional[LCMGeneratorOptions] = None, + eos_vec: Optional[torch.Tensor] = None, + ) -> None: + super().__init__(model, options, eos_vec) + + assert isinstance(self.model, TwoTowerDiffusionLCModel), ( + "The TwoTowerDiffusionLCMGenerator expects a Diffusion LCM" + ) + + logger.info( + f"Setting up the model with decoding_options: {options} -- {type(options)}" + ) + model.prep_for_denoising(options) + + @torch.inference_mode() + def __call__( + self, + batch_input: EmbeddingsBatch, + max_gen_len: Optional[int] = None, + min_gen_len: Optional[int] = None, + temperature: float = 0.0, + disable_cache: bool = False, + **kwargs, + ) -> SequenceGeneratorOutput: + """ + :param input: + `bacth_input` embedded and padded tensor sequence of the inputs + `max_gen_len` max length to be generated for the given input + `min_gen_len` minimum length to be generated for the given input + `disable_cache` if True, do not use kv-caching + `temperature` temperature to control the generation + :returns: + The output of the LCM generator, consists of :math:`N` lists of + hypotheses for :math:`N` prompts. Each list has 1 Hypothesis + (beam size = 1), of which `seq` has the *Shape:* math:`(S+T, D)` + (:math:`S` is the prompt length, :math:`T` the length of the + generated sequence after the prompt and :math:`D` the model + dimension.) + + """ + if self.options.seed: + torch.manual_seed(self.options.seed) + + # Setup the variables + batch_size, self.max_prompt_len, embed_dim = batch_input.seqs.size() + prompt_padding_mask = batch_input.padding_mask + if prompt_padding_mask is None: + self.min_prompt_len = self.max_prompt_len + self.prompt_padding_mask = None + self.prompt_seq_lens = None + else: + self.prompt_seq_lens = prompt_padding_mask.seq_lens + assert self.prompt_seq_lens is not None, ( + "Expecting a valid `self.prompt_seq_lens` Tensor, found `None`" + ) + self.min_prompt_len = int(torch.min(self.prompt_seq_lens, dim=0)[0].item()) + + # Keep the materialized mask + self.prompt_padding_mask = prompt_padding_mask.materialize() + + if not max_gen_len: + max_gen_len = self.max_seq_len + + # Make sure we do not accidentally set a max_gen_len that exceeds + # the generator's model capability + assert max_gen_len <= self.max_seq_len, ( + f"Generator can generate up to {self.max_seq_len} sequences, max_gen_len={max_gen_len}" + ) + self.max_gen_len = max_gen_len + + if not min_gen_len: + min_gen_len = self.min_seq_len + + assert min_gen_len > 0, ( + f"min_gen_len must be greater than or equal to 1, min_gen_len={min_gen_len}" + ) + self.min_gen_len = min_gen_len + + if temperature == 0.0: + # If the call doesn't pass a specific temperature, + # use the default one from the decoding options + temperature = self.options.lcm_temperature + + # Holds the generated sequences, scores and sample-dependent variables + dtype = self.model.dtype + device = batch_input.seqs.device + self.temperature = temperature + + if disable_cache: + self.state_bag = None + self.context_state_bag = None + else: + self.state_bag = LCMIncrementalStateBag( + self.max_prompt_len + self.max_gen_len + ) + self.context_state_bag = LCMIncrementalStateBag( + self.max_prompt_len + self.max_gen_len + ) + + # reserving full sequences capacity + self.seqs = torch.zeros( + (batch_size, self.max_prompt_len + self.max_gen_len, embed_dim), + device=device, + dtype=dtype, + ) + self.step_scores = torch.zeros( + (batch_size, self.max_prompt_len + self.max_gen_len), + device=device, + ) + self.lengths = torch.zeros(batch_size, dtype=torch.int, device=device) - 1 + + # Hold the samples indices to return in order + self.sample_indices = torch.arange(batch_size, device=device) + # Output buffer + self.hypotheses: List[List[Hypothesis]] = [[] for _ in range(batch_size)] + + # Bootstrap the sequences with the provided prompt. + self.seqs[:, : self.max_prompt_len] = batch_input.seqs[:, : self.max_prompt_len] + self.step_nr = self.min_prompt_len + + # A context we keep growing in each decoding step + self.prefill() + + for self.step_nr in range( + self.min_prompt_len, self.max_prompt_len + self.max_gen_len + ): + if not self._step(): + break + + return SequenceGeneratorOutput(self.hypotheses, counters=GenerationCounters()) + + def state_bag_reorder(self, new_order: torch.Tensor) -> None: + if self.state_bag is not None: + self.state_bag.reorder(new_order) + + if self.context_state_bag is not None: + self.context_state_bag.reorder(new_order) + + @torch.inference_mode() + def prefill(self, **kwargs) -> None: + """encode the prefix with the context encoder""" + + assert self.context_state_bag is not None, ( + "Expecting a context state bag to prefill" + ) + + context: EmbeddingsBatch + + prefill_len = self.step_nr - 1 + if prefill_len > 0: + # normalize then encode + input_seqs = self.seqs[:, :prefill_len] + if self.model.config.sonar_normalizer_name is not None: + input_seqs = self.model.sonar_normalizer.normalize(input_seqs) + + context = self.model.encode( + EmbeddingsBatch(input_seqs, None), + state_bag=self.context_state_bag, + **kwargs, + ) + + self.context_state_bag.increment_step_nr(prefill_len) + + else: + logger.warning( + f"Skipping prefill since only a context size of {self.step_nr} is provided in the prefix" + ) + context = EmbeddingsBatch( + torch.empty( + (self.seqs.shape[0], 0, self.model.model_dim), + dtype=self.seqs.dtype, + device=self.seqs.device, + ) + ) + + self.context = context + + @torch.inference_mode() + def _decode( + self, + seqs: torch.Tensor, + padding_mask: Optional[PaddingMask] = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + output, context = self.model.predict_next_sentence( + batch=EmbeddingsBatch(seqs, padding_mask), + context=self.context, + temperature=self.temperature, + state_bag=self.state_bag, + context_state_bag=self.context_state_bag, + **kwargs, + ) + self.context = context + + # Dummy scores + scores = torch.zeros(seqs.shape[:-1]) + return output.seqs, scores + + def _step(self) -> bool: + # Generate the next step output. + + if self.state_bag is None: + # Without a state_bag, we're forwarding the full prefix + # Encode the full context: + + model_output, step_score = self._decode( + seqs=self.seqs[:, : self.step_nr], + padding_mask=None, + ) + else: + # Since we're using a state_bag, we're only forwarding the last embedding + model_output, step_score = self._decode( + seqs=self.seqs[:, self.step_nr - 1 : self.step_nr], + padding_mask=None, + ) + + self.state_bag.increment_step_nr() + + # model_output: EmbeddingBag + return self.finalize_step(model_output, step_score) + + def finalize_step( + self, model_output: torch.Tensor, step_score: torch.Tensor + ) -> bool: + """Post-processing and finalizing a step + by checking all stopping criteria + Takes the model's outputed embeddings (model_output) + and their associated scores (step_score) + If we're stepping, return True, else return False + """ + already_finished = self.lengths > -1 + should_finish_now = torch.zeros_like(already_finished) + + model_last_output = model_output[:, -1] + device = model_last_output.device + + # Ignore prompt positions between min-max prompt_len + must_keep_going = None + if self.step_nr < self.max_prompt_len: + assert self.prompt_padding_mask is not None, ( + f"If self.prompt_padding_mas is None, then self.step_nr should start from self.max_prompt_len={self.max_prompt_len} - currently self.step_nr = {self.step_nr}" + ) + mask = self.prompt_padding_mask[:, self.step_nr] + model_last_output[mask] = self.seqs[mask, self.step_nr] + must_keep_going = mask + + # Check stopping based on EOS similarity. + if self.eos_threshold is not None and self.eos_vec is not None: + sim2eos = torch.nn.functional.cosine_similarity( + self.eos_vec.to(device), model_last_output + ) + logger.debug(f"Similarity to eos vector: {sim2eos} vs {self.eos_threshold}") + should_finish_now = should_finish_now | sim2eos.ge(self.eos_threshold) + + # Check stopping based on repetition. + if ( + self.options.stop_on_repetition_cosine_threshold is not None + and self.step_nr > 0 + ): + sim2prev = torch.nn.functional.cosine_similarity( + self.seqs[:, self.step_nr - 1], model_last_output + ) + logger.debug( + f"Similarity to prev vector: {sim2prev} vs {self.options.stop_on_repetition_cosine_threshold}" + ) + should_finish_now = should_finish_now | sim2prev.ge( + self.options.stop_on_repetition_cosine_threshold + ) + + if must_keep_going is not None: + logger.debug( + f"Must keep going (to cover max_prompt_len={self.max_prompt_len}) is not None = {must_keep_going}" + ) + should_finish_now = should_finish_now & ~must_keep_going + + # Keep going if output is shorter than min_gen_len: + if self.prompt_seq_lens is not None: + longuer_than_min_gen_len = (self.step_nr - self.prompt_seq_lens).ge( + self.min_gen_len + ) + else: + longuer_than_min_gen_len = ( + self.step_nr - self.max_prompt_len + ) >= self.min_gen_len + + logger.debug( + f"Longuer than min_gen_len ({self.min_gen_len}) = {longuer_than_min_gen_len}" + ) + should_finish_now = should_finish_now & longuer_than_min_gen_len + stopped_on_eos = should_finish_now + + # Stop hypotheses that reached max_gen_len + if self.prompt_seq_lens is not None: + exceeds_max_gen_len = (self.step_nr - self.prompt_seq_lens + 1).ge( + self.max_gen_len + ) + logger.debug( + f"step: {self.step_nr}; max_gen_len: {self.max_gen_len}; promt_lens: {self.prompt_seq_lens}; steps exceeded: {self.max_gen_len + self.prompt_seq_lens}" + ) + + else: + exceeds_max_gen_len = ( + self.step_nr - self.max_prompt_len + 1 + ) >= self.max_gen_len + logger.debug( + f"step: {self.step_nr}; max_gen_len: {self.max_gen_len}; promt_lens: None (unique length: {self.max_prompt_len}); steps exceeded: {self.max_prompt_len + self.max_gen_len}" + ) + + logger.debug( + f"Stopping criteria: {should_finish_now}; exceeds max len: {exceeds_max_gen_len}; already finished: {already_finished}" + ) + + should_finish_now = should_finish_now | exceeds_max_gen_len + + # Assign lengths to the sequences that have just finished. + should_finish_now = should_finish_now & ~already_finished + self.lengths[should_finish_now] = self.step_nr + 1 + + # Record the current step. + self.seqs[:, self.step_nr] = model_last_output.squeeze(1) + self.step_scores[:, self.step_nr - self.min_prompt_len] = step_score[:, -1] + + # Save completed hypsptheses + finished_mask = self.lengths.ne(-1) + finished_indices = finished_mask.nonzero() + + # Remove finished hypotheses and reorder variables/state_bag if any are left + if len(finished_indices) > 0: + for idx in finished_indices: + self.finish_sequence(int(idx), is_eos=bool(stopped_on_eos[int(idx)])) + + active_mask = ~finished_mask + active_indices = active_mask.nonzero().squeeze(-1) + + if len(active_indices) == 0: + return False + + self.reorder_state(active_indices) + + return True + + def finish_sequence(self, idx: int, is_eos: bool = False) -> None: + seq_len = int(self.lengths[idx].item()) + + if self.options.trim_hypotheses and self.lengths[idx].item() > -1 and is_eos: + seq_len = int(self.lengths[idx].item()) - int( + not self.options.include_eos_token + ) + + sample_idx = int(self.sample_indices[idx]) + self.hypotheses[sample_idx] = [ + Hypothesis( + seq=self.seqs[idx, :seq_len], + score=None, + step_scores=self.step_scores[idx], # Trim it as well? + ) + ] + + def reorder_state(self, new_order: torch.Tensor) -> None: + self.state_bag_reorder(new_order) + + self.context = EmbeddingsBatch( + self.context.seqs.index_select(dim=0, index=new_order), + self.context.padding_mask, + ) + + self.seqs = self.seqs.index_select(dim=0, index=new_order) + + self.sample_indices = self.sample_indices.index_select(dim=0, index=new_order) + + self.step_scores = self.step_scores.index_select(dim=0, index=new_order) + + self.lengths = self.lengths.index_select(dim=0, index=new_order) + + if self.prompt_padding_mask is not None: + self.prompt_padding_mask = self.prompt_padding_mask.index_select( + dim=0, index=new_order + ) + + if self.prompt_seq_lens is not None: + self.prompt_seq_lens = self.prompt_seq_lens.index_select( + dim=0, index=new_order + ) diff --git a/lcm/inference/two_tower_diffusion_lcm/scorer.py b/lcm/inference/two_tower_diffusion_lcm/scorer.py new file mode 100644 index 0000000000000000000000000000000000000000..bad99d776e0537c592dca8cda625972515574fc6 --- /dev/null +++ b/lcm/inference/two_tower_diffusion_lcm/scorer.py @@ -0,0 +1,314 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# + +from typing import List, Optional, Tuple + +import torch +from fairseq2.generation.generator import ( + GenerationCounters, + Hypothesis, + SequenceGeneratorOutput, +) +from fairseq2.logging import get_log_writer + +from lcm.datasets.batch import EmbeddingsBatch, PaddingMask +from lcm.inference.lcm.generator import LCMGeneratorOptions +from lcm.inference.two_tower_diffusion_lcm import ( + TwoTowerDiffusionLCMGenerator, +) +from lcm.models.abstract_lcm import AbstractLCModel +from lcm.nn.incremental_state import LCMIncrementalStateBag + +logger = get_log_writer(__name__) + + +class TwoTowerDiffusionLCMScorer(TwoTowerDiffusionLCMGenerator): + """Score by generating in teacher-forcing mode with a Two-tower Diffusion LCM model.""" + + def __init__( + self, + model: AbstractLCModel, + options: Optional[LCMGeneratorOptions] = None, + eos_vec: Optional[torch.Tensor] = None, + ) -> None: + super().__init__(model, options, eos_vec) + + @torch.inference_mode() + def __call__( # type: ignore + self, + batch_input: EmbeddingsBatch, + max_gen_len: Optional[int] = None, + min_gen_len: Optional[int] = None, + min_context_len: int = 1, + temperature: float = 0.0, + disable_cache: bool = False, + ) -> SequenceGeneratorOutput: + """ + :param input: + `bacth_input` embedded and padded tensor sequence of the inputs + `max_gen_len` max length to be generated for the given input + `min_gen_len` minimum length to be generated for the given input + `disable_cache` if True, do not use kv-caching + :returns: + The output of the LCM generator, consists of :math:`N` lists of + hypotheses for :math:`N` documents. Each list has 1 Hypothesis + (beam size = 1), of which `seq` has the *Shape:* math:`(T, D)` + (:math:`T` the length of the document and :math:`D` the model + dimension.) + + """ + if self.options.seed: + torch.manual_seed(self.options.seed) + + # Setup the variables + self.min_context_len = min_context_len + batch_size, self.max_text_len, embed_dim = batch_input.seqs.size() + text_padding_mask = batch_input.padding_mask + if text_padding_mask is None: + self.text_padding_mask = None + self.text_seq_lens = self.max_text_len * torch.ones( + batch_size, + dtype=torch.long, + device=batch_input.seqs.device, + ) + else: + self.text_seq_lens = text_padding_mask.seq_lens + assert self.text_seq_lens is not None, ( + "Expecting a valid `self.text_seq_lens` Tensor, found `None`" + ) + + # Keep the materialized mask + self.text_padding_mask = text_padding_mask.materialize() + + if not max_gen_len: + max_gen_len = self.max_seq_len + + max_gen_len = min(max_gen_len, self.max_text_len - self.min_context_len) + assert max_gen_len is not None, "max_gen_len is None" + + # Make sure we do not accidentally set a max_gen_len that exceeds + # the generator's model capability + assert max_gen_len <= self.max_seq_len, ( + f"Generator can generate up to {self.max_seq_len} sequences, max_gen_len={max_gen_len}" + ) + self.max_gen_len = max_gen_len + + if not min_gen_len: + min_gen_len = self.min_seq_len + + assert min_gen_len is not None, "A `min_gen_len` is required" + + assert min_gen_len > 0, ( + f"min_gen_len must be greater than or equal to 1, min_gen_len={min_gen_len}" + ) + + self.min_gen_len = min_gen_len + + if temperature == 0.0: + # If the call doesn't pass a specific temperature, + # use the default one from the decoding options + temperature = self.options.lcm_temperature + + # Holds the generated sequences, scores and sample-dependent variables + dtype = self.model.dtype + device = batch_input.seqs.device + self.temperature = temperature + + if disable_cache: + self.state_bag = None + self.context_state_bag = None + else: + self.state_bag = LCMIncrementalStateBag(self.max_text_len) + self.context_state_bag = LCMIncrementalStateBag(self.max_text_len) + + # reserving full sequences capacity + self.seqs = batch_input.seqs + self.preds = torch.zeros( + (batch_size, self.max_text_len - self.min_context_len, embed_dim), + device=device, + dtype=dtype, + ) + + self.step_scores = torch.zeros( + (batch_size, self.max_text_len), + device=device, + ) + # Hold the samples indices to return in order + self.sample_indices = torch.arange(batch_size, device=device) + # Output buffer + self.hypotheses: List[List[Hypothesis]] = [[] for _ in range(batch_size)] + + # the sequences with the provided prompt. + self.step_nr = self.min_context_len + + # A context we keep growing in each decoding step + self.prefill() + + for self.step_nr in range(self.min_context_len, self.max_text_len): + if not self._step(): + break + + return SequenceGeneratorOutput(self.hypotheses, counters=GenerationCounters()) + + def state_bag_reorder(self, new_order: torch.Tensor) -> None: + if self.state_bag is not None: + self.state_bag.reorder(new_order) + + if self.context_state_bag is not None: + self.context_state_bag.reorder(new_order) + + @torch.inference_mode() + def prefill(self, **kwargs) -> None: + """encode the prefix with the context encoder""" + + assert self.context_state_bag is not None, ( + "Expecting a context state bag to prefill" + ) + + context: EmbeddingsBatch + + # FIXME for this model we can prefill with self.step_nr + prefill_len = self.step_nr - 1 + if prefill_len > 0: + # normalize then encode + input_seqs = self.seqs[:, :prefill_len] + if self.model.config.sonar_normalizer_name is not None: + input_seqs = self.model.sonar_normalizer.normalize(input_seqs) + + context = self.model.encode( + EmbeddingsBatch(input_seqs, None), + state_bag=self.context_state_bag, + **kwargs, + ) + + self.context_state_bag.increment_step_nr(prefill_len) + + else: + logger.warning( + f"Skipping prefill since only a context size of {self.step_nr} is provided in the prefix" + ) + context = EmbeddingsBatch( + torch.empty( + (self.seqs.shape[0], 0, self.model.model_dim), + dtype=self.seqs.dtype, + device=self.seqs.device, + ) + ) + + self.context = context + + @torch.inference_mode() + def _decode( + self, + seqs: torch.Tensor, + padding_mask: Optional[PaddingMask] = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + output, context = self.model.predict_next_sentence( + batch=EmbeddingsBatch(seqs, padding_mask), + context=self.context, + temperature=self.temperature, + state_bag=self.state_bag, + context_state_bag=self.context_state_bag, + **kwargs, + ) + self.context = context + + # Dummy score + scores = torch.zeros(seqs.shape[:-1]) + return output.seqs, scores + + def _step(self) -> bool: + # Generate the next step output. + + if self.state_bag is None: + # Without a state_bag, we're forwarding the full prefix + # Encode the full context: + + model_output, step_score = self._decode( + seqs=self.seqs[:, : self.step_nr], + padding_mask=None, + ) + else: + # Since we're using a state_bag, we're only forwarding the last embedding + model_output, step_score = self._decode( + seqs=self.seqs[:, self.step_nr - 1 : self.step_nr], + padding_mask=None, + ) + + self.state_bag.increment_step_nr() + + # model_output: EmbeddingBag + return self.finalize_step(model_output, step_score) + + def finalize_step( + self, model_output: torch.Tensor, step_score: torch.Tensor + ) -> bool: + """Post-processing and finalizing a step + by checking all stopping criteria + Takes the model's outputed embeddings (model_output) + and their associated scores (step_score) + If we're stepping, return True, else return False + """ + model_last_output = model_output[:, -1] + must_keep_going = self.text_seq_lens.gt(self.step_nr + 1) + should_finish_now = ~must_keep_going + + # Record the current step prediction. + self.preds[:, self.step_nr - self.min_context_len] = model_last_output.squeeze( + 1 + ) + self.step_scores[:, self.step_nr - self.min_context_len] = step_score[:, -1] + + # Save completed hypsptheses + finished_indices = should_finish_now.nonzero() + + # Remove finished hypotheses and reorder variables/state_bag if any are left + if len(finished_indices) > 0: + for idx in finished_indices: + self.finish_sequence(int(idx)) + + active_mask = must_keep_going + active_indices = active_mask.nonzero().squeeze(-1) + + if len(active_indices) == 0: + return False + + self.reorder_state(active_indices) + + return True + + def finish_sequence(self, idx: int) -> None: # type: ignore + seq_len = int(self.text_seq_lens[idx].item()) + sample_idx = int(self.sample_indices[idx]) + self.hypotheses[sample_idx] = [ + Hypothesis( + seq=self.preds[idx, : seq_len - self.min_context_len], + score=None, + step_scores=self.step_scores[idx], # Trim it as well? + ) + ] + + def reorder_state(self, new_order: torch.Tensor) -> None: + self.state_bag_reorder(new_order) + + self.context = EmbeddingsBatch( + self.context.seqs.index_select(dim=0, index=new_order), + self.context.padding_mask, + ) + + self.seqs = self.seqs.index_select(dim=0, index=new_order) + self.preds = self.preds.index_select(dim=0, index=new_order) + + self.sample_indices = self.sample_indices.index_select(dim=0, index=new_order) + + self.step_scores = self.step_scores.index_select(dim=0, index=new_order) + + if self.text_padding_mask is not None: + self.text_padding_mask = self.text_padding_mask.index_select( + dim=0, index=new_order + ) + + self.text_seq_lens = self.text_seq_lens.index_select(dim=0, index=new_order) diff --git a/lcm/models/__init__.py b/lcm/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fc7af6fccb2d07332d6686a1ad77e3fe205e7625 --- /dev/null +++ b/lcm/models/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# + +# We import all the model types in order to populate the model type registry +from lcm.models.base_lcm.loader import BASE_LCM_MODEL_TYPE +from lcm.models.two_tower_diffusion_lcm.loader import ( + TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE, +) + +__all__ = [ + "BASE_LCM_MODEL_TYPE", + "TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE", +] diff --git a/lcm/models/abstract_lcm/__init__.py b/lcm/models/abstract_lcm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b8105659de132a089d4daaeff714d2930a7a79f2 --- /dev/null +++ b/lcm/models/abstract_lcm/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# + +from lcm.models.abstract_lcm.builder import ( + AbstractLCModel, + AbstractLCModelBuilder, + AbstractLCModelConfig, +) + +__all__ = [ + "AbstractLCModel", + "AbstractLCModelBuilder", + "AbstractLCModelConfig", +] diff --git a/lcm/models/abstract_lcm/builder.py b/lcm/models/abstract_lcm/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..365215dd82fdf881dadaaac612c0d46bb1776f4a --- /dev/null +++ b/lcm/models/abstract_lcm/builder.py @@ -0,0 +1,106 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# + +from abc import abstractmethod +from dataclasses import dataclass +from typing import Optional + +from fairseq2.config_registry import ConfigRegistry +from fairseq2.logging import get_log_writer +from fairseq2.typing import DataType, Device +from torch.nn import Module + +from lcm.models.sonar_normalizer import SonarNormalizer, load_sonar_normalizer_model + +logger = get_log_writer(__name__) + + +""" +An abstract LCM model class for the bare minimum +""" + +ABSTRACT_LCM_MODEL_TYPE = "abstract_lcm" + + +@dataclass +class AbstractLCModelConfig: + model_type: str = ABSTRACT_LCM_MODEL_TYPE + + sonar_embed_dim: int = 1024 + + sonar_normalizer_name: Optional[str] = None + + +lcm_archs = ConfigRegistry[AbstractLCModelConfig]() +lcm_arch = lcm_archs.decorator + + +class AbstractLCModel(Module): + """Asbtract Class for LCM models""" + + def __init__( + self, + config: AbstractLCModelConfig, + ) -> None: + """ + Asbtract LCM model + """ + super().__init__() + + self.config = config + + @property + def dtype(self): + return next(self.parameters()).dtype + + @property + def device(self): + return next(self.parameters()).device + + +class AbstractLCModelBuilder: + """Builds modules of an LCM""" + + config: AbstractLCModelConfig + device: Optional[Device] + dtype: Optional[DataType] + + def __init__( + self, + config: AbstractLCModelConfig, + *, + device: Optional[Device] = None, + dtype: Optional[DataType] = None, + ) -> None: + """ + :param config: + The configuration. + :param device: + The device on which to initialize modules. + :param dtype: + The data type of module parameters and buffers. + """ + self.config = config + + self.device, self.dtype = device, dtype + + def build_sonar_normalizer( + self, + ) -> Optional[SonarNormalizer]: + if self.config.sonar_normalizer_name is not None: + logger.info( + f"Building sonar_normalizer = {self.config.sonar_normalizer_name}" + ) + return load_sonar_normalizer_model( + self.config.sonar_normalizer_name, + device=self.device, + dtype=self.dtype, + ) + return None + + @abstractmethod + def build_model(self) -> AbstractLCModel: + """Build a model.""" + ... diff --git a/lcm/models/base_lcm/__init__.py b/lcm/models/base_lcm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..33f215314ff0d90ae8a0ae651ef7043cc8b690f0 --- /dev/null +++ b/lcm/models/base_lcm/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# + +# Register architectures +import lcm.models.base_lcm.archs # noqa +from lcm.models.base_lcm.builder import ( + BaseLCModel, + BaseLCModelBuilder, + BaseLCModelConfig, + create_base_lcm_model, +) + +__all__ = [ + "BaseLCModel", + "BaseLCModelBuilder", + "BaseLCModelConfig", + "create_base_lcm_model", +] diff --git a/lcm/models/base_lcm/archs.py b/lcm/models/base_lcm/archs.py new file mode 100644 index 0000000000000000000000000000000000000000..4945f600e3db49db286c3ee1989fd79584ffd544 --- /dev/null +++ b/lcm/models/base_lcm/archs.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# + +from lcm.models.base_lcm.builder import ( + BaseLCModelConfig, + LCMFrontendConfig, + ProjectionConfig, + TransformerConfig, + lcm_arch, +) + + +# Every model must register a toy_{model_family} +@lcm_arch("toy_base_lcm") +def toy_base_lcm() -> BaseLCModelConfig: + return BaseLCModelConfig( + lcm=TransformerConfig(num_layers=2), + ) + + +@lcm_arch("base_lcm_1_6B") +def base_lcm_1_6B() -> BaseLCModelConfig: + """Base 1.6B model + Parameter Size: 1,647,635,456 + """ + model_dim: int = 2048 + num_attn_heads: int = 16 + return BaseLCModelConfig( + max_seq_len=4096, + model_dim=model_dim, + sonar_embed_dim=1024, + sonar_normalizer_name="dummy_sonar_normalizer", + frontend=LCMFrontendConfig(), + lcm=TransformerConfig( + final_dropout_p=0.0, + attention_dropout_p=0.0, + dropout_p=0.1, + mha_output_proj_bias=True, + ffn_inner_dim=model_dim * 4, + num_attn_heads=num_attn_heads, + num_layers=32, + pos_embedding_style="rope", + use_swiglu=True, + layer_normalization_style="rms", + ), + postnet=ProjectionConfig(), + ) diff --git a/lcm/models/base_lcm/builder.py b/lcm/models/base_lcm/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..da7258fabc5d7241a1798bc3b2ef4059f5b23aa7 --- /dev/null +++ b/lcm/models/base_lcm/builder.py @@ -0,0 +1,285 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# + +from dataclasses import dataclass, field +from typing import Optional + +import torch.nn +from fairseq2.config_registry import ConfigRegistry +from fairseq2.logging import get_log_writer +from fairseq2.nn.incremental_state import IncrementalStateBag +from fairseq2.nn.transformer import AttentionMaskFactory, CausalAttentionMaskFactory +from fairseq2.typing import DataType, Device + +from lcm.datasets.batch import EmbeddingsBatch +from lcm.models.abstract_lcm import ( + AbstractLCModel, + AbstractLCModelBuilder, + AbstractLCModelConfig, +) +from lcm.models.base_lcm.frontend import LCMFrontend, LCMFrontendConfig +from lcm.nn.initialization import parse_norm_order +from lcm.nn.normalization import parse_layer_norm_factory +from lcm.nn.projection import Projection, ProjectionConfig +from lcm.nn.transformer import ( + LCMTransformerDecoder, + TransformerConfig, + TransformerFactory, +) + +logger = get_log_writer(__name__) + +BASE_LCM_MODEL_TYPE = "base_lcm" + + +@dataclass +class BaseLCModelConfig(AbstractLCModelConfig): + model_type: str = BASE_LCM_MODEL_TYPE + + max_seq_len: int = 2048 + + model_dim: int = 1024 + + model_output_dim: Optional[int] = None + """If ``None`` use SONAR dimension as output_dim.""" + + frontend: LCMFrontendConfig = field(default_factory=lambda: LCMFrontendConfig()) + """The fronted config. This module maps from `sonar_embed_dim` to `model_dim` + and potentially adds positional embeddings""" + + lcm: TransformerConfig = field(default_factory=lambda: TransformerConfig()) + """The core lcm config. This is causal Transformer decoder""" + + postnet: ProjectionConfig = field(default_factory=lambda: ProjectionConfig()) + """The postnet config. A module mapping the output of the core lcm + back to `sonar_embed_dim`""" + + +lcm_archs = ConfigRegistry[BaseLCModelConfig]() +lcm_arch = lcm_archs.decorator + + +class BaseLCModel(AbstractLCModel): + """Base class for LCM models""" + + config: BaseLCModelConfig + + def __init__( + self, + config: BaseLCModelConfig, + lcm: LCMTransformerDecoder, + frontend: LCMFrontend, + postnet: Projection, + ) -> None: + """ + Basic LCM model with : + - fronted + - lcm + - postnet + """ + super().__init__(config) + + self.frontend = frontend + + self.lcm = lcm + + self.postnet = postnet + + self.model_dim = lcm.model_dim + + self.sonar_embed_dim = config.sonar_embed_dim + + def forward( + self, + batch: EmbeddingsBatch, + state_bag: Optional[IncrementalStateBag] = None, + **kwargs, + ) -> EmbeddingsBatch: + """ + Scaling + Positions + If a normalizer is provided, the features will be normalized in the + frontend's pre_forward (e.g. MSE LCM) or in the criterion (Diffusion LCM) + """ + seqs, padding_mask = self.frontend( + batch.seqs, + batch.padding_mask, + diffusion_timesteps=batch.diffusion_timesteps, + state_bag=state_bag, + **kwargs, + ) + + # Core LCM + seqs, padding_mask = self.lcm( + seqs, + padding_mask, + state_bag=state_bag, + **kwargs, + ) + + # Postnet: + seqs = self.postnet(seqs) # type: ignore + + return EmbeddingsBatch(seqs=seqs, padding_mask=padding_mask) + + def predict_next_sentence( + self, + batch: EmbeddingsBatch, + sample: bool = False, + temperature: float = 1.0, + state_bag: Optional[IncrementalStateBag] = None, + **kwargs, + ) -> EmbeddingsBatch: + """ + The method for predicting the next sentence embeddings. + In the basic LCM, this is equivalent to just the forward method, + but the derived architectures may have a different implementation. + E.g. in VAE LCM, we run the VAE decoder on top of the `forward` results. + + Args: + batch (EmbeddingsBatch): the sequence of concepts which + the model should continue. + sample (bool): whether to predict the single most probable next sentence + or to sample from the predicted distribution. + temperature (float): a positive float indicating the degree of diversity + for the sampling (active only if `sample is True`). + Returns: + EmbeddingsBatch: the batch with predicted SONAR sentences. + """ + # Normalize the input embeddings if we're expected to + # normalize outside of the model's forward pass + if self.frontend.sonar_normalizer is not None: + batch = batch.normalize_seqs(self.frontend.sonar_normalizer) + + # TODO: implement efficient sampling of multiple candidates + predicted_means = self.forward(batch, state_bag=state_bag, **kwargs) + + if sample and temperature > 0: + noise = torch.randn_like(predicted_means.seqs) * temperature + predicted_means.seqs = predicted_means.seqs + noise + + if self.frontend.sonar_normalizer is not None: + predicted_means = predicted_means.denormalize_seqs( + self.frontend.sonar_normalizer + ) + + return predicted_means + + +class BaseLCModelBuilder(AbstractLCModelBuilder): + """Builds modules of a base LCM model""" + + config: BaseLCModelConfig + device: Optional[Device] + dtype: Optional[DataType] + + def __init__( + self, + config: BaseLCModelConfig, + *, + device: Optional[Device] = None, + dtype: Optional[DataType] = None, + ) -> None: + super().__init__(config=config, device=device, dtype=dtype) + self.lcm_factory = TransformerFactory( + model_dim=self.config.model_dim, + max_seq_len=self.config.max_seq_len, + config=self.config.lcm, + device=device, + dtype=dtype, + ) + + if config.model_output_dim is None: + self.model_output_dim = self.config.sonar_embed_dim + else: + self.model_output_dim = config.model_output_dim + + def build_model(self) -> BaseLCModel: + """Build a model.""" + + frontend = self.build_frontend() + + lcm = self.build_core_lcm() + + postnet = self.build_postnet() + + return BaseLCModel( + config=self.config, + frontend=frontend, + lcm=lcm, + postnet=postnet, + ) + + def build_frontend(self) -> LCMFrontend: + """Build the LCM front-end (i.e., prenet).""" + + return LCMFrontend( + sonar_embed_dim=self.config.sonar_embed_dim, + model_dim=self.config.model_dim, + config=self.config.frontend, + pos_encoder=self.lcm_factory.build_pos_encoder(), + sonar_normalizer=self.build_sonar_normalizer(), + device=self.device, + dtype=self.dtype, + ) + + def build_postnet(self) -> Projection: + return Projection( + output_dim=self.model_output_dim, + input_dim=self.config.model_dim, + config=self.config.postnet, + device=self.device, + dtype=self.dtype, + ) + + def build_attention_mask_factory(self): + self_attn_mask_factory: AttentionMaskFactory + + self_attn_mask_factory = CausalAttentionMaskFactory() + + return self_attn_mask_factory + + def build_core_lcm(self) -> LCMTransformerDecoder: + """Build the core LCM module.""" + + config = self.config.lcm + + layers = [self.lcm_factory.build_layer() for _ in range(config.num_layers)] + + self_attn_mask_factory = self.build_attention_mask_factory() + + if config.final_norm_order_style is None: + # The final norm order style will be that of the layer-level norm order + final_norm_order = parse_norm_order(config.norm_order_style) + else: + final_norm_order = parse_norm_order(config.final_norm_order_style) + + layer_norm_factory = parse_layer_norm_factory(config.layer_normalization_style) + + return LCMTransformerDecoder( + layers, # type: ignore + self_attn_mask_factory=self_attn_mask_factory, + norm_order=final_norm_order, + layer_norm_factory=layer_norm_factory, + dropout_p=config.final_dropout_p, + device=self.device, + dtype=self.dtype, + ) + + +def create_base_lcm_model( + config: BaseLCModelConfig, + *, + device: Optional[Device] = None, + dtype: Optional[DataType] = None, +) -> BaseLCModel: + """Create an LCM model. + :param config: + The configuration. + :param device: + The device on which to initialize modules. + :param dtype: + The data type of module parameters and buffers. + """ + return BaseLCModelBuilder(config, device=device, dtype=dtype).build_model() diff --git a/lcm/models/base_lcm/frontend.py b/lcm/models/base_lcm/frontend.py new file mode 100644 index 0000000000000000000000000000000000000000..6a553f98fcb9e15535eeb8493c5cddd7a7251c54 --- /dev/null +++ b/lcm/models/base_lcm/frontend.py @@ -0,0 +1,183 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# + +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from fairseq2.logging import get_log_writer +from fairseq2.nn import Embedding, LearnedPositionEncoder, PositionEncoder +from fairseq2.nn.incremental_state import IncrementalStateBag +from fairseq2.nn.padding import PaddingMask +from fairseq2.nn.projection import Linear +from fairseq2.typing import DataType, Device +from torch import Tensor +from torch.nn import Dropout, Module + +from lcm.models.sonar_normalizer.builder import SonarNormalizer +from lcm.nn.initialization import SONAR_STD, SUPPORTED_INIT_TYPES, get_init_fn + +logger = get_log_writer(__name__) + + +@dataclass +class LCMFrontendConfig: + dropout_p: float = 0.0 + """ The dropout probability applied to the module' output""" + + pre_linear_bias: bool = True + """ Whether or not the pre-linear layer has a bias term""" + + pre_linear_init_fn: SUPPORTED_INIT_TYPES = "kaiming_uniform" + + scale_embeddings: bool = False + """ Scale the embeddings by model_dim before + adding positions (and before the pre_linear) """ + + weight_normalization: bool = False + + embedding_std: float = SONAR_STD + """Most SONAR embeddings have a distribution with the mean close to 0 + and std close to 0.006. Initializing embedding-like parameters (e.g. end-of-text vector) + from a similar distribution is recommended, to minimize their disruption of the model training + """ + + +class LCMFrontend(Module): + """ + A fronted for the LCM with positional embeddings + """ + + embed: Embedding + scale: float + pos_encoder: Optional[PositionEncoder] + dropout: Optional[Dropout] + + def __init__( + self, + sonar_embed_dim: int, + model_dim: int, + config: LCMFrontendConfig, + pos_encoder: Optional[PositionEncoder], + timestep_embed_dim: int = 0, + sonar_normalizer: Optional[SonarNormalizer] = None, + *, + device: Optional[Device] = None, + dtype: Optional[DataType] = None, + ) -> None: + """ + :param sonar_embed_dim + The embedding dimension of the sentence encoder, in this case SONAR + :param model_dim + The model embedding dimension + :param timestep_embed_dim + The embedding dimension of diffusion timesteps (if relevant, defaults to 0) + :param config: + A Frontend config. See `LCMFrontendConfig` + :param pos_encoder: + An optional position encoder. + """ + + super().__init__() + + self.sonar_embed_dim = sonar_embed_dim + + self.model_dim = model_dim + + self.device = device + + self.embed_scale: float = model_dim**0.5 if config.scale_embeddings else 1.0 + + logger.info(f"Using LCMFrontend with embeddings scaler = {self.embed_scale}") + + # Optional sonar normalizer + self.sonar_normalizer = sonar_normalizer + + # Pre-linear to map to model dimension + + init_fn = get_init_fn(config.pre_linear_init_fn) + + lin = Linear( + sonar_embed_dim + timestep_embed_dim, + model_dim, + bias=config.pre_linear_bias, + device=device, + dtype=dtype, + init_fn=init_fn, + ) + + if config.weight_normalization: + self.pre_linear = torch.nn.utils.parametrizations.weight_norm(lin) + else: + self.pre_linear = lin + + if pos_encoder is not None: + if pos_encoder.encoding_dim != self.model_dim: + raise ValueError( + f"`encoding_dim` of `pos_encoder` and `embedding_dim` of \ + `embed` must be equal, but are {pos_encoder.encoding_dim} \ + and {self.model_dim} instead." + ) + + self.pos_encoder = pos_encoder + else: + self.register_module("pos_encoder", None) + + if config.dropout_p > 0.0: + self.dropout = Dropout(config.dropout_p) + else: + self.register_module("dropout", None) + + self.reset_parameters(embedding_std=config.embedding_std) + + def reset_parameters(self, embedding_std: float) -> None: + """Initialize module parameters. + The positional embeddings should be initialized with the + same order of magnitude as the semantic embeddings, in order + to make the early training as stable as possible. + Otherwise, the positional and special token embeddings would + flood out the semantic information. + """ + logger.info( + f"Initializing frontend embeddings (special and positional) ~ N(0, {embedding_std})" + ) + if isinstance(self.pos_encoder, LearnedPositionEncoder): + torch.nn.init.normal_(self.pos_encoder.weight, std=embedding_std) + + def pre_forward( + self, seqs: Tensor, diffusion_timesteps: Optional[Tensor] = None, **kwargs + ) -> Tensor: + return seqs + + def forward( + self, + seqs: Tensor, + padding_mask: Optional[PaddingMask], + state_bag: Optional[IncrementalStateBag] = None, + diffusion_timesteps: Optional[Tensor] = None, + **kwargs, + ) -> Tuple[Tensor, Optional[PaddingMask]]: + """ + Apply pre-linear (if relevant) and add positional embeddings + """ + + # Normalize in standard LCM or add timestep embeddings in diffusion frontentd + seqs = self.pre_forward(seqs, diffusion_timesteps, **kwargs) + + # pre-linear if any: + seqs = self.pre_linear(self.embed_scale * seqs) + + if self.pos_encoder is not None: + seqs = self.pos_encoder( + seqs, + padding_mask, + state_bag=state_bag, + **kwargs, + ) + + if self.dropout is not None: + seqs = self.dropout(seqs) + + return seqs, padding_mask diff --git a/lcm/models/base_lcm/loader.py b/lcm/models/base_lcm/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..97517d7164a6ac6d9967b999f57bd9bb780e44a5 --- /dev/null +++ b/lcm/models/base_lcm/loader.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# + +import logging +from typing import Any, Dict + +from fairseq2.models.config_loader import StandardModelConfigLoader +from fairseq2.models.loader import StandardModelLoader, load_model +from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present + +from lcm.models.base_lcm.builder import ( + BASE_LCM_MODEL_TYPE, + BaseLCModelConfig, + create_base_lcm_model, + lcm_archs, +) +from lcm.utils.model_type_registry import ModelTypeConfig, lcm_model_type_registry + +logger = logging.getLogger(__name__) + + +def convert_lcm_checkpoint( + checkpoint: Dict[str, Any], config: BaseLCModelConfig +) -> Dict[str, Any]: + # For DDP checkpoints + # We need to first remove the prefix "module." from state dict keys. + consume_prefix_in_state_dict_if_present(checkpoint["model"], "module.") + return checkpoint + + +load_base_lcm_config = StandardModelConfigLoader( + family=BASE_LCM_MODEL_TYPE, + config_kls=BaseLCModelConfig, + arch_configs=lcm_archs, +) + +load_base_lcm_model = StandardModelLoader( + config_loader=load_base_lcm_config, + factory=create_base_lcm_model, + checkpoint_converter=convert_lcm_checkpoint, + restrict_checkpoints=False, +) + +load_model.register(BASE_LCM_MODEL_TYPE, load_base_lcm_model) + +lcm_model_type_registry.register( + ModelTypeConfig( + model_type=BASE_LCM_MODEL_TYPE, + config_loader=load_base_lcm_config, + model_factory=create_base_lcm_model, + model_loader=load_base_lcm_model, + ) +) diff --git a/lcm/models/base_lcm/normalization.py b/lcm/models/base_lcm/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..f68e532b70f84377e7feef2e3f986c32916519dd --- /dev/null +++ b/lcm/models/base_lcm/normalization.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# + +from typing import Optional, final + +import torch +from fairseq2.nn import LayerNorm, RMSNorm +from fairseq2.typing import DataType, Device, override + + +@final +class FP32LayerNorm(LayerNorm): + """Applies Layer Normalization in single-precision.""" + + @override + def forward(self, x: torch.Tensor) -> torch.Tensor: + w, b = self.weight, self.bias + + # cast input and params to float32 + fp32_x = x.float() + fp32_w = w.float() if w is not None else None + fp32_b = b.float() if b is not None else None + + y = torch.nn.functional.layer_norm( + fp32_x, self.normalized_shape, fp32_w, fp32_b, self.eps + ) + + return y.type_as(x) + + +def build_rms_layer_norm( + model_dim: int, + *, + device: Optional[Device] = None, + dtype: Optional[DataType] = None, +) -> LayerNorm: + """Build an RMS Layer Normalization module.""" + return RMSNorm(model_dim, bias=False, device=device, dtype=dtype) + + +def build_fp32_layer_norm( + model_dim: int, + *, + device: Optional[Device] = None, + dtype: Optional[DataType] = None, +) -> LayerNorm: + """Build an Single-precision Layer Normalization module.""" + return FP32LayerNorm(model_dim, bias=False, device=device, dtype=dtype) diff --git a/lcm/models/sonar_normalizer/__init__.py b/lcm/models/sonar_normalizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..46efcf23c4d8dd1c4f6370ce33d2d1af251a92fb --- /dev/null +++ b/lcm/models/sonar_normalizer/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# + +# Register architectures +import lcm.models.sonar_normalizer.archs # noqa +from lcm.models.sonar_normalizer.builder import ( + SonarNormalizer, + SonarNormalizerConfig, + create_sonar_normalizer, +) +from lcm.models.sonar_normalizer.loader import load_sonar_normalizer_model + +__all__ = [ + "SonarNormalizer", + "SonarNormalizerConfig", + "create_sonar_normalizer", + "load_sonar_normalizer_model", +] diff --git a/lcm/models/sonar_normalizer/archs.py b/lcm/models/sonar_normalizer/archs.py new file mode 100644 index 0000000000000000000000000000000000000000..bcbd5b4f670c7dd82ab172c057b89d51f1575267 --- /dev/null +++ b/lcm/models/sonar_normalizer/archs.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# + +from lcm.models.sonar_normalizer.builder import ( + SonarNormalizerConfig, + sonar_normalizer_arch, +) + + +@sonar_normalizer_arch("base") +def _base_sonar_normalizer() -> SonarNormalizerConfig: + """The base architecture for all center-and-scale normalizers + regardless of how the center/scale are estimated""" + return SonarNormalizerConfig( + dim=1024, + ) + + +@sonar_normalizer_arch("base_page4k") +def _base_page_normalizer() -> SonarNormalizerConfig: + return SonarNormalizerConfig( + dim=4 * 1024, + ) + + +@sonar_normalizer_arch("base_fft") +def _base_fft_sonar_normalizer() -> SonarNormalizerConfig: + return SonarNormalizerConfig(dim=1024, with_fft=True) + + +@sonar_normalizer_arch("clipping") +def _clipping_sonar_normalizer() -> SonarNormalizerConfig: + return SonarNormalizerConfig(dim=1024, clip_proba=1e-4) + + +@sonar_normalizer_arch("clipping_fft") +def _clipping_fft_sonar_normalizer() -> SonarNormalizerConfig: + return SonarNormalizerConfig(dim=1024, clip_proba=1e-4, with_fft=True) diff --git a/lcm/models/sonar_normalizer/builder.py b/lcm/models/sonar_normalizer/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..fc577f082c5da9eb921946019ff4b5691737ff90 --- /dev/null +++ b/lcm/models/sonar_normalizer/builder.py @@ -0,0 +1,210 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# + +from dataclasses import dataclass +from typing import Literal, Optional + +import torch +from fairseq2.config_registry import ConfigRegistry +from fairseq2.typing import DataType, Device +from torch import Tensor +from torch.nn import Module + + +@dataclass +class SonarNormalizerConfig: + dim: int = 1024 + """The dimension of the features to be normalized""" + + clip_proba: Optional[float] = None + """ + If `clip_proba` is not None, `clip_min` and `clip_max` will + be used to clip the features before normalizing. + `clip_min` and `clip_max` correspond to the pre-computed `clip_proba` + and `1-clip_proba` quantiles respectively. + """ + + with_fft: bool = False + """ + Applying FFT transform at the raw input before all other transforms. + """ + + quantile_min: float = 0.25 + """The lower quantile used to measure the IQR when estimating the scale with a robust scaler""" + + quantile_max: float = 0.75 + """The upper quantile used to measure the IQR when estimating the scale with a robust scaler""" + + normalization_method: Literal["standard", "robust", "gaussian_robust"] = ( + "gaussian_robust" + ) + """ + Dictates how the normalizer's scale is evaluated when fitting. + (1) 'standard': center=mean, scale = std + (2) 'robust': center=median, scale = IQR = Qmax - Qmin + (3) 'gaussian_robust': center=median, scale = IQR / k, + where k=`stats.norm.ppf(q_max / 100.0) - stats.norm.ppf(q_min / 100.0)` + i.e scale = scale = 0.7413 x IQR if q_min=0.25 and q_max=0.75. + This is the robust normalization of https://arxiv.org/pdf/2307.05445 + """ + + +sonar_normalizer_archs = ConfigRegistry[SonarNormalizerConfig]() +sonar_normalizer_arch = sonar_normalizer_archs.decorator + + +class FFTInterface: + @staticmethod + def fft_transform(embeddings: Tensor) -> Tensor: + dtype = embeddings.dtype + if dtype in [torch.float16, torch.bfloat16]: + embeddings = embeddings.to(dtype=torch.float32) + embeddings = torch.fft.rfft(embeddings, norm="backward") + return torch.concat( + [torch.real(embeddings), torch.imag(embeddings)[..., 1:-1]], dim=-1 + ).to(dtype) + + @staticmethod + def fft_inverse_transform(embeddings: Tensor) -> Tensor: + assert embeddings.shape[-1] % 2 == 0 + dtype = embeddings.dtype + if dtype in [torch.float16, torch.bfloat16]: + embeddings = embeddings.to(dtype=torch.float32) + rr, im = torch.split( + embeddings, + [embeddings.shape[-1] // 2 + 1, embeddings.shape[-1] // 2 - 1], + dim=-1, + ) + im = torch.concat( + [torch.zeros_like(im[..., :1]), im, torch.zeros_like(im[..., :1])], dim=-1 + ) + embeddings = torch.fft.irfft(rr + im * 1j) + return embeddings.to(dtype) + + +class SonarNormalizer(FFTInterface, Module): + """ + To perform efficient diffusion modeling, SONAR embeddings need to be + normalized. This SonarNormalizer follows the robust normalization introduced in + https://arxiv.org/abs/2307.05445 + Quoting from the paper: "Due to the very long-tailed feature distribution, typical mean and standard deviation statistics will be + heavily biased. We thus propose a robust alternative based on the feature distribution quantiles. We + take the median as the center of the distribution and approximate its scale using the Normalized + InterQuartile Range (IQR) for a normal distribution: 0.7413 × IQR + """ + + def __init__( + self, + config: SonarNormalizerConfig, + device: Optional[Device] = None, + dtype: Optional[DataType] = None, + ) -> None: + super().__init__() + self.config = config + + self.register_buffer( + "center", torch.zeros(config.dim, dtype=dtype, device=device) + ) + self.register_buffer( + "scale", torch.ones(config.dim, dtype=dtype, device=device) + ) + if self.config.clip_proba is not None: + self.register_buffer( + "clip_min", torch.ones(config.dim, dtype=dtype, device=device) + ) + self.register_buffer( + "clip_max", torch.ones(config.dim, dtype=dtype, device=device) + ) + + def normalize(self, embeddings: Tensor) -> Tensor: + if self.config.with_fft: + embeddings = self.fft_transform(embeddings) + + embeddings = (embeddings - self.center) / self.scale + if self.config.clip_proba is not None: + embeddings = torch.clamp(embeddings, min=self.clip_min, max=self.clip_max) + return embeddings + + def denormalize(self, embeddings: Tensor) -> Tensor: + if self.config.clip_proba is not None: + embeddings = torch.clamp(embeddings, min=self.clip_min, max=self.clip_max) + + embeddings = (embeddings * self.scale) + self.center + if self.config.with_fft: + embeddings = self.fft_inverse_transform(embeddings) + return embeddings + + @torch.no_grad() + def fit(self, embeddings: Tensor): + if self.config.normalization_method in [ + "robust", + "gaussian_robust", + ]: + from sklearn.preprocessing import RobustScaler + + _scaler = RobustScaler( + unit_variance=self.config.normalization_method == "gaussian_robust", + quantile_range=(self.config.quantile_min, self.config.quantile_max), + ) + + elif self.config.normalization_method == "standard": + from sklearn.preprocessing import StandardScaler + + _scaler = StandardScaler() + else: + raise ValueError( + f"Unrecognizable method {self.config.normalization_method} for scaling input features" + ) + + assert embeddings.shape[-1] == self.config.dim + assert len(embeddings.shape) == 2 + + if self.config.with_fft: + embeddings = self.fft_transform(embeddings) + + embeddings = _scaler.fit_transform(embeddings.cpu().float().numpy()) + + if self.config.normalization_method in [ + "robust", + "gaussian_robust", + ]: + _center = _scaler.center_ + _scale = _scaler.scale_ + + elif self.config.normalization_method == "standard": + _center = _scaler.mean_ + _scale = _scaler.scale_ + + self.center[:] = torch.tensor( + _center, dtype=self.center.dtype, device=self.center.device + ) + self.scale[:] = torch.tensor( + _scale, dtype=self.scale.dtype, device=self.scale.device + ) + + if self.config.clip_proba is not None: + self.clip_min[:] = torch.quantile( + torch.tensor(embeddings), self.config.clip_proba, dim=0 + ).to(dtype=self.clip_min.dtype, device=self.clip_min.device) + self.clip_max[:] = torch.quantile( + torch.tensor(embeddings), 1 - self.config.clip_proba, dim=0 + ).to(dtype=self.clip_max.dtype, device=self.clip_max.device) + + +def create_sonar_normalizer( + config: SonarNormalizerConfig, + *, + device: Optional[Device] = None, + dtype: Optional[DataType] = None, +) -> SonarNormalizer: + """Create an LCM model. + :param config: + The configuration. + :param device: + The device on which to initialize modules. + :param dtype: + The data type of module parameters and buffers. + """ + return SonarNormalizer(config, device=device, dtype=dtype) diff --git a/lcm/models/sonar_normalizer/loader.py b/lcm/models/sonar_normalizer/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..785355a1b7780a83b4669397902a5407756171d0 --- /dev/null +++ b/lcm/models/sonar_normalizer/loader.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# + + +from fairseq2.models.config_loader import StandardModelConfigLoader +from fairseq2.models.loader import StandardModelLoader, load_model + +from lcm.models.sonar_normalizer.builder import ( + SonarNormalizerConfig, + create_sonar_normalizer, + sonar_normalizer_archs, +) + +load_sonar_normalizer_config = StandardModelConfigLoader( + family="sonar_normalizer", + config_kls=SonarNormalizerConfig, + arch_configs=sonar_normalizer_archs, +) + +load_sonar_normalizer_model = StandardModelLoader( + config_loader=load_sonar_normalizer_config, + factory=create_sonar_normalizer, + restrict_checkpoints=False, +) + +load_model.register("sonar_normalizer", load_sonar_normalizer_model) diff --git a/lcm/models/two_tower_diffusion_lcm/__init__.py b/lcm/models/two_tower_diffusion_lcm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7d19a452e106fb1d853a500c68a27884b69a5efb --- /dev/null +++ b/lcm/models/two_tower_diffusion_lcm/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# + +# Register architectures +import lcm.models.two_tower_diffusion_lcm.archs # noqa diff --git a/lcm/models/two_tower_diffusion_lcm/archs.py b/lcm/models/two_tower_diffusion_lcm/archs.py new file mode 100644 index 0000000000000000000000000000000000000000..9e2e3f4ca1f963c5cf1be0c21c6de23696fe7444 --- /dev/null +++ b/lcm/models/two_tower_diffusion_lcm/archs.py @@ -0,0 +1,207 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# + +from lcm.models.two_tower_diffusion_lcm.builder import ( + DenoiserConfig, + EncoderFrontendConfig, + TransformerConfig, + TwoTowerDiffusionLCModelConfig, + lcm_arch, +) +from lcm.nn.projection import ProjectionConfig +from lcm.nn.schedulers import DDIMSchedulerConfig + + +@lcm_arch("toy_two_tower_diffusion_lcm") +def toy_lcm() -> TwoTowerDiffusionLCModelConfig: + return TwoTowerDiffusionLCModelConfig( + context_encoder=TransformerConfig(num_layers=2), + denoiser=DenoiserConfig(num_layers=2), + # TODO change normalizer name to align with the normalizer instructions + sonar_normalizer_name="dummy_sonar_normalizer_A", + ) + + +@lcm_arch("arch_lexa_lcm_pre0_toy") +def lexa_lcm_pre0_toy() -> TwoTowerDiffusionLCModelConfig: + return TwoTowerDiffusionLCModelConfig( + context_encoder=TransformerConfig(num_layers=2), + denoiser=DenoiserConfig(num_layers=2), + sonar_normalizer_name="sonar_normalizer_wikipedia_en_1m", + trained_with_cf_guidance=True, + ) + + +@lcm_arch("arch_lexa_lcm_pre0_minimal") +def lexa_lcm_pre0_minimal() -> TwoTowerDiffusionLCModelConfig: + """4-layer encoder / 6-layer denoiser / model dim 768""" + model_dim: int = 768 # Reduced from 2048 to 768 + num_attn_heads: int = 12 # Reduced from 16 to 12 + return TwoTowerDiffusionLCModelConfig( + model_dim=model_dim, + max_seq_len=2048, + frontend=EncoderFrontendConfig(), + context_encoder=TransformerConfig( + num_layers=3, + ffn_inner_dim=3 * model_dim, # Reduced from 4 * model_dim to 3 * model_dim + num_attn_heads=num_attn_heads, + final_dropout_p=0.0, + attention_dropout_p=0.0, + dropout_p=0.1, + mha_output_proj_bias=True, + use_swiglu=True, + layer_normalization_style="rms", + pos_embedding_style="rope", + ), + denoiser=DenoiserConfig( + num_layers=6, # Reduced from 13 to 6 + timestep_embed_dim=model_dim, + ffn_inner_dim=3 * model_dim, # Reduced from 4 * model_dim to 3 * model_dim + pos_embedding_style="none", + num_attn_heads=num_attn_heads, + final_dropout_p=0.0, + attention_dropout_p=0.0, + dropout_p=0.1, + mha_output_proj_bias=True, + use_swiglu=True, + layer_normalization_style="rms", + pre_denoiser=ProjectionConfig(), + post_denoiser=ProjectionConfig(), + ), + sonar_normalizer_name="sonar_normalizer_wikipedia_en_1m", + trained_with_cf_guidance=True, + noise_scheduler=DDIMSchedulerConfig(num_diffusion_train_steps=100), + ) + + +@lcm_arch("arch_lexa_lcm_pre0") +def lexa_lcm_pre0() -> TwoTowerDiffusionLCModelConfig: + """4-layer encoder / 10-layer denoiser / model dim 1024 + Parameter Size: 287,880,192""" + model_dim: int = 1024 # Reduced from 2048 to 1024 + num_attn_heads: int = 16 + return TwoTowerDiffusionLCModelConfig( + model_dim=model_dim, + max_seq_len=2048, + frontend=EncoderFrontendConfig(), + context_encoder=TransformerConfig( + num_layers=4, # Reduced from 5 to 4 + ffn_inner_dim=3 * model_dim, # Reduced from 4 * model_dim to 3 * model_dim + num_attn_heads=num_attn_heads, + final_dropout_p=0.0, + attention_dropout_p=0.0, + dropout_p=0.1, + mha_output_proj_bias=True, + use_swiglu=True, + layer_normalization_style="rms", + pos_embedding_style="rope", + ), + denoiser=DenoiserConfig( + num_layers=10, # Reduced from 13 to 10 + timestep_embed_dim=model_dim, + ffn_inner_dim=3 * model_dim, # Reduced from 4 * model_dim to 3 * model_dim + pos_embedding_style="none", + num_attn_heads=num_attn_heads, + final_dropout_p=0.0, + attention_dropout_p=0.0, + dropout_p=0.1, + mha_output_proj_bias=True, + use_swiglu=True, + layer_normalization_style="rms", + pre_denoiser=ProjectionConfig(), + post_denoiser=ProjectionConfig(), + ), + sonar_normalizer_name="sonar_normalizer_wikipedia_en_1m", + trained_with_cf_guidance=True, + noise_scheduler=DDIMSchedulerConfig(num_diffusion_train_steps=100), + ) + + +@lcm_arch("two_tower_diffusion_lcm_1_6B") +def two_tower_diffusion_lcm_1_6B() -> TwoTowerDiffusionLCModelConfig: + """5-layer encodder / 13-layer denoiser / model dim 2048 + Parameter Size: 1,635,101,696""" + model_dim: int = 2048 + num_attn_heads: int = 16 + return TwoTowerDiffusionLCModelConfig( + model_dim=model_dim, + max_seq_len=4096, + frontend=EncoderFrontendConfig(), + context_encoder=TransformerConfig( + num_layers=5, + ffn_inner_dim=4 * model_dim, + num_attn_heads=num_attn_heads, + final_dropout_p=0.0, + attention_dropout_p=0.0, + dropout_p=0.1, + mha_output_proj_bias=True, + use_swiglu=True, + layer_normalization_style="rms", + pos_embedding_style="rope", + ), + denoiser=DenoiserConfig( + num_layers=13, + timestep_embed_dim=model_dim, + ffn_inner_dim=4 * model_dim, + pos_embedding_style="none", + num_attn_heads=num_attn_heads, + final_dropout_p=0.0, + attention_dropout_p=0.0, + dropout_p=0.1, + mha_output_proj_bias=True, + use_swiglu=True, + layer_normalization_style="rms", + pre_denoiser=ProjectionConfig(), + post_denoiser=ProjectionConfig(), + ), + # TODO change normalizer name to align with the normalizer instructions + sonar_normalizer_name="dummy_sonar_normalizer_B", + trained_with_cf_guidance=True, + noise_scheduler=DDIMSchedulerConfig(num_diffusion_train_steps=100), + ) + + +@lcm_arch("two_tower_diffusion_lcm_7B") +def two_tower_diffusion_lcm_7B() -> TwoTowerDiffusionLCModelConfig: + # 5-layer encodder / 14-layer denoiser / model dim 4096 + # Parameter Size: 6,930,781,696 + model_dim: int = 4096 + num_attn_heads: int = 32 + return TwoTowerDiffusionLCModelConfig( + model_dim=model_dim, + max_seq_len=4096, + frontend=EncoderFrontendConfig(), + context_encoder=TransformerConfig( + num_layers=5, + ffn_inner_dim=4 * model_dim, + num_attn_heads=num_attn_heads, + final_dropout_p=0.0, + attention_dropout_p=0.0, + dropout_p=0.1, + mha_output_proj_bias=True, + use_swiglu=True, + layer_normalization_style="rms", + pos_embedding_style="rope", + ), + denoiser=DenoiserConfig( + num_layers=14, + timestep_embed_dim=model_dim, + ffn_inner_dim=4 * model_dim, + pos_embedding_style="none", + num_attn_heads=num_attn_heads, + final_dropout_p=0.0, + attention_dropout_p=0.0, + dropout_p=0.1, + mha_output_proj_bias=True, + use_swiglu=True, + layer_normalization_style="rms", + pre_denoiser=ProjectionConfig(), + post_denoiser=ProjectionConfig(), + ), + # TODO change normalizer name to align with the normalizer instructions + sonar_normalizer_name="dummy_sonar_normalizer_C", + trained_with_cf_guidance=True, + noise_scheduler=DDIMSchedulerConfig(num_diffusion_train_steps=100), + ) diff --git a/lcm/models/two_tower_diffusion_lcm/builder.py b/lcm/models/two_tower_diffusion_lcm/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..0572ad4d7362fc673a5f5441791a14e3acb439b1 --- /dev/null +++ b/lcm/models/two_tower_diffusion_lcm/builder.py @@ -0,0 +1,628 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# + +from dataclasses import dataclass, field +from typing import Optional, Tuple + +import torch +from fairseq2.config_registry import ConfigRegistry +from fairseq2.logging import get_log_writer +from fairseq2.nn.padding import PaddingMask, get_seq_lens +from fairseq2.nn.transformer import CausalAttentionMaskFactory +from fairseq2.typing import DataType, Device +from torch import Tensor + +from lcm.datasets.batch import EmbeddingsBatch +from lcm.models.abstract_lcm import ( + AbstractLCModel, + AbstractLCModelBuilder, + AbstractLCModelConfig, +) +from lcm.models.sonar_normalizer.builder import SonarNormalizer +from lcm.models.two_tower_diffusion_lcm.frontend import ( + EncoderFrontend, + EncoderFrontendConfig, +) +from lcm.nn.denoisers import ( + DenoiserConfig, + LCMDenoiser, + LCMDenoiserTransformerFactory, +) +from lcm.nn.incremental_state import LCMIncrementalStateBag +from lcm.nn.initialization import parse_norm_order +from lcm.nn.normalization import parse_layer_norm_factory +from lcm.nn.schedulers import DDIMScheduler, DDIMSchedulerConfig +from lcm.nn.transformer import ( + LCMTransformerDecoder, + TransformerConfig, + TransformerFactory, +) + +logger = get_log_writer(__name__) + + +TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE = "two_tower_diffusion_lcm" + + +@dataclass +class TwoTowerDiffusionLCModelConfig(AbstractLCModelConfig): + model_type: str = TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE + + max_seq_len: int = 2048 + + model_dim: int = 1024 + + frontend: EncoderFrontendConfig = field( + default_factory=lambda: EncoderFrontendConfig() + ) + """ The fronted config. This module maps from `sonar_embed_dim` to `model_dim` + and potentially adds positional embeddings""" + + context_encoder: TransformerConfig = field( + default_factory=lambda: TransformerConfig() + ) + """The context encoder config. This is causal Transformer decoder""" + + noise_scheduler: DDIMSchedulerConfig = field( + default_factory=lambda: DDIMSchedulerConfig() + ) + """The config of the noise scheduler. + See lcm/diffusion_schedulers/ddim for more""" + + denoiser: DenoiserConfig = field(default_factory=lambda: DenoiserConfig()) + """the config of the denoiser""" + + trained_with_cf_guidance: bool = False + """If `True`, the model will be trained with classifier-free guidance i.e., + unconditional embedding generation. + The CF-guidance probability is set in + DiffusionLCMCriterionConfig.cf_guidance_probability""" + + +lcm_archs = ConfigRegistry[TwoTowerDiffusionLCModelConfig]() +lcm_arch = lcm_archs.decorator + + +class TwoTowerDiffusionLCModel(AbstractLCModel): + """Class for a diffusion-based LCM model""" + + config: TwoTowerDiffusionLCModelConfig + + def __init__( + self, + config: TwoTowerDiffusionLCModelConfig, + sonar_normalizer: SonarNormalizer, + encoder_frontend: EncoderFrontend, + context_encoder: LCMTransformerDecoder, + denoiser: LCMDenoiser, + noise_scheduler: DDIMScheduler, + ) -> None: + super().__init__(config) + + self.model_dim = context_encoder.model_dim + + self.sonar_embed_dim = config.sonar_embed_dim + + self.sonar_normalizer = sonar_normalizer + + self.encoder_frontend = encoder_frontend + """The frontend of the context encoder. + This frontend simply applies a pre-linear projection + (to increase dimensionality) then adds positional embeddings""" + + self.context_encoder = context_encoder + """A causal Transformer decoder""" + + self.noise_scheduler = noise_scheduler + """The diffusion noise scheduler""" + + self.denoiser = denoiser + + def extra_repr(self) -> str: + """:meta private:""" + s = super().extra_repr() + return f"{s}, dtype={self.dtype}" + + def forward( + self, + batch: EmbeddingsBatch, + noisy_batch: EmbeddingsBatch, + cf_guidance_prob: float = 0.0, + ) -> EmbeddingsBatch: + """ + Arguments: + - batch (`EmbeddingsBatch`): The clean batch of embeddings to encode the context. + If `unsupervised` this is the source embeddings. + If `supervised` this is the source+target embeddings. + + - noisy_batch (`EmbeddingsBatch`): the embeddings noised by the noise scheduler + If `unsupervised` this is noised source embeddings. + If `supervised` this is noised target-only embeddings. + + - cf_guidance_prob: probability of training without any guiding context + """ + # Get source lengths if any: + source_lengths = batch.source_lengths + + # Encode as context: + context = self.encode(batch) + + # Predict denoised output + output_batch = self.denoise( + noisy_batch=noisy_batch, + context=context, + source_lengths=source_lengths, + cf_guidance_prob=cf_guidance_prob, + ) + return output_batch + + def encode( + self, + batch: EmbeddingsBatch, + state_bag: Optional[LCMIncrementalStateBag] = None, + **kwargs, + ) -> EmbeddingsBatch: + """ + The main context encoder that takes in a sequence of sonar embeddings in B, T, D + and returns a sequence of the same shape after causal contextualization. + + Main modules: + `frontend`: linear projection to model_dim + optional positional embeddings, + `context_encoder`: Causal Transformer decoder to causally encode the context + """ + # Frontend + seqs, padding_mask = self.encoder_frontend( + batch.seqs, + batch.padding_mask, + state_bag=state_bag, + **kwargs, + ) + + # Main Transformer + seqs, padding_mask = self.context_encoder( + seqs, + padding_mask, + state_bag=state_bag, + **kwargs, + ) + + return EmbeddingsBatch(seqs=seqs, padding_mask=padding_mask) + + def denoise( + self, + noisy_batch: EmbeddingsBatch, + context: EmbeddingsBatch, + source_lengths: Optional[Tensor] = None, + cf_guidance_prob: float = 0.0, + state_bag: Optional[LCMIncrementalStateBag] = None, + inference: bool = False, + ) -> EmbeddingsBatch: + """Diffuse a noised sonar embedding conditioned on the encoded context""" + seqs, padding_mask = self.denoiser( + seqs=noisy_batch.seqs, + diffusion_timesteps=noisy_batch.diffusion_timesteps, + padding_mask=noisy_batch.padding_mask, + conditioning_variables=context.seqs, + conditioning_variables_padding_mask=context.padding_mask, + source_lengths=source_lengths, + cf_guidance_prob=cf_guidance_prob, + inference=inference, + ) + return EmbeddingsBatch(seqs=seqs, padding_mask=padding_mask) + + def prep_for_denoising(self, decoding_options): + """This setup is done once when we initialize the generator""" + self.guidance_scale = decoding_options.guidance_scale + self.guidance_rescale = decoding_options.guidance_rescale + self.initial_noise_scale = decoding_options.initial_noise_scale + self.timesteps = decoding_options.inference_timesteps + self.clip_noise = decoding_options.clip_noise + self.ddim_eta = decoding_options.ddim_eta + self.epsilon_scaling = decoding_options.epsilon_scaling + + # if guidance_scale > 1.0 we will duplicate batches + self.do_classifier_free_guidance = self.guidance_scale != 1.0 + + # Setup the diffusion training-like noise scheduler + # by updating the timesteps according to the decoding `inference_timesteps` + self.noise_scheduler.set_timesteps(self.timesteps, device=self.device) + + # Override the initial noise scale + self.noise_scheduler.init_noise_sigma = self.initial_noise_scale + # Override thresholding options: + if decoding_options.thresholding: + self.noise_scheduler.config.thresholding = decoding_options.thresholding + self.noise_scheduler.config.dynamic_thresholding_ratio = ( + decoding_options.dynamic_thresholding_ratio + ) + self.noise_scheduler.config.sample_max_value = ( + decoding_options.sample_max_value + ) + + def sample_initial_noise_vectors(self, batch_size: int): + # Check that we called `prep_for_denoising`: + assert hasattr(self, "clip_noise"), ( + "The model is not properly set for decoding, make sure to call `model.prep_for_denoising()`" + ) + + # Sample a noise vector for next embedding prediction + latents = torch.randn( + batch_size, 1, self.config.sonar_embed_dim, device=self.device + ) + + # Scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.noise_scheduler.init_noise_sigma + + # clip? + latents = latents.clip(-self.clip_noise, self.clip_noise) + return latents + + @torch.inference_mode() + def predict_next_sentence( # type: ignore + self, + batch: EmbeddingsBatch, + context: EmbeddingsBatch, + temperature: float = 1.0, + state_bag: Optional[LCMIncrementalStateBag] = None, + context_state_bag: Optional[LCMIncrementalStateBag] = None, + **kwargs, + ) -> Tuple[EmbeddingsBatch, EmbeddingsBatch]: + assert context_state_bag is not None, ( + "Expected a state_bag to incrementally encode the context" + ) + + if self.do_classifier_free_guidance: + logger.debug("Running inference with CF-guidance...") + return self.predict_next_sentence_with_cf_guidance( + batch=batch, + context=context, + temperature=temperature, + state_bag=state_bag, + context_state_bag=context_state_bag, + **kwargs, + ) + + # Normalize the input embeddings if we're expected to + # normalize outside of the model's forward pass + if self.sonar_normalizer is not None: + batch = batch.normalize_seqs(self.sonar_normalizer) + + # Encode context: + new_context = self.encode(batch, context_state_bag) + context_state_bag.increment_step_nr(1) + + # Append to context + context = EmbeddingsBatch(torch.cat((context.seqs, new_context.seqs), dim=1)) + + # Sample latents: + latents = self.sample_initial_noise_vectors(batch_size=batch.seqs.size(0)) + + # Denoise + diffusion_timesteps_schedule = self.noise_scheduler.timesteps + + for diffusion_timestep in diffusion_timesteps_schedule: + input_batch = EmbeddingsBatch( + seqs=latents, + diffusion_timesteps=diffusion_timestep.long().repeat( + (latents.shape[0], 1) + ), + ) + # Get model output + model_prediction = self.denoise( + noisy_batch=input_batch, + context=context, + state_bag=None, + inference=True, + ) + + scheduler_outputs = self.noise_scheduler.step( + model_output=model_prediction.seqs, + timestep=diffusion_timestep, + sample=latents, + eta=self.ddim_eta, + epsilon_scaling=self.epsilon_scaling, + ) + + # setup latents for the next diffusion step + latents = scheduler_outputs.prev_sample + # clip? + latents = latents.clip(-self.clip_noise, self.clip_noise) + + # Take the final predicted denoised sample (x_0 in the ddim paper) and denormalize if needed: + final_seqs = scheduler_outputs.pred_original_sample + + final_seqs = self.sonar_normalizer.denormalize(final_seqs) + + return EmbeddingsBatch(final_seqs, None), context + + @torch.inference_mode() + def predict_next_sentence_with_cf_guidance( # type: ignore + self, + batch: EmbeddingsBatch, + context: EmbeddingsBatch, + temperature: float = 1.0, + state_bag: Optional[LCMIncrementalStateBag] = None, + context_state_bag: Optional[LCMIncrementalStateBag] = None, + **kwargs, + ) -> Tuple[EmbeddingsBatch, EmbeddingsBatch]: + assert context_state_bag is not None, ( + "Expected a state_bag to incrementally encode the context" + ) + + # Normalize the input embeddings if we're expected to + # normalize outside of the model's forward pass + if self.sonar_normalizer is not None: + batch = batch.normalize_seqs(self.sonar_normalizer) + + # Encode context: + new_context = self.encode(batch, context_state_bag) + context_state_bag.increment_step_nr(1) + + # Append to context + context = EmbeddingsBatch(torch.cat((context.seqs, new_context.seqs), dim=1)) + + # Sample latents: + latents = self.sample_initial_noise_vectors(batch_size=batch.seqs.size(0)) + + # Denoise + diffusion_timesteps_schedule = self.noise_scheduler.timesteps + + # Duplicate the context and its padding mask, the second half will be ignored + _seq_lens = get_seq_lens(context.seqs, context.padding_mask) + + # add zeros: + _seq_lens = torch.concat((_seq_lens, torch.zeros_like(_seq_lens)), dim=0) + + context = EmbeddingsBatch( + torch.concat((context.seqs, torch.zeros_like(context.seqs)), dim=0), + PaddingMask(_seq_lens, batch_seq_len=context.seqs.size(1)), + ) + + batch_multiplier = 2 + for diffusion_timestep in diffusion_timesteps_schedule: + is_max_diffusion_step = ( + diffusion_timestep == self.noise_scheduler.num_diffusion_train_steps - 1 + ) + + input_batch = EmbeddingsBatch( + torch.concat(batch_multiplier * [latents], dim=0), + diffusion_timesteps=diffusion_timestep.long().repeat( + (latents.shape[0] * batch_multiplier, 1) + ), + ) + + model_prediction = self.denoise( + noisy_batch=input_batch, + context=context, + state_bag=None, + inference=True, + ) + + # If at the max step, do not step in the epsilon_scheduler + if is_max_diffusion_step: + # if beta_prod_t (denominator) is null i.e., + # the diffusion timestep is at its max value (num_training_stesp-1) + # no denoising will be performed. + + # Note that since the batch might be doubled because + # we're doing classifier-free guidance, we chunk the model output + # by batch_multiplier. If not at max_diffusion_step + # this chunking is performed in apply_classifier_free_guidance + scheduler_outputs = self.noise_scheduler.step( + model_output=model_prediction.seqs.chunk(batch_multiplier)[0], + timestep=diffusion_timestep, + sample=latents, + eta=self.ddim_eta, + epsilon_scaling=self.epsilon_scaling, + ) + else: + # Predict the noise residual according to the prediction type + predicted_noise = self.noise_scheduler.get_epsilon( + model_output=model_prediction.seqs, + sample=input_batch.seqs, + timestep=diffusion_timestep, + ) + + if self.do_classifier_free_guidance: + # Perform guidance if trained with cf-guidance: + # The returned predicted noise will combine the conditional and + # unconditional predictions i.e., from (2 x batch_size, 1, C) + # to: (batch_size, 1, C) + predicted_noise = self.apply_classifier_free_guidance( + predicted_noise + ) + + # The cf-guidance operates on predicted noises and although we + # can go back and forth between epsilon and predicted sample + # once we combine cond and uncond we cannot go back to predicted_x0 + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_outputs = self.noise_scheduler.step( + model_output=predicted_noise, + timestep=diffusion_timestep, + sample=latents, + eta=self.ddim_eta, + epsilon_scaling=self.epsilon_scaling, + prediction_type="epsilon", + ) + + # setup latents for the next diffusion step + latents = scheduler_outputs.prev_sample + # clip? + latents = latents.clip(-self.clip_noise, self.clip_noise) + + # Take the final predicted denoised sample (x_0 in the ddim paper) and denormalize if needed: + final_seqs = scheduler_outputs.pred_original_sample + + final_seqs = self.sonar_normalizer.denormalize(final_seqs) + + return EmbeddingsBatch(final_seqs, None), context + + def apply_classifier_free_guidance(self, predicted_noise: Tensor) -> Tensor: + """ " + Apply Classifier-Free Guidance with Rescale as introduced in Algorithm 2 of https://arxiv.org/pdf/2305.08891 + `pos` would be the conditional prediction `cond_prediction` + and `neg` the unconditional prediction `uncond_prediction`: + The batch during prefilling is prepared with the conditioning prefix in + the first half + """ + # Chunk and follow algorithm 2 + cond_prediction, uncond_prediction = predicted_noise.chunk(2) + + # Regular classifier-free guidance: + guided_noise_prediction = uncond_prediction + self.guidance_scale * ( + cond_prediction - uncond_prediction + ) + + # Rescale classifier-free guidance to prevent over-exposure + # Calculate standard deviations. + std_pos = cond_prediction.std(dim=-1, keepdim=True) + std_cfg = guided_noise_prediction.std(dim=-1, keepdim=True) + + # Apply guidance rescale with fused operations. + factor = std_pos / std_cfg + factor = self.guidance_rescale * factor + (1 - self.guidance_rescale) + + return factor * guided_noise_prediction + + +class TwoTowerDiffusionLCModelBuilder(AbstractLCModelBuilder): + """Builds modules of a diffusion-based LCM""" + + config: TwoTowerDiffusionLCModelConfig + denoiser_factory: LCMDenoiserTransformerFactory + + def __init__( + self, + config: TwoTowerDiffusionLCModelConfig, + *, + device: Optional[Device] = None, + dtype: Optional[DataType] = None, + ) -> None: + """ + :param config: + The configuration. + :param device: + The device on which to initialize modules. + :param dtype: + The data type of module parameters and buffers. + """ + super().__init__(config=config, device=device, dtype=dtype) + + self.context_encoder_factory = TransformerFactory( + model_dim=self.config.model_dim, + max_seq_len=self.config.max_seq_len, + config=self.config.context_encoder, + device=device, + dtype=dtype, + ) + + self.denoiser_factory = LCMDenoiserTransformerFactory( + model_dim=self.config.model_dim, + num_diffusion_train_timesteps=self.config.noise_scheduler.num_diffusion_train_steps, + max_seq_len=self.config.max_seq_len, + config=self.config.denoiser, + input_dim=self.config.sonar_embed_dim, + device=device, + dtype=dtype, + ) + + def build_model(self) -> TwoTowerDiffusionLCModel: + """Build a model.""" + + sonar_normalizer = self.build_sonar_normalizer() + assert sonar_normalizer is not None, ( + "TwoTowerDiffusionLCModel expects a `sonar_normalizer`" + ) + + # the context encoder + encoder_frontend = self.build_frontend() + + context_encoder = self.build_context_encoder() + + # the denoiser + denoiser = self.build_denoiser() + + noise_scheduler = self.build_noise_scheduler() + + return TwoTowerDiffusionLCModel( + config=self.config, + sonar_normalizer=sonar_normalizer, + context_encoder=context_encoder, + encoder_frontend=encoder_frontend, + denoiser=denoiser, + noise_scheduler=noise_scheduler, + ) + + def build_frontend(self) -> EncoderFrontend: + """Build the context encoder front-end.""" + + return EncoderFrontend( + sonar_embed_dim=self.config.sonar_embed_dim, + model_dim=self.config.model_dim, + config=self.config.frontend, + pos_encoder=self.context_encoder_factory.build_pos_encoder(), + device=self.device, + dtype=self.dtype, + ) + + def build_context_encoder(self) -> LCMTransformerDecoder: + """Build the context encoder.""" + + config = self.config.context_encoder + + num_layers = config.num_layers + assert num_layers > 0, "The context encoder needs a non-zero number of layers" + + layers = [self.context_encoder_factory.build_layer() for _ in range(num_layers)] + + self_attn_mask_factory = CausalAttentionMaskFactory() + + if config.final_norm_order_style is None: + # The final norm order style will be that of + # the layer-level norm order + final_norm_order = parse_norm_order(config.norm_order_style) + else: + final_norm_order = parse_norm_order(config.final_norm_order_style) + + layer_norm_factory = parse_layer_norm_factory(config.layer_normalization_style) + + return LCMTransformerDecoder( + layers, + self_attn_mask_factory=self_attn_mask_factory, + norm_order=final_norm_order, + layer_norm_factory=layer_norm_factory, + dropout_p=config.final_dropout_p, + device=self.device, + dtype=self.dtype, + ) + + def build_noise_scheduler(self) -> DDIMScheduler: + return DDIMScheduler(self.config.noise_scheduler) + + def build_denoiser(self) -> LCMDenoiser: + """Build a Transformer for diffusing noised latents.""" + return self.denoiser_factory.build_model() + + +def create_two_tower_diffusion_lcm_model( + config: TwoTowerDiffusionLCModelConfig, + *, + device: Optional[Device] = None, + dtype: Optional[DataType] = None, +) -> TwoTowerDiffusionLCModel: + """Create a DiffusionLCM model. + :param config: + The configuration. + :param device: + The device on which to initialize modules. + :param dtype: + The data type of module parameters and buffers. + """ + return TwoTowerDiffusionLCModelBuilder( + config, + device=device, + dtype=dtype, # type: ignore + ).build_model() diff --git a/lcm/models/two_tower_diffusion_lcm/frontend.py b/lcm/models/two_tower_diffusion_lcm/frontend.py new file mode 100644 index 0000000000000000000000000000000000000000..d061e32286d8ea52fb53d2ab7116e5acc39a861e --- /dev/null +++ b/lcm/models/two_tower_diffusion_lcm/frontend.py @@ -0,0 +1,152 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# + +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from fairseq2.logging import get_log_writer +from fairseq2.nn import Embedding, LearnedPositionEncoder, PositionEncoder +from fairseq2.nn.incremental_state import IncrementalStateBag +from fairseq2.nn.padding import PaddingMask +from fairseq2.nn.projection import Linear +from fairseq2.typing import DataType, Device +from torch import Tensor +from torch.nn import Dropout, Module + +from lcm.nn.initialization import SUPPORTED_INIT_TYPES, get_init_fn + +logger = get_log_writer(__name__) + + +@dataclass +class EncoderFrontendConfig: + dropout_p: float = 0.0 + """ The dropout probability applied to the module' output""" + + pre_linear_bias: bool = True + """ Whether or not the pre-linear layer has a bias term""" + + pre_linear_init_fn: SUPPORTED_INIT_TYPES = "kaiming_uniform" + + weight_normalization: bool = False + + embedding_std: float = 1.0 + + +class EncoderFrontend(Module): + """ + A fronted for the context encoder in encoder-decoder LCMs + """ + + embed: Embedding + pos_encoder: Optional[PositionEncoder] + dropout: Optional[Dropout] + + def __init__( + self, + sonar_embed_dim: int, + model_dim: int, + config: EncoderFrontendConfig, + pos_encoder: Optional[PositionEncoder], + *, + device: Optional[Device] = None, + dtype: Optional[DataType] = None, + ) -> None: + """ + :param sonar_embed_dim + The embedding dimension of the sentence encoder, in this case SONAR + :param model_dim + The model embedding dimension + :param config: + A Frontend config. See `LCMFrontendConfig` + :param pos_encoder: + An optional position encoder. + """ + + super().__init__() + + self.sonar_embed_dim = sonar_embed_dim + + self.model_dim = model_dim + + self.device = device + + # Pre-linear to map to model dimension + init_fn = get_init_fn(config.pre_linear_init_fn) + + lin = Linear( + sonar_embed_dim, + model_dim, + bias=config.pre_linear_bias, + device=device, + dtype=dtype, + init_fn=init_fn, + ) + + if config.weight_normalization: + self.pre_linear = torch.nn.utils.parametrizations.weight_norm(lin) + else: + self.pre_linear = lin + + if pos_encoder is not None: + if pos_encoder.encoding_dim != self.model_dim: + raise ValueError( + f"`encoding_dim` of `pos_encoder` and `embedding_dim` of \ + `embed` must be equal, but are {pos_encoder.encoding_dim} \ + and {self.model_dim} instead." + ) + + self.pos_encoder = pos_encoder + else: + self.register_module("pos_encoder", None) + + if config.dropout_p > 0.0: + self.dropout = Dropout(config.dropout_p) + else: + self.register_module("dropout", None) + + self.reset_parameters(embedding_std=config.embedding_std) + + def reset_parameters(self, embedding_std: float) -> None: + """Initialize module parameters. + The positional embeddings should be initialized with the + same order of magnitude as the semantic embeddings, in order + to make the early training as stable as possible. + Otherwise, the positional and special token embeddings would + flood out the semantic information. + """ + logger.info( + f"Initializing frontend embeddings (special and positional) ~ N(0, {embedding_std})" + ) + if isinstance(self.pos_encoder, LearnedPositionEncoder): + torch.nn.init.normal_(self.pos_encoder.weight, std=embedding_std) + + def forward( + self, + seqs: Tensor, + padding_mask: Optional[PaddingMask], + state_bag: Optional[IncrementalStateBag] = None, + **kwargs, + ) -> Tuple[Tensor, Optional[PaddingMask]]: + """ + Apply pre-linear (if relevant) and add positional embeddings + """ + + # pre-linear if any: + seqs = self.pre_linear(seqs) + + if self.pos_encoder is not None: + seqs = self.pos_encoder( + seqs, + padding_mask, + state_bag=state_bag, + **kwargs, + ) + + if self.dropout is not None: + seqs = self.dropout(seqs) + + return seqs, padding_mask diff --git a/lcm/models/two_tower_diffusion_lcm/loader.py b/lcm/models/two_tower_diffusion_lcm/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..ee66083fc5144fdffa502e3e5262af93a95b4dbf --- /dev/null +++ b/lcm/models/two_tower_diffusion_lcm/loader.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# + + +from fairseq2.models.config_loader import StandardModelConfigLoader +from fairseq2.models.loader import StandardModelLoader, load_model + +from lcm.models.base_lcm.loader import convert_lcm_checkpoint +from lcm.models.two_tower_diffusion_lcm.builder import ( + TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE, + TwoTowerDiffusionLCModelConfig, + create_two_tower_diffusion_lcm_model, + lcm_archs, +) +from lcm.utils.model_type_registry import ModelTypeConfig, lcm_model_type_registry + +load_two_tower_diffusion_lcm_config = StandardModelConfigLoader( + family=TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE, + config_kls=TwoTowerDiffusionLCModelConfig, + arch_configs=lcm_archs, +) + + +load_two_tower_diffusion_lcm_model = StandardModelLoader( # type: ignore # FIXME + config_loader=load_two_tower_diffusion_lcm_config, + factory=create_two_tower_diffusion_lcm_model, + checkpoint_converter=convert_lcm_checkpoint, + restrict_checkpoints=False, +) + +load_model.register( + TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE, load_two_tower_diffusion_lcm_model +) + +lcm_model_type_registry.register( + ModelTypeConfig( + model_type=TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE, + config_loader=load_two_tower_diffusion_lcm_config, + model_factory=create_two_tower_diffusion_lcm_model, + model_loader=load_two_tower_diffusion_lcm_model, + ) +) diff --git a/lcm/nn/__init__.py b/lcm/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..28faebf07e16031daaa4182323df5af116055ae2 --- /dev/null +++ b/lcm/nn/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# diff --git a/lcm/nn/denoisers/__init__.py b/lcm/nn/denoisers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f2f846dfe1c767674822d1e9b737574ae2a9e9b2 --- /dev/null +++ b/lcm/nn/denoisers/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# + + +from lcm.nn.denoisers.factory import ( + DenoiserConfig, + LCMDenoiser, + LCMDenoiserTransformerFactory, +) + +__all__ = [ + "DenoiserConfig", + "LCMDenoiser", + "LCMDenoiserTransformerFactory", +] diff --git a/lcm/nn/denoisers/attention_masks.py b/lcm/nn/denoisers/attention_masks.py new file mode 100644 index 0000000000000000000000000000000000000000..2d96ea41352a0a037b3c68e92244a89becb18389 --- /dev/null +++ b/lcm/nn/denoisers/attention_masks.py @@ -0,0 +1,228 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# + +import math +from typing import Optional, final + +import torch +from fairseq2.nn.transformer import ( + AbstractAttentionMask, + AttentionMask, + AttentionMaskFactory, +) +from fairseq2.typing import DataType, Device, override +from torch import Tensor + +from lcm.nn.incremental_state import LCMIncrementalStateBag + + +def _get_shifted_causal_mask( + seq_len: int, + key_len: int, + shift: int = 0, + cf_guidance_prob: float = 0.0, + zero_vector: bool = False, + device: Optional[Device] = None, + dtype: Optional[DataType] = None, +) -> Tensor: + causal_mask = torch.ones( + (seq_len, key_len), + device=device, + dtype=dtype, + ) + causal_mask.tril_(diagonal=shift) + + if cf_guidance_prob > 0.0: + num_rows_to_drop = math.floor((seq_len - 1) * cf_guidance_prob) + if num_rows_to_drop > 0: + rows_to_drop = 1 + torch.randperm(seq_len - 1)[:num_rows_to_drop] + if zero_vector: + causal_mask[rows_to_drop, 1:] = 0 + else: + causal_mask[rows_to_drop, :] = 0 + + return causal_mask + + +class NoAttentionMaskFactory(AttentionMaskFactory): + """Constructs instances of :class:`NoAttentionMask`.""" + + @override + def __call__( # type: ignore + self, + seqs: Tensor, + keys: Tensor, + *, + training: bool = True, + state_bag: Optional[LCMIncrementalStateBag] = None, + inference_without_caching: Optional[bool] = False, + **kwargs, + ) -> Optional[AttentionMask]: + mask: NoAttentionMask + + attn_len: Optional[int] = seqs.size(1) + seq_len = seqs.size(1) + key_len = keys.size(1) + + mask = NoAttentionMask( + seq_len=seq_len, + key_len=key_len, + attn_len=attn_len, + device=seqs.device, + dtype=seqs.dtype, + ) + return mask + + def __repr__(self) -> str: + return "NoAttentionMaskFactory()" + + +@final +class NoAttentionMask(AbstractAttentionMask): + """ + Represents a diagonal attention mask, i.e attention + on current position only. + This turns the self-attention layer into an FFN + """ + + def __init__( + self, + seq_len: int, + key_len: int, + attn_len: Optional[int], + *, + device: Optional[Device] = None, + dtype: Optional[DataType] = None, + ) -> None: + """ + :param seq_len: + The sequence length. + """ + super().__init__() + + self.seq_len = seq_len + + self._device, self._dtype = device, dtype + + @override + def _do_materialize(self) -> Tensor: + mask = torch.eye((self.seq_len), device=self._device, dtype=self._dtype) + mask.log_() + return mask + + +class ShiftedCausalAttentionMaskFactory(AttentionMaskFactory): + """ + Constructs instances of :class:`ShiftedCausalAttentionMask` + """ + + @override + def __call__( # type: ignore + self, + seqs: Tensor, + keys: Tensor, + *, + source_lengths: Optional[Tensor] = None, + cf_guidance_prob: float = 0.0, + training: bool = True, + state_bag: Optional[LCMIncrementalStateBag] = None, + inference: bool = False, + ) -> Optional[AttentionMask]: + mask: Optional[ShiftedCausalAttentionMask] + + attn_len: Optional[int] = seqs.size(1) + seq_len = seqs.size(1) + key_len = keys.size(1) + + if inference: + mask = None + else: + mask = ShiftedCausalAttentionMask( + seq_len=seq_len, + key_len=key_len, + attn_len=attn_len, + source_lengths=source_lengths, + cf_guidance_prob=cf_guidance_prob, + device=seqs.device, + dtype=seqs.dtype, + ) + + return mask + + def __repr__(self) -> str: + return "ShiftedCausalAttentionMask()" + + +@final +class ShiftedCausalAttentionMask(AbstractAttentionMask): + """ + Represents a causal mask shifted by source_lengths + + In training time, Without source_lengths, the mask look like (e.g. seq_len = 5): + + [ 0., -inf, -inf, -inf, -inf, -inf], + [ 0., 0., -inf, -inf, -inf, -inf], + [ 0., 0., 0., -inf, -inf, -inf], + [ 0., 0., 0., 0., -inf, -inf], + [ 0., 0., 0., 0., 0., -inf] + + """ + + def __init__( + self, + seq_len: int, + key_len: int, + attn_len: Optional[int], + *, + source_lengths: Optional[Tensor] = None, + cf_guidance_prob: float = 0.0, + device: Optional[Device] = None, + dtype: Optional[DataType] = None, + ) -> None: + """ + :param seq_len: + The sequence length. + """ + super().__init__() + + self.seq_len = seq_len + self.key_len = key_len + self._source_lengths = source_lengths + self._cf_guidance_prob = cf_guidance_prob + self._device, self._dtype = device, dtype + + @override + def _do_materialize(self) -> Tensor: + if self._source_lengths is None: + causal_mask = _get_shifted_causal_mask( + seq_len=self.seq_len, + key_len=self.key_len, + shift=0, + cf_guidance_prob=self._cf_guidance_prob, + zero_vector=True, + device=self._device, + dtype=self._dtype, + ) + + else: + causal_mask = torch.stack( + [ + _get_shifted_causal_mask( + seq_len=self.seq_len, + key_len=self.key_len, + shift=src_len, + cf_guidance_prob=self._cf_guidance_prob, + zero_vector=True, + device=self._device, + dtype=self._dtype, + ) + for src_len in self._source_lengths + ] + ).unsqueeze(1) + # bs x 1 (head) x seq_len x seq_len + + causal_mask.log_() + + return causal_mask diff --git a/lcm/nn/denoisers/factory.py b/lcm/nn/denoisers/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..96516845a29d2262008c3f709e1dd54f4b5aa5f8 --- /dev/null +++ b/lcm/nn/denoisers/factory.py @@ -0,0 +1,192 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# + +from dataclasses import dataclass, field +from typing import Literal, Optional + +from fairseq2.logging import get_log_writer +from fairseq2.typing import DataType, Device + +from lcm.nn.denoisers.attention_masks import ( + NoAttentionMaskFactory, + ShiftedCausalAttentionMaskFactory, +) +from lcm.nn.denoisers.lcm_denoiser import ( + LCMDenoiser, + LCMDenoiserLayer, +) +from lcm.nn.initialization import parse_norm_order +from lcm.nn.normalization import parse_layer_norm_factory +from lcm.nn.projection import ( + Projection, + ProjectionConfig, +) +from lcm.nn.timestep_encoder import DiTTimestepEncoder +from lcm.nn.transformer import TransformerConfig, TransformerFactory + +logger = get_log_writer(__name__) + + +@dataclass +class DenoiserConfig(TransformerConfig): + """Config for building the LCM's denoiser""" + + pos_embedding_style: Literal["rope", "sine", "learned", "none"] = "none" + """By default, a denoiser does not have a positional embedder""" + + pre_denoiser: ProjectionConfig = field(default_factory=lambda: ProjectionConfig()) + """the initial projection at the top of the denoiser""" + + post_denoiser: ProjectionConfig = field(default_factory=lambda: ProjectionConfig()) + """the final output projection at the end of the denoiser""" + + timestep_embed_dim: int = 1024 + """Diffusion timestep embedding dimension""" + + +class LCMDenoiserTransformerFactory(TransformerFactory): + """Denoiser with hybrid AdaLN and cross-attention""" + + config: DenoiserConfig + + def __init__( + self, + model_dim: int, + max_seq_len: int, + num_diffusion_train_timesteps: int, + config: DenoiserConfig, + input_dim: int = 1024, + device: Optional[Device] = None, + dtype: Optional[DataType] = None, + ) -> None: + """ + :param model_dim: + The hidden model dimension of the Transformer + :params max_seqs_len: + Maximum supported sequence length by the model + :param config: + The configuration. + :param input_dim: + The input embedding dimension i.e `sonar_embed_dim`` + :param device: + The device on which to initialize modules. + :param dtype: + The data type of module parameters and buffers. + """ + super().__init__( + model_dim=model_dim, + max_seq_len=max_seq_len, + config=config, + device=device, + dtype=dtype, + ) + + self.input_dim = input_dim + + self.num_diffusion_train_timesteps = num_diffusion_train_timesteps + + def build_cross_attention_mask(self): + return ShiftedCausalAttentionMaskFactory() + + def build_timestep_embedder(self): + return DiTTimestepEncoder( + embedding_dim=self.config.timestep_embed_dim, + dtype=self.dtype, + device=self.device, + ) + + def build_initial_proj(self) -> Projection: + # We will be concatenating context and timesteps embeddings + assert self.config.timestep_embed_dim == self.model_dim, ( + "Since the timestep embeddings will be added to the sequence of " + "conditioning variables, they need to be of the same dimension. " + f"Found timestep_embed_dim={self.config.timestep_embed_dim} " + f"and model_dim={self.model_dim}" + ) + + return Projection( + output_dim=self.model_dim, + input_dim=self.input_dim, + config=self.config.pre_denoiser, + device=self.device, + dtype=self.dtype, + ) + + def build_final_proj(self) -> Projection: + return Projection( + output_dim=self.input_dim, + input_dim=self.model_dim, + config=self.config.post_denoiser, + device=self.device, + dtype=self.dtype, + ) + + def build_model(self) -> LCMDenoiser: + """Build the denoiser with its layers and initial/final projections""" + embed_time = self.build_timestep_embedder() + + layers = [self.build_layer() for _ in range(self.config.num_layers)] + + norm_order = parse_norm_order(self.config.norm_order_style) + + # Self-attention here does not contextualize + self_attn_mask_factory = NoAttentionMaskFactory() + + cross_attention_mask_factory = self.build_cross_attention_mask() + + layer_norm_factory = parse_layer_norm_factory( + self.config.layer_normalization_style + ) + + pos_encoder = self.build_pos_encoder() + + return LCMDenoiser( + embed_time=embed_time, + layers=layers, + initial_proj=self.build_initial_proj(), + final_proj=self.build_final_proj(), + dropout_p=self.config.final_dropout_p, + norm_order=norm_order, + layer_norm_factory=layer_norm_factory, + self_attn_mask_factory=self_attn_mask_factory, + cross_attention_mask_factory=cross_attention_mask_factory, + pos_encoder=pos_encoder, + device=self.device, + dtype=self.dtype, + ) + + def build_layer(self) -> LCMDenoiserLayer: + """Build a Transformer decoder layer based on the provided config.""" + + assert isinstance(self.config, DenoiserConfig), ( + "Expecting a DenoiserConfig in the DenoiserTransformerFactory" + ) + + self_attn = self.build_attention() + + cross_attn = self.build_attention() + + ffn = self.build_ffn() + + norm_order = parse_norm_order(self.config.norm_order_style) + + layer_norm_factory = parse_layer_norm_factory( + self.config.layer_normalization_style + ) + + modulator_input_dim = self_attn.model_dim + + layer = LCMDenoiserLayer( + self_attn=self_attn, + cross_attention=cross_attn, + ffn=ffn, + modulator_input_dim=modulator_input_dim, + dropout_p=self.config.dropout_p, + norm_order=norm_order, + layer_norm_factory=layer_norm_factory, + device=self.device, + dtype=self.dtype, + ) + return layer diff --git a/lcm/nn/denoisers/lcm_denoiser.py b/lcm/nn/denoisers/lcm_denoiser.py new file mode 100644 index 0000000000000000000000000000000000000000..6129f0d386e758109264162b33f5f8055641f8a6 --- /dev/null +++ b/lcm/nn/denoisers/lcm_denoiser.py @@ -0,0 +1,546 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# + +from typing import Iterable, Optional, Tuple, cast + +import torch +import torch.nn as nn +from fairseq2.nn import PositionEncoder +from fairseq2.nn.incremental_state import IncrementalStateBag +from fairseq2.nn.normalization import LayerNorm +from fairseq2.nn.padding import PaddingMask +from fairseq2.nn.transformer import ( + AttentionMask, + AttentionMaskFactory, + FeedForwardNetwork, + LayerNormFactory, + MultiheadAttention, + TransformerDecoderLayer, + TransformerNormOrder, + create_standard_layer_norm, +) +from fairseq2.typing import DataType, Device, override +from torch import Tensor +from torch.nn import Dropout, Module, ModuleList +from torch.nn.parameter import Parameter + +from lcm.nn.projection import Projection +from lcm.nn.timestep_encoder import DiTTimestepEncoder + + +class AdaLNModulator(Module): + """An adaptive LayerNorm modulator to estimate + shift, gate and scale for all 3 sub-modules.""" + + def __init__( + self, + input_dim: int, + output_dim: int, + device: Optional[Device] = None, + dtype: Optional[DataType] = None, + ): + super().__init__() + + self.activate = nn.SiLU() + self.fc = nn.Linear( + input_dim, + 9 * output_dim, + bias=True, + device=device, + dtype=dtype, + ) + + def reset_parameters(self): + # zero-init + nn.init.constant_(self.fc.weight, 0) + nn.init.constant_(self.fc.bias, 0) + + def forward(self, context: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + (modulate_san, modulate_cross_attention, modulate_ffn) = self.fc( + self.activate(context) + ).chunk(3, dim=-1) + return modulate_san, modulate_cross_attention, modulate_ffn + + +class LCMDenoiser(Module): + """ + The main denoiser module of the two-tower diffusion LCM. + """ + + model_dim: int + layers: ModuleList + self_attn_mask_factory: AttentionMaskFactory + layer_norm: Optional[LayerNorm] + dropout_p: float + norm_order: TransformerNormOrder + cross_attention_mask_factory: AttentionMaskFactory + + def __init__( + self, + embed_time: DiTTimestepEncoder, + layers: Iterable[TransformerDecoderLayer], + initial_proj: Projection, + final_proj: Projection, + *, + self_attn_mask_factory: AttentionMaskFactory, + cross_attention_mask_factory: AttentionMaskFactory, + dropout_p: float = 0.0, + norm_order: TransformerNormOrder = TransformerNormOrder.POST, + pos_encoder: Optional[PositionEncoder] = None, + layer_norm_factory: Optional[LayerNormFactory] = None, + device: Optional[Device] = None, + dtype: Optional[DataType] = None, + ) -> None: + """ + :param layers: + The decoder layers. + :param self_attn_mask_factory: + The self attention mask factory. + :param cross_attention_mask_factory: + The cross attention mask factory. + :param dropout_p: + The dropout probability on decoder outputs. + :param norm_order: + The Layer Normalization order. + :param: pos_encoder: + An optional positional encoding module + :param layer_norm_factory: + The factory to construct the Layer Normalization module. + """ + layer_list = ModuleList(layers) + + if not layer_list: + raise ValueError("`layers` must be non-empty.") + + model_dim = layer_list[0].model_dim + + super().__init__() + + self.model_dim = model_dim + + self.embed_time = embed_time + + self.initial_proj = initial_proj + + self.final_proj = final_proj + + self.pos_encoder = pos_encoder + + if layer_norm_factory is None: + layer_norm_factory = create_standard_layer_norm + + self.self_attn_mask_factory = self_attn_mask_factory + + self.layers = layer_list + + if norm_order != TransformerNormOrder.POST: + self.layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype) + else: + self.register_module("layer_norm", None) + + if dropout_p > 0.0: + self.dropout = Dropout(dropout_p) + else: + self.register_module("dropout", None) + + self.norm_order = norm_order + + self.cross_attention_mask_factory = cross_attention_mask_factory + + def forward( + self, + seqs: Tensor, + diffusion_timesteps: Tensor, + padding_mask: Optional[PaddingMask], + conditioning_variables: Optional[Tensor] = None, + conditioning_variables_padding_mask: Optional[PaddingMask] = None, + source_lengths: Optional[Tensor] = None, + cf_guidance_prob: float = 0.0, + *, + state_bag: Optional[IncrementalStateBag] = None, + inference: Optional[bool] = False, + ) -> Tuple[Tensor, Optional[PaddingMask]]: + """ + Arguments: + - seqs (`Tensor`): the sequence of latents to denoise + - diffusion_timesteps (`Tensor`) the indices of the diffusion timesteps + to be embedded and fed as a conditioning variable. + - padding_mask (`PaddingMask`) mask of padded positions in the latents (seqs) + + - conditioning_variables (`Tensor`) the sequence of conditioning + variables that will be combined with the timestep embedding to + guide the diffusion process + - conditioning_variables_padding_mask (`PaddingMask`) the mask of padded + positions in `conditioning_variables` + - source_lengths (`Optional[Tensor]`) the lengths of the source embeddings + in `conditioning_variables` to properly shift the cross-attention mask + - cf_guidance_prob: probability rate with which to drop all conditioning variables when denoising + - state_bag (`IncrementalStateBag`) the incremental state bag of the denoiser to enable kv-caching + - inference (`bool`) if `True` the cross-attention mask will be adjusted accordingly + """ + + emb_timesteps = self.embed_time(diffusion_timesteps) + assert conditioning_variables is not None, ( + "Expected conditioning_variables, found None" + ) + + assert conditioning_variables is not None, ( + "Mypy - Expecting non-None conditioning_variables" + ) + + conditioning_variables = torch.cat( + [ + torch.zeros_like(conditioning_variables[:, 0:1]), + conditioning_variables, + ], + dim=1, + ) + + if conditioning_variables_padding_mask is not None: + # shift by the length of the prepended timesteps + conditioning_variables_padding_mask = PaddingMask( + conditioning_variables_padding_mask._seq_lens + 1, + conditioning_variables_padding_mask._batch_seq_len + 1, + ) + + # project to model_dim and add optional position codes: + seqs = self.initial_proj(seqs) + + if self.pos_encoder is not None: + seqs = self.pos_encoder(seqs, padding_mask) + + self_attn_mask = self.self_attn_mask_factory( + seqs, keys=seqs, training=self.training, state_bag=state_bag + ) + + assert conditioning_variables is not None + cross_attention_mask = self.cross_attention_mask_factory( + seqs, + keys=conditioning_variables, + source_lengths=source_lengths, + cf_guidance_prob=cf_guidance_prob, + training=self.training, + state_bag=state_bag, + inference=inference, # type: ignore + ) + + for layer_idx, layer in enumerate(self.layers): + layer_output, layer_padding_mask = layer( + seqs=seqs, + padding_mask=padding_mask, + self_attn_mask=self_attn_mask, + emb_timesteps=emb_timesteps, + conditioning_variables=conditioning_variables, + conditioning_variables_padding_mask=conditioning_variables_padding_mask, + cross_attention_mask=cross_attention_mask, + state_bag=state_bag, + ) + + seqs, padding_mask = layer_output, layer_padding_mask + + if self.layer_norm is not None: + seqs = self.layer_norm(seqs) + + if self.dropout is not None: + seqs = self.dropout(seqs) + + seqs = self.final_proj(seqs) + + return seqs, padding_mask + + +class LCMDenoiserLayer(TransformerDecoderLayer): + """A single layer of the hybrid denoiser""" + + self_attn: MultiheadAttention + self_attn_norm: Optional[LayerNorm] + self_attn_dropout: Optional[Dropout] + self_attn_layer_norm: LayerNorm + cross_attention: MultiheadAttention + cross_attention_dropout: Optional[Dropout] + cross_attention_layer_norm: Optional[LayerNorm] + ffn: FeedForwardNetwork + ffn_dropout: Optional[Dropout] + residual_scale: Optional[Parameter] + ffn_layer_norm: LayerNorm + norm_order: TransformerNormOrder + + def __init__( + self, + self_attn: MultiheadAttention, + ffn: FeedForwardNetwork, + cross_attention: MultiheadAttention, + *, + scale_residual: bool = False, + dropout_p: float = 0.0, + norm_order: TransformerNormOrder = TransformerNormOrder.POST, + layer_norm_factory: Optional[LayerNormFactory] = None, + modulator_input_dim: Optional[int] = None, + device: Optional[Device] = None, + dtype: Optional[DataType] = None, + ) -> None: + """ + :param self_attn: + The self attention layer. + :param cross_attention: + The cross attention layer if denoiser-type is `cross-attention`. + :param ffn: + The feed-forward network. + :param scale_residual: + If ``True``, scales residuals before adding them to the output of + the feed-forward network as described in + :cite:t:`https://doi.org/10.48550/arxiv.2110.09456`. + :param dropout_p: + The dropout probability on outputs of the attention layers and the + feed-forward network. + :param norm_order: + The Layer Normalization order. + :param layer_norm_factory: + The factory to construct the Layer Normalization modules. + """ + model_dim = self_attn.model_dim + + super().__init__(model_dim) + + if layer_norm_factory is None: + layer_norm_factory = create_standard_layer_norm + + self_attn_layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype) + + if norm_order != TransformerNormOrder.POST: + self.self_attn_layer_norm = self_attn_layer_norm + + self.self_attn = self_attn + + if norm_order == TransformerNormOrder.PRE_WITH_NORMFORMER: + self.self_attn_norm = layer_norm_factory( + model_dim, device=device, dtype=dtype + ) + else: + self.register_module("self_attn_norm", None) + + if dropout_p > 0.0: + self.self_attn_dropout = Dropout(dropout_p) + else: + self.register_module("self_attn_dropout", None) + + if norm_order == TransformerNormOrder.POST: + self.self_attn_layer_norm = self_attn_layer_norm + + # Deal with the cross-attention layers: + if cross_attention is None: + self.register_module("cross_attention", None) + self.register_module("cross_attention_layer_norm", None) + else: + cross_attention_layer_norm = layer_norm_factory( + model_dim, device=device, dtype=dtype + ) + + if norm_order != TransformerNormOrder.POST: + self.cross_attention_layer_norm = cross_attention_layer_norm + + self.cross_attention = cross_attention + + if dropout_p > 0.0: + self.cross_attention_dropout = Dropout(dropout_p) + else: + self.register_module("cross_attention_dropout", None) + + if norm_order == TransformerNormOrder.POST: + self.cross_attention_layer_norm = cross_attention_layer_norm + # / deal with cross-attention + + ffn_layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype) + + if norm_order != TransformerNormOrder.POST: + self.ffn_layer_norm = ffn_layer_norm + + self.ffn = ffn + + if dropout_p > 0.0: + self.ffn_dropout = Dropout(dropout_p) + else: + self.register_module("ffn_dropout", None) + + if norm_order == TransformerNormOrder.POST: + self.ffn_layer_norm = ffn_layer_norm + + self.norm_order = norm_order + + # Add a modulator: + modulator_input_dim = modulator_input_dim or model_dim + self.modulator = AdaLNModulator( + input_dim=modulator_input_dim, + output_dim=model_dim, + device=device, + dtype=dtype, + ) + + self.reset_parameters() + + def reset_parameters(self) -> None: + """Reset the parameters and buffers of the module.""" + # Zero-out the modulators: + self.modulator.reset_parameters() + + @override + def forward( # type: ignore + self, + seqs: Tensor, + padding_mask: Optional[PaddingMask], + conditioning_variables: Tensor, + emb_timesteps: Tensor, + self_attn_mask: Optional[AttentionMask] = None, + conditioning_variables_padding_mask: Optional[PaddingMask] = None, + cross_attention_mask: Optional[AttentionMask] = None, + *, + state_bag: Optional[IncrementalStateBag] = None, + ) -> Tuple[Tensor, Optional[PaddingMask]]: + # Get modulator output: + (modulate_san, modulate_cross_attention, modulate_ffn) = self.modulator( + emb_timesteps + ) + + seqs = self._forward_self_attn( + seqs=seqs, + padding_mask=padding_mask, + modulators=modulate_san, + self_attn_mask=self_attn_mask, + state_bag=state_bag, + ) + + seqs = self._forward_cross_attention( + seqs=seqs, + padding_mask=padding_mask, + conditioning_variables=conditioning_variables, + modulators=modulate_cross_attention, + cross_attention_mask=cross_attention_mask, + key_padding_mask=conditioning_variables_padding_mask, + state_bag=state_bag, + ) + + seqs = self._forward_ffn( + seqs=seqs, + modulators=modulate_ffn, + ) + + return seqs, padding_mask + + def _forward_self_attn( + self, + seqs: Tensor, + modulators: Tensor, + padding_mask: Optional[PaddingMask], + self_attn_mask: Optional[AttentionMask], + state_bag: Optional[IncrementalStateBag], + ) -> Tensor: + residual = seqs + + assert self.norm_order != TransformerNormOrder.POST, ( + "DiT AdaLN expect pre-normalization" + ) + + if self.norm_order != TransformerNormOrder.POST: + seqs = self.self_attn_layer_norm(seqs) + + # split modulators into shift, scale and gate: + shift, scale, gate = modulators.chunk(3, dim=-1) + + # modulate the input: + seqs = seqs * (1 + scale) + shift + + seqs = self.self_attn( + seqs, + padding_mask, + keys=seqs, + key_padding_mask=None, + values=seqs, + attn_mask=self_attn_mask, + state_bag=state_bag, + ) + + if self.self_attn_norm is not None: + seqs = self.self_attn_norm(seqs) + + if self.self_attn_dropout is not None: + seqs = self.self_attn_dropout(seqs) + + # Scale the residual with the gate weights + seqs = residual + gate * seqs + + return seqs + + def _forward_cross_attention( + self, + seqs: Tensor, + modulators: Tensor, + padding_mask: Optional[PaddingMask], + conditioning_variables: Optional[Tensor], + key_padding_mask: Optional[PaddingMask], + cross_attention_mask: Optional[AttentionMask], + state_bag: Optional[IncrementalStateBag], + ) -> Tensor: + if conditioning_variables is None: + raise ValueError( + "`conditioning_variables` must not be `None` for cross attention." + ) + + residual = seqs + + assert self.norm_order != TransformerNormOrder.POST, ( + "DiT AdaLN expect pre-normalization" + ) + + if self.norm_order != TransformerNormOrder.POST: + seqs = cast(LayerNorm, self.cross_attention_layer_norm)(seqs) + + # split modulators into shift, scale and gate: + shift, scale, gate = modulators.chunk(3, dim=-1) + + # modulate the input: + seqs = seqs * (1 + scale) + shift + + seqs = self.cross_attention( + seqs, + padding_mask, + keys=conditioning_variables, + key_padding_mask=key_padding_mask, + attn_mask=cross_attention_mask, + values=conditioning_variables, + state_bag=state_bag, + ) + + if self.cross_attention_dropout is not None: + seqs = self.cross_attention_dropout(seqs) + + # Scale the residual with the gate weights + seqs = residual + gate * seqs + + return seqs + + def _forward_ffn(self, seqs: Tensor, modulators: Tensor) -> Tensor: + assert self.norm_order != TransformerNormOrder.POST, ( + "DiT AdaLN expects pre-normalization" + ) + residual = seqs + + if self.norm_order != TransformerNormOrder.POST: + seqs = self.ffn_layer_norm(seqs) + + # split modulators into shift, scale and gate: + shift, scale, gate = modulators.chunk(3, dim=-1) + + # modulate the input: + seqs = seqs * (1 + scale) + shift + + seqs = self.ffn(seqs) + + if self.ffn_dropout is not None: + seqs = self.ffn_dropout(seqs) + + # Scale the branch with the gate weights + seqs = residual + gate * seqs + + return seqs diff --git a/lcm/nn/incremental_state.py b/lcm/nn/incremental_state.py new file mode 100644 index 0000000000000000000000000000000000000000..d68b37b34e025091846fbea1a18d8c990c0887ca --- /dev/null +++ b/lcm/nn/incremental_state.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# + +from typing import Dict, Optional, final + +from fairseq2.nn.incremental_state import IncrementalState, IncrementalStateBag +from fairseq2.nn.transformer import FullAttentionState +from torch import Tensor +from torch.nn import Module + + +@final +class LCMIncrementalStateBag(IncrementalStateBag): # type: ignore + """Holds the module states during incremental decoding.""" + + _module_states: Dict[Module, FullAttentionState] # type: ignore + + def __init__( + self, max_num_steps: int, *, capacity_increment: Optional[int] = 16 + ) -> None: + super().__init__( + max_num_steps=max_num_steps, capacity_increment=capacity_increment + ) + + def reorder(self, new_order: Tensor) -> None: + """Reorder the module states. + + See :meth:`IncrementalState.reorder` for more information. + """ + # FIXME Deal with reordering diffusion state bags here + for state in self._module_states.values(): + state.reorder(new_order) + + def set_state(self, m: Module, state: IncrementalState) -> None: + """Set the state of ``m``. + :param m: The module. + :param state: The state to store. + There is no current call to `set_state` when the bag + is frozen, but it's implemented here for completeness + """ + super().set_state(m, state) diff --git a/lcm/nn/initialization.py b/lcm/nn/initialization.py new file mode 100644 index 0000000000000000000000000000000000000000..190cdb424cab974b91a41485e5b63087ae2b18e7 --- /dev/null +++ b/lcm/nn/initialization.py @@ -0,0 +1,152 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# + +import math +from functools import partial +from typing import Literal, Optional + +import torch +from fairseq2.nn.projection import Linear +from fairseq2.nn.transformer import TransformerNormOrder +from torch.nn import Module + +SUPPORTED_INIT_TYPES = Literal[ + "xavier", + "sonar", + "zero", + "trunc_normal", + "kaiming_uniform", + "none", +] + + +SONAR_STD = 0.006 +# Most SONAR embeddings have a distribution with the mean close to 0 and std close to 0.006 +# Initializing embedding-like parameters (e.g. end-of-text vector) from a similar distribution is recommended, +# to minimize their disruption of the model training + + +def get_init_fn(style: str = "xavier", sonar_std: float = SONAR_STD): + if style == "xavier": + return init_linear_xavier + + if style == "kaiming_uniform": + return init_linear_kaiming_uniform + + if style == "sonar": + return partial(init_linear_to_sonar, sonar_std=sonar_std) + + if style == "zero": + return init_linear_zero + + if style == "trunc_normal": + return init_linear_trunc_normal + + if style == "none": + return None + + else: + raise ValueError(f"Could not recognize initialization function {style}") + + +def init_linear_to_sonar(layer: Linear, sonar_std: float) -> None: + """ + Initialize the post-lcm in such a way, that if it is fed layer-normed + lcm outputs (with zero mean and unit variance), its outputs have zero + mean and the variance of SONAR embeddings. + """ + if layer.bias is not None: + torch.nn.init.zeros_(layer.bias) + + std = sonar_std * (3 / layer.input_dim) ** 0.5 + + torch.nn.init.uniform_(layer.weight, a=-std, b=std) + + +def init_linear_xavier(layer: Linear) -> None: + torch.nn.init.xavier_uniform_(layer.weight) + if layer.bias is not None: + torch.nn.init.zeros_(layer.bias) + + +def init_linear_zero(layer: Linear) -> None: + torch.nn.init.zeros_(layer.weight) + if layer.bias is not None: + torch.nn.init.zeros_(layer.bias) + + +def init_linear_trunc_normal(layer: Linear) -> None: + torch.nn.init.trunc_normal_(layer.weight, std=1e-3) + if layer.bias is not None: + torch.nn.init.zeros_(layer.bias) + + +def init_linear_kaiming_uniform(layer: Linear) -> None: + torch.nn.init.kaiming_uniform_(layer.weight, a=math.sqrt(5)) + + if layer.bias is not None: + fan_in = layer.weight.size(1) + + m = 1 + if layer.weight.ndim > 2: + for s in layer.weight.shape[2:]: + m *= s + + fan_in *= m + + # We do not calculate the true standard deviation of the uniform + # distribution (i.e. multiply with sqrt(3)). See + # https://github.com/pytorch/pytorch/issues/57109#issuecomment-828847575. + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + + torch.nn.init.uniform_(layer.bias, -bound, bound) + + +def parse_norm_order(var: str) -> TransformerNormOrder: + norm_order: TransformerNormOrder + if var == "pre": + norm_order = TransformerNormOrder.PRE + elif var == "post": + norm_order = TransformerNormOrder.POST + elif var == "normformer": + norm_order = TransformerNormOrder.PRE_WITH_NORMFORMER + else: + raise ValueError(f"Unknown normalization order {var}") + + return norm_order + + +def parse_activation_fn(var: str = None) -> Optional[Module]: + if var is None: + return None + + activ_fn: Module + + if var == "relu": + activ_fn = torch.nn.ReLU() + elif var == "tanh": + activ_fn = torch.nn.Tanh() + elif var == "elu": + activ_fn = torch.nn.ELU() + elif var == "leaky_relu": + activ_fn = torch.nn.LeakyReLU() + elif var == "prelu": + activ_fn = torch.nn.PReLU() + elif var == "selu": + activ_fn = torch.nn.SELU() + elif var == "gelu": + activ_fn = torch.nn.GELU() + elif var == "silu": + activ_fn = torch.nn.SiLU() + elif var == "softsign": + activ_fn = torch.nn.Softsign() + elif var == "sigmoid": + activ_fn = torch.nn.Sigmoid() + elif var == "hardsigmoid": + activ_fn = torch.nn.Hardsigmoid() + else: + raise ValueError(f"Unknown activation function {var}") + + return activ_fn diff --git a/lcm/nn/normalization.py b/lcm/nn/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..825602b6f6787317793d212f2d7404868e61f403 --- /dev/null +++ b/lcm/nn/normalization.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# + +from typing import Literal, Optional, final + +import torch +from fairseq2.nn import LayerNorm, RMSNorm, StandardLayerNorm +from fairseq2.nn.transformer import LayerNormFactory, create_standard_layer_norm +from fairseq2.typing import DataType, Device, override + +SUPPORTED_LN_TYPES = Literal["standard", "fp32", "rms", "unit"] + + +@final +class FP32LayerNorm(LayerNorm): + """Applies Layer Normalization in single-precision.""" + + @override + def forward(self, x: torch.Tensor) -> torch.Tensor: + w, b = self.weight, self.bias + + # cast input and params to float32 + fp32_x = x.float() + fp32_w = w.float() if w is not None else None + fp32_b = b.float() if b is not None else None + + y = torch.nn.functional.layer_norm( + fp32_x, self.normalized_shape, fp32_w, fp32_b, self.eps + ) + + return y.type_as(x) + + +def build_rms_layer_norm( + model_dim: int, + *, + device: Optional[Device] = None, + dtype: Optional[DataType] = None, +) -> LayerNorm: + """Build an RMS Layer Normalization module.""" + return RMSNorm(model_dim, bias=False, device=device, dtype=dtype) + + +def build_fp32_layer_norm( + model_dim: int, + *, + device: Optional[Device] = None, + dtype: Optional[DataType] = None, +) -> LayerNorm: + """Build an Single-precision Layer Normalization module.""" + return FP32LayerNorm(model_dim, bias=False, device=device, dtype=dtype) + + +def build_unit_layer_norm( + model_dim: int, + *, + device: Optional[Device] = None, + dtype: Optional[DataType] = None, +) -> LayerNorm: + """Create an instance of :class:`StandardLayerNorm + without learnable mean and variance`.""" + return StandardLayerNorm( + model_dim, + bias=False, + elementwise_affine=False, + device=device, + dtype=dtype, + ) + + +def parse_layer_norm_factory(layer_normalization_style: str) -> LayerNormFactory: + if layer_normalization_style == "rms": + # Note that RMSNorm normalizes in single-precision by default + return build_rms_layer_norm + + elif layer_normalization_style == "unit": + return build_unit_layer_norm + + elif layer_normalization_style == "fp32": + return build_fp32_layer_norm + + elif layer_normalization_style == "standard": + return create_standard_layer_norm + + else: + raise ValueError(f"Unsupported LayerNorm style {layer_normalization_style}") diff --git a/lcm/nn/projection.py b/lcm/nn/projection.py new file mode 100644 index 0000000000000000000000000000000000000000..0e6116f1db99ee67781d6c2ad89b611bb388e8a0 --- /dev/null +++ b/lcm/nn/projection.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# + +from dataclasses import dataclass +from typing import Optional + +import torch +from fairseq2.nn.projection import Linear +from fairseq2.typing import DataType, Device +from torch import Tensor +from torch.nn import Module + +from lcm.nn.initialization import ( + SUPPORTED_INIT_TYPES, + get_init_fn, + parse_activation_fn, +) +from lcm.nn.normalization import SUPPORTED_LN_TYPES + + +@dataclass +class ProjectionConfig: + dropout_p: float = 0.0 + """ The dropout probability applied to the module' output""" + + linear_bias: bool = True + """ Whether or not the pre-linear layer has a bias term""" + + linear_init_fn: SUPPORTED_INIT_TYPES = "kaiming_uniform" + + weight_normalization: bool = False + + layer_normalization_style: SUPPORTED_LN_TYPES = "standard" + + activation_name: Optional[str] = None + """the activation function to apply after fi any""" + + +class Projection(Module): + """ + An output projecton module. + """ + + def __init__( + self, + output_dim: int, + input_dim: int, + config: ProjectionConfig, + device: Optional[Device] = None, + dtype: Optional[DataType] = None, + ) -> None: + super().__init__() + + self.dtype = dtype + + init_fn = get_init_fn(config.linear_init_fn) + + lin = Linear( + input_dim, + output_dim, + bias=config.linear_bias, + device=device, + dtype=dtype, + init_fn=init_fn, + ) + if config.weight_normalization: + self.fc = torch.nn.utils.parametrizations.weight_norm(lin) + else: + self.fc = lin + + self.activation_fn = parse_activation_fn(config.activation_name) + + if self.activation_fn is not None: + # some activation functions (e.g., PReLU) have parameters + # and so we need to move them to the right device + self.activation_fn.to(device) + + def forward(self, seqs: Tensor): + seqs = self.fc(seqs) + + if self.activation_fn is not None: + seqs = self.activation_fn(seqs) + + return seqs diff --git a/lcm/nn/schedulers/__init__.py b/lcm/nn/schedulers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a19eb595b4c1b1dcbb746790f2ff7a5b2a1aab56 --- /dev/null +++ b/lcm/nn/schedulers/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# + + +from lcm.nn.schedulers.ddim import ( + DDIMScheduler, + DDIMSchedulerConfig, + DDIMSchedulerOutput, +) + +__all__ = [ + "DDIMScheduler", + "DDIMSchedulerConfig", + "DDIMSchedulerOutput", +] diff --git a/lcm/nn/schedulers/ddim.py b/lcm/nn/schedulers/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..8f0992704111d24a98c4ed4438f74827845dc06f --- /dev/null +++ b/lcm/nn/schedulers/ddim.py @@ -0,0 +1,741 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# + +# This code is based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddim.py, which is distributed under the Apache 2.0 License. +# HuggingFace's diffusers DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +import math +from dataclasses import dataclass +from typing import List, Literal, Optional, Tuple, Union + +import torch +from fairseq2.logging import get_log_writer +from fairseq2.typing import CPU +from torch import Tensor + +logger = get_log_writer(__name__) + + +def sigmoid(x): + return 1 / (1 + math.exp(-x)) + + +def logit(x): + return math.log(x / (1 - x)) + + +@dataclass +class DDIMSchedulerOutput: + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`Tensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: Tensor + pred_original_sample: Tensor + + +@dataclass +class DDIMSchedulerConfig: + num_diffusion_train_steps: int = 1000 + """The number of diffusion steps to train the model.""" + + beta_start: float = 0.0001 + """The starting `beta` value of inference.""" + + beta_end: float = 0.02 + """The final `beta` value.""" + """In DDPM (https://arxiv.org/pdf/2006.11239), $\beta_t$ is increasing + linearly from $\beta_1$ (`beta_start`)=1e−4 to $\beta_T$ (`beta_end`)=0.02. + These constants were chosen to be small relative to data scaled to [−1, 1], + ensuring that reverse and forward processes have approximately + the same functional form while keeping the signal-to-noise ratio at $x_T$ as small as possible. + Another common choice in HF:diffusers `beta_start=0.00085, beta_end=0.012,` + Note that `beta_start` and `beta_end` are irrelevant for `squaredcos_cap_v2` + """ + + beta_schedule: Literal[ + "linear", + "scaled_linear", + "squaredcos_cap_v2", + "sigmoid", + ] = "squaredcos_cap_v2" + """The beta schedule, a mapping from a beta range to a sequence of betas + for stepping the model (length=`num_diffusion_train_steps`). + Choose from: + - `linear`: Linearly spaced betas between `beta_start` and `beta_end`. + Referred to as `sqrt_linear` in stable-diffusion. + - `scaled_linear`: Squared values after linearly spacing form sqrt(beta_start) to sqrt(beta_end). + Referred to as `linear` in stable-diffusion. + -`squaredcos_cap_v2`: Creates a beta schedule that discretizes + math:: $\bar alpha(t) = {cos((t/T + s) / (1+s) * \pi/2)}^2$, HF:diffusers sets `s` to 0.008. + For the intuition behind how a cosine schedule compares to a linear schedule + see Figure 3 of https://arxiv.org/pdf/2102.09672 + - `sigmoid` our sigmoid schedule (see Equation 14 of the LCM paper). + """ + + scaled_linear_exponent: float = 2.0 + """Exponent for the scaled linear beta schedule. Default is quadratic (scaled_linear_exponent=2)""" + + sigmoid_schedule_alpha: float = 1.5 + sigmoid_schedule_beta: float = 0 + """alpha and beta hyper-parameters of the sigmoid beta-schedule""" + + clip_sample: bool = False + """Clip the predicted sample for numerical stability.""" + + clip_sample_range: float = 1.0 + """The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.""" + + set_alpha_to_one: bool = True + """Each diffusion step uses the alphas product value at that step and at the previous one. For the final step + there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the alpha value at step 0.""" + + prediction_type: Literal["sample", "epsilon", "v_prediction"] = "sample" + """If `sample`, the model predicts the clean ground truth embeddings. + If `epsilon`, the model predicts the added noise of the diffusion process. + If `v_epsilon`, the model predicts an interpolation of the ground truth clean + embeddings and the added noise. As introduced in section 2.4 of the Imagen paper + (https://imagen.research.google/video/paper.pdf) + """ + + thresholding: bool = False + """Whether to use the "dynamic thresholding" method. + This is unsuitable for latent-space diffusion models such as Stable Diffusion.""" + + dynamic_thresholding_ratio: float = 0.995 + """The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.""" + + sample_max_value: float = 1.0 + """The threshold value for dynamic thresholding. Valid only when `thresholding=True`.""" + + rescale_betas_zero_snr: bool = True + """Whether to rescale the betas to have zero terminal SNR. This enables the + model to generate very bright and dark samples instead of limiting it to samples + with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).""" + + # Inference specific + timestep_spacing: Literal["linspace", "leading", "trailing"] = "trailing" + """The way the timesteps should be scaled. Refer to Table 2 of + https://arxiv.org/abs/2305.08891 for more information.""" + + +class DDIMScheduler: + def __init__(self, config: DDIMSchedulerConfig): + self.config = config + + # Make these 2 arguments easily accessible + self.num_diffusion_train_steps = self.config.num_diffusion_train_steps + + self.prediction_type = self.config.prediction_type + + beta_schedule = self.config.beta_schedule + + if beta_schedule == "linear": + self.betas = torch.linspace( + self.config.beta_start, + self.config.beta_end, + self.num_diffusion_train_steps, + dtype=torch.float32, + ) + elif beta_schedule == "scaled_linear": + # This schedule is very specific to the latent diffusion model. + exponent = self.config.scaled_linear_exponent + self.betas = ( + torch.linspace( + self.config.beta_start ** (1 / exponent), + self.config.beta_end ** (1 / exponent), + self.num_diffusion_train_steps, + dtype=torch.float32, + ) + ** exponent + ) + elif beta_schedule == "squaredcos_cap_v2": + # Cosine schedule as introduced in + # [Nichol and Dhariwal, 2021](https://proceedings.mlr.press/v139/nichol21a/nichol21a.pdf) + self.betas = betas_for_alpha_bar( + self.num_diffusion_train_steps, + alpha_transform_type="cosine", + ) + + elif beta_schedule == "sigmoid": + self.betas = betas_for_alpha_bar( + self.num_diffusion_train_steps, + alpha_transform_type="sigmoid", + sigmoid_alpha=self.config.sigmoid_schedule_alpha, + sigmoid_beta=self.config.sigmoid_schedule_beta, + ) + + else: + raise NotImplementedError( + f"We do not recognize beta_schedule={beta_schedule}" + ) + + # Rescale for zero SNR + if self.config.rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = ( + torch.tensor(1.0) + if self.config.set_alpha_to_one + else self.alphas_cumprod[0] + ) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # timesteps for inference + self.num_inference_steps: Optional[int] = None + + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = ( + self.alphas_cumprod[prev_timestep] + if prev_timestep >= 0 + else self.final_alpha_cumprod + ) + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * ( + 1 - alpha_prod_t / alpha_prod_t_prev + ) + return variance + + def get_variances(self) -> Tensor: + alpha_prod_t = self.alphas_cumprod + alpha_prod_t_prev = torch.cat( + (torch.tensor([self.final_alpha_cumprod]), alpha_prod_t[:-1]) + ) + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * ( + 1 - alpha_prod_t / alpha_prod_t_prev + ) + return variance + + def get_snrs(self) -> Tensor: + alphas_cumprod = self.alphas_cumprod + snr = alphas_cumprod / (1 - alphas_cumprod) + return snr + + def _threshold_sample(self, sample: Tensor) -> Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain + percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), + and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) + inwards, thereby actively preventing pixels from saturation at each step. + We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, + especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = ( + sample.float() + ) # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, -1) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = ( + torch.clamp(sample, -s, s) / s + ) # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + def set_timesteps( + self, num_inference_steps: int, device: Union[str, torch.device] = None + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + """ + + if num_inference_steps > self.config.num_diffusion_train_steps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.num_diffusion_train_steps`:" + f" {self.num_diffusion_train_steps} as the unet model trained with this scheduler can only handle" + f" maximal {self.num_diffusion_train_steps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + # With T the number of training steps and S the number of inference steps + + if self.config.timestep_spacing == "linspace": + # Linspace: flip round(linspace(1, T, S)) + # With T=1000 and S=10; [999, 888, 777, 666, 555, 444, 333, 222, 111, 0] + timesteps = torch.linspace( + 0, + self.config.num_diffusion_train_steps - 1, + self.num_inference_steps, + device=device, + dtype=torch.long, + ) + timesteps = torch.flip(timesteps, dims=(0,)).round() + + elif self.config.timestep_spacing == "leading": + # Leading: flip arange(1, T + 1, floor(T /S)) + # With T=1000 and S=10: [900, 800, 700, 600, 500, 400, 300, 200, 100, 0] + + leading_step_ratio = ( + self.num_diffusion_train_steps // self.num_inference_steps + ) + timesteps = torch.arange( + start=0, + end=self.num_diffusion_train_steps, + step=leading_step_ratio, + device=device, + dtype=torch.long, + ) + timesteps = torch.flip(timesteps, dims=(0,)).round() + + elif self.config.timestep_spacing == "trailing": + # Trailing: round(flip(arange(T, 0, −T /S))) + # With T=1000 and S=10: [999, 899, 799, 699, 599, 499, 399, 299, 199, 99] + trailing_step_ratio: float = ( + self.num_diffusion_train_steps / self.num_inference_steps + ) + # creates integer timesteps by multiplying by ratio + timesteps = torch.arange( + self.config.num_diffusion_train_steps, + 0, + -trailing_step_ratio, + device=device, + dtype=torch.long, + ).round() + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." + ) + + self.timesteps = timesteps + logger.debug( + f"With `{self.config.timestep_spacing}`, setting inference timesteps to {self.timesteps}" + ) + + def step( + self, + model_output: Tensor, + timestep: int, + sample: Tensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + variance_noise: Optional[Tensor] = None, + prediction_type: Optional[str] = None, + epsilon_scaling: Optional[float] = None, + ) -> DDIMSchedulerOutput: + """ + INFERENCE ONLY. + Predict the sample from the previous timestep by reversing the SDE. + This function propagates the diffusion + process from the learned model outputs. + + Args: + model_output (`Tensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`Tensor`): + A current instance of a sample created by the diffusion process. + eta (`float`): + The weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`, defaults to `False`): + If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary + because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no + clipping has happened, "corrected" `model_output` would coincide with the one provided as input and + `use_clipped_model_output` has no effect. + generator (`torch.Generator`, *optional*): + A random number generator. + variance_noise (`Tensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`CycleDiffusion`]. + prediction_type: Optional[str] if provided we step with a different prediction_type + than the one in the config + epsilon_scaling: Optional[float] if not None, the predicted epsilon will be scaled down by + the provided factor as introduced in https://arxiv.org/pdf/2308.15321 + + Returns: + DDIMSchedulerOutput + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. Get previous step value (=t-1) + prev_timestep = ( + timestep - self.config.num_diffusion_train_steps // self.num_inference_steps + ) + + # 2. Compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = ( + self.alphas_cumprod[prev_timestep] + if prev_timestep >= 0 + else self.final_alpha_cumprod + ) + + beta_prod_t = 1 - alpha_prod_t + + # 3. Compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prediction_type = prediction_type or self.prediction_type + if prediction_type == "epsilon": + pred_original_sample = ( + sample - beta_prod_t ** (0.5) * model_output + ) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = ( + sample - alpha_prod_t ** (0.5) * pred_original_sample + ) / beta_prod_t ** (0.5) + elif prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - ( + beta_prod_t**0.5 + ) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + ( + beta_prod_t**0.5 + ) * sample + else: + raise ValueError( + f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + # 3.a epsilon scaling: + if epsilon_scaling is not None: + pred_epsilon = pred_epsilon / epsilon_scaling + + # 4. Clip or threshold "predicted x_0" + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 5. Compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = self._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + if use_clipped_model_output: + # the pred_epsilon is always re-derived from the clipped x_0 in Glide + pred_epsilon = ( + sample - alpha_prod_t ** (0.5) * pred_original_sample + ) / beta_prod_t ** (0.5) + # 6. Compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** ( + 0.5 + ) * pred_epsilon + # 7. Compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = ( + alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + ) + + if eta > 0: + if variance_noise is not None and generator is not None: + raise ValueError( + "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" + " `variance_noise` stays `None`." + ) + + if variance_noise is None: + variance_noise = randn_tensor( + model_output.shape, + generator=generator, + device=model_output.device, + dtype=model_output.dtype, + ) + variance = std_dev_t * variance_noise + prev_sample = prev_sample + variance + + return DDIMSchedulerOutput( + prev_sample=prev_sample, pred_original_sample=pred_original_sample + ) + + def add_noise( + self, + original_samples: Tensor, + noise: Tensor, + timesteps: Tensor, + ) -> Tensor: + """TRAINING ONLY + Forward noising process during training""" + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement + # for the subsequent add_noise calls + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device).to(torch.int32) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = ( + sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + ) + return noisy_samples + + def get_velocity(self, sample: Tensor, noise: Tensor, timesteps: Tensor) -> Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) + timesteps = timesteps.to(sample.device).to(torch.int32) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def get_epsilon( + self, model_output: Tensor, sample: Tensor, timestep: int + ) -> Tensor: + """Given model inputs (sample) and outputs (model_output) + Predict the noise residual according to the scheduler's + prediction type""" + + pred_type = self.prediction_type + + alpha_prod_t = self.alphas_cumprod[timestep] + + beta_prod_t = 1 - alpha_prod_t + + if pred_type == "epsilon": + return model_output + + elif pred_type == "sample": + return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** ( + 0.5 + ) + + elif pred_type == "v_prediction": + return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"The scheduler's prediction type {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`" + ) + + +def randn_tensor( + shape: Union[Tuple, List], + generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, + device: Optional["torch.device"] = None, + dtype: Optional["torch.dtype"] = None, + layout: Optional["torch.layout"] = None, +): + """A helper function to create random tensors on the desired `device` with the desired `dtype`. When + passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor + is always created on the CPU. + """ + # device on which tensor is created defaults to device + rand_device = device + batch_size = shape[0] + + layout = layout or torch.strided + device = device or torch.device("cpu") + + if generator is not None: + gen_device_type = ( + generator.device.type + if not isinstance(generator, list) + else generator[0].device.type + ) + if gen_device_type != device.type and gen_device_type == "cpu": + rand_device = CPU + if device != "mps": + logger.info( + f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." + f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" + f" slighly speed up this function by passing a generator that was created on the {device} device." + ) + elif gen_device_type != device.type and gen_device_type == "cuda": + raise ValueError( + f"Cannot generate a {device} tensor from a generator of type {gen_device_type}." + ) + + # make sure generator list of length 1 is treated like a non-list + if isinstance(generator, list) and len(generator) == 1: + generator = generator[0] + + if isinstance(generator, list): + shape = (1,) + shape[1:] # type: ignore + latents_list = [ + torch.randn( + shape, + generator=generator[i], + device=rand_device, + dtype=dtype, + layout=layout, + ) + for i in range(batch_size) + ] + latents = torch.cat(latents_list, dim=0).to(device) + else: + latents = torch.randn( + shape, generator=generator, device=rand_device, dtype=dtype, layout=layout + ).to(device) + + return latents + + +def betas_for_alpha_bar( + num_diffusion_timesteps: int, + max_beta: float = 0.999, + alpha_transform_type: Literal["cosine", "exp", "sigmoid"] = "cosine", + sigmoid_alpha: float = 1.5, + sigmoid_beta: float = 0, +) -> Tensor: + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + sigmoid_alpha/sigmoid_beta: additional hyper-parameters for the sigmoid schedule + + Returns: + betas (`Tensor`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "sigmoid": + + def alpha_bar_fn(t): + epsilon = 1e-32 + return sigmoid( + sigmoid_beta + - sigmoid_alpha + * logit(torch.clamp(torch.tensor(t), min=epsilon, max=1 - epsilon)) + ) + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +def rescale_zero_terminal_snr(betas: Tensor) -> Tensor: + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + Args: + betas (`Tensor`): + the betas that the scheduler is being initialized with. + + Returns: + `Tensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas diff --git a/lcm/nn/timestep_encoder.py b/lcm/nn/timestep_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..5fcfd29a44c8d5048b54ec637522543ffc9d463d --- /dev/null +++ b/lcm/nn/timestep_encoder.py @@ -0,0 +1,122 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# + +import math +from typing import Optional + +import torch +from fairseq2.nn.projection import Linear +from fairseq2.typing import DataType, Device +from torch import Tensor +from torch.nn import Module + +from lcm.nn.initialization import parse_activation_fn + + +class DiTTimestepEncoder(Module): + """ + Embeds scalar timesteps into vector representations. + Based on DiT's `TimestepEmbedder` + https://github.com/facebookresearch/DiT/blob/main/models.py + """ + + def __init__( + self, + embedding_dim: int, + frequency_embedding_size: int = 256, + activation_fn_name: str = "silu", + device: Optional[Device] = None, + dtype: Optional[DataType] = None, + ): + super().__init__() + + self.dtype = dtype + + self.device = device + + self.embedding_dim = embedding_dim + + self.frequency_embedding_size = frequency_embedding_size + + self.fc1 = Linear( + frequency_embedding_size, + embedding_dim, + bias=True, + device=device, + dtype=dtype, + ) + self.nonlin = parse_activation_fn(activation_fn_name) + self.fc2 = Linear( + embedding_dim, + embedding_dim, + bias=True, + device=device, + dtype=dtype, + ) + + self.reset_parameters() + + def reset_parameters(self) -> None: + """Reset the parameters and buffers of the module.""" + torch.nn.init.normal_(self.fc1.weight, std=0.02) + torch.nn.init.normal_(self.fc2.weight, std=0.02) + + if self.fc1.bias is not None: + torch.nn.init.zeros_(self.fc1.bias) + + if self.fc2.bias is not None: + torch.nn.init.zeros_(self.fc2.bias) + + @staticmethod + def sinusoidal_timestep_embedding( + timestep, frequency_embedding_size, max_period=10000 + ): + """ + Create sinusoidal timestep embeddings. + :param timestep: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param frequency_embedding_size: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + + Based on DiT's `TimestepEmbedder` + https://github.com/facebookresearch/DiT/blob/main/models.py + """ + half = frequency_embedding_size // 2 + + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=timestep.device) + + args = timestep[:, None].float() * freqs[None] + + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + + if frequency_embedding_size % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + + return embedding + + def forward(self, timesteps: Tensor) -> Tensor: + initial_size = timesteps.size() + + flat_timesteps = timesteps.view(-1, 1) + + t_freq = self.sinusoidal_timestep_embedding( + flat_timesteps, self.frequency_embedding_size + ).to(self.dtype) + + t_emb = self.fc1(t_freq) + + if self.nonlin is not None: + t_emb = self.nonlin(t_emb) + + t_emb = self.fc2(t_emb) + + return t_emb.view(*initial_size, self.embedding_dim) diff --git a/lcm/nn/transformer/__init__.py b/lcm/nn/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..70c7920a75b5c4d9238aff564cf3865cb6d6a603 --- /dev/null +++ b/lcm/nn/transformer/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# + +from lcm.nn.transformer.attention import ( + QKNormMultiheadAttention, +) +from lcm.nn.transformer.decoder import ( + LCMStandardTransformerDecoderLayer, + LCMTransformerDecoder, +) +from lcm.nn.transformer.factory import ( + TransformerConfig, + TransformerFactory, +) + +__all__ = [ + "QKNormMultiheadAttention", + "LCMStandardTransformerDecoderLayer", + "LCMTransformerDecoder", + "TransformerConfig", + "TransformerFactory", +] diff --git a/lcm/nn/transformer/attention.py b/lcm/nn/transformer/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..5275673f7ed7722c99128cb4f73ed50c1b70036f --- /dev/null +++ b/lcm/nn/transformer/attention.py @@ -0,0 +1,307 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# + +from typing import Optional, Tuple, final + +import torch +import torch.nn as nn +from fairseq2.nn.ops import repeat_interleave +from fairseq2.nn.padding import PaddingMask +from fairseq2.nn.position_encoder import PositionEncoder +from fairseq2.nn.projection import Projection +from fairseq2.nn.transformer import ( + AttentionMask, + AttentionMaskFactory, + AttentionState, + AttentionStateFactory, + FullAttentionState, + LayerNormFactory, + StandardMultiheadAttention, + create_standard_layer_norm, +) +from fairseq2.nn.transformer.attention import SDPA +from fairseq2.typing import DataType, Device, override +from torch import Tensor +from torch.nn.parameter import Parameter + +# FIXME revert to fs2's standard state bag if possible +from lcm.nn.incremental_state import ( + LCMIncrementalStateBag, +) + + +@final +class QKNormMultiheadAttention(StandardMultiheadAttention): # type: ignore + """Represents a Transformer multi-head attention as described in + :cite:t:`https://doi.org/10.48550/arxiv.1706.03762` + with two additional layer-normalization for keys and queries + as described in https://arxiv.org/pdf/2302.05442 + and other related work + """ + + kv_dim: int + num_key_value_heads: int + q_proj: Projection + k_proj: Projection + v_proj: Projection + attn_mask_factory: Optional[AttentionMaskFactory] + pos_encoder: Optional[PositionEncoder] + bias_k: Optional[Parameter] + bias_v: Optional[Parameter] + add_zero_attn: bool + sdpa: SDPA + head_scale_weight: Optional[Parameter] + output_proj: Projection + state_factory: Optional[AttentionStateFactory] + layer_norm_factory: Optional[LayerNormFactory] + + """ + For full parameters description see fairseq2/src/fairseq2/nn/transformer/multihead_attention.py + Parameters of interest to us: + :param num_key_value_heads: + The number of key/value heads for Grouped Query Attention as + described in :cite:t:`https://doi.org/10.48550/arXiv.2305.13245`. + If ``None`` or set to ``num_heads``, it is equivalent to standard + Multi Head Attention (MHA); if set to 1, it is equivalent to Multi + Query Attention (MQA). + + :param enable_qk_layernorm: + If True follow Q/K projections with LayerNorms + + :param weight_normalization: + If True, wrap K/Q/V projections with weight normalization for regularization + + :param pos_encoder: + For RoPE positional encoder that adds positional encoding to keys + and queries before computing the attention scores + """ + + def __init__( + self, + model_dim: int, + num_heads: int, + *, + kv_dim: Optional[int] = None, + num_key_value_heads: Optional[int] = None, + q_proj: Optional[Projection] = None, + k_proj: Optional[Projection] = None, + v_proj: Optional[Projection] = None, + attn_mask_factory: Optional[AttentionMaskFactory] = None, + pos_encoder: Optional[PositionEncoder] = None, + sdpa: Optional[SDPA] = None, + scale_heads: bool = False, + output_proj: Optional[Projection] = None, + bias: bool = True, + state_factory: Optional[AttentionStateFactory] = None, + enable_qk_layernorm: bool = False, + weight_normalization: bool = False, + layer_norm_factory: Optional[LayerNormFactory] = None, + device: Optional[Device] = None, + dtype: Optional[DataType] = None, + ) -> None: + super().__init__( + model_dim=model_dim, + num_heads=num_heads, + kv_dim=kv_dim, + num_key_value_heads=num_key_value_heads, + q_proj=q_proj, + k_proj=k_proj, + v_proj=v_proj, + attn_mask_factory=attn_mask_factory, + pos_encoder=pos_encoder, + sdpa=sdpa, + scale_heads=scale_heads, + output_proj=output_proj, + bias=bias, + state_factory=state_factory, + device=device, + dtype=dtype, + ) + + # wrap linear layers with weight norm + if weight_normalization: + self.k_proj = nn.utils.parametrizations.weight_norm(self.k_proj) + self.q_proj = nn.utils.parametrizations.weight_norm(self.q_proj) + self.v_proj = nn.utils.parametrizations.weight_norm(self.v_proj) + + self.enable_qk_layernorm = enable_qk_layernorm + # initialize q-k LayerNorms if needed + if self.enable_qk_layernorm: + if layer_norm_factory is None: + # use default LayerNorm factory + layer_norm_factory = create_standard_layer_norm + + self.q_layer_norm = layer_norm_factory( + model_dim, device=device, dtype=dtype + ) + self.k_layer_norm = layer_norm_factory( + self.kv_dim, device=device, dtype=dtype + ) + + @override + def _project_q( # type: ignore + self, + seqs: Tensor, + padding_mask: Optional[PaddingMask], + state_bag: Optional[LCMIncrementalStateBag] = None, + ) -> Tensor: + # (N, S, M) -> (N, S, K_proj) + q = self.q_proj(seqs) + + # normalize queries + if self.enable_qk_layernorm: + q = self.q_layer_norm(q) + + # (N, S, K_proj) -> (N, H, S, K_h) + q = q.unflatten(-1, (self.num_heads, -1)).transpose(1, 2) + + if self.pos_encoder is not None: + q = self.pos_encoder( + q, + padding_mask, + state_bag=state_bag, + ) + + return q # type: ignore[no-any-return] + + @override + def _project_kv( # type: ignore + self, + keys: Tensor, + key_padding_mask: Optional[PaddingMask], + values: Tensor, + state_bag: Optional[LCMIncrementalStateBag] = None, + ) -> Tuple[Tensor, Tensor]: + # (N, S, K) -> (N, S, K_proj) + k = self.k_proj(keys) + + # normalize keys + if self.enable_qk_layernorm: + k = self.k_layer_norm(k) + + # (N, S, V) -> (N, S, V_proj) + v = self.v_proj(values) + + # (N, S, K_proj) -> (N, H, S, K_h) + k = k.unflatten(-1, (self.num_key_value_heads, -1)).transpose(1, 2) + # (N, S, V_proj) -> (N, H, S, V_h) + v = v.unflatten(-1, (self.num_key_value_heads, -1)).transpose(1, 2) + + if self.pos_encoder is not None: + k = self.pos_encoder( + k, + key_padding_mask, + state_bag=state_bag, + ) + + return k, v + + @override + def forward( # type: ignore + self, + seqs: Tensor, + padding_mask: Optional[PaddingMask], + keys: Tensor, + key_padding_mask: Optional[PaddingMask], + values: Tensor, + *, + attn_mask: Optional[AttentionMask] = None, + state_bag: Optional[LCMIncrementalStateBag] = None, + ) -> Tensor: + # (N, S, M) -> (N, H, S, K_h) + q = self._project_q( + seqs, + padding_mask, + state_bag, + ) + if self.training or state_bag is None: + # k: (N, S_kv, M) -> (N, H_kv, S_kv, K_h) + # v: (N, S_kv, M) -> (N, H_kv, S_kv, V_h) + k, v = self._project_kv( + keys, + key_padding_mask, + values, + ) + else: + if key_padding_mask is not None: + raise ValueError( + "`key_padding_mask` must be `None` during incremental decoding." + ) + + # k: (N, S_step, M) -> (N, H_kv, S_step, K_h) + # v: (N, S_step, M) -> (N, H_kv, S_step, V_h) + k, v = self._project_kv(keys, key_padding_mask, values, state_bag) + + state = state_bag.get_state(self, AttentionState) # type: ignore + + if state is None: + state_factory = self.state_factory or FullAttentionState + + state = state_factory( + k, v, state_bag.max_num_steps, state_bag.capacity_increment + ) + + state_bag.set_state(self, state) + else: + state.append(k, v) + + # k: (N, H_kv, S_kv, K_h) + # v: (N, H_kv, S_kv, V_h) + + k, v = state.get() + + # With Grouped Query Attention, each key/value head is repeated. + if (num_query_groups := self.num_heads // self.num_key_value_heads) > 1: + # (N, H_kv, S_kv, K_h) -> (N, H, S_kv, K_h) + k = repeat_interleave(k, dim=1, repeat=num_query_groups) + # (N, H_kv, S_kv, K_h) -> (N, H, S_kv, V_h) + v = repeat_interleave(v, dim=1, repeat=num_query_groups) + + if self.attn_mask_factory is not None: + attn_mask = self.attn_mask_factory( + seqs, keys=keys, training=self.training, state_bag=state_bag + ) + + needs_weights = len(self._attn_weight_hooks) > 0 + + # attn: (N, H, S, V_h) + # attn_weights: (N, H, S, S_kv) + + attn, attn_weights = self.sdpa( + q, + k, + key_padding_mask, + v, + attn_mask=attn_mask, + needs_weights=needs_weights, + ) + + if attn_weights is not None: + for hook in self._attn_weight_hooks.values(): + hook(self, attn, attn_weights) + + # (N, H, S, V_h) -> (N, S, H, V_h) + attn = attn.transpose(1, 2) + + if self.head_scale_weight is not None: + attn = torch.einsum("nshv,h->nshv", attn, self.head_scale_weight) + + # (N, S, H, V_h) -> (N, S, V_proj) + attn = attn.flatten(2, 3) + + # (N, S, V_proj) -> (N, S, M) + + attn = self.output_proj(attn) + + return attn # type: ignore[no-any-return] + + @override + def extra_repr(self) -> str: + """:meta private:""" + s = super().extra_repr() + + s = f"{s}, enable_qk_layernorm={self.enable_qk_layernorm}" + + return s diff --git a/lcm/nn/transformer/decoder.py b/lcm/nn/transformer/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..449b80f731f555ecd1964debf0eb98f57f76217c --- /dev/null +++ b/lcm/nn/transformer/decoder.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# + +from typing import List, Optional, Tuple + +from fairseq2.nn.padding import PaddingMask +from fairseq2.nn.transformer import ( + AttentionMask, + AttentionMaskFactory, + LayerNormFactory, + StandardTransformerDecoderLayer, + TransformerDecoder, + TransformerDecoderLayer, + TransformerNormOrder, +) +from fairseq2.typing import DataType, Device, override +from torch import Generator, Tensor +from torch.nn import Dropout, ModuleList + +from lcm.nn.incremental_state import LCMIncrementalStateBag + + +class LCMStandardTransformerDecoderLayer(StandardTransformerDecoderLayer): # type: ignore + """Pass on `source_lengths` to StandardTransformerDecoderLayer's forward_pass.""" + + @override + def forward( # type: ignore + self, + seqs: Tensor, + padding_mask: Optional[PaddingMask], + self_attn_mask: Optional[AttentionMask] = None, + encoder_output: Optional[Tensor] = None, + encoder_padding_mask: Optional[PaddingMask] = None, + *, + state_bag: Optional[LCMIncrementalStateBag] = None, + ) -> Tuple[Tensor, Optional[PaddingMask]]: + seqs = self._forward_self_attn( + seqs, + padding_mask, + self_attn_mask, + state_bag, + ) + + seqs = self._forward_encoder_decoder_attn( + seqs, padding_mask, encoder_output, encoder_padding_mask, state_bag + ) + + seqs = self._forward_ffn(seqs) + + return seqs, padding_mask + + @override + def _forward_self_attn( # type: ignore + self, + seqs: Tensor, + padding_mask: Optional[PaddingMask], + self_attn_mask: Optional[AttentionMask], + state_bag: Optional[LCMIncrementalStateBag], + ) -> Tensor: + residual = seqs + + if self.norm_order != TransformerNormOrder.POST: + seqs = self.self_attn_layer_norm(seqs) + + seqs = self.self_attn( + seqs, + padding_mask, + keys=seqs, + key_padding_mask=padding_mask, + values=seqs, + attn_mask=self_attn_mask, + state_bag=state_bag, + ) + + if self.self_attn_norm is not None: + seqs = self.self_attn_norm(seqs) + + if self.self_attn_dropout is not None: + seqs = self.self_attn_dropout(seqs) + + seqs = seqs + residual + + if self.norm_order == TransformerNormOrder.POST: + seqs = self.self_attn_layer_norm(seqs) + + return seqs + + +class LCMTransformerDecoder(TransformerDecoder): + def __init__( + self, + layers: List[TransformerDecoderLayer], + layer_norm_factory: LayerNormFactory, + self_attn_mask_factory: AttentionMaskFactory, + use_causal_attn_mask: bool = True, + generator: Optional[Generator] = None, + dropout_p: float = 0.0, + norm_order: TransformerNormOrder = TransformerNormOrder.POST, + device: Optional[Device] = None, + dtype: Optional[DataType] = None, + ) -> None: + layer_list = ModuleList(layers) + + if not layer_list: + raise ValueError("`layers` must be non-empty.") + + model_dim = layer_list[0].model_dim + + super().__init__(model_dim) + + self.self_attn_mask_factory = self_attn_mask_factory + + self.layers = layer_list + + self.generator = generator + + if norm_order != TransformerNormOrder.POST: + self.layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype) + else: + self.register_module("layer_norm", None) + + if dropout_p > 0.0: + self.dropout = Dropout(dropout_p) + else: + self.register_module("dropout", None) + + self.norm_order = norm_order + + @override + def forward( # type: ignore + self, + seqs: Tensor, + padding_mask: Optional[PaddingMask], + encoder_output: Optional[Tensor] = None, + encoder_padding_mask: Optional[PaddingMask] = None, + *, + state_bag: Optional[LCMIncrementalStateBag] = None, + **kwargs, + ) -> Tuple[Tensor, Optional[PaddingMask]]: + """Pass on two additional arguments to StandardTransformerDecoder's forward_pass:""" + num_layers = len(self.layers) + + self_attn_mask: Optional[AttentionMask] = None + if self.self_attn_mask_factory is not None: + self_attn_mask = self.self_attn_mask_factory( + seqs, + keys=seqs, + training=self.training, + state_bag=state_bag, + ) + + for layer_idx, layer in enumerate(self.layers): + layer_output, layer_padding_mask = layer( + seqs, + padding_mask, + self_attn_mask, + encoder_output, + encoder_padding_mask, + state_bag=state_bag, + ) + + seqs, padding_mask = layer_output, layer_padding_mask + + for hook in self._layer_output_hooks.values(): + if not hook(layer_idx, seqs, padding_mask, num_layers): + break + + if self.layer_norm is not None: + seqs = self.layer_norm(seqs) + + if self.dropout is not None: + seqs = self.dropout(seqs) + + return seqs, padding_mask diff --git a/lcm/nn/transformer/factory.py b/lcm/nn/transformer/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..0bcfa08ea69533121a43a980747c33a60ac60a94 --- /dev/null +++ b/lcm/nn/transformer/factory.py @@ -0,0 +1,300 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# + +from dataclasses import dataclass +from typing import Literal, Optional + +import torch +from fairseq2.logging import get_log_writer +from fairseq2.nn import PositionEncoder +from fairseq2.nn.position_encoder import ( + LearnedPositionEncoder, + RotaryEncoder, + SinusoidalPositionEncoder, +) +from fairseq2.nn.projection import Linear +from fairseq2.nn.transformer import ( + FeedForwardNetwork, + GLUFeedForwardNetwork, + MultiheadAttention, + StandardFeedForwardNetwork, + TransformerDecoderLayer, + create_default_sdpa, +) +from fairseq2.typing import DataType, Device + +from lcm.nn.initialization import ( + SUPPORTED_INIT_TYPES, + get_init_fn, + parse_activation_fn, + parse_norm_order, +) +from lcm.nn.normalization import SUPPORTED_LN_TYPES, parse_layer_norm_factory +from lcm.nn.transformer import LCMStandardTransformerDecoderLayer +from lcm.nn.transformer.attention import ( + FullAttentionState, + QKNormMultiheadAttention, +) + +SUPPORTED_NORM_ORDERS = Literal["pre", "post", "normformer"] + + +logger = get_log_writer(__name__) + + +@dataclass +class TransformerConfig: + """A config object to group all config + hyper-parameters of a LCMTransformerDecoder""" + + num_layers: int = 2 + + num_attn_heads: int = 8 + + # Dropout rates + dropout_p: float = 0.1 + """ The dropout probability outputs of the attention layers and the + feed-forward network (before joining the residual stream)""" + + final_dropout_p: float = 0.1 + """ The dropout probability on decoder outputs""" + + attention_dropout_p: float = 0.0 + """the dropout rate on attention weights in SDPA""" + + # FFN + ffn_inner_dim: int = 1024 * 4 + + use_swiglu: bool = False + """Use GLUFeedForwardNetwork instead of regular FFN blocks""" + + ffn_inner_activation_name: str = "relu" + + """The activation to apply to outputs of the FFN inner projection layer. + Default is `relu `i.e., `torch.nn.ReLU`. This is only relevant when `use_swiglu= False`""" + + # positional embedding + pos_embedding_style: Literal["rope", "sine", "learned", "none"] = "learned" + + """If `rope`: a rotary positional encoder in used in the attention layers. + If `sine`: Sinusoidal positional embeddings will be added in + the frontend before heading into the decoder + If `learned`: Learned positional embeddings will be added in + the frontend before heading into the decoder. + If `None`: no positional embeddings will be used (e.g. in the case + of unconditional diffusion of a single vector).""" + + rope_theta: float = 10_000.0 + """ The coefficient of the long-term decay of RoPE embeddings.""" + + # Normalization + layer_normalization_style: SUPPORTED_LN_TYPES = "standard" + + norm_order_style: SUPPORTED_NORM_ORDERS = "pre" + """LayerNorm order in the transformer decoder, + default is pre-normalization (`pre`). Other options are post-normalization (`post`) + and normformer-style normalization (`normformer`)""" + + final_norm_order_style: Optional[SUPPORTED_NORM_ORDERS] = None + """Controls lcm-level norm-order, using ``post`` here with a ``pre`` layer-level norm-order + means that we will skip the last layernorm in the stack""" + + enable_qk_layernorm: bool = False + """If ``True``, LayerNorms will be applied to queries and keys in self-attention layers + QK-LayerNorm described in https://arxiv.org/pdf/2302.05442 and subsequent work + is recommended to alleviate Transformer training instabilities + """ + mha_qkv_weight_normalization: bool = False + """if ``True`` wrap the K/Q/V linears of MHA in weight normalization""" + + mha_output_weight_normalization: bool = False + """if ``True`` wrap the output projection of MHA with weight normalization. + This is a temporary fix to resume training some models and will be removed""" + + # Miscellaneous + mha_output_proj_bias: bool = False + """If ``True`` add a bias term to the MHA output projection""" + + scale_residual: Optional[float] = None + """scale to multiply the residual in the Transformer decoder""" + + attention_output_init_fn: SUPPORTED_INIT_TYPES = "xavier" + + +class TransformerFactory: + def __init__( + self, + model_dim: int, + max_seq_len: int, + config: TransformerConfig, + device: Optional[Device] = None, + dtype: Optional[DataType] = None, + ) -> None: + """ + :param model_dim: + The hidden model dimension of the Transformer + :params max_seq_len: + Maximum supported sequence length by the model + :param config: + The configuration. + :param device: + The device on which to initialize modules. + :param dtype: + The data type of module parameters and buffers. + """ + self.model_dim = model_dim + self.max_seq_len = max_seq_len + self.config = config + self.device, self.dtype = device, dtype + + def build_layer(self) -> TransformerDecoderLayer: + """Build a Transformer decoder layer based on the provided config.""" + + self_attn = self.build_attention() + + ffn = self.build_ffn() + + norm_order = parse_norm_order(self.config.norm_order_style) + + layer_norm_factory = parse_layer_norm_factory( + self.config.layer_normalization_style + ) + + layer = LCMStandardTransformerDecoderLayer( + self_attn=self_attn, + encoder_decoder_attn=None, + ffn=ffn, + dropout_p=self.config.dropout_p, + norm_order=norm_order, + layer_norm_factory=layer_norm_factory, + scale_residual=self.config.scale_residual is not None, + device=self.device, + dtype=self.dtype, + ) + # reset residual_scale + if layer.residual_scale is not None: + assert self.config.scale_residual is not None, ( + f"Layer has a resiudal scale but scale={self.config.scale_residual}" + ) + torch.nn.init.constant_(layer.residual_scale, self.config.scale_residual) + logger.info( + f"Initializing the residual scale at {self.config.scale_residual}" + ) + return layer + + def build_pos_encoder(self) -> Optional[PositionEncoder]: + """Build the positional encoder (learned or sinusoidal, if any) + that will be used in the frontend""" + pos_encoder: Optional[PositionEncoder] + + if self.config.pos_embedding_style == "learned": + pos_encoder = LearnedPositionEncoder( + self.model_dim, + self.max_seq_len, + device=self.device, + dtype=self.dtype, + ) + elif self.config.pos_embedding_style == "sine": + pos_encoder = SinusoidalPositionEncoder( + self.model_dim, + self.max_seq_len, + device=self.device, + ) + + else: + pos_encoder = None + + return pos_encoder + + def build_attention_pos_encoder(self) -> Optional[PositionEncoder]: + """Build the position encoder that can + potentially be used in the MHA module""" + + pos_encoder: Optional[PositionEncoder] + + if self.config.pos_embedding_style == "rope": + pos_encoder = RotaryEncoder( + encoding_dim=self.model_dim // self.config.num_attn_heads, + max_seq_len=self.max_seq_len, + theta=self.config.rope_theta, + device=self.device, + ) + else: + pos_encoder = None + return pos_encoder + + def build_attention(self) -> MultiheadAttention: + """Build a Transformer multi-head attention layer.""" + + # allow for a different kv_dim + kv_dim = self.model_dim + + # fairseq2.nn.transformer.attention.TorchSDPA + sdpa = create_default_sdpa(attn_dropout_p=self.config.attention_dropout_p) + + init_fn = get_init_fn(self.config.attention_output_init_fn) + + # How does Rope play with encoder-decoder attention? + pos_encoder = self.build_attention_pos_encoder() + + layer_norm_factory = parse_layer_norm_factory( + self.config.layer_normalization_style + ) + + # build output_proj: + output_proj = Linear( + self.model_dim, + self.model_dim, + bias=self.config.mha_output_proj_bias, + init_fn=init_fn, + device=self.device, + dtype=self.dtype, + ) + if self.config.mha_output_weight_normalization: + output_proj = torch.nn.utils.parametrizations.weight_norm(output_proj) + + return QKNormMultiheadAttention( + self.model_dim, + self.config.num_attn_heads, + kv_dim=kv_dim, + pos_encoder=pos_encoder, + sdpa=sdpa, + output_proj=output_proj, + enable_qk_layernorm=self.config.enable_qk_layernorm, + weight_normalization=self.config.mha_qkv_weight_normalization, + layer_norm_factory=layer_norm_factory, + state_factory=FullAttentionState, + device=self.device, + dtype=self.dtype, + ) + + def build_ffn(self) -> FeedForwardNetwork: + """Build a Transformer feed-forward network.""" + if self.config.use_swiglu: + # Default gate_activation is torch.nn.SiLU + return GLUFeedForwardNetwork( + self.model_dim, + self.config.ffn_inner_dim, + bias=True, + inner_dim_scale=2 / 3, + inner_dim_to_multiple=256, + device=self.device, + dtype=self.dtype, + ) + + ffn_inner_activation = parse_activation_fn( + self.config.ffn_inner_activation_name + ) + norm_order = parse_norm_order(self.config.norm_order_style) + + return StandardFeedForwardNetwork( + self.model_dim, + self.config.ffn_inner_dim, + inner_activation=ffn_inner_activation, + bias=True, + norm_order=norm_order, + device=self.device, + dtype=self.dtype, + ) diff --git a/lcm/utils/__init__.py b/lcm/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..10a7068d9a823ab535754fe3f5d1d5637f4801ac --- /dev/null +++ b/lcm/utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# diff --git a/lcm/utils/card_utils.py b/lcm/utils/card_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8ab337ea569c148bfe51a7fc0bc77c4bad57737d --- /dev/null +++ b/lcm/utils/card_utils.py @@ -0,0 +1,223 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# + +import dataclasses +import logging +from pathlib import Path +from typing import Dict, Optional, Union + +import yaml +from fairseq2.assets import ( + AssetNotFoundError, + InProcAssetMetadataProvider, + default_asset_store, +) +from fairseq2.assets.card import AssetCard +from fairseq2.checkpoint import FileCheckpointManager +from fairseq2.gang import FakeGang +from fairseq2.models import get_model_family +from fairseq2.typing import DataType, Device + +from lcm.models.abstract_lcm import AbstractLCModel, AbstractLCModelConfig +from lcm.utils.model_type_registry import lcm_model_type_registry + +logger = logging.getLogger(__file__) + + +def create_model_card( + checkpoint_path: Path, + model_config: Union[Dict, AbstractLCModelConfig, None], + model_type: str, # TODO: take this parameter from the config + model_name="on_the_fly_lcm", + model_arch: Optional[str] = None, + **additional_card_kwargs, +) -> AssetCard: + """ + Create an LCModel card given the checkpoint path and model args + Args: + - `checkpoint_path`: Path to the checkpoint to evaluate + - `model_config`: model parmeters + the default arch + """ + + # Create a fairseq2 model card on the fly. + # assert ( + # checkpoint_path.is_file() + # ), f"Couldn't find the checkpoint at {checkpoint_path}" + + if isinstance(model_config, AbstractLCModelConfig): + model_config = dataclasses.asdict(model_config) + + model_card_info = { + "name": model_name, + "model_family": model_type, + "checkpoint": "file://" + checkpoint_path.as_posix(), + **additional_card_kwargs, + } + + if model_config is not None: + model_card_info["model_config"] = model_config + + if model_arch is not None: + model_card_info["model_arch"] = model_arch + + default_asset_store.metadata_providers.append( + InProcAssetMetadataProvider([model_card_info]) + ) + return default_asset_store.retrieve_card(model_name) + + +def load_model_with_overrides( + model_dir: Path, + step: Optional[int] = None, + model_type: Optional[str] = None, + device: Optional[Device] = None, + dtype: Optional[DataType] = None, + model_filename: str = "model.pt", +): + if step is not None: + checkpoint_path = model_dir / f"checkpoints/step_{step}" / model_filename + else: + checkpoint_path = model_dir / model_filename + + # New checkpoint + config_path = checkpoint_path.parent / "model_card.yaml" + if config_path.exists(): + try: + return load_model_from_card( + config_path.as_posix(), device=device, dtype=dtype + ) + except Exception as exc: + logger.warning( + f"Model card {config_path} exists but is not valid ({exc}). " + "Try global config instead." + ) + + # Old checkpoint + config_path = model_dir / "config_logs/all_config.yaml" + if config_path.exists(): + assert model_type, f"Need explicit model_type for checkpoint {checkpoint_path}" + with open(config_path, "r") as f: + config = yaml.full_load(f) + model_config = config["trainer"]["model_config_or_name"] + + temporary_card = create_model_card( + checkpoint_path=checkpoint_path, + model_config=model_config, + model_type=model_type, + model_arch=f"toy_{model_type}", + ) + loader_fn = lcm_model_type_registry.get_model_loader(model_type=model_type) + return loader_fn(temporary_card, device=device, dtype=dtype) # type: ignore + else: + raise ValueError(f"{model_dir} is not a valid model directory") + + +def create_model_card_from_training_folder( + folder: Union[str, Path], + card_name: str, + step_nr: Optional[int] = None, +) -> AssetCard: + """ + Extract the model config and the last checkpoint path using the checkpoint manager. + Create and return a model card + """ + folder_path = Path(folder) + assert folder_path.exists(), f"Model directory {folder} does not exist." + cp_dir = folder_path / "checkpoints" + + gang = FakeGang() + checkpoint_manager = FileCheckpointManager(cp_dir, gang) + + if step_nr is None: + step_numbers = checkpoint_manager.get_step_numbers() + if not step_numbers: + raise ValueError( + f"In {cp_dir}, no step number with model checkpoints was detected!" + ) + step_nr = step_numbers[-1] + logger.info(f"Automatically setting step number as {step_nr}") + + metadata = checkpoint_manager.load_metadata(step_nr) + assert metadata is not None, "The checkpoint does not have metadata." + + training_config = metadata["config"] + model_config = training_config.model_config_or_name + + cp_fn = checkpoint_manager._checkpoint_dir / f"step_{step_nr}" / "model.pt" + assert cp_fn, ( + f"Checkpoint manager could not extract checkpoint path for step {step_nr}." + ) + # TODO: deal with the fine-tuning case, where model_config is a string + if isinstance(model_config, str): + parent_card = default_asset_store.retrieve_card(model_config) + model_config = parent_card._metadata["model_config_or_name"] + model_type = parent_card._metadata["model_family"] + else: + model_type = model_config.model_type + + card = create_model_card( + checkpoint_path=cp_fn.absolute(), + model_config=model_config, + model_type=model_type, + model_arch=f"toy_{model_type}", # TODO: get rid of the toy architecture when FS2 allows it + model_name=card_name, + ) + return card + + +def save_model_card(card: AssetCard, path: Union[str, Path]) -> None: + """Save a model card as YAML.""" + card_data = card._metadata # TODO: use the exposed attribute when available + with open(path, "w", encoding="utf-8") as outfile: + yaml.dump(card_data, outfile, default_flow_style=False) + + +def load_model_from_card( + model_name: str, + device: Optional[Device] = None, + dtype: Optional[DataType] = None, +) -> AbstractLCModel: + """ + Load LC model from the given assed card or path. + The parameter `model_name` can be interpreted in multiple ways: + - as the name of the model card + - as the path to the yaml file of the model card + - as the path to the training directory of the model + - as the path to the model checkpoint (within a training directory, because we need to find the config) + """ + try: + card = default_asset_store.retrieve_card(model_name) + except AssetNotFoundError as err: + path = Path(model_name) + # If the card is not found, try looking it up by interpreting model_name as a path to the yaml card. + if path.exists() and path.suffix == ".yaml": + with open(path, "r", encoding="utf-8") as f: + card_data = yaml.full_load(f) + model_name = card_data["name"] + card = AssetCard(card_data) + # If the card is not found, try interpreting model_name as the model training directory + elif (path / "checkpoints").exists(): + card = create_model_card_from_training_folder( + path, card_name="temporary_card" + ) + # If the card is not found, try interpreting model_name as the path to the checkpoint within a training directory + elif ( + path.suffix == ".pt" + and path.parent.name.startswith("step_") + and path.parent.parent.name == "checkpoints" + ): + training_dir = path.parent.parent.parent + step_nr = int(path.parent.name[5:]) + card = create_model_card_from_training_folder( + training_dir, card_name="temporary_card", step_nr=step_nr + ) + else: + raise err + logger.info(f"Card loaded: {card}") + model_type = get_model_family(card) + loader = lcm_model_type_registry.get_model_loader(model_type=model_type) + model = loader(card, device=device, dtype=dtype) + return model diff --git a/lcm/utils/common.py b/lcm/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..462b866f8589936af10a3bfe254b500d48f6ab7d --- /dev/null +++ b/lcm/utils/common.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# + +import ctypes +from abc import abstractmethod +from pathlib import Path +from typing import ( + Any, + Dict, + Iterable, + Optional, + Protocol, + Sized, + Type, + TypeVar, + Union, + runtime_checkable, +) + +import torch +from omegaconf import DictConfig, OmegaConf + +root_working_dir = Path(__file__).parent.parent.parent + + +def set_mkl_num_threads(): + """Setting mkl num threads to 1, so that we don't get thread explosion.""" + mkl_rt = ctypes.CDLL("libmkl_rt.so") + mkl_rt.mkl_set_num_threads(ctypes.byref(ctypes.c_int(1))) + + +def working_dir_resolver(p: str): + """The omegaconf resolver that translates a relative path to the absolute path""" + return "file://" + str(root_working_dir.joinpath(p).resolve()) + + +def setup_conf(): + """Register the common Hydra config groups used in LCM (for now only the launcher)""" + from stopes.pipelines import config_registry # noqa + + recipe_root = Path(__file__).parent.parent.parent / "recipes" + config_registry["lcm-common"] = "file://" + str((recipe_root / "common").resolve()) + config_registry["lcm-root"] = "file://" + str(recipe_root.resolve()) + + # Register omegaconf resovlers + OmegaConf.register_new_resolver("realpath", working_dir_resolver, replace=True) + + +def torch_type( + dtype: Optional[Union[str, torch.dtype]] = None, +) -> Optional[torch.dtype]: + # Convert dtyp string from the checkpoint to torch.dtype + # https://github.com/pytorch/pytorch/issues/40471 + if dtype is None: + return None + + if isinstance(dtype, torch.dtype): + return dtype + + _dtype = eval(dtype) # type: ignore + assert isinstance(_dtype, torch.dtype), f"Invalid dtype value: {dtype}" + return _dtype + + +@runtime_checkable +class Batched(Sized, Protocol): + """Abstract class for batched data""" + + @abstractmethod + def __getitem__(self, i: int) -> Any: ... + + +T = TypeVar("T") + + +def promote_config(config: Union[T, DictConfig, Dict], config_cls: Type[T]) -> T: + if isinstance(config, (Dict, DictConfig)): + import dacite + + if isinstance(config, DictConfig): + config = OmegaConf.to_container(config) # type: ignore + + return dacite.from_dict( + data_class=config_cls, + data=config, # type: ignore + config=dacite.Config(cast=[Path]), # type: ignore + ) + else: + assert isinstance(config, config_cls), f"Unknown config type: {type(config)}" + return config + + +def batched(inputs: Iterable, batch_size=10000) -> Iterable: + batch = [] + for line in inputs: + batch.append(line) + if len(batch) == batch_size: + yield batch + batch = [] + if len(batch) > 0: + yield batch diff --git a/lcm/utils/data_utils.py b/lcm/utils/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d061fd309123ac0a2d623553dc67410f69dd14d9 --- /dev/null +++ b/lcm/utils/data_utils.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# + + +from dataclasses import fields +from typing import Any, List, Mapping + +from fairseq2.typing import DataClass, is_dataclass_instance + + +def update_dataclass( + obj: DataClass, + overrides: Mapping[str, Any], +) -> List[str]: + """Update ``obj`` with the data contained in ``overrides`` Return the unknown fields. + Copied from an old version of fairseq2 with simplification. + + :param obj: + The data class instance to update. + :param overrides: + The dictionary containing the data to set in ``obj``. + """ + + unknown_fields: List[str] = [] + + field_path: List[str] = [] + + # The dataset config has a special attribute `silent_freeze` that does not allow hard update + forbidden_fields_ = ["silent_freeze"] + + def update(obj_: DataClass, overrides_: Mapping[str, Any]) -> None: + overrides_copy = {**overrides_} + + for field in fields(obj_): + if field.name in forbidden_fields_: + continue + value = getattr(obj_, field.name) + + try: + override = overrides_copy.pop(field.name) + except KeyError: + continue + + # Recursively traverse child dataclasses. + if override is not None and is_dataclass_instance(value): + if not isinstance(override, Mapping): + pathname = ".".join(field_path + [field.name]) + + raise RuntimeError( + pathname, + f"The field '{pathname}' is expected to be of type `{type(value)}`, but is of type `{type(override)}` instead.", # fmt: skip + ) + + field_path.append(field.name) + + update(value, override) + + field_path.pop() + else: + setattr(obj_, field.name, override) + + if overrides_copy: + unknown_fields.extend( + ".".join(field_path + [name]) for name in overrides_copy + ) + + update(obj, overrides) + + unknown_fields.sort() + + return unknown_fields diff --git a/lcm/utils/distributed.py b/lcm/utils/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..04d5d11e552c64fcab883b80b735391115232609 --- /dev/null +++ b/lcm/utils/distributed.py @@ -0,0 +1,213 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# + +import logging +import os +import random +import subprocess +import warnings +from datetime import timedelta +from functools import partial +from typing import Any, List, Literal, Optional, Set, Tuple, Type + +import submitit +import torch +import torch.distributed as dist +from fairseq2.gang import Gang, ProcessGroupGang +from fairseq2.logging import get_log_writer +from fairseq2.nn.fsdp import ( + FSDP_LOW_MEMORY_POLICY, + FSDP_STANDARD_MEMORY_POLICY, + FSDP_VERY_LOW_MEMORY_POLICY, + FSDPMemoryPolicy, + FSDPWrapPolicy, +) +from fairseq2.nn.transformer import ( + TransformerDecoder, + TransformerDecoderLayer, + TransformerEncoder, + TransformerEncoderLayer, +) +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from torch.nn import Module + +logger = get_log_writer(__name__) + + +SUPPORTED_FSDP_MEMORY_POLICIES = Literal["standard", "low", "very_low"] +SUPPORTED_FSDP_WRAP_POLICIES = Literal["layer", "stack", "model"] + + +def get_fsdp_memory_policy( + policy: SUPPORTED_FSDP_MEMORY_POLICIES = "standard", +) -> FSDPMemoryPolicy: + fsdp_memory_policy: FSDPMemoryPolicy + if policy == "standard": + fsdp_memory_policy = FSDP_STANDARD_MEMORY_POLICY + elif policy == "low": + fsdp_memory_policy = FSDP_LOW_MEMORY_POLICY + elif policy == "very_low": + fsdp_memory_policy = FSDP_VERY_LOW_MEMORY_POLICY + else: + raise ValueError("Unsupported policy {policy}. Choose from {}") + + return fsdp_memory_policy + + +def get_fsdp_wrap_policy( + model: Module, wrap_granularity: SUPPORTED_FSDP_WRAP_POLICIES = "layer" +) -> Tuple[Optional[FSDPWrapPolicy], Optional[List[Module]]]: + """Return the FSDP wrap policy for ``model`` along with ignored modules. + + :param model: + The model to be wrapped. + :param wrap_granularity: + The granularity at which to wrap modules of ``model``. + + - 'layer': Wraps individual layers (e.g. :class:`TransformerDecoderLayer`). + - 'stack': Wraps layer stacks (e.g. :class:`TransformerDecoder`). + - 'model': Wraps ``model`` only. + + Copied over from fs2 to experiment easily with fsdp wrap policies + """ + if wrap_granularity == "model": + return None, None + + kls: Set[Type[Module]] + + if wrap_granularity == "stack": + kls = {TransformerEncoder, TransformerDecoder} + elif wrap_granularity == "layer": + kls = { + TransformerEncoderLayer, + TransformerDecoderLayer, + } + else: + raise ValueError( + f"`wrap_granularity` must be 'layer', 'stack', or 'model', but is '{wrap_granularity}' instead." + ) + + wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls=kls) + + return wrap_policy, None + + +def init_process_group(config: Any, logger: logging.Logger) -> Gang: + if getattr(config, "use_submitit", True): + try: + submitit.helpers.TorchDistributedEnvironment().export(overwrite=True) + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1" + + except RuntimeError: + warnings.warn( + "looks like you are not in a submitit/stopes job. \ + You probably want to override use_submitit=false", + stacklevel=2, + ) + + timeout = timedelta(minutes=15) + + gang = ProcessGroupGang.init_default_process_group( + ok_initialized=False, + timeout=timeout, + ) + logger.info(f"Initialized gang with default process group (timeout={timeout})") + + return gang + + +def is_torch_run() -> bool: + return os.environ.get("TORCHELASTIC_RUN_ID") is not None + + +def is_slurm_job() -> bool: + return "SLURM_JOB_ID" in os.environ + + +def get_global_rank() -> int: + if dist.is_initialized(): + return dist.get_rank() + if is_torch_run(): + return int(os.environ["RANK"]) + if is_slurm_job(): + return int(os.environ["SLURM_PROCID"]) + return 0 + + +def get_local_rank() -> int: + if is_torch_run(): + return int(os.environ["LOCAL_RANK"]) + if is_slurm_job(): + return int(os.environ["SLURM_LOCALID"]) + return 0 + + +def get_world_size() -> int: + if dist.is_initialized(): + return dist.get_world_size() + if is_torch_run(): + return int(os.environ["WORLD_SIZE"]) + if is_slurm_job(): + return int(os.environ["SLURM_NTASKS"]) + return 1 + + +def get_master_addr() -> str: + if is_torch_run(): + return os.environ["MASTER_ADDR"] + if is_slurm_job(): + hostnames = subprocess.check_output( + ["scontrol", "show", "hostnames", os.environ["SLURM_JOB_NODELIST"]] + ) + return hostnames.split()[0].decode("utf-8") + return "127.0.0.1" + + +def get_master_port(job_id: int) -> Optional[int]: + if is_torch_run(): + return int(os.environ["MASTER_PORT"]) + else: + MIN_MASTER_PORT, MAX_MASTER_PORT = (20000, 60000) + rng = random.Random(job_id) + return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT) + + +def init_torch_distributed( + backend: str = "cpu:gloo,cuda:nccl", + port: Optional[str] = None, + max_attempt: int = 5, +) -> None: + if dist.is_initialized(): + return + os.environ["RANK"] = str(get_global_rank()) + os.environ["WORLD_SIZE"] = str(get_world_size()) + + master_addr = get_master_addr() + + # Allow max_attempt to be set directly via os environment variable + # TORCH_DISTRIBUTED_PORT_ATTEMPTS + if os.environ.get("TORCH_DISTRIBUTED_PORT_ATTEMPTS", None): + max_attempt = int(os.environ["TORCH_DISTRIBUTED_PORT_ATTEMPTS"]) + attempt = 0 + while True: + try: + os.environ["MASTER_ADDR"] = master_addr + if port is None: + port = str( + get_master_port(job_id=int(os.environ.get("SLURM_JOB_ID", -1))) + ) + os.environ["MASTER_PORT"] = port + local_rank = get_local_rank() + if "nccl" in backend: + torch.cuda.set_device(local_rank) + timeout = timedelta(hours=10) + dist.init_process_group(backend=backend, timeout=timeout) + break + except (dist.DistNetworkError, RuntimeError) as e: + attempt += 1 + if attempt == max_attempt: + raise RuntimeError( + "Failed to initialize torch.distributed after 5 max attempts" + ) from e diff --git a/lcm/utils/logging.py b/lcm/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..66f550bca2156fb0771923546d93cf29becde7d1 --- /dev/null +++ b/lcm/utils/logging.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# + +import os +import subprocess +from pathlib import Path +from typing import Dict + +import torch.distributed as dist +from fairseq2.gang import get_rank +from fairseq2.logging import get_log_writer +from fairseq2.recipes.logging import _setup_aten_logging, _setup_nccl_logging +from fairseq2.recipes.utils.log import log_environment_info +from fairseq2.typing import Device + +logger = get_log_writer(__name__) + +LCM_REPOS = ["lcm", "fairseq2", "sonar", "stopes"] + + +def setup_additional_logging(log_folder: Path): + slurm_job_id: str = os.environ.get("SLURM_JOB_ID", "local") + base_log_file = log_folder / f"{slurm_job_id}_{get_rank()}.log" + _setup_aten_logging(base_log_file, force=False) + _setup_nccl_logging(base_log_file, force=False) + + +def log_git_status( + repo: str = "lcm", + tolerate_uncommitted: bool = False, +) -> str: + assert repo in LCM_REPOS, ( + f"Only the LCM core repos ({LCM_REPOS}) are supported in `log_git_status`" + ) + + repo_path = os.path.dirname(globals()[repo].__file__) + + try: + # check for modifications + mod_output = subprocess.run( + f"cd {repo_path}; git status --porcelain", capture_output=True, shell=True + ) + modifications = mod_output.stdout.decode("utf-8").split("\n") + uncommitted = len( + [ + m + for m in modifications + if m.startswith(" M") or m.startswith(("M ", "A ", "D ", "R ")) + ] + ) + if uncommitted > 0: + if tolerate_uncommitted: + logger.warning( + ( + "Changes to {} should be committed before running a job " + "- found {} change(s)." + " We will continue regardless, but the git commit hashes are unreliable!" + ).format(repo, uncommitted) + ) + else: + raise AssertionError( + f"Changes to {repo} should be committed before running a job - found {uncommitted} change(s). If runing tests try adding `--debug-training`" + ) + + # get commit hash + output = subprocess.run( + f"cd {repo_path}; git rev-parse HEAD", capture_output=True, shell=True + ) + commit_hash = output.stdout.decode("ascii").strip() + logger.info(f"{repo} ({repo_path}) commit hash: {commit_hash}") + + return commit_hash + + except AssertionError: + raise + + except BaseException: + raise ValueError( + f"Could not check the git revision hash, make sure you can run `git status` in {repo} ({repo_path})" + ) + + +def log_lcm_environment(tolerate_uncommitted: bool = False) -> Dict: + """ + For traceability and reproducibility, get the latest commit hash for the four key repos + """ + + commit_hashes = { + repo: log_git_status(repo, tolerate_uncommitted) for repo in LCM_REPOS + } + + return commit_hashes + + +def log_env_variables(device: Device) -> None: + """Log environment variables useful for debugging, including + fs2's `log_environment_info` to dump Fairseq2, torch, nccl and other relevant metadata + """ + for key in sorted(os.environ.keys()): + if not ( + key.startswith( + ("SLURM_", "SUBMITIT_", "NCCL_", "FI_", "CUDA_", "FAIRSEQ2_", "TORCH_") + ) + or key + in ( + "MASTER_ADDR", + "MASTER_PORT", + "RANK", + "WORLD_SIZE", + "LOCAL_RANK", + "LOCAL_WORLD_SIZE", + ) + ): + continue + value = os.environ[key] + logger.info(f"R{dist.get_rank()} -- {key}={value}") + + # For Fairseq2, torch and devices + log_environment_info(logger, device) diff --git a/lcm/utils/model_type_registry.py b/lcm/utils/model_type_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..974f7ebe0951e36b3b7f5c850467b0667f0d084e --- /dev/null +++ b/lcm/utils/model_type_registry.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# + + +from dataclasses import dataclass +from typing import Callable, Dict + + +@dataclass +class ModelTypeConfig: + """A container for all functions associated with a specific model type.""" + + model_type: str + config_loader: Callable + model_factory: Callable + model_loader: Callable + + +class ModelTypeRegistry: + """ + Represents a registry of model types. + In fairseq2 terms, "architecture" refers to a set of model hyperparameters, + and "model type" refers to a more generic way of constructing the model with the given hyperparameters. + """ + + _configs: Dict[str, ModelTypeConfig] + + def __init__(self) -> None: + self._configs = {} + + def register(self, model_type_config: ModelTypeConfig) -> None: + """Register a new architecture. + + :param arch_name: + The name of the architecture. + :param config_factory: + The factory to construct model configurations. + """ + model_type = model_type_config.model_type + assert model_type, ( + "To register a model type, the model_type parameter should be non-empty." + ) + if model_type in self._configs: + raise ValueError( + f"`model_type` must be a unique model type name, but '{model_type}' is already registered." + ) + self._configs[model_type] = model_type_config + + def get_config(self, model_type: str) -> ModelTypeConfig: + """Return the ModelTypeConfig for the specified model type. + + :param model_type: + The model type. + """ + # we import lcm.modules at runtime in order to populate the registy and avoid cyclical imports + + try: + return self._configs[model_type] + except KeyError: + raise ValueError( + f"The registry of model types does not contain a model type named '{model_type}'." + ) + + def get_model_loader(self, model_type: str) -> Callable: + """Get a model loader function for the given model type.""" + model_type_config = self.get_config(model_type) + return model_type_config.model_loader + + +lcm_model_type_registry = ModelTypeRegistry() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..f25358569de0581850c1fbb9a41dfa41e8dc61d9 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,139 @@ +[project] +name = "LexaLCM_Pre0_288M" +readme = "README.md" +requires-python = ">=3.10" +version = "0.1.0" +description = "A pre-trained LCM model with 288M parameters based on Meta FAIR's LCM architecture." +dependencies = [ + "dacite>=1.8.1", + "fire>=0.7.0", + "hydra-core>=1.3.2", + "importlib-resources~=6.4", + "numpy>=1.21", + "polars>=1.16.0", + "pyarrow>=16.1.0", + "retrying>=1.3.4", + "sentence-splitter>=1.4", + "sonar-space>=0.3.2", + "stopes[mono]>=2.2.0", + "tensorboard>=2.18.0", + "wandb>=0.19.11", +] + +classifiers = [ + "License :: OSI Approved :: MIT License", + "Topic :: Scientific/Engineering", + "Development Status :: 4 - Beta", +] + +[build-system] +requires = ["flit_core >=3.2,<4", "setuptools < 74"] +build-backend = "flit_core.buildapi" + +[tool.flit.module] +name = "lcm" # TODO change module name + +[project.optional-dependencies] +gpu = [ + "torch==2.5.1", + "torchaudio==2.5.1", + "fairseq2n==0.3.0rc1", + "fairseq2[arrow]==0.3.0rc1", +] +# cpu = [ +# "torch==2.5.1+cpu", +# "torchaudio==2.5.1+cpu", +# "fairseq2n==0.3.0rc1", +# "fairseq2[arrow]==0.3.0rc1", +# ] +eval = [ + "accelerate>=1.2.0", + "bert-score>=0.3.13", + "editdistance>=0.8.1", + "jinja2>=3.1.3", + "nltk>=3.9.1", + "rouge-score>=0.1.2", + "sacrebleu>=2.4.3", + "scikit-learn>=1.5.2", + "spacy>=3.7.5", + "textdescriptives>=2.8.2", + "tiktoken>=0.8.0", + "transformers>=4.45.0", + "fairscale>=0.4.13", +] +data = [ + "numpy>=1.21", + "numba>=0.60.0", + "spacy>=3.7.5", + "en_core_web_sm@https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl", + "sacremoses>=0.1.1", + "nltk>=3.8.1", + "scipy>=1.14", + "editdistance>=0.8.1", + "sacrebleu>=2.4.1", + "datasets>=2.18.0", + "wtpsplit>=2.1.0", + "transformers>=4.45.0", +] + + +[tool.ruff] +target-version = "py310" + +[tool.mypy] +python_version = "3.10" +show_error_codes = true +check_untyped_defs = true +ignore_missing_imports = true +implicit_optional = true +implicit_reexport = true + +files = [ + "lcm/", # TODO +] + +[tool.uv] +prerelease = "explicit" # for fairseq2 0.3.0rc1 + +# TODO Change versions +[tool.uv.sources] +fairseq2 = [ + { index = "fairseq2-gpu", extra = 'gpu' }, + # { index = "fairseq2-cpu", extra = 'cpu' } +] +fairseq2n = [ + { index = "fairseq2-gpu", extra = 'gpu' }, + # { index = "fairseq2-cpu", extra = 'cpu' } +] +torch={index="pytorch-gpu"} +torchaudio={index="pytorch-gpu"} +# torch={index="pytorch-cpu"} +# torchaudio={index="pytorch-cpu"} +# sonar-space = { git = "https://github.com/facebookresearch/SONAR", branch = "update_fs2" } # TODO + +[[tool.uv.index]] +name = "fairseq2-gpu" +url = "https://fair.pkg.atmeta.com/fairseq2/whl/rc/pt2.5.1/cu121" +explicit = true + +# [[tool.uv.index]] +# name = "fairseq2-cpu" +# url = "https://fair.pkg.atmeta.com/fairseq2/whl/rc/pt2.5.1/cpu/" +# explicit = true + +[[tool.uv.index]] +name = "pytorch-gpu" +url = "https://download.pytorch.org/whl/cu121" +explicit = true + +# [[tool.uv.index]] +# name = "pytorch-cpu" +# url = "https://download.pytorch.org/whl/cpu" +# explicit = true + +[dependency-groups] +dev = ["pytest-asyncio>=0.23.2", "pytest>=8.0.0"] + + +[project.entry-points."fairseq2"] +"fairseq2" = "lcm:setup_fairseq2" # TODO update lcm diff --git a/scripts/run_inference.py b/scripts/run_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..d58ca43785b779ef1d07343e0bf8e01849b1f1ae --- /dev/null +++ b/scripts/run_inference.py @@ -0,0 +1,108 @@ +import torch +from lcm.utils.card_utils import load_model_from_card +from lcm.inference.two_tower_diffusion_lcm import TwoTowerDiffusionLCMGenerator, DiffusionLCMGeneratorOptions +from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline, EmbeddingToTextModelPipeline +from lcm.datasets.batch import EmbeddingsBatch + +def main(): + # Setup device + device = torch.device("cuda:0") + + # Load model + model_card = "./_LexaLCM_Pre0/Checkpoints/LCM_TwoTower_Pre0/checkpoints/step_250000/model_card.yaml" + model = load_model_from_card(model_card, device=device, dtype=torch.float32) + + # Setup generator options + options = DiffusionLCMGeneratorOptions( + guidance_scale=3.0, # Increased from 1.0 to make generation more focused + guidance_rescale=0.0, + ddim_eta=0.0, + initial_noise_scale=1.0, + inference_timesteps=100, + clip_noise=100, + thresholding=False, + dynamic_thresholding_ratio=0.995, + sample_max_value=6.0 + ) + + # Create generator + generator = TwoTowerDiffusionLCMGenerator(model=model, options=options) + + # Setup text encoders/decoders + text_encoder = TextToEmbeddingModelPipeline( + encoder="text_sonar_basic_encoder", + tokenizer="text_sonar_basic_encoder", + device=device, + dtype=torch.float32 + ) + + text_decoder = EmbeddingToTextModelPipeline( + decoder="text_sonar_basic_decoder", + tokenizer="text_sonar_basic_decoder", + device=device, + dtype=torch.float32 + ) + + # Get EOS embedding + eos_text = "End of text." + eos_embedding = text_encoder.predict([eos_text], source_lang="eng_Latn") + generator.eos_vec = eos_embedding.squeeze(0) # Remove batch dimension + + # Example prompts (each inner list is a multi-sentence prompt) + prompts = [ + ["Petals fall in the wind.", "They swirl and dance and float away.", "Then all becomes still again."], + ["Like whisps of light, the moonlight meets the rolling brook.", "Upon seeing it glimmer, she turns and smiles.", "Her friend was glad they could share this tranquil moment."], + ["Tokyo is the modern capital of Japan.", "Although it is currently the case, historically, cities such as Kyoto and Nara have also served as the capital."] + ] + + print("\nProcessing prompts:") + all_prompt_embeddings = [] + max_sentences = max(len(prompt) for prompt in prompts) + + # Process each multi-sentence prompt + for i, prompt_sentences in enumerate(prompts): + print(f"\nPrompt {i+1}:") + for sentence in prompt_sentences: + print(f" {sentence}") + + # Encode each sentence separately + sentence_embeddings = text_encoder.predict(prompt_sentences, source_lang="eng_Latn") + print(f" Sentence embeddings shape: {sentence_embeddings.shape}") + + # Pad to max_sentences if needed + if len(prompt_sentences) < max_sentences: + padding = torch.zeros((max_sentences - len(prompt_sentences), sentence_embeddings.shape[1]), + device=device, dtype=sentence_embeddings.dtype) + sentence_embeddings = torch.cat([sentence_embeddings, padding], dim=0) + + all_prompt_embeddings.append(sentence_embeddings) + + # Stack all prompts into a batch + prompt_embeddings = torch.stack(all_prompt_embeddings) + print("\nFinal batch shape:", prompt_embeddings.shape) + + batch = EmbeddingsBatch(prompt_embeddings, None) + + # Generate + output = generator( + batch_input=batch, + max_gen_len=24, + min_gen_len=10, + temperature=0.1 + ) + + # Decode generated embeddings + print("\nGenerated outputs:") + for i, hypotheses in enumerate(output.hypotheses): + print(f"\nOutput for Prompt {i+1}:") + for hypothesis in hypotheses: + generated_text = text_decoder.predict( + hypothesis.seq, + target_lang="eng_Latn", + max_seq_len=256, + temperature=1.0 + ) + print("Generated text:", generated_text) + +if __name__ == "__main__": + main() \ No newline at end of file