Lexa commited on
Commit
3d79eb3
·
0 Parent(s):

Initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +118 -0
  2. .pre-commit-config.yaml +16 -0
  3. LICENSE +21 -0
  4. README.md +29 -0
  5. lcm/__init__.py +22 -0
  6. lcm/cards/Normalizer_Wikipedia_En_1M.pt +0 -0
  7. lcm/cards/sonar_normalizer.yaml +4 -0
  8. lcm/datacards/datacards.yaml +5 -0
  9. lcm/datasets/__init__.py +4 -0
  10. lcm/datasets/batch.py +425 -0
  11. lcm/inference/lcm/__init__.py +9 -0
  12. lcm/inference/lcm/generator.py +448 -0
  13. lcm/inference/lcm/scorer.py +198 -0
  14. lcm/inference/two_tower_diffusion_lcm/__init__.py +16 -0
  15. lcm/inference/two_tower_diffusion_lcm/generator.py +466 -0
  16. lcm/inference/two_tower_diffusion_lcm/scorer.py +314 -0
  17. lcm/models/__init__.py +15 -0
  18. lcm/models/abstract_lcm/__init__.py +16 -0
  19. lcm/models/abstract_lcm/builder.py +106 -0
  20. lcm/models/base_lcm/__init__.py +20 -0
  21. lcm/models/base_lcm/archs.py +49 -0
  22. lcm/models/base_lcm/builder.py +285 -0
  23. lcm/models/base_lcm/frontend.py +183 -0
  24. lcm/models/base_lcm/loader.py +55 -0
  25. lcm/models/base_lcm/normalization.py +50 -0
  26. lcm/models/sonar_normalizer/__init__.py +20 -0
  27. lcm/models/sonar_normalizer/archs.py +40 -0
  28. lcm/models/sonar_normalizer/builder.py +210 -0
  29. lcm/models/sonar_normalizer/loader.py +28 -0
  30. lcm/models/two_tower_diffusion_lcm/__init__.py +7 -0
  31. lcm/models/two_tower_diffusion_lcm/archs.py +207 -0
  32. lcm/models/two_tower_diffusion_lcm/builder.py +628 -0
  33. lcm/models/two_tower_diffusion_lcm/frontend.py +152 -0
  34. lcm/models/two_tower_diffusion_lcm/loader.py +44 -0
  35. lcm/nn/__init__.py +4 -0
  36. lcm/nn/denoisers/__init__.py +17 -0
  37. lcm/nn/denoisers/attention_masks.py +228 -0
  38. lcm/nn/denoisers/factory.py +192 -0
  39. lcm/nn/denoisers/lcm_denoiser.py +546 -0
  40. lcm/nn/incremental_state.py +43 -0
  41. lcm/nn/initialization.py +152 -0
  42. lcm/nn/normalization.py +88 -0
  43. lcm/nn/projection.py +86 -0
  44. lcm/nn/schedulers/__init__.py +17 -0
  45. lcm/nn/schedulers/ddim.py +741 -0
  46. lcm/nn/timestep_encoder.py +122 -0
  47. lcm/nn/transformer/__init__.py +24 -0
  48. lcm/nn/transformer/attention.py +307 -0
  49. lcm/nn/transformer/decoder.py +176 -0
  50. lcm/nn/transformer/factory.py +300 -0
.gitignore ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # JetBrains PyCharm IDE
3
+ .idea/
4
+
5
+ # Byte-compiled / optimized / DLL files
6
+ **/*/__pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # macOS dir files
14
+ .DS_Store
15
+
16
+ # Distribution / packaging
17
+ .Python
18
+ env/
19
+ build/
20
+ develop-eggs/
21
+ dist/
22
+ downloads/
23
+ eggs/
24
+ .eggs/
25
+ lib64/
26
+ parts/
27
+ sdist/
28
+ var/
29
+ wheels/
30
+ *.egg-info/
31
+ .installed.cfg
32
+ *.egg
33
+
34
+ # PyInstaller
35
+ # Usually these files are written by a python script from a template
36
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
37
+ *.manifest
38
+ *.spec
39
+
40
+ # Installer logs
41
+ pip-log.txt
42
+ pip-delete-this-directory.txt
43
+
44
+ # Unit test / coverage reports
45
+ htmlcov/
46
+ .tox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ .hypothesis/
54
+
55
+ # Translations
56
+ *.mo
57
+ *.pot
58
+
59
+ # Django stuff:
60
+ *.log
61
+ local_settings.py
62
+
63
+ # Flask stuff:
64
+ instance/
65
+ .webassets-cache
66
+
67
+ # Scrapy stuff:
68
+ .scrapy
69
+
70
+ # Sphinx documentation
71
+ docs/_build/
72
+
73
+ # PyBuilder
74
+ target/
75
+
76
+ # Jupyter Notebook
77
+ .ipynb_checkpoints
78
+
79
+ # pyenv
80
+ .python-version
81
+
82
+ # dotenv
83
+ .env
84
+
85
+ # virtualenv
86
+ .venv
87
+ venv/
88
+ ENV/
89
+
90
+ # mkdocs documentation
91
+ /site
92
+
93
+ # mypy
94
+ .mypy_cache/
95
+
96
+ .pytest_cache
97
+ .ruff_cache
98
+
99
+ # VSCODE
100
+ .vscode/ftp-sync.json
101
+ .vscode/settings.json
102
+ .vscode/launch.json
103
+
104
+ # stopes logs
105
+ executor_logs/
106
+ config_logs/
107
+ outputs/
108
+
109
+ logs/
110
+ **/dask_jobqueue_logs
111
+ core.*
112
+ mortimer_env.txt
113
+
114
+ # datasets
115
+ _LexaLCM_Block0/Datasets/
116
+
117
+ # UV
118
+ uv.lock
.pre-commit-config.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/astral-sh/uv-pre-commit
3
+ rev: 0.5.7
4
+ hooks:
5
+ - id: uv-lock
6
+ - repo: https://github.com/astral-sh/ruff-pre-commit
7
+ rev: v0.8.2
8
+ hooks:
9
+ # Lint
10
+ - id: ruff
11
+ args: [ --fix ]
12
+ # sort imports
13
+ - id: ruff
14
+ args: ["check", "--select", "I", "--fix"]
15
+ # format
16
+ - id: ruff-format
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Alexandra 'Lexa' Baldwin
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LexaLCM Pre0 288M Pre-trained Large Concept Model
2
+ A pre-trained LCM model with 288M parameters based on Meta FAIR's LCM architecture.
3
+
4
+ [[Paper]](https://ai.meta.com/research/publications/large-concept-models-language-modeling-in-a-sentence-representation-space/)
5
+
6
+ Note: These instructions are for running the model on a single machine with a single GPU. If your system does not have a GPU that supports at least CUDA 12.1, or if you intend to execute this in the cloud, you'll need to modify the code per your requirements.
7
+
8
+ ## 1. Instal the Intel MKL runtime
9
+ ```bash
10
+ sudo apt update
11
+ sudo apt install libmkl-rt
12
+ export LD_LIBRARY_PATH=/opt/intel/mkl/lib/intel64:$LD_LIBRARY_PATH
13
+ source ~/.bashrc
14
+ ```
15
+
16
+ ## 2. Install dependencies
17
+ ```bash
18
+ uv sync --extra gpu --extra eval --extra data
19
+ ```
20
+
21
+ ## 3. Update the model cards' paths
22
+ These two model cards' paths must be updated to use the current paths based on where they exist in your local filesystem.
23
+ * '_LexaLCM_Pre0/Checkpoints/LCM_TwoTower_Pre0/model_card.yaml'
24
+ * '_LexaLCM_Pre0/Checkpoints/LCM_TwoTower_Pre0/checkpoints/step_250000/model_card.yaml'
25
+
26
+ ## 4. Test the model's inference
27
+ ```bash
28
+ uv run --extra gpu scripts/run_inference.py
29
+ ```
lcm/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ """
7
+ LCM: Modular and Extensible Reasoning in an Embedding Space
8
+ Code base for training different LCM models.
9
+ """
10
+
11
+ from fairseq2 import setup_extensions
12
+ from fairseq2.assets import default_asset_store
13
+
14
+ __version__ = "0.1.0.dev0"
15
+
16
+
17
+ def setup_fairseq2() -> None:
18
+ default_asset_store.add_package_metadata_provider("lcm.cards")
19
+
20
+
21
+ # This call activates setup_fairseq2 and potentially other extensions,
22
+ setup_extensions()
lcm/cards/Normalizer_Wikipedia_En_1M.pt ADDED
Binary file (9.99 kB). View file
 
lcm/cards/sonar_normalizer.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ name: sonar_normalizer_wikipedia_en_1m
2
+ model_family: sonar_normalizer
3
+ model_arch: base
4
+ checkpoint: Normalizer_Wikipedia_En_1M.pt
lcm/datacards/datacards.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ name: "Data_Wikipedia_En_1M"
2
+ parquet_path:
3
+ local: "./_LexaLCM_Pre0/Datasets/Wikipedia_En_1M"
4
+ source_column: "text_sentences_sonar_emb"
5
+ source_text_column: "text_sentences"
lcm/datasets/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ #
lcm/datasets/batch.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # All rights reserved.
4
+ #
5
+ #
6
+
7
+ from copy import deepcopy
8
+ from dataclasses import dataclass
9
+ from enum import Enum
10
+ from typing import Any, Dict, List, Optional, Tuple, Union
11
+
12
+ import torch
13
+ from fairseq2.logging import get_log_writer
14
+ from fairseq2.models.sequence import SequenceBatch
15
+ from fairseq2.nn.padding import PaddingMask, pad_seqs
16
+ from fairseq2.typing import Device
17
+ from torch import Tensor
18
+ from torch.nn import Module
19
+
20
+ from lcm.utils.common import Batched
21
+
22
+ logger = get_log_writer(__name__)
23
+
24
+
25
+ DOC_LENGTHS = "__doc_lengths"
26
+
27
+
28
+ class LCMStyle(Enum):
29
+ """Specifies a style for preparing the LCM input."""
30
+
31
+ SUPERVISED = 1
32
+ """For when the model is fed supervised data with source & target sentences."""
33
+
34
+ UNSUPERVISED = 2
35
+ """For when the model is fed unsupervised data with source sentences only."""
36
+
37
+ PACKED_UNSUPERVISED = 3
38
+ """For when the model is fed ``packed`` unsupervised data with source sentences only.
39
+ This means that we will look for document_lengths and propagate them to the
40
+ packed causal masked attention and the packed position encoders"""
41
+
42
+
43
+ @dataclass
44
+ class EmbeddingsBatch:
45
+ """Represents a sequence of embeddings batch.
46
+ Resembles Fairseq2's SequenceBatch with additional properties"""
47
+
48
+ seqs: Tensor
49
+ """The sequences. *Shape:* :math:`(B,S,D)`, where :math:`B` is the batch
50
+ size, :math:`S` is the sequence length (in sentences per document),
51
+ and :math:`D` the embedding dimension
52
+ """
53
+
54
+ padding_mask: Optional[PaddingMask] = None
55
+ """The padding mask of ``seqs``. *Shape:* :math:`(B,S)`, where :math:`B` is
56
+ the batch size and :math:`S` is the sequence length."""
57
+
58
+ diffusion_timesteps: Optional[Tensor] = None
59
+ """Diffusion timesteps of noising process of ``seqs``. *Shape:* :math:`(B,S)`, where :math:`B` is
60
+ the batch size and :math:`S` is the sequence length."""
61
+
62
+ document_lengths: Optional[Tensor] = None
63
+ """Lengths of the documents (in sentences) present in the batch
64
+ Shape: (len_doc, )
65
+ """
66
+
67
+ source_lengths: Optional[Tensor] = None
68
+ """Lengths of source part for each element in batch, so that `seqs[i, :source_lengths[i]]` corresponds to source for each i in [0, batch_size).
69
+ Shape: (batch_size, )
70
+ """
71
+
72
+ def __post_init__(self):
73
+ if self.document_lengths is not None:
74
+ assert self.document_lengths.sum() == self.seqs.size(
75
+ 1
76
+ ) or 2 * self.document_lengths.sum() == self.seqs.size(1), (
77
+ "The legnths do no sum up to the sequence length "
78
+ "(nor half the length for doubled diffusion sequences). "
79
+ f"We have seqs.size={self.seqs.size()} and lengths={self.document_lengths} "
80
+ f"summing to {self.document_lengths.sum()}"
81
+ )
82
+
83
+ def __len__(self) -> int:
84
+ return self.batch_size
85
+
86
+ @property
87
+ def batch_size(self) -> int:
88
+ """The size of the batch."""
89
+ return self.seqs.size(0)
90
+
91
+ @property
92
+ def shape(self) -> torch.Size:
93
+ """The shape of the batch."""
94
+ return self.seqs.shape
95
+
96
+ @property
97
+ def device(self) -> Device:
98
+ """The device of the batch."""
99
+ return self.seqs.device
100
+
101
+ def clone(self):
102
+ return deepcopy(self)
103
+
104
+ def __getitem__(self, i: int) -> Any:
105
+ raise NotImplementedError(
106
+ "Access to each item in EmbeddingsBatch not allowed yet"
107
+ )
108
+
109
+ def unbatch(self) -> List[Tensor]:
110
+ if self.padding_mask is None:
111
+ return list(self.seqs)
112
+ else:
113
+ return [
114
+ tt[:length] for tt, length in zip(self.seqs, self.padding_mask.seq_lens)
115
+ ]
116
+
117
+ def get_last_element(self) -> Tensor:
118
+ if self.padding_mask:
119
+ return self.seqs[
120
+ torch.arange(len(self.padding_mask.seq_lens), device=self.seqs.device),
121
+ (self.padding_mask.seq_lens - 1),
122
+ ]
123
+ else:
124
+ return self.seqs[:, -1]
125
+
126
+ def set_last_element(self, element: Tensor) -> None:
127
+ element = element.to(self.seqs.device)
128
+ if self.padding_mask:
129
+ for i, slen in enumerate(self.padding_mask.seq_lens):
130
+ self.seqs[i, slen - 1] = element[i]
131
+ else:
132
+ self.seqs[:, -1] = element
133
+
134
+ def normalize_seqs(self, normalizer: Optional[Module]) -> "EmbeddingsBatch":
135
+ if normalizer is None:
136
+ logger.warning(
137
+ "The normalizer is None, as such, the features will remain unchanged"
138
+ )
139
+ return self
140
+
141
+ return EmbeddingsBatch(
142
+ seqs=normalizer.normalize(self.seqs),
143
+ padding_mask=self.padding_mask,
144
+ diffusion_timesteps=self.diffusion_timesteps,
145
+ document_lengths=self.document_lengths,
146
+ source_lengths=self.source_lengths,
147
+ )
148
+
149
+ def denormalize_seqs(self, normalizer: Optional[Module]) -> "EmbeddingsBatch":
150
+ if normalizer is None:
151
+ logger.warning(
152
+ "The normalizer is None, as such, the features will remain unchanged"
153
+ )
154
+ return self
155
+
156
+ return EmbeddingsBatch(
157
+ seqs=normalizer.denormalize(self.seqs),
158
+ padding_mask=self.padding_mask,
159
+ diffusion_timesteps=self.diffusion_timesteps,
160
+ document_lengths=self.document_lengths,
161
+ source_lengths=self.source_lengths,
162
+ )
163
+
164
+ def double_seqs(self) -> "EmbeddingsBatch":
165
+ """
166
+ performs sequence elements repeatition in sequence dim :
167
+ 1, 2, 3 -> 1, 1, 2, 2, 3, 3
168
+ x, y -> x, x, y, y
169
+ """
170
+ if self.padding_mask is not None:
171
+ doubled_padding_mask = PaddingMask(
172
+ seq_lens=2 * self.padding_mask._seq_lens,
173
+ batch_seq_len=2 * self.padding_mask._batch_seq_len,
174
+ )
175
+ else:
176
+ doubled_padding_mask = None
177
+
178
+ return EmbeddingsBatch(
179
+ seqs=torch.repeat_interleave(self.seqs, 2, dim=1),
180
+ padding_mask=doubled_padding_mask,
181
+ diffusion_timesteps=self.diffusion_timesteps,
182
+ document_lengths=self.document_lengths,
183
+ source_lengths=(
184
+ torch.repeat_interleave(self.source_lengths, 2, dim=0)
185
+ if self.source_lengths is not None
186
+ else None
187
+ ),
188
+ )
189
+
190
+ def flatten_to_sentences(self) -> Tensor:
191
+ """Flatten the sequence of embeddings
192
+ from B, S, D to B*~S, D after removing the padded positions
193
+ """
194
+
195
+ embed_dim = self.seqs.size(-1)
196
+
197
+ if self.padding_mask is not None:
198
+ seq_lens = self.padding_mask.seq_lens
199
+
200
+ embeds_mask = self.padding_mask.materialize().unsqueeze(-1)
201
+
202
+ # Remove padded positions and reshape as B*~S, D
203
+ flat_embeds = torch.masked_select(self.seqs, embeds_mask).reshape(
204
+ (-1, embed_dim)
205
+ )
206
+
207
+ # split per document/paragraph
208
+ flat_embeds_per_doc = list(torch.split(flat_embeds, seq_lens.tolist()))
209
+
210
+ # Concatenate back
211
+ flat_embeds = torch.concat(flat_embeds_per_doc)
212
+
213
+ else:
214
+ embeds = self.seqs
215
+
216
+ flat_embeds = embeds.reshape((-1, embed_dim))
217
+
218
+ return flat_embeds
219
+
220
+
221
+ @dataclass
222
+ class LCMInput(Batched):
223
+ """Dataclass for a pair of source/target sequences of SONAR embeddings"""
224
+
225
+ source: List[Tensor]
226
+ """source: SONAR embeddings of the source text
227
+ i.e [X^1 in (N_1, D), ... X^M in (N_M, D)]"""
228
+
229
+ target: Union[None, List[Tensor]]
230
+ """target: If supervised data: SONAR embeddings of the target text"""
231
+
232
+ tokens: Union[None, SequenceBatch] = None
233
+ """tokens: Tokenized flattened sentences for the SONAR decoder
234
+ (see the dataloader `_prepare_subword_tokens`)"""
235
+
236
+ target_tokens: Union[None, SequenceBatch] = None
237
+ """target_tokens: a sequence of the same shape as target_tokens, but shifted, to serve as the target.
238
+ (see the dataloader `_prepare_subword_tokens`)"""
239
+
240
+ name: Optional[str] = None
241
+ """
242
+ dataset name from which input is coming from
243
+ """
244
+ batch: Optional[Dict[str, Any]] = None
245
+ """raw batch of dataloader used for tracking and debugging"""
246
+
247
+ def __post_init__(self):
248
+ assert self.source is not None
249
+
250
+ length = len(self.source)
251
+
252
+ assert (self.target is None) or (len(self.target) == length), (
253
+ f"all elements in LCMInput should be of the same length, got {len(self.target)} and {length}"
254
+ )
255
+
256
+ def __len__(self) -> int:
257
+ return len(self.source)
258
+
259
+ def __getitem__(self, i: int) -> Union[Tensor, Tuple[Tensor, Tensor]]:
260
+ """
261
+ Return the content of item in the batch
262
+ """
263
+ if self.target is None:
264
+ return self.source[i]
265
+ else:
266
+ return self.source[i], self.target[i]
267
+
268
+ def prepare_input(
269
+ self,
270
+ style: LCMStyle = LCMStyle.UNSUPERVISED,
271
+ ) -> EmbeddingsBatch:
272
+ """
273
+ Adds special tokens to the source (& target) and prepares
274
+ the EmbeddingsBatch (tensor & its padding mask) that will be
275
+ forwarded in the LCM model.
276
+
277
+ `style`: LCMStyle is either supervised or
278
+ unsupervised (requires target embeddings)
279
+ """
280
+
281
+ if style == LCMStyle.UNSUPERVISED:
282
+ return get_embeddings_sequence(src_seqs=self.source)
283
+
284
+ elif style == LCMStyle.PACKED_UNSUPERVISED:
285
+ # If using PACKED_UNSUPERVISED, document_lengths will be added to `EmbeddingsBatch`
286
+ document_lengths = None
287
+ if self.batch is not None and self.batch.get(DOC_LENGTHS, None) is not None:
288
+ # document_lengths will only be consumed if the batch_size is 1
289
+ assert len(self.batch[DOC_LENGTHS]) == 1, "Expecting batch size of 1"
290
+
291
+ document_lengths = self.batch[DOC_LENGTHS][0].type(torch.int64)
292
+
293
+ return get_embeddings_sequence(
294
+ src_seqs=self.source,
295
+ document_lengths=document_lengths,
296
+ )
297
+
298
+ elif style == LCMStyle.SUPERVISED:
299
+ assert self.target is not None, (
300
+ "Missing target embeddings for a supervised batch"
301
+ )
302
+ return get_embeddings_sequence(
303
+ src_seqs=self.source,
304
+ tgt_seqs=self.target,
305
+ )
306
+
307
+ raise ValueError(f"Unsupported style={style} - could not prepare input")
308
+
309
+ def prepare_target_mask(
310
+ self,
311
+ embeddings: EmbeddingsBatch,
312
+ style: LCMStyle,
313
+ min_context_size: Optional[int] = None,
314
+ ) -> Tensor:
315
+ """Prepare a target mask signaling what positions
316
+ we should predict and optimize the model for
317
+
318
+ Args:
319
+ - min_context_size: the minimum context used to predict the next
320
+ concept (only used for unuspervised training)
321
+
322
+ """
323
+
324
+ batch_size, maxlen, _ = embeddings.seqs.size()
325
+
326
+ device = embeddings.seqs.device
327
+
328
+ if style == LCMStyle.UNSUPERVISED:
329
+ # A target mask for unsupervised next sentence prediction
330
+ # All positions are optimized/predicted starting from min_context_size
331
+ target_mask = torch.ones(
332
+ (batch_size, maxlen),
333
+ dtype=torch.bool,
334
+ device=device,
335
+ )
336
+ if min_context_size is not None:
337
+ target_mask[:, : min(min_context_size, target_mask.size(1))] = False
338
+
339
+ elif style == LCMStyle.PACKED_UNSUPERVISED:
340
+ # A target mask for unsupervised next sentence prediction when the data is packed
341
+ # All positions are optimized starting from min_context_size in each document
342
+ document_lengths = embeddings.document_lengths
343
+ if document_lengths is not None: # training
344
+
345
+ def get_document_target_mask(doc_length):
346
+ mask = torch.ones(doc_length, dtype=torch.bool, device=device)
347
+ mask[: min(min_context_size, doc_length)] = False
348
+ return mask
349
+
350
+ target_mask = torch.cat(
351
+ [get_document_target_mask(length) for length in document_lengths]
352
+ ).unsqueeze(0)
353
+
354
+ else: # validation with unpacked data:
355
+ target_mask = torch.ones(
356
+ (batch_size, maxlen),
357
+ dtype=torch.bool,
358
+ device=device,
359
+ )
360
+ if min_context_size is not None:
361
+ target_mask[:, : min(min_context_size, target_mask.size(1))] = False
362
+
363
+ elif style == LCMStyle.SUPERVISED:
364
+ # A target mask for target prediction
365
+ indices = torch.arange(maxlen, device=device).expand(batch_size, -1)
366
+
367
+ source_lengths = torch.tensor(
368
+ [seq.size(0) for seq in self.source],
369
+ device=device,
370
+ )
371
+
372
+ target_mask = indices >= source_lengths.unsqueeze(1).expand(-1, maxlen)
373
+
374
+ # Factor in padded positions:
375
+ if embeddings.padding_mask is not None:
376
+ target_mask = target_mask * embeddings.padding_mask.materialize()
377
+
378
+ return target_mask.detach()
379
+
380
+
381
+ def get_embeddings_sequence(
382
+ src_seqs: List[Tensor],
383
+ tgt_seqs: Optional[List[Tensor]] = None,
384
+ document_lengths: Optional[Tensor] = None,
385
+ double_target: bool = False,
386
+ ) -> EmbeddingsBatch:
387
+ seqs_lst: List[Tensor] = []
388
+ for src_seq, tgt_seq in zip(src_seqs, tgt_seqs or [None] * len(src_seqs)): # type: ignore
389
+ embeds: List[Tensor] = []
390
+ device, dtype = src_seq.device, src_seq.dtype
391
+
392
+ # mandatory src_sec
393
+ embeds.append(src_seq)
394
+
395
+ # supervised tgt_seq
396
+ if tgt_seq is not None:
397
+ tgt_seq = tgt_seq.to(device).type(dtype)
398
+
399
+ if double_target:
400
+ embeds.append(torch.repeat_interleave(tgt_seq, 2, dim=0))
401
+ else:
402
+ embeds.append(tgt_seq)
403
+
404
+ seqs_lst.append(torch.concat(embeds))
405
+
406
+ seqs, padding_mask = pad_seqs(seqs_lst)
407
+
408
+ if document_lengths is not None:
409
+ document_lengths = document_lengths.to(seqs.device)
410
+
411
+ if tgt_seqs is not None:
412
+ source_lengths = torch.tensor(
413
+ [seq.size(0) for seq in src_seqs], device=seqs.device
414
+ )
415
+ else:
416
+ source_lengths = None
417
+
418
+ output = EmbeddingsBatch(
419
+ seqs,
420
+ padding_mask=padding_mask,
421
+ document_lengths=document_lengths,
422
+ source_lengths=source_lengths,
423
+ )
424
+
425
+ return output
lcm/inference/lcm/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from lcm.inference.lcm.generator import LCMGenerator as LCMGenerator
7
+ from lcm.inference.lcm.generator import LCMGeneratorOptions as LCMGeneratorOptions
8
+
9
+ __all__ = ["LCMGenerator", "LCMGeneratorOptions"]
lcm/inference/lcm/generator.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from dataclasses import dataclass
7
+ from typing import List, Optional, Tuple
8
+
9
+ import torch
10
+ from fairseq2.generation.generator import (
11
+ GenerationCounters,
12
+ Hypothesis,
13
+ SequenceGeneratorOutput,
14
+ )
15
+ from fairseq2.logging import get_log_writer
16
+
17
+ from lcm.datasets.batch import EmbeddingsBatch, PaddingMask
18
+ from lcm.models.abstract_lcm import AbstractLCModel
19
+ from lcm.nn.incremental_state import LCMIncrementalStateBag
20
+
21
+ logger = get_log_writer(__name__)
22
+
23
+
24
+ """
25
+ This generator follows the style of existing generators in Fairseq2
26
+ """
27
+
28
+
29
+ @dataclass
30
+ class LCMGeneratorOptions:
31
+ """Holds the options to pass to a sequence generator."""
32
+
33
+ max_seq_len: int = 200
34
+ """The hard limit on maximum length of generated sequences."""
35
+
36
+ min_seq_len: int = 1
37
+ """The minimum length of generated sequences."""
38
+
39
+ eos_threshold: Optional[float] = 0.9
40
+ """Threshold for cosine similarity to the EOS vector"""
41
+
42
+ sample_latent_variable: bool = True
43
+ """When using VAE models, whether to return the mean or sample"""
44
+
45
+ stop_on_repetition_cosine_threshold: Optional[float] = None
46
+ """Stop the generation when the similarity of two consecutive concepts is above the threshold."""
47
+
48
+ include_eos_token: bool = False
49
+ """Whether the eos token should be included in the hypotheses (matters only if they are trimmed)."""
50
+
51
+ trim_hypotheses: bool = False
52
+ """Whether the tokens after the EOS token should be included in the hypotheses."""
53
+
54
+ seed: Optional[int] = None
55
+ """Seed to make generation deterministic"""
56
+
57
+ lcm_temperature: float = 1.0
58
+ """Temperature for decoding in the LCM"""
59
+
60
+
61
+ class LCMGenerator:
62
+ """Generates with an LCM model."""
63
+
64
+ def __init__(
65
+ self,
66
+ model: AbstractLCModel,
67
+ options: Optional[LCMGeneratorOptions] = None,
68
+ eos_vec: Optional[torch.Tensor] = None,
69
+ ) -> None:
70
+ """
71
+ :param model:
72
+ The LC model to use for generation.
73
+ """
74
+ model.eval()
75
+ self.model = model
76
+
77
+ if options is None:
78
+ options = LCMGeneratorOptions()
79
+
80
+ self.eos_vec = eos_vec
81
+ if self.eos_vec is None and options.eos_threshold:
82
+ logger.warning(
83
+ f"eos_threshold is set to {options.eos_threshold}, but eos_vec is not provided"
84
+ )
85
+ if options.eos_threshold:
86
+ logger.debug(f"The eos_vec in generator has been set to {self.eos_vec}")
87
+
88
+ self.options = options
89
+
90
+ self.max_seq_len = options.max_seq_len
91
+ self.min_seq_len = options.min_seq_len
92
+
93
+ assert self.min_seq_len >= 1, (
94
+ f"min_seq_len must be greater than or equal to 1, min_seq_len={options.min_seq_len}"
95
+ )
96
+
97
+ self.eos_threshold = options.eos_threshold
98
+
99
+ self.seqs: torch.Tensor
100
+ self.step_nr = 0
101
+ self.min_prompt_len: int
102
+ self.max_prompt_len: int
103
+ self.sample_indices: torch.Tensor
104
+ self.state_bag: Optional[LCMIncrementalStateBag] = None
105
+ self.prompt_seq_lens: Optional[torch.Tensor] = None
106
+ self.prompt_padding_mask: Optional[torch.Tensor] = None
107
+ self.lengths: torch.Tensor
108
+ self.step_scores: torch.Tensor
109
+
110
+ @torch.inference_mode()
111
+ def __call__(
112
+ self,
113
+ batch_input: EmbeddingsBatch,
114
+ max_gen_len: Optional[int] = None,
115
+ min_gen_len: Optional[int] = None,
116
+ temperature: float = 0.0,
117
+ disable_cache: bool = False,
118
+ **kwargs,
119
+ ) -> SequenceGeneratorOutput:
120
+ """
121
+ :param input:
122
+ `bacth_input` embedded and padded tensor sequence of the inputs
123
+ `max_gen_len` max length to be generated for the given input
124
+ `min_gen_len` minimum length to be generated for the given input
125
+ `temperature` temperature to control the generation
126
+ `disable_cache` if True, do not use kv-caching
127
+ :returns:
128
+ The output of the LCM generator, consists of :math:`N` lists of
129
+ hypotheses for :math:`N` prompts. Each list has 1 Hypothesis
130
+ (beam size = 1), of which `seq` has the *Shape:* math:`(S+T, D)`
131
+ (:math:`S` is the prompt length, :math:`T` the length of the
132
+ generated sequence after the prompt and :math:`D` the model
133
+ dimension.)
134
+
135
+ """
136
+ if self.options.seed:
137
+ torch.manual_seed(self.options.seed)
138
+
139
+ # Setup the variables
140
+ batch_size, self.max_prompt_len, embed_dim = batch_input.seqs.size()
141
+ prompt_padding_mask = batch_input.padding_mask
142
+ if prompt_padding_mask is None:
143
+ self.min_prompt_len = self.max_prompt_len
144
+ self.prompt_padding_mask = None
145
+ self.prompt_seq_lens = None
146
+ else:
147
+ self.prompt_seq_lens = prompt_padding_mask.seq_lens
148
+ assert self.prompt_seq_lens is not None, (
149
+ "Expecting a valid `self.prompt_seq_lens` Tensor, found `None`"
150
+ )
151
+ self.min_prompt_len = int(torch.min(self.prompt_seq_lens, dim=0)[0].item())
152
+
153
+ # Keep the materialized mask
154
+ self.prompt_padding_mask = prompt_padding_mask.materialize()
155
+
156
+ if not max_gen_len:
157
+ max_gen_len = self.max_seq_len
158
+
159
+ # Make sure we do not accidentally set a max_gen_len that exceeds
160
+ # the generator's model capability
161
+ assert max_gen_len <= self.max_seq_len, (
162
+ f"Generator can generate up to {self.max_seq_len} sequences, max_gen_len={max_gen_len}"
163
+ )
164
+ self.max_gen_len = max_gen_len
165
+
166
+ if not min_gen_len:
167
+ min_gen_len = self.min_seq_len
168
+
169
+ assert min_gen_len > 0, (
170
+ f"min_gen_len must be greater than or equal to 1, min_gen_len={min_gen_len}"
171
+ )
172
+ self.min_gen_len = min_gen_len
173
+
174
+ if temperature == 0.0:
175
+ # If the call doesn't pass a specific temperature,
176
+ # use the default one from the decoding options
177
+ temperature = self.options.lcm_temperature
178
+
179
+ self.temperature = temperature
180
+
181
+ for k, v in kwargs.items():
182
+ if hasattr(self.options, k) and v:
183
+ setattr(self.options, k, v)
184
+
185
+ # Holds the generated sequences, scores and sample-dependent variables
186
+ dtype = self.model.dtype
187
+ device = batch_input.seqs.device
188
+
189
+ if disable_cache:
190
+ self.state_bag = None
191
+ else:
192
+ self.state_bag = LCMIncrementalStateBag(
193
+ self.max_prompt_len + self.max_gen_len
194
+ )
195
+
196
+ # reserving full sequences capacity
197
+ self.seqs = torch.zeros(
198
+ (batch_size, self.max_prompt_len + self.max_gen_len, embed_dim),
199
+ device=device,
200
+ dtype=dtype,
201
+ )
202
+ self.step_scores = torch.zeros(
203
+ (batch_size, self.max_prompt_len + self.max_gen_len),
204
+ device=device,
205
+ )
206
+ self.lengths = torch.zeros(batch_size, dtype=torch.int, device=device) - 1
207
+
208
+ # Hold the samples indices to return in order
209
+ self.sample_indices = torch.arange(batch_size, device=device)
210
+ # Output buffer
211
+ self.hypotheses: List[List[Hypothesis]] = [[] for _ in range(batch_size)]
212
+
213
+ # Bootstrap the sequences with the provided prompt.
214
+ self.seqs[:, : self.max_prompt_len] = batch_input.seqs[:, : self.max_prompt_len]
215
+ self.step_nr = self.min_prompt_len
216
+ self.prefill(**kwargs)
217
+
218
+ for self.step_nr in range(
219
+ self.min_prompt_len, self.max_prompt_len + self.max_gen_len
220
+ ):
221
+ if not self._step():
222
+ break
223
+
224
+ return SequenceGeneratorOutput(self.hypotheses, counters=GenerationCounters())
225
+
226
+ @torch.inference_mode()
227
+ def prefill(self, **kwargs) -> None:
228
+ """The initial forward pass in the decoder with the prefix/prompt
229
+ to populate the KV-cache"""
230
+
231
+ if self.state_bag is None:
232
+ return
233
+
234
+ # Prefilling with -1 since the next call to step will use the last token in the prefix
235
+ prefill_len = self.step_nr - 1
236
+
237
+ if prefill_len > 0:
238
+ _ = self._decode(
239
+ self.seqs[:, :prefill_len],
240
+ padding_mask=None,
241
+ )
242
+ self.state_bag.increment_step_nr(prefill_len) # type: ignore
243
+ else:
244
+ logger.warning(
245
+ f"Skipping prefill since only a context size of {self.step_nr} is provided in the prefix"
246
+ )
247
+
248
+ @torch.inference_mode()
249
+ def _decode(
250
+ self,
251
+ seqs: torch.Tensor,
252
+ padding_mask: Optional[PaddingMask],
253
+ **kwargs,
254
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
255
+ output = self.model.predict_next_sentence(
256
+ EmbeddingsBatch(seqs, padding_mask),
257
+ sample=self.options.sample_latent_variable,
258
+ temperature=self.temperature,
259
+ state_bag=self.state_bag,
260
+ **kwargs,
261
+ )
262
+
263
+ # Dummy scores
264
+ scores = torch.zeros(seqs.shape[:-1])
265
+ return output.seqs, scores
266
+
267
+ def _step(self) -> bool:
268
+ # Generate the next step output.
269
+
270
+ if self.state_bag is None:
271
+ # Without a state_bag, we're forwarding the full prefix
272
+ model_output, step_score = self._decode(
273
+ seqs=self.seqs[:, : self.step_nr],
274
+ padding_mask=None,
275
+ )
276
+ else:
277
+ # Since we're using a state_bag, we're only forwarding the last embedding
278
+ model_output, step_score = self._decode(
279
+ seqs=self.seqs[:, self.step_nr - 1 : self.step_nr],
280
+ padding_mask=None,
281
+ )
282
+
283
+ self.state_bag.increment_step_nr()
284
+
285
+ # model_output: EmbeddingBag
286
+ return self.finalize_step(model_output, step_score)
287
+
288
+ def finalize_step(
289
+ self, model_output: torch.Tensor, step_score: torch.Tensor
290
+ ) -> bool:
291
+ """Post-processing and finalizing a step
292
+ by checking all stopping criteria
293
+ Takes the model's outputed embeddings (model_output)
294
+ and their associated scores (step_score)
295
+ If we're stepping, return True, else return False
296
+ """
297
+ already_finished = self.lengths > -1
298
+ should_finish_now = torch.zeros_like(already_finished)
299
+
300
+ model_last_output = model_output[:, -1]
301
+ device = model_last_output.device
302
+
303
+ # Ignore prompt positions between min-max prompt_len
304
+ must_keep_going = None
305
+ if self.step_nr < self.max_prompt_len:
306
+ assert self.prompt_padding_mask is not None, (
307
+ f"If self.prompt_padding_mas is None, then self.step_nr should start from self.max_prompt_len={self.max_prompt_len} - currently self.step_nr = {self.step_nr}"
308
+ )
309
+ mask = self.prompt_padding_mask[:, self.step_nr]
310
+ model_last_output[mask] = self.seqs[mask, self.step_nr]
311
+ must_keep_going = mask
312
+
313
+ # Check stopping based on EOS similarity.
314
+ if self.eos_threshold is not None and self.eos_vec is not None:
315
+ sim2eos = torch.nn.functional.cosine_similarity(
316
+ self.eos_vec.to(device), model_last_output
317
+ )
318
+ logger.debug(f"Similarity to eos vector: {sim2eos} vs {self.eos_threshold}")
319
+ should_finish_now = should_finish_now | sim2eos.ge(self.eos_threshold)
320
+
321
+ # Check stopping based on repetition.
322
+ if (
323
+ self.options.stop_on_repetition_cosine_threshold is not None
324
+ and self.step_nr > 0
325
+ ):
326
+ sim2prev = torch.nn.functional.cosine_similarity(
327
+ self.seqs[:, self.step_nr - 1], model_last_output
328
+ )
329
+ logger.debug(
330
+ f"Similarity to prev vector: {sim2prev} vs {self.options.stop_on_repetition_cosine_threshold}"
331
+ )
332
+ should_finish_now = should_finish_now | sim2prev.ge(
333
+ self.options.stop_on_repetition_cosine_threshold
334
+ )
335
+
336
+ if must_keep_going is not None:
337
+ logger.debug(
338
+ f"Must keep going (to cover max_prompt_len={self.max_prompt_len}) is not None = {must_keep_going}"
339
+ )
340
+ should_finish_now = should_finish_now & ~must_keep_going
341
+
342
+ # Keep going if output is shorter than min_gen_len:
343
+ if self.prompt_seq_lens is not None:
344
+ longer_than_min_gen_len = (self.step_nr - self.prompt_seq_lens).ge(
345
+ self.min_gen_len
346
+ )
347
+ else:
348
+ longer_than_min_gen_len = (
349
+ self.step_nr - self.max_prompt_len
350
+ ) >= self.min_gen_len
351
+
352
+ logger.debug(
353
+ f"Longer than min_gen_len ({self.min_gen_len}) = {longer_than_min_gen_len}"
354
+ )
355
+ should_finish_now = should_finish_now & longer_than_min_gen_len
356
+ stopped_on_eos = should_finish_now
357
+
358
+ # Stop hypotheses that reached max_gen_len
359
+ if self.prompt_seq_lens is not None:
360
+ exceeds_max_gen_len = (self.step_nr - self.prompt_seq_lens + 1).ge(
361
+ self.max_gen_len
362
+ )
363
+ logger.debug(
364
+ f"step: {self.step_nr}; max_gen_len: {self.max_gen_len}; promt_lens: {self.prompt_seq_lens}; steps exceeded: {self.max_gen_len + self.prompt_seq_lens}"
365
+ )
366
+
367
+ else:
368
+ exceeds_max_gen_len = (
369
+ self.step_nr - self.max_prompt_len + 1
370
+ ) >= self.max_gen_len
371
+ logger.debug(
372
+ f"step: {self.step_nr}; max_gen_len: {self.max_gen_len}; promt_lens: None (unique length: {self.max_prompt_len}); steps exceeded: {self.max_prompt_len + self.max_gen_len}"
373
+ )
374
+
375
+ logger.debug(
376
+ f"Stopping criteria: {should_finish_now}; exceeds max len: {exceeds_max_gen_len}; already finished: {already_finished}"
377
+ )
378
+
379
+ should_finish_now = should_finish_now | exceeds_max_gen_len
380
+
381
+ # Assign lengths to the sequences that have just finished.
382
+ should_finish_now = should_finish_now & ~already_finished
383
+ self.lengths[should_finish_now] = self.step_nr + 1
384
+
385
+ # Record the current step.
386
+ self.seqs[:, self.step_nr] = model_last_output.squeeze(1)
387
+ self.step_scores[:, self.step_nr - self.min_prompt_len] = step_score[:, -1]
388
+
389
+ # Save completed hypsptheses
390
+ finished_mask = self.lengths.ne(-1)
391
+ finished_indices = finished_mask.nonzero()
392
+
393
+ # Remove finished hypotheses and reorder variables/state_bag if any are left
394
+ if len(finished_indices) > 0:
395
+ for idx in finished_indices:
396
+ self.finish_sequence(int(idx), is_eos=bool(stopped_on_eos[int(idx)]))
397
+
398
+ active_mask = ~finished_mask
399
+ active_indices = active_mask.nonzero().squeeze(-1)
400
+
401
+ if len(active_indices) == 0:
402
+ return False
403
+
404
+ self.reorder_state(active_indices)
405
+
406
+ return True
407
+
408
+ def finish_sequence(self, idx: int, is_eos: bool = False) -> None:
409
+ seq_len = int(self.lengths[idx].item())
410
+
411
+ if self.options.trim_hypotheses and self.lengths[idx].item() > -1 and is_eos:
412
+ seq_len = int(self.lengths[idx].item()) - int(
413
+ not self.options.include_eos_token
414
+ )
415
+
416
+ sample_idx = int(self.sample_indices[idx])
417
+ self.hypotheses[sample_idx] = [
418
+ Hypothesis(
419
+ seq=self.seqs[idx, :seq_len],
420
+ score=None,
421
+ step_scores=self.step_scores[idx], # Trim it as well?
422
+ )
423
+ ]
424
+
425
+ def state_bag_reorder(self, new_order: torch.Tensor) -> None:
426
+ if self.state_bag is not None:
427
+ self.state_bag.reorder(new_order)
428
+
429
+ def reorder_state(self, new_order: torch.Tensor) -> None:
430
+ self.state_bag_reorder(new_order)
431
+
432
+ self.seqs = self.seqs.index_select(dim=0, index=new_order)
433
+
434
+ self.sample_indices = self.sample_indices.index_select(dim=0, index=new_order)
435
+
436
+ self.step_scores = self.step_scores.index_select(dim=0, index=new_order)
437
+
438
+ self.lengths = self.lengths.index_select(dim=0, index=new_order)
439
+
440
+ if self.prompt_padding_mask is not None:
441
+ self.prompt_padding_mask = self.prompt_padding_mask.index_select(
442
+ dim=0, index=new_order
443
+ )
444
+
445
+ if self.prompt_seq_lens is not None:
446
+ self.prompt_seq_lens = self.prompt_seq_lens.index_select(
447
+ dim=0, index=new_order
448
+ )
lcm/inference/lcm/scorer.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from typing import List, Optional
7
+
8
+ import torch
9
+ from fairseq2.generation.generator import (
10
+ GenerationCounters,
11
+ Hypothesis,
12
+ SequenceGeneratorOutput,
13
+ )
14
+
15
+ from lcm.datasets.batch import EmbeddingsBatch
16
+ from lcm.inference.lcm.generator import LCMGenerator, LCMGeneratorOptions
17
+ from lcm.nn.incremental_state import LCMIncrementalStateBag
18
+
19
+
20
+ class LCMScorer(LCMGenerator):
21
+ """Generates with an LCM model in teacher-forcing mode."""
22
+
23
+ options: LCMGeneratorOptions
24
+
25
+ @torch.inference_mode()
26
+ def __call__( # type: ignore
27
+ self,
28
+ batch_input: EmbeddingsBatch,
29
+ max_gen_len: Optional[int] = None,
30
+ min_gen_len: Optional[int] = None,
31
+ min_context_len: int = 1,
32
+ temperature: float = 0.0,
33
+ disable_cache: bool = False,
34
+ ) -> SequenceGeneratorOutput:
35
+ """
36
+ :param input:
37
+ `bacth_input` embedded and padded tensor sequence of the inputs
38
+ `max_gen_len` max length to be generated for the given input
39
+ `min_gen_len` minimum length to be generated for the given input
40
+ `disable_cache` if True, do not use kv-caching
41
+ :returns:
42
+ The output of the LCM generator, consists of :math:`N` lists of
43
+ hypotheses for :math:`N` documents. Each list has 1 Hypothesis
44
+ (beam size = 1), of which `seq` has the *Shape:* math:`(T, D)`
45
+ (:math:`T` the length of the document and :math:`D` the model
46
+ dimension
47
+
48
+ """
49
+ if self.options.seed:
50
+ torch.manual_seed(self.options.seed)
51
+
52
+ # Setup the variables
53
+ self.min_context_len = min_context_len
54
+ batch_size, self.max_text_len, embed_dim = batch_input.seqs.size()
55
+ text_padding_mask = batch_input.padding_mask
56
+ if text_padding_mask is None:
57
+ self.text_padding_mask = None
58
+ self.text_seq_lens = self.max_text_len * torch.ones(
59
+ batch_size,
60
+ dtype=torch.long,
61
+ device=batch_input.seqs.device,
62
+ )
63
+ else:
64
+ self.text_seq_lens = text_padding_mask.seq_lens
65
+ assert self.text_seq_lens is not None, (
66
+ "Expecting a valid `self.text_seq_lens` Tensor, found `None`"
67
+ )
68
+
69
+ # Keep the materialized mask
70
+ self.text_padding_mask = text_padding_mask.materialize()
71
+
72
+ if not max_gen_len:
73
+ max_gen_len = self.max_seq_len
74
+
75
+ max_gen_len = min(max_gen_len, self.max_text_len - self.min_context_len)
76
+ assert max_gen_len is not None, "max_gen_len is None"
77
+
78
+ # Make sure we do not accidentally set a max_gen_len that exceeds
79
+ # the generator's model capability
80
+ assert max_gen_len <= self.max_seq_len, (
81
+ f"Generator can generate up to {self.max_seq_len} sequences, max_gen_len={max_gen_len}"
82
+ )
83
+ self.max_gen_len = max_gen_len
84
+
85
+ if not min_gen_len:
86
+ min_gen_len = self.min_seq_len
87
+
88
+ assert min_gen_len > 0, (
89
+ f"min_gen_len must be greater than or equal to 1, min_gen_len={min_gen_len}"
90
+ )
91
+ self.min_gen_len = min_gen_len
92
+
93
+ if temperature == 0.0:
94
+ # If the call doesn't pass a specific temperature,
95
+ # use the default one from the decoding options
96
+ temperature = self.options.lcm_temperature
97
+
98
+ # Holds the generated sequences, scores and sample-dependent variables
99
+ dtype = self.model.dtype
100
+ device = batch_input.seqs.device
101
+ self.temperature = temperature
102
+
103
+ if disable_cache:
104
+ self.state_bag = None
105
+ else:
106
+ self.state_bag = LCMIncrementalStateBag(self.max_text_len)
107
+
108
+ # reserving full sequences capacity
109
+ self.seqs = batch_input.seqs
110
+ self.preds = torch.zeros(
111
+ (batch_size, self.max_text_len - self.min_context_len, embed_dim),
112
+ device=device,
113
+ dtype=dtype,
114
+ )
115
+ self.step_scores = torch.zeros(
116
+ (batch_size, self.max_text_len),
117
+ device=device,
118
+ )
119
+
120
+ # Hold the samples indices to return in order
121
+ self.sample_indices = torch.arange(batch_size, device=device)
122
+ # Output buffer
123
+ self.hypotheses: List[List[Hypothesis]] = [[] for _ in range(batch_size)]
124
+
125
+ # the sequences with the provided prompt.
126
+ self.step_nr = self.min_context_len
127
+ self.prefill()
128
+
129
+ for self.step_nr in range(self.min_context_len, self.max_text_len):
130
+ if not self._step():
131
+ break
132
+
133
+ return SequenceGeneratorOutput(self.hypotheses, counters=GenerationCounters())
134
+
135
+ def finalize_step(
136
+ self, model_output: torch.Tensor, step_score: torch.Tensor
137
+ ) -> bool:
138
+ """Post-processing and finalizing a step
139
+ by checking all stopping criteria
140
+ Takes the model's outputed embeddings (model_output)
141
+ and their associated scores (step_score)
142
+ If we're stepping, return True, else return False
143
+ """
144
+ model_last_output = model_output[:, -1]
145
+ must_keep_going = self.text_seq_lens.gt(self.step_nr + 1)
146
+ should_finish_now = ~must_keep_going
147
+
148
+ # Record the current step prediction.
149
+ self.preds[:, self.step_nr - self.min_context_len] = model_last_output.squeeze(
150
+ 1
151
+ )
152
+ self.step_scores[:, self.step_nr - self.min_context_len] = step_score[:, -1]
153
+
154
+ # Save completed hypotheses
155
+ finished_indices = should_finish_now.nonzero()
156
+
157
+ # Remove finished hypotheses and reorder variables/state_bag if any are left
158
+ if len(finished_indices) > 0:
159
+ for idx in finished_indices:
160
+ self.finish_sequence(int(idx))
161
+
162
+ active_mask = must_keep_going
163
+ active_indices = active_mask.nonzero().squeeze(-1)
164
+
165
+ if len(active_indices) == 0:
166
+ return False
167
+
168
+ self.reorder_state(active_indices)
169
+
170
+ return True
171
+
172
+ def finish_sequence(self, idx: int, is_eos: bool = False) -> None:
173
+ seq_len = int(self.text_seq_lens[idx].item())
174
+ sample_idx = int(self.sample_indices[idx])
175
+ self.hypotheses[sample_idx] = [
176
+ Hypothesis(
177
+ seq=self.preds[idx, : seq_len - self.min_context_len],
178
+ score=None,
179
+ step_scores=self.step_scores[idx], # Trim it as well?
180
+ )
181
+ ]
182
+
183
+ def reorder_state(self, new_order: torch.Tensor) -> None:
184
+ self.state_bag_reorder(new_order)
185
+
186
+ self.seqs = self.seqs.index_select(dim=0, index=new_order)
187
+ self.preds = self.preds.index_select(dim=0, index=new_order)
188
+
189
+ self.sample_indices = self.sample_indices.index_select(dim=0, index=new_order)
190
+
191
+ self.step_scores = self.step_scores.index_select(dim=0, index=new_order)
192
+
193
+ if self.text_padding_mask is not None:
194
+ self.text_padding_mask = self.text_padding_mask.index_select(
195
+ dim=0, index=new_order
196
+ )
197
+
198
+ self.text_seq_lens = self.text_seq_lens.index_select(dim=0, index=new_order)
lcm/inference/two_tower_diffusion_lcm/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from lcm.inference.two_tower_diffusion_lcm.generator import (
7
+ DiffusionLCMGeneratorOptions as DiffusionLCMGeneratorOptions,
8
+ )
9
+ from lcm.inference.two_tower_diffusion_lcm.generator import (
10
+ TwoTowerDiffusionLCMGenerator as TwoTowerDiffusionLCMGenerator,
11
+ )
12
+
13
+ __all__ = [
14
+ "TwoTowerDiffusionLCMGenerator",
15
+ "DiffusionLCMGeneratorOptions",
16
+ ]
lcm/inference/two_tower_diffusion_lcm/generator.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from dataclasses import dataclass
7
+ from typing import List, Optional, Tuple
8
+
9
+ import torch
10
+ from fairseq2.generation.generator import (
11
+ GenerationCounters,
12
+ Hypothesis,
13
+ SequenceGeneratorOutput,
14
+ )
15
+ from fairseq2.logging import get_log_writer
16
+
17
+ from lcm.datasets.batch import EmbeddingsBatch, PaddingMask
18
+ from lcm.inference.lcm.generator import (
19
+ LCMGenerator,
20
+ LCMGeneratorOptions,
21
+ )
22
+ from lcm.models.abstract_lcm import AbstractLCModel
23
+ from lcm.models.two_tower_diffusion_lcm.builder import TwoTowerDiffusionLCModel
24
+ from lcm.nn.incremental_state import LCMIncrementalStateBag
25
+
26
+ logger = get_log_writer(__name__)
27
+
28
+
29
+ @dataclass
30
+ class DiffusionLCMGeneratorOptions(LCMGeneratorOptions):
31
+ """Holds the options to pass to a diffusion-based sequence generator."""
32
+
33
+ guidance_scale: float = 1.0
34
+ """The weight of the regular classifier-free guidance.
35
+ Here `guidance_scale` is defined as the guidance weight `w` of
36
+ Equation (2) of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf.
37
+ `guidance_scale = 1` corresponds to doing no classifier free guidance.
38
+ A higher guidance scale value encourages the model to generate outputs
39
+ closely related to the `prompt` at the expense of lower quality."""
40
+
41
+ guidance_rescale: float = 0.0
42
+ """The rescaling factor for Classifier-Free Guidance with Rescale
43
+ (Algorithm 2 - https://arxiv.org/pdf/2305.08891)"""
44
+
45
+ ddim_eta: float = 0.0
46
+ """The weight of noise for added noise in diffusion step.
47
+ It controls the level of interpolation between a deterministic
48
+ DDIM (at eta=0.0) and a stochastic DDPM (at eta = 1.0)
49
+ See section 5 of the DDIM paper https://arxiv.org/pdf/2010.02502 """
50
+
51
+ epsilon_scaling: Optional[float] = None
52
+ """epsilon_scaling: Optional[float] if not None, the predicted epsilon will
53
+ be scaled down by the provided factor as
54
+ introduced in https://arxiv.org/pdf/2308.15321""" ""
55
+
56
+ initial_noise_scale: float = 1.0
57
+ """For Diffusion models, scaling of initial noise"""
58
+
59
+ inference_timesteps: int = 100
60
+ """For Diffusion models, number of denoising timesteps"""
61
+
62
+ clip_noise: int = 100
63
+ """For Diffusion models, factor to clip noise of the sampling steps"""
64
+
65
+ thresholding: bool = False
66
+ """Whether to use the "dynamic thresholding" method.
67
+ This is unsuitable for latent-space diffusion models such as Stable Diffusion."""
68
+
69
+ dynamic_thresholding_ratio: float = 0.995
70
+ """The ratio for the dynamic thresholding method. Valid only when `thresholding=True`."""
71
+
72
+ sample_max_value: float = 6.0
73
+ """The threshold value for dynamic thresholding. Valid only when `thresholding=True`."""
74
+
75
+
76
+ class TwoTowerDiffusionLCMGenerator(LCMGenerator):
77
+ """Generates with a Two-tower Diffusion LCM model."""
78
+
79
+ options: DiffusionLCMGeneratorOptions
80
+
81
+ def __init__(
82
+ self,
83
+ model: AbstractLCModel,
84
+ options: Optional[LCMGeneratorOptions] = None,
85
+ eos_vec: Optional[torch.Tensor] = None,
86
+ ) -> None:
87
+ super().__init__(model, options, eos_vec)
88
+
89
+ assert isinstance(self.model, TwoTowerDiffusionLCModel), (
90
+ "The TwoTowerDiffusionLCMGenerator expects a Diffusion LCM"
91
+ )
92
+
93
+ logger.info(
94
+ f"Setting up the model with decoding_options: {options} -- {type(options)}"
95
+ )
96
+ model.prep_for_denoising(options)
97
+
98
+ @torch.inference_mode()
99
+ def __call__(
100
+ self,
101
+ batch_input: EmbeddingsBatch,
102
+ max_gen_len: Optional[int] = None,
103
+ min_gen_len: Optional[int] = None,
104
+ temperature: float = 0.0,
105
+ disable_cache: bool = False,
106
+ **kwargs,
107
+ ) -> SequenceGeneratorOutput:
108
+ """
109
+ :param input:
110
+ `bacth_input` embedded and padded tensor sequence of the inputs
111
+ `max_gen_len` max length to be generated for the given input
112
+ `min_gen_len` minimum length to be generated for the given input
113
+ `disable_cache` if True, do not use kv-caching
114
+ `temperature` temperature to control the generation
115
+ :returns:
116
+ The output of the LCM generator, consists of :math:`N` lists of
117
+ hypotheses for :math:`N` prompts. Each list has 1 Hypothesis
118
+ (beam size = 1), of which `seq` has the *Shape:* math:`(S+T, D)`
119
+ (:math:`S` is the prompt length, :math:`T` the length of the
120
+ generated sequence after the prompt and :math:`D` the model
121
+ dimension.)
122
+
123
+ """
124
+ if self.options.seed:
125
+ torch.manual_seed(self.options.seed)
126
+
127
+ # Setup the variables
128
+ batch_size, self.max_prompt_len, embed_dim = batch_input.seqs.size()
129
+ prompt_padding_mask = batch_input.padding_mask
130
+ if prompt_padding_mask is None:
131
+ self.min_prompt_len = self.max_prompt_len
132
+ self.prompt_padding_mask = None
133
+ self.prompt_seq_lens = None
134
+ else:
135
+ self.prompt_seq_lens = prompt_padding_mask.seq_lens
136
+ assert self.prompt_seq_lens is not None, (
137
+ "Expecting a valid `self.prompt_seq_lens` Tensor, found `None`"
138
+ )
139
+ self.min_prompt_len = int(torch.min(self.prompt_seq_lens, dim=0)[0].item())
140
+
141
+ # Keep the materialized mask
142
+ self.prompt_padding_mask = prompt_padding_mask.materialize()
143
+
144
+ if not max_gen_len:
145
+ max_gen_len = self.max_seq_len
146
+
147
+ # Make sure we do not accidentally set a max_gen_len that exceeds
148
+ # the generator's model capability
149
+ assert max_gen_len <= self.max_seq_len, (
150
+ f"Generator can generate up to {self.max_seq_len} sequences, max_gen_len={max_gen_len}"
151
+ )
152
+ self.max_gen_len = max_gen_len
153
+
154
+ if not min_gen_len:
155
+ min_gen_len = self.min_seq_len
156
+
157
+ assert min_gen_len > 0, (
158
+ f"min_gen_len must be greater than or equal to 1, min_gen_len={min_gen_len}"
159
+ )
160
+ self.min_gen_len = min_gen_len
161
+
162
+ if temperature == 0.0:
163
+ # If the call doesn't pass a specific temperature,
164
+ # use the default one from the decoding options
165
+ temperature = self.options.lcm_temperature
166
+
167
+ # Holds the generated sequences, scores and sample-dependent variables
168
+ dtype = self.model.dtype
169
+ device = batch_input.seqs.device
170
+ self.temperature = temperature
171
+
172
+ if disable_cache:
173
+ self.state_bag = None
174
+ self.context_state_bag = None
175
+ else:
176
+ self.state_bag = LCMIncrementalStateBag(
177
+ self.max_prompt_len + self.max_gen_len
178
+ )
179
+ self.context_state_bag = LCMIncrementalStateBag(
180
+ self.max_prompt_len + self.max_gen_len
181
+ )
182
+
183
+ # reserving full sequences capacity
184
+ self.seqs = torch.zeros(
185
+ (batch_size, self.max_prompt_len + self.max_gen_len, embed_dim),
186
+ device=device,
187
+ dtype=dtype,
188
+ )
189
+ self.step_scores = torch.zeros(
190
+ (batch_size, self.max_prompt_len + self.max_gen_len),
191
+ device=device,
192
+ )
193
+ self.lengths = torch.zeros(batch_size, dtype=torch.int, device=device) - 1
194
+
195
+ # Hold the samples indices to return in order
196
+ self.sample_indices = torch.arange(batch_size, device=device)
197
+ # Output buffer
198
+ self.hypotheses: List[List[Hypothesis]] = [[] for _ in range(batch_size)]
199
+
200
+ # Bootstrap the sequences with the provided prompt.
201
+ self.seqs[:, : self.max_prompt_len] = batch_input.seqs[:, : self.max_prompt_len]
202
+ self.step_nr = self.min_prompt_len
203
+
204
+ # A context we keep growing in each decoding step
205
+ self.prefill()
206
+
207
+ for self.step_nr in range(
208
+ self.min_prompt_len, self.max_prompt_len + self.max_gen_len
209
+ ):
210
+ if not self._step():
211
+ break
212
+
213
+ return SequenceGeneratorOutput(self.hypotheses, counters=GenerationCounters())
214
+
215
+ def state_bag_reorder(self, new_order: torch.Tensor) -> None:
216
+ if self.state_bag is not None:
217
+ self.state_bag.reorder(new_order)
218
+
219
+ if self.context_state_bag is not None:
220
+ self.context_state_bag.reorder(new_order)
221
+
222
+ @torch.inference_mode()
223
+ def prefill(self, **kwargs) -> None:
224
+ """encode the prefix with the context encoder"""
225
+
226
+ assert self.context_state_bag is not None, (
227
+ "Expecting a context state bag to prefill"
228
+ )
229
+
230
+ context: EmbeddingsBatch
231
+
232
+ prefill_len = self.step_nr - 1
233
+ if prefill_len > 0:
234
+ # normalize then encode
235
+ input_seqs = self.seqs[:, :prefill_len]
236
+ if self.model.config.sonar_normalizer_name is not None:
237
+ input_seqs = self.model.sonar_normalizer.normalize(input_seqs)
238
+
239
+ context = self.model.encode(
240
+ EmbeddingsBatch(input_seqs, None),
241
+ state_bag=self.context_state_bag,
242
+ **kwargs,
243
+ )
244
+
245
+ self.context_state_bag.increment_step_nr(prefill_len)
246
+
247
+ else:
248
+ logger.warning(
249
+ f"Skipping prefill since only a context size of {self.step_nr} is provided in the prefix"
250
+ )
251
+ context = EmbeddingsBatch(
252
+ torch.empty(
253
+ (self.seqs.shape[0], 0, self.model.model_dim),
254
+ dtype=self.seqs.dtype,
255
+ device=self.seqs.device,
256
+ )
257
+ )
258
+
259
+ self.context = context
260
+
261
+ @torch.inference_mode()
262
+ def _decode(
263
+ self,
264
+ seqs: torch.Tensor,
265
+ padding_mask: Optional[PaddingMask] = None,
266
+ **kwargs,
267
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
268
+ output, context = self.model.predict_next_sentence(
269
+ batch=EmbeddingsBatch(seqs, padding_mask),
270
+ context=self.context,
271
+ temperature=self.temperature,
272
+ state_bag=self.state_bag,
273
+ context_state_bag=self.context_state_bag,
274
+ **kwargs,
275
+ )
276
+ self.context = context
277
+
278
+ # Dummy scores
279
+ scores = torch.zeros(seqs.shape[:-1])
280
+ return output.seqs, scores
281
+
282
+ def _step(self) -> bool:
283
+ # Generate the next step output.
284
+
285
+ if self.state_bag is None:
286
+ # Without a state_bag, we're forwarding the full prefix
287
+ # Encode the full context:
288
+
289
+ model_output, step_score = self._decode(
290
+ seqs=self.seqs[:, : self.step_nr],
291
+ padding_mask=None,
292
+ )
293
+ else:
294
+ # Since we're using a state_bag, we're only forwarding the last embedding
295
+ model_output, step_score = self._decode(
296
+ seqs=self.seqs[:, self.step_nr - 1 : self.step_nr],
297
+ padding_mask=None,
298
+ )
299
+
300
+ self.state_bag.increment_step_nr()
301
+
302
+ # model_output: EmbeddingBag
303
+ return self.finalize_step(model_output, step_score)
304
+
305
+ def finalize_step(
306
+ self, model_output: torch.Tensor, step_score: torch.Tensor
307
+ ) -> bool:
308
+ """Post-processing and finalizing a step
309
+ by checking all stopping criteria
310
+ Takes the model's outputed embeddings (model_output)
311
+ and their associated scores (step_score)
312
+ If we're stepping, return True, else return False
313
+ """
314
+ already_finished = self.lengths > -1
315
+ should_finish_now = torch.zeros_like(already_finished)
316
+
317
+ model_last_output = model_output[:, -1]
318
+ device = model_last_output.device
319
+
320
+ # Ignore prompt positions between min-max prompt_len
321
+ must_keep_going = None
322
+ if self.step_nr < self.max_prompt_len:
323
+ assert self.prompt_padding_mask is not None, (
324
+ f"If self.prompt_padding_mas is None, then self.step_nr should start from self.max_prompt_len={self.max_prompt_len} - currently self.step_nr = {self.step_nr}"
325
+ )
326
+ mask = self.prompt_padding_mask[:, self.step_nr]
327
+ model_last_output[mask] = self.seqs[mask, self.step_nr]
328
+ must_keep_going = mask
329
+
330
+ # Check stopping based on EOS similarity.
331
+ if self.eos_threshold is not None and self.eos_vec is not None:
332
+ sim2eos = torch.nn.functional.cosine_similarity(
333
+ self.eos_vec.to(device), model_last_output
334
+ )
335
+ logger.debug(f"Similarity to eos vector: {sim2eos} vs {self.eos_threshold}")
336
+ should_finish_now = should_finish_now | sim2eos.ge(self.eos_threshold)
337
+
338
+ # Check stopping based on repetition.
339
+ if (
340
+ self.options.stop_on_repetition_cosine_threshold is not None
341
+ and self.step_nr > 0
342
+ ):
343
+ sim2prev = torch.nn.functional.cosine_similarity(
344
+ self.seqs[:, self.step_nr - 1], model_last_output
345
+ )
346
+ logger.debug(
347
+ f"Similarity to prev vector: {sim2prev} vs {self.options.stop_on_repetition_cosine_threshold}"
348
+ )
349
+ should_finish_now = should_finish_now | sim2prev.ge(
350
+ self.options.stop_on_repetition_cosine_threshold
351
+ )
352
+
353
+ if must_keep_going is not None:
354
+ logger.debug(
355
+ f"Must keep going (to cover max_prompt_len={self.max_prompt_len}) is not None = {must_keep_going}"
356
+ )
357
+ should_finish_now = should_finish_now & ~must_keep_going
358
+
359
+ # Keep going if output is shorter than min_gen_len:
360
+ if self.prompt_seq_lens is not None:
361
+ longuer_than_min_gen_len = (self.step_nr - self.prompt_seq_lens).ge(
362
+ self.min_gen_len
363
+ )
364
+ else:
365
+ longuer_than_min_gen_len = (
366
+ self.step_nr - self.max_prompt_len
367
+ ) >= self.min_gen_len
368
+
369
+ logger.debug(
370
+ f"Longuer than min_gen_len ({self.min_gen_len}) = {longuer_than_min_gen_len}"
371
+ )
372
+ should_finish_now = should_finish_now & longuer_than_min_gen_len
373
+ stopped_on_eos = should_finish_now
374
+
375
+ # Stop hypotheses that reached max_gen_len
376
+ if self.prompt_seq_lens is not None:
377
+ exceeds_max_gen_len = (self.step_nr - self.prompt_seq_lens + 1).ge(
378
+ self.max_gen_len
379
+ )
380
+ logger.debug(
381
+ f"step: {self.step_nr}; max_gen_len: {self.max_gen_len}; promt_lens: {self.prompt_seq_lens}; steps exceeded: {self.max_gen_len + self.prompt_seq_lens}"
382
+ )
383
+
384
+ else:
385
+ exceeds_max_gen_len = (
386
+ self.step_nr - self.max_prompt_len + 1
387
+ ) >= self.max_gen_len
388
+ logger.debug(
389
+ f"step: {self.step_nr}; max_gen_len: {self.max_gen_len}; promt_lens: None (unique length: {self.max_prompt_len}); steps exceeded: {self.max_prompt_len + self.max_gen_len}"
390
+ )
391
+
392
+ logger.debug(
393
+ f"Stopping criteria: {should_finish_now}; exceeds max len: {exceeds_max_gen_len}; already finished: {already_finished}"
394
+ )
395
+
396
+ should_finish_now = should_finish_now | exceeds_max_gen_len
397
+
398
+ # Assign lengths to the sequences that have just finished.
399
+ should_finish_now = should_finish_now & ~already_finished
400
+ self.lengths[should_finish_now] = self.step_nr + 1
401
+
402
+ # Record the current step.
403
+ self.seqs[:, self.step_nr] = model_last_output.squeeze(1)
404
+ self.step_scores[:, self.step_nr - self.min_prompt_len] = step_score[:, -1]
405
+
406
+ # Save completed hypsptheses
407
+ finished_mask = self.lengths.ne(-1)
408
+ finished_indices = finished_mask.nonzero()
409
+
410
+ # Remove finished hypotheses and reorder variables/state_bag if any are left
411
+ if len(finished_indices) > 0:
412
+ for idx in finished_indices:
413
+ self.finish_sequence(int(idx), is_eos=bool(stopped_on_eos[int(idx)]))
414
+
415
+ active_mask = ~finished_mask
416
+ active_indices = active_mask.nonzero().squeeze(-1)
417
+
418
+ if len(active_indices) == 0:
419
+ return False
420
+
421
+ self.reorder_state(active_indices)
422
+
423
+ return True
424
+
425
+ def finish_sequence(self, idx: int, is_eos: bool = False) -> None:
426
+ seq_len = int(self.lengths[idx].item())
427
+
428
+ if self.options.trim_hypotheses and self.lengths[idx].item() > -1 and is_eos:
429
+ seq_len = int(self.lengths[idx].item()) - int(
430
+ not self.options.include_eos_token
431
+ )
432
+
433
+ sample_idx = int(self.sample_indices[idx])
434
+ self.hypotheses[sample_idx] = [
435
+ Hypothesis(
436
+ seq=self.seqs[idx, :seq_len],
437
+ score=None,
438
+ step_scores=self.step_scores[idx], # Trim it as well?
439
+ )
440
+ ]
441
+
442
+ def reorder_state(self, new_order: torch.Tensor) -> None:
443
+ self.state_bag_reorder(new_order)
444
+
445
+ self.context = EmbeddingsBatch(
446
+ self.context.seqs.index_select(dim=0, index=new_order),
447
+ self.context.padding_mask,
448
+ )
449
+
450
+ self.seqs = self.seqs.index_select(dim=0, index=new_order)
451
+
452
+ self.sample_indices = self.sample_indices.index_select(dim=0, index=new_order)
453
+
454
+ self.step_scores = self.step_scores.index_select(dim=0, index=new_order)
455
+
456
+ self.lengths = self.lengths.index_select(dim=0, index=new_order)
457
+
458
+ if self.prompt_padding_mask is not None:
459
+ self.prompt_padding_mask = self.prompt_padding_mask.index_select(
460
+ dim=0, index=new_order
461
+ )
462
+
463
+ if self.prompt_seq_lens is not None:
464
+ self.prompt_seq_lens = self.prompt_seq_lens.index_select(
465
+ dim=0, index=new_order
466
+ )
lcm/inference/two_tower_diffusion_lcm/scorer.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from typing import List, Optional, Tuple
7
+
8
+ import torch
9
+ from fairseq2.generation.generator import (
10
+ GenerationCounters,
11
+ Hypothesis,
12
+ SequenceGeneratorOutput,
13
+ )
14
+ from fairseq2.logging import get_log_writer
15
+
16
+ from lcm.datasets.batch import EmbeddingsBatch, PaddingMask
17
+ from lcm.inference.lcm.generator import LCMGeneratorOptions
18
+ from lcm.inference.two_tower_diffusion_lcm import (
19
+ TwoTowerDiffusionLCMGenerator,
20
+ )
21
+ from lcm.models.abstract_lcm import AbstractLCModel
22
+ from lcm.nn.incremental_state import LCMIncrementalStateBag
23
+
24
+ logger = get_log_writer(__name__)
25
+
26
+
27
+ class TwoTowerDiffusionLCMScorer(TwoTowerDiffusionLCMGenerator):
28
+ """Score by generating in teacher-forcing mode with a Two-tower Diffusion LCM model."""
29
+
30
+ def __init__(
31
+ self,
32
+ model: AbstractLCModel,
33
+ options: Optional[LCMGeneratorOptions] = None,
34
+ eos_vec: Optional[torch.Tensor] = None,
35
+ ) -> None:
36
+ super().__init__(model, options, eos_vec)
37
+
38
+ @torch.inference_mode()
39
+ def __call__( # type: ignore
40
+ self,
41
+ batch_input: EmbeddingsBatch,
42
+ max_gen_len: Optional[int] = None,
43
+ min_gen_len: Optional[int] = None,
44
+ min_context_len: int = 1,
45
+ temperature: float = 0.0,
46
+ disable_cache: bool = False,
47
+ ) -> SequenceGeneratorOutput:
48
+ """
49
+ :param input:
50
+ `bacth_input` embedded and padded tensor sequence of the inputs
51
+ `max_gen_len` max length to be generated for the given input
52
+ `min_gen_len` minimum length to be generated for the given input
53
+ `disable_cache` if True, do not use kv-caching
54
+ :returns:
55
+ The output of the LCM generator, consists of :math:`N` lists of
56
+ hypotheses for :math:`N` documents. Each list has 1 Hypothesis
57
+ (beam size = 1), of which `seq` has the *Shape:* math:`(T, D)`
58
+ (:math:`T` the length of the document and :math:`D` the model
59
+ dimension.)
60
+
61
+ """
62
+ if self.options.seed:
63
+ torch.manual_seed(self.options.seed)
64
+
65
+ # Setup the variables
66
+ self.min_context_len = min_context_len
67
+ batch_size, self.max_text_len, embed_dim = batch_input.seqs.size()
68
+ text_padding_mask = batch_input.padding_mask
69
+ if text_padding_mask is None:
70
+ self.text_padding_mask = None
71
+ self.text_seq_lens = self.max_text_len * torch.ones(
72
+ batch_size,
73
+ dtype=torch.long,
74
+ device=batch_input.seqs.device,
75
+ )
76
+ else:
77
+ self.text_seq_lens = text_padding_mask.seq_lens
78
+ assert self.text_seq_lens is not None, (
79
+ "Expecting a valid `self.text_seq_lens` Tensor, found `None`"
80
+ )
81
+
82
+ # Keep the materialized mask
83
+ self.text_padding_mask = text_padding_mask.materialize()
84
+
85
+ if not max_gen_len:
86
+ max_gen_len = self.max_seq_len
87
+
88
+ max_gen_len = min(max_gen_len, self.max_text_len - self.min_context_len)
89
+ assert max_gen_len is not None, "max_gen_len is None"
90
+
91
+ # Make sure we do not accidentally set a max_gen_len that exceeds
92
+ # the generator's model capability
93
+ assert max_gen_len <= self.max_seq_len, (
94
+ f"Generator can generate up to {self.max_seq_len} sequences, max_gen_len={max_gen_len}"
95
+ )
96
+ self.max_gen_len = max_gen_len
97
+
98
+ if not min_gen_len:
99
+ min_gen_len = self.min_seq_len
100
+
101
+ assert min_gen_len is not None, "A `min_gen_len` is required"
102
+
103
+ assert min_gen_len > 0, (
104
+ f"min_gen_len must be greater than or equal to 1, min_gen_len={min_gen_len}"
105
+ )
106
+
107
+ self.min_gen_len = min_gen_len
108
+
109
+ if temperature == 0.0:
110
+ # If the call doesn't pass a specific temperature,
111
+ # use the default one from the decoding options
112
+ temperature = self.options.lcm_temperature
113
+
114
+ # Holds the generated sequences, scores and sample-dependent variables
115
+ dtype = self.model.dtype
116
+ device = batch_input.seqs.device
117
+ self.temperature = temperature
118
+
119
+ if disable_cache:
120
+ self.state_bag = None
121
+ self.context_state_bag = None
122
+ else:
123
+ self.state_bag = LCMIncrementalStateBag(self.max_text_len)
124
+ self.context_state_bag = LCMIncrementalStateBag(self.max_text_len)
125
+
126
+ # reserving full sequences capacity
127
+ self.seqs = batch_input.seqs
128
+ self.preds = torch.zeros(
129
+ (batch_size, self.max_text_len - self.min_context_len, embed_dim),
130
+ device=device,
131
+ dtype=dtype,
132
+ )
133
+
134
+ self.step_scores = torch.zeros(
135
+ (batch_size, self.max_text_len),
136
+ device=device,
137
+ )
138
+ # Hold the samples indices to return in order
139
+ self.sample_indices = torch.arange(batch_size, device=device)
140
+ # Output buffer
141
+ self.hypotheses: List[List[Hypothesis]] = [[] for _ in range(batch_size)]
142
+
143
+ # the sequences with the provided prompt.
144
+ self.step_nr = self.min_context_len
145
+
146
+ # A context we keep growing in each decoding step
147
+ self.prefill()
148
+
149
+ for self.step_nr in range(self.min_context_len, self.max_text_len):
150
+ if not self._step():
151
+ break
152
+
153
+ return SequenceGeneratorOutput(self.hypotheses, counters=GenerationCounters())
154
+
155
+ def state_bag_reorder(self, new_order: torch.Tensor) -> None:
156
+ if self.state_bag is not None:
157
+ self.state_bag.reorder(new_order)
158
+
159
+ if self.context_state_bag is not None:
160
+ self.context_state_bag.reorder(new_order)
161
+
162
+ @torch.inference_mode()
163
+ def prefill(self, **kwargs) -> None:
164
+ """encode the prefix with the context encoder"""
165
+
166
+ assert self.context_state_bag is not None, (
167
+ "Expecting a context state bag to prefill"
168
+ )
169
+
170
+ context: EmbeddingsBatch
171
+
172
+ # FIXME for this model we can prefill with self.step_nr
173
+ prefill_len = self.step_nr - 1
174
+ if prefill_len > 0:
175
+ # normalize then encode
176
+ input_seqs = self.seqs[:, :prefill_len]
177
+ if self.model.config.sonar_normalizer_name is not None:
178
+ input_seqs = self.model.sonar_normalizer.normalize(input_seqs)
179
+
180
+ context = self.model.encode(
181
+ EmbeddingsBatch(input_seqs, None),
182
+ state_bag=self.context_state_bag,
183
+ **kwargs,
184
+ )
185
+
186
+ self.context_state_bag.increment_step_nr(prefill_len)
187
+
188
+ else:
189
+ logger.warning(
190
+ f"Skipping prefill since only a context size of {self.step_nr} is provided in the prefix"
191
+ )
192
+ context = EmbeddingsBatch(
193
+ torch.empty(
194
+ (self.seqs.shape[0], 0, self.model.model_dim),
195
+ dtype=self.seqs.dtype,
196
+ device=self.seqs.device,
197
+ )
198
+ )
199
+
200
+ self.context = context
201
+
202
+ @torch.inference_mode()
203
+ def _decode(
204
+ self,
205
+ seqs: torch.Tensor,
206
+ padding_mask: Optional[PaddingMask] = None,
207
+ **kwargs,
208
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
209
+ output, context = self.model.predict_next_sentence(
210
+ batch=EmbeddingsBatch(seqs, padding_mask),
211
+ context=self.context,
212
+ temperature=self.temperature,
213
+ state_bag=self.state_bag,
214
+ context_state_bag=self.context_state_bag,
215
+ **kwargs,
216
+ )
217
+ self.context = context
218
+
219
+ # Dummy score
220
+ scores = torch.zeros(seqs.shape[:-1])
221
+ return output.seqs, scores
222
+
223
+ def _step(self) -> bool:
224
+ # Generate the next step output.
225
+
226
+ if self.state_bag is None:
227
+ # Without a state_bag, we're forwarding the full prefix
228
+ # Encode the full context:
229
+
230
+ model_output, step_score = self._decode(
231
+ seqs=self.seqs[:, : self.step_nr],
232
+ padding_mask=None,
233
+ )
234
+ else:
235
+ # Since we're using a state_bag, we're only forwarding the last embedding
236
+ model_output, step_score = self._decode(
237
+ seqs=self.seqs[:, self.step_nr - 1 : self.step_nr],
238
+ padding_mask=None,
239
+ )
240
+
241
+ self.state_bag.increment_step_nr()
242
+
243
+ # model_output: EmbeddingBag
244
+ return self.finalize_step(model_output, step_score)
245
+
246
+ def finalize_step(
247
+ self, model_output: torch.Tensor, step_score: torch.Tensor
248
+ ) -> bool:
249
+ """Post-processing and finalizing a step
250
+ by checking all stopping criteria
251
+ Takes the model's outputed embeddings (model_output)
252
+ and their associated scores (step_score)
253
+ If we're stepping, return True, else return False
254
+ """
255
+ model_last_output = model_output[:, -1]
256
+ must_keep_going = self.text_seq_lens.gt(self.step_nr + 1)
257
+ should_finish_now = ~must_keep_going
258
+
259
+ # Record the current step prediction.
260
+ self.preds[:, self.step_nr - self.min_context_len] = model_last_output.squeeze(
261
+ 1
262
+ )
263
+ self.step_scores[:, self.step_nr - self.min_context_len] = step_score[:, -1]
264
+
265
+ # Save completed hypsptheses
266
+ finished_indices = should_finish_now.nonzero()
267
+
268
+ # Remove finished hypotheses and reorder variables/state_bag if any are left
269
+ if len(finished_indices) > 0:
270
+ for idx in finished_indices:
271
+ self.finish_sequence(int(idx))
272
+
273
+ active_mask = must_keep_going
274
+ active_indices = active_mask.nonzero().squeeze(-1)
275
+
276
+ if len(active_indices) == 0:
277
+ return False
278
+
279
+ self.reorder_state(active_indices)
280
+
281
+ return True
282
+
283
+ def finish_sequence(self, idx: int) -> None: # type: ignore
284
+ seq_len = int(self.text_seq_lens[idx].item())
285
+ sample_idx = int(self.sample_indices[idx])
286
+ self.hypotheses[sample_idx] = [
287
+ Hypothesis(
288
+ seq=self.preds[idx, : seq_len - self.min_context_len],
289
+ score=None,
290
+ step_scores=self.step_scores[idx], # Trim it as well?
291
+ )
292
+ ]
293
+
294
+ def reorder_state(self, new_order: torch.Tensor) -> None:
295
+ self.state_bag_reorder(new_order)
296
+
297
+ self.context = EmbeddingsBatch(
298
+ self.context.seqs.index_select(dim=0, index=new_order),
299
+ self.context.padding_mask,
300
+ )
301
+
302
+ self.seqs = self.seqs.index_select(dim=0, index=new_order)
303
+ self.preds = self.preds.index_select(dim=0, index=new_order)
304
+
305
+ self.sample_indices = self.sample_indices.index_select(dim=0, index=new_order)
306
+
307
+ self.step_scores = self.step_scores.index_select(dim=0, index=new_order)
308
+
309
+ if self.text_padding_mask is not None:
310
+ self.text_padding_mask = self.text_padding_mask.index_select(
311
+ dim=0, index=new_order
312
+ )
313
+
314
+ self.text_seq_lens = self.text_seq_lens.index_select(dim=0, index=new_order)
lcm/models/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ # We import all the model types in order to populate the model type registry
7
+ from lcm.models.base_lcm.loader import BASE_LCM_MODEL_TYPE
8
+ from lcm.models.two_tower_diffusion_lcm.loader import (
9
+ TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE,
10
+ )
11
+
12
+ __all__ = [
13
+ "BASE_LCM_MODEL_TYPE",
14
+ "TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE",
15
+ ]
lcm/models/abstract_lcm/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from lcm.models.abstract_lcm.builder import (
7
+ AbstractLCModel,
8
+ AbstractLCModelBuilder,
9
+ AbstractLCModelConfig,
10
+ )
11
+
12
+ __all__ = [
13
+ "AbstractLCModel",
14
+ "AbstractLCModelBuilder",
15
+ "AbstractLCModelConfig",
16
+ ]
lcm/models/abstract_lcm/builder.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from abc import abstractmethod
7
+ from dataclasses import dataclass
8
+ from typing import Optional
9
+
10
+ from fairseq2.config_registry import ConfigRegistry
11
+ from fairseq2.logging import get_log_writer
12
+ from fairseq2.typing import DataType, Device
13
+ from torch.nn import Module
14
+
15
+ from lcm.models.sonar_normalizer import SonarNormalizer, load_sonar_normalizer_model
16
+
17
+ logger = get_log_writer(__name__)
18
+
19
+
20
+ """
21
+ An abstract LCM model class for the bare minimum
22
+ """
23
+
24
+ ABSTRACT_LCM_MODEL_TYPE = "abstract_lcm"
25
+
26
+
27
+ @dataclass
28
+ class AbstractLCModelConfig:
29
+ model_type: str = ABSTRACT_LCM_MODEL_TYPE
30
+
31
+ sonar_embed_dim: int = 1024
32
+
33
+ sonar_normalizer_name: Optional[str] = None
34
+
35
+
36
+ lcm_archs = ConfigRegistry[AbstractLCModelConfig]()
37
+ lcm_arch = lcm_archs.decorator
38
+
39
+
40
+ class AbstractLCModel(Module):
41
+ """Asbtract Class for LCM models"""
42
+
43
+ def __init__(
44
+ self,
45
+ config: AbstractLCModelConfig,
46
+ ) -> None:
47
+ """
48
+ Asbtract LCM model
49
+ """
50
+ super().__init__()
51
+
52
+ self.config = config
53
+
54
+ @property
55
+ def dtype(self):
56
+ return next(self.parameters()).dtype
57
+
58
+ @property
59
+ def device(self):
60
+ return next(self.parameters()).device
61
+
62
+
63
+ class AbstractLCModelBuilder:
64
+ """Builds modules of an LCM"""
65
+
66
+ config: AbstractLCModelConfig
67
+ device: Optional[Device]
68
+ dtype: Optional[DataType]
69
+
70
+ def __init__(
71
+ self,
72
+ config: AbstractLCModelConfig,
73
+ *,
74
+ device: Optional[Device] = None,
75
+ dtype: Optional[DataType] = None,
76
+ ) -> None:
77
+ """
78
+ :param config:
79
+ The configuration.
80
+ :param device:
81
+ The device on which to initialize modules.
82
+ :param dtype:
83
+ The data type of module parameters and buffers.
84
+ """
85
+ self.config = config
86
+
87
+ self.device, self.dtype = device, dtype
88
+
89
+ def build_sonar_normalizer(
90
+ self,
91
+ ) -> Optional[SonarNormalizer]:
92
+ if self.config.sonar_normalizer_name is not None:
93
+ logger.info(
94
+ f"Building sonar_normalizer = {self.config.sonar_normalizer_name}"
95
+ )
96
+ return load_sonar_normalizer_model(
97
+ self.config.sonar_normalizer_name,
98
+ device=self.device,
99
+ dtype=self.dtype,
100
+ )
101
+ return None
102
+
103
+ @abstractmethod
104
+ def build_model(self) -> AbstractLCModel:
105
+ """Build a model."""
106
+ ...
lcm/models/base_lcm/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ # Register architectures
7
+ import lcm.models.base_lcm.archs # noqa
8
+ from lcm.models.base_lcm.builder import (
9
+ BaseLCModel,
10
+ BaseLCModelBuilder,
11
+ BaseLCModelConfig,
12
+ create_base_lcm_model,
13
+ )
14
+
15
+ __all__ = [
16
+ "BaseLCModel",
17
+ "BaseLCModelBuilder",
18
+ "BaseLCModelConfig",
19
+ "create_base_lcm_model",
20
+ ]
lcm/models/base_lcm/archs.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from lcm.models.base_lcm.builder import (
7
+ BaseLCModelConfig,
8
+ LCMFrontendConfig,
9
+ ProjectionConfig,
10
+ TransformerConfig,
11
+ lcm_arch,
12
+ )
13
+
14
+
15
+ # Every model must register a toy_{model_family}
16
+ @lcm_arch("toy_base_lcm")
17
+ def toy_base_lcm() -> BaseLCModelConfig:
18
+ return BaseLCModelConfig(
19
+ lcm=TransformerConfig(num_layers=2),
20
+ )
21
+
22
+
23
+ @lcm_arch("base_lcm_1_6B")
24
+ def base_lcm_1_6B() -> BaseLCModelConfig:
25
+ """Base 1.6B model
26
+ Parameter Size: 1,647,635,456
27
+ """
28
+ model_dim: int = 2048
29
+ num_attn_heads: int = 16
30
+ return BaseLCModelConfig(
31
+ max_seq_len=4096,
32
+ model_dim=model_dim,
33
+ sonar_embed_dim=1024,
34
+ sonar_normalizer_name="dummy_sonar_normalizer",
35
+ frontend=LCMFrontendConfig(),
36
+ lcm=TransformerConfig(
37
+ final_dropout_p=0.0,
38
+ attention_dropout_p=0.0,
39
+ dropout_p=0.1,
40
+ mha_output_proj_bias=True,
41
+ ffn_inner_dim=model_dim * 4,
42
+ num_attn_heads=num_attn_heads,
43
+ num_layers=32,
44
+ pos_embedding_style="rope",
45
+ use_swiglu=True,
46
+ layer_normalization_style="rms",
47
+ ),
48
+ postnet=ProjectionConfig(),
49
+ )
lcm/models/base_lcm/builder.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from dataclasses import dataclass, field
7
+ from typing import Optional
8
+
9
+ import torch.nn
10
+ from fairseq2.config_registry import ConfigRegistry
11
+ from fairseq2.logging import get_log_writer
12
+ from fairseq2.nn.incremental_state import IncrementalStateBag
13
+ from fairseq2.nn.transformer import AttentionMaskFactory, CausalAttentionMaskFactory
14
+ from fairseq2.typing import DataType, Device
15
+
16
+ from lcm.datasets.batch import EmbeddingsBatch
17
+ from lcm.models.abstract_lcm import (
18
+ AbstractLCModel,
19
+ AbstractLCModelBuilder,
20
+ AbstractLCModelConfig,
21
+ )
22
+ from lcm.models.base_lcm.frontend import LCMFrontend, LCMFrontendConfig
23
+ from lcm.nn.initialization import parse_norm_order
24
+ from lcm.nn.normalization import parse_layer_norm_factory
25
+ from lcm.nn.projection import Projection, ProjectionConfig
26
+ from lcm.nn.transformer import (
27
+ LCMTransformerDecoder,
28
+ TransformerConfig,
29
+ TransformerFactory,
30
+ )
31
+
32
+ logger = get_log_writer(__name__)
33
+
34
+ BASE_LCM_MODEL_TYPE = "base_lcm"
35
+
36
+
37
+ @dataclass
38
+ class BaseLCModelConfig(AbstractLCModelConfig):
39
+ model_type: str = BASE_LCM_MODEL_TYPE
40
+
41
+ max_seq_len: int = 2048
42
+
43
+ model_dim: int = 1024
44
+
45
+ model_output_dim: Optional[int] = None
46
+ """If ``None`` use SONAR dimension as output_dim."""
47
+
48
+ frontend: LCMFrontendConfig = field(default_factory=lambda: LCMFrontendConfig())
49
+ """The fronted config. This module maps from `sonar_embed_dim` to `model_dim`
50
+ and potentially adds positional embeddings"""
51
+
52
+ lcm: TransformerConfig = field(default_factory=lambda: TransformerConfig())
53
+ """The core lcm config. This is causal Transformer decoder"""
54
+
55
+ postnet: ProjectionConfig = field(default_factory=lambda: ProjectionConfig())
56
+ """The postnet config. A module mapping the output of the core lcm
57
+ back to `sonar_embed_dim`"""
58
+
59
+
60
+ lcm_archs = ConfigRegistry[BaseLCModelConfig]()
61
+ lcm_arch = lcm_archs.decorator
62
+
63
+
64
+ class BaseLCModel(AbstractLCModel):
65
+ """Base class for LCM models"""
66
+
67
+ config: BaseLCModelConfig
68
+
69
+ def __init__(
70
+ self,
71
+ config: BaseLCModelConfig,
72
+ lcm: LCMTransformerDecoder,
73
+ frontend: LCMFrontend,
74
+ postnet: Projection,
75
+ ) -> None:
76
+ """
77
+ Basic LCM model with :
78
+ - fronted
79
+ - lcm
80
+ - postnet
81
+ """
82
+ super().__init__(config)
83
+
84
+ self.frontend = frontend
85
+
86
+ self.lcm = lcm
87
+
88
+ self.postnet = postnet
89
+
90
+ self.model_dim = lcm.model_dim
91
+
92
+ self.sonar_embed_dim = config.sonar_embed_dim
93
+
94
+ def forward(
95
+ self,
96
+ batch: EmbeddingsBatch,
97
+ state_bag: Optional[IncrementalStateBag] = None,
98
+ **kwargs,
99
+ ) -> EmbeddingsBatch:
100
+ """
101
+ Scaling + Positions
102
+ If a normalizer is provided, the features will be normalized in the
103
+ frontend's pre_forward (e.g. MSE LCM) or in the criterion (Diffusion LCM)
104
+ """
105
+ seqs, padding_mask = self.frontend(
106
+ batch.seqs,
107
+ batch.padding_mask,
108
+ diffusion_timesteps=batch.diffusion_timesteps,
109
+ state_bag=state_bag,
110
+ **kwargs,
111
+ )
112
+
113
+ # Core LCM
114
+ seqs, padding_mask = self.lcm(
115
+ seqs,
116
+ padding_mask,
117
+ state_bag=state_bag,
118
+ **kwargs,
119
+ )
120
+
121
+ # Postnet:
122
+ seqs = self.postnet(seqs) # type: ignore
123
+
124
+ return EmbeddingsBatch(seqs=seqs, padding_mask=padding_mask)
125
+
126
+ def predict_next_sentence(
127
+ self,
128
+ batch: EmbeddingsBatch,
129
+ sample: bool = False,
130
+ temperature: float = 1.0,
131
+ state_bag: Optional[IncrementalStateBag] = None,
132
+ **kwargs,
133
+ ) -> EmbeddingsBatch:
134
+ """
135
+ The method for predicting the next sentence embeddings.
136
+ In the basic LCM, this is equivalent to just the forward method,
137
+ but the derived architectures may have a different implementation.
138
+ E.g. in VAE LCM, we run the VAE decoder on top of the `forward` results.
139
+
140
+ Args:
141
+ batch (EmbeddingsBatch): the sequence of concepts which
142
+ the model should continue.
143
+ sample (bool): whether to predict the single most probable next sentence
144
+ or to sample from the predicted distribution.
145
+ temperature (float): a positive float indicating the degree of diversity
146
+ for the sampling (active only if `sample is True`).
147
+ Returns:
148
+ EmbeddingsBatch: the batch with predicted SONAR sentences.
149
+ """
150
+ # Normalize the input embeddings if we're expected to
151
+ # normalize outside of the model's forward pass
152
+ if self.frontend.sonar_normalizer is not None:
153
+ batch = batch.normalize_seqs(self.frontend.sonar_normalizer)
154
+
155
+ # TODO: implement efficient sampling of multiple candidates
156
+ predicted_means = self.forward(batch, state_bag=state_bag, **kwargs)
157
+
158
+ if sample and temperature > 0:
159
+ noise = torch.randn_like(predicted_means.seqs) * temperature
160
+ predicted_means.seqs = predicted_means.seqs + noise
161
+
162
+ if self.frontend.sonar_normalizer is not None:
163
+ predicted_means = predicted_means.denormalize_seqs(
164
+ self.frontend.sonar_normalizer
165
+ )
166
+
167
+ return predicted_means
168
+
169
+
170
+ class BaseLCModelBuilder(AbstractLCModelBuilder):
171
+ """Builds modules of a base LCM model"""
172
+
173
+ config: BaseLCModelConfig
174
+ device: Optional[Device]
175
+ dtype: Optional[DataType]
176
+
177
+ def __init__(
178
+ self,
179
+ config: BaseLCModelConfig,
180
+ *,
181
+ device: Optional[Device] = None,
182
+ dtype: Optional[DataType] = None,
183
+ ) -> None:
184
+ super().__init__(config=config, device=device, dtype=dtype)
185
+ self.lcm_factory = TransformerFactory(
186
+ model_dim=self.config.model_dim,
187
+ max_seq_len=self.config.max_seq_len,
188
+ config=self.config.lcm,
189
+ device=device,
190
+ dtype=dtype,
191
+ )
192
+
193
+ if config.model_output_dim is None:
194
+ self.model_output_dim = self.config.sonar_embed_dim
195
+ else:
196
+ self.model_output_dim = config.model_output_dim
197
+
198
+ def build_model(self) -> BaseLCModel:
199
+ """Build a model."""
200
+
201
+ frontend = self.build_frontend()
202
+
203
+ lcm = self.build_core_lcm()
204
+
205
+ postnet = self.build_postnet()
206
+
207
+ return BaseLCModel(
208
+ config=self.config,
209
+ frontend=frontend,
210
+ lcm=lcm,
211
+ postnet=postnet,
212
+ )
213
+
214
+ def build_frontend(self) -> LCMFrontend:
215
+ """Build the LCM front-end (i.e., prenet)."""
216
+
217
+ return LCMFrontend(
218
+ sonar_embed_dim=self.config.sonar_embed_dim,
219
+ model_dim=self.config.model_dim,
220
+ config=self.config.frontend,
221
+ pos_encoder=self.lcm_factory.build_pos_encoder(),
222
+ sonar_normalizer=self.build_sonar_normalizer(),
223
+ device=self.device,
224
+ dtype=self.dtype,
225
+ )
226
+
227
+ def build_postnet(self) -> Projection:
228
+ return Projection(
229
+ output_dim=self.model_output_dim,
230
+ input_dim=self.config.model_dim,
231
+ config=self.config.postnet,
232
+ device=self.device,
233
+ dtype=self.dtype,
234
+ )
235
+
236
+ def build_attention_mask_factory(self):
237
+ self_attn_mask_factory: AttentionMaskFactory
238
+
239
+ self_attn_mask_factory = CausalAttentionMaskFactory()
240
+
241
+ return self_attn_mask_factory
242
+
243
+ def build_core_lcm(self) -> LCMTransformerDecoder:
244
+ """Build the core LCM module."""
245
+
246
+ config = self.config.lcm
247
+
248
+ layers = [self.lcm_factory.build_layer() for _ in range(config.num_layers)]
249
+
250
+ self_attn_mask_factory = self.build_attention_mask_factory()
251
+
252
+ if config.final_norm_order_style is None:
253
+ # The final norm order style will be that of the layer-level norm order
254
+ final_norm_order = parse_norm_order(config.norm_order_style)
255
+ else:
256
+ final_norm_order = parse_norm_order(config.final_norm_order_style)
257
+
258
+ layer_norm_factory = parse_layer_norm_factory(config.layer_normalization_style)
259
+
260
+ return LCMTransformerDecoder(
261
+ layers, # type: ignore
262
+ self_attn_mask_factory=self_attn_mask_factory,
263
+ norm_order=final_norm_order,
264
+ layer_norm_factory=layer_norm_factory,
265
+ dropout_p=config.final_dropout_p,
266
+ device=self.device,
267
+ dtype=self.dtype,
268
+ )
269
+
270
+
271
+ def create_base_lcm_model(
272
+ config: BaseLCModelConfig,
273
+ *,
274
+ device: Optional[Device] = None,
275
+ dtype: Optional[DataType] = None,
276
+ ) -> BaseLCModel:
277
+ """Create an LCM model.
278
+ :param config:
279
+ The configuration.
280
+ :param device:
281
+ The device on which to initialize modules.
282
+ :param dtype:
283
+ The data type of module parameters and buffers.
284
+ """
285
+ return BaseLCModelBuilder(config, device=device, dtype=dtype).build_model()
lcm/models/base_lcm/frontend.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Tuple
8
+
9
+ import torch
10
+ from fairseq2.logging import get_log_writer
11
+ from fairseq2.nn import Embedding, LearnedPositionEncoder, PositionEncoder
12
+ from fairseq2.nn.incremental_state import IncrementalStateBag
13
+ from fairseq2.nn.padding import PaddingMask
14
+ from fairseq2.nn.projection import Linear
15
+ from fairseq2.typing import DataType, Device
16
+ from torch import Tensor
17
+ from torch.nn import Dropout, Module
18
+
19
+ from lcm.models.sonar_normalizer.builder import SonarNormalizer
20
+ from lcm.nn.initialization import SONAR_STD, SUPPORTED_INIT_TYPES, get_init_fn
21
+
22
+ logger = get_log_writer(__name__)
23
+
24
+
25
+ @dataclass
26
+ class LCMFrontendConfig:
27
+ dropout_p: float = 0.0
28
+ """ The dropout probability applied to the module' output"""
29
+
30
+ pre_linear_bias: bool = True
31
+ """ Whether or not the pre-linear layer has a bias term"""
32
+
33
+ pre_linear_init_fn: SUPPORTED_INIT_TYPES = "kaiming_uniform"
34
+
35
+ scale_embeddings: bool = False
36
+ """ Scale the embeddings by model_dim before
37
+ adding positions (and before the pre_linear) """
38
+
39
+ weight_normalization: bool = False
40
+
41
+ embedding_std: float = SONAR_STD
42
+ """Most SONAR embeddings have a distribution with the mean close to 0
43
+ and std close to 0.006. Initializing embedding-like parameters (e.g. end-of-text vector)
44
+ from a similar distribution is recommended, to minimize their disruption of the model training
45
+ """
46
+
47
+
48
+ class LCMFrontend(Module):
49
+ """
50
+ A fronted for the LCM with positional embeddings
51
+ """
52
+
53
+ embed: Embedding
54
+ scale: float
55
+ pos_encoder: Optional[PositionEncoder]
56
+ dropout: Optional[Dropout]
57
+
58
+ def __init__(
59
+ self,
60
+ sonar_embed_dim: int,
61
+ model_dim: int,
62
+ config: LCMFrontendConfig,
63
+ pos_encoder: Optional[PositionEncoder],
64
+ timestep_embed_dim: int = 0,
65
+ sonar_normalizer: Optional[SonarNormalizer] = None,
66
+ *,
67
+ device: Optional[Device] = None,
68
+ dtype: Optional[DataType] = None,
69
+ ) -> None:
70
+ """
71
+ :param sonar_embed_dim
72
+ The embedding dimension of the sentence encoder, in this case SONAR
73
+ :param model_dim
74
+ The model embedding dimension
75
+ :param timestep_embed_dim
76
+ The embedding dimension of diffusion timesteps (if relevant, defaults to 0)
77
+ :param config:
78
+ A Frontend config. See `LCMFrontendConfig`
79
+ :param pos_encoder:
80
+ An optional position encoder.
81
+ """
82
+
83
+ super().__init__()
84
+
85
+ self.sonar_embed_dim = sonar_embed_dim
86
+
87
+ self.model_dim = model_dim
88
+
89
+ self.device = device
90
+
91
+ self.embed_scale: float = model_dim**0.5 if config.scale_embeddings else 1.0
92
+
93
+ logger.info(f"Using LCMFrontend with embeddings scaler = {self.embed_scale}")
94
+
95
+ # Optional sonar normalizer
96
+ self.sonar_normalizer = sonar_normalizer
97
+
98
+ # Pre-linear to map to model dimension
99
+
100
+ init_fn = get_init_fn(config.pre_linear_init_fn)
101
+
102
+ lin = Linear(
103
+ sonar_embed_dim + timestep_embed_dim,
104
+ model_dim,
105
+ bias=config.pre_linear_bias,
106
+ device=device,
107
+ dtype=dtype,
108
+ init_fn=init_fn,
109
+ )
110
+
111
+ if config.weight_normalization:
112
+ self.pre_linear = torch.nn.utils.parametrizations.weight_norm(lin)
113
+ else:
114
+ self.pre_linear = lin
115
+
116
+ if pos_encoder is not None:
117
+ if pos_encoder.encoding_dim != self.model_dim:
118
+ raise ValueError(
119
+ f"`encoding_dim` of `pos_encoder` and `embedding_dim` of \
120
+ `embed` must be equal, but are {pos_encoder.encoding_dim} \
121
+ and {self.model_dim} instead."
122
+ )
123
+
124
+ self.pos_encoder = pos_encoder
125
+ else:
126
+ self.register_module("pos_encoder", None)
127
+
128
+ if config.dropout_p > 0.0:
129
+ self.dropout = Dropout(config.dropout_p)
130
+ else:
131
+ self.register_module("dropout", None)
132
+
133
+ self.reset_parameters(embedding_std=config.embedding_std)
134
+
135
+ def reset_parameters(self, embedding_std: float) -> None:
136
+ """Initialize module parameters.
137
+ The positional embeddings should be initialized with the
138
+ same order of magnitude as the semantic embeddings, in order
139
+ to make the early training as stable as possible.
140
+ Otherwise, the positional and special token embeddings would
141
+ flood out the semantic information.
142
+ """
143
+ logger.info(
144
+ f"Initializing frontend embeddings (special and positional) ~ N(0, {embedding_std})"
145
+ )
146
+ if isinstance(self.pos_encoder, LearnedPositionEncoder):
147
+ torch.nn.init.normal_(self.pos_encoder.weight, std=embedding_std)
148
+
149
+ def pre_forward(
150
+ self, seqs: Tensor, diffusion_timesteps: Optional[Tensor] = None, **kwargs
151
+ ) -> Tensor:
152
+ return seqs
153
+
154
+ def forward(
155
+ self,
156
+ seqs: Tensor,
157
+ padding_mask: Optional[PaddingMask],
158
+ state_bag: Optional[IncrementalStateBag] = None,
159
+ diffusion_timesteps: Optional[Tensor] = None,
160
+ **kwargs,
161
+ ) -> Tuple[Tensor, Optional[PaddingMask]]:
162
+ """
163
+ Apply pre-linear (if relevant) and add positional embeddings
164
+ """
165
+
166
+ # Normalize in standard LCM or add timestep embeddings in diffusion frontentd
167
+ seqs = self.pre_forward(seqs, diffusion_timesteps, **kwargs)
168
+
169
+ # pre-linear if any:
170
+ seqs = self.pre_linear(self.embed_scale * seqs)
171
+
172
+ if self.pos_encoder is not None:
173
+ seqs = self.pos_encoder(
174
+ seqs,
175
+ padding_mask,
176
+ state_bag=state_bag,
177
+ **kwargs,
178
+ )
179
+
180
+ if self.dropout is not None:
181
+ seqs = self.dropout(seqs)
182
+
183
+ return seqs, padding_mask
lcm/models/base_lcm/loader.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ import logging
7
+ from typing import Any, Dict
8
+
9
+ from fairseq2.models.config_loader import StandardModelConfigLoader
10
+ from fairseq2.models.loader import StandardModelLoader, load_model
11
+ from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
12
+
13
+ from lcm.models.base_lcm.builder import (
14
+ BASE_LCM_MODEL_TYPE,
15
+ BaseLCModelConfig,
16
+ create_base_lcm_model,
17
+ lcm_archs,
18
+ )
19
+ from lcm.utils.model_type_registry import ModelTypeConfig, lcm_model_type_registry
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def convert_lcm_checkpoint(
25
+ checkpoint: Dict[str, Any], config: BaseLCModelConfig
26
+ ) -> Dict[str, Any]:
27
+ # For DDP checkpoints
28
+ # We need to first remove the prefix "module." from state dict keys.
29
+ consume_prefix_in_state_dict_if_present(checkpoint["model"], "module.")
30
+ return checkpoint
31
+
32
+
33
+ load_base_lcm_config = StandardModelConfigLoader(
34
+ family=BASE_LCM_MODEL_TYPE,
35
+ config_kls=BaseLCModelConfig,
36
+ arch_configs=lcm_archs,
37
+ )
38
+
39
+ load_base_lcm_model = StandardModelLoader(
40
+ config_loader=load_base_lcm_config,
41
+ factory=create_base_lcm_model,
42
+ checkpoint_converter=convert_lcm_checkpoint,
43
+ restrict_checkpoints=False,
44
+ )
45
+
46
+ load_model.register(BASE_LCM_MODEL_TYPE, load_base_lcm_model)
47
+
48
+ lcm_model_type_registry.register(
49
+ ModelTypeConfig(
50
+ model_type=BASE_LCM_MODEL_TYPE,
51
+ config_loader=load_base_lcm_config,
52
+ model_factory=create_base_lcm_model,
53
+ model_loader=load_base_lcm_model,
54
+ )
55
+ )
lcm/models/base_lcm/normalization.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from typing import Optional, final
7
+
8
+ import torch
9
+ from fairseq2.nn import LayerNorm, RMSNorm
10
+ from fairseq2.typing import DataType, Device, override
11
+
12
+
13
+ @final
14
+ class FP32LayerNorm(LayerNorm):
15
+ """Applies Layer Normalization in single-precision."""
16
+
17
+ @override
18
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
19
+ w, b = self.weight, self.bias
20
+
21
+ # cast input and params to float32
22
+ fp32_x = x.float()
23
+ fp32_w = w.float() if w is not None else None
24
+ fp32_b = b.float() if b is not None else None
25
+
26
+ y = torch.nn.functional.layer_norm(
27
+ fp32_x, self.normalized_shape, fp32_w, fp32_b, self.eps
28
+ )
29
+
30
+ return y.type_as(x)
31
+
32
+
33
+ def build_rms_layer_norm(
34
+ model_dim: int,
35
+ *,
36
+ device: Optional[Device] = None,
37
+ dtype: Optional[DataType] = None,
38
+ ) -> LayerNorm:
39
+ """Build an RMS Layer Normalization module."""
40
+ return RMSNorm(model_dim, bias=False, device=device, dtype=dtype)
41
+
42
+
43
+ def build_fp32_layer_norm(
44
+ model_dim: int,
45
+ *,
46
+ device: Optional[Device] = None,
47
+ dtype: Optional[DataType] = None,
48
+ ) -> LayerNorm:
49
+ """Build an Single-precision Layer Normalization module."""
50
+ return FP32LayerNorm(model_dim, bias=False, device=device, dtype=dtype)
lcm/models/sonar_normalizer/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ # Register architectures
7
+ import lcm.models.sonar_normalizer.archs # noqa
8
+ from lcm.models.sonar_normalizer.builder import (
9
+ SonarNormalizer,
10
+ SonarNormalizerConfig,
11
+ create_sonar_normalizer,
12
+ )
13
+ from lcm.models.sonar_normalizer.loader import load_sonar_normalizer_model
14
+
15
+ __all__ = [
16
+ "SonarNormalizer",
17
+ "SonarNormalizerConfig",
18
+ "create_sonar_normalizer",
19
+ "load_sonar_normalizer_model",
20
+ ]
lcm/models/sonar_normalizer/archs.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from lcm.models.sonar_normalizer.builder import (
7
+ SonarNormalizerConfig,
8
+ sonar_normalizer_arch,
9
+ )
10
+
11
+
12
+ @sonar_normalizer_arch("base")
13
+ def _base_sonar_normalizer() -> SonarNormalizerConfig:
14
+ """The base architecture for all center-and-scale normalizers
15
+ regardless of how the center/scale are estimated"""
16
+ return SonarNormalizerConfig(
17
+ dim=1024,
18
+ )
19
+
20
+
21
+ @sonar_normalizer_arch("base_page4k")
22
+ def _base_page_normalizer() -> SonarNormalizerConfig:
23
+ return SonarNormalizerConfig(
24
+ dim=4 * 1024,
25
+ )
26
+
27
+
28
+ @sonar_normalizer_arch("base_fft")
29
+ def _base_fft_sonar_normalizer() -> SonarNormalizerConfig:
30
+ return SonarNormalizerConfig(dim=1024, with_fft=True)
31
+
32
+
33
+ @sonar_normalizer_arch("clipping")
34
+ def _clipping_sonar_normalizer() -> SonarNormalizerConfig:
35
+ return SonarNormalizerConfig(dim=1024, clip_proba=1e-4)
36
+
37
+
38
+ @sonar_normalizer_arch("clipping_fft")
39
+ def _clipping_fft_sonar_normalizer() -> SonarNormalizerConfig:
40
+ return SonarNormalizerConfig(dim=1024, clip_proba=1e-4, with_fft=True)
lcm/models/sonar_normalizer/builder.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Literal, Optional
8
+
9
+ import torch
10
+ from fairseq2.config_registry import ConfigRegistry
11
+ from fairseq2.typing import DataType, Device
12
+ from torch import Tensor
13
+ from torch.nn import Module
14
+
15
+
16
+ @dataclass
17
+ class SonarNormalizerConfig:
18
+ dim: int = 1024
19
+ """The dimension of the features to be normalized"""
20
+
21
+ clip_proba: Optional[float] = None
22
+ """
23
+ If `clip_proba` is not None, `clip_min` and `clip_max` will
24
+ be used to clip the features before normalizing.
25
+ `clip_min` and `clip_max` correspond to the pre-computed `clip_proba`
26
+ and `1-clip_proba` quantiles respectively.
27
+ """
28
+
29
+ with_fft: bool = False
30
+ """
31
+ Applying FFT transform at the raw input before all other transforms.
32
+ """
33
+
34
+ quantile_min: float = 0.25
35
+ """The lower quantile used to measure the IQR when estimating the scale with a robust scaler"""
36
+
37
+ quantile_max: float = 0.75
38
+ """The upper quantile used to measure the IQR when estimating the scale with a robust scaler"""
39
+
40
+ normalization_method: Literal["standard", "robust", "gaussian_robust"] = (
41
+ "gaussian_robust"
42
+ )
43
+ """
44
+ Dictates how the normalizer's scale is evaluated when fitting.
45
+ (1) 'standard': center=mean, scale = std
46
+ (2) 'robust': center=median, scale = IQR = Qmax - Qmin
47
+ (3) 'gaussian_robust': center=median, scale = IQR / k,
48
+ where k=`stats.norm.ppf(q_max / 100.0) - stats.norm.ppf(q_min / 100.0)`
49
+ i.e scale = scale = 0.7413 x IQR if q_min=0.25 and q_max=0.75.
50
+ This is the robust normalization of https://arxiv.org/pdf/2307.05445
51
+ """
52
+
53
+
54
+ sonar_normalizer_archs = ConfigRegistry[SonarNormalizerConfig]()
55
+ sonar_normalizer_arch = sonar_normalizer_archs.decorator
56
+
57
+
58
+ class FFTInterface:
59
+ @staticmethod
60
+ def fft_transform(embeddings: Tensor) -> Tensor:
61
+ dtype = embeddings.dtype
62
+ if dtype in [torch.float16, torch.bfloat16]:
63
+ embeddings = embeddings.to(dtype=torch.float32)
64
+ embeddings = torch.fft.rfft(embeddings, norm="backward")
65
+ return torch.concat(
66
+ [torch.real(embeddings), torch.imag(embeddings)[..., 1:-1]], dim=-1
67
+ ).to(dtype)
68
+
69
+ @staticmethod
70
+ def fft_inverse_transform(embeddings: Tensor) -> Tensor:
71
+ assert embeddings.shape[-1] % 2 == 0
72
+ dtype = embeddings.dtype
73
+ if dtype in [torch.float16, torch.bfloat16]:
74
+ embeddings = embeddings.to(dtype=torch.float32)
75
+ rr, im = torch.split(
76
+ embeddings,
77
+ [embeddings.shape[-1] // 2 + 1, embeddings.shape[-1] // 2 - 1],
78
+ dim=-1,
79
+ )
80
+ im = torch.concat(
81
+ [torch.zeros_like(im[..., :1]), im, torch.zeros_like(im[..., :1])], dim=-1
82
+ )
83
+ embeddings = torch.fft.irfft(rr + im * 1j)
84
+ return embeddings.to(dtype)
85
+
86
+
87
+ class SonarNormalizer(FFTInterface, Module):
88
+ """
89
+ To perform efficient diffusion modeling, SONAR embeddings need to be
90
+ normalized. This SonarNormalizer follows the robust normalization introduced in
91
+ https://arxiv.org/abs/2307.05445
92
+ Quoting from the paper: "Due to the very long-tailed feature distribution, typical mean and standard deviation statistics will be
93
+ heavily biased. We thus propose a robust alternative based on the feature distribution quantiles. We
94
+ take the median as the center of the distribution and approximate its scale using the Normalized
95
+ InterQuartile Range (IQR) for a normal distribution: 0.7413 × IQR
96
+ """
97
+
98
+ def __init__(
99
+ self,
100
+ config: SonarNormalizerConfig,
101
+ device: Optional[Device] = None,
102
+ dtype: Optional[DataType] = None,
103
+ ) -> None:
104
+ super().__init__()
105
+ self.config = config
106
+
107
+ self.register_buffer(
108
+ "center", torch.zeros(config.dim, dtype=dtype, device=device)
109
+ )
110
+ self.register_buffer(
111
+ "scale", torch.ones(config.dim, dtype=dtype, device=device)
112
+ )
113
+ if self.config.clip_proba is not None:
114
+ self.register_buffer(
115
+ "clip_min", torch.ones(config.dim, dtype=dtype, device=device)
116
+ )
117
+ self.register_buffer(
118
+ "clip_max", torch.ones(config.dim, dtype=dtype, device=device)
119
+ )
120
+
121
+ def normalize(self, embeddings: Tensor) -> Tensor:
122
+ if self.config.with_fft:
123
+ embeddings = self.fft_transform(embeddings)
124
+
125
+ embeddings = (embeddings - self.center) / self.scale
126
+ if self.config.clip_proba is not None:
127
+ embeddings = torch.clamp(embeddings, min=self.clip_min, max=self.clip_max)
128
+ return embeddings
129
+
130
+ def denormalize(self, embeddings: Tensor) -> Tensor:
131
+ if self.config.clip_proba is not None:
132
+ embeddings = torch.clamp(embeddings, min=self.clip_min, max=self.clip_max)
133
+
134
+ embeddings = (embeddings * self.scale) + self.center
135
+ if self.config.with_fft:
136
+ embeddings = self.fft_inverse_transform(embeddings)
137
+ return embeddings
138
+
139
+ @torch.no_grad()
140
+ def fit(self, embeddings: Tensor):
141
+ if self.config.normalization_method in [
142
+ "robust",
143
+ "gaussian_robust",
144
+ ]:
145
+ from sklearn.preprocessing import RobustScaler
146
+
147
+ _scaler = RobustScaler(
148
+ unit_variance=self.config.normalization_method == "gaussian_robust",
149
+ quantile_range=(self.config.quantile_min, self.config.quantile_max),
150
+ )
151
+
152
+ elif self.config.normalization_method == "standard":
153
+ from sklearn.preprocessing import StandardScaler
154
+
155
+ _scaler = StandardScaler()
156
+ else:
157
+ raise ValueError(
158
+ f"Unrecognizable method {self.config.normalization_method} for scaling input features"
159
+ )
160
+
161
+ assert embeddings.shape[-1] == self.config.dim
162
+ assert len(embeddings.shape) == 2
163
+
164
+ if self.config.with_fft:
165
+ embeddings = self.fft_transform(embeddings)
166
+
167
+ embeddings = _scaler.fit_transform(embeddings.cpu().float().numpy())
168
+
169
+ if self.config.normalization_method in [
170
+ "robust",
171
+ "gaussian_robust",
172
+ ]:
173
+ _center = _scaler.center_
174
+ _scale = _scaler.scale_
175
+
176
+ elif self.config.normalization_method == "standard":
177
+ _center = _scaler.mean_
178
+ _scale = _scaler.scale_
179
+
180
+ self.center[:] = torch.tensor(
181
+ _center, dtype=self.center.dtype, device=self.center.device
182
+ )
183
+ self.scale[:] = torch.tensor(
184
+ _scale, dtype=self.scale.dtype, device=self.scale.device
185
+ )
186
+
187
+ if self.config.clip_proba is not None:
188
+ self.clip_min[:] = torch.quantile(
189
+ torch.tensor(embeddings), self.config.clip_proba, dim=0
190
+ ).to(dtype=self.clip_min.dtype, device=self.clip_min.device)
191
+ self.clip_max[:] = torch.quantile(
192
+ torch.tensor(embeddings), 1 - self.config.clip_proba, dim=0
193
+ ).to(dtype=self.clip_max.dtype, device=self.clip_max.device)
194
+
195
+
196
+ def create_sonar_normalizer(
197
+ config: SonarNormalizerConfig,
198
+ *,
199
+ device: Optional[Device] = None,
200
+ dtype: Optional[DataType] = None,
201
+ ) -> SonarNormalizer:
202
+ """Create an LCM model.
203
+ :param config:
204
+ The configuration.
205
+ :param device:
206
+ The device on which to initialize modules.
207
+ :param dtype:
208
+ The data type of module parameters and buffers.
209
+ """
210
+ return SonarNormalizer(config, device=device, dtype=dtype)
lcm/models/sonar_normalizer/loader.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+
7
+ from fairseq2.models.config_loader import StandardModelConfigLoader
8
+ from fairseq2.models.loader import StandardModelLoader, load_model
9
+
10
+ from lcm.models.sonar_normalizer.builder import (
11
+ SonarNormalizerConfig,
12
+ create_sonar_normalizer,
13
+ sonar_normalizer_archs,
14
+ )
15
+
16
+ load_sonar_normalizer_config = StandardModelConfigLoader(
17
+ family="sonar_normalizer",
18
+ config_kls=SonarNormalizerConfig,
19
+ arch_configs=sonar_normalizer_archs,
20
+ )
21
+
22
+ load_sonar_normalizer_model = StandardModelLoader(
23
+ config_loader=load_sonar_normalizer_config,
24
+ factory=create_sonar_normalizer,
25
+ restrict_checkpoints=False,
26
+ )
27
+
28
+ load_model.register("sonar_normalizer", load_sonar_normalizer_model)
lcm/models/two_tower_diffusion_lcm/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ # Register architectures
7
+ import lcm.models.two_tower_diffusion_lcm.archs # noqa
lcm/models/two_tower_diffusion_lcm/archs.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from lcm.models.two_tower_diffusion_lcm.builder import (
7
+ DenoiserConfig,
8
+ EncoderFrontendConfig,
9
+ TransformerConfig,
10
+ TwoTowerDiffusionLCModelConfig,
11
+ lcm_arch,
12
+ )
13
+ from lcm.nn.projection import ProjectionConfig
14
+ from lcm.nn.schedulers import DDIMSchedulerConfig
15
+
16
+
17
+ @lcm_arch("toy_two_tower_diffusion_lcm")
18
+ def toy_lcm() -> TwoTowerDiffusionLCModelConfig:
19
+ return TwoTowerDiffusionLCModelConfig(
20
+ context_encoder=TransformerConfig(num_layers=2),
21
+ denoiser=DenoiserConfig(num_layers=2),
22
+ # TODO change normalizer name to align with the normalizer instructions
23
+ sonar_normalizer_name="dummy_sonar_normalizer_A",
24
+ )
25
+
26
+
27
+ @lcm_arch("arch_lexa_lcm_pre0_toy")
28
+ def lexa_lcm_pre0_toy() -> TwoTowerDiffusionLCModelConfig:
29
+ return TwoTowerDiffusionLCModelConfig(
30
+ context_encoder=TransformerConfig(num_layers=2),
31
+ denoiser=DenoiserConfig(num_layers=2),
32
+ sonar_normalizer_name="sonar_normalizer_wikipedia_en_1m",
33
+ trained_with_cf_guidance=True,
34
+ )
35
+
36
+
37
+ @lcm_arch("arch_lexa_lcm_pre0_minimal")
38
+ def lexa_lcm_pre0_minimal() -> TwoTowerDiffusionLCModelConfig:
39
+ """4-layer encoder / 6-layer denoiser / model dim 768"""
40
+ model_dim: int = 768 # Reduced from 2048 to 768
41
+ num_attn_heads: int = 12 # Reduced from 16 to 12
42
+ return TwoTowerDiffusionLCModelConfig(
43
+ model_dim=model_dim,
44
+ max_seq_len=2048,
45
+ frontend=EncoderFrontendConfig(),
46
+ context_encoder=TransformerConfig(
47
+ num_layers=3,
48
+ ffn_inner_dim=3 * model_dim, # Reduced from 4 * model_dim to 3 * model_dim
49
+ num_attn_heads=num_attn_heads,
50
+ final_dropout_p=0.0,
51
+ attention_dropout_p=0.0,
52
+ dropout_p=0.1,
53
+ mha_output_proj_bias=True,
54
+ use_swiglu=True,
55
+ layer_normalization_style="rms",
56
+ pos_embedding_style="rope",
57
+ ),
58
+ denoiser=DenoiserConfig(
59
+ num_layers=6, # Reduced from 13 to 6
60
+ timestep_embed_dim=model_dim,
61
+ ffn_inner_dim=3 * model_dim, # Reduced from 4 * model_dim to 3 * model_dim
62
+ pos_embedding_style="none",
63
+ num_attn_heads=num_attn_heads,
64
+ final_dropout_p=0.0,
65
+ attention_dropout_p=0.0,
66
+ dropout_p=0.1,
67
+ mha_output_proj_bias=True,
68
+ use_swiglu=True,
69
+ layer_normalization_style="rms",
70
+ pre_denoiser=ProjectionConfig(),
71
+ post_denoiser=ProjectionConfig(),
72
+ ),
73
+ sonar_normalizer_name="sonar_normalizer_wikipedia_en_1m",
74
+ trained_with_cf_guidance=True,
75
+ noise_scheduler=DDIMSchedulerConfig(num_diffusion_train_steps=100),
76
+ )
77
+
78
+
79
+ @lcm_arch("arch_lexa_lcm_pre0")
80
+ def lexa_lcm_pre0() -> TwoTowerDiffusionLCModelConfig:
81
+ """4-layer encoder / 10-layer denoiser / model dim 1024
82
+ Parameter Size: 287,880,192"""
83
+ model_dim: int = 1024 # Reduced from 2048 to 1024
84
+ num_attn_heads: int = 16
85
+ return TwoTowerDiffusionLCModelConfig(
86
+ model_dim=model_dim,
87
+ max_seq_len=2048,
88
+ frontend=EncoderFrontendConfig(),
89
+ context_encoder=TransformerConfig(
90
+ num_layers=4, # Reduced from 5 to 4
91
+ ffn_inner_dim=3 * model_dim, # Reduced from 4 * model_dim to 3 * model_dim
92
+ num_attn_heads=num_attn_heads,
93
+ final_dropout_p=0.0,
94
+ attention_dropout_p=0.0,
95
+ dropout_p=0.1,
96
+ mha_output_proj_bias=True,
97
+ use_swiglu=True,
98
+ layer_normalization_style="rms",
99
+ pos_embedding_style="rope",
100
+ ),
101
+ denoiser=DenoiserConfig(
102
+ num_layers=10, # Reduced from 13 to 10
103
+ timestep_embed_dim=model_dim,
104
+ ffn_inner_dim=3 * model_dim, # Reduced from 4 * model_dim to 3 * model_dim
105
+ pos_embedding_style="none",
106
+ num_attn_heads=num_attn_heads,
107
+ final_dropout_p=0.0,
108
+ attention_dropout_p=0.0,
109
+ dropout_p=0.1,
110
+ mha_output_proj_bias=True,
111
+ use_swiglu=True,
112
+ layer_normalization_style="rms",
113
+ pre_denoiser=ProjectionConfig(),
114
+ post_denoiser=ProjectionConfig(),
115
+ ),
116
+ sonar_normalizer_name="sonar_normalizer_wikipedia_en_1m",
117
+ trained_with_cf_guidance=True,
118
+ noise_scheduler=DDIMSchedulerConfig(num_diffusion_train_steps=100),
119
+ )
120
+
121
+
122
+ @lcm_arch("two_tower_diffusion_lcm_1_6B")
123
+ def two_tower_diffusion_lcm_1_6B() -> TwoTowerDiffusionLCModelConfig:
124
+ """5-layer encodder / 13-layer denoiser / model dim 2048
125
+ Parameter Size: 1,635,101,696"""
126
+ model_dim: int = 2048
127
+ num_attn_heads: int = 16
128
+ return TwoTowerDiffusionLCModelConfig(
129
+ model_dim=model_dim,
130
+ max_seq_len=4096,
131
+ frontend=EncoderFrontendConfig(),
132
+ context_encoder=TransformerConfig(
133
+ num_layers=5,
134
+ ffn_inner_dim=4 * model_dim,
135
+ num_attn_heads=num_attn_heads,
136
+ final_dropout_p=0.0,
137
+ attention_dropout_p=0.0,
138
+ dropout_p=0.1,
139
+ mha_output_proj_bias=True,
140
+ use_swiglu=True,
141
+ layer_normalization_style="rms",
142
+ pos_embedding_style="rope",
143
+ ),
144
+ denoiser=DenoiserConfig(
145
+ num_layers=13,
146
+ timestep_embed_dim=model_dim,
147
+ ffn_inner_dim=4 * model_dim,
148
+ pos_embedding_style="none",
149
+ num_attn_heads=num_attn_heads,
150
+ final_dropout_p=0.0,
151
+ attention_dropout_p=0.0,
152
+ dropout_p=0.1,
153
+ mha_output_proj_bias=True,
154
+ use_swiglu=True,
155
+ layer_normalization_style="rms",
156
+ pre_denoiser=ProjectionConfig(),
157
+ post_denoiser=ProjectionConfig(),
158
+ ),
159
+ # TODO change normalizer name to align with the normalizer instructions
160
+ sonar_normalizer_name="dummy_sonar_normalizer_B",
161
+ trained_with_cf_guidance=True,
162
+ noise_scheduler=DDIMSchedulerConfig(num_diffusion_train_steps=100),
163
+ )
164
+
165
+
166
+ @lcm_arch("two_tower_diffusion_lcm_7B")
167
+ def two_tower_diffusion_lcm_7B() -> TwoTowerDiffusionLCModelConfig:
168
+ # 5-layer encodder / 14-layer denoiser / model dim 4096
169
+ # Parameter Size: 6,930,781,696
170
+ model_dim: int = 4096
171
+ num_attn_heads: int = 32
172
+ return TwoTowerDiffusionLCModelConfig(
173
+ model_dim=model_dim,
174
+ max_seq_len=4096,
175
+ frontend=EncoderFrontendConfig(),
176
+ context_encoder=TransformerConfig(
177
+ num_layers=5,
178
+ ffn_inner_dim=4 * model_dim,
179
+ num_attn_heads=num_attn_heads,
180
+ final_dropout_p=0.0,
181
+ attention_dropout_p=0.0,
182
+ dropout_p=0.1,
183
+ mha_output_proj_bias=True,
184
+ use_swiglu=True,
185
+ layer_normalization_style="rms",
186
+ pos_embedding_style="rope",
187
+ ),
188
+ denoiser=DenoiserConfig(
189
+ num_layers=14,
190
+ timestep_embed_dim=model_dim,
191
+ ffn_inner_dim=4 * model_dim,
192
+ pos_embedding_style="none",
193
+ num_attn_heads=num_attn_heads,
194
+ final_dropout_p=0.0,
195
+ attention_dropout_p=0.0,
196
+ dropout_p=0.1,
197
+ mha_output_proj_bias=True,
198
+ use_swiglu=True,
199
+ layer_normalization_style="rms",
200
+ pre_denoiser=ProjectionConfig(),
201
+ post_denoiser=ProjectionConfig(),
202
+ ),
203
+ # TODO change normalizer name to align with the normalizer instructions
204
+ sonar_normalizer_name="dummy_sonar_normalizer_C",
205
+ trained_with_cf_guidance=True,
206
+ noise_scheduler=DDIMSchedulerConfig(num_diffusion_train_steps=100),
207
+ )
lcm/models/two_tower_diffusion_lcm/builder.py ADDED
@@ -0,0 +1,628 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from dataclasses import dataclass, field
7
+ from typing import Optional, Tuple
8
+
9
+ import torch
10
+ from fairseq2.config_registry import ConfigRegistry
11
+ from fairseq2.logging import get_log_writer
12
+ from fairseq2.nn.padding import PaddingMask, get_seq_lens
13
+ from fairseq2.nn.transformer import CausalAttentionMaskFactory
14
+ from fairseq2.typing import DataType, Device
15
+ from torch import Tensor
16
+
17
+ from lcm.datasets.batch import EmbeddingsBatch
18
+ from lcm.models.abstract_lcm import (
19
+ AbstractLCModel,
20
+ AbstractLCModelBuilder,
21
+ AbstractLCModelConfig,
22
+ )
23
+ from lcm.models.sonar_normalizer.builder import SonarNormalizer
24
+ from lcm.models.two_tower_diffusion_lcm.frontend import (
25
+ EncoderFrontend,
26
+ EncoderFrontendConfig,
27
+ )
28
+ from lcm.nn.denoisers import (
29
+ DenoiserConfig,
30
+ LCMDenoiser,
31
+ LCMDenoiserTransformerFactory,
32
+ )
33
+ from lcm.nn.incremental_state import LCMIncrementalStateBag
34
+ from lcm.nn.initialization import parse_norm_order
35
+ from lcm.nn.normalization import parse_layer_norm_factory
36
+ from lcm.nn.schedulers import DDIMScheduler, DDIMSchedulerConfig
37
+ from lcm.nn.transformer import (
38
+ LCMTransformerDecoder,
39
+ TransformerConfig,
40
+ TransformerFactory,
41
+ )
42
+
43
+ logger = get_log_writer(__name__)
44
+
45
+
46
+ TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE = "two_tower_diffusion_lcm"
47
+
48
+
49
+ @dataclass
50
+ class TwoTowerDiffusionLCModelConfig(AbstractLCModelConfig):
51
+ model_type: str = TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE
52
+
53
+ max_seq_len: int = 2048
54
+
55
+ model_dim: int = 1024
56
+
57
+ frontend: EncoderFrontendConfig = field(
58
+ default_factory=lambda: EncoderFrontendConfig()
59
+ )
60
+ """ The fronted config. This module maps from `sonar_embed_dim` to `model_dim`
61
+ and potentially adds positional embeddings"""
62
+
63
+ context_encoder: TransformerConfig = field(
64
+ default_factory=lambda: TransformerConfig()
65
+ )
66
+ """The context encoder config. This is causal Transformer decoder"""
67
+
68
+ noise_scheduler: DDIMSchedulerConfig = field(
69
+ default_factory=lambda: DDIMSchedulerConfig()
70
+ )
71
+ """The config of the noise scheduler.
72
+ See lcm/diffusion_schedulers/ddim for more"""
73
+
74
+ denoiser: DenoiserConfig = field(default_factory=lambda: DenoiserConfig())
75
+ """the config of the denoiser"""
76
+
77
+ trained_with_cf_guidance: bool = False
78
+ """If `True`, the model will be trained with classifier-free guidance i.e.,
79
+ unconditional embedding generation.
80
+ The CF-guidance probability is set in
81
+ DiffusionLCMCriterionConfig.cf_guidance_probability"""
82
+
83
+
84
+ lcm_archs = ConfigRegistry[TwoTowerDiffusionLCModelConfig]()
85
+ lcm_arch = lcm_archs.decorator
86
+
87
+
88
+ class TwoTowerDiffusionLCModel(AbstractLCModel):
89
+ """Class for a diffusion-based LCM model"""
90
+
91
+ config: TwoTowerDiffusionLCModelConfig
92
+
93
+ def __init__(
94
+ self,
95
+ config: TwoTowerDiffusionLCModelConfig,
96
+ sonar_normalizer: SonarNormalizer,
97
+ encoder_frontend: EncoderFrontend,
98
+ context_encoder: LCMTransformerDecoder,
99
+ denoiser: LCMDenoiser,
100
+ noise_scheduler: DDIMScheduler,
101
+ ) -> None:
102
+ super().__init__(config)
103
+
104
+ self.model_dim = context_encoder.model_dim
105
+
106
+ self.sonar_embed_dim = config.sonar_embed_dim
107
+
108
+ self.sonar_normalizer = sonar_normalizer
109
+
110
+ self.encoder_frontend = encoder_frontend
111
+ """The frontend of the context encoder.
112
+ This frontend simply applies a pre-linear projection
113
+ (to increase dimensionality) then adds positional embeddings"""
114
+
115
+ self.context_encoder = context_encoder
116
+ """A causal Transformer decoder"""
117
+
118
+ self.noise_scheduler = noise_scheduler
119
+ """The diffusion noise scheduler"""
120
+
121
+ self.denoiser = denoiser
122
+
123
+ def extra_repr(self) -> str:
124
+ """:meta private:"""
125
+ s = super().extra_repr()
126
+ return f"{s}, dtype={self.dtype}"
127
+
128
+ def forward(
129
+ self,
130
+ batch: EmbeddingsBatch,
131
+ noisy_batch: EmbeddingsBatch,
132
+ cf_guidance_prob: float = 0.0,
133
+ ) -> EmbeddingsBatch:
134
+ """
135
+ Arguments:
136
+ - batch (`EmbeddingsBatch`): The clean batch of embeddings to encode the context.
137
+ If `unsupervised` this is the source embeddings.
138
+ If `supervised` this is the source+target embeddings.
139
+
140
+ - noisy_batch (`EmbeddingsBatch`): the embeddings noised by the noise scheduler
141
+ If `unsupervised` this is noised source embeddings.
142
+ If `supervised` this is noised target-only embeddings.
143
+
144
+ - cf_guidance_prob: probability of training without any guiding context
145
+ """
146
+ # Get source lengths if any:
147
+ source_lengths = batch.source_lengths
148
+
149
+ # Encode as context:
150
+ context = self.encode(batch)
151
+
152
+ # Predict denoised output
153
+ output_batch = self.denoise(
154
+ noisy_batch=noisy_batch,
155
+ context=context,
156
+ source_lengths=source_lengths,
157
+ cf_guidance_prob=cf_guidance_prob,
158
+ )
159
+ return output_batch
160
+
161
+ def encode(
162
+ self,
163
+ batch: EmbeddingsBatch,
164
+ state_bag: Optional[LCMIncrementalStateBag] = None,
165
+ **kwargs,
166
+ ) -> EmbeddingsBatch:
167
+ """
168
+ The main context encoder that takes in a sequence of sonar embeddings in B, T, D
169
+ and returns a sequence of the same shape after causal contextualization.
170
+
171
+ Main modules:
172
+ `frontend`: linear projection to model_dim + optional positional embeddings,
173
+ `context_encoder`: Causal Transformer decoder to causally encode the context
174
+ """
175
+ # Frontend
176
+ seqs, padding_mask = self.encoder_frontend(
177
+ batch.seqs,
178
+ batch.padding_mask,
179
+ state_bag=state_bag,
180
+ **kwargs,
181
+ )
182
+
183
+ # Main Transformer
184
+ seqs, padding_mask = self.context_encoder(
185
+ seqs,
186
+ padding_mask,
187
+ state_bag=state_bag,
188
+ **kwargs,
189
+ )
190
+
191
+ return EmbeddingsBatch(seqs=seqs, padding_mask=padding_mask)
192
+
193
+ def denoise(
194
+ self,
195
+ noisy_batch: EmbeddingsBatch,
196
+ context: EmbeddingsBatch,
197
+ source_lengths: Optional[Tensor] = None,
198
+ cf_guidance_prob: float = 0.0,
199
+ state_bag: Optional[LCMIncrementalStateBag] = None,
200
+ inference: bool = False,
201
+ ) -> EmbeddingsBatch:
202
+ """Diffuse a noised sonar embedding conditioned on the encoded context"""
203
+ seqs, padding_mask = self.denoiser(
204
+ seqs=noisy_batch.seqs,
205
+ diffusion_timesteps=noisy_batch.diffusion_timesteps,
206
+ padding_mask=noisy_batch.padding_mask,
207
+ conditioning_variables=context.seqs,
208
+ conditioning_variables_padding_mask=context.padding_mask,
209
+ source_lengths=source_lengths,
210
+ cf_guidance_prob=cf_guidance_prob,
211
+ inference=inference,
212
+ )
213
+ return EmbeddingsBatch(seqs=seqs, padding_mask=padding_mask)
214
+
215
+ def prep_for_denoising(self, decoding_options):
216
+ """This setup is done once when we initialize the generator"""
217
+ self.guidance_scale = decoding_options.guidance_scale
218
+ self.guidance_rescale = decoding_options.guidance_rescale
219
+ self.initial_noise_scale = decoding_options.initial_noise_scale
220
+ self.timesteps = decoding_options.inference_timesteps
221
+ self.clip_noise = decoding_options.clip_noise
222
+ self.ddim_eta = decoding_options.ddim_eta
223
+ self.epsilon_scaling = decoding_options.epsilon_scaling
224
+
225
+ # if guidance_scale > 1.0 we will duplicate batches
226
+ self.do_classifier_free_guidance = self.guidance_scale != 1.0
227
+
228
+ # Setup the diffusion training-like noise scheduler
229
+ # by updating the timesteps according to the decoding `inference_timesteps`
230
+ self.noise_scheduler.set_timesteps(self.timesteps, device=self.device)
231
+
232
+ # Override the initial noise scale
233
+ self.noise_scheduler.init_noise_sigma = self.initial_noise_scale
234
+ # Override thresholding options:
235
+ if decoding_options.thresholding:
236
+ self.noise_scheduler.config.thresholding = decoding_options.thresholding
237
+ self.noise_scheduler.config.dynamic_thresholding_ratio = (
238
+ decoding_options.dynamic_thresholding_ratio
239
+ )
240
+ self.noise_scheduler.config.sample_max_value = (
241
+ decoding_options.sample_max_value
242
+ )
243
+
244
+ def sample_initial_noise_vectors(self, batch_size: int):
245
+ # Check that we called `prep_for_denoising`:
246
+ assert hasattr(self, "clip_noise"), (
247
+ "The model is not properly set for decoding, make sure to call `model.prep_for_denoising()`"
248
+ )
249
+
250
+ # Sample a noise vector for next embedding prediction
251
+ latents = torch.randn(
252
+ batch_size, 1, self.config.sonar_embed_dim, device=self.device
253
+ )
254
+
255
+ # Scale the initial noise by the standard deviation required by the scheduler
256
+ latents = latents * self.noise_scheduler.init_noise_sigma
257
+
258
+ # clip?
259
+ latents = latents.clip(-self.clip_noise, self.clip_noise)
260
+ return latents
261
+
262
+ @torch.inference_mode()
263
+ def predict_next_sentence( # type: ignore
264
+ self,
265
+ batch: EmbeddingsBatch,
266
+ context: EmbeddingsBatch,
267
+ temperature: float = 1.0,
268
+ state_bag: Optional[LCMIncrementalStateBag] = None,
269
+ context_state_bag: Optional[LCMIncrementalStateBag] = None,
270
+ **kwargs,
271
+ ) -> Tuple[EmbeddingsBatch, EmbeddingsBatch]:
272
+ assert context_state_bag is not None, (
273
+ "Expected a state_bag to incrementally encode the context"
274
+ )
275
+
276
+ if self.do_classifier_free_guidance:
277
+ logger.debug("Running inference with CF-guidance...")
278
+ return self.predict_next_sentence_with_cf_guidance(
279
+ batch=batch,
280
+ context=context,
281
+ temperature=temperature,
282
+ state_bag=state_bag,
283
+ context_state_bag=context_state_bag,
284
+ **kwargs,
285
+ )
286
+
287
+ # Normalize the input embeddings if we're expected to
288
+ # normalize outside of the model's forward pass
289
+ if self.sonar_normalizer is not None:
290
+ batch = batch.normalize_seqs(self.sonar_normalizer)
291
+
292
+ # Encode context:
293
+ new_context = self.encode(batch, context_state_bag)
294
+ context_state_bag.increment_step_nr(1)
295
+
296
+ # Append to context
297
+ context = EmbeddingsBatch(torch.cat((context.seqs, new_context.seqs), dim=1))
298
+
299
+ # Sample latents:
300
+ latents = self.sample_initial_noise_vectors(batch_size=batch.seqs.size(0))
301
+
302
+ # Denoise
303
+ diffusion_timesteps_schedule = self.noise_scheduler.timesteps
304
+
305
+ for diffusion_timestep in diffusion_timesteps_schedule:
306
+ input_batch = EmbeddingsBatch(
307
+ seqs=latents,
308
+ diffusion_timesteps=diffusion_timestep.long().repeat(
309
+ (latents.shape[0], 1)
310
+ ),
311
+ )
312
+ # Get model output
313
+ model_prediction = self.denoise(
314
+ noisy_batch=input_batch,
315
+ context=context,
316
+ state_bag=None,
317
+ inference=True,
318
+ )
319
+
320
+ scheduler_outputs = self.noise_scheduler.step(
321
+ model_output=model_prediction.seqs,
322
+ timestep=diffusion_timestep,
323
+ sample=latents,
324
+ eta=self.ddim_eta,
325
+ epsilon_scaling=self.epsilon_scaling,
326
+ )
327
+
328
+ # setup latents for the next diffusion step
329
+ latents = scheduler_outputs.prev_sample
330
+ # clip?
331
+ latents = latents.clip(-self.clip_noise, self.clip_noise)
332
+
333
+ # Take the final predicted denoised sample (x_0 in the ddim paper) and denormalize if needed:
334
+ final_seqs = scheduler_outputs.pred_original_sample
335
+
336
+ final_seqs = self.sonar_normalizer.denormalize(final_seqs)
337
+
338
+ return EmbeddingsBatch(final_seqs, None), context
339
+
340
+ @torch.inference_mode()
341
+ def predict_next_sentence_with_cf_guidance( # type: ignore
342
+ self,
343
+ batch: EmbeddingsBatch,
344
+ context: EmbeddingsBatch,
345
+ temperature: float = 1.0,
346
+ state_bag: Optional[LCMIncrementalStateBag] = None,
347
+ context_state_bag: Optional[LCMIncrementalStateBag] = None,
348
+ **kwargs,
349
+ ) -> Tuple[EmbeddingsBatch, EmbeddingsBatch]:
350
+ assert context_state_bag is not None, (
351
+ "Expected a state_bag to incrementally encode the context"
352
+ )
353
+
354
+ # Normalize the input embeddings if we're expected to
355
+ # normalize outside of the model's forward pass
356
+ if self.sonar_normalizer is not None:
357
+ batch = batch.normalize_seqs(self.sonar_normalizer)
358
+
359
+ # Encode context:
360
+ new_context = self.encode(batch, context_state_bag)
361
+ context_state_bag.increment_step_nr(1)
362
+
363
+ # Append to context
364
+ context = EmbeddingsBatch(torch.cat((context.seqs, new_context.seqs), dim=1))
365
+
366
+ # Sample latents:
367
+ latents = self.sample_initial_noise_vectors(batch_size=batch.seqs.size(0))
368
+
369
+ # Denoise
370
+ diffusion_timesteps_schedule = self.noise_scheduler.timesteps
371
+
372
+ # Duplicate the context and its padding mask, the second half will be ignored
373
+ _seq_lens = get_seq_lens(context.seqs, context.padding_mask)
374
+
375
+ # add zeros:
376
+ _seq_lens = torch.concat((_seq_lens, torch.zeros_like(_seq_lens)), dim=0)
377
+
378
+ context = EmbeddingsBatch(
379
+ torch.concat((context.seqs, torch.zeros_like(context.seqs)), dim=0),
380
+ PaddingMask(_seq_lens, batch_seq_len=context.seqs.size(1)),
381
+ )
382
+
383
+ batch_multiplier = 2
384
+ for diffusion_timestep in diffusion_timesteps_schedule:
385
+ is_max_diffusion_step = (
386
+ diffusion_timestep == self.noise_scheduler.num_diffusion_train_steps - 1
387
+ )
388
+
389
+ input_batch = EmbeddingsBatch(
390
+ torch.concat(batch_multiplier * [latents], dim=0),
391
+ diffusion_timesteps=diffusion_timestep.long().repeat(
392
+ (latents.shape[0] * batch_multiplier, 1)
393
+ ),
394
+ )
395
+
396
+ model_prediction = self.denoise(
397
+ noisy_batch=input_batch,
398
+ context=context,
399
+ state_bag=None,
400
+ inference=True,
401
+ )
402
+
403
+ # If at the max step, do not step in the epsilon_scheduler
404
+ if is_max_diffusion_step:
405
+ # if beta_prod_t (denominator) is null i.e.,
406
+ # the diffusion timestep is at its max value (num_training_stesp-1)
407
+ # no denoising will be performed.
408
+
409
+ # Note that since the batch might be doubled because
410
+ # we're doing classifier-free guidance, we chunk the model output
411
+ # by batch_multiplier. If not at max_diffusion_step
412
+ # this chunking is performed in apply_classifier_free_guidance
413
+ scheduler_outputs = self.noise_scheduler.step(
414
+ model_output=model_prediction.seqs.chunk(batch_multiplier)[0],
415
+ timestep=diffusion_timestep,
416
+ sample=latents,
417
+ eta=self.ddim_eta,
418
+ epsilon_scaling=self.epsilon_scaling,
419
+ )
420
+ else:
421
+ # Predict the noise residual according to the prediction type
422
+ predicted_noise = self.noise_scheduler.get_epsilon(
423
+ model_output=model_prediction.seqs,
424
+ sample=input_batch.seqs,
425
+ timestep=diffusion_timestep,
426
+ )
427
+
428
+ if self.do_classifier_free_guidance:
429
+ # Perform guidance if trained with cf-guidance:
430
+ # The returned predicted noise will combine the conditional and
431
+ # unconditional predictions i.e., from (2 x batch_size, 1, C)
432
+ # to: (batch_size, 1, C)
433
+ predicted_noise = self.apply_classifier_free_guidance(
434
+ predicted_noise
435
+ )
436
+
437
+ # The cf-guidance operates on predicted noises and although we
438
+ # can go back and forth between epsilon and predicted sample
439
+ # once we combine cond and uncond we cannot go back to predicted_x0
440
+
441
+ # compute the previous noisy sample x_t -> x_t-1
442
+ scheduler_outputs = self.noise_scheduler.step(
443
+ model_output=predicted_noise,
444
+ timestep=diffusion_timestep,
445
+ sample=latents,
446
+ eta=self.ddim_eta,
447
+ epsilon_scaling=self.epsilon_scaling,
448
+ prediction_type="epsilon",
449
+ )
450
+
451
+ # setup latents for the next diffusion step
452
+ latents = scheduler_outputs.prev_sample
453
+ # clip?
454
+ latents = latents.clip(-self.clip_noise, self.clip_noise)
455
+
456
+ # Take the final predicted denoised sample (x_0 in the ddim paper) and denormalize if needed:
457
+ final_seqs = scheduler_outputs.pred_original_sample
458
+
459
+ final_seqs = self.sonar_normalizer.denormalize(final_seqs)
460
+
461
+ return EmbeddingsBatch(final_seqs, None), context
462
+
463
+ def apply_classifier_free_guidance(self, predicted_noise: Tensor) -> Tensor:
464
+ """ "
465
+ Apply Classifier-Free Guidance with Rescale as introduced in Algorithm 2 of https://arxiv.org/pdf/2305.08891
466
+ `pos` would be the conditional prediction `cond_prediction`
467
+ and `neg` the unconditional prediction `uncond_prediction`:
468
+ The batch during prefilling is prepared with the conditioning prefix in
469
+ the first half
470
+ """
471
+ # Chunk and follow algorithm 2
472
+ cond_prediction, uncond_prediction = predicted_noise.chunk(2)
473
+
474
+ # Regular classifier-free guidance:
475
+ guided_noise_prediction = uncond_prediction + self.guidance_scale * (
476
+ cond_prediction - uncond_prediction
477
+ )
478
+
479
+ # Rescale classifier-free guidance to prevent over-exposure
480
+ # Calculate standard deviations.
481
+ std_pos = cond_prediction.std(dim=-1, keepdim=True)
482
+ std_cfg = guided_noise_prediction.std(dim=-1, keepdim=True)
483
+
484
+ # Apply guidance rescale with fused operations.
485
+ factor = std_pos / std_cfg
486
+ factor = self.guidance_rescale * factor + (1 - self.guidance_rescale)
487
+
488
+ return factor * guided_noise_prediction
489
+
490
+
491
+ class TwoTowerDiffusionLCModelBuilder(AbstractLCModelBuilder):
492
+ """Builds modules of a diffusion-based LCM"""
493
+
494
+ config: TwoTowerDiffusionLCModelConfig
495
+ denoiser_factory: LCMDenoiserTransformerFactory
496
+
497
+ def __init__(
498
+ self,
499
+ config: TwoTowerDiffusionLCModelConfig,
500
+ *,
501
+ device: Optional[Device] = None,
502
+ dtype: Optional[DataType] = None,
503
+ ) -> None:
504
+ """
505
+ :param config:
506
+ The configuration.
507
+ :param device:
508
+ The device on which to initialize modules.
509
+ :param dtype:
510
+ The data type of module parameters and buffers.
511
+ """
512
+ super().__init__(config=config, device=device, dtype=dtype)
513
+
514
+ self.context_encoder_factory = TransformerFactory(
515
+ model_dim=self.config.model_dim,
516
+ max_seq_len=self.config.max_seq_len,
517
+ config=self.config.context_encoder,
518
+ device=device,
519
+ dtype=dtype,
520
+ )
521
+
522
+ self.denoiser_factory = LCMDenoiserTransformerFactory(
523
+ model_dim=self.config.model_dim,
524
+ num_diffusion_train_timesteps=self.config.noise_scheduler.num_diffusion_train_steps,
525
+ max_seq_len=self.config.max_seq_len,
526
+ config=self.config.denoiser,
527
+ input_dim=self.config.sonar_embed_dim,
528
+ device=device,
529
+ dtype=dtype,
530
+ )
531
+
532
+ def build_model(self) -> TwoTowerDiffusionLCModel:
533
+ """Build a model."""
534
+
535
+ sonar_normalizer = self.build_sonar_normalizer()
536
+ assert sonar_normalizer is not None, (
537
+ "TwoTowerDiffusionLCModel expects a `sonar_normalizer`"
538
+ )
539
+
540
+ # the context encoder
541
+ encoder_frontend = self.build_frontend()
542
+
543
+ context_encoder = self.build_context_encoder()
544
+
545
+ # the denoiser
546
+ denoiser = self.build_denoiser()
547
+
548
+ noise_scheduler = self.build_noise_scheduler()
549
+
550
+ return TwoTowerDiffusionLCModel(
551
+ config=self.config,
552
+ sonar_normalizer=sonar_normalizer,
553
+ context_encoder=context_encoder,
554
+ encoder_frontend=encoder_frontend,
555
+ denoiser=denoiser,
556
+ noise_scheduler=noise_scheduler,
557
+ )
558
+
559
+ def build_frontend(self) -> EncoderFrontend:
560
+ """Build the context encoder front-end."""
561
+
562
+ return EncoderFrontend(
563
+ sonar_embed_dim=self.config.sonar_embed_dim,
564
+ model_dim=self.config.model_dim,
565
+ config=self.config.frontend,
566
+ pos_encoder=self.context_encoder_factory.build_pos_encoder(),
567
+ device=self.device,
568
+ dtype=self.dtype,
569
+ )
570
+
571
+ def build_context_encoder(self) -> LCMTransformerDecoder:
572
+ """Build the context encoder."""
573
+
574
+ config = self.config.context_encoder
575
+
576
+ num_layers = config.num_layers
577
+ assert num_layers > 0, "The context encoder needs a non-zero number of layers"
578
+
579
+ layers = [self.context_encoder_factory.build_layer() for _ in range(num_layers)]
580
+
581
+ self_attn_mask_factory = CausalAttentionMaskFactory()
582
+
583
+ if config.final_norm_order_style is None:
584
+ # The final norm order style will be that of
585
+ # the layer-level norm order
586
+ final_norm_order = parse_norm_order(config.norm_order_style)
587
+ else:
588
+ final_norm_order = parse_norm_order(config.final_norm_order_style)
589
+
590
+ layer_norm_factory = parse_layer_norm_factory(config.layer_normalization_style)
591
+
592
+ return LCMTransformerDecoder(
593
+ layers,
594
+ self_attn_mask_factory=self_attn_mask_factory,
595
+ norm_order=final_norm_order,
596
+ layer_norm_factory=layer_norm_factory,
597
+ dropout_p=config.final_dropout_p,
598
+ device=self.device,
599
+ dtype=self.dtype,
600
+ )
601
+
602
+ def build_noise_scheduler(self) -> DDIMScheduler:
603
+ return DDIMScheduler(self.config.noise_scheduler)
604
+
605
+ def build_denoiser(self) -> LCMDenoiser:
606
+ """Build a Transformer for diffusing noised latents."""
607
+ return self.denoiser_factory.build_model()
608
+
609
+
610
+ def create_two_tower_diffusion_lcm_model(
611
+ config: TwoTowerDiffusionLCModelConfig,
612
+ *,
613
+ device: Optional[Device] = None,
614
+ dtype: Optional[DataType] = None,
615
+ ) -> TwoTowerDiffusionLCModel:
616
+ """Create a DiffusionLCM model.
617
+ :param config:
618
+ The configuration.
619
+ :param device:
620
+ The device on which to initialize modules.
621
+ :param dtype:
622
+ The data type of module parameters and buffers.
623
+ """
624
+ return TwoTowerDiffusionLCModelBuilder(
625
+ config,
626
+ device=device,
627
+ dtype=dtype, # type: ignore
628
+ ).build_model()
lcm/models/two_tower_diffusion_lcm/frontend.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Tuple
8
+
9
+ import torch
10
+ from fairseq2.logging import get_log_writer
11
+ from fairseq2.nn import Embedding, LearnedPositionEncoder, PositionEncoder
12
+ from fairseq2.nn.incremental_state import IncrementalStateBag
13
+ from fairseq2.nn.padding import PaddingMask
14
+ from fairseq2.nn.projection import Linear
15
+ from fairseq2.typing import DataType, Device
16
+ from torch import Tensor
17
+ from torch.nn import Dropout, Module
18
+
19
+ from lcm.nn.initialization import SUPPORTED_INIT_TYPES, get_init_fn
20
+
21
+ logger = get_log_writer(__name__)
22
+
23
+
24
+ @dataclass
25
+ class EncoderFrontendConfig:
26
+ dropout_p: float = 0.0
27
+ """ The dropout probability applied to the module' output"""
28
+
29
+ pre_linear_bias: bool = True
30
+ """ Whether or not the pre-linear layer has a bias term"""
31
+
32
+ pre_linear_init_fn: SUPPORTED_INIT_TYPES = "kaiming_uniform"
33
+
34
+ weight_normalization: bool = False
35
+
36
+ embedding_std: float = 1.0
37
+
38
+
39
+ class EncoderFrontend(Module):
40
+ """
41
+ A fronted for the context encoder in encoder-decoder LCMs
42
+ """
43
+
44
+ embed: Embedding
45
+ pos_encoder: Optional[PositionEncoder]
46
+ dropout: Optional[Dropout]
47
+
48
+ def __init__(
49
+ self,
50
+ sonar_embed_dim: int,
51
+ model_dim: int,
52
+ config: EncoderFrontendConfig,
53
+ pos_encoder: Optional[PositionEncoder],
54
+ *,
55
+ device: Optional[Device] = None,
56
+ dtype: Optional[DataType] = None,
57
+ ) -> None:
58
+ """
59
+ :param sonar_embed_dim
60
+ The embedding dimension of the sentence encoder, in this case SONAR
61
+ :param model_dim
62
+ The model embedding dimension
63
+ :param config:
64
+ A Frontend config. See `LCMFrontendConfig`
65
+ :param pos_encoder:
66
+ An optional position encoder.
67
+ """
68
+
69
+ super().__init__()
70
+
71
+ self.sonar_embed_dim = sonar_embed_dim
72
+
73
+ self.model_dim = model_dim
74
+
75
+ self.device = device
76
+
77
+ # Pre-linear to map to model dimension
78
+ init_fn = get_init_fn(config.pre_linear_init_fn)
79
+
80
+ lin = Linear(
81
+ sonar_embed_dim,
82
+ model_dim,
83
+ bias=config.pre_linear_bias,
84
+ device=device,
85
+ dtype=dtype,
86
+ init_fn=init_fn,
87
+ )
88
+
89
+ if config.weight_normalization:
90
+ self.pre_linear = torch.nn.utils.parametrizations.weight_norm(lin)
91
+ else:
92
+ self.pre_linear = lin
93
+
94
+ if pos_encoder is not None:
95
+ if pos_encoder.encoding_dim != self.model_dim:
96
+ raise ValueError(
97
+ f"`encoding_dim` of `pos_encoder` and `embedding_dim` of \
98
+ `embed` must be equal, but are {pos_encoder.encoding_dim} \
99
+ and {self.model_dim} instead."
100
+ )
101
+
102
+ self.pos_encoder = pos_encoder
103
+ else:
104
+ self.register_module("pos_encoder", None)
105
+
106
+ if config.dropout_p > 0.0:
107
+ self.dropout = Dropout(config.dropout_p)
108
+ else:
109
+ self.register_module("dropout", None)
110
+
111
+ self.reset_parameters(embedding_std=config.embedding_std)
112
+
113
+ def reset_parameters(self, embedding_std: float) -> None:
114
+ """Initialize module parameters.
115
+ The positional embeddings should be initialized with the
116
+ same order of magnitude as the semantic embeddings, in order
117
+ to make the early training as stable as possible.
118
+ Otherwise, the positional and special token embeddings would
119
+ flood out the semantic information.
120
+ """
121
+ logger.info(
122
+ f"Initializing frontend embeddings (special and positional) ~ N(0, {embedding_std})"
123
+ )
124
+ if isinstance(self.pos_encoder, LearnedPositionEncoder):
125
+ torch.nn.init.normal_(self.pos_encoder.weight, std=embedding_std)
126
+
127
+ def forward(
128
+ self,
129
+ seqs: Tensor,
130
+ padding_mask: Optional[PaddingMask],
131
+ state_bag: Optional[IncrementalStateBag] = None,
132
+ **kwargs,
133
+ ) -> Tuple[Tensor, Optional[PaddingMask]]:
134
+ """
135
+ Apply pre-linear (if relevant) and add positional embeddings
136
+ """
137
+
138
+ # pre-linear if any:
139
+ seqs = self.pre_linear(seqs)
140
+
141
+ if self.pos_encoder is not None:
142
+ seqs = self.pos_encoder(
143
+ seqs,
144
+ padding_mask,
145
+ state_bag=state_bag,
146
+ **kwargs,
147
+ )
148
+
149
+ if self.dropout is not None:
150
+ seqs = self.dropout(seqs)
151
+
152
+ return seqs, padding_mask
lcm/models/two_tower_diffusion_lcm/loader.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+
7
+ from fairseq2.models.config_loader import StandardModelConfigLoader
8
+ from fairseq2.models.loader import StandardModelLoader, load_model
9
+
10
+ from lcm.models.base_lcm.loader import convert_lcm_checkpoint
11
+ from lcm.models.two_tower_diffusion_lcm.builder import (
12
+ TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE,
13
+ TwoTowerDiffusionLCModelConfig,
14
+ create_two_tower_diffusion_lcm_model,
15
+ lcm_archs,
16
+ )
17
+ from lcm.utils.model_type_registry import ModelTypeConfig, lcm_model_type_registry
18
+
19
+ load_two_tower_diffusion_lcm_config = StandardModelConfigLoader(
20
+ family=TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE,
21
+ config_kls=TwoTowerDiffusionLCModelConfig,
22
+ arch_configs=lcm_archs,
23
+ )
24
+
25
+
26
+ load_two_tower_diffusion_lcm_model = StandardModelLoader( # type: ignore # FIXME
27
+ config_loader=load_two_tower_diffusion_lcm_config,
28
+ factory=create_two_tower_diffusion_lcm_model,
29
+ checkpoint_converter=convert_lcm_checkpoint,
30
+ restrict_checkpoints=False,
31
+ )
32
+
33
+ load_model.register(
34
+ TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE, load_two_tower_diffusion_lcm_model
35
+ )
36
+
37
+ lcm_model_type_registry.register(
38
+ ModelTypeConfig(
39
+ model_type=TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE,
40
+ config_loader=load_two_tower_diffusion_lcm_config,
41
+ model_factory=create_two_tower_diffusion_lcm_model,
42
+ model_loader=load_two_tower_diffusion_lcm_model,
43
+ )
44
+ )
lcm/nn/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ #
lcm/nn/denoisers/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+
7
+ from lcm.nn.denoisers.factory import (
8
+ DenoiserConfig,
9
+ LCMDenoiser,
10
+ LCMDenoiserTransformerFactory,
11
+ )
12
+
13
+ __all__ = [
14
+ "DenoiserConfig",
15
+ "LCMDenoiser",
16
+ "LCMDenoiserTransformerFactory",
17
+ ]
lcm/nn/denoisers/attention_masks.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ import math
7
+ from typing import Optional, final
8
+
9
+ import torch
10
+ from fairseq2.nn.transformer import (
11
+ AbstractAttentionMask,
12
+ AttentionMask,
13
+ AttentionMaskFactory,
14
+ )
15
+ from fairseq2.typing import DataType, Device, override
16
+ from torch import Tensor
17
+
18
+ from lcm.nn.incremental_state import LCMIncrementalStateBag
19
+
20
+
21
+ def _get_shifted_causal_mask(
22
+ seq_len: int,
23
+ key_len: int,
24
+ shift: int = 0,
25
+ cf_guidance_prob: float = 0.0,
26
+ zero_vector: bool = False,
27
+ device: Optional[Device] = None,
28
+ dtype: Optional[DataType] = None,
29
+ ) -> Tensor:
30
+ causal_mask = torch.ones(
31
+ (seq_len, key_len),
32
+ device=device,
33
+ dtype=dtype,
34
+ )
35
+ causal_mask.tril_(diagonal=shift)
36
+
37
+ if cf_guidance_prob > 0.0:
38
+ num_rows_to_drop = math.floor((seq_len - 1) * cf_guidance_prob)
39
+ if num_rows_to_drop > 0:
40
+ rows_to_drop = 1 + torch.randperm(seq_len - 1)[:num_rows_to_drop]
41
+ if zero_vector:
42
+ causal_mask[rows_to_drop, 1:] = 0
43
+ else:
44
+ causal_mask[rows_to_drop, :] = 0
45
+
46
+ return causal_mask
47
+
48
+
49
+ class NoAttentionMaskFactory(AttentionMaskFactory):
50
+ """Constructs instances of :class:`NoAttentionMask`."""
51
+
52
+ @override
53
+ def __call__( # type: ignore
54
+ self,
55
+ seqs: Tensor,
56
+ keys: Tensor,
57
+ *,
58
+ training: bool = True,
59
+ state_bag: Optional[LCMIncrementalStateBag] = None,
60
+ inference_without_caching: Optional[bool] = False,
61
+ **kwargs,
62
+ ) -> Optional[AttentionMask]:
63
+ mask: NoAttentionMask
64
+
65
+ attn_len: Optional[int] = seqs.size(1)
66
+ seq_len = seqs.size(1)
67
+ key_len = keys.size(1)
68
+
69
+ mask = NoAttentionMask(
70
+ seq_len=seq_len,
71
+ key_len=key_len,
72
+ attn_len=attn_len,
73
+ device=seqs.device,
74
+ dtype=seqs.dtype,
75
+ )
76
+ return mask
77
+
78
+ def __repr__(self) -> str:
79
+ return "NoAttentionMaskFactory()"
80
+
81
+
82
+ @final
83
+ class NoAttentionMask(AbstractAttentionMask):
84
+ """
85
+ Represents a diagonal attention mask, i.e attention
86
+ on current position only.
87
+ This turns the self-attention layer into an FFN
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ seq_len: int,
93
+ key_len: int,
94
+ attn_len: Optional[int],
95
+ *,
96
+ device: Optional[Device] = None,
97
+ dtype: Optional[DataType] = None,
98
+ ) -> None:
99
+ """
100
+ :param seq_len:
101
+ The sequence length.
102
+ """
103
+ super().__init__()
104
+
105
+ self.seq_len = seq_len
106
+
107
+ self._device, self._dtype = device, dtype
108
+
109
+ @override
110
+ def _do_materialize(self) -> Tensor:
111
+ mask = torch.eye((self.seq_len), device=self._device, dtype=self._dtype)
112
+ mask.log_()
113
+ return mask
114
+
115
+
116
+ class ShiftedCausalAttentionMaskFactory(AttentionMaskFactory):
117
+ """
118
+ Constructs instances of :class:`ShiftedCausalAttentionMask`
119
+ """
120
+
121
+ @override
122
+ def __call__( # type: ignore
123
+ self,
124
+ seqs: Tensor,
125
+ keys: Tensor,
126
+ *,
127
+ source_lengths: Optional[Tensor] = None,
128
+ cf_guidance_prob: float = 0.0,
129
+ training: bool = True,
130
+ state_bag: Optional[LCMIncrementalStateBag] = None,
131
+ inference: bool = False,
132
+ ) -> Optional[AttentionMask]:
133
+ mask: Optional[ShiftedCausalAttentionMask]
134
+
135
+ attn_len: Optional[int] = seqs.size(1)
136
+ seq_len = seqs.size(1)
137
+ key_len = keys.size(1)
138
+
139
+ if inference:
140
+ mask = None
141
+ else:
142
+ mask = ShiftedCausalAttentionMask(
143
+ seq_len=seq_len,
144
+ key_len=key_len,
145
+ attn_len=attn_len,
146
+ source_lengths=source_lengths,
147
+ cf_guidance_prob=cf_guidance_prob,
148
+ device=seqs.device,
149
+ dtype=seqs.dtype,
150
+ )
151
+
152
+ return mask
153
+
154
+ def __repr__(self) -> str:
155
+ return "ShiftedCausalAttentionMask()"
156
+
157
+
158
+ @final
159
+ class ShiftedCausalAttentionMask(AbstractAttentionMask):
160
+ """
161
+ Represents a causal mask shifted by source_lengths
162
+
163
+ In training time, Without source_lengths, the mask look like (e.g. seq_len = 5):
164
+
165
+ [ 0., -inf, -inf, -inf, -inf, -inf],
166
+ [ 0., 0., -inf, -inf, -inf, -inf],
167
+ [ 0., 0., 0., -inf, -inf, -inf],
168
+ [ 0., 0., 0., 0., -inf, -inf],
169
+ [ 0., 0., 0., 0., 0., -inf]
170
+
171
+ """
172
+
173
+ def __init__(
174
+ self,
175
+ seq_len: int,
176
+ key_len: int,
177
+ attn_len: Optional[int],
178
+ *,
179
+ source_lengths: Optional[Tensor] = None,
180
+ cf_guidance_prob: float = 0.0,
181
+ device: Optional[Device] = None,
182
+ dtype: Optional[DataType] = None,
183
+ ) -> None:
184
+ """
185
+ :param seq_len:
186
+ The sequence length.
187
+ """
188
+ super().__init__()
189
+
190
+ self.seq_len = seq_len
191
+ self.key_len = key_len
192
+ self._source_lengths = source_lengths
193
+ self._cf_guidance_prob = cf_guidance_prob
194
+ self._device, self._dtype = device, dtype
195
+
196
+ @override
197
+ def _do_materialize(self) -> Tensor:
198
+ if self._source_lengths is None:
199
+ causal_mask = _get_shifted_causal_mask(
200
+ seq_len=self.seq_len,
201
+ key_len=self.key_len,
202
+ shift=0,
203
+ cf_guidance_prob=self._cf_guidance_prob,
204
+ zero_vector=True,
205
+ device=self._device,
206
+ dtype=self._dtype,
207
+ )
208
+
209
+ else:
210
+ causal_mask = torch.stack(
211
+ [
212
+ _get_shifted_causal_mask(
213
+ seq_len=self.seq_len,
214
+ key_len=self.key_len,
215
+ shift=src_len,
216
+ cf_guidance_prob=self._cf_guidance_prob,
217
+ zero_vector=True,
218
+ device=self._device,
219
+ dtype=self._dtype,
220
+ )
221
+ for src_len in self._source_lengths
222
+ ]
223
+ ).unsqueeze(1)
224
+ # bs x 1 (head) x seq_len x seq_len
225
+
226
+ causal_mask.log_()
227
+
228
+ return causal_mask
lcm/nn/denoisers/factory.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from dataclasses import dataclass, field
7
+ from typing import Literal, Optional
8
+
9
+ from fairseq2.logging import get_log_writer
10
+ from fairseq2.typing import DataType, Device
11
+
12
+ from lcm.nn.denoisers.attention_masks import (
13
+ NoAttentionMaskFactory,
14
+ ShiftedCausalAttentionMaskFactory,
15
+ )
16
+ from lcm.nn.denoisers.lcm_denoiser import (
17
+ LCMDenoiser,
18
+ LCMDenoiserLayer,
19
+ )
20
+ from lcm.nn.initialization import parse_norm_order
21
+ from lcm.nn.normalization import parse_layer_norm_factory
22
+ from lcm.nn.projection import (
23
+ Projection,
24
+ ProjectionConfig,
25
+ )
26
+ from lcm.nn.timestep_encoder import DiTTimestepEncoder
27
+ from lcm.nn.transformer import TransformerConfig, TransformerFactory
28
+
29
+ logger = get_log_writer(__name__)
30
+
31
+
32
+ @dataclass
33
+ class DenoiserConfig(TransformerConfig):
34
+ """Config for building the LCM's denoiser"""
35
+
36
+ pos_embedding_style: Literal["rope", "sine", "learned", "none"] = "none"
37
+ """By default, a denoiser does not have a positional embedder"""
38
+
39
+ pre_denoiser: ProjectionConfig = field(default_factory=lambda: ProjectionConfig())
40
+ """the initial projection at the top of the denoiser"""
41
+
42
+ post_denoiser: ProjectionConfig = field(default_factory=lambda: ProjectionConfig())
43
+ """the final output projection at the end of the denoiser"""
44
+
45
+ timestep_embed_dim: int = 1024
46
+ """Diffusion timestep embedding dimension"""
47
+
48
+
49
+ class LCMDenoiserTransformerFactory(TransformerFactory):
50
+ """Denoiser with hybrid AdaLN and cross-attention"""
51
+
52
+ config: DenoiserConfig
53
+
54
+ def __init__(
55
+ self,
56
+ model_dim: int,
57
+ max_seq_len: int,
58
+ num_diffusion_train_timesteps: int,
59
+ config: DenoiserConfig,
60
+ input_dim: int = 1024,
61
+ device: Optional[Device] = None,
62
+ dtype: Optional[DataType] = None,
63
+ ) -> None:
64
+ """
65
+ :param model_dim:
66
+ The hidden model dimension of the Transformer
67
+ :params max_seqs_len:
68
+ Maximum supported sequence length by the model
69
+ :param config:
70
+ The configuration.
71
+ :param input_dim:
72
+ The input embedding dimension i.e `sonar_embed_dim``
73
+ :param device:
74
+ The device on which to initialize modules.
75
+ :param dtype:
76
+ The data type of module parameters and buffers.
77
+ """
78
+ super().__init__(
79
+ model_dim=model_dim,
80
+ max_seq_len=max_seq_len,
81
+ config=config,
82
+ device=device,
83
+ dtype=dtype,
84
+ )
85
+
86
+ self.input_dim = input_dim
87
+
88
+ self.num_diffusion_train_timesteps = num_diffusion_train_timesteps
89
+
90
+ def build_cross_attention_mask(self):
91
+ return ShiftedCausalAttentionMaskFactory()
92
+
93
+ def build_timestep_embedder(self):
94
+ return DiTTimestepEncoder(
95
+ embedding_dim=self.config.timestep_embed_dim,
96
+ dtype=self.dtype,
97
+ device=self.device,
98
+ )
99
+
100
+ def build_initial_proj(self) -> Projection:
101
+ # We will be concatenating context and timesteps embeddings
102
+ assert self.config.timestep_embed_dim == self.model_dim, (
103
+ "Since the timestep embeddings will be added to the sequence of "
104
+ "conditioning variables, they need to be of the same dimension. "
105
+ f"Found timestep_embed_dim={self.config.timestep_embed_dim} "
106
+ f"and model_dim={self.model_dim}"
107
+ )
108
+
109
+ return Projection(
110
+ output_dim=self.model_dim,
111
+ input_dim=self.input_dim,
112
+ config=self.config.pre_denoiser,
113
+ device=self.device,
114
+ dtype=self.dtype,
115
+ )
116
+
117
+ def build_final_proj(self) -> Projection:
118
+ return Projection(
119
+ output_dim=self.input_dim,
120
+ input_dim=self.model_dim,
121
+ config=self.config.post_denoiser,
122
+ device=self.device,
123
+ dtype=self.dtype,
124
+ )
125
+
126
+ def build_model(self) -> LCMDenoiser:
127
+ """Build the denoiser with its layers and initial/final projections"""
128
+ embed_time = self.build_timestep_embedder()
129
+
130
+ layers = [self.build_layer() for _ in range(self.config.num_layers)]
131
+
132
+ norm_order = parse_norm_order(self.config.norm_order_style)
133
+
134
+ # Self-attention here does not contextualize
135
+ self_attn_mask_factory = NoAttentionMaskFactory()
136
+
137
+ cross_attention_mask_factory = self.build_cross_attention_mask()
138
+
139
+ layer_norm_factory = parse_layer_norm_factory(
140
+ self.config.layer_normalization_style
141
+ )
142
+
143
+ pos_encoder = self.build_pos_encoder()
144
+
145
+ return LCMDenoiser(
146
+ embed_time=embed_time,
147
+ layers=layers,
148
+ initial_proj=self.build_initial_proj(),
149
+ final_proj=self.build_final_proj(),
150
+ dropout_p=self.config.final_dropout_p,
151
+ norm_order=norm_order,
152
+ layer_norm_factory=layer_norm_factory,
153
+ self_attn_mask_factory=self_attn_mask_factory,
154
+ cross_attention_mask_factory=cross_attention_mask_factory,
155
+ pos_encoder=pos_encoder,
156
+ device=self.device,
157
+ dtype=self.dtype,
158
+ )
159
+
160
+ def build_layer(self) -> LCMDenoiserLayer:
161
+ """Build a Transformer decoder layer based on the provided config."""
162
+
163
+ assert isinstance(self.config, DenoiserConfig), (
164
+ "Expecting a DenoiserConfig in the DenoiserTransformerFactory"
165
+ )
166
+
167
+ self_attn = self.build_attention()
168
+
169
+ cross_attn = self.build_attention()
170
+
171
+ ffn = self.build_ffn()
172
+
173
+ norm_order = parse_norm_order(self.config.norm_order_style)
174
+
175
+ layer_norm_factory = parse_layer_norm_factory(
176
+ self.config.layer_normalization_style
177
+ )
178
+
179
+ modulator_input_dim = self_attn.model_dim
180
+
181
+ layer = LCMDenoiserLayer(
182
+ self_attn=self_attn,
183
+ cross_attention=cross_attn,
184
+ ffn=ffn,
185
+ modulator_input_dim=modulator_input_dim,
186
+ dropout_p=self.config.dropout_p,
187
+ norm_order=norm_order,
188
+ layer_norm_factory=layer_norm_factory,
189
+ device=self.device,
190
+ dtype=self.dtype,
191
+ )
192
+ return layer
lcm/nn/denoisers/lcm_denoiser.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from typing import Iterable, Optional, Tuple, cast
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from fairseq2.nn import PositionEncoder
11
+ from fairseq2.nn.incremental_state import IncrementalStateBag
12
+ from fairseq2.nn.normalization import LayerNorm
13
+ from fairseq2.nn.padding import PaddingMask
14
+ from fairseq2.nn.transformer import (
15
+ AttentionMask,
16
+ AttentionMaskFactory,
17
+ FeedForwardNetwork,
18
+ LayerNormFactory,
19
+ MultiheadAttention,
20
+ TransformerDecoderLayer,
21
+ TransformerNormOrder,
22
+ create_standard_layer_norm,
23
+ )
24
+ from fairseq2.typing import DataType, Device, override
25
+ from torch import Tensor
26
+ from torch.nn import Dropout, Module, ModuleList
27
+ from torch.nn.parameter import Parameter
28
+
29
+ from lcm.nn.projection import Projection
30
+ from lcm.nn.timestep_encoder import DiTTimestepEncoder
31
+
32
+
33
+ class AdaLNModulator(Module):
34
+ """An adaptive LayerNorm modulator to estimate
35
+ shift, gate and scale for all 3 sub-modules."""
36
+
37
+ def __init__(
38
+ self,
39
+ input_dim: int,
40
+ output_dim: int,
41
+ device: Optional[Device] = None,
42
+ dtype: Optional[DataType] = None,
43
+ ):
44
+ super().__init__()
45
+
46
+ self.activate = nn.SiLU()
47
+ self.fc = nn.Linear(
48
+ input_dim,
49
+ 9 * output_dim,
50
+ bias=True,
51
+ device=device,
52
+ dtype=dtype,
53
+ )
54
+
55
+ def reset_parameters(self):
56
+ # zero-init
57
+ nn.init.constant_(self.fc.weight, 0)
58
+ nn.init.constant_(self.fc.bias, 0)
59
+
60
+ def forward(self, context: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
61
+ (modulate_san, modulate_cross_attention, modulate_ffn) = self.fc(
62
+ self.activate(context)
63
+ ).chunk(3, dim=-1)
64
+ return modulate_san, modulate_cross_attention, modulate_ffn
65
+
66
+
67
+ class LCMDenoiser(Module):
68
+ """
69
+ The main denoiser module of the two-tower diffusion LCM.
70
+ """
71
+
72
+ model_dim: int
73
+ layers: ModuleList
74
+ self_attn_mask_factory: AttentionMaskFactory
75
+ layer_norm: Optional[LayerNorm]
76
+ dropout_p: float
77
+ norm_order: TransformerNormOrder
78
+ cross_attention_mask_factory: AttentionMaskFactory
79
+
80
+ def __init__(
81
+ self,
82
+ embed_time: DiTTimestepEncoder,
83
+ layers: Iterable[TransformerDecoderLayer],
84
+ initial_proj: Projection,
85
+ final_proj: Projection,
86
+ *,
87
+ self_attn_mask_factory: AttentionMaskFactory,
88
+ cross_attention_mask_factory: AttentionMaskFactory,
89
+ dropout_p: float = 0.0,
90
+ norm_order: TransformerNormOrder = TransformerNormOrder.POST,
91
+ pos_encoder: Optional[PositionEncoder] = None,
92
+ layer_norm_factory: Optional[LayerNormFactory] = None,
93
+ device: Optional[Device] = None,
94
+ dtype: Optional[DataType] = None,
95
+ ) -> None:
96
+ """
97
+ :param layers:
98
+ The decoder layers.
99
+ :param self_attn_mask_factory:
100
+ The self attention mask factory.
101
+ :param cross_attention_mask_factory:
102
+ The cross attention mask factory.
103
+ :param dropout_p:
104
+ The dropout probability on decoder outputs.
105
+ :param norm_order:
106
+ The Layer Normalization order.
107
+ :param: pos_encoder:
108
+ An optional positional encoding module
109
+ :param layer_norm_factory:
110
+ The factory to construct the Layer Normalization module.
111
+ """
112
+ layer_list = ModuleList(layers)
113
+
114
+ if not layer_list:
115
+ raise ValueError("`layers` must be non-empty.")
116
+
117
+ model_dim = layer_list[0].model_dim
118
+
119
+ super().__init__()
120
+
121
+ self.model_dim = model_dim
122
+
123
+ self.embed_time = embed_time
124
+
125
+ self.initial_proj = initial_proj
126
+
127
+ self.final_proj = final_proj
128
+
129
+ self.pos_encoder = pos_encoder
130
+
131
+ if layer_norm_factory is None:
132
+ layer_norm_factory = create_standard_layer_norm
133
+
134
+ self.self_attn_mask_factory = self_attn_mask_factory
135
+
136
+ self.layers = layer_list
137
+
138
+ if norm_order != TransformerNormOrder.POST:
139
+ self.layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype)
140
+ else:
141
+ self.register_module("layer_norm", None)
142
+
143
+ if dropout_p > 0.0:
144
+ self.dropout = Dropout(dropout_p)
145
+ else:
146
+ self.register_module("dropout", None)
147
+
148
+ self.norm_order = norm_order
149
+
150
+ self.cross_attention_mask_factory = cross_attention_mask_factory
151
+
152
+ def forward(
153
+ self,
154
+ seqs: Tensor,
155
+ diffusion_timesteps: Tensor,
156
+ padding_mask: Optional[PaddingMask],
157
+ conditioning_variables: Optional[Tensor] = None,
158
+ conditioning_variables_padding_mask: Optional[PaddingMask] = None,
159
+ source_lengths: Optional[Tensor] = None,
160
+ cf_guidance_prob: float = 0.0,
161
+ *,
162
+ state_bag: Optional[IncrementalStateBag] = None,
163
+ inference: Optional[bool] = False,
164
+ ) -> Tuple[Tensor, Optional[PaddingMask]]:
165
+ """
166
+ Arguments:
167
+ - seqs (`Tensor`): the sequence of latents to denoise
168
+ - diffusion_timesteps (`Tensor`) the indices of the diffusion timesteps
169
+ to be embedded and fed as a conditioning variable.
170
+ - padding_mask (`PaddingMask`) mask of padded positions in the latents (seqs)
171
+
172
+ - conditioning_variables (`Tensor`) the sequence of conditioning
173
+ variables that will be combined with the timestep embedding to
174
+ guide the diffusion process
175
+ - conditioning_variables_padding_mask (`PaddingMask`) the mask of padded
176
+ positions in `conditioning_variables`
177
+ - source_lengths (`Optional[Tensor]`) the lengths of the source embeddings
178
+ in `conditioning_variables` to properly shift the cross-attention mask
179
+ - cf_guidance_prob: probability rate with which to drop all conditioning variables when denoising
180
+ - state_bag (`IncrementalStateBag`) the incremental state bag of the denoiser to enable kv-caching
181
+ - inference (`bool`) if `True` the cross-attention mask will be adjusted accordingly
182
+ """
183
+
184
+ emb_timesteps = self.embed_time(diffusion_timesteps)
185
+ assert conditioning_variables is not None, (
186
+ "Expected conditioning_variables, found None"
187
+ )
188
+
189
+ assert conditioning_variables is not None, (
190
+ "Mypy - Expecting non-None conditioning_variables"
191
+ )
192
+
193
+ conditioning_variables = torch.cat(
194
+ [
195
+ torch.zeros_like(conditioning_variables[:, 0:1]),
196
+ conditioning_variables,
197
+ ],
198
+ dim=1,
199
+ )
200
+
201
+ if conditioning_variables_padding_mask is not None:
202
+ # shift by the length of the prepended timesteps
203
+ conditioning_variables_padding_mask = PaddingMask(
204
+ conditioning_variables_padding_mask._seq_lens + 1,
205
+ conditioning_variables_padding_mask._batch_seq_len + 1,
206
+ )
207
+
208
+ # project to model_dim and add optional position codes:
209
+ seqs = self.initial_proj(seqs)
210
+
211
+ if self.pos_encoder is not None:
212
+ seqs = self.pos_encoder(seqs, padding_mask)
213
+
214
+ self_attn_mask = self.self_attn_mask_factory(
215
+ seqs, keys=seqs, training=self.training, state_bag=state_bag
216
+ )
217
+
218
+ assert conditioning_variables is not None
219
+ cross_attention_mask = self.cross_attention_mask_factory(
220
+ seqs,
221
+ keys=conditioning_variables,
222
+ source_lengths=source_lengths,
223
+ cf_guidance_prob=cf_guidance_prob,
224
+ training=self.training,
225
+ state_bag=state_bag,
226
+ inference=inference, # type: ignore
227
+ )
228
+
229
+ for layer_idx, layer in enumerate(self.layers):
230
+ layer_output, layer_padding_mask = layer(
231
+ seqs=seqs,
232
+ padding_mask=padding_mask,
233
+ self_attn_mask=self_attn_mask,
234
+ emb_timesteps=emb_timesteps,
235
+ conditioning_variables=conditioning_variables,
236
+ conditioning_variables_padding_mask=conditioning_variables_padding_mask,
237
+ cross_attention_mask=cross_attention_mask,
238
+ state_bag=state_bag,
239
+ )
240
+
241
+ seqs, padding_mask = layer_output, layer_padding_mask
242
+
243
+ if self.layer_norm is not None:
244
+ seqs = self.layer_norm(seqs)
245
+
246
+ if self.dropout is not None:
247
+ seqs = self.dropout(seqs)
248
+
249
+ seqs = self.final_proj(seqs)
250
+
251
+ return seqs, padding_mask
252
+
253
+
254
+ class LCMDenoiserLayer(TransformerDecoderLayer):
255
+ """A single layer of the hybrid denoiser"""
256
+
257
+ self_attn: MultiheadAttention
258
+ self_attn_norm: Optional[LayerNorm]
259
+ self_attn_dropout: Optional[Dropout]
260
+ self_attn_layer_norm: LayerNorm
261
+ cross_attention: MultiheadAttention
262
+ cross_attention_dropout: Optional[Dropout]
263
+ cross_attention_layer_norm: Optional[LayerNorm]
264
+ ffn: FeedForwardNetwork
265
+ ffn_dropout: Optional[Dropout]
266
+ residual_scale: Optional[Parameter]
267
+ ffn_layer_norm: LayerNorm
268
+ norm_order: TransformerNormOrder
269
+
270
+ def __init__(
271
+ self,
272
+ self_attn: MultiheadAttention,
273
+ ffn: FeedForwardNetwork,
274
+ cross_attention: MultiheadAttention,
275
+ *,
276
+ scale_residual: bool = False,
277
+ dropout_p: float = 0.0,
278
+ norm_order: TransformerNormOrder = TransformerNormOrder.POST,
279
+ layer_norm_factory: Optional[LayerNormFactory] = None,
280
+ modulator_input_dim: Optional[int] = None,
281
+ device: Optional[Device] = None,
282
+ dtype: Optional[DataType] = None,
283
+ ) -> None:
284
+ """
285
+ :param self_attn:
286
+ The self attention layer.
287
+ :param cross_attention:
288
+ The cross attention layer if denoiser-type is `cross-attention`.
289
+ :param ffn:
290
+ The feed-forward network.
291
+ :param scale_residual:
292
+ If ``True``, scales residuals before adding them to the output of
293
+ the feed-forward network as described in
294
+ :cite:t:`https://doi.org/10.48550/arxiv.2110.09456`.
295
+ :param dropout_p:
296
+ The dropout probability on outputs of the attention layers and the
297
+ feed-forward network.
298
+ :param norm_order:
299
+ The Layer Normalization order.
300
+ :param layer_norm_factory:
301
+ The factory to construct the Layer Normalization modules.
302
+ """
303
+ model_dim = self_attn.model_dim
304
+
305
+ super().__init__(model_dim)
306
+
307
+ if layer_norm_factory is None:
308
+ layer_norm_factory = create_standard_layer_norm
309
+
310
+ self_attn_layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype)
311
+
312
+ if norm_order != TransformerNormOrder.POST:
313
+ self.self_attn_layer_norm = self_attn_layer_norm
314
+
315
+ self.self_attn = self_attn
316
+
317
+ if norm_order == TransformerNormOrder.PRE_WITH_NORMFORMER:
318
+ self.self_attn_norm = layer_norm_factory(
319
+ model_dim, device=device, dtype=dtype
320
+ )
321
+ else:
322
+ self.register_module("self_attn_norm", None)
323
+
324
+ if dropout_p > 0.0:
325
+ self.self_attn_dropout = Dropout(dropout_p)
326
+ else:
327
+ self.register_module("self_attn_dropout", None)
328
+
329
+ if norm_order == TransformerNormOrder.POST:
330
+ self.self_attn_layer_norm = self_attn_layer_norm
331
+
332
+ # Deal with the cross-attention layers:
333
+ if cross_attention is None:
334
+ self.register_module("cross_attention", None)
335
+ self.register_module("cross_attention_layer_norm", None)
336
+ else:
337
+ cross_attention_layer_norm = layer_norm_factory(
338
+ model_dim, device=device, dtype=dtype
339
+ )
340
+
341
+ if norm_order != TransformerNormOrder.POST:
342
+ self.cross_attention_layer_norm = cross_attention_layer_norm
343
+
344
+ self.cross_attention = cross_attention
345
+
346
+ if dropout_p > 0.0:
347
+ self.cross_attention_dropout = Dropout(dropout_p)
348
+ else:
349
+ self.register_module("cross_attention_dropout", None)
350
+
351
+ if norm_order == TransformerNormOrder.POST:
352
+ self.cross_attention_layer_norm = cross_attention_layer_norm
353
+ # / deal with cross-attention
354
+
355
+ ffn_layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype)
356
+
357
+ if norm_order != TransformerNormOrder.POST:
358
+ self.ffn_layer_norm = ffn_layer_norm
359
+
360
+ self.ffn = ffn
361
+
362
+ if dropout_p > 0.0:
363
+ self.ffn_dropout = Dropout(dropout_p)
364
+ else:
365
+ self.register_module("ffn_dropout", None)
366
+
367
+ if norm_order == TransformerNormOrder.POST:
368
+ self.ffn_layer_norm = ffn_layer_norm
369
+
370
+ self.norm_order = norm_order
371
+
372
+ # Add a modulator:
373
+ modulator_input_dim = modulator_input_dim or model_dim
374
+ self.modulator = AdaLNModulator(
375
+ input_dim=modulator_input_dim,
376
+ output_dim=model_dim,
377
+ device=device,
378
+ dtype=dtype,
379
+ )
380
+
381
+ self.reset_parameters()
382
+
383
+ def reset_parameters(self) -> None:
384
+ """Reset the parameters and buffers of the module."""
385
+ # Zero-out the modulators:
386
+ self.modulator.reset_parameters()
387
+
388
+ @override
389
+ def forward( # type: ignore
390
+ self,
391
+ seqs: Tensor,
392
+ padding_mask: Optional[PaddingMask],
393
+ conditioning_variables: Tensor,
394
+ emb_timesteps: Tensor,
395
+ self_attn_mask: Optional[AttentionMask] = None,
396
+ conditioning_variables_padding_mask: Optional[PaddingMask] = None,
397
+ cross_attention_mask: Optional[AttentionMask] = None,
398
+ *,
399
+ state_bag: Optional[IncrementalStateBag] = None,
400
+ ) -> Tuple[Tensor, Optional[PaddingMask]]:
401
+ # Get modulator output:
402
+ (modulate_san, modulate_cross_attention, modulate_ffn) = self.modulator(
403
+ emb_timesteps
404
+ )
405
+
406
+ seqs = self._forward_self_attn(
407
+ seqs=seqs,
408
+ padding_mask=padding_mask,
409
+ modulators=modulate_san,
410
+ self_attn_mask=self_attn_mask,
411
+ state_bag=state_bag,
412
+ )
413
+
414
+ seqs = self._forward_cross_attention(
415
+ seqs=seqs,
416
+ padding_mask=padding_mask,
417
+ conditioning_variables=conditioning_variables,
418
+ modulators=modulate_cross_attention,
419
+ cross_attention_mask=cross_attention_mask,
420
+ key_padding_mask=conditioning_variables_padding_mask,
421
+ state_bag=state_bag,
422
+ )
423
+
424
+ seqs = self._forward_ffn(
425
+ seqs=seqs,
426
+ modulators=modulate_ffn,
427
+ )
428
+
429
+ return seqs, padding_mask
430
+
431
+ def _forward_self_attn(
432
+ self,
433
+ seqs: Tensor,
434
+ modulators: Tensor,
435
+ padding_mask: Optional[PaddingMask],
436
+ self_attn_mask: Optional[AttentionMask],
437
+ state_bag: Optional[IncrementalStateBag],
438
+ ) -> Tensor:
439
+ residual = seqs
440
+
441
+ assert self.norm_order != TransformerNormOrder.POST, (
442
+ "DiT AdaLN expect pre-normalization"
443
+ )
444
+
445
+ if self.norm_order != TransformerNormOrder.POST:
446
+ seqs = self.self_attn_layer_norm(seqs)
447
+
448
+ # split modulators into shift, scale and gate:
449
+ shift, scale, gate = modulators.chunk(3, dim=-1)
450
+
451
+ # modulate the input:
452
+ seqs = seqs * (1 + scale) + shift
453
+
454
+ seqs = self.self_attn(
455
+ seqs,
456
+ padding_mask,
457
+ keys=seqs,
458
+ key_padding_mask=None,
459
+ values=seqs,
460
+ attn_mask=self_attn_mask,
461
+ state_bag=state_bag,
462
+ )
463
+
464
+ if self.self_attn_norm is not None:
465
+ seqs = self.self_attn_norm(seqs)
466
+
467
+ if self.self_attn_dropout is not None:
468
+ seqs = self.self_attn_dropout(seqs)
469
+
470
+ # Scale the residual with the gate weights
471
+ seqs = residual + gate * seqs
472
+
473
+ return seqs
474
+
475
+ def _forward_cross_attention(
476
+ self,
477
+ seqs: Tensor,
478
+ modulators: Tensor,
479
+ padding_mask: Optional[PaddingMask],
480
+ conditioning_variables: Optional[Tensor],
481
+ key_padding_mask: Optional[PaddingMask],
482
+ cross_attention_mask: Optional[AttentionMask],
483
+ state_bag: Optional[IncrementalStateBag],
484
+ ) -> Tensor:
485
+ if conditioning_variables is None:
486
+ raise ValueError(
487
+ "`conditioning_variables` must not be `None` for cross attention."
488
+ )
489
+
490
+ residual = seqs
491
+
492
+ assert self.norm_order != TransformerNormOrder.POST, (
493
+ "DiT AdaLN expect pre-normalization"
494
+ )
495
+
496
+ if self.norm_order != TransformerNormOrder.POST:
497
+ seqs = cast(LayerNorm, self.cross_attention_layer_norm)(seqs)
498
+
499
+ # split modulators into shift, scale and gate:
500
+ shift, scale, gate = modulators.chunk(3, dim=-1)
501
+
502
+ # modulate the input:
503
+ seqs = seqs * (1 + scale) + shift
504
+
505
+ seqs = self.cross_attention(
506
+ seqs,
507
+ padding_mask,
508
+ keys=conditioning_variables,
509
+ key_padding_mask=key_padding_mask,
510
+ attn_mask=cross_attention_mask,
511
+ values=conditioning_variables,
512
+ state_bag=state_bag,
513
+ )
514
+
515
+ if self.cross_attention_dropout is not None:
516
+ seqs = self.cross_attention_dropout(seqs)
517
+
518
+ # Scale the residual with the gate weights
519
+ seqs = residual + gate * seqs
520
+
521
+ return seqs
522
+
523
+ def _forward_ffn(self, seqs: Tensor, modulators: Tensor) -> Tensor:
524
+ assert self.norm_order != TransformerNormOrder.POST, (
525
+ "DiT AdaLN expects pre-normalization"
526
+ )
527
+ residual = seqs
528
+
529
+ if self.norm_order != TransformerNormOrder.POST:
530
+ seqs = self.ffn_layer_norm(seqs)
531
+
532
+ # split modulators into shift, scale and gate:
533
+ shift, scale, gate = modulators.chunk(3, dim=-1)
534
+
535
+ # modulate the input:
536
+ seqs = seqs * (1 + scale) + shift
537
+
538
+ seqs = self.ffn(seqs)
539
+
540
+ if self.ffn_dropout is not None:
541
+ seqs = self.ffn_dropout(seqs)
542
+
543
+ # Scale the branch with the gate weights
544
+ seqs = residual + gate * seqs
545
+
546
+ return seqs
lcm/nn/incremental_state.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from typing import Dict, Optional, final
7
+
8
+ from fairseq2.nn.incremental_state import IncrementalState, IncrementalStateBag
9
+ from fairseq2.nn.transformer import FullAttentionState
10
+ from torch import Tensor
11
+ from torch.nn import Module
12
+
13
+
14
+ @final
15
+ class LCMIncrementalStateBag(IncrementalStateBag): # type: ignore
16
+ """Holds the module states during incremental decoding."""
17
+
18
+ _module_states: Dict[Module, FullAttentionState] # type: ignore
19
+
20
+ def __init__(
21
+ self, max_num_steps: int, *, capacity_increment: Optional[int] = 16
22
+ ) -> None:
23
+ super().__init__(
24
+ max_num_steps=max_num_steps, capacity_increment=capacity_increment
25
+ )
26
+
27
+ def reorder(self, new_order: Tensor) -> None:
28
+ """Reorder the module states.
29
+
30
+ See :meth:`IncrementalState.reorder` for more information.
31
+ """
32
+ # FIXME Deal with reordering diffusion state bags here
33
+ for state in self._module_states.values():
34
+ state.reorder(new_order)
35
+
36
+ def set_state(self, m: Module, state: IncrementalState) -> None:
37
+ """Set the state of ``m``.
38
+ :param m: The module.
39
+ :param state: The state to store.
40
+ There is no current call to `set_state` when the bag
41
+ is frozen, but it's implemented here for completeness
42
+ """
43
+ super().set_state(m, state)
lcm/nn/initialization.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ import math
7
+ from functools import partial
8
+ from typing import Literal, Optional
9
+
10
+ import torch
11
+ from fairseq2.nn.projection import Linear
12
+ from fairseq2.nn.transformer import TransformerNormOrder
13
+ from torch.nn import Module
14
+
15
+ SUPPORTED_INIT_TYPES = Literal[
16
+ "xavier",
17
+ "sonar",
18
+ "zero",
19
+ "trunc_normal",
20
+ "kaiming_uniform",
21
+ "none",
22
+ ]
23
+
24
+
25
+ SONAR_STD = 0.006
26
+ # Most SONAR embeddings have a distribution with the mean close to 0 and std close to 0.006
27
+ # Initializing embedding-like parameters (e.g. end-of-text vector) from a similar distribution is recommended,
28
+ # to minimize their disruption of the model training
29
+
30
+
31
+ def get_init_fn(style: str = "xavier", sonar_std: float = SONAR_STD):
32
+ if style == "xavier":
33
+ return init_linear_xavier
34
+
35
+ if style == "kaiming_uniform":
36
+ return init_linear_kaiming_uniform
37
+
38
+ if style == "sonar":
39
+ return partial(init_linear_to_sonar, sonar_std=sonar_std)
40
+
41
+ if style == "zero":
42
+ return init_linear_zero
43
+
44
+ if style == "trunc_normal":
45
+ return init_linear_trunc_normal
46
+
47
+ if style == "none":
48
+ return None
49
+
50
+ else:
51
+ raise ValueError(f"Could not recognize initialization function {style}")
52
+
53
+
54
+ def init_linear_to_sonar(layer: Linear, sonar_std: float) -> None:
55
+ """
56
+ Initialize the post-lcm in such a way, that if it is fed layer-normed
57
+ lcm outputs (with zero mean and unit variance), its outputs have zero
58
+ mean and the variance of SONAR embeddings.
59
+ """
60
+ if layer.bias is not None:
61
+ torch.nn.init.zeros_(layer.bias)
62
+
63
+ std = sonar_std * (3 / layer.input_dim) ** 0.5
64
+
65
+ torch.nn.init.uniform_(layer.weight, a=-std, b=std)
66
+
67
+
68
+ def init_linear_xavier(layer: Linear) -> None:
69
+ torch.nn.init.xavier_uniform_(layer.weight)
70
+ if layer.bias is not None:
71
+ torch.nn.init.zeros_(layer.bias)
72
+
73
+
74
+ def init_linear_zero(layer: Linear) -> None:
75
+ torch.nn.init.zeros_(layer.weight)
76
+ if layer.bias is not None:
77
+ torch.nn.init.zeros_(layer.bias)
78
+
79
+
80
+ def init_linear_trunc_normal(layer: Linear) -> None:
81
+ torch.nn.init.trunc_normal_(layer.weight, std=1e-3)
82
+ if layer.bias is not None:
83
+ torch.nn.init.zeros_(layer.bias)
84
+
85
+
86
+ def init_linear_kaiming_uniform(layer: Linear) -> None:
87
+ torch.nn.init.kaiming_uniform_(layer.weight, a=math.sqrt(5))
88
+
89
+ if layer.bias is not None:
90
+ fan_in = layer.weight.size(1)
91
+
92
+ m = 1
93
+ if layer.weight.ndim > 2:
94
+ for s in layer.weight.shape[2:]:
95
+ m *= s
96
+
97
+ fan_in *= m
98
+
99
+ # We do not calculate the true standard deviation of the uniform
100
+ # distribution (i.e. multiply with sqrt(3)). See
101
+ # https://github.com/pytorch/pytorch/issues/57109#issuecomment-828847575.
102
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
103
+
104
+ torch.nn.init.uniform_(layer.bias, -bound, bound)
105
+
106
+
107
+ def parse_norm_order(var: str) -> TransformerNormOrder:
108
+ norm_order: TransformerNormOrder
109
+ if var == "pre":
110
+ norm_order = TransformerNormOrder.PRE
111
+ elif var == "post":
112
+ norm_order = TransformerNormOrder.POST
113
+ elif var == "normformer":
114
+ norm_order = TransformerNormOrder.PRE_WITH_NORMFORMER
115
+ else:
116
+ raise ValueError(f"Unknown normalization order {var}")
117
+
118
+ return norm_order
119
+
120
+
121
+ def parse_activation_fn(var: str = None) -> Optional[Module]:
122
+ if var is None:
123
+ return None
124
+
125
+ activ_fn: Module
126
+
127
+ if var == "relu":
128
+ activ_fn = torch.nn.ReLU()
129
+ elif var == "tanh":
130
+ activ_fn = torch.nn.Tanh()
131
+ elif var == "elu":
132
+ activ_fn = torch.nn.ELU()
133
+ elif var == "leaky_relu":
134
+ activ_fn = torch.nn.LeakyReLU()
135
+ elif var == "prelu":
136
+ activ_fn = torch.nn.PReLU()
137
+ elif var == "selu":
138
+ activ_fn = torch.nn.SELU()
139
+ elif var == "gelu":
140
+ activ_fn = torch.nn.GELU()
141
+ elif var == "silu":
142
+ activ_fn = torch.nn.SiLU()
143
+ elif var == "softsign":
144
+ activ_fn = torch.nn.Softsign()
145
+ elif var == "sigmoid":
146
+ activ_fn = torch.nn.Sigmoid()
147
+ elif var == "hardsigmoid":
148
+ activ_fn = torch.nn.Hardsigmoid()
149
+ else:
150
+ raise ValueError(f"Unknown activation function {var}")
151
+
152
+ return activ_fn
lcm/nn/normalization.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from typing import Literal, Optional, final
7
+
8
+ import torch
9
+ from fairseq2.nn import LayerNorm, RMSNorm, StandardLayerNorm
10
+ from fairseq2.nn.transformer import LayerNormFactory, create_standard_layer_norm
11
+ from fairseq2.typing import DataType, Device, override
12
+
13
+ SUPPORTED_LN_TYPES = Literal["standard", "fp32", "rms", "unit"]
14
+
15
+
16
+ @final
17
+ class FP32LayerNorm(LayerNorm):
18
+ """Applies Layer Normalization in single-precision."""
19
+
20
+ @override
21
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
22
+ w, b = self.weight, self.bias
23
+
24
+ # cast input and params to float32
25
+ fp32_x = x.float()
26
+ fp32_w = w.float() if w is not None else None
27
+ fp32_b = b.float() if b is not None else None
28
+
29
+ y = torch.nn.functional.layer_norm(
30
+ fp32_x, self.normalized_shape, fp32_w, fp32_b, self.eps
31
+ )
32
+
33
+ return y.type_as(x)
34
+
35
+
36
+ def build_rms_layer_norm(
37
+ model_dim: int,
38
+ *,
39
+ device: Optional[Device] = None,
40
+ dtype: Optional[DataType] = None,
41
+ ) -> LayerNorm:
42
+ """Build an RMS Layer Normalization module."""
43
+ return RMSNorm(model_dim, bias=False, device=device, dtype=dtype)
44
+
45
+
46
+ def build_fp32_layer_norm(
47
+ model_dim: int,
48
+ *,
49
+ device: Optional[Device] = None,
50
+ dtype: Optional[DataType] = None,
51
+ ) -> LayerNorm:
52
+ """Build an Single-precision Layer Normalization module."""
53
+ return FP32LayerNorm(model_dim, bias=False, device=device, dtype=dtype)
54
+
55
+
56
+ def build_unit_layer_norm(
57
+ model_dim: int,
58
+ *,
59
+ device: Optional[Device] = None,
60
+ dtype: Optional[DataType] = None,
61
+ ) -> LayerNorm:
62
+ """Create an instance of :class:`StandardLayerNorm
63
+ without learnable mean and variance`."""
64
+ return StandardLayerNorm(
65
+ model_dim,
66
+ bias=False,
67
+ elementwise_affine=False,
68
+ device=device,
69
+ dtype=dtype,
70
+ )
71
+
72
+
73
+ def parse_layer_norm_factory(layer_normalization_style: str) -> LayerNormFactory:
74
+ if layer_normalization_style == "rms":
75
+ # Note that RMSNorm normalizes in single-precision by default
76
+ return build_rms_layer_norm
77
+
78
+ elif layer_normalization_style == "unit":
79
+ return build_unit_layer_norm
80
+
81
+ elif layer_normalization_style == "fp32":
82
+ return build_fp32_layer_norm
83
+
84
+ elif layer_normalization_style == "standard":
85
+ return create_standard_layer_norm
86
+
87
+ else:
88
+ raise ValueError(f"Unsupported LayerNorm style {layer_normalization_style}")
lcm/nn/projection.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Optional
8
+
9
+ import torch
10
+ from fairseq2.nn.projection import Linear
11
+ from fairseq2.typing import DataType, Device
12
+ from torch import Tensor
13
+ from torch.nn import Module
14
+
15
+ from lcm.nn.initialization import (
16
+ SUPPORTED_INIT_TYPES,
17
+ get_init_fn,
18
+ parse_activation_fn,
19
+ )
20
+ from lcm.nn.normalization import SUPPORTED_LN_TYPES
21
+
22
+
23
+ @dataclass
24
+ class ProjectionConfig:
25
+ dropout_p: float = 0.0
26
+ """ The dropout probability applied to the module' output"""
27
+
28
+ linear_bias: bool = True
29
+ """ Whether or not the pre-linear layer has a bias term"""
30
+
31
+ linear_init_fn: SUPPORTED_INIT_TYPES = "kaiming_uniform"
32
+
33
+ weight_normalization: bool = False
34
+
35
+ layer_normalization_style: SUPPORTED_LN_TYPES = "standard"
36
+
37
+ activation_name: Optional[str] = None
38
+ """the activation function to apply after fi any"""
39
+
40
+
41
+ class Projection(Module):
42
+ """
43
+ An output projecton module.
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ output_dim: int,
49
+ input_dim: int,
50
+ config: ProjectionConfig,
51
+ device: Optional[Device] = None,
52
+ dtype: Optional[DataType] = None,
53
+ ) -> None:
54
+ super().__init__()
55
+
56
+ self.dtype = dtype
57
+
58
+ init_fn = get_init_fn(config.linear_init_fn)
59
+
60
+ lin = Linear(
61
+ input_dim,
62
+ output_dim,
63
+ bias=config.linear_bias,
64
+ device=device,
65
+ dtype=dtype,
66
+ init_fn=init_fn,
67
+ )
68
+ if config.weight_normalization:
69
+ self.fc = torch.nn.utils.parametrizations.weight_norm(lin)
70
+ else:
71
+ self.fc = lin
72
+
73
+ self.activation_fn = parse_activation_fn(config.activation_name)
74
+
75
+ if self.activation_fn is not None:
76
+ # some activation functions (e.g., PReLU) have parameters
77
+ # and so we need to move them to the right device
78
+ self.activation_fn.to(device)
79
+
80
+ def forward(self, seqs: Tensor):
81
+ seqs = self.fc(seqs)
82
+
83
+ if self.activation_fn is not None:
84
+ seqs = self.activation_fn(seqs)
85
+
86
+ return seqs
lcm/nn/schedulers/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+
7
+ from lcm.nn.schedulers.ddim import (
8
+ DDIMScheduler,
9
+ DDIMSchedulerConfig,
10
+ DDIMSchedulerOutput,
11
+ )
12
+
13
+ __all__ = [
14
+ "DDIMScheduler",
15
+ "DDIMSchedulerConfig",
16
+ "DDIMSchedulerOutput",
17
+ ]
lcm/nn/schedulers/ddim.py ADDED
@@ -0,0 +1,741 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ # This code is based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddim.py, which is distributed under the Apache 2.0 License.
7
+ # HuggingFace's diffusers DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
8
+ # and https://github.com/hojonathanho/diffusion
9
+
10
+ import math
11
+ from dataclasses import dataclass
12
+ from typing import List, Literal, Optional, Tuple, Union
13
+
14
+ import torch
15
+ from fairseq2.logging import get_log_writer
16
+ from fairseq2.typing import CPU
17
+ from torch import Tensor
18
+
19
+ logger = get_log_writer(__name__)
20
+
21
+
22
+ def sigmoid(x):
23
+ return 1 / (1 + math.exp(-x))
24
+
25
+
26
+ def logit(x):
27
+ return math.log(x / (1 - x))
28
+
29
+
30
+ @dataclass
31
+ class DDIMSchedulerOutput:
32
+ """
33
+ Output class for the scheduler's `step` function output.
34
+
35
+ Args:
36
+ prev_sample (`Tensor` of shape `(batch_size, num_channels, height, width)` for images):
37
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
38
+ denoising loop.
39
+ pred_original_sample (`Tensor` of shape `(batch_size, num_channels, height, width)` for images):
40
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
41
+ `pred_original_sample` can be used to preview progress or for guidance.
42
+ """
43
+
44
+ prev_sample: Tensor
45
+ pred_original_sample: Tensor
46
+
47
+
48
+ @dataclass
49
+ class DDIMSchedulerConfig:
50
+ num_diffusion_train_steps: int = 1000
51
+ """The number of diffusion steps to train the model."""
52
+
53
+ beta_start: float = 0.0001
54
+ """The starting `beta` value of inference."""
55
+
56
+ beta_end: float = 0.02
57
+ """The final `beta` value."""
58
+ """In DDPM (https://arxiv.org/pdf/2006.11239), $\beta_t$ is increasing
59
+ linearly from $\beta_1$ (`beta_start`)=1e−4 to $\beta_T$ (`beta_end`)=0.02.
60
+ These constants were chosen to be small relative to data scaled to [−1, 1],
61
+ ensuring that reverse and forward processes have approximately
62
+ the same functional form while keeping the signal-to-noise ratio at $x_T$ as small as possible.
63
+ Another common choice in HF:diffusers `beta_start=0.00085, beta_end=0.012,`
64
+ Note that `beta_start` and `beta_end` are irrelevant for `squaredcos_cap_v2`
65
+ """
66
+
67
+ beta_schedule: Literal[
68
+ "linear",
69
+ "scaled_linear",
70
+ "squaredcos_cap_v2",
71
+ "sigmoid",
72
+ ] = "squaredcos_cap_v2"
73
+ """The beta schedule, a mapping from a beta range to a sequence of betas
74
+ for stepping the model (length=`num_diffusion_train_steps`).
75
+ Choose from:
76
+ - `linear`: Linearly spaced betas between `beta_start` and `beta_end`.
77
+ Referred to as `sqrt_linear` in stable-diffusion.
78
+ - `scaled_linear`: Squared values after linearly spacing form sqrt(beta_start) to sqrt(beta_end).
79
+ Referred to as `linear` in stable-diffusion.
80
+ -`squaredcos_cap_v2`: Creates a beta schedule that discretizes
81
+ math:: $\bar alpha(t) = {cos((t/T + s) / (1+s) * \pi/2)}^2$, HF:diffusers sets `s` to 0.008.
82
+ For the intuition behind how a cosine schedule compares to a linear schedule
83
+ see Figure 3 of https://arxiv.org/pdf/2102.09672
84
+ - `sigmoid` our sigmoid schedule (see Equation 14 of the LCM paper).
85
+ """
86
+
87
+ scaled_linear_exponent: float = 2.0
88
+ """Exponent for the scaled linear beta schedule. Default is quadratic (scaled_linear_exponent=2)"""
89
+
90
+ sigmoid_schedule_alpha: float = 1.5
91
+ sigmoid_schedule_beta: float = 0
92
+ """alpha and beta hyper-parameters of the sigmoid beta-schedule"""
93
+
94
+ clip_sample: bool = False
95
+ """Clip the predicted sample for numerical stability."""
96
+
97
+ clip_sample_range: float = 1.0
98
+ """The maximum magnitude for sample clipping. Valid only when `clip_sample=True`."""
99
+
100
+ set_alpha_to_one: bool = True
101
+ """Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
102
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
103
+ otherwise it uses the alpha value at step 0."""
104
+
105
+ prediction_type: Literal["sample", "epsilon", "v_prediction"] = "sample"
106
+ """If `sample`, the model predicts the clean ground truth embeddings.
107
+ If `epsilon`, the model predicts the added noise of the diffusion process.
108
+ If `v_epsilon`, the model predicts an interpolation of the ground truth clean
109
+ embeddings and the added noise. As introduced in section 2.4 of the Imagen paper
110
+ (https://imagen.research.google/video/paper.pdf)
111
+ """
112
+
113
+ thresholding: bool = False
114
+ """Whether to use the "dynamic thresholding" method.
115
+ This is unsuitable for latent-space diffusion models such as Stable Diffusion."""
116
+
117
+ dynamic_thresholding_ratio: float = 0.995
118
+ """The ratio for the dynamic thresholding method. Valid only when `thresholding=True`."""
119
+
120
+ sample_max_value: float = 1.0
121
+ """The threshold value for dynamic thresholding. Valid only when `thresholding=True`."""
122
+
123
+ rescale_betas_zero_snr: bool = True
124
+ """Whether to rescale the betas to have zero terminal SNR. This enables the
125
+ model to generate very bright and dark samples instead of limiting it to samples
126
+ with medium brightness. Loosely related to
127
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506)."""
128
+
129
+ # Inference specific
130
+ timestep_spacing: Literal["linspace", "leading", "trailing"] = "trailing"
131
+ """The way the timesteps should be scaled. Refer to Table 2 of
132
+ https://arxiv.org/abs/2305.08891 for more information."""
133
+
134
+
135
+ class DDIMScheduler:
136
+ def __init__(self, config: DDIMSchedulerConfig):
137
+ self.config = config
138
+
139
+ # Make these 2 arguments easily accessible
140
+ self.num_diffusion_train_steps = self.config.num_diffusion_train_steps
141
+
142
+ self.prediction_type = self.config.prediction_type
143
+
144
+ beta_schedule = self.config.beta_schedule
145
+
146
+ if beta_schedule == "linear":
147
+ self.betas = torch.linspace(
148
+ self.config.beta_start,
149
+ self.config.beta_end,
150
+ self.num_diffusion_train_steps,
151
+ dtype=torch.float32,
152
+ )
153
+ elif beta_schedule == "scaled_linear":
154
+ # This schedule is very specific to the latent diffusion model.
155
+ exponent = self.config.scaled_linear_exponent
156
+ self.betas = (
157
+ torch.linspace(
158
+ self.config.beta_start ** (1 / exponent),
159
+ self.config.beta_end ** (1 / exponent),
160
+ self.num_diffusion_train_steps,
161
+ dtype=torch.float32,
162
+ )
163
+ ** exponent
164
+ )
165
+ elif beta_schedule == "squaredcos_cap_v2":
166
+ # Cosine schedule as introduced in
167
+ # [Nichol and Dhariwal, 2021](https://proceedings.mlr.press/v139/nichol21a/nichol21a.pdf)
168
+ self.betas = betas_for_alpha_bar(
169
+ self.num_diffusion_train_steps,
170
+ alpha_transform_type="cosine",
171
+ )
172
+
173
+ elif beta_schedule == "sigmoid":
174
+ self.betas = betas_for_alpha_bar(
175
+ self.num_diffusion_train_steps,
176
+ alpha_transform_type="sigmoid",
177
+ sigmoid_alpha=self.config.sigmoid_schedule_alpha,
178
+ sigmoid_beta=self.config.sigmoid_schedule_beta,
179
+ )
180
+
181
+ else:
182
+ raise NotImplementedError(
183
+ f"We do not recognize beta_schedule={beta_schedule}"
184
+ )
185
+
186
+ # Rescale for zero SNR
187
+ if self.config.rescale_betas_zero_snr:
188
+ self.betas = rescale_zero_terminal_snr(self.betas)
189
+
190
+ self.alphas = 1.0 - self.betas
191
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
192
+
193
+ # At every step in ddim, we are looking into the previous alphas_cumprod
194
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
195
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
196
+ # whether we use the final alpha of the "non-previous" one.
197
+ self.final_alpha_cumprod = (
198
+ torch.tensor(1.0)
199
+ if self.config.set_alpha_to_one
200
+ else self.alphas_cumprod[0]
201
+ )
202
+
203
+ # standard deviation of the initial noise distribution
204
+ self.init_noise_sigma = 1.0
205
+
206
+ # timesteps for inference
207
+ self.num_inference_steps: Optional[int] = None
208
+
209
+ def _get_variance(self, timestep, prev_timestep):
210
+ alpha_prod_t = self.alphas_cumprod[timestep]
211
+ alpha_prod_t_prev = (
212
+ self.alphas_cumprod[prev_timestep]
213
+ if prev_timestep >= 0
214
+ else self.final_alpha_cumprod
215
+ )
216
+ beta_prod_t = 1 - alpha_prod_t
217
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
218
+
219
+ variance = (beta_prod_t_prev / beta_prod_t) * (
220
+ 1 - alpha_prod_t / alpha_prod_t_prev
221
+ )
222
+ return variance
223
+
224
+ def get_variances(self) -> Tensor:
225
+ alpha_prod_t = self.alphas_cumprod
226
+ alpha_prod_t_prev = torch.cat(
227
+ (torch.tensor([self.final_alpha_cumprod]), alpha_prod_t[:-1])
228
+ )
229
+ beta_prod_t = 1 - alpha_prod_t
230
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
231
+
232
+ variance = (beta_prod_t_prev / beta_prod_t) * (
233
+ 1 - alpha_prod_t / alpha_prod_t_prev
234
+ )
235
+ return variance
236
+
237
+ def get_snrs(self) -> Tensor:
238
+ alphas_cumprod = self.alphas_cumprod
239
+ snr = alphas_cumprod / (1 - alphas_cumprod)
240
+ return snr
241
+
242
+ def _threshold_sample(self, sample: Tensor) -> Tensor:
243
+ """
244
+ "Dynamic thresholding: At each sampling step we set s to a certain
245
+ percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t),
246
+ and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
247
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1)
248
+ inwards, thereby actively preventing pixels from saturation at each step.
249
+ We find that dynamic thresholding results in significantly better
250
+ photorealism as well as better image-text alignment,
251
+ especially when using very large guidance weights."
252
+
253
+ https://arxiv.org/abs/2205.11487
254
+ """
255
+ dtype = sample.dtype
256
+ batch_size, channels, *remaining_dims = sample.shape
257
+
258
+ if dtype not in (torch.float32, torch.float64):
259
+ sample = (
260
+ sample.float()
261
+ ) # upcast for quantile calculation, and clamp not implemented for cpu half
262
+
263
+ # Flatten sample for doing quantile calculation along each image
264
+ sample = sample.reshape(batch_size, -1)
265
+
266
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
267
+
268
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
269
+ s = torch.clamp(
270
+ s, min=1, max=self.config.sample_max_value
271
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
272
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
273
+ sample = (
274
+ torch.clamp(sample, -s, s) / s
275
+ ) # "we threshold xt0 to the range [-s, s] and then divide by s"
276
+
277
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
278
+ sample = sample.to(dtype)
279
+
280
+ return sample
281
+
282
+ def set_timesteps(
283
+ self, num_inference_steps: int, device: Union[str, torch.device] = None
284
+ ):
285
+ """
286
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
287
+
288
+ Args:
289
+ num_inference_steps (`int`):
290
+ The number of diffusion steps used when generating samples with a pre-trained model.
291
+ """
292
+
293
+ if num_inference_steps > self.config.num_diffusion_train_steps:
294
+ raise ValueError(
295
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.num_diffusion_train_steps`:"
296
+ f" {self.num_diffusion_train_steps} as the unet model trained with this scheduler can only handle"
297
+ f" maximal {self.num_diffusion_train_steps} timesteps."
298
+ )
299
+
300
+ self.num_inference_steps = num_inference_steps
301
+
302
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
303
+ # With T the number of training steps and S the number of inference steps
304
+
305
+ if self.config.timestep_spacing == "linspace":
306
+ # Linspace: flip round(linspace(1, T, S))
307
+ # With T=1000 and S=10; [999, 888, 777, 666, 555, 444, 333, 222, 111, 0]
308
+ timesteps = torch.linspace(
309
+ 0,
310
+ self.config.num_diffusion_train_steps - 1,
311
+ self.num_inference_steps,
312
+ device=device,
313
+ dtype=torch.long,
314
+ )
315
+ timesteps = torch.flip(timesteps, dims=(0,)).round()
316
+
317
+ elif self.config.timestep_spacing == "leading":
318
+ # Leading: flip arange(1, T + 1, floor(T /S))
319
+ # With T=1000 and S=10: [900, 800, 700, 600, 500, 400, 300, 200, 100, 0]
320
+
321
+ leading_step_ratio = (
322
+ self.num_diffusion_train_steps // self.num_inference_steps
323
+ )
324
+ timesteps = torch.arange(
325
+ start=0,
326
+ end=self.num_diffusion_train_steps,
327
+ step=leading_step_ratio,
328
+ device=device,
329
+ dtype=torch.long,
330
+ )
331
+ timesteps = torch.flip(timesteps, dims=(0,)).round()
332
+
333
+ elif self.config.timestep_spacing == "trailing":
334
+ # Trailing: round(flip(arange(T, 0, −T /S)))
335
+ # With T=1000 and S=10: [999, 899, 799, 699, 599, 499, 399, 299, 199, 99]
336
+ trailing_step_ratio: float = (
337
+ self.num_diffusion_train_steps / self.num_inference_steps
338
+ )
339
+ # creates integer timesteps by multiplying by ratio
340
+ timesteps = torch.arange(
341
+ self.config.num_diffusion_train_steps,
342
+ 0,
343
+ -trailing_step_ratio,
344
+ device=device,
345
+ dtype=torch.long,
346
+ ).round()
347
+ timesteps -= 1
348
+ else:
349
+ raise ValueError(
350
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
351
+ )
352
+
353
+ self.timesteps = timesteps
354
+ logger.debug(
355
+ f"With `{self.config.timestep_spacing}`, setting inference timesteps to {self.timesteps}"
356
+ )
357
+
358
+ def step(
359
+ self,
360
+ model_output: Tensor,
361
+ timestep: int,
362
+ sample: Tensor,
363
+ eta: float = 0.0,
364
+ use_clipped_model_output: bool = False,
365
+ generator=None,
366
+ variance_noise: Optional[Tensor] = None,
367
+ prediction_type: Optional[str] = None,
368
+ epsilon_scaling: Optional[float] = None,
369
+ ) -> DDIMSchedulerOutput:
370
+ """
371
+ INFERENCE ONLY.
372
+ Predict the sample from the previous timestep by reversing the SDE.
373
+ This function propagates the diffusion
374
+ process from the learned model outputs.
375
+
376
+ Args:
377
+ model_output (`Tensor`):
378
+ The direct output from learned diffusion model.
379
+ timestep (`float`):
380
+ The current discrete timestep in the diffusion chain.
381
+ sample (`Tensor`):
382
+ A current instance of a sample created by the diffusion process.
383
+ eta (`float`):
384
+ The weight of noise for added noise in diffusion step.
385
+ use_clipped_model_output (`bool`, defaults to `False`):
386
+ If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
387
+ because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
388
+ clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
389
+ `use_clipped_model_output` has no effect.
390
+ generator (`torch.Generator`, *optional*):
391
+ A random number generator.
392
+ variance_noise (`Tensor`):
393
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
394
+ itself. Useful for methods such as [`CycleDiffusion`].
395
+ prediction_type: Optional[str] if provided we step with a different prediction_type
396
+ than the one in the config
397
+ epsilon_scaling: Optional[float] if not None, the predicted epsilon will be scaled down by
398
+ the provided factor as introduced in https://arxiv.org/pdf/2308.15321
399
+
400
+ Returns:
401
+ DDIMSchedulerOutput
402
+
403
+ """
404
+ if self.num_inference_steps is None:
405
+ raise ValueError(
406
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
407
+ )
408
+
409
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
410
+ # Ideally, read DDIM paper in-detail understanding
411
+
412
+ # Notation (<variable name> -> <name in paper>
413
+ # - pred_noise_t -> e_theta(x_t, t)
414
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
415
+ # - std_dev_t -> sigma_t
416
+ # - eta -> η
417
+ # - pred_sample_direction -> "direction pointing to x_t"
418
+ # - pred_prev_sample -> "x_t-1"
419
+
420
+ # 1. Get previous step value (=t-1)
421
+ prev_timestep = (
422
+ timestep - self.config.num_diffusion_train_steps // self.num_inference_steps
423
+ )
424
+
425
+ # 2. Compute alphas, betas
426
+ alpha_prod_t = self.alphas_cumprod[timestep]
427
+ alpha_prod_t_prev = (
428
+ self.alphas_cumprod[prev_timestep]
429
+ if prev_timestep >= 0
430
+ else self.final_alpha_cumprod
431
+ )
432
+
433
+ beta_prod_t = 1 - alpha_prod_t
434
+
435
+ # 3. Compute predicted original sample from predicted noise also called
436
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
437
+ prediction_type = prediction_type or self.prediction_type
438
+ if prediction_type == "epsilon":
439
+ pred_original_sample = (
440
+ sample - beta_prod_t ** (0.5) * model_output
441
+ ) / alpha_prod_t ** (0.5)
442
+ pred_epsilon = model_output
443
+ elif prediction_type == "sample":
444
+ pred_original_sample = model_output
445
+ pred_epsilon = (
446
+ sample - alpha_prod_t ** (0.5) * pred_original_sample
447
+ ) / beta_prod_t ** (0.5)
448
+ elif prediction_type == "v_prediction":
449
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (
450
+ beta_prod_t**0.5
451
+ ) * model_output
452
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (
453
+ beta_prod_t**0.5
454
+ ) * sample
455
+ else:
456
+ raise ValueError(
457
+ f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or"
458
+ " `v_prediction`"
459
+ )
460
+
461
+ # 3.a epsilon scaling:
462
+ if epsilon_scaling is not None:
463
+ pred_epsilon = pred_epsilon / epsilon_scaling
464
+
465
+ # 4. Clip or threshold "predicted x_0"
466
+ if self.config.thresholding:
467
+ pred_original_sample = self._threshold_sample(pred_original_sample)
468
+ elif self.config.clip_sample:
469
+ pred_original_sample = pred_original_sample.clamp(
470
+ -self.config.clip_sample_range, self.config.clip_sample_range
471
+ )
472
+
473
+ # 5. Compute variance: "sigma_t(η)" -> see formula (16)
474
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
475
+ variance = self._get_variance(timestep, prev_timestep)
476
+ std_dev_t = eta * variance ** (0.5)
477
+ if use_clipped_model_output:
478
+ # the pred_epsilon is always re-derived from the clipped x_0 in Glide
479
+ pred_epsilon = (
480
+ sample - alpha_prod_t ** (0.5) * pred_original_sample
481
+ ) / beta_prod_t ** (0.5)
482
+ # 6. Compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
483
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (
484
+ 0.5
485
+ ) * pred_epsilon
486
+ # 7. Compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
487
+ prev_sample = (
488
+ alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
489
+ )
490
+
491
+ if eta > 0:
492
+ if variance_noise is not None and generator is not None:
493
+ raise ValueError(
494
+ "Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
495
+ " `variance_noise` stays `None`."
496
+ )
497
+
498
+ if variance_noise is None:
499
+ variance_noise = randn_tensor(
500
+ model_output.shape,
501
+ generator=generator,
502
+ device=model_output.device,
503
+ dtype=model_output.dtype,
504
+ )
505
+ variance = std_dev_t * variance_noise
506
+ prev_sample = prev_sample + variance
507
+
508
+ return DDIMSchedulerOutput(
509
+ prev_sample=prev_sample, pred_original_sample=pred_original_sample
510
+ )
511
+
512
+ def add_noise(
513
+ self,
514
+ original_samples: Tensor,
515
+ noise: Tensor,
516
+ timesteps: Tensor,
517
+ ) -> Tensor:
518
+ """TRAINING ONLY
519
+ Forward noising process during training"""
520
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
521
+ # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
522
+ # for the subsequent add_noise calls
523
+ self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
524
+ alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
525
+ timesteps = timesteps.to(original_samples.device).to(torch.int32)
526
+
527
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
528
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
529
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
530
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
531
+
532
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
533
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
534
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
535
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
536
+
537
+ noisy_samples = (
538
+ sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
539
+ )
540
+ return noisy_samples
541
+
542
+ def get_velocity(self, sample: Tensor, noise: Tensor, timesteps: Tensor) -> Tensor:
543
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
544
+ self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
545
+ alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
546
+ timesteps = timesteps.to(sample.device).to(torch.int32)
547
+
548
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
549
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
550
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
551
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
552
+
553
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
554
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
555
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
556
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
557
+
558
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
559
+ return velocity
560
+
561
+ def get_epsilon(
562
+ self, model_output: Tensor, sample: Tensor, timestep: int
563
+ ) -> Tensor:
564
+ """Given model inputs (sample) and outputs (model_output)
565
+ Predict the noise residual according to the scheduler's
566
+ prediction type"""
567
+
568
+ pred_type = self.prediction_type
569
+
570
+ alpha_prod_t = self.alphas_cumprod[timestep]
571
+
572
+ beta_prod_t = 1 - alpha_prod_t
573
+
574
+ if pred_type == "epsilon":
575
+ return model_output
576
+
577
+ elif pred_type == "sample":
578
+ return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (
579
+ 0.5
580
+ )
581
+
582
+ elif pred_type == "v_prediction":
583
+ return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
584
+ else:
585
+ raise ValueError(
586
+ f"The scheduler's prediction type {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`"
587
+ )
588
+
589
+
590
+ def randn_tensor(
591
+ shape: Union[Tuple, List],
592
+ generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
593
+ device: Optional["torch.device"] = None,
594
+ dtype: Optional["torch.dtype"] = None,
595
+ layout: Optional["torch.layout"] = None,
596
+ ):
597
+ """A helper function to create random tensors on the desired `device` with the desired `dtype`. When
598
+ passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor
599
+ is always created on the CPU.
600
+ """
601
+ # device on which tensor is created defaults to device
602
+ rand_device = device
603
+ batch_size = shape[0]
604
+
605
+ layout = layout or torch.strided
606
+ device = device or torch.device("cpu")
607
+
608
+ if generator is not None:
609
+ gen_device_type = (
610
+ generator.device.type
611
+ if not isinstance(generator, list)
612
+ else generator[0].device.type
613
+ )
614
+ if gen_device_type != device.type and gen_device_type == "cpu":
615
+ rand_device = CPU
616
+ if device != "mps":
617
+ logger.info(
618
+ f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
619
+ f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
620
+ f" slighly speed up this function by passing a generator that was created on the {device} device."
621
+ )
622
+ elif gen_device_type != device.type and gen_device_type == "cuda":
623
+ raise ValueError(
624
+ f"Cannot generate a {device} tensor from a generator of type {gen_device_type}."
625
+ )
626
+
627
+ # make sure generator list of length 1 is treated like a non-list
628
+ if isinstance(generator, list) and len(generator) == 1:
629
+ generator = generator[0]
630
+
631
+ if isinstance(generator, list):
632
+ shape = (1,) + shape[1:] # type: ignore
633
+ latents_list = [
634
+ torch.randn(
635
+ shape,
636
+ generator=generator[i],
637
+ device=rand_device,
638
+ dtype=dtype,
639
+ layout=layout,
640
+ )
641
+ for i in range(batch_size)
642
+ ]
643
+ latents = torch.cat(latents_list, dim=0).to(device)
644
+ else:
645
+ latents = torch.randn(
646
+ shape, generator=generator, device=rand_device, dtype=dtype, layout=layout
647
+ ).to(device)
648
+
649
+ return latents
650
+
651
+
652
+ def betas_for_alpha_bar(
653
+ num_diffusion_timesteps: int,
654
+ max_beta: float = 0.999,
655
+ alpha_transform_type: Literal["cosine", "exp", "sigmoid"] = "cosine",
656
+ sigmoid_alpha: float = 1.5,
657
+ sigmoid_beta: float = 0,
658
+ ) -> Tensor:
659
+ """
660
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
661
+ (1-beta) over time from t = [0,1].
662
+
663
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
664
+ to that part of the diffusion process.
665
+
666
+
667
+ Args:
668
+ num_diffusion_timesteps (`int`): the number of betas to produce.
669
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
670
+ prevent singularities.
671
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
672
+ Choose from `cosine` or `exp`
673
+ sigmoid_alpha/sigmoid_beta: additional hyper-parameters for the sigmoid schedule
674
+
675
+ Returns:
676
+ betas (`Tensor`): the betas used by the scheduler to step the model outputs
677
+ """
678
+ if alpha_transform_type == "cosine":
679
+
680
+ def alpha_bar_fn(t):
681
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
682
+
683
+ elif alpha_transform_type == "sigmoid":
684
+
685
+ def alpha_bar_fn(t):
686
+ epsilon = 1e-32
687
+ return sigmoid(
688
+ sigmoid_beta
689
+ - sigmoid_alpha
690
+ * logit(torch.clamp(torch.tensor(t), min=epsilon, max=1 - epsilon))
691
+ )
692
+
693
+ elif alpha_transform_type == "exp":
694
+
695
+ def alpha_bar_fn(t):
696
+ return math.exp(t * -12.0)
697
+
698
+ else:
699
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
700
+
701
+ betas = []
702
+ for i in range(num_diffusion_timesteps):
703
+ t1 = i / num_diffusion_timesteps
704
+ t2 = (i + 1) / num_diffusion_timesteps
705
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
706
+ return torch.tensor(betas, dtype=torch.float32)
707
+
708
+
709
+ def rescale_zero_terminal_snr(betas: Tensor) -> Tensor:
710
+ """
711
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
712
+
713
+ Args:
714
+ betas (`Tensor`):
715
+ the betas that the scheduler is being initialized with.
716
+
717
+ Returns:
718
+ `Tensor`: rescaled betas with zero terminal SNR
719
+ """
720
+ # Convert betas to alphas_bar_sqrt
721
+ alphas = 1.0 - betas
722
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
723
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
724
+
725
+ # Store old values.
726
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
727
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
728
+
729
+ # Shift so the last timestep is zero.
730
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
731
+
732
+ # Scale so the first timestep is back to the old value.
733
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
734
+
735
+ # Convert alphas_bar_sqrt to betas
736
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
737
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
738
+ alphas = torch.cat([alphas_bar[0:1], alphas])
739
+ betas = 1 - alphas
740
+
741
+ return betas
lcm/nn/timestep_encoder.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ import math
7
+ from typing import Optional
8
+
9
+ import torch
10
+ from fairseq2.nn.projection import Linear
11
+ from fairseq2.typing import DataType, Device
12
+ from torch import Tensor
13
+ from torch.nn import Module
14
+
15
+ from lcm.nn.initialization import parse_activation_fn
16
+
17
+
18
+ class DiTTimestepEncoder(Module):
19
+ """
20
+ Embeds scalar timesteps into vector representations.
21
+ Based on DiT's `TimestepEmbedder`
22
+ https://github.com/facebookresearch/DiT/blob/main/models.py
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ embedding_dim: int,
28
+ frequency_embedding_size: int = 256,
29
+ activation_fn_name: str = "silu",
30
+ device: Optional[Device] = None,
31
+ dtype: Optional[DataType] = None,
32
+ ):
33
+ super().__init__()
34
+
35
+ self.dtype = dtype
36
+
37
+ self.device = device
38
+
39
+ self.embedding_dim = embedding_dim
40
+
41
+ self.frequency_embedding_size = frequency_embedding_size
42
+
43
+ self.fc1 = Linear(
44
+ frequency_embedding_size,
45
+ embedding_dim,
46
+ bias=True,
47
+ device=device,
48
+ dtype=dtype,
49
+ )
50
+ self.nonlin = parse_activation_fn(activation_fn_name)
51
+ self.fc2 = Linear(
52
+ embedding_dim,
53
+ embedding_dim,
54
+ bias=True,
55
+ device=device,
56
+ dtype=dtype,
57
+ )
58
+
59
+ self.reset_parameters()
60
+
61
+ def reset_parameters(self) -> None:
62
+ """Reset the parameters and buffers of the module."""
63
+ torch.nn.init.normal_(self.fc1.weight, std=0.02)
64
+ torch.nn.init.normal_(self.fc2.weight, std=0.02)
65
+
66
+ if self.fc1.bias is not None:
67
+ torch.nn.init.zeros_(self.fc1.bias)
68
+
69
+ if self.fc2.bias is not None:
70
+ torch.nn.init.zeros_(self.fc2.bias)
71
+
72
+ @staticmethod
73
+ def sinusoidal_timestep_embedding(
74
+ timestep, frequency_embedding_size, max_period=10000
75
+ ):
76
+ """
77
+ Create sinusoidal timestep embeddings.
78
+ :param timestep: a 1-D Tensor of N indices, one per batch element.
79
+ These may be fractional.
80
+ :param frequency_embedding_size: the dimension of the output.
81
+ :param max_period: controls the minimum frequency of the embeddings.
82
+ :return: an (N, D) Tensor of positional embeddings.
83
+
84
+ Based on DiT's `TimestepEmbedder`
85
+ https://github.com/facebookresearch/DiT/blob/main/models.py
86
+ """
87
+ half = frequency_embedding_size // 2
88
+
89
+ freqs = torch.exp(
90
+ -math.log(max_period)
91
+ * torch.arange(start=0, end=half, dtype=torch.float32)
92
+ / half
93
+ ).to(device=timestep.device)
94
+
95
+ args = timestep[:, None].float() * freqs[None]
96
+
97
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
98
+
99
+ if frequency_embedding_size % 2:
100
+ embedding = torch.cat(
101
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
102
+ )
103
+
104
+ return embedding
105
+
106
+ def forward(self, timesteps: Tensor) -> Tensor:
107
+ initial_size = timesteps.size()
108
+
109
+ flat_timesteps = timesteps.view(-1, 1)
110
+
111
+ t_freq = self.sinusoidal_timestep_embedding(
112
+ flat_timesteps, self.frequency_embedding_size
113
+ ).to(self.dtype)
114
+
115
+ t_emb = self.fc1(t_freq)
116
+
117
+ if self.nonlin is not None:
118
+ t_emb = self.nonlin(t_emb)
119
+
120
+ t_emb = self.fc2(t_emb)
121
+
122
+ return t_emb.view(*initial_size, self.embedding_dim)
lcm/nn/transformer/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from lcm.nn.transformer.attention import (
7
+ QKNormMultiheadAttention,
8
+ )
9
+ from lcm.nn.transformer.decoder import (
10
+ LCMStandardTransformerDecoderLayer,
11
+ LCMTransformerDecoder,
12
+ )
13
+ from lcm.nn.transformer.factory import (
14
+ TransformerConfig,
15
+ TransformerFactory,
16
+ )
17
+
18
+ __all__ = [
19
+ "QKNormMultiheadAttention",
20
+ "LCMStandardTransformerDecoderLayer",
21
+ "LCMTransformerDecoder",
22
+ "TransformerConfig",
23
+ "TransformerFactory",
24
+ ]
lcm/nn/transformer/attention.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from typing import Optional, Tuple, final
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from fairseq2.nn.ops import repeat_interleave
11
+ from fairseq2.nn.padding import PaddingMask
12
+ from fairseq2.nn.position_encoder import PositionEncoder
13
+ from fairseq2.nn.projection import Projection
14
+ from fairseq2.nn.transformer import (
15
+ AttentionMask,
16
+ AttentionMaskFactory,
17
+ AttentionState,
18
+ AttentionStateFactory,
19
+ FullAttentionState,
20
+ LayerNormFactory,
21
+ StandardMultiheadAttention,
22
+ create_standard_layer_norm,
23
+ )
24
+ from fairseq2.nn.transformer.attention import SDPA
25
+ from fairseq2.typing import DataType, Device, override
26
+ from torch import Tensor
27
+ from torch.nn.parameter import Parameter
28
+
29
+ # FIXME revert to fs2's standard state bag if possible
30
+ from lcm.nn.incremental_state import (
31
+ LCMIncrementalStateBag,
32
+ )
33
+
34
+
35
+ @final
36
+ class QKNormMultiheadAttention(StandardMultiheadAttention): # type: ignore
37
+ """Represents a Transformer multi-head attention as described in
38
+ :cite:t:`https://doi.org/10.48550/arxiv.1706.03762`
39
+ with two additional layer-normalization for keys and queries
40
+ as described in https://arxiv.org/pdf/2302.05442
41
+ and other related work
42
+ """
43
+
44
+ kv_dim: int
45
+ num_key_value_heads: int
46
+ q_proj: Projection
47
+ k_proj: Projection
48
+ v_proj: Projection
49
+ attn_mask_factory: Optional[AttentionMaskFactory]
50
+ pos_encoder: Optional[PositionEncoder]
51
+ bias_k: Optional[Parameter]
52
+ bias_v: Optional[Parameter]
53
+ add_zero_attn: bool
54
+ sdpa: SDPA
55
+ head_scale_weight: Optional[Parameter]
56
+ output_proj: Projection
57
+ state_factory: Optional[AttentionStateFactory]
58
+ layer_norm_factory: Optional[LayerNormFactory]
59
+
60
+ """
61
+ For full parameters description see fairseq2/src/fairseq2/nn/transformer/multihead_attention.py
62
+ Parameters of interest to us:
63
+ :param num_key_value_heads:
64
+ The number of key/value heads for Grouped Query Attention as
65
+ described in :cite:t:`https://doi.org/10.48550/arXiv.2305.13245`.
66
+ If ``None`` or set to ``num_heads``, it is equivalent to standard
67
+ Multi Head Attention (MHA); if set to 1, it is equivalent to Multi
68
+ Query Attention (MQA).
69
+
70
+ :param enable_qk_layernorm:
71
+ If True follow Q/K projections with LayerNorms
72
+
73
+ :param weight_normalization:
74
+ If True, wrap K/Q/V projections with weight normalization for regularization
75
+
76
+ :param pos_encoder:
77
+ For RoPE positional encoder that adds positional encoding to keys
78
+ and queries before computing the attention scores
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ model_dim: int,
84
+ num_heads: int,
85
+ *,
86
+ kv_dim: Optional[int] = None,
87
+ num_key_value_heads: Optional[int] = None,
88
+ q_proj: Optional[Projection] = None,
89
+ k_proj: Optional[Projection] = None,
90
+ v_proj: Optional[Projection] = None,
91
+ attn_mask_factory: Optional[AttentionMaskFactory] = None,
92
+ pos_encoder: Optional[PositionEncoder] = None,
93
+ sdpa: Optional[SDPA] = None,
94
+ scale_heads: bool = False,
95
+ output_proj: Optional[Projection] = None,
96
+ bias: bool = True,
97
+ state_factory: Optional[AttentionStateFactory] = None,
98
+ enable_qk_layernorm: bool = False,
99
+ weight_normalization: bool = False,
100
+ layer_norm_factory: Optional[LayerNormFactory] = None,
101
+ device: Optional[Device] = None,
102
+ dtype: Optional[DataType] = None,
103
+ ) -> None:
104
+ super().__init__(
105
+ model_dim=model_dim,
106
+ num_heads=num_heads,
107
+ kv_dim=kv_dim,
108
+ num_key_value_heads=num_key_value_heads,
109
+ q_proj=q_proj,
110
+ k_proj=k_proj,
111
+ v_proj=v_proj,
112
+ attn_mask_factory=attn_mask_factory,
113
+ pos_encoder=pos_encoder,
114
+ sdpa=sdpa,
115
+ scale_heads=scale_heads,
116
+ output_proj=output_proj,
117
+ bias=bias,
118
+ state_factory=state_factory,
119
+ device=device,
120
+ dtype=dtype,
121
+ )
122
+
123
+ # wrap linear layers with weight norm
124
+ if weight_normalization:
125
+ self.k_proj = nn.utils.parametrizations.weight_norm(self.k_proj)
126
+ self.q_proj = nn.utils.parametrizations.weight_norm(self.q_proj)
127
+ self.v_proj = nn.utils.parametrizations.weight_norm(self.v_proj)
128
+
129
+ self.enable_qk_layernorm = enable_qk_layernorm
130
+ # initialize q-k LayerNorms if needed
131
+ if self.enable_qk_layernorm:
132
+ if layer_norm_factory is None:
133
+ # use default LayerNorm factory
134
+ layer_norm_factory = create_standard_layer_norm
135
+
136
+ self.q_layer_norm = layer_norm_factory(
137
+ model_dim, device=device, dtype=dtype
138
+ )
139
+ self.k_layer_norm = layer_norm_factory(
140
+ self.kv_dim, device=device, dtype=dtype
141
+ )
142
+
143
+ @override
144
+ def _project_q( # type: ignore
145
+ self,
146
+ seqs: Tensor,
147
+ padding_mask: Optional[PaddingMask],
148
+ state_bag: Optional[LCMIncrementalStateBag] = None,
149
+ ) -> Tensor:
150
+ # (N, S, M) -> (N, S, K_proj)
151
+ q = self.q_proj(seqs)
152
+
153
+ # normalize queries
154
+ if self.enable_qk_layernorm:
155
+ q = self.q_layer_norm(q)
156
+
157
+ # (N, S, K_proj) -> (N, H, S, K_h)
158
+ q = q.unflatten(-1, (self.num_heads, -1)).transpose(1, 2)
159
+
160
+ if self.pos_encoder is not None:
161
+ q = self.pos_encoder(
162
+ q,
163
+ padding_mask,
164
+ state_bag=state_bag,
165
+ )
166
+
167
+ return q # type: ignore[no-any-return]
168
+
169
+ @override
170
+ def _project_kv( # type: ignore
171
+ self,
172
+ keys: Tensor,
173
+ key_padding_mask: Optional[PaddingMask],
174
+ values: Tensor,
175
+ state_bag: Optional[LCMIncrementalStateBag] = None,
176
+ ) -> Tuple[Tensor, Tensor]:
177
+ # (N, S, K) -> (N, S, K_proj)
178
+ k = self.k_proj(keys)
179
+
180
+ # normalize keys
181
+ if self.enable_qk_layernorm:
182
+ k = self.k_layer_norm(k)
183
+
184
+ # (N, S, V) -> (N, S, V_proj)
185
+ v = self.v_proj(values)
186
+
187
+ # (N, S, K_proj) -> (N, H, S, K_h)
188
+ k = k.unflatten(-1, (self.num_key_value_heads, -1)).transpose(1, 2)
189
+ # (N, S, V_proj) -> (N, H, S, V_h)
190
+ v = v.unflatten(-1, (self.num_key_value_heads, -1)).transpose(1, 2)
191
+
192
+ if self.pos_encoder is not None:
193
+ k = self.pos_encoder(
194
+ k,
195
+ key_padding_mask,
196
+ state_bag=state_bag,
197
+ )
198
+
199
+ return k, v
200
+
201
+ @override
202
+ def forward( # type: ignore
203
+ self,
204
+ seqs: Tensor,
205
+ padding_mask: Optional[PaddingMask],
206
+ keys: Tensor,
207
+ key_padding_mask: Optional[PaddingMask],
208
+ values: Tensor,
209
+ *,
210
+ attn_mask: Optional[AttentionMask] = None,
211
+ state_bag: Optional[LCMIncrementalStateBag] = None,
212
+ ) -> Tensor:
213
+ # (N, S, M) -> (N, H, S, K_h)
214
+ q = self._project_q(
215
+ seqs,
216
+ padding_mask,
217
+ state_bag,
218
+ )
219
+ if self.training or state_bag is None:
220
+ # k: (N, S_kv, M) -> (N, H_kv, S_kv, K_h)
221
+ # v: (N, S_kv, M) -> (N, H_kv, S_kv, V_h)
222
+ k, v = self._project_kv(
223
+ keys,
224
+ key_padding_mask,
225
+ values,
226
+ )
227
+ else:
228
+ if key_padding_mask is not None:
229
+ raise ValueError(
230
+ "`key_padding_mask` must be `None` during incremental decoding."
231
+ )
232
+
233
+ # k: (N, S_step, M) -> (N, H_kv, S_step, K_h)
234
+ # v: (N, S_step, M) -> (N, H_kv, S_step, V_h)
235
+ k, v = self._project_kv(keys, key_padding_mask, values, state_bag)
236
+
237
+ state = state_bag.get_state(self, AttentionState) # type: ignore
238
+
239
+ if state is None:
240
+ state_factory = self.state_factory or FullAttentionState
241
+
242
+ state = state_factory(
243
+ k, v, state_bag.max_num_steps, state_bag.capacity_increment
244
+ )
245
+
246
+ state_bag.set_state(self, state)
247
+ else:
248
+ state.append(k, v)
249
+
250
+ # k: (N, H_kv, S_kv, K_h)
251
+ # v: (N, H_kv, S_kv, V_h)
252
+
253
+ k, v = state.get()
254
+
255
+ # With Grouped Query Attention, each key/value head is repeated.
256
+ if (num_query_groups := self.num_heads // self.num_key_value_heads) > 1:
257
+ # (N, H_kv, S_kv, K_h) -> (N, H, S_kv, K_h)
258
+ k = repeat_interleave(k, dim=1, repeat=num_query_groups)
259
+ # (N, H_kv, S_kv, K_h) -> (N, H, S_kv, V_h)
260
+ v = repeat_interleave(v, dim=1, repeat=num_query_groups)
261
+
262
+ if self.attn_mask_factory is not None:
263
+ attn_mask = self.attn_mask_factory(
264
+ seqs, keys=keys, training=self.training, state_bag=state_bag
265
+ )
266
+
267
+ needs_weights = len(self._attn_weight_hooks) > 0
268
+
269
+ # attn: (N, H, S, V_h)
270
+ # attn_weights: (N, H, S, S_kv)
271
+
272
+ attn, attn_weights = self.sdpa(
273
+ q,
274
+ k,
275
+ key_padding_mask,
276
+ v,
277
+ attn_mask=attn_mask,
278
+ needs_weights=needs_weights,
279
+ )
280
+
281
+ if attn_weights is not None:
282
+ for hook in self._attn_weight_hooks.values():
283
+ hook(self, attn, attn_weights)
284
+
285
+ # (N, H, S, V_h) -> (N, S, H, V_h)
286
+ attn = attn.transpose(1, 2)
287
+
288
+ if self.head_scale_weight is not None:
289
+ attn = torch.einsum("nshv,h->nshv", attn, self.head_scale_weight)
290
+
291
+ # (N, S, H, V_h) -> (N, S, V_proj)
292
+ attn = attn.flatten(2, 3)
293
+
294
+ # (N, S, V_proj) -> (N, S, M)
295
+
296
+ attn = self.output_proj(attn)
297
+
298
+ return attn # type: ignore[no-any-return]
299
+
300
+ @override
301
+ def extra_repr(self) -> str:
302
+ """:meta private:"""
303
+ s = super().extra_repr()
304
+
305
+ s = f"{s}, enable_qk_layernorm={self.enable_qk_layernorm}"
306
+
307
+ return s
lcm/nn/transformer/decoder.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from typing import List, Optional, Tuple
7
+
8
+ from fairseq2.nn.padding import PaddingMask
9
+ from fairseq2.nn.transformer import (
10
+ AttentionMask,
11
+ AttentionMaskFactory,
12
+ LayerNormFactory,
13
+ StandardTransformerDecoderLayer,
14
+ TransformerDecoder,
15
+ TransformerDecoderLayer,
16
+ TransformerNormOrder,
17
+ )
18
+ from fairseq2.typing import DataType, Device, override
19
+ from torch import Generator, Tensor
20
+ from torch.nn import Dropout, ModuleList
21
+
22
+ from lcm.nn.incremental_state import LCMIncrementalStateBag
23
+
24
+
25
+ class LCMStandardTransformerDecoderLayer(StandardTransformerDecoderLayer): # type: ignore
26
+ """Pass on `source_lengths` to StandardTransformerDecoderLayer's forward_pass."""
27
+
28
+ @override
29
+ def forward( # type: ignore
30
+ self,
31
+ seqs: Tensor,
32
+ padding_mask: Optional[PaddingMask],
33
+ self_attn_mask: Optional[AttentionMask] = None,
34
+ encoder_output: Optional[Tensor] = None,
35
+ encoder_padding_mask: Optional[PaddingMask] = None,
36
+ *,
37
+ state_bag: Optional[LCMIncrementalStateBag] = None,
38
+ ) -> Tuple[Tensor, Optional[PaddingMask]]:
39
+ seqs = self._forward_self_attn(
40
+ seqs,
41
+ padding_mask,
42
+ self_attn_mask,
43
+ state_bag,
44
+ )
45
+
46
+ seqs = self._forward_encoder_decoder_attn(
47
+ seqs, padding_mask, encoder_output, encoder_padding_mask, state_bag
48
+ )
49
+
50
+ seqs = self._forward_ffn(seqs)
51
+
52
+ return seqs, padding_mask
53
+
54
+ @override
55
+ def _forward_self_attn( # type: ignore
56
+ self,
57
+ seqs: Tensor,
58
+ padding_mask: Optional[PaddingMask],
59
+ self_attn_mask: Optional[AttentionMask],
60
+ state_bag: Optional[LCMIncrementalStateBag],
61
+ ) -> Tensor:
62
+ residual = seqs
63
+
64
+ if self.norm_order != TransformerNormOrder.POST:
65
+ seqs = self.self_attn_layer_norm(seqs)
66
+
67
+ seqs = self.self_attn(
68
+ seqs,
69
+ padding_mask,
70
+ keys=seqs,
71
+ key_padding_mask=padding_mask,
72
+ values=seqs,
73
+ attn_mask=self_attn_mask,
74
+ state_bag=state_bag,
75
+ )
76
+
77
+ if self.self_attn_norm is not None:
78
+ seqs = self.self_attn_norm(seqs)
79
+
80
+ if self.self_attn_dropout is not None:
81
+ seqs = self.self_attn_dropout(seqs)
82
+
83
+ seqs = seqs + residual
84
+
85
+ if self.norm_order == TransformerNormOrder.POST:
86
+ seqs = self.self_attn_layer_norm(seqs)
87
+
88
+ return seqs
89
+
90
+
91
+ class LCMTransformerDecoder(TransformerDecoder):
92
+ def __init__(
93
+ self,
94
+ layers: List[TransformerDecoderLayer],
95
+ layer_norm_factory: LayerNormFactory,
96
+ self_attn_mask_factory: AttentionMaskFactory,
97
+ use_causal_attn_mask: bool = True,
98
+ generator: Optional[Generator] = None,
99
+ dropout_p: float = 0.0,
100
+ norm_order: TransformerNormOrder = TransformerNormOrder.POST,
101
+ device: Optional[Device] = None,
102
+ dtype: Optional[DataType] = None,
103
+ ) -> None:
104
+ layer_list = ModuleList(layers)
105
+
106
+ if not layer_list:
107
+ raise ValueError("`layers` must be non-empty.")
108
+
109
+ model_dim = layer_list[0].model_dim
110
+
111
+ super().__init__(model_dim)
112
+
113
+ self.self_attn_mask_factory = self_attn_mask_factory
114
+
115
+ self.layers = layer_list
116
+
117
+ self.generator = generator
118
+
119
+ if norm_order != TransformerNormOrder.POST:
120
+ self.layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype)
121
+ else:
122
+ self.register_module("layer_norm", None)
123
+
124
+ if dropout_p > 0.0:
125
+ self.dropout = Dropout(dropout_p)
126
+ else:
127
+ self.register_module("dropout", None)
128
+
129
+ self.norm_order = norm_order
130
+
131
+ @override
132
+ def forward( # type: ignore
133
+ self,
134
+ seqs: Tensor,
135
+ padding_mask: Optional[PaddingMask],
136
+ encoder_output: Optional[Tensor] = None,
137
+ encoder_padding_mask: Optional[PaddingMask] = None,
138
+ *,
139
+ state_bag: Optional[LCMIncrementalStateBag] = None,
140
+ **kwargs,
141
+ ) -> Tuple[Tensor, Optional[PaddingMask]]:
142
+ """Pass on two additional arguments to StandardTransformerDecoder's forward_pass:"""
143
+ num_layers = len(self.layers)
144
+
145
+ self_attn_mask: Optional[AttentionMask] = None
146
+ if self.self_attn_mask_factory is not None:
147
+ self_attn_mask = self.self_attn_mask_factory(
148
+ seqs,
149
+ keys=seqs,
150
+ training=self.training,
151
+ state_bag=state_bag,
152
+ )
153
+
154
+ for layer_idx, layer in enumerate(self.layers):
155
+ layer_output, layer_padding_mask = layer(
156
+ seqs,
157
+ padding_mask,
158
+ self_attn_mask,
159
+ encoder_output,
160
+ encoder_padding_mask,
161
+ state_bag=state_bag,
162
+ )
163
+
164
+ seqs, padding_mask = layer_output, layer_padding_mask
165
+
166
+ for hook in self._layer_output_hooks.values():
167
+ if not hook(layer_idx, seqs, padding_mask, num_layers):
168
+ break
169
+
170
+ if self.layer_norm is not None:
171
+ seqs = self.layer_norm(seqs)
172
+
173
+ if self.dropout is not None:
174
+ seqs = self.dropout(seqs)
175
+
176
+ return seqs, padding_mask
lcm/nn/transformer/factory.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Literal, Optional
8
+
9
+ import torch
10
+ from fairseq2.logging import get_log_writer
11
+ from fairseq2.nn import PositionEncoder
12
+ from fairseq2.nn.position_encoder import (
13
+ LearnedPositionEncoder,
14
+ RotaryEncoder,
15
+ SinusoidalPositionEncoder,
16
+ )
17
+ from fairseq2.nn.projection import Linear
18
+ from fairseq2.nn.transformer import (
19
+ FeedForwardNetwork,
20
+ GLUFeedForwardNetwork,
21
+ MultiheadAttention,
22
+ StandardFeedForwardNetwork,
23
+ TransformerDecoderLayer,
24
+ create_default_sdpa,
25
+ )
26
+ from fairseq2.typing import DataType, Device
27
+
28
+ from lcm.nn.initialization import (
29
+ SUPPORTED_INIT_TYPES,
30
+ get_init_fn,
31
+ parse_activation_fn,
32
+ parse_norm_order,
33
+ )
34
+ from lcm.nn.normalization import SUPPORTED_LN_TYPES, parse_layer_norm_factory
35
+ from lcm.nn.transformer import LCMStandardTransformerDecoderLayer
36
+ from lcm.nn.transformer.attention import (
37
+ FullAttentionState,
38
+ QKNormMultiheadAttention,
39
+ )
40
+
41
+ SUPPORTED_NORM_ORDERS = Literal["pre", "post", "normformer"]
42
+
43
+
44
+ logger = get_log_writer(__name__)
45
+
46
+
47
+ @dataclass
48
+ class TransformerConfig:
49
+ """A config object to group all config
50
+ hyper-parameters of a LCMTransformerDecoder"""
51
+
52
+ num_layers: int = 2
53
+
54
+ num_attn_heads: int = 8
55
+
56
+ # Dropout rates
57
+ dropout_p: float = 0.1
58
+ """ The dropout probability outputs of the attention layers and the
59
+ feed-forward network (before joining the residual stream)"""
60
+
61
+ final_dropout_p: float = 0.1
62
+ """ The dropout probability on decoder outputs"""
63
+
64
+ attention_dropout_p: float = 0.0
65
+ """the dropout rate on attention weights in SDPA"""
66
+
67
+ # FFN
68
+ ffn_inner_dim: int = 1024 * 4
69
+
70
+ use_swiglu: bool = False
71
+ """Use GLUFeedForwardNetwork instead of regular FFN blocks"""
72
+
73
+ ffn_inner_activation_name: str = "relu"
74
+
75
+ """The activation to apply to outputs of the FFN inner projection layer.
76
+ Default is `relu `i.e., `torch.nn.ReLU`. This is only relevant when `use_swiglu= False`"""
77
+
78
+ # positional embedding
79
+ pos_embedding_style: Literal["rope", "sine", "learned", "none"] = "learned"
80
+
81
+ """If `rope`: a rotary positional encoder in used in the attention layers.
82
+ If `sine`: Sinusoidal positional embeddings will be added in
83
+ the frontend before heading into the decoder
84
+ If `learned`: Learned positional embeddings will be added in
85
+ the frontend before heading into the decoder.
86
+ If `None`: no positional embeddings will be used (e.g. in the case
87
+ of unconditional diffusion of a single vector)."""
88
+
89
+ rope_theta: float = 10_000.0
90
+ """ The coefficient of the long-term decay of RoPE embeddings."""
91
+
92
+ # Normalization
93
+ layer_normalization_style: SUPPORTED_LN_TYPES = "standard"
94
+
95
+ norm_order_style: SUPPORTED_NORM_ORDERS = "pre"
96
+ """LayerNorm order in the transformer decoder,
97
+ default is pre-normalization (`pre`). Other options are post-normalization (`post`)
98
+ and normformer-style normalization (`normformer`)"""
99
+
100
+ final_norm_order_style: Optional[SUPPORTED_NORM_ORDERS] = None
101
+ """Controls lcm-level norm-order, using ``post`` here with a ``pre`` layer-level norm-order
102
+ means that we will skip the last layernorm in the stack"""
103
+
104
+ enable_qk_layernorm: bool = False
105
+ """If ``True``, LayerNorms will be applied to queries and keys in self-attention layers
106
+ QK-LayerNorm described in https://arxiv.org/pdf/2302.05442 and subsequent work
107
+ is recommended to alleviate Transformer training instabilities
108
+ """
109
+ mha_qkv_weight_normalization: bool = False
110
+ """if ``True`` wrap the K/Q/V linears of MHA in weight normalization"""
111
+
112
+ mha_output_weight_normalization: bool = False
113
+ """if ``True`` wrap the output projection of MHA with weight normalization.
114
+ This is a temporary fix to resume training some models and will be removed"""
115
+
116
+ # Miscellaneous
117
+ mha_output_proj_bias: bool = False
118
+ """If ``True`` add a bias term to the MHA output projection"""
119
+
120
+ scale_residual: Optional[float] = None
121
+ """scale to multiply the residual in the Transformer decoder"""
122
+
123
+ attention_output_init_fn: SUPPORTED_INIT_TYPES = "xavier"
124
+
125
+
126
+ class TransformerFactory:
127
+ def __init__(
128
+ self,
129
+ model_dim: int,
130
+ max_seq_len: int,
131
+ config: TransformerConfig,
132
+ device: Optional[Device] = None,
133
+ dtype: Optional[DataType] = None,
134
+ ) -> None:
135
+ """
136
+ :param model_dim:
137
+ The hidden model dimension of the Transformer
138
+ :params max_seq_len:
139
+ Maximum supported sequence length by the model
140
+ :param config:
141
+ The configuration.
142
+ :param device:
143
+ The device on which to initialize modules.
144
+ :param dtype:
145
+ The data type of module parameters and buffers.
146
+ """
147
+ self.model_dim = model_dim
148
+ self.max_seq_len = max_seq_len
149
+ self.config = config
150
+ self.device, self.dtype = device, dtype
151
+
152
+ def build_layer(self) -> TransformerDecoderLayer:
153
+ """Build a Transformer decoder layer based on the provided config."""
154
+
155
+ self_attn = self.build_attention()
156
+
157
+ ffn = self.build_ffn()
158
+
159
+ norm_order = parse_norm_order(self.config.norm_order_style)
160
+
161
+ layer_norm_factory = parse_layer_norm_factory(
162
+ self.config.layer_normalization_style
163
+ )
164
+
165
+ layer = LCMStandardTransformerDecoderLayer(
166
+ self_attn=self_attn,
167
+ encoder_decoder_attn=None,
168
+ ffn=ffn,
169
+ dropout_p=self.config.dropout_p,
170
+ norm_order=norm_order,
171
+ layer_norm_factory=layer_norm_factory,
172
+ scale_residual=self.config.scale_residual is not None,
173
+ device=self.device,
174
+ dtype=self.dtype,
175
+ )
176
+ # reset residual_scale
177
+ if layer.residual_scale is not None:
178
+ assert self.config.scale_residual is not None, (
179
+ f"Layer has a resiudal scale but scale={self.config.scale_residual}"
180
+ )
181
+ torch.nn.init.constant_(layer.residual_scale, self.config.scale_residual)
182
+ logger.info(
183
+ f"Initializing the residual scale at {self.config.scale_residual}"
184
+ )
185
+ return layer
186
+
187
+ def build_pos_encoder(self) -> Optional[PositionEncoder]:
188
+ """Build the positional encoder (learned or sinusoidal, if any)
189
+ that will be used in the frontend"""
190
+ pos_encoder: Optional[PositionEncoder]
191
+
192
+ if self.config.pos_embedding_style == "learned":
193
+ pos_encoder = LearnedPositionEncoder(
194
+ self.model_dim,
195
+ self.max_seq_len,
196
+ device=self.device,
197
+ dtype=self.dtype,
198
+ )
199
+ elif self.config.pos_embedding_style == "sine":
200
+ pos_encoder = SinusoidalPositionEncoder(
201
+ self.model_dim,
202
+ self.max_seq_len,
203
+ device=self.device,
204
+ )
205
+
206
+ else:
207
+ pos_encoder = None
208
+
209
+ return pos_encoder
210
+
211
+ def build_attention_pos_encoder(self) -> Optional[PositionEncoder]:
212
+ """Build the position encoder that can
213
+ potentially be used in the MHA module"""
214
+
215
+ pos_encoder: Optional[PositionEncoder]
216
+
217
+ if self.config.pos_embedding_style == "rope":
218
+ pos_encoder = RotaryEncoder(
219
+ encoding_dim=self.model_dim // self.config.num_attn_heads,
220
+ max_seq_len=self.max_seq_len,
221
+ theta=self.config.rope_theta,
222
+ device=self.device,
223
+ )
224
+ else:
225
+ pos_encoder = None
226
+ return pos_encoder
227
+
228
+ def build_attention(self) -> MultiheadAttention:
229
+ """Build a Transformer multi-head attention layer."""
230
+
231
+ # allow for a different kv_dim
232
+ kv_dim = self.model_dim
233
+
234
+ # fairseq2.nn.transformer.attention.TorchSDPA
235
+ sdpa = create_default_sdpa(attn_dropout_p=self.config.attention_dropout_p)
236
+
237
+ init_fn = get_init_fn(self.config.attention_output_init_fn)
238
+
239
+ # How does Rope play with encoder-decoder attention?
240
+ pos_encoder = self.build_attention_pos_encoder()
241
+
242
+ layer_norm_factory = parse_layer_norm_factory(
243
+ self.config.layer_normalization_style
244
+ )
245
+
246
+ # build output_proj:
247
+ output_proj = Linear(
248
+ self.model_dim,
249
+ self.model_dim,
250
+ bias=self.config.mha_output_proj_bias,
251
+ init_fn=init_fn,
252
+ device=self.device,
253
+ dtype=self.dtype,
254
+ )
255
+ if self.config.mha_output_weight_normalization:
256
+ output_proj = torch.nn.utils.parametrizations.weight_norm(output_proj)
257
+
258
+ return QKNormMultiheadAttention(
259
+ self.model_dim,
260
+ self.config.num_attn_heads,
261
+ kv_dim=kv_dim,
262
+ pos_encoder=pos_encoder,
263
+ sdpa=sdpa,
264
+ output_proj=output_proj,
265
+ enable_qk_layernorm=self.config.enable_qk_layernorm,
266
+ weight_normalization=self.config.mha_qkv_weight_normalization,
267
+ layer_norm_factory=layer_norm_factory,
268
+ state_factory=FullAttentionState,
269
+ device=self.device,
270
+ dtype=self.dtype,
271
+ )
272
+
273
+ def build_ffn(self) -> FeedForwardNetwork:
274
+ """Build a Transformer feed-forward network."""
275
+ if self.config.use_swiglu:
276
+ # Default gate_activation is torch.nn.SiLU
277
+ return GLUFeedForwardNetwork(
278
+ self.model_dim,
279
+ self.config.ffn_inner_dim,
280
+ bias=True,
281
+ inner_dim_scale=2 / 3,
282
+ inner_dim_to_multiple=256,
283
+ device=self.device,
284
+ dtype=self.dtype,
285
+ )
286
+
287
+ ffn_inner_activation = parse_activation_fn(
288
+ self.config.ffn_inner_activation_name
289
+ )
290
+ norm_order = parse_norm_order(self.config.norm_order_style)
291
+
292
+ return StandardFeedForwardNetwork(
293
+ self.model_dim,
294
+ self.config.ffn_inner_dim,
295
+ inner_activation=ffn_inner_activation,
296
+ bias=True,
297
+ norm_order=norm_order,
298
+ device=self.device,
299
+ dtype=self.dtype,
300
+ )