Lexa
commited on
Commit
·
3d79eb3
0
Parent(s):
Initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +118 -0
- .pre-commit-config.yaml +16 -0
- LICENSE +21 -0
- README.md +29 -0
- lcm/__init__.py +22 -0
- lcm/cards/Normalizer_Wikipedia_En_1M.pt +0 -0
- lcm/cards/sonar_normalizer.yaml +4 -0
- lcm/datacards/datacards.yaml +5 -0
- lcm/datasets/__init__.py +4 -0
- lcm/datasets/batch.py +425 -0
- lcm/inference/lcm/__init__.py +9 -0
- lcm/inference/lcm/generator.py +448 -0
- lcm/inference/lcm/scorer.py +198 -0
- lcm/inference/two_tower_diffusion_lcm/__init__.py +16 -0
- lcm/inference/two_tower_diffusion_lcm/generator.py +466 -0
- lcm/inference/two_tower_diffusion_lcm/scorer.py +314 -0
- lcm/models/__init__.py +15 -0
- lcm/models/abstract_lcm/__init__.py +16 -0
- lcm/models/abstract_lcm/builder.py +106 -0
- lcm/models/base_lcm/__init__.py +20 -0
- lcm/models/base_lcm/archs.py +49 -0
- lcm/models/base_lcm/builder.py +285 -0
- lcm/models/base_lcm/frontend.py +183 -0
- lcm/models/base_lcm/loader.py +55 -0
- lcm/models/base_lcm/normalization.py +50 -0
- lcm/models/sonar_normalizer/__init__.py +20 -0
- lcm/models/sonar_normalizer/archs.py +40 -0
- lcm/models/sonar_normalizer/builder.py +210 -0
- lcm/models/sonar_normalizer/loader.py +28 -0
- lcm/models/two_tower_diffusion_lcm/__init__.py +7 -0
- lcm/models/two_tower_diffusion_lcm/archs.py +207 -0
- lcm/models/two_tower_diffusion_lcm/builder.py +628 -0
- lcm/models/two_tower_diffusion_lcm/frontend.py +152 -0
- lcm/models/two_tower_diffusion_lcm/loader.py +44 -0
- lcm/nn/__init__.py +4 -0
- lcm/nn/denoisers/__init__.py +17 -0
- lcm/nn/denoisers/attention_masks.py +228 -0
- lcm/nn/denoisers/factory.py +192 -0
- lcm/nn/denoisers/lcm_denoiser.py +546 -0
- lcm/nn/incremental_state.py +43 -0
- lcm/nn/initialization.py +152 -0
- lcm/nn/normalization.py +88 -0
- lcm/nn/projection.py +86 -0
- lcm/nn/schedulers/__init__.py +17 -0
- lcm/nn/schedulers/ddim.py +741 -0
- lcm/nn/timestep_encoder.py +122 -0
- lcm/nn/transformer/__init__.py +24 -0
- lcm/nn/transformer/attention.py +307 -0
- lcm/nn/transformer/decoder.py +176 -0
- lcm/nn/transformer/factory.py +300 -0
.gitignore
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# JetBrains PyCharm IDE
|
| 3 |
+
.idea/
|
| 4 |
+
|
| 5 |
+
# Byte-compiled / optimized / DLL files
|
| 6 |
+
**/*/__pycache__/
|
| 7 |
+
*.py[cod]
|
| 8 |
+
*$py.class
|
| 9 |
+
|
| 10 |
+
# C extensions
|
| 11 |
+
*.so
|
| 12 |
+
|
| 13 |
+
# macOS dir files
|
| 14 |
+
.DS_Store
|
| 15 |
+
|
| 16 |
+
# Distribution / packaging
|
| 17 |
+
.Python
|
| 18 |
+
env/
|
| 19 |
+
build/
|
| 20 |
+
develop-eggs/
|
| 21 |
+
dist/
|
| 22 |
+
downloads/
|
| 23 |
+
eggs/
|
| 24 |
+
.eggs/
|
| 25 |
+
lib64/
|
| 26 |
+
parts/
|
| 27 |
+
sdist/
|
| 28 |
+
var/
|
| 29 |
+
wheels/
|
| 30 |
+
*.egg-info/
|
| 31 |
+
.installed.cfg
|
| 32 |
+
*.egg
|
| 33 |
+
|
| 34 |
+
# PyInstaller
|
| 35 |
+
# Usually these files are written by a python script from a template
|
| 36 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 37 |
+
*.manifest
|
| 38 |
+
*.spec
|
| 39 |
+
|
| 40 |
+
# Installer logs
|
| 41 |
+
pip-log.txt
|
| 42 |
+
pip-delete-this-directory.txt
|
| 43 |
+
|
| 44 |
+
# Unit test / coverage reports
|
| 45 |
+
htmlcov/
|
| 46 |
+
.tox/
|
| 47 |
+
.coverage
|
| 48 |
+
.coverage.*
|
| 49 |
+
.cache
|
| 50 |
+
nosetests.xml
|
| 51 |
+
coverage.xml
|
| 52 |
+
*.cover
|
| 53 |
+
.hypothesis/
|
| 54 |
+
|
| 55 |
+
# Translations
|
| 56 |
+
*.mo
|
| 57 |
+
*.pot
|
| 58 |
+
|
| 59 |
+
# Django stuff:
|
| 60 |
+
*.log
|
| 61 |
+
local_settings.py
|
| 62 |
+
|
| 63 |
+
# Flask stuff:
|
| 64 |
+
instance/
|
| 65 |
+
.webassets-cache
|
| 66 |
+
|
| 67 |
+
# Scrapy stuff:
|
| 68 |
+
.scrapy
|
| 69 |
+
|
| 70 |
+
# Sphinx documentation
|
| 71 |
+
docs/_build/
|
| 72 |
+
|
| 73 |
+
# PyBuilder
|
| 74 |
+
target/
|
| 75 |
+
|
| 76 |
+
# Jupyter Notebook
|
| 77 |
+
.ipynb_checkpoints
|
| 78 |
+
|
| 79 |
+
# pyenv
|
| 80 |
+
.python-version
|
| 81 |
+
|
| 82 |
+
# dotenv
|
| 83 |
+
.env
|
| 84 |
+
|
| 85 |
+
# virtualenv
|
| 86 |
+
.venv
|
| 87 |
+
venv/
|
| 88 |
+
ENV/
|
| 89 |
+
|
| 90 |
+
# mkdocs documentation
|
| 91 |
+
/site
|
| 92 |
+
|
| 93 |
+
# mypy
|
| 94 |
+
.mypy_cache/
|
| 95 |
+
|
| 96 |
+
.pytest_cache
|
| 97 |
+
.ruff_cache
|
| 98 |
+
|
| 99 |
+
# VSCODE
|
| 100 |
+
.vscode/ftp-sync.json
|
| 101 |
+
.vscode/settings.json
|
| 102 |
+
.vscode/launch.json
|
| 103 |
+
|
| 104 |
+
# stopes logs
|
| 105 |
+
executor_logs/
|
| 106 |
+
config_logs/
|
| 107 |
+
outputs/
|
| 108 |
+
|
| 109 |
+
logs/
|
| 110 |
+
**/dask_jobqueue_logs
|
| 111 |
+
core.*
|
| 112 |
+
mortimer_env.txt
|
| 113 |
+
|
| 114 |
+
# datasets
|
| 115 |
+
_LexaLCM_Block0/Datasets/
|
| 116 |
+
|
| 117 |
+
# UV
|
| 118 |
+
uv.lock
|
.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
repos:
|
| 2 |
+
- repo: https://github.com/astral-sh/uv-pre-commit
|
| 3 |
+
rev: 0.5.7
|
| 4 |
+
hooks:
|
| 5 |
+
- id: uv-lock
|
| 6 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
| 7 |
+
rev: v0.8.2
|
| 8 |
+
hooks:
|
| 9 |
+
# Lint
|
| 10 |
+
- id: ruff
|
| 11 |
+
args: [ --fix ]
|
| 12 |
+
# sort imports
|
| 13 |
+
- id: ruff
|
| 14 |
+
args: ["check", "--select", "I", "--fix"]
|
| 15 |
+
# format
|
| 16 |
+
- id: ruff-format
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Alexandra 'Lexa' Baldwin
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LexaLCM Pre0 288M Pre-trained Large Concept Model
|
| 2 |
+
A pre-trained LCM model with 288M parameters based on Meta FAIR's LCM architecture.
|
| 3 |
+
|
| 4 |
+
[[Paper]](https://ai.meta.com/research/publications/large-concept-models-language-modeling-in-a-sentence-representation-space/)
|
| 5 |
+
|
| 6 |
+
Note: These instructions are for running the model on a single machine with a single GPU. If your system does not have a GPU that supports at least CUDA 12.1, or if you intend to execute this in the cloud, you'll need to modify the code per your requirements.
|
| 7 |
+
|
| 8 |
+
## 1. Instal the Intel MKL runtime
|
| 9 |
+
```bash
|
| 10 |
+
sudo apt update
|
| 11 |
+
sudo apt install libmkl-rt
|
| 12 |
+
export LD_LIBRARY_PATH=/opt/intel/mkl/lib/intel64:$LD_LIBRARY_PATH
|
| 13 |
+
source ~/.bashrc
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
## 2. Install dependencies
|
| 17 |
+
```bash
|
| 18 |
+
uv sync --extra gpu --extra eval --extra data
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
## 3. Update the model cards' paths
|
| 22 |
+
These two model cards' paths must be updated to use the current paths based on where they exist in your local filesystem.
|
| 23 |
+
* '_LexaLCM_Pre0/Checkpoints/LCM_TwoTower_Pre0/model_card.yaml'
|
| 24 |
+
* '_LexaLCM_Pre0/Checkpoints/LCM_TwoTower_Pre0/checkpoints/step_250000/model_card.yaml'
|
| 25 |
+
|
| 26 |
+
## 4. Test the model's inference
|
| 27 |
+
```bash
|
| 28 |
+
uv run --extra gpu scripts/run_inference.py
|
| 29 |
+
```
|
lcm/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
LCM: Modular and Extensible Reasoning in an Embedding Space
|
| 8 |
+
Code base for training different LCM models.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from fairseq2 import setup_extensions
|
| 12 |
+
from fairseq2.assets import default_asset_store
|
| 13 |
+
|
| 14 |
+
__version__ = "0.1.0.dev0"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def setup_fairseq2() -> None:
|
| 18 |
+
default_asset_store.add_package_metadata_provider("lcm.cards")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# This call activates setup_fairseq2 and potentially other extensions,
|
| 22 |
+
setup_extensions()
|
lcm/cards/Normalizer_Wikipedia_En_1M.pt
ADDED
|
Binary file (9.99 kB). View file
|
|
|
lcm/cards/sonar_normalizer.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: sonar_normalizer_wikipedia_en_1m
|
| 2 |
+
model_family: sonar_normalizer
|
| 3 |
+
model_arch: base
|
| 4 |
+
checkpoint: Normalizer_Wikipedia_En_1M.pt
|
lcm/datacards/datacards.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: "Data_Wikipedia_En_1M"
|
| 2 |
+
parquet_path:
|
| 3 |
+
local: "./_LexaLCM_Pre0/Datasets/Wikipedia_En_1M"
|
| 4 |
+
source_column: "text_sentences_sonar_emb"
|
| 5 |
+
source_text_column: "text_sentences"
|
lcm/datasets/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
lcm/datasets/batch.py
ADDED
|
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
#
|
| 6 |
+
|
| 7 |
+
from copy import deepcopy
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from enum import Enum
|
| 10 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from fairseq2.logging import get_log_writer
|
| 14 |
+
from fairseq2.models.sequence import SequenceBatch
|
| 15 |
+
from fairseq2.nn.padding import PaddingMask, pad_seqs
|
| 16 |
+
from fairseq2.typing import Device
|
| 17 |
+
from torch import Tensor
|
| 18 |
+
from torch.nn import Module
|
| 19 |
+
|
| 20 |
+
from lcm.utils.common import Batched
|
| 21 |
+
|
| 22 |
+
logger = get_log_writer(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
DOC_LENGTHS = "__doc_lengths"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class LCMStyle(Enum):
|
| 29 |
+
"""Specifies a style for preparing the LCM input."""
|
| 30 |
+
|
| 31 |
+
SUPERVISED = 1
|
| 32 |
+
"""For when the model is fed supervised data with source & target sentences."""
|
| 33 |
+
|
| 34 |
+
UNSUPERVISED = 2
|
| 35 |
+
"""For when the model is fed unsupervised data with source sentences only."""
|
| 36 |
+
|
| 37 |
+
PACKED_UNSUPERVISED = 3
|
| 38 |
+
"""For when the model is fed ``packed`` unsupervised data with source sentences only.
|
| 39 |
+
This means that we will look for document_lengths and propagate them to the
|
| 40 |
+
packed causal masked attention and the packed position encoders"""
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dataclass
|
| 44 |
+
class EmbeddingsBatch:
|
| 45 |
+
"""Represents a sequence of embeddings batch.
|
| 46 |
+
Resembles Fairseq2's SequenceBatch with additional properties"""
|
| 47 |
+
|
| 48 |
+
seqs: Tensor
|
| 49 |
+
"""The sequences. *Shape:* :math:`(B,S,D)`, where :math:`B` is the batch
|
| 50 |
+
size, :math:`S` is the sequence length (in sentences per document),
|
| 51 |
+
and :math:`D` the embedding dimension
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
padding_mask: Optional[PaddingMask] = None
|
| 55 |
+
"""The padding mask of ``seqs``. *Shape:* :math:`(B,S)`, where :math:`B` is
|
| 56 |
+
the batch size and :math:`S` is the sequence length."""
|
| 57 |
+
|
| 58 |
+
diffusion_timesteps: Optional[Tensor] = None
|
| 59 |
+
"""Diffusion timesteps of noising process of ``seqs``. *Shape:* :math:`(B,S)`, where :math:`B` is
|
| 60 |
+
the batch size and :math:`S` is the sequence length."""
|
| 61 |
+
|
| 62 |
+
document_lengths: Optional[Tensor] = None
|
| 63 |
+
"""Lengths of the documents (in sentences) present in the batch
|
| 64 |
+
Shape: (len_doc, )
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
source_lengths: Optional[Tensor] = None
|
| 68 |
+
"""Lengths of source part for each element in batch, so that `seqs[i, :source_lengths[i]]` corresponds to source for each i in [0, batch_size).
|
| 69 |
+
Shape: (batch_size, )
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def __post_init__(self):
|
| 73 |
+
if self.document_lengths is not None:
|
| 74 |
+
assert self.document_lengths.sum() == self.seqs.size(
|
| 75 |
+
1
|
| 76 |
+
) or 2 * self.document_lengths.sum() == self.seqs.size(1), (
|
| 77 |
+
"The legnths do no sum up to the sequence length "
|
| 78 |
+
"(nor half the length for doubled diffusion sequences). "
|
| 79 |
+
f"We have seqs.size={self.seqs.size()} and lengths={self.document_lengths} "
|
| 80 |
+
f"summing to {self.document_lengths.sum()}"
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def __len__(self) -> int:
|
| 84 |
+
return self.batch_size
|
| 85 |
+
|
| 86 |
+
@property
|
| 87 |
+
def batch_size(self) -> int:
|
| 88 |
+
"""The size of the batch."""
|
| 89 |
+
return self.seqs.size(0)
|
| 90 |
+
|
| 91 |
+
@property
|
| 92 |
+
def shape(self) -> torch.Size:
|
| 93 |
+
"""The shape of the batch."""
|
| 94 |
+
return self.seqs.shape
|
| 95 |
+
|
| 96 |
+
@property
|
| 97 |
+
def device(self) -> Device:
|
| 98 |
+
"""The device of the batch."""
|
| 99 |
+
return self.seqs.device
|
| 100 |
+
|
| 101 |
+
def clone(self):
|
| 102 |
+
return deepcopy(self)
|
| 103 |
+
|
| 104 |
+
def __getitem__(self, i: int) -> Any:
|
| 105 |
+
raise NotImplementedError(
|
| 106 |
+
"Access to each item in EmbeddingsBatch not allowed yet"
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
def unbatch(self) -> List[Tensor]:
|
| 110 |
+
if self.padding_mask is None:
|
| 111 |
+
return list(self.seqs)
|
| 112 |
+
else:
|
| 113 |
+
return [
|
| 114 |
+
tt[:length] for tt, length in zip(self.seqs, self.padding_mask.seq_lens)
|
| 115 |
+
]
|
| 116 |
+
|
| 117 |
+
def get_last_element(self) -> Tensor:
|
| 118 |
+
if self.padding_mask:
|
| 119 |
+
return self.seqs[
|
| 120 |
+
torch.arange(len(self.padding_mask.seq_lens), device=self.seqs.device),
|
| 121 |
+
(self.padding_mask.seq_lens - 1),
|
| 122 |
+
]
|
| 123 |
+
else:
|
| 124 |
+
return self.seqs[:, -1]
|
| 125 |
+
|
| 126 |
+
def set_last_element(self, element: Tensor) -> None:
|
| 127 |
+
element = element.to(self.seqs.device)
|
| 128 |
+
if self.padding_mask:
|
| 129 |
+
for i, slen in enumerate(self.padding_mask.seq_lens):
|
| 130 |
+
self.seqs[i, slen - 1] = element[i]
|
| 131 |
+
else:
|
| 132 |
+
self.seqs[:, -1] = element
|
| 133 |
+
|
| 134 |
+
def normalize_seqs(self, normalizer: Optional[Module]) -> "EmbeddingsBatch":
|
| 135 |
+
if normalizer is None:
|
| 136 |
+
logger.warning(
|
| 137 |
+
"The normalizer is None, as such, the features will remain unchanged"
|
| 138 |
+
)
|
| 139 |
+
return self
|
| 140 |
+
|
| 141 |
+
return EmbeddingsBatch(
|
| 142 |
+
seqs=normalizer.normalize(self.seqs),
|
| 143 |
+
padding_mask=self.padding_mask,
|
| 144 |
+
diffusion_timesteps=self.diffusion_timesteps,
|
| 145 |
+
document_lengths=self.document_lengths,
|
| 146 |
+
source_lengths=self.source_lengths,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
def denormalize_seqs(self, normalizer: Optional[Module]) -> "EmbeddingsBatch":
|
| 150 |
+
if normalizer is None:
|
| 151 |
+
logger.warning(
|
| 152 |
+
"The normalizer is None, as such, the features will remain unchanged"
|
| 153 |
+
)
|
| 154 |
+
return self
|
| 155 |
+
|
| 156 |
+
return EmbeddingsBatch(
|
| 157 |
+
seqs=normalizer.denormalize(self.seqs),
|
| 158 |
+
padding_mask=self.padding_mask,
|
| 159 |
+
diffusion_timesteps=self.diffusion_timesteps,
|
| 160 |
+
document_lengths=self.document_lengths,
|
| 161 |
+
source_lengths=self.source_lengths,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
def double_seqs(self) -> "EmbeddingsBatch":
|
| 165 |
+
"""
|
| 166 |
+
performs sequence elements repeatition in sequence dim :
|
| 167 |
+
1, 2, 3 -> 1, 1, 2, 2, 3, 3
|
| 168 |
+
x, y -> x, x, y, y
|
| 169 |
+
"""
|
| 170 |
+
if self.padding_mask is not None:
|
| 171 |
+
doubled_padding_mask = PaddingMask(
|
| 172 |
+
seq_lens=2 * self.padding_mask._seq_lens,
|
| 173 |
+
batch_seq_len=2 * self.padding_mask._batch_seq_len,
|
| 174 |
+
)
|
| 175 |
+
else:
|
| 176 |
+
doubled_padding_mask = None
|
| 177 |
+
|
| 178 |
+
return EmbeddingsBatch(
|
| 179 |
+
seqs=torch.repeat_interleave(self.seqs, 2, dim=1),
|
| 180 |
+
padding_mask=doubled_padding_mask,
|
| 181 |
+
diffusion_timesteps=self.diffusion_timesteps,
|
| 182 |
+
document_lengths=self.document_lengths,
|
| 183 |
+
source_lengths=(
|
| 184 |
+
torch.repeat_interleave(self.source_lengths, 2, dim=0)
|
| 185 |
+
if self.source_lengths is not None
|
| 186 |
+
else None
|
| 187 |
+
),
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
def flatten_to_sentences(self) -> Tensor:
|
| 191 |
+
"""Flatten the sequence of embeddings
|
| 192 |
+
from B, S, D to B*~S, D after removing the padded positions
|
| 193 |
+
"""
|
| 194 |
+
|
| 195 |
+
embed_dim = self.seqs.size(-1)
|
| 196 |
+
|
| 197 |
+
if self.padding_mask is not None:
|
| 198 |
+
seq_lens = self.padding_mask.seq_lens
|
| 199 |
+
|
| 200 |
+
embeds_mask = self.padding_mask.materialize().unsqueeze(-1)
|
| 201 |
+
|
| 202 |
+
# Remove padded positions and reshape as B*~S, D
|
| 203 |
+
flat_embeds = torch.masked_select(self.seqs, embeds_mask).reshape(
|
| 204 |
+
(-1, embed_dim)
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# split per document/paragraph
|
| 208 |
+
flat_embeds_per_doc = list(torch.split(flat_embeds, seq_lens.tolist()))
|
| 209 |
+
|
| 210 |
+
# Concatenate back
|
| 211 |
+
flat_embeds = torch.concat(flat_embeds_per_doc)
|
| 212 |
+
|
| 213 |
+
else:
|
| 214 |
+
embeds = self.seqs
|
| 215 |
+
|
| 216 |
+
flat_embeds = embeds.reshape((-1, embed_dim))
|
| 217 |
+
|
| 218 |
+
return flat_embeds
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
@dataclass
|
| 222 |
+
class LCMInput(Batched):
|
| 223 |
+
"""Dataclass for a pair of source/target sequences of SONAR embeddings"""
|
| 224 |
+
|
| 225 |
+
source: List[Tensor]
|
| 226 |
+
"""source: SONAR embeddings of the source text
|
| 227 |
+
i.e [X^1 in (N_1, D), ... X^M in (N_M, D)]"""
|
| 228 |
+
|
| 229 |
+
target: Union[None, List[Tensor]]
|
| 230 |
+
"""target: If supervised data: SONAR embeddings of the target text"""
|
| 231 |
+
|
| 232 |
+
tokens: Union[None, SequenceBatch] = None
|
| 233 |
+
"""tokens: Tokenized flattened sentences for the SONAR decoder
|
| 234 |
+
(see the dataloader `_prepare_subword_tokens`)"""
|
| 235 |
+
|
| 236 |
+
target_tokens: Union[None, SequenceBatch] = None
|
| 237 |
+
"""target_tokens: a sequence of the same shape as target_tokens, but shifted, to serve as the target.
|
| 238 |
+
(see the dataloader `_prepare_subword_tokens`)"""
|
| 239 |
+
|
| 240 |
+
name: Optional[str] = None
|
| 241 |
+
"""
|
| 242 |
+
dataset name from which input is coming from
|
| 243 |
+
"""
|
| 244 |
+
batch: Optional[Dict[str, Any]] = None
|
| 245 |
+
"""raw batch of dataloader used for tracking and debugging"""
|
| 246 |
+
|
| 247 |
+
def __post_init__(self):
|
| 248 |
+
assert self.source is not None
|
| 249 |
+
|
| 250 |
+
length = len(self.source)
|
| 251 |
+
|
| 252 |
+
assert (self.target is None) or (len(self.target) == length), (
|
| 253 |
+
f"all elements in LCMInput should be of the same length, got {len(self.target)} and {length}"
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
def __len__(self) -> int:
|
| 257 |
+
return len(self.source)
|
| 258 |
+
|
| 259 |
+
def __getitem__(self, i: int) -> Union[Tensor, Tuple[Tensor, Tensor]]:
|
| 260 |
+
"""
|
| 261 |
+
Return the content of item in the batch
|
| 262 |
+
"""
|
| 263 |
+
if self.target is None:
|
| 264 |
+
return self.source[i]
|
| 265 |
+
else:
|
| 266 |
+
return self.source[i], self.target[i]
|
| 267 |
+
|
| 268 |
+
def prepare_input(
|
| 269 |
+
self,
|
| 270 |
+
style: LCMStyle = LCMStyle.UNSUPERVISED,
|
| 271 |
+
) -> EmbeddingsBatch:
|
| 272 |
+
"""
|
| 273 |
+
Adds special tokens to the source (& target) and prepares
|
| 274 |
+
the EmbeddingsBatch (tensor & its padding mask) that will be
|
| 275 |
+
forwarded in the LCM model.
|
| 276 |
+
|
| 277 |
+
`style`: LCMStyle is either supervised or
|
| 278 |
+
unsupervised (requires target embeddings)
|
| 279 |
+
"""
|
| 280 |
+
|
| 281 |
+
if style == LCMStyle.UNSUPERVISED:
|
| 282 |
+
return get_embeddings_sequence(src_seqs=self.source)
|
| 283 |
+
|
| 284 |
+
elif style == LCMStyle.PACKED_UNSUPERVISED:
|
| 285 |
+
# If using PACKED_UNSUPERVISED, document_lengths will be added to `EmbeddingsBatch`
|
| 286 |
+
document_lengths = None
|
| 287 |
+
if self.batch is not None and self.batch.get(DOC_LENGTHS, None) is not None:
|
| 288 |
+
# document_lengths will only be consumed if the batch_size is 1
|
| 289 |
+
assert len(self.batch[DOC_LENGTHS]) == 1, "Expecting batch size of 1"
|
| 290 |
+
|
| 291 |
+
document_lengths = self.batch[DOC_LENGTHS][0].type(torch.int64)
|
| 292 |
+
|
| 293 |
+
return get_embeddings_sequence(
|
| 294 |
+
src_seqs=self.source,
|
| 295 |
+
document_lengths=document_lengths,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
elif style == LCMStyle.SUPERVISED:
|
| 299 |
+
assert self.target is not None, (
|
| 300 |
+
"Missing target embeddings for a supervised batch"
|
| 301 |
+
)
|
| 302 |
+
return get_embeddings_sequence(
|
| 303 |
+
src_seqs=self.source,
|
| 304 |
+
tgt_seqs=self.target,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
raise ValueError(f"Unsupported style={style} - could not prepare input")
|
| 308 |
+
|
| 309 |
+
def prepare_target_mask(
|
| 310 |
+
self,
|
| 311 |
+
embeddings: EmbeddingsBatch,
|
| 312 |
+
style: LCMStyle,
|
| 313 |
+
min_context_size: Optional[int] = None,
|
| 314 |
+
) -> Tensor:
|
| 315 |
+
"""Prepare a target mask signaling what positions
|
| 316 |
+
we should predict and optimize the model for
|
| 317 |
+
|
| 318 |
+
Args:
|
| 319 |
+
- min_context_size: the minimum context used to predict the next
|
| 320 |
+
concept (only used for unuspervised training)
|
| 321 |
+
|
| 322 |
+
"""
|
| 323 |
+
|
| 324 |
+
batch_size, maxlen, _ = embeddings.seqs.size()
|
| 325 |
+
|
| 326 |
+
device = embeddings.seqs.device
|
| 327 |
+
|
| 328 |
+
if style == LCMStyle.UNSUPERVISED:
|
| 329 |
+
# A target mask for unsupervised next sentence prediction
|
| 330 |
+
# All positions are optimized/predicted starting from min_context_size
|
| 331 |
+
target_mask = torch.ones(
|
| 332 |
+
(batch_size, maxlen),
|
| 333 |
+
dtype=torch.bool,
|
| 334 |
+
device=device,
|
| 335 |
+
)
|
| 336 |
+
if min_context_size is not None:
|
| 337 |
+
target_mask[:, : min(min_context_size, target_mask.size(1))] = False
|
| 338 |
+
|
| 339 |
+
elif style == LCMStyle.PACKED_UNSUPERVISED:
|
| 340 |
+
# A target mask for unsupervised next sentence prediction when the data is packed
|
| 341 |
+
# All positions are optimized starting from min_context_size in each document
|
| 342 |
+
document_lengths = embeddings.document_lengths
|
| 343 |
+
if document_lengths is not None: # training
|
| 344 |
+
|
| 345 |
+
def get_document_target_mask(doc_length):
|
| 346 |
+
mask = torch.ones(doc_length, dtype=torch.bool, device=device)
|
| 347 |
+
mask[: min(min_context_size, doc_length)] = False
|
| 348 |
+
return mask
|
| 349 |
+
|
| 350 |
+
target_mask = torch.cat(
|
| 351 |
+
[get_document_target_mask(length) for length in document_lengths]
|
| 352 |
+
).unsqueeze(0)
|
| 353 |
+
|
| 354 |
+
else: # validation with unpacked data:
|
| 355 |
+
target_mask = torch.ones(
|
| 356 |
+
(batch_size, maxlen),
|
| 357 |
+
dtype=torch.bool,
|
| 358 |
+
device=device,
|
| 359 |
+
)
|
| 360 |
+
if min_context_size is not None:
|
| 361 |
+
target_mask[:, : min(min_context_size, target_mask.size(1))] = False
|
| 362 |
+
|
| 363 |
+
elif style == LCMStyle.SUPERVISED:
|
| 364 |
+
# A target mask for target prediction
|
| 365 |
+
indices = torch.arange(maxlen, device=device).expand(batch_size, -1)
|
| 366 |
+
|
| 367 |
+
source_lengths = torch.tensor(
|
| 368 |
+
[seq.size(0) for seq in self.source],
|
| 369 |
+
device=device,
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
target_mask = indices >= source_lengths.unsqueeze(1).expand(-1, maxlen)
|
| 373 |
+
|
| 374 |
+
# Factor in padded positions:
|
| 375 |
+
if embeddings.padding_mask is not None:
|
| 376 |
+
target_mask = target_mask * embeddings.padding_mask.materialize()
|
| 377 |
+
|
| 378 |
+
return target_mask.detach()
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def get_embeddings_sequence(
|
| 382 |
+
src_seqs: List[Tensor],
|
| 383 |
+
tgt_seqs: Optional[List[Tensor]] = None,
|
| 384 |
+
document_lengths: Optional[Tensor] = None,
|
| 385 |
+
double_target: bool = False,
|
| 386 |
+
) -> EmbeddingsBatch:
|
| 387 |
+
seqs_lst: List[Tensor] = []
|
| 388 |
+
for src_seq, tgt_seq in zip(src_seqs, tgt_seqs or [None] * len(src_seqs)): # type: ignore
|
| 389 |
+
embeds: List[Tensor] = []
|
| 390 |
+
device, dtype = src_seq.device, src_seq.dtype
|
| 391 |
+
|
| 392 |
+
# mandatory src_sec
|
| 393 |
+
embeds.append(src_seq)
|
| 394 |
+
|
| 395 |
+
# supervised tgt_seq
|
| 396 |
+
if tgt_seq is not None:
|
| 397 |
+
tgt_seq = tgt_seq.to(device).type(dtype)
|
| 398 |
+
|
| 399 |
+
if double_target:
|
| 400 |
+
embeds.append(torch.repeat_interleave(tgt_seq, 2, dim=0))
|
| 401 |
+
else:
|
| 402 |
+
embeds.append(tgt_seq)
|
| 403 |
+
|
| 404 |
+
seqs_lst.append(torch.concat(embeds))
|
| 405 |
+
|
| 406 |
+
seqs, padding_mask = pad_seqs(seqs_lst)
|
| 407 |
+
|
| 408 |
+
if document_lengths is not None:
|
| 409 |
+
document_lengths = document_lengths.to(seqs.device)
|
| 410 |
+
|
| 411 |
+
if tgt_seqs is not None:
|
| 412 |
+
source_lengths = torch.tensor(
|
| 413 |
+
[seq.size(0) for seq in src_seqs], device=seqs.device
|
| 414 |
+
)
|
| 415 |
+
else:
|
| 416 |
+
source_lengths = None
|
| 417 |
+
|
| 418 |
+
output = EmbeddingsBatch(
|
| 419 |
+
seqs,
|
| 420 |
+
padding_mask=padding_mask,
|
| 421 |
+
document_lengths=document_lengths,
|
| 422 |
+
source_lengths=source_lengths,
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
return output
|
lcm/inference/lcm/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from lcm.inference.lcm.generator import LCMGenerator as LCMGenerator
|
| 7 |
+
from lcm.inference.lcm.generator import LCMGeneratorOptions as LCMGeneratorOptions
|
| 8 |
+
|
| 9 |
+
__all__ = ["LCMGenerator", "LCMGeneratorOptions"]
|
lcm/inference/lcm/generator.py
ADDED
|
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import List, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from fairseq2.generation.generator import (
|
| 11 |
+
GenerationCounters,
|
| 12 |
+
Hypothesis,
|
| 13 |
+
SequenceGeneratorOutput,
|
| 14 |
+
)
|
| 15 |
+
from fairseq2.logging import get_log_writer
|
| 16 |
+
|
| 17 |
+
from lcm.datasets.batch import EmbeddingsBatch, PaddingMask
|
| 18 |
+
from lcm.models.abstract_lcm import AbstractLCModel
|
| 19 |
+
from lcm.nn.incremental_state import LCMIncrementalStateBag
|
| 20 |
+
|
| 21 |
+
logger = get_log_writer(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
"""
|
| 25 |
+
This generator follows the style of existing generators in Fairseq2
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class LCMGeneratorOptions:
|
| 31 |
+
"""Holds the options to pass to a sequence generator."""
|
| 32 |
+
|
| 33 |
+
max_seq_len: int = 200
|
| 34 |
+
"""The hard limit on maximum length of generated sequences."""
|
| 35 |
+
|
| 36 |
+
min_seq_len: int = 1
|
| 37 |
+
"""The minimum length of generated sequences."""
|
| 38 |
+
|
| 39 |
+
eos_threshold: Optional[float] = 0.9
|
| 40 |
+
"""Threshold for cosine similarity to the EOS vector"""
|
| 41 |
+
|
| 42 |
+
sample_latent_variable: bool = True
|
| 43 |
+
"""When using VAE models, whether to return the mean or sample"""
|
| 44 |
+
|
| 45 |
+
stop_on_repetition_cosine_threshold: Optional[float] = None
|
| 46 |
+
"""Stop the generation when the similarity of two consecutive concepts is above the threshold."""
|
| 47 |
+
|
| 48 |
+
include_eos_token: bool = False
|
| 49 |
+
"""Whether the eos token should be included in the hypotheses (matters only if they are trimmed)."""
|
| 50 |
+
|
| 51 |
+
trim_hypotheses: bool = False
|
| 52 |
+
"""Whether the tokens after the EOS token should be included in the hypotheses."""
|
| 53 |
+
|
| 54 |
+
seed: Optional[int] = None
|
| 55 |
+
"""Seed to make generation deterministic"""
|
| 56 |
+
|
| 57 |
+
lcm_temperature: float = 1.0
|
| 58 |
+
"""Temperature for decoding in the LCM"""
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class LCMGenerator:
|
| 62 |
+
"""Generates with an LCM model."""
|
| 63 |
+
|
| 64 |
+
def __init__(
|
| 65 |
+
self,
|
| 66 |
+
model: AbstractLCModel,
|
| 67 |
+
options: Optional[LCMGeneratorOptions] = None,
|
| 68 |
+
eos_vec: Optional[torch.Tensor] = None,
|
| 69 |
+
) -> None:
|
| 70 |
+
"""
|
| 71 |
+
:param model:
|
| 72 |
+
The LC model to use for generation.
|
| 73 |
+
"""
|
| 74 |
+
model.eval()
|
| 75 |
+
self.model = model
|
| 76 |
+
|
| 77 |
+
if options is None:
|
| 78 |
+
options = LCMGeneratorOptions()
|
| 79 |
+
|
| 80 |
+
self.eos_vec = eos_vec
|
| 81 |
+
if self.eos_vec is None and options.eos_threshold:
|
| 82 |
+
logger.warning(
|
| 83 |
+
f"eos_threshold is set to {options.eos_threshold}, but eos_vec is not provided"
|
| 84 |
+
)
|
| 85 |
+
if options.eos_threshold:
|
| 86 |
+
logger.debug(f"The eos_vec in generator has been set to {self.eos_vec}")
|
| 87 |
+
|
| 88 |
+
self.options = options
|
| 89 |
+
|
| 90 |
+
self.max_seq_len = options.max_seq_len
|
| 91 |
+
self.min_seq_len = options.min_seq_len
|
| 92 |
+
|
| 93 |
+
assert self.min_seq_len >= 1, (
|
| 94 |
+
f"min_seq_len must be greater than or equal to 1, min_seq_len={options.min_seq_len}"
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
self.eos_threshold = options.eos_threshold
|
| 98 |
+
|
| 99 |
+
self.seqs: torch.Tensor
|
| 100 |
+
self.step_nr = 0
|
| 101 |
+
self.min_prompt_len: int
|
| 102 |
+
self.max_prompt_len: int
|
| 103 |
+
self.sample_indices: torch.Tensor
|
| 104 |
+
self.state_bag: Optional[LCMIncrementalStateBag] = None
|
| 105 |
+
self.prompt_seq_lens: Optional[torch.Tensor] = None
|
| 106 |
+
self.prompt_padding_mask: Optional[torch.Tensor] = None
|
| 107 |
+
self.lengths: torch.Tensor
|
| 108 |
+
self.step_scores: torch.Tensor
|
| 109 |
+
|
| 110 |
+
@torch.inference_mode()
|
| 111 |
+
def __call__(
|
| 112 |
+
self,
|
| 113 |
+
batch_input: EmbeddingsBatch,
|
| 114 |
+
max_gen_len: Optional[int] = None,
|
| 115 |
+
min_gen_len: Optional[int] = None,
|
| 116 |
+
temperature: float = 0.0,
|
| 117 |
+
disable_cache: bool = False,
|
| 118 |
+
**kwargs,
|
| 119 |
+
) -> SequenceGeneratorOutput:
|
| 120 |
+
"""
|
| 121 |
+
:param input:
|
| 122 |
+
`bacth_input` embedded and padded tensor sequence of the inputs
|
| 123 |
+
`max_gen_len` max length to be generated for the given input
|
| 124 |
+
`min_gen_len` minimum length to be generated for the given input
|
| 125 |
+
`temperature` temperature to control the generation
|
| 126 |
+
`disable_cache` if True, do not use kv-caching
|
| 127 |
+
:returns:
|
| 128 |
+
The output of the LCM generator, consists of :math:`N` lists of
|
| 129 |
+
hypotheses for :math:`N` prompts. Each list has 1 Hypothesis
|
| 130 |
+
(beam size = 1), of which `seq` has the *Shape:* math:`(S+T, D)`
|
| 131 |
+
(:math:`S` is the prompt length, :math:`T` the length of the
|
| 132 |
+
generated sequence after the prompt and :math:`D` the model
|
| 133 |
+
dimension.)
|
| 134 |
+
|
| 135 |
+
"""
|
| 136 |
+
if self.options.seed:
|
| 137 |
+
torch.manual_seed(self.options.seed)
|
| 138 |
+
|
| 139 |
+
# Setup the variables
|
| 140 |
+
batch_size, self.max_prompt_len, embed_dim = batch_input.seqs.size()
|
| 141 |
+
prompt_padding_mask = batch_input.padding_mask
|
| 142 |
+
if prompt_padding_mask is None:
|
| 143 |
+
self.min_prompt_len = self.max_prompt_len
|
| 144 |
+
self.prompt_padding_mask = None
|
| 145 |
+
self.prompt_seq_lens = None
|
| 146 |
+
else:
|
| 147 |
+
self.prompt_seq_lens = prompt_padding_mask.seq_lens
|
| 148 |
+
assert self.prompt_seq_lens is not None, (
|
| 149 |
+
"Expecting a valid `self.prompt_seq_lens` Tensor, found `None`"
|
| 150 |
+
)
|
| 151 |
+
self.min_prompt_len = int(torch.min(self.prompt_seq_lens, dim=0)[0].item())
|
| 152 |
+
|
| 153 |
+
# Keep the materialized mask
|
| 154 |
+
self.prompt_padding_mask = prompt_padding_mask.materialize()
|
| 155 |
+
|
| 156 |
+
if not max_gen_len:
|
| 157 |
+
max_gen_len = self.max_seq_len
|
| 158 |
+
|
| 159 |
+
# Make sure we do not accidentally set a max_gen_len that exceeds
|
| 160 |
+
# the generator's model capability
|
| 161 |
+
assert max_gen_len <= self.max_seq_len, (
|
| 162 |
+
f"Generator can generate up to {self.max_seq_len} sequences, max_gen_len={max_gen_len}"
|
| 163 |
+
)
|
| 164 |
+
self.max_gen_len = max_gen_len
|
| 165 |
+
|
| 166 |
+
if not min_gen_len:
|
| 167 |
+
min_gen_len = self.min_seq_len
|
| 168 |
+
|
| 169 |
+
assert min_gen_len > 0, (
|
| 170 |
+
f"min_gen_len must be greater than or equal to 1, min_gen_len={min_gen_len}"
|
| 171 |
+
)
|
| 172 |
+
self.min_gen_len = min_gen_len
|
| 173 |
+
|
| 174 |
+
if temperature == 0.0:
|
| 175 |
+
# If the call doesn't pass a specific temperature,
|
| 176 |
+
# use the default one from the decoding options
|
| 177 |
+
temperature = self.options.lcm_temperature
|
| 178 |
+
|
| 179 |
+
self.temperature = temperature
|
| 180 |
+
|
| 181 |
+
for k, v in kwargs.items():
|
| 182 |
+
if hasattr(self.options, k) and v:
|
| 183 |
+
setattr(self.options, k, v)
|
| 184 |
+
|
| 185 |
+
# Holds the generated sequences, scores and sample-dependent variables
|
| 186 |
+
dtype = self.model.dtype
|
| 187 |
+
device = batch_input.seqs.device
|
| 188 |
+
|
| 189 |
+
if disable_cache:
|
| 190 |
+
self.state_bag = None
|
| 191 |
+
else:
|
| 192 |
+
self.state_bag = LCMIncrementalStateBag(
|
| 193 |
+
self.max_prompt_len + self.max_gen_len
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# reserving full sequences capacity
|
| 197 |
+
self.seqs = torch.zeros(
|
| 198 |
+
(batch_size, self.max_prompt_len + self.max_gen_len, embed_dim),
|
| 199 |
+
device=device,
|
| 200 |
+
dtype=dtype,
|
| 201 |
+
)
|
| 202 |
+
self.step_scores = torch.zeros(
|
| 203 |
+
(batch_size, self.max_prompt_len + self.max_gen_len),
|
| 204 |
+
device=device,
|
| 205 |
+
)
|
| 206 |
+
self.lengths = torch.zeros(batch_size, dtype=torch.int, device=device) - 1
|
| 207 |
+
|
| 208 |
+
# Hold the samples indices to return in order
|
| 209 |
+
self.sample_indices = torch.arange(batch_size, device=device)
|
| 210 |
+
# Output buffer
|
| 211 |
+
self.hypotheses: List[List[Hypothesis]] = [[] for _ in range(batch_size)]
|
| 212 |
+
|
| 213 |
+
# Bootstrap the sequences with the provided prompt.
|
| 214 |
+
self.seqs[:, : self.max_prompt_len] = batch_input.seqs[:, : self.max_prompt_len]
|
| 215 |
+
self.step_nr = self.min_prompt_len
|
| 216 |
+
self.prefill(**kwargs)
|
| 217 |
+
|
| 218 |
+
for self.step_nr in range(
|
| 219 |
+
self.min_prompt_len, self.max_prompt_len + self.max_gen_len
|
| 220 |
+
):
|
| 221 |
+
if not self._step():
|
| 222 |
+
break
|
| 223 |
+
|
| 224 |
+
return SequenceGeneratorOutput(self.hypotheses, counters=GenerationCounters())
|
| 225 |
+
|
| 226 |
+
@torch.inference_mode()
|
| 227 |
+
def prefill(self, **kwargs) -> None:
|
| 228 |
+
"""The initial forward pass in the decoder with the prefix/prompt
|
| 229 |
+
to populate the KV-cache"""
|
| 230 |
+
|
| 231 |
+
if self.state_bag is None:
|
| 232 |
+
return
|
| 233 |
+
|
| 234 |
+
# Prefilling with -1 since the next call to step will use the last token in the prefix
|
| 235 |
+
prefill_len = self.step_nr - 1
|
| 236 |
+
|
| 237 |
+
if prefill_len > 0:
|
| 238 |
+
_ = self._decode(
|
| 239 |
+
self.seqs[:, :prefill_len],
|
| 240 |
+
padding_mask=None,
|
| 241 |
+
)
|
| 242 |
+
self.state_bag.increment_step_nr(prefill_len) # type: ignore
|
| 243 |
+
else:
|
| 244 |
+
logger.warning(
|
| 245 |
+
f"Skipping prefill since only a context size of {self.step_nr} is provided in the prefix"
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
@torch.inference_mode()
|
| 249 |
+
def _decode(
|
| 250 |
+
self,
|
| 251 |
+
seqs: torch.Tensor,
|
| 252 |
+
padding_mask: Optional[PaddingMask],
|
| 253 |
+
**kwargs,
|
| 254 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 255 |
+
output = self.model.predict_next_sentence(
|
| 256 |
+
EmbeddingsBatch(seqs, padding_mask),
|
| 257 |
+
sample=self.options.sample_latent_variable,
|
| 258 |
+
temperature=self.temperature,
|
| 259 |
+
state_bag=self.state_bag,
|
| 260 |
+
**kwargs,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
# Dummy scores
|
| 264 |
+
scores = torch.zeros(seqs.shape[:-1])
|
| 265 |
+
return output.seqs, scores
|
| 266 |
+
|
| 267 |
+
def _step(self) -> bool:
|
| 268 |
+
# Generate the next step output.
|
| 269 |
+
|
| 270 |
+
if self.state_bag is None:
|
| 271 |
+
# Without a state_bag, we're forwarding the full prefix
|
| 272 |
+
model_output, step_score = self._decode(
|
| 273 |
+
seqs=self.seqs[:, : self.step_nr],
|
| 274 |
+
padding_mask=None,
|
| 275 |
+
)
|
| 276 |
+
else:
|
| 277 |
+
# Since we're using a state_bag, we're only forwarding the last embedding
|
| 278 |
+
model_output, step_score = self._decode(
|
| 279 |
+
seqs=self.seqs[:, self.step_nr - 1 : self.step_nr],
|
| 280 |
+
padding_mask=None,
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
self.state_bag.increment_step_nr()
|
| 284 |
+
|
| 285 |
+
# model_output: EmbeddingBag
|
| 286 |
+
return self.finalize_step(model_output, step_score)
|
| 287 |
+
|
| 288 |
+
def finalize_step(
|
| 289 |
+
self, model_output: torch.Tensor, step_score: torch.Tensor
|
| 290 |
+
) -> bool:
|
| 291 |
+
"""Post-processing and finalizing a step
|
| 292 |
+
by checking all stopping criteria
|
| 293 |
+
Takes the model's outputed embeddings (model_output)
|
| 294 |
+
and their associated scores (step_score)
|
| 295 |
+
If we're stepping, return True, else return False
|
| 296 |
+
"""
|
| 297 |
+
already_finished = self.lengths > -1
|
| 298 |
+
should_finish_now = torch.zeros_like(already_finished)
|
| 299 |
+
|
| 300 |
+
model_last_output = model_output[:, -1]
|
| 301 |
+
device = model_last_output.device
|
| 302 |
+
|
| 303 |
+
# Ignore prompt positions between min-max prompt_len
|
| 304 |
+
must_keep_going = None
|
| 305 |
+
if self.step_nr < self.max_prompt_len:
|
| 306 |
+
assert self.prompt_padding_mask is not None, (
|
| 307 |
+
f"If self.prompt_padding_mas is None, then self.step_nr should start from self.max_prompt_len={self.max_prompt_len} - currently self.step_nr = {self.step_nr}"
|
| 308 |
+
)
|
| 309 |
+
mask = self.prompt_padding_mask[:, self.step_nr]
|
| 310 |
+
model_last_output[mask] = self.seqs[mask, self.step_nr]
|
| 311 |
+
must_keep_going = mask
|
| 312 |
+
|
| 313 |
+
# Check stopping based on EOS similarity.
|
| 314 |
+
if self.eos_threshold is not None and self.eos_vec is not None:
|
| 315 |
+
sim2eos = torch.nn.functional.cosine_similarity(
|
| 316 |
+
self.eos_vec.to(device), model_last_output
|
| 317 |
+
)
|
| 318 |
+
logger.debug(f"Similarity to eos vector: {sim2eos} vs {self.eos_threshold}")
|
| 319 |
+
should_finish_now = should_finish_now | sim2eos.ge(self.eos_threshold)
|
| 320 |
+
|
| 321 |
+
# Check stopping based on repetition.
|
| 322 |
+
if (
|
| 323 |
+
self.options.stop_on_repetition_cosine_threshold is not None
|
| 324 |
+
and self.step_nr > 0
|
| 325 |
+
):
|
| 326 |
+
sim2prev = torch.nn.functional.cosine_similarity(
|
| 327 |
+
self.seqs[:, self.step_nr - 1], model_last_output
|
| 328 |
+
)
|
| 329 |
+
logger.debug(
|
| 330 |
+
f"Similarity to prev vector: {sim2prev} vs {self.options.stop_on_repetition_cosine_threshold}"
|
| 331 |
+
)
|
| 332 |
+
should_finish_now = should_finish_now | sim2prev.ge(
|
| 333 |
+
self.options.stop_on_repetition_cosine_threshold
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
if must_keep_going is not None:
|
| 337 |
+
logger.debug(
|
| 338 |
+
f"Must keep going (to cover max_prompt_len={self.max_prompt_len}) is not None = {must_keep_going}"
|
| 339 |
+
)
|
| 340 |
+
should_finish_now = should_finish_now & ~must_keep_going
|
| 341 |
+
|
| 342 |
+
# Keep going if output is shorter than min_gen_len:
|
| 343 |
+
if self.prompt_seq_lens is not None:
|
| 344 |
+
longer_than_min_gen_len = (self.step_nr - self.prompt_seq_lens).ge(
|
| 345 |
+
self.min_gen_len
|
| 346 |
+
)
|
| 347 |
+
else:
|
| 348 |
+
longer_than_min_gen_len = (
|
| 349 |
+
self.step_nr - self.max_prompt_len
|
| 350 |
+
) >= self.min_gen_len
|
| 351 |
+
|
| 352 |
+
logger.debug(
|
| 353 |
+
f"Longer than min_gen_len ({self.min_gen_len}) = {longer_than_min_gen_len}"
|
| 354 |
+
)
|
| 355 |
+
should_finish_now = should_finish_now & longer_than_min_gen_len
|
| 356 |
+
stopped_on_eos = should_finish_now
|
| 357 |
+
|
| 358 |
+
# Stop hypotheses that reached max_gen_len
|
| 359 |
+
if self.prompt_seq_lens is not None:
|
| 360 |
+
exceeds_max_gen_len = (self.step_nr - self.prompt_seq_lens + 1).ge(
|
| 361 |
+
self.max_gen_len
|
| 362 |
+
)
|
| 363 |
+
logger.debug(
|
| 364 |
+
f"step: {self.step_nr}; max_gen_len: {self.max_gen_len}; promt_lens: {self.prompt_seq_lens}; steps exceeded: {self.max_gen_len + self.prompt_seq_lens}"
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
else:
|
| 368 |
+
exceeds_max_gen_len = (
|
| 369 |
+
self.step_nr - self.max_prompt_len + 1
|
| 370 |
+
) >= self.max_gen_len
|
| 371 |
+
logger.debug(
|
| 372 |
+
f"step: {self.step_nr}; max_gen_len: {self.max_gen_len}; promt_lens: None (unique length: {self.max_prompt_len}); steps exceeded: {self.max_prompt_len + self.max_gen_len}"
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
logger.debug(
|
| 376 |
+
f"Stopping criteria: {should_finish_now}; exceeds max len: {exceeds_max_gen_len}; already finished: {already_finished}"
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
should_finish_now = should_finish_now | exceeds_max_gen_len
|
| 380 |
+
|
| 381 |
+
# Assign lengths to the sequences that have just finished.
|
| 382 |
+
should_finish_now = should_finish_now & ~already_finished
|
| 383 |
+
self.lengths[should_finish_now] = self.step_nr + 1
|
| 384 |
+
|
| 385 |
+
# Record the current step.
|
| 386 |
+
self.seqs[:, self.step_nr] = model_last_output.squeeze(1)
|
| 387 |
+
self.step_scores[:, self.step_nr - self.min_prompt_len] = step_score[:, -1]
|
| 388 |
+
|
| 389 |
+
# Save completed hypsptheses
|
| 390 |
+
finished_mask = self.lengths.ne(-1)
|
| 391 |
+
finished_indices = finished_mask.nonzero()
|
| 392 |
+
|
| 393 |
+
# Remove finished hypotheses and reorder variables/state_bag if any are left
|
| 394 |
+
if len(finished_indices) > 0:
|
| 395 |
+
for idx in finished_indices:
|
| 396 |
+
self.finish_sequence(int(idx), is_eos=bool(stopped_on_eos[int(idx)]))
|
| 397 |
+
|
| 398 |
+
active_mask = ~finished_mask
|
| 399 |
+
active_indices = active_mask.nonzero().squeeze(-1)
|
| 400 |
+
|
| 401 |
+
if len(active_indices) == 0:
|
| 402 |
+
return False
|
| 403 |
+
|
| 404 |
+
self.reorder_state(active_indices)
|
| 405 |
+
|
| 406 |
+
return True
|
| 407 |
+
|
| 408 |
+
def finish_sequence(self, idx: int, is_eos: bool = False) -> None:
|
| 409 |
+
seq_len = int(self.lengths[idx].item())
|
| 410 |
+
|
| 411 |
+
if self.options.trim_hypotheses and self.lengths[idx].item() > -1 and is_eos:
|
| 412 |
+
seq_len = int(self.lengths[idx].item()) - int(
|
| 413 |
+
not self.options.include_eos_token
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
sample_idx = int(self.sample_indices[idx])
|
| 417 |
+
self.hypotheses[sample_idx] = [
|
| 418 |
+
Hypothesis(
|
| 419 |
+
seq=self.seqs[idx, :seq_len],
|
| 420 |
+
score=None,
|
| 421 |
+
step_scores=self.step_scores[idx], # Trim it as well?
|
| 422 |
+
)
|
| 423 |
+
]
|
| 424 |
+
|
| 425 |
+
def state_bag_reorder(self, new_order: torch.Tensor) -> None:
|
| 426 |
+
if self.state_bag is not None:
|
| 427 |
+
self.state_bag.reorder(new_order)
|
| 428 |
+
|
| 429 |
+
def reorder_state(self, new_order: torch.Tensor) -> None:
|
| 430 |
+
self.state_bag_reorder(new_order)
|
| 431 |
+
|
| 432 |
+
self.seqs = self.seqs.index_select(dim=0, index=new_order)
|
| 433 |
+
|
| 434 |
+
self.sample_indices = self.sample_indices.index_select(dim=0, index=new_order)
|
| 435 |
+
|
| 436 |
+
self.step_scores = self.step_scores.index_select(dim=0, index=new_order)
|
| 437 |
+
|
| 438 |
+
self.lengths = self.lengths.index_select(dim=0, index=new_order)
|
| 439 |
+
|
| 440 |
+
if self.prompt_padding_mask is not None:
|
| 441 |
+
self.prompt_padding_mask = self.prompt_padding_mask.index_select(
|
| 442 |
+
dim=0, index=new_order
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
if self.prompt_seq_lens is not None:
|
| 446 |
+
self.prompt_seq_lens = self.prompt_seq_lens.index_select(
|
| 447 |
+
dim=0, index=new_order
|
| 448 |
+
)
|
lcm/inference/lcm/scorer.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from typing import List, Optional
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from fairseq2.generation.generator import (
|
| 10 |
+
GenerationCounters,
|
| 11 |
+
Hypothesis,
|
| 12 |
+
SequenceGeneratorOutput,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
from lcm.datasets.batch import EmbeddingsBatch
|
| 16 |
+
from lcm.inference.lcm.generator import LCMGenerator, LCMGeneratorOptions
|
| 17 |
+
from lcm.nn.incremental_state import LCMIncrementalStateBag
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class LCMScorer(LCMGenerator):
|
| 21 |
+
"""Generates with an LCM model in teacher-forcing mode."""
|
| 22 |
+
|
| 23 |
+
options: LCMGeneratorOptions
|
| 24 |
+
|
| 25 |
+
@torch.inference_mode()
|
| 26 |
+
def __call__( # type: ignore
|
| 27 |
+
self,
|
| 28 |
+
batch_input: EmbeddingsBatch,
|
| 29 |
+
max_gen_len: Optional[int] = None,
|
| 30 |
+
min_gen_len: Optional[int] = None,
|
| 31 |
+
min_context_len: int = 1,
|
| 32 |
+
temperature: float = 0.0,
|
| 33 |
+
disable_cache: bool = False,
|
| 34 |
+
) -> SequenceGeneratorOutput:
|
| 35 |
+
"""
|
| 36 |
+
:param input:
|
| 37 |
+
`bacth_input` embedded and padded tensor sequence of the inputs
|
| 38 |
+
`max_gen_len` max length to be generated for the given input
|
| 39 |
+
`min_gen_len` minimum length to be generated for the given input
|
| 40 |
+
`disable_cache` if True, do not use kv-caching
|
| 41 |
+
:returns:
|
| 42 |
+
The output of the LCM generator, consists of :math:`N` lists of
|
| 43 |
+
hypotheses for :math:`N` documents. Each list has 1 Hypothesis
|
| 44 |
+
(beam size = 1), of which `seq` has the *Shape:* math:`(T, D)`
|
| 45 |
+
(:math:`T` the length of the document and :math:`D` the model
|
| 46 |
+
dimension
|
| 47 |
+
|
| 48 |
+
"""
|
| 49 |
+
if self.options.seed:
|
| 50 |
+
torch.manual_seed(self.options.seed)
|
| 51 |
+
|
| 52 |
+
# Setup the variables
|
| 53 |
+
self.min_context_len = min_context_len
|
| 54 |
+
batch_size, self.max_text_len, embed_dim = batch_input.seqs.size()
|
| 55 |
+
text_padding_mask = batch_input.padding_mask
|
| 56 |
+
if text_padding_mask is None:
|
| 57 |
+
self.text_padding_mask = None
|
| 58 |
+
self.text_seq_lens = self.max_text_len * torch.ones(
|
| 59 |
+
batch_size,
|
| 60 |
+
dtype=torch.long,
|
| 61 |
+
device=batch_input.seqs.device,
|
| 62 |
+
)
|
| 63 |
+
else:
|
| 64 |
+
self.text_seq_lens = text_padding_mask.seq_lens
|
| 65 |
+
assert self.text_seq_lens is not None, (
|
| 66 |
+
"Expecting a valid `self.text_seq_lens` Tensor, found `None`"
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# Keep the materialized mask
|
| 70 |
+
self.text_padding_mask = text_padding_mask.materialize()
|
| 71 |
+
|
| 72 |
+
if not max_gen_len:
|
| 73 |
+
max_gen_len = self.max_seq_len
|
| 74 |
+
|
| 75 |
+
max_gen_len = min(max_gen_len, self.max_text_len - self.min_context_len)
|
| 76 |
+
assert max_gen_len is not None, "max_gen_len is None"
|
| 77 |
+
|
| 78 |
+
# Make sure we do not accidentally set a max_gen_len that exceeds
|
| 79 |
+
# the generator's model capability
|
| 80 |
+
assert max_gen_len <= self.max_seq_len, (
|
| 81 |
+
f"Generator can generate up to {self.max_seq_len} sequences, max_gen_len={max_gen_len}"
|
| 82 |
+
)
|
| 83 |
+
self.max_gen_len = max_gen_len
|
| 84 |
+
|
| 85 |
+
if not min_gen_len:
|
| 86 |
+
min_gen_len = self.min_seq_len
|
| 87 |
+
|
| 88 |
+
assert min_gen_len > 0, (
|
| 89 |
+
f"min_gen_len must be greater than or equal to 1, min_gen_len={min_gen_len}"
|
| 90 |
+
)
|
| 91 |
+
self.min_gen_len = min_gen_len
|
| 92 |
+
|
| 93 |
+
if temperature == 0.0:
|
| 94 |
+
# If the call doesn't pass a specific temperature,
|
| 95 |
+
# use the default one from the decoding options
|
| 96 |
+
temperature = self.options.lcm_temperature
|
| 97 |
+
|
| 98 |
+
# Holds the generated sequences, scores and sample-dependent variables
|
| 99 |
+
dtype = self.model.dtype
|
| 100 |
+
device = batch_input.seqs.device
|
| 101 |
+
self.temperature = temperature
|
| 102 |
+
|
| 103 |
+
if disable_cache:
|
| 104 |
+
self.state_bag = None
|
| 105 |
+
else:
|
| 106 |
+
self.state_bag = LCMIncrementalStateBag(self.max_text_len)
|
| 107 |
+
|
| 108 |
+
# reserving full sequences capacity
|
| 109 |
+
self.seqs = batch_input.seqs
|
| 110 |
+
self.preds = torch.zeros(
|
| 111 |
+
(batch_size, self.max_text_len - self.min_context_len, embed_dim),
|
| 112 |
+
device=device,
|
| 113 |
+
dtype=dtype,
|
| 114 |
+
)
|
| 115 |
+
self.step_scores = torch.zeros(
|
| 116 |
+
(batch_size, self.max_text_len),
|
| 117 |
+
device=device,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Hold the samples indices to return in order
|
| 121 |
+
self.sample_indices = torch.arange(batch_size, device=device)
|
| 122 |
+
# Output buffer
|
| 123 |
+
self.hypotheses: List[List[Hypothesis]] = [[] for _ in range(batch_size)]
|
| 124 |
+
|
| 125 |
+
# the sequences with the provided prompt.
|
| 126 |
+
self.step_nr = self.min_context_len
|
| 127 |
+
self.prefill()
|
| 128 |
+
|
| 129 |
+
for self.step_nr in range(self.min_context_len, self.max_text_len):
|
| 130 |
+
if not self._step():
|
| 131 |
+
break
|
| 132 |
+
|
| 133 |
+
return SequenceGeneratorOutput(self.hypotheses, counters=GenerationCounters())
|
| 134 |
+
|
| 135 |
+
def finalize_step(
|
| 136 |
+
self, model_output: torch.Tensor, step_score: torch.Tensor
|
| 137 |
+
) -> bool:
|
| 138 |
+
"""Post-processing and finalizing a step
|
| 139 |
+
by checking all stopping criteria
|
| 140 |
+
Takes the model's outputed embeddings (model_output)
|
| 141 |
+
and their associated scores (step_score)
|
| 142 |
+
If we're stepping, return True, else return False
|
| 143 |
+
"""
|
| 144 |
+
model_last_output = model_output[:, -1]
|
| 145 |
+
must_keep_going = self.text_seq_lens.gt(self.step_nr + 1)
|
| 146 |
+
should_finish_now = ~must_keep_going
|
| 147 |
+
|
| 148 |
+
# Record the current step prediction.
|
| 149 |
+
self.preds[:, self.step_nr - self.min_context_len] = model_last_output.squeeze(
|
| 150 |
+
1
|
| 151 |
+
)
|
| 152 |
+
self.step_scores[:, self.step_nr - self.min_context_len] = step_score[:, -1]
|
| 153 |
+
|
| 154 |
+
# Save completed hypotheses
|
| 155 |
+
finished_indices = should_finish_now.nonzero()
|
| 156 |
+
|
| 157 |
+
# Remove finished hypotheses and reorder variables/state_bag if any are left
|
| 158 |
+
if len(finished_indices) > 0:
|
| 159 |
+
for idx in finished_indices:
|
| 160 |
+
self.finish_sequence(int(idx))
|
| 161 |
+
|
| 162 |
+
active_mask = must_keep_going
|
| 163 |
+
active_indices = active_mask.nonzero().squeeze(-1)
|
| 164 |
+
|
| 165 |
+
if len(active_indices) == 0:
|
| 166 |
+
return False
|
| 167 |
+
|
| 168 |
+
self.reorder_state(active_indices)
|
| 169 |
+
|
| 170 |
+
return True
|
| 171 |
+
|
| 172 |
+
def finish_sequence(self, idx: int, is_eos: bool = False) -> None:
|
| 173 |
+
seq_len = int(self.text_seq_lens[idx].item())
|
| 174 |
+
sample_idx = int(self.sample_indices[idx])
|
| 175 |
+
self.hypotheses[sample_idx] = [
|
| 176 |
+
Hypothesis(
|
| 177 |
+
seq=self.preds[idx, : seq_len - self.min_context_len],
|
| 178 |
+
score=None,
|
| 179 |
+
step_scores=self.step_scores[idx], # Trim it as well?
|
| 180 |
+
)
|
| 181 |
+
]
|
| 182 |
+
|
| 183 |
+
def reorder_state(self, new_order: torch.Tensor) -> None:
|
| 184 |
+
self.state_bag_reorder(new_order)
|
| 185 |
+
|
| 186 |
+
self.seqs = self.seqs.index_select(dim=0, index=new_order)
|
| 187 |
+
self.preds = self.preds.index_select(dim=0, index=new_order)
|
| 188 |
+
|
| 189 |
+
self.sample_indices = self.sample_indices.index_select(dim=0, index=new_order)
|
| 190 |
+
|
| 191 |
+
self.step_scores = self.step_scores.index_select(dim=0, index=new_order)
|
| 192 |
+
|
| 193 |
+
if self.text_padding_mask is not None:
|
| 194 |
+
self.text_padding_mask = self.text_padding_mask.index_select(
|
| 195 |
+
dim=0, index=new_order
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
self.text_seq_lens = self.text_seq_lens.index_select(dim=0, index=new_order)
|
lcm/inference/two_tower_diffusion_lcm/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from lcm.inference.two_tower_diffusion_lcm.generator import (
|
| 7 |
+
DiffusionLCMGeneratorOptions as DiffusionLCMGeneratorOptions,
|
| 8 |
+
)
|
| 9 |
+
from lcm.inference.two_tower_diffusion_lcm.generator import (
|
| 10 |
+
TwoTowerDiffusionLCMGenerator as TwoTowerDiffusionLCMGenerator,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
__all__ = [
|
| 14 |
+
"TwoTowerDiffusionLCMGenerator",
|
| 15 |
+
"DiffusionLCMGeneratorOptions",
|
| 16 |
+
]
|
lcm/inference/two_tower_diffusion_lcm/generator.py
ADDED
|
@@ -0,0 +1,466 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import List, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from fairseq2.generation.generator import (
|
| 11 |
+
GenerationCounters,
|
| 12 |
+
Hypothesis,
|
| 13 |
+
SequenceGeneratorOutput,
|
| 14 |
+
)
|
| 15 |
+
from fairseq2.logging import get_log_writer
|
| 16 |
+
|
| 17 |
+
from lcm.datasets.batch import EmbeddingsBatch, PaddingMask
|
| 18 |
+
from lcm.inference.lcm.generator import (
|
| 19 |
+
LCMGenerator,
|
| 20 |
+
LCMGeneratorOptions,
|
| 21 |
+
)
|
| 22 |
+
from lcm.models.abstract_lcm import AbstractLCModel
|
| 23 |
+
from lcm.models.two_tower_diffusion_lcm.builder import TwoTowerDiffusionLCModel
|
| 24 |
+
from lcm.nn.incremental_state import LCMIncrementalStateBag
|
| 25 |
+
|
| 26 |
+
logger = get_log_writer(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class DiffusionLCMGeneratorOptions(LCMGeneratorOptions):
|
| 31 |
+
"""Holds the options to pass to a diffusion-based sequence generator."""
|
| 32 |
+
|
| 33 |
+
guidance_scale: float = 1.0
|
| 34 |
+
"""The weight of the regular classifier-free guidance.
|
| 35 |
+
Here `guidance_scale` is defined as the guidance weight `w` of
|
| 36 |
+
Equation (2) of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf.
|
| 37 |
+
`guidance_scale = 1` corresponds to doing no classifier free guidance.
|
| 38 |
+
A higher guidance scale value encourages the model to generate outputs
|
| 39 |
+
closely related to the `prompt` at the expense of lower quality."""
|
| 40 |
+
|
| 41 |
+
guidance_rescale: float = 0.0
|
| 42 |
+
"""The rescaling factor for Classifier-Free Guidance with Rescale
|
| 43 |
+
(Algorithm 2 - https://arxiv.org/pdf/2305.08891)"""
|
| 44 |
+
|
| 45 |
+
ddim_eta: float = 0.0
|
| 46 |
+
"""The weight of noise for added noise in diffusion step.
|
| 47 |
+
It controls the level of interpolation between a deterministic
|
| 48 |
+
DDIM (at eta=0.0) and a stochastic DDPM (at eta = 1.0)
|
| 49 |
+
See section 5 of the DDIM paper https://arxiv.org/pdf/2010.02502 """
|
| 50 |
+
|
| 51 |
+
epsilon_scaling: Optional[float] = None
|
| 52 |
+
"""epsilon_scaling: Optional[float] if not None, the predicted epsilon will
|
| 53 |
+
be scaled down by the provided factor as
|
| 54 |
+
introduced in https://arxiv.org/pdf/2308.15321""" ""
|
| 55 |
+
|
| 56 |
+
initial_noise_scale: float = 1.0
|
| 57 |
+
"""For Diffusion models, scaling of initial noise"""
|
| 58 |
+
|
| 59 |
+
inference_timesteps: int = 100
|
| 60 |
+
"""For Diffusion models, number of denoising timesteps"""
|
| 61 |
+
|
| 62 |
+
clip_noise: int = 100
|
| 63 |
+
"""For Diffusion models, factor to clip noise of the sampling steps"""
|
| 64 |
+
|
| 65 |
+
thresholding: bool = False
|
| 66 |
+
"""Whether to use the "dynamic thresholding" method.
|
| 67 |
+
This is unsuitable for latent-space diffusion models such as Stable Diffusion."""
|
| 68 |
+
|
| 69 |
+
dynamic_thresholding_ratio: float = 0.995
|
| 70 |
+
"""The ratio for the dynamic thresholding method. Valid only when `thresholding=True`."""
|
| 71 |
+
|
| 72 |
+
sample_max_value: float = 6.0
|
| 73 |
+
"""The threshold value for dynamic thresholding. Valid only when `thresholding=True`."""
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class TwoTowerDiffusionLCMGenerator(LCMGenerator):
|
| 77 |
+
"""Generates with a Two-tower Diffusion LCM model."""
|
| 78 |
+
|
| 79 |
+
options: DiffusionLCMGeneratorOptions
|
| 80 |
+
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
model: AbstractLCModel,
|
| 84 |
+
options: Optional[LCMGeneratorOptions] = None,
|
| 85 |
+
eos_vec: Optional[torch.Tensor] = None,
|
| 86 |
+
) -> None:
|
| 87 |
+
super().__init__(model, options, eos_vec)
|
| 88 |
+
|
| 89 |
+
assert isinstance(self.model, TwoTowerDiffusionLCModel), (
|
| 90 |
+
"The TwoTowerDiffusionLCMGenerator expects a Diffusion LCM"
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
logger.info(
|
| 94 |
+
f"Setting up the model with decoding_options: {options} -- {type(options)}"
|
| 95 |
+
)
|
| 96 |
+
model.prep_for_denoising(options)
|
| 97 |
+
|
| 98 |
+
@torch.inference_mode()
|
| 99 |
+
def __call__(
|
| 100 |
+
self,
|
| 101 |
+
batch_input: EmbeddingsBatch,
|
| 102 |
+
max_gen_len: Optional[int] = None,
|
| 103 |
+
min_gen_len: Optional[int] = None,
|
| 104 |
+
temperature: float = 0.0,
|
| 105 |
+
disable_cache: bool = False,
|
| 106 |
+
**kwargs,
|
| 107 |
+
) -> SequenceGeneratorOutput:
|
| 108 |
+
"""
|
| 109 |
+
:param input:
|
| 110 |
+
`bacth_input` embedded and padded tensor sequence of the inputs
|
| 111 |
+
`max_gen_len` max length to be generated for the given input
|
| 112 |
+
`min_gen_len` minimum length to be generated for the given input
|
| 113 |
+
`disable_cache` if True, do not use kv-caching
|
| 114 |
+
`temperature` temperature to control the generation
|
| 115 |
+
:returns:
|
| 116 |
+
The output of the LCM generator, consists of :math:`N` lists of
|
| 117 |
+
hypotheses for :math:`N` prompts. Each list has 1 Hypothesis
|
| 118 |
+
(beam size = 1), of which `seq` has the *Shape:* math:`(S+T, D)`
|
| 119 |
+
(:math:`S` is the prompt length, :math:`T` the length of the
|
| 120 |
+
generated sequence after the prompt and :math:`D` the model
|
| 121 |
+
dimension.)
|
| 122 |
+
|
| 123 |
+
"""
|
| 124 |
+
if self.options.seed:
|
| 125 |
+
torch.manual_seed(self.options.seed)
|
| 126 |
+
|
| 127 |
+
# Setup the variables
|
| 128 |
+
batch_size, self.max_prompt_len, embed_dim = batch_input.seqs.size()
|
| 129 |
+
prompt_padding_mask = batch_input.padding_mask
|
| 130 |
+
if prompt_padding_mask is None:
|
| 131 |
+
self.min_prompt_len = self.max_prompt_len
|
| 132 |
+
self.prompt_padding_mask = None
|
| 133 |
+
self.prompt_seq_lens = None
|
| 134 |
+
else:
|
| 135 |
+
self.prompt_seq_lens = prompt_padding_mask.seq_lens
|
| 136 |
+
assert self.prompt_seq_lens is not None, (
|
| 137 |
+
"Expecting a valid `self.prompt_seq_lens` Tensor, found `None`"
|
| 138 |
+
)
|
| 139 |
+
self.min_prompt_len = int(torch.min(self.prompt_seq_lens, dim=0)[0].item())
|
| 140 |
+
|
| 141 |
+
# Keep the materialized mask
|
| 142 |
+
self.prompt_padding_mask = prompt_padding_mask.materialize()
|
| 143 |
+
|
| 144 |
+
if not max_gen_len:
|
| 145 |
+
max_gen_len = self.max_seq_len
|
| 146 |
+
|
| 147 |
+
# Make sure we do not accidentally set a max_gen_len that exceeds
|
| 148 |
+
# the generator's model capability
|
| 149 |
+
assert max_gen_len <= self.max_seq_len, (
|
| 150 |
+
f"Generator can generate up to {self.max_seq_len} sequences, max_gen_len={max_gen_len}"
|
| 151 |
+
)
|
| 152 |
+
self.max_gen_len = max_gen_len
|
| 153 |
+
|
| 154 |
+
if not min_gen_len:
|
| 155 |
+
min_gen_len = self.min_seq_len
|
| 156 |
+
|
| 157 |
+
assert min_gen_len > 0, (
|
| 158 |
+
f"min_gen_len must be greater than or equal to 1, min_gen_len={min_gen_len}"
|
| 159 |
+
)
|
| 160 |
+
self.min_gen_len = min_gen_len
|
| 161 |
+
|
| 162 |
+
if temperature == 0.0:
|
| 163 |
+
# If the call doesn't pass a specific temperature,
|
| 164 |
+
# use the default one from the decoding options
|
| 165 |
+
temperature = self.options.lcm_temperature
|
| 166 |
+
|
| 167 |
+
# Holds the generated sequences, scores and sample-dependent variables
|
| 168 |
+
dtype = self.model.dtype
|
| 169 |
+
device = batch_input.seqs.device
|
| 170 |
+
self.temperature = temperature
|
| 171 |
+
|
| 172 |
+
if disable_cache:
|
| 173 |
+
self.state_bag = None
|
| 174 |
+
self.context_state_bag = None
|
| 175 |
+
else:
|
| 176 |
+
self.state_bag = LCMIncrementalStateBag(
|
| 177 |
+
self.max_prompt_len + self.max_gen_len
|
| 178 |
+
)
|
| 179 |
+
self.context_state_bag = LCMIncrementalStateBag(
|
| 180 |
+
self.max_prompt_len + self.max_gen_len
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# reserving full sequences capacity
|
| 184 |
+
self.seqs = torch.zeros(
|
| 185 |
+
(batch_size, self.max_prompt_len + self.max_gen_len, embed_dim),
|
| 186 |
+
device=device,
|
| 187 |
+
dtype=dtype,
|
| 188 |
+
)
|
| 189 |
+
self.step_scores = torch.zeros(
|
| 190 |
+
(batch_size, self.max_prompt_len + self.max_gen_len),
|
| 191 |
+
device=device,
|
| 192 |
+
)
|
| 193 |
+
self.lengths = torch.zeros(batch_size, dtype=torch.int, device=device) - 1
|
| 194 |
+
|
| 195 |
+
# Hold the samples indices to return in order
|
| 196 |
+
self.sample_indices = torch.arange(batch_size, device=device)
|
| 197 |
+
# Output buffer
|
| 198 |
+
self.hypotheses: List[List[Hypothesis]] = [[] for _ in range(batch_size)]
|
| 199 |
+
|
| 200 |
+
# Bootstrap the sequences with the provided prompt.
|
| 201 |
+
self.seqs[:, : self.max_prompt_len] = batch_input.seqs[:, : self.max_prompt_len]
|
| 202 |
+
self.step_nr = self.min_prompt_len
|
| 203 |
+
|
| 204 |
+
# A context we keep growing in each decoding step
|
| 205 |
+
self.prefill()
|
| 206 |
+
|
| 207 |
+
for self.step_nr in range(
|
| 208 |
+
self.min_prompt_len, self.max_prompt_len + self.max_gen_len
|
| 209 |
+
):
|
| 210 |
+
if not self._step():
|
| 211 |
+
break
|
| 212 |
+
|
| 213 |
+
return SequenceGeneratorOutput(self.hypotheses, counters=GenerationCounters())
|
| 214 |
+
|
| 215 |
+
def state_bag_reorder(self, new_order: torch.Tensor) -> None:
|
| 216 |
+
if self.state_bag is not None:
|
| 217 |
+
self.state_bag.reorder(new_order)
|
| 218 |
+
|
| 219 |
+
if self.context_state_bag is not None:
|
| 220 |
+
self.context_state_bag.reorder(new_order)
|
| 221 |
+
|
| 222 |
+
@torch.inference_mode()
|
| 223 |
+
def prefill(self, **kwargs) -> None:
|
| 224 |
+
"""encode the prefix with the context encoder"""
|
| 225 |
+
|
| 226 |
+
assert self.context_state_bag is not None, (
|
| 227 |
+
"Expecting a context state bag to prefill"
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
context: EmbeddingsBatch
|
| 231 |
+
|
| 232 |
+
prefill_len = self.step_nr - 1
|
| 233 |
+
if prefill_len > 0:
|
| 234 |
+
# normalize then encode
|
| 235 |
+
input_seqs = self.seqs[:, :prefill_len]
|
| 236 |
+
if self.model.config.sonar_normalizer_name is not None:
|
| 237 |
+
input_seqs = self.model.sonar_normalizer.normalize(input_seqs)
|
| 238 |
+
|
| 239 |
+
context = self.model.encode(
|
| 240 |
+
EmbeddingsBatch(input_seqs, None),
|
| 241 |
+
state_bag=self.context_state_bag,
|
| 242 |
+
**kwargs,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
self.context_state_bag.increment_step_nr(prefill_len)
|
| 246 |
+
|
| 247 |
+
else:
|
| 248 |
+
logger.warning(
|
| 249 |
+
f"Skipping prefill since only a context size of {self.step_nr} is provided in the prefix"
|
| 250 |
+
)
|
| 251 |
+
context = EmbeddingsBatch(
|
| 252 |
+
torch.empty(
|
| 253 |
+
(self.seqs.shape[0], 0, self.model.model_dim),
|
| 254 |
+
dtype=self.seqs.dtype,
|
| 255 |
+
device=self.seqs.device,
|
| 256 |
+
)
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
self.context = context
|
| 260 |
+
|
| 261 |
+
@torch.inference_mode()
|
| 262 |
+
def _decode(
|
| 263 |
+
self,
|
| 264 |
+
seqs: torch.Tensor,
|
| 265 |
+
padding_mask: Optional[PaddingMask] = None,
|
| 266 |
+
**kwargs,
|
| 267 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 268 |
+
output, context = self.model.predict_next_sentence(
|
| 269 |
+
batch=EmbeddingsBatch(seqs, padding_mask),
|
| 270 |
+
context=self.context,
|
| 271 |
+
temperature=self.temperature,
|
| 272 |
+
state_bag=self.state_bag,
|
| 273 |
+
context_state_bag=self.context_state_bag,
|
| 274 |
+
**kwargs,
|
| 275 |
+
)
|
| 276 |
+
self.context = context
|
| 277 |
+
|
| 278 |
+
# Dummy scores
|
| 279 |
+
scores = torch.zeros(seqs.shape[:-1])
|
| 280 |
+
return output.seqs, scores
|
| 281 |
+
|
| 282 |
+
def _step(self) -> bool:
|
| 283 |
+
# Generate the next step output.
|
| 284 |
+
|
| 285 |
+
if self.state_bag is None:
|
| 286 |
+
# Without a state_bag, we're forwarding the full prefix
|
| 287 |
+
# Encode the full context:
|
| 288 |
+
|
| 289 |
+
model_output, step_score = self._decode(
|
| 290 |
+
seqs=self.seqs[:, : self.step_nr],
|
| 291 |
+
padding_mask=None,
|
| 292 |
+
)
|
| 293 |
+
else:
|
| 294 |
+
# Since we're using a state_bag, we're only forwarding the last embedding
|
| 295 |
+
model_output, step_score = self._decode(
|
| 296 |
+
seqs=self.seqs[:, self.step_nr - 1 : self.step_nr],
|
| 297 |
+
padding_mask=None,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
self.state_bag.increment_step_nr()
|
| 301 |
+
|
| 302 |
+
# model_output: EmbeddingBag
|
| 303 |
+
return self.finalize_step(model_output, step_score)
|
| 304 |
+
|
| 305 |
+
def finalize_step(
|
| 306 |
+
self, model_output: torch.Tensor, step_score: torch.Tensor
|
| 307 |
+
) -> bool:
|
| 308 |
+
"""Post-processing and finalizing a step
|
| 309 |
+
by checking all stopping criteria
|
| 310 |
+
Takes the model's outputed embeddings (model_output)
|
| 311 |
+
and their associated scores (step_score)
|
| 312 |
+
If we're stepping, return True, else return False
|
| 313 |
+
"""
|
| 314 |
+
already_finished = self.lengths > -1
|
| 315 |
+
should_finish_now = torch.zeros_like(already_finished)
|
| 316 |
+
|
| 317 |
+
model_last_output = model_output[:, -1]
|
| 318 |
+
device = model_last_output.device
|
| 319 |
+
|
| 320 |
+
# Ignore prompt positions between min-max prompt_len
|
| 321 |
+
must_keep_going = None
|
| 322 |
+
if self.step_nr < self.max_prompt_len:
|
| 323 |
+
assert self.prompt_padding_mask is not None, (
|
| 324 |
+
f"If self.prompt_padding_mas is None, then self.step_nr should start from self.max_prompt_len={self.max_prompt_len} - currently self.step_nr = {self.step_nr}"
|
| 325 |
+
)
|
| 326 |
+
mask = self.prompt_padding_mask[:, self.step_nr]
|
| 327 |
+
model_last_output[mask] = self.seqs[mask, self.step_nr]
|
| 328 |
+
must_keep_going = mask
|
| 329 |
+
|
| 330 |
+
# Check stopping based on EOS similarity.
|
| 331 |
+
if self.eos_threshold is not None and self.eos_vec is not None:
|
| 332 |
+
sim2eos = torch.nn.functional.cosine_similarity(
|
| 333 |
+
self.eos_vec.to(device), model_last_output
|
| 334 |
+
)
|
| 335 |
+
logger.debug(f"Similarity to eos vector: {sim2eos} vs {self.eos_threshold}")
|
| 336 |
+
should_finish_now = should_finish_now | sim2eos.ge(self.eos_threshold)
|
| 337 |
+
|
| 338 |
+
# Check stopping based on repetition.
|
| 339 |
+
if (
|
| 340 |
+
self.options.stop_on_repetition_cosine_threshold is not None
|
| 341 |
+
and self.step_nr > 0
|
| 342 |
+
):
|
| 343 |
+
sim2prev = torch.nn.functional.cosine_similarity(
|
| 344 |
+
self.seqs[:, self.step_nr - 1], model_last_output
|
| 345 |
+
)
|
| 346 |
+
logger.debug(
|
| 347 |
+
f"Similarity to prev vector: {sim2prev} vs {self.options.stop_on_repetition_cosine_threshold}"
|
| 348 |
+
)
|
| 349 |
+
should_finish_now = should_finish_now | sim2prev.ge(
|
| 350 |
+
self.options.stop_on_repetition_cosine_threshold
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
if must_keep_going is not None:
|
| 354 |
+
logger.debug(
|
| 355 |
+
f"Must keep going (to cover max_prompt_len={self.max_prompt_len}) is not None = {must_keep_going}"
|
| 356 |
+
)
|
| 357 |
+
should_finish_now = should_finish_now & ~must_keep_going
|
| 358 |
+
|
| 359 |
+
# Keep going if output is shorter than min_gen_len:
|
| 360 |
+
if self.prompt_seq_lens is not None:
|
| 361 |
+
longuer_than_min_gen_len = (self.step_nr - self.prompt_seq_lens).ge(
|
| 362 |
+
self.min_gen_len
|
| 363 |
+
)
|
| 364 |
+
else:
|
| 365 |
+
longuer_than_min_gen_len = (
|
| 366 |
+
self.step_nr - self.max_prompt_len
|
| 367 |
+
) >= self.min_gen_len
|
| 368 |
+
|
| 369 |
+
logger.debug(
|
| 370 |
+
f"Longuer than min_gen_len ({self.min_gen_len}) = {longuer_than_min_gen_len}"
|
| 371 |
+
)
|
| 372 |
+
should_finish_now = should_finish_now & longuer_than_min_gen_len
|
| 373 |
+
stopped_on_eos = should_finish_now
|
| 374 |
+
|
| 375 |
+
# Stop hypotheses that reached max_gen_len
|
| 376 |
+
if self.prompt_seq_lens is not None:
|
| 377 |
+
exceeds_max_gen_len = (self.step_nr - self.prompt_seq_lens + 1).ge(
|
| 378 |
+
self.max_gen_len
|
| 379 |
+
)
|
| 380 |
+
logger.debug(
|
| 381 |
+
f"step: {self.step_nr}; max_gen_len: {self.max_gen_len}; promt_lens: {self.prompt_seq_lens}; steps exceeded: {self.max_gen_len + self.prompt_seq_lens}"
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
else:
|
| 385 |
+
exceeds_max_gen_len = (
|
| 386 |
+
self.step_nr - self.max_prompt_len + 1
|
| 387 |
+
) >= self.max_gen_len
|
| 388 |
+
logger.debug(
|
| 389 |
+
f"step: {self.step_nr}; max_gen_len: {self.max_gen_len}; promt_lens: None (unique length: {self.max_prompt_len}); steps exceeded: {self.max_prompt_len + self.max_gen_len}"
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
logger.debug(
|
| 393 |
+
f"Stopping criteria: {should_finish_now}; exceeds max len: {exceeds_max_gen_len}; already finished: {already_finished}"
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
should_finish_now = should_finish_now | exceeds_max_gen_len
|
| 397 |
+
|
| 398 |
+
# Assign lengths to the sequences that have just finished.
|
| 399 |
+
should_finish_now = should_finish_now & ~already_finished
|
| 400 |
+
self.lengths[should_finish_now] = self.step_nr + 1
|
| 401 |
+
|
| 402 |
+
# Record the current step.
|
| 403 |
+
self.seqs[:, self.step_nr] = model_last_output.squeeze(1)
|
| 404 |
+
self.step_scores[:, self.step_nr - self.min_prompt_len] = step_score[:, -1]
|
| 405 |
+
|
| 406 |
+
# Save completed hypsptheses
|
| 407 |
+
finished_mask = self.lengths.ne(-1)
|
| 408 |
+
finished_indices = finished_mask.nonzero()
|
| 409 |
+
|
| 410 |
+
# Remove finished hypotheses and reorder variables/state_bag if any are left
|
| 411 |
+
if len(finished_indices) > 0:
|
| 412 |
+
for idx in finished_indices:
|
| 413 |
+
self.finish_sequence(int(idx), is_eos=bool(stopped_on_eos[int(idx)]))
|
| 414 |
+
|
| 415 |
+
active_mask = ~finished_mask
|
| 416 |
+
active_indices = active_mask.nonzero().squeeze(-1)
|
| 417 |
+
|
| 418 |
+
if len(active_indices) == 0:
|
| 419 |
+
return False
|
| 420 |
+
|
| 421 |
+
self.reorder_state(active_indices)
|
| 422 |
+
|
| 423 |
+
return True
|
| 424 |
+
|
| 425 |
+
def finish_sequence(self, idx: int, is_eos: bool = False) -> None:
|
| 426 |
+
seq_len = int(self.lengths[idx].item())
|
| 427 |
+
|
| 428 |
+
if self.options.trim_hypotheses and self.lengths[idx].item() > -1 and is_eos:
|
| 429 |
+
seq_len = int(self.lengths[idx].item()) - int(
|
| 430 |
+
not self.options.include_eos_token
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
sample_idx = int(self.sample_indices[idx])
|
| 434 |
+
self.hypotheses[sample_idx] = [
|
| 435 |
+
Hypothesis(
|
| 436 |
+
seq=self.seqs[idx, :seq_len],
|
| 437 |
+
score=None,
|
| 438 |
+
step_scores=self.step_scores[idx], # Trim it as well?
|
| 439 |
+
)
|
| 440 |
+
]
|
| 441 |
+
|
| 442 |
+
def reorder_state(self, new_order: torch.Tensor) -> None:
|
| 443 |
+
self.state_bag_reorder(new_order)
|
| 444 |
+
|
| 445 |
+
self.context = EmbeddingsBatch(
|
| 446 |
+
self.context.seqs.index_select(dim=0, index=new_order),
|
| 447 |
+
self.context.padding_mask,
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
self.seqs = self.seqs.index_select(dim=0, index=new_order)
|
| 451 |
+
|
| 452 |
+
self.sample_indices = self.sample_indices.index_select(dim=0, index=new_order)
|
| 453 |
+
|
| 454 |
+
self.step_scores = self.step_scores.index_select(dim=0, index=new_order)
|
| 455 |
+
|
| 456 |
+
self.lengths = self.lengths.index_select(dim=0, index=new_order)
|
| 457 |
+
|
| 458 |
+
if self.prompt_padding_mask is not None:
|
| 459 |
+
self.prompt_padding_mask = self.prompt_padding_mask.index_select(
|
| 460 |
+
dim=0, index=new_order
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
if self.prompt_seq_lens is not None:
|
| 464 |
+
self.prompt_seq_lens = self.prompt_seq_lens.index_select(
|
| 465 |
+
dim=0, index=new_order
|
| 466 |
+
)
|
lcm/inference/two_tower_diffusion_lcm/scorer.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from typing import List, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from fairseq2.generation.generator import (
|
| 10 |
+
GenerationCounters,
|
| 11 |
+
Hypothesis,
|
| 12 |
+
SequenceGeneratorOutput,
|
| 13 |
+
)
|
| 14 |
+
from fairseq2.logging import get_log_writer
|
| 15 |
+
|
| 16 |
+
from lcm.datasets.batch import EmbeddingsBatch, PaddingMask
|
| 17 |
+
from lcm.inference.lcm.generator import LCMGeneratorOptions
|
| 18 |
+
from lcm.inference.two_tower_diffusion_lcm import (
|
| 19 |
+
TwoTowerDiffusionLCMGenerator,
|
| 20 |
+
)
|
| 21 |
+
from lcm.models.abstract_lcm import AbstractLCModel
|
| 22 |
+
from lcm.nn.incremental_state import LCMIncrementalStateBag
|
| 23 |
+
|
| 24 |
+
logger = get_log_writer(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class TwoTowerDiffusionLCMScorer(TwoTowerDiffusionLCMGenerator):
|
| 28 |
+
"""Score by generating in teacher-forcing mode with a Two-tower Diffusion LCM model."""
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
model: AbstractLCModel,
|
| 33 |
+
options: Optional[LCMGeneratorOptions] = None,
|
| 34 |
+
eos_vec: Optional[torch.Tensor] = None,
|
| 35 |
+
) -> None:
|
| 36 |
+
super().__init__(model, options, eos_vec)
|
| 37 |
+
|
| 38 |
+
@torch.inference_mode()
|
| 39 |
+
def __call__( # type: ignore
|
| 40 |
+
self,
|
| 41 |
+
batch_input: EmbeddingsBatch,
|
| 42 |
+
max_gen_len: Optional[int] = None,
|
| 43 |
+
min_gen_len: Optional[int] = None,
|
| 44 |
+
min_context_len: int = 1,
|
| 45 |
+
temperature: float = 0.0,
|
| 46 |
+
disable_cache: bool = False,
|
| 47 |
+
) -> SequenceGeneratorOutput:
|
| 48 |
+
"""
|
| 49 |
+
:param input:
|
| 50 |
+
`bacth_input` embedded and padded tensor sequence of the inputs
|
| 51 |
+
`max_gen_len` max length to be generated for the given input
|
| 52 |
+
`min_gen_len` minimum length to be generated for the given input
|
| 53 |
+
`disable_cache` if True, do not use kv-caching
|
| 54 |
+
:returns:
|
| 55 |
+
The output of the LCM generator, consists of :math:`N` lists of
|
| 56 |
+
hypotheses for :math:`N` documents. Each list has 1 Hypothesis
|
| 57 |
+
(beam size = 1), of which `seq` has the *Shape:* math:`(T, D)`
|
| 58 |
+
(:math:`T` the length of the document and :math:`D` the model
|
| 59 |
+
dimension.)
|
| 60 |
+
|
| 61 |
+
"""
|
| 62 |
+
if self.options.seed:
|
| 63 |
+
torch.manual_seed(self.options.seed)
|
| 64 |
+
|
| 65 |
+
# Setup the variables
|
| 66 |
+
self.min_context_len = min_context_len
|
| 67 |
+
batch_size, self.max_text_len, embed_dim = batch_input.seqs.size()
|
| 68 |
+
text_padding_mask = batch_input.padding_mask
|
| 69 |
+
if text_padding_mask is None:
|
| 70 |
+
self.text_padding_mask = None
|
| 71 |
+
self.text_seq_lens = self.max_text_len * torch.ones(
|
| 72 |
+
batch_size,
|
| 73 |
+
dtype=torch.long,
|
| 74 |
+
device=batch_input.seqs.device,
|
| 75 |
+
)
|
| 76 |
+
else:
|
| 77 |
+
self.text_seq_lens = text_padding_mask.seq_lens
|
| 78 |
+
assert self.text_seq_lens is not None, (
|
| 79 |
+
"Expecting a valid `self.text_seq_lens` Tensor, found `None`"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Keep the materialized mask
|
| 83 |
+
self.text_padding_mask = text_padding_mask.materialize()
|
| 84 |
+
|
| 85 |
+
if not max_gen_len:
|
| 86 |
+
max_gen_len = self.max_seq_len
|
| 87 |
+
|
| 88 |
+
max_gen_len = min(max_gen_len, self.max_text_len - self.min_context_len)
|
| 89 |
+
assert max_gen_len is not None, "max_gen_len is None"
|
| 90 |
+
|
| 91 |
+
# Make sure we do not accidentally set a max_gen_len that exceeds
|
| 92 |
+
# the generator's model capability
|
| 93 |
+
assert max_gen_len <= self.max_seq_len, (
|
| 94 |
+
f"Generator can generate up to {self.max_seq_len} sequences, max_gen_len={max_gen_len}"
|
| 95 |
+
)
|
| 96 |
+
self.max_gen_len = max_gen_len
|
| 97 |
+
|
| 98 |
+
if not min_gen_len:
|
| 99 |
+
min_gen_len = self.min_seq_len
|
| 100 |
+
|
| 101 |
+
assert min_gen_len is not None, "A `min_gen_len` is required"
|
| 102 |
+
|
| 103 |
+
assert min_gen_len > 0, (
|
| 104 |
+
f"min_gen_len must be greater than or equal to 1, min_gen_len={min_gen_len}"
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
self.min_gen_len = min_gen_len
|
| 108 |
+
|
| 109 |
+
if temperature == 0.0:
|
| 110 |
+
# If the call doesn't pass a specific temperature,
|
| 111 |
+
# use the default one from the decoding options
|
| 112 |
+
temperature = self.options.lcm_temperature
|
| 113 |
+
|
| 114 |
+
# Holds the generated sequences, scores and sample-dependent variables
|
| 115 |
+
dtype = self.model.dtype
|
| 116 |
+
device = batch_input.seqs.device
|
| 117 |
+
self.temperature = temperature
|
| 118 |
+
|
| 119 |
+
if disable_cache:
|
| 120 |
+
self.state_bag = None
|
| 121 |
+
self.context_state_bag = None
|
| 122 |
+
else:
|
| 123 |
+
self.state_bag = LCMIncrementalStateBag(self.max_text_len)
|
| 124 |
+
self.context_state_bag = LCMIncrementalStateBag(self.max_text_len)
|
| 125 |
+
|
| 126 |
+
# reserving full sequences capacity
|
| 127 |
+
self.seqs = batch_input.seqs
|
| 128 |
+
self.preds = torch.zeros(
|
| 129 |
+
(batch_size, self.max_text_len - self.min_context_len, embed_dim),
|
| 130 |
+
device=device,
|
| 131 |
+
dtype=dtype,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
self.step_scores = torch.zeros(
|
| 135 |
+
(batch_size, self.max_text_len),
|
| 136 |
+
device=device,
|
| 137 |
+
)
|
| 138 |
+
# Hold the samples indices to return in order
|
| 139 |
+
self.sample_indices = torch.arange(batch_size, device=device)
|
| 140 |
+
# Output buffer
|
| 141 |
+
self.hypotheses: List[List[Hypothesis]] = [[] for _ in range(batch_size)]
|
| 142 |
+
|
| 143 |
+
# the sequences with the provided prompt.
|
| 144 |
+
self.step_nr = self.min_context_len
|
| 145 |
+
|
| 146 |
+
# A context we keep growing in each decoding step
|
| 147 |
+
self.prefill()
|
| 148 |
+
|
| 149 |
+
for self.step_nr in range(self.min_context_len, self.max_text_len):
|
| 150 |
+
if not self._step():
|
| 151 |
+
break
|
| 152 |
+
|
| 153 |
+
return SequenceGeneratorOutput(self.hypotheses, counters=GenerationCounters())
|
| 154 |
+
|
| 155 |
+
def state_bag_reorder(self, new_order: torch.Tensor) -> None:
|
| 156 |
+
if self.state_bag is not None:
|
| 157 |
+
self.state_bag.reorder(new_order)
|
| 158 |
+
|
| 159 |
+
if self.context_state_bag is not None:
|
| 160 |
+
self.context_state_bag.reorder(new_order)
|
| 161 |
+
|
| 162 |
+
@torch.inference_mode()
|
| 163 |
+
def prefill(self, **kwargs) -> None:
|
| 164 |
+
"""encode the prefix with the context encoder"""
|
| 165 |
+
|
| 166 |
+
assert self.context_state_bag is not None, (
|
| 167 |
+
"Expecting a context state bag to prefill"
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
context: EmbeddingsBatch
|
| 171 |
+
|
| 172 |
+
# FIXME for this model we can prefill with self.step_nr
|
| 173 |
+
prefill_len = self.step_nr - 1
|
| 174 |
+
if prefill_len > 0:
|
| 175 |
+
# normalize then encode
|
| 176 |
+
input_seqs = self.seqs[:, :prefill_len]
|
| 177 |
+
if self.model.config.sonar_normalizer_name is not None:
|
| 178 |
+
input_seqs = self.model.sonar_normalizer.normalize(input_seqs)
|
| 179 |
+
|
| 180 |
+
context = self.model.encode(
|
| 181 |
+
EmbeddingsBatch(input_seqs, None),
|
| 182 |
+
state_bag=self.context_state_bag,
|
| 183 |
+
**kwargs,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
self.context_state_bag.increment_step_nr(prefill_len)
|
| 187 |
+
|
| 188 |
+
else:
|
| 189 |
+
logger.warning(
|
| 190 |
+
f"Skipping prefill since only a context size of {self.step_nr} is provided in the prefix"
|
| 191 |
+
)
|
| 192 |
+
context = EmbeddingsBatch(
|
| 193 |
+
torch.empty(
|
| 194 |
+
(self.seqs.shape[0], 0, self.model.model_dim),
|
| 195 |
+
dtype=self.seqs.dtype,
|
| 196 |
+
device=self.seqs.device,
|
| 197 |
+
)
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
self.context = context
|
| 201 |
+
|
| 202 |
+
@torch.inference_mode()
|
| 203 |
+
def _decode(
|
| 204 |
+
self,
|
| 205 |
+
seqs: torch.Tensor,
|
| 206 |
+
padding_mask: Optional[PaddingMask] = None,
|
| 207 |
+
**kwargs,
|
| 208 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 209 |
+
output, context = self.model.predict_next_sentence(
|
| 210 |
+
batch=EmbeddingsBatch(seqs, padding_mask),
|
| 211 |
+
context=self.context,
|
| 212 |
+
temperature=self.temperature,
|
| 213 |
+
state_bag=self.state_bag,
|
| 214 |
+
context_state_bag=self.context_state_bag,
|
| 215 |
+
**kwargs,
|
| 216 |
+
)
|
| 217 |
+
self.context = context
|
| 218 |
+
|
| 219 |
+
# Dummy score
|
| 220 |
+
scores = torch.zeros(seqs.shape[:-1])
|
| 221 |
+
return output.seqs, scores
|
| 222 |
+
|
| 223 |
+
def _step(self) -> bool:
|
| 224 |
+
# Generate the next step output.
|
| 225 |
+
|
| 226 |
+
if self.state_bag is None:
|
| 227 |
+
# Without a state_bag, we're forwarding the full prefix
|
| 228 |
+
# Encode the full context:
|
| 229 |
+
|
| 230 |
+
model_output, step_score = self._decode(
|
| 231 |
+
seqs=self.seqs[:, : self.step_nr],
|
| 232 |
+
padding_mask=None,
|
| 233 |
+
)
|
| 234 |
+
else:
|
| 235 |
+
# Since we're using a state_bag, we're only forwarding the last embedding
|
| 236 |
+
model_output, step_score = self._decode(
|
| 237 |
+
seqs=self.seqs[:, self.step_nr - 1 : self.step_nr],
|
| 238 |
+
padding_mask=None,
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
self.state_bag.increment_step_nr()
|
| 242 |
+
|
| 243 |
+
# model_output: EmbeddingBag
|
| 244 |
+
return self.finalize_step(model_output, step_score)
|
| 245 |
+
|
| 246 |
+
def finalize_step(
|
| 247 |
+
self, model_output: torch.Tensor, step_score: torch.Tensor
|
| 248 |
+
) -> bool:
|
| 249 |
+
"""Post-processing and finalizing a step
|
| 250 |
+
by checking all stopping criteria
|
| 251 |
+
Takes the model's outputed embeddings (model_output)
|
| 252 |
+
and their associated scores (step_score)
|
| 253 |
+
If we're stepping, return True, else return False
|
| 254 |
+
"""
|
| 255 |
+
model_last_output = model_output[:, -1]
|
| 256 |
+
must_keep_going = self.text_seq_lens.gt(self.step_nr + 1)
|
| 257 |
+
should_finish_now = ~must_keep_going
|
| 258 |
+
|
| 259 |
+
# Record the current step prediction.
|
| 260 |
+
self.preds[:, self.step_nr - self.min_context_len] = model_last_output.squeeze(
|
| 261 |
+
1
|
| 262 |
+
)
|
| 263 |
+
self.step_scores[:, self.step_nr - self.min_context_len] = step_score[:, -1]
|
| 264 |
+
|
| 265 |
+
# Save completed hypsptheses
|
| 266 |
+
finished_indices = should_finish_now.nonzero()
|
| 267 |
+
|
| 268 |
+
# Remove finished hypotheses and reorder variables/state_bag if any are left
|
| 269 |
+
if len(finished_indices) > 0:
|
| 270 |
+
for idx in finished_indices:
|
| 271 |
+
self.finish_sequence(int(idx))
|
| 272 |
+
|
| 273 |
+
active_mask = must_keep_going
|
| 274 |
+
active_indices = active_mask.nonzero().squeeze(-1)
|
| 275 |
+
|
| 276 |
+
if len(active_indices) == 0:
|
| 277 |
+
return False
|
| 278 |
+
|
| 279 |
+
self.reorder_state(active_indices)
|
| 280 |
+
|
| 281 |
+
return True
|
| 282 |
+
|
| 283 |
+
def finish_sequence(self, idx: int) -> None: # type: ignore
|
| 284 |
+
seq_len = int(self.text_seq_lens[idx].item())
|
| 285 |
+
sample_idx = int(self.sample_indices[idx])
|
| 286 |
+
self.hypotheses[sample_idx] = [
|
| 287 |
+
Hypothesis(
|
| 288 |
+
seq=self.preds[idx, : seq_len - self.min_context_len],
|
| 289 |
+
score=None,
|
| 290 |
+
step_scores=self.step_scores[idx], # Trim it as well?
|
| 291 |
+
)
|
| 292 |
+
]
|
| 293 |
+
|
| 294 |
+
def reorder_state(self, new_order: torch.Tensor) -> None:
|
| 295 |
+
self.state_bag_reorder(new_order)
|
| 296 |
+
|
| 297 |
+
self.context = EmbeddingsBatch(
|
| 298 |
+
self.context.seqs.index_select(dim=0, index=new_order),
|
| 299 |
+
self.context.padding_mask,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
self.seqs = self.seqs.index_select(dim=0, index=new_order)
|
| 303 |
+
self.preds = self.preds.index_select(dim=0, index=new_order)
|
| 304 |
+
|
| 305 |
+
self.sample_indices = self.sample_indices.index_select(dim=0, index=new_order)
|
| 306 |
+
|
| 307 |
+
self.step_scores = self.step_scores.index_select(dim=0, index=new_order)
|
| 308 |
+
|
| 309 |
+
if self.text_padding_mask is not None:
|
| 310 |
+
self.text_padding_mask = self.text_padding_mask.index_select(
|
| 311 |
+
dim=0, index=new_order
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
self.text_seq_lens = self.text_seq_lens.index_select(dim=0, index=new_order)
|
lcm/models/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
# We import all the model types in order to populate the model type registry
|
| 7 |
+
from lcm.models.base_lcm.loader import BASE_LCM_MODEL_TYPE
|
| 8 |
+
from lcm.models.two_tower_diffusion_lcm.loader import (
|
| 9 |
+
TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"BASE_LCM_MODEL_TYPE",
|
| 14 |
+
"TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE",
|
| 15 |
+
]
|
lcm/models/abstract_lcm/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from lcm.models.abstract_lcm.builder import (
|
| 7 |
+
AbstractLCModel,
|
| 8 |
+
AbstractLCModelBuilder,
|
| 9 |
+
AbstractLCModelConfig,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"AbstractLCModel",
|
| 14 |
+
"AbstractLCModelBuilder",
|
| 15 |
+
"AbstractLCModelConfig",
|
| 16 |
+
]
|
lcm/models/abstract_lcm/builder.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from abc import abstractmethod
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
from fairseq2.config_registry import ConfigRegistry
|
| 11 |
+
from fairseq2.logging import get_log_writer
|
| 12 |
+
from fairseq2.typing import DataType, Device
|
| 13 |
+
from torch.nn import Module
|
| 14 |
+
|
| 15 |
+
from lcm.models.sonar_normalizer import SonarNormalizer, load_sonar_normalizer_model
|
| 16 |
+
|
| 17 |
+
logger = get_log_writer(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
"""
|
| 21 |
+
An abstract LCM model class for the bare minimum
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
ABSTRACT_LCM_MODEL_TYPE = "abstract_lcm"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class AbstractLCModelConfig:
|
| 29 |
+
model_type: str = ABSTRACT_LCM_MODEL_TYPE
|
| 30 |
+
|
| 31 |
+
sonar_embed_dim: int = 1024
|
| 32 |
+
|
| 33 |
+
sonar_normalizer_name: Optional[str] = None
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
lcm_archs = ConfigRegistry[AbstractLCModelConfig]()
|
| 37 |
+
lcm_arch = lcm_archs.decorator
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class AbstractLCModel(Module):
|
| 41 |
+
"""Asbtract Class for LCM models"""
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
config: AbstractLCModelConfig,
|
| 46 |
+
) -> None:
|
| 47 |
+
"""
|
| 48 |
+
Asbtract LCM model
|
| 49 |
+
"""
|
| 50 |
+
super().__init__()
|
| 51 |
+
|
| 52 |
+
self.config = config
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def dtype(self):
|
| 56 |
+
return next(self.parameters()).dtype
|
| 57 |
+
|
| 58 |
+
@property
|
| 59 |
+
def device(self):
|
| 60 |
+
return next(self.parameters()).device
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class AbstractLCModelBuilder:
|
| 64 |
+
"""Builds modules of an LCM"""
|
| 65 |
+
|
| 66 |
+
config: AbstractLCModelConfig
|
| 67 |
+
device: Optional[Device]
|
| 68 |
+
dtype: Optional[DataType]
|
| 69 |
+
|
| 70 |
+
def __init__(
|
| 71 |
+
self,
|
| 72 |
+
config: AbstractLCModelConfig,
|
| 73 |
+
*,
|
| 74 |
+
device: Optional[Device] = None,
|
| 75 |
+
dtype: Optional[DataType] = None,
|
| 76 |
+
) -> None:
|
| 77 |
+
"""
|
| 78 |
+
:param config:
|
| 79 |
+
The configuration.
|
| 80 |
+
:param device:
|
| 81 |
+
The device on which to initialize modules.
|
| 82 |
+
:param dtype:
|
| 83 |
+
The data type of module parameters and buffers.
|
| 84 |
+
"""
|
| 85 |
+
self.config = config
|
| 86 |
+
|
| 87 |
+
self.device, self.dtype = device, dtype
|
| 88 |
+
|
| 89 |
+
def build_sonar_normalizer(
|
| 90 |
+
self,
|
| 91 |
+
) -> Optional[SonarNormalizer]:
|
| 92 |
+
if self.config.sonar_normalizer_name is not None:
|
| 93 |
+
logger.info(
|
| 94 |
+
f"Building sonar_normalizer = {self.config.sonar_normalizer_name}"
|
| 95 |
+
)
|
| 96 |
+
return load_sonar_normalizer_model(
|
| 97 |
+
self.config.sonar_normalizer_name,
|
| 98 |
+
device=self.device,
|
| 99 |
+
dtype=self.dtype,
|
| 100 |
+
)
|
| 101 |
+
return None
|
| 102 |
+
|
| 103 |
+
@abstractmethod
|
| 104 |
+
def build_model(self) -> AbstractLCModel:
|
| 105 |
+
"""Build a model."""
|
| 106 |
+
...
|
lcm/models/base_lcm/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
# Register architectures
|
| 7 |
+
import lcm.models.base_lcm.archs # noqa
|
| 8 |
+
from lcm.models.base_lcm.builder import (
|
| 9 |
+
BaseLCModel,
|
| 10 |
+
BaseLCModelBuilder,
|
| 11 |
+
BaseLCModelConfig,
|
| 12 |
+
create_base_lcm_model,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"BaseLCModel",
|
| 17 |
+
"BaseLCModelBuilder",
|
| 18 |
+
"BaseLCModelConfig",
|
| 19 |
+
"create_base_lcm_model",
|
| 20 |
+
]
|
lcm/models/base_lcm/archs.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from lcm.models.base_lcm.builder import (
|
| 7 |
+
BaseLCModelConfig,
|
| 8 |
+
LCMFrontendConfig,
|
| 9 |
+
ProjectionConfig,
|
| 10 |
+
TransformerConfig,
|
| 11 |
+
lcm_arch,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# Every model must register a toy_{model_family}
|
| 16 |
+
@lcm_arch("toy_base_lcm")
|
| 17 |
+
def toy_base_lcm() -> BaseLCModelConfig:
|
| 18 |
+
return BaseLCModelConfig(
|
| 19 |
+
lcm=TransformerConfig(num_layers=2),
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@lcm_arch("base_lcm_1_6B")
|
| 24 |
+
def base_lcm_1_6B() -> BaseLCModelConfig:
|
| 25 |
+
"""Base 1.6B model
|
| 26 |
+
Parameter Size: 1,647,635,456
|
| 27 |
+
"""
|
| 28 |
+
model_dim: int = 2048
|
| 29 |
+
num_attn_heads: int = 16
|
| 30 |
+
return BaseLCModelConfig(
|
| 31 |
+
max_seq_len=4096,
|
| 32 |
+
model_dim=model_dim,
|
| 33 |
+
sonar_embed_dim=1024,
|
| 34 |
+
sonar_normalizer_name="dummy_sonar_normalizer",
|
| 35 |
+
frontend=LCMFrontendConfig(),
|
| 36 |
+
lcm=TransformerConfig(
|
| 37 |
+
final_dropout_p=0.0,
|
| 38 |
+
attention_dropout_p=0.0,
|
| 39 |
+
dropout_p=0.1,
|
| 40 |
+
mha_output_proj_bias=True,
|
| 41 |
+
ffn_inner_dim=model_dim * 4,
|
| 42 |
+
num_attn_heads=num_attn_heads,
|
| 43 |
+
num_layers=32,
|
| 44 |
+
pos_embedding_style="rope",
|
| 45 |
+
use_swiglu=True,
|
| 46 |
+
layer_normalization_style="rms",
|
| 47 |
+
),
|
| 48 |
+
postnet=ProjectionConfig(),
|
| 49 |
+
)
|
lcm/models/base_lcm/builder.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
import torch.nn
|
| 10 |
+
from fairseq2.config_registry import ConfigRegistry
|
| 11 |
+
from fairseq2.logging import get_log_writer
|
| 12 |
+
from fairseq2.nn.incremental_state import IncrementalStateBag
|
| 13 |
+
from fairseq2.nn.transformer import AttentionMaskFactory, CausalAttentionMaskFactory
|
| 14 |
+
from fairseq2.typing import DataType, Device
|
| 15 |
+
|
| 16 |
+
from lcm.datasets.batch import EmbeddingsBatch
|
| 17 |
+
from lcm.models.abstract_lcm import (
|
| 18 |
+
AbstractLCModel,
|
| 19 |
+
AbstractLCModelBuilder,
|
| 20 |
+
AbstractLCModelConfig,
|
| 21 |
+
)
|
| 22 |
+
from lcm.models.base_lcm.frontend import LCMFrontend, LCMFrontendConfig
|
| 23 |
+
from lcm.nn.initialization import parse_norm_order
|
| 24 |
+
from lcm.nn.normalization import parse_layer_norm_factory
|
| 25 |
+
from lcm.nn.projection import Projection, ProjectionConfig
|
| 26 |
+
from lcm.nn.transformer import (
|
| 27 |
+
LCMTransformerDecoder,
|
| 28 |
+
TransformerConfig,
|
| 29 |
+
TransformerFactory,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
logger = get_log_writer(__name__)
|
| 33 |
+
|
| 34 |
+
BASE_LCM_MODEL_TYPE = "base_lcm"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class BaseLCModelConfig(AbstractLCModelConfig):
|
| 39 |
+
model_type: str = BASE_LCM_MODEL_TYPE
|
| 40 |
+
|
| 41 |
+
max_seq_len: int = 2048
|
| 42 |
+
|
| 43 |
+
model_dim: int = 1024
|
| 44 |
+
|
| 45 |
+
model_output_dim: Optional[int] = None
|
| 46 |
+
"""If ``None`` use SONAR dimension as output_dim."""
|
| 47 |
+
|
| 48 |
+
frontend: LCMFrontendConfig = field(default_factory=lambda: LCMFrontendConfig())
|
| 49 |
+
"""The fronted config. This module maps from `sonar_embed_dim` to `model_dim`
|
| 50 |
+
and potentially adds positional embeddings"""
|
| 51 |
+
|
| 52 |
+
lcm: TransformerConfig = field(default_factory=lambda: TransformerConfig())
|
| 53 |
+
"""The core lcm config. This is causal Transformer decoder"""
|
| 54 |
+
|
| 55 |
+
postnet: ProjectionConfig = field(default_factory=lambda: ProjectionConfig())
|
| 56 |
+
"""The postnet config. A module mapping the output of the core lcm
|
| 57 |
+
back to `sonar_embed_dim`"""
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
lcm_archs = ConfigRegistry[BaseLCModelConfig]()
|
| 61 |
+
lcm_arch = lcm_archs.decorator
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class BaseLCModel(AbstractLCModel):
|
| 65 |
+
"""Base class for LCM models"""
|
| 66 |
+
|
| 67 |
+
config: BaseLCModelConfig
|
| 68 |
+
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
config: BaseLCModelConfig,
|
| 72 |
+
lcm: LCMTransformerDecoder,
|
| 73 |
+
frontend: LCMFrontend,
|
| 74 |
+
postnet: Projection,
|
| 75 |
+
) -> None:
|
| 76 |
+
"""
|
| 77 |
+
Basic LCM model with :
|
| 78 |
+
- fronted
|
| 79 |
+
- lcm
|
| 80 |
+
- postnet
|
| 81 |
+
"""
|
| 82 |
+
super().__init__(config)
|
| 83 |
+
|
| 84 |
+
self.frontend = frontend
|
| 85 |
+
|
| 86 |
+
self.lcm = lcm
|
| 87 |
+
|
| 88 |
+
self.postnet = postnet
|
| 89 |
+
|
| 90 |
+
self.model_dim = lcm.model_dim
|
| 91 |
+
|
| 92 |
+
self.sonar_embed_dim = config.sonar_embed_dim
|
| 93 |
+
|
| 94 |
+
def forward(
|
| 95 |
+
self,
|
| 96 |
+
batch: EmbeddingsBatch,
|
| 97 |
+
state_bag: Optional[IncrementalStateBag] = None,
|
| 98 |
+
**kwargs,
|
| 99 |
+
) -> EmbeddingsBatch:
|
| 100 |
+
"""
|
| 101 |
+
Scaling + Positions
|
| 102 |
+
If a normalizer is provided, the features will be normalized in the
|
| 103 |
+
frontend's pre_forward (e.g. MSE LCM) or in the criterion (Diffusion LCM)
|
| 104 |
+
"""
|
| 105 |
+
seqs, padding_mask = self.frontend(
|
| 106 |
+
batch.seqs,
|
| 107 |
+
batch.padding_mask,
|
| 108 |
+
diffusion_timesteps=batch.diffusion_timesteps,
|
| 109 |
+
state_bag=state_bag,
|
| 110 |
+
**kwargs,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# Core LCM
|
| 114 |
+
seqs, padding_mask = self.lcm(
|
| 115 |
+
seqs,
|
| 116 |
+
padding_mask,
|
| 117 |
+
state_bag=state_bag,
|
| 118 |
+
**kwargs,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# Postnet:
|
| 122 |
+
seqs = self.postnet(seqs) # type: ignore
|
| 123 |
+
|
| 124 |
+
return EmbeddingsBatch(seqs=seqs, padding_mask=padding_mask)
|
| 125 |
+
|
| 126 |
+
def predict_next_sentence(
|
| 127 |
+
self,
|
| 128 |
+
batch: EmbeddingsBatch,
|
| 129 |
+
sample: bool = False,
|
| 130 |
+
temperature: float = 1.0,
|
| 131 |
+
state_bag: Optional[IncrementalStateBag] = None,
|
| 132 |
+
**kwargs,
|
| 133 |
+
) -> EmbeddingsBatch:
|
| 134 |
+
"""
|
| 135 |
+
The method for predicting the next sentence embeddings.
|
| 136 |
+
In the basic LCM, this is equivalent to just the forward method,
|
| 137 |
+
but the derived architectures may have a different implementation.
|
| 138 |
+
E.g. in VAE LCM, we run the VAE decoder on top of the `forward` results.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
batch (EmbeddingsBatch): the sequence of concepts which
|
| 142 |
+
the model should continue.
|
| 143 |
+
sample (bool): whether to predict the single most probable next sentence
|
| 144 |
+
or to sample from the predicted distribution.
|
| 145 |
+
temperature (float): a positive float indicating the degree of diversity
|
| 146 |
+
for the sampling (active only if `sample is True`).
|
| 147 |
+
Returns:
|
| 148 |
+
EmbeddingsBatch: the batch with predicted SONAR sentences.
|
| 149 |
+
"""
|
| 150 |
+
# Normalize the input embeddings if we're expected to
|
| 151 |
+
# normalize outside of the model's forward pass
|
| 152 |
+
if self.frontend.sonar_normalizer is not None:
|
| 153 |
+
batch = batch.normalize_seqs(self.frontend.sonar_normalizer)
|
| 154 |
+
|
| 155 |
+
# TODO: implement efficient sampling of multiple candidates
|
| 156 |
+
predicted_means = self.forward(batch, state_bag=state_bag, **kwargs)
|
| 157 |
+
|
| 158 |
+
if sample and temperature > 0:
|
| 159 |
+
noise = torch.randn_like(predicted_means.seqs) * temperature
|
| 160 |
+
predicted_means.seqs = predicted_means.seqs + noise
|
| 161 |
+
|
| 162 |
+
if self.frontend.sonar_normalizer is not None:
|
| 163 |
+
predicted_means = predicted_means.denormalize_seqs(
|
| 164 |
+
self.frontend.sonar_normalizer
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
return predicted_means
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class BaseLCModelBuilder(AbstractLCModelBuilder):
|
| 171 |
+
"""Builds modules of a base LCM model"""
|
| 172 |
+
|
| 173 |
+
config: BaseLCModelConfig
|
| 174 |
+
device: Optional[Device]
|
| 175 |
+
dtype: Optional[DataType]
|
| 176 |
+
|
| 177 |
+
def __init__(
|
| 178 |
+
self,
|
| 179 |
+
config: BaseLCModelConfig,
|
| 180 |
+
*,
|
| 181 |
+
device: Optional[Device] = None,
|
| 182 |
+
dtype: Optional[DataType] = None,
|
| 183 |
+
) -> None:
|
| 184 |
+
super().__init__(config=config, device=device, dtype=dtype)
|
| 185 |
+
self.lcm_factory = TransformerFactory(
|
| 186 |
+
model_dim=self.config.model_dim,
|
| 187 |
+
max_seq_len=self.config.max_seq_len,
|
| 188 |
+
config=self.config.lcm,
|
| 189 |
+
device=device,
|
| 190 |
+
dtype=dtype,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
if config.model_output_dim is None:
|
| 194 |
+
self.model_output_dim = self.config.sonar_embed_dim
|
| 195 |
+
else:
|
| 196 |
+
self.model_output_dim = config.model_output_dim
|
| 197 |
+
|
| 198 |
+
def build_model(self) -> BaseLCModel:
|
| 199 |
+
"""Build a model."""
|
| 200 |
+
|
| 201 |
+
frontend = self.build_frontend()
|
| 202 |
+
|
| 203 |
+
lcm = self.build_core_lcm()
|
| 204 |
+
|
| 205 |
+
postnet = self.build_postnet()
|
| 206 |
+
|
| 207 |
+
return BaseLCModel(
|
| 208 |
+
config=self.config,
|
| 209 |
+
frontend=frontend,
|
| 210 |
+
lcm=lcm,
|
| 211 |
+
postnet=postnet,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
def build_frontend(self) -> LCMFrontend:
|
| 215 |
+
"""Build the LCM front-end (i.e., prenet)."""
|
| 216 |
+
|
| 217 |
+
return LCMFrontend(
|
| 218 |
+
sonar_embed_dim=self.config.sonar_embed_dim,
|
| 219 |
+
model_dim=self.config.model_dim,
|
| 220 |
+
config=self.config.frontend,
|
| 221 |
+
pos_encoder=self.lcm_factory.build_pos_encoder(),
|
| 222 |
+
sonar_normalizer=self.build_sonar_normalizer(),
|
| 223 |
+
device=self.device,
|
| 224 |
+
dtype=self.dtype,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
def build_postnet(self) -> Projection:
|
| 228 |
+
return Projection(
|
| 229 |
+
output_dim=self.model_output_dim,
|
| 230 |
+
input_dim=self.config.model_dim,
|
| 231 |
+
config=self.config.postnet,
|
| 232 |
+
device=self.device,
|
| 233 |
+
dtype=self.dtype,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
def build_attention_mask_factory(self):
|
| 237 |
+
self_attn_mask_factory: AttentionMaskFactory
|
| 238 |
+
|
| 239 |
+
self_attn_mask_factory = CausalAttentionMaskFactory()
|
| 240 |
+
|
| 241 |
+
return self_attn_mask_factory
|
| 242 |
+
|
| 243 |
+
def build_core_lcm(self) -> LCMTransformerDecoder:
|
| 244 |
+
"""Build the core LCM module."""
|
| 245 |
+
|
| 246 |
+
config = self.config.lcm
|
| 247 |
+
|
| 248 |
+
layers = [self.lcm_factory.build_layer() for _ in range(config.num_layers)]
|
| 249 |
+
|
| 250 |
+
self_attn_mask_factory = self.build_attention_mask_factory()
|
| 251 |
+
|
| 252 |
+
if config.final_norm_order_style is None:
|
| 253 |
+
# The final norm order style will be that of the layer-level norm order
|
| 254 |
+
final_norm_order = parse_norm_order(config.norm_order_style)
|
| 255 |
+
else:
|
| 256 |
+
final_norm_order = parse_norm_order(config.final_norm_order_style)
|
| 257 |
+
|
| 258 |
+
layer_norm_factory = parse_layer_norm_factory(config.layer_normalization_style)
|
| 259 |
+
|
| 260 |
+
return LCMTransformerDecoder(
|
| 261 |
+
layers, # type: ignore
|
| 262 |
+
self_attn_mask_factory=self_attn_mask_factory,
|
| 263 |
+
norm_order=final_norm_order,
|
| 264 |
+
layer_norm_factory=layer_norm_factory,
|
| 265 |
+
dropout_p=config.final_dropout_p,
|
| 266 |
+
device=self.device,
|
| 267 |
+
dtype=self.dtype,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def create_base_lcm_model(
|
| 272 |
+
config: BaseLCModelConfig,
|
| 273 |
+
*,
|
| 274 |
+
device: Optional[Device] = None,
|
| 275 |
+
dtype: Optional[DataType] = None,
|
| 276 |
+
) -> BaseLCModel:
|
| 277 |
+
"""Create an LCM model.
|
| 278 |
+
:param config:
|
| 279 |
+
The configuration.
|
| 280 |
+
:param device:
|
| 281 |
+
The device on which to initialize modules.
|
| 282 |
+
:param dtype:
|
| 283 |
+
The data type of module parameters and buffers.
|
| 284 |
+
"""
|
| 285 |
+
return BaseLCModelBuilder(config, device=device, dtype=dtype).build_model()
|
lcm/models/base_lcm/frontend.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from fairseq2.logging import get_log_writer
|
| 11 |
+
from fairseq2.nn import Embedding, LearnedPositionEncoder, PositionEncoder
|
| 12 |
+
from fairseq2.nn.incremental_state import IncrementalStateBag
|
| 13 |
+
from fairseq2.nn.padding import PaddingMask
|
| 14 |
+
from fairseq2.nn.projection import Linear
|
| 15 |
+
from fairseq2.typing import DataType, Device
|
| 16 |
+
from torch import Tensor
|
| 17 |
+
from torch.nn import Dropout, Module
|
| 18 |
+
|
| 19 |
+
from lcm.models.sonar_normalizer.builder import SonarNormalizer
|
| 20 |
+
from lcm.nn.initialization import SONAR_STD, SUPPORTED_INIT_TYPES, get_init_fn
|
| 21 |
+
|
| 22 |
+
logger = get_log_writer(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class LCMFrontendConfig:
|
| 27 |
+
dropout_p: float = 0.0
|
| 28 |
+
""" The dropout probability applied to the module' output"""
|
| 29 |
+
|
| 30 |
+
pre_linear_bias: bool = True
|
| 31 |
+
""" Whether or not the pre-linear layer has a bias term"""
|
| 32 |
+
|
| 33 |
+
pre_linear_init_fn: SUPPORTED_INIT_TYPES = "kaiming_uniform"
|
| 34 |
+
|
| 35 |
+
scale_embeddings: bool = False
|
| 36 |
+
""" Scale the embeddings by model_dim before
|
| 37 |
+
adding positions (and before the pre_linear) """
|
| 38 |
+
|
| 39 |
+
weight_normalization: bool = False
|
| 40 |
+
|
| 41 |
+
embedding_std: float = SONAR_STD
|
| 42 |
+
"""Most SONAR embeddings have a distribution with the mean close to 0
|
| 43 |
+
and std close to 0.006. Initializing embedding-like parameters (e.g. end-of-text vector)
|
| 44 |
+
from a similar distribution is recommended, to minimize their disruption of the model training
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class LCMFrontend(Module):
|
| 49 |
+
"""
|
| 50 |
+
A fronted for the LCM with positional embeddings
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
embed: Embedding
|
| 54 |
+
scale: float
|
| 55 |
+
pos_encoder: Optional[PositionEncoder]
|
| 56 |
+
dropout: Optional[Dropout]
|
| 57 |
+
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
sonar_embed_dim: int,
|
| 61 |
+
model_dim: int,
|
| 62 |
+
config: LCMFrontendConfig,
|
| 63 |
+
pos_encoder: Optional[PositionEncoder],
|
| 64 |
+
timestep_embed_dim: int = 0,
|
| 65 |
+
sonar_normalizer: Optional[SonarNormalizer] = None,
|
| 66 |
+
*,
|
| 67 |
+
device: Optional[Device] = None,
|
| 68 |
+
dtype: Optional[DataType] = None,
|
| 69 |
+
) -> None:
|
| 70 |
+
"""
|
| 71 |
+
:param sonar_embed_dim
|
| 72 |
+
The embedding dimension of the sentence encoder, in this case SONAR
|
| 73 |
+
:param model_dim
|
| 74 |
+
The model embedding dimension
|
| 75 |
+
:param timestep_embed_dim
|
| 76 |
+
The embedding dimension of diffusion timesteps (if relevant, defaults to 0)
|
| 77 |
+
:param config:
|
| 78 |
+
A Frontend config. See `LCMFrontendConfig`
|
| 79 |
+
:param pos_encoder:
|
| 80 |
+
An optional position encoder.
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
super().__init__()
|
| 84 |
+
|
| 85 |
+
self.sonar_embed_dim = sonar_embed_dim
|
| 86 |
+
|
| 87 |
+
self.model_dim = model_dim
|
| 88 |
+
|
| 89 |
+
self.device = device
|
| 90 |
+
|
| 91 |
+
self.embed_scale: float = model_dim**0.5 if config.scale_embeddings else 1.0
|
| 92 |
+
|
| 93 |
+
logger.info(f"Using LCMFrontend with embeddings scaler = {self.embed_scale}")
|
| 94 |
+
|
| 95 |
+
# Optional sonar normalizer
|
| 96 |
+
self.sonar_normalizer = sonar_normalizer
|
| 97 |
+
|
| 98 |
+
# Pre-linear to map to model dimension
|
| 99 |
+
|
| 100 |
+
init_fn = get_init_fn(config.pre_linear_init_fn)
|
| 101 |
+
|
| 102 |
+
lin = Linear(
|
| 103 |
+
sonar_embed_dim + timestep_embed_dim,
|
| 104 |
+
model_dim,
|
| 105 |
+
bias=config.pre_linear_bias,
|
| 106 |
+
device=device,
|
| 107 |
+
dtype=dtype,
|
| 108 |
+
init_fn=init_fn,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
if config.weight_normalization:
|
| 112 |
+
self.pre_linear = torch.nn.utils.parametrizations.weight_norm(lin)
|
| 113 |
+
else:
|
| 114 |
+
self.pre_linear = lin
|
| 115 |
+
|
| 116 |
+
if pos_encoder is not None:
|
| 117 |
+
if pos_encoder.encoding_dim != self.model_dim:
|
| 118 |
+
raise ValueError(
|
| 119 |
+
f"`encoding_dim` of `pos_encoder` and `embedding_dim` of \
|
| 120 |
+
`embed` must be equal, but are {pos_encoder.encoding_dim} \
|
| 121 |
+
and {self.model_dim} instead."
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
self.pos_encoder = pos_encoder
|
| 125 |
+
else:
|
| 126 |
+
self.register_module("pos_encoder", None)
|
| 127 |
+
|
| 128 |
+
if config.dropout_p > 0.0:
|
| 129 |
+
self.dropout = Dropout(config.dropout_p)
|
| 130 |
+
else:
|
| 131 |
+
self.register_module("dropout", None)
|
| 132 |
+
|
| 133 |
+
self.reset_parameters(embedding_std=config.embedding_std)
|
| 134 |
+
|
| 135 |
+
def reset_parameters(self, embedding_std: float) -> None:
|
| 136 |
+
"""Initialize module parameters.
|
| 137 |
+
The positional embeddings should be initialized with the
|
| 138 |
+
same order of magnitude as the semantic embeddings, in order
|
| 139 |
+
to make the early training as stable as possible.
|
| 140 |
+
Otherwise, the positional and special token embeddings would
|
| 141 |
+
flood out the semantic information.
|
| 142 |
+
"""
|
| 143 |
+
logger.info(
|
| 144 |
+
f"Initializing frontend embeddings (special and positional) ~ N(0, {embedding_std})"
|
| 145 |
+
)
|
| 146 |
+
if isinstance(self.pos_encoder, LearnedPositionEncoder):
|
| 147 |
+
torch.nn.init.normal_(self.pos_encoder.weight, std=embedding_std)
|
| 148 |
+
|
| 149 |
+
def pre_forward(
|
| 150 |
+
self, seqs: Tensor, diffusion_timesteps: Optional[Tensor] = None, **kwargs
|
| 151 |
+
) -> Tensor:
|
| 152 |
+
return seqs
|
| 153 |
+
|
| 154 |
+
def forward(
|
| 155 |
+
self,
|
| 156 |
+
seqs: Tensor,
|
| 157 |
+
padding_mask: Optional[PaddingMask],
|
| 158 |
+
state_bag: Optional[IncrementalStateBag] = None,
|
| 159 |
+
diffusion_timesteps: Optional[Tensor] = None,
|
| 160 |
+
**kwargs,
|
| 161 |
+
) -> Tuple[Tensor, Optional[PaddingMask]]:
|
| 162 |
+
"""
|
| 163 |
+
Apply pre-linear (if relevant) and add positional embeddings
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
# Normalize in standard LCM or add timestep embeddings in diffusion frontentd
|
| 167 |
+
seqs = self.pre_forward(seqs, diffusion_timesteps, **kwargs)
|
| 168 |
+
|
| 169 |
+
# pre-linear if any:
|
| 170 |
+
seqs = self.pre_linear(self.embed_scale * seqs)
|
| 171 |
+
|
| 172 |
+
if self.pos_encoder is not None:
|
| 173 |
+
seqs = self.pos_encoder(
|
| 174 |
+
seqs,
|
| 175 |
+
padding_mask,
|
| 176 |
+
state_bag=state_bag,
|
| 177 |
+
**kwargs,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
if self.dropout is not None:
|
| 181 |
+
seqs = self.dropout(seqs)
|
| 182 |
+
|
| 183 |
+
return seqs, padding_mask
|
lcm/models/base_lcm/loader.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
from typing import Any, Dict
|
| 8 |
+
|
| 9 |
+
from fairseq2.models.config_loader import StandardModelConfigLoader
|
| 10 |
+
from fairseq2.models.loader import StandardModelLoader, load_model
|
| 11 |
+
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
|
| 12 |
+
|
| 13 |
+
from lcm.models.base_lcm.builder import (
|
| 14 |
+
BASE_LCM_MODEL_TYPE,
|
| 15 |
+
BaseLCModelConfig,
|
| 16 |
+
create_base_lcm_model,
|
| 17 |
+
lcm_archs,
|
| 18 |
+
)
|
| 19 |
+
from lcm.utils.model_type_registry import ModelTypeConfig, lcm_model_type_registry
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def convert_lcm_checkpoint(
|
| 25 |
+
checkpoint: Dict[str, Any], config: BaseLCModelConfig
|
| 26 |
+
) -> Dict[str, Any]:
|
| 27 |
+
# For DDP checkpoints
|
| 28 |
+
# We need to first remove the prefix "module." from state dict keys.
|
| 29 |
+
consume_prefix_in_state_dict_if_present(checkpoint["model"], "module.")
|
| 30 |
+
return checkpoint
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
load_base_lcm_config = StandardModelConfigLoader(
|
| 34 |
+
family=BASE_LCM_MODEL_TYPE,
|
| 35 |
+
config_kls=BaseLCModelConfig,
|
| 36 |
+
arch_configs=lcm_archs,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
load_base_lcm_model = StandardModelLoader(
|
| 40 |
+
config_loader=load_base_lcm_config,
|
| 41 |
+
factory=create_base_lcm_model,
|
| 42 |
+
checkpoint_converter=convert_lcm_checkpoint,
|
| 43 |
+
restrict_checkpoints=False,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
load_model.register(BASE_LCM_MODEL_TYPE, load_base_lcm_model)
|
| 47 |
+
|
| 48 |
+
lcm_model_type_registry.register(
|
| 49 |
+
ModelTypeConfig(
|
| 50 |
+
model_type=BASE_LCM_MODEL_TYPE,
|
| 51 |
+
config_loader=load_base_lcm_config,
|
| 52 |
+
model_factory=create_base_lcm_model,
|
| 53 |
+
model_loader=load_base_lcm_model,
|
| 54 |
+
)
|
| 55 |
+
)
|
lcm/models/base_lcm/normalization.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from typing import Optional, final
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from fairseq2.nn import LayerNorm, RMSNorm
|
| 10 |
+
from fairseq2.typing import DataType, Device, override
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@final
|
| 14 |
+
class FP32LayerNorm(LayerNorm):
|
| 15 |
+
"""Applies Layer Normalization in single-precision."""
|
| 16 |
+
|
| 17 |
+
@override
|
| 18 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 19 |
+
w, b = self.weight, self.bias
|
| 20 |
+
|
| 21 |
+
# cast input and params to float32
|
| 22 |
+
fp32_x = x.float()
|
| 23 |
+
fp32_w = w.float() if w is not None else None
|
| 24 |
+
fp32_b = b.float() if b is not None else None
|
| 25 |
+
|
| 26 |
+
y = torch.nn.functional.layer_norm(
|
| 27 |
+
fp32_x, self.normalized_shape, fp32_w, fp32_b, self.eps
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
return y.type_as(x)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def build_rms_layer_norm(
|
| 34 |
+
model_dim: int,
|
| 35 |
+
*,
|
| 36 |
+
device: Optional[Device] = None,
|
| 37 |
+
dtype: Optional[DataType] = None,
|
| 38 |
+
) -> LayerNorm:
|
| 39 |
+
"""Build an RMS Layer Normalization module."""
|
| 40 |
+
return RMSNorm(model_dim, bias=False, device=device, dtype=dtype)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def build_fp32_layer_norm(
|
| 44 |
+
model_dim: int,
|
| 45 |
+
*,
|
| 46 |
+
device: Optional[Device] = None,
|
| 47 |
+
dtype: Optional[DataType] = None,
|
| 48 |
+
) -> LayerNorm:
|
| 49 |
+
"""Build an Single-precision Layer Normalization module."""
|
| 50 |
+
return FP32LayerNorm(model_dim, bias=False, device=device, dtype=dtype)
|
lcm/models/sonar_normalizer/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
# Register architectures
|
| 7 |
+
import lcm.models.sonar_normalizer.archs # noqa
|
| 8 |
+
from lcm.models.sonar_normalizer.builder import (
|
| 9 |
+
SonarNormalizer,
|
| 10 |
+
SonarNormalizerConfig,
|
| 11 |
+
create_sonar_normalizer,
|
| 12 |
+
)
|
| 13 |
+
from lcm.models.sonar_normalizer.loader import load_sonar_normalizer_model
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"SonarNormalizer",
|
| 17 |
+
"SonarNormalizerConfig",
|
| 18 |
+
"create_sonar_normalizer",
|
| 19 |
+
"load_sonar_normalizer_model",
|
| 20 |
+
]
|
lcm/models/sonar_normalizer/archs.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from lcm.models.sonar_normalizer.builder import (
|
| 7 |
+
SonarNormalizerConfig,
|
| 8 |
+
sonar_normalizer_arch,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@sonar_normalizer_arch("base")
|
| 13 |
+
def _base_sonar_normalizer() -> SonarNormalizerConfig:
|
| 14 |
+
"""The base architecture for all center-and-scale normalizers
|
| 15 |
+
regardless of how the center/scale are estimated"""
|
| 16 |
+
return SonarNormalizerConfig(
|
| 17 |
+
dim=1024,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@sonar_normalizer_arch("base_page4k")
|
| 22 |
+
def _base_page_normalizer() -> SonarNormalizerConfig:
|
| 23 |
+
return SonarNormalizerConfig(
|
| 24 |
+
dim=4 * 1024,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@sonar_normalizer_arch("base_fft")
|
| 29 |
+
def _base_fft_sonar_normalizer() -> SonarNormalizerConfig:
|
| 30 |
+
return SonarNormalizerConfig(dim=1024, with_fft=True)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@sonar_normalizer_arch("clipping")
|
| 34 |
+
def _clipping_sonar_normalizer() -> SonarNormalizerConfig:
|
| 35 |
+
return SonarNormalizerConfig(dim=1024, clip_proba=1e-4)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@sonar_normalizer_arch("clipping_fft")
|
| 39 |
+
def _clipping_fft_sonar_normalizer() -> SonarNormalizerConfig:
|
| 40 |
+
return SonarNormalizerConfig(dim=1024, clip_proba=1e-4, with_fft=True)
|
lcm/models/sonar_normalizer/builder.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Literal, Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from fairseq2.config_registry import ConfigRegistry
|
| 11 |
+
from fairseq2.typing import DataType, Device
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
from torch.nn import Module
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class SonarNormalizerConfig:
|
| 18 |
+
dim: int = 1024
|
| 19 |
+
"""The dimension of the features to be normalized"""
|
| 20 |
+
|
| 21 |
+
clip_proba: Optional[float] = None
|
| 22 |
+
"""
|
| 23 |
+
If `clip_proba` is not None, `clip_min` and `clip_max` will
|
| 24 |
+
be used to clip the features before normalizing.
|
| 25 |
+
`clip_min` and `clip_max` correspond to the pre-computed `clip_proba`
|
| 26 |
+
and `1-clip_proba` quantiles respectively.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
with_fft: bool = False
|
| 30 |
+
"""
|
| 31 |
+
Applying FFT transform at the raw input before all other transforms.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
quantile_min: float = 0.25
|
| 35 |
+
"""The lower quantile used to measure the IQR when estimating the scale with a robust scaler"""
|
| 36 |
+
|
| 37 |
+
quantile_max: float = 0.75
|
| 38 |
+
"""The upper quantile used to measure the IQR when estimating the scale with a robust scaler"""
|
| 39 |
+
|
| 40 |
+
normalization_method: Literal["standard", "robust", "gaussian_robust"] = (
|
| 41 |
+
"gaussian_robust"
|
| 42 |
+
)
|
| 43 |
+
"""
|
| 44 |
+
Dictates how the normalizer's scale is evaluated when fitting.
|
| 45 |
+
(1) 'standard': center=mean, scale = std
|
| 46 |
+
(2) 'robust': center=median, scale = IQR = Qmax - Qmin
|
| 47 |
+
(3) 'gaussian_robust': center=median, scale = IQR / k,
|
| 48 |
+
where k=`stats.norm.ppf(q_max / 100.0) - stats.norm.ppf(q_min / 100.0)`
|
| 49 |
+
i.e scale = scale = 0.7413 x IQR if q_min=0.25 and q_max=0.75.
|
| 50 |
+
This is the robust normalization of https://arxiv.org/pdf/2307.05445
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
sonar_normalizer_archs = ConfigRegistry[SonarNormalizerConfig]()
|
| 55 |
+
sonar_normalizer_arch = sonar_normalizer_archs.decorator
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class FFTInterface:
|
| 59 |
+
@staticmethod
|
| 60 |
+
def fft_transform(embeddings: Tensor) -> Tensor:
|
| 61 |
+
dtype = embeddings.dtype
|
| 62 |
+
if dtype in [torch.float16, torch.bfloat16]:
|
| 63 |
+
embeddings = embeddings.to(dtype=torch.float32)
|
| 64 |
+
embeddings = torch.fft.rfft(embeddings, norm="backward")
|
| 65 |
+
return torch.concat(
|
| 66 |
+
[torch.real(embeddings), torch.imag(embeddings)[..., 1:-1]], dim=-1
|
| 67 |
+
).to(dtype)
|
| 68 |
+
|
| 69 |
+
@staticmethod
|
| 70 |
+
def fft_inverse_transform(embeddings: Tensor) -> Tensor:
|
| 71 |
+
assert embeddings.shape[-1] % 2 == 0
|
| 72 |
+
dtype = embeddings.dtype
|
| 73 |
+
if dtype in [torch.float16, torch.bfloat16]:
|
| 74 |
+
embeddings = embeddings.to(dtype=torch.float32)
|
| 75 |
+
rr, im = torch.split(
|
| 76 |
+
embeddings,
|
| 77 |
+
[embeddings.shape[-1] // 2 + 1, embeddings.shape[-1] // 2 - 1],
|
| 78 |
+
dim=-1,
|
| 79 |
+
)
|
| 80 |
+
im = torch.concat(
|
| 81 |
+
[torch.zeros_like(im[..., :1]), im, torch.zeros_like(im[..., :1])], dim=-1
|
| 82 |
+
)
|
| 83 |
+
embeddings = torch.fft.irfft(rr + im * 1j)
|
| 84 |
+
return embeddings.to(dtype)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class SonarNormalizer(FFTInterface, Module):
|
| 88 |
+
"""
|
| 89 |
+
To perform efficient diffusion modeling, SONAR embeddings need to be
|
| 90 |
+
normalized. This SonarNormalizer follows the robust normalization introduced in
|
| 91 |
+
https://arxiv.org/abs/2307.05445
|
| 92 |
+
Quoting from the paper: "Due to the very long-tailed feature distribution, typical mean and standard deviation statistics will be
|
| 93 |
+
heavily biased. We thus propose a robust alternative based on the feature distribution quantiles. We
|
| 94 |
+
take the median as the center of the distribution and approximate its scale using the Normalized
|
| 95 |
+
InterQuartile Range (IQR) for a normal distribution: 0.7413 × IQR
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
def __init__(
|
| 99 |
+
self,
|
| 100 |
+
config: SonarNormalizerConfig,
|
| 101 |
+
device: Optional[Device] = None,
|
| 102 |
+
dtype: Optional[DataType] = None,
|
| 103 |
+
) -> None:
|
| 104 |
+
super().__init__()
|
| 105 |
+
self.config = config
|
| 106 |
+
|
| 107 |
+
self.register_buffer(
|
| 108 |
+
"center", torch.zeros(config.dim, dtype=dtype, device=device)
|
| 109 |
+
)
|
| 110 |
+
self.register_buffer(
|
| 111 |
+
"scale", torch.ones(config.dim, dtype=dtype, device=device)
|
| 112 |
+
)
|
| 113 |
+
if self.config.clip_proba is not None:
|
| 114 |
+
self.register_buffer(
|
| 115 |
+
"clip_min", torch.ones(config.dim, dtype=dtype, device=device)
|
| 116 |
+
)
|
| 117 |
+
self.register_buffer(
|
| 118 |
+
"clip_max", torch.ones(config.dim, dtype=dtype, device=device)
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
def normalize(self, embeddings: Tensor) -> Tensor:
|
| 122 |
+
if self.config.with_fft:
|
| 123 |
+
embeddings = self.fft_transform(embeddings)
|
| 124 |
+
|
| 125 |
+
embeddings = (embeddings - self.center) / self.scale
|
| 126 |
+
if self.config.clip_proba is not None:
|
| 127 |
+
embeddings = torch.clamp(embeddings, min=self.clip_min, max=self.clip_max)
|
| 128 |
+
return embeddings
|
| 129 |
+
|
| 130 |
+
def denormalize(self, embeddings: Tensor) -> Tensor:
|
| 131 |
+
if self.config.clip_proba is not None:
|
| 132 |
+
embeddings = torch.clamp(embeddings, min=self.clip_min, max=self.clip_max)
|
| 133 |
+
|
| 134 |
+
embeddings = (embeddings * self.scale) + self.center
|
| 135 |
+
if self.config.with_fft:
|
| 136 |
+
embeddings = self.fft_inverse_transform(embeddings)
|
| 137 |
+
return embeddings
|
| 138 |
+
|
| 139 |
+
@torch.no_grad()
|
| 140 |
+
def fit(self, embeddings: Tensor):
|
| 141 |
+
if self.config.normalization_method in [
|
| 142 |
+
"robust",
|
| 143 |
+
"gaussian_robust",
|
| 144 |
+
]:
|
| 145 |
+
from sklearn.preprocessing import RobustScaler
|
| 146 |
+
|
| 147 |
+
_scaler = RobustScaler(
|
| 148 |
+
unit_variance=self.config.normalization_method == "gaussian_robust",
|
| 149 |
+
quantile_range=(self.config.quantile_min, self.config.quantile_max),
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
elif self.config.normalization_method == "standard":
|
| 153 |
+
from sklearn.preprocessing import StandardScaler
|
| 154 |
+
|
| 155 |
+
_scaler = StandardScaler()
|
| 156 |
+
else:
|
| 157 |
+
raise ValueError(
|
| 158 |
+
f"Unrecognizable method {self.config.normalization_method} for scaling input features"
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
assert embeddings.shape[-1] == self.config.dim
|
| 162 |
+
assert len(embeddings.shape) == 2
|
| 163 |
+
|
| 164 |
+
if self.config.with_fft:
|
| 165 |
+
embeddings = self.fft_transform(embeddings)
|
| 166 |
+
|
| 167 |
+
embeddings = _scaler.fit_transform(embeddings.cpu().float().numpy())
|
| 168 |
+
|
| 169 |
+
if self.config.normalization_method in [
|
| 170 |
+
"robust",
|
| 171 |
+
"gaussian_robust",
|
| 172 |
+
]:
|
| 173 |
+
_center = _scaler.center_
|
| 174 |
+
_scale = _scaler.scale_
|
| 175 |
+
|
| 176 |
+
elif self.config.normalization_method == "standard":
|
| 177 |
+
_center = _scaler.mean_
|
| 178 |
+
_scale = _scaler.scale_
|
| 179 |
+
|
| 180 |
+
self.center[:] = torch.tensor(
|
| 181 |
+
_center, dtype=self.center.dtype, device=self.center.device
|
| 182 |
+
)
|
| 183 |
+
self.scale[:] = torch.tensor(
|
| 184 |
+
_scale, dtype=self.scale.dtype, device=self.scale.device
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
if self.config.clip_proba is not None:
|
| 188 |
+
self.clip_min[:] = torch.quantile(
|
| 189 |
+
torch.tensor(embeddings), self.config.clip_proba, dim=0
|
| 190 |
+
).to(dtype=self.clip_min.dtype, device=self.clip_min.device)
|
| 191 |
+
self.clip_max[:] = torch.quantile(
|
| 192 |
+
torch.tensor(embeddings), 1 - self.config.clip_proba, dim=0
|
| 193 |
+
).to(dtype=self.clip_max.dtype, device=self.clip_max.device)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def create_sonar_normalizer(
|
| 197 |
+
config: SonarNormalizerConfig,
|
| 198 |
+
*,
|
| 199 |
+
device: Optional[Device] = None,
|
| 200 |
+
dtype: Optional[DataType] = None,
|
| 201 |
+
) -> SonarNormalizer:
|
| 202 |
+
"""Create an LCM model.
|
| 203 |
+
:param config:
|
| 204 |
+
The configuration.
|
| 205 |
+
:param device:
|
| 206 |
+
The device on which to initialize modules.
|
| 207 |
+
:param dtype:
|
| 208 |
+
The data type of module parameters and buffers.
|
| 209 |
+
"""
|
| 210 |
+
return SonarNormalizer(config, device=device, dtype=dtype)
|
lcm/models/sonar_normalizer/loader.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from fairseq2.models.config_loader import StandardModelConfigLoader
|
| 8 |
+
from fairseq2.models.loader import StandardModelLoader, load_model
|
| 9 |
+
|
| 10 |
+
from lcm.models.sonar_normalizer.builder import (
|
| 11 |
+
SonarNormalizerConfig,
|
| 12 |
+
create_sonar_normalizer,
|
| 13 |
+
sonar_normalizer_archs,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
load_sonar_normalizer_config = StandardModelConfigLoader(
|
| 17 |
+
family="sonar_normalizer",
|
| 18 |
+
config_kls=SonarNormalizerConfig,
|
| 19 |
+
arch_configs=sonar_normalizer_archs,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
load_sonar_normalizer_model = StandardModelLoader(
|
| 23 |
+
config_loader=load_sonar_normalizer_config,
|
| 24 |
+
factory=create_sonar_normalizer,
|
| 25 |
+
restrict_checkpoints=False,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
load_model.register("sonar_normalizer", load_sonar_normalizer_model)
|
lcm/models/two_tower_diffusion_lcm/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
# Register architectures
|
| 7 |
+
import lcm.models.two_tower_diffusion_lcm.archs # noqa
|
lcm/models/two_tower_diffusion_lcm/archs.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from lcm.models.two_tower_diffusion_lcm.builder import (
|
| 7 |
+
DenoiserConfig,
|
| 8 |
+
EncoderFrontendConfig,
|
| 9 |
+
TransformerConfig,
|
| 10 |
+
TwoTowerDiffusionLCModelConfig,
|
| 11 |
+
lcm_arch,
|
| 12 |
+
)
|
| 13 |
+
from lcm.nn.projection import ProjectionConfig
|
| 14 |
+
from lcm.nn.schedulers import DDIMSchedulerConfig
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@lcm_arch("toy_two_tower_diffusion_lcm")
|
| 18 |
+
def toy_lcm() -> TwoTowerDiffusionLCModelConfig:
|
| 19 |
+
return TwoTowerDiffusionLCModelConfig(
|
| 20 |
+
context_encoder=TransformerConfig(num_layers=2),
|
| 21 |
+
denoiser=DenoiserConfig(num_layers=2),
|
| 22 |
+
# TODO change normalizer name to align with the normalizer instructions
|
| 23 |
+
sonar_normalizer_name="dummy_sonar_normalizer_A",
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@lcm_arch("arch_lexa_lcm_pre0_toy")
|
| 28 |
+
def lexa_lcm_pre0_toy() -> TwoTowerDiffusionLCModelConfig:
|
| 29 |
+
return TwoTowerDiffusionLCModelConfig(
|
| 30 |
+
context_encoder=TransformerConfig(num_layers=2),
|
| 31 |
+
denoiser=DenoiserConfig(num_layers=2),
|
| 32 |
+
sonar_normalizer_name="sonar_normalizer_wikipedia_en_1m",
|
| 33 |
+
trained_with_cf_guidance=True,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@lcm_arch("arch_lexa_lcm_pre0_minimal")
|
| 38 |
+
def lexa_lcm_pre0_minimal() -> TwoTowerDiffusionLCModelConfig:
|
| 39 |
+
"""4-layer encoder / 6-layer denoiser / model dim 768"""
|
| 40 |
+
model_dim: int = 768 # Reduced from 2048 to 768
|
| 41 |
+
num_attn_heads: int = 12 # Reduced from 16 to 12
|
| 42 |
+
return TwoTowerDiffusionLCModelConfig(
|
| 43 |
+
model_dim=model_dim,
|
| 44 |
+
max_seq_len=2048,
|
| 45 |
+
frontend=EncoderFrontendConfig(),
|
| 46 |
+
context_encoder=TransformerConfig(
|
| 47 |
+
num_layers=3,
|
| 48 |
+
ffn_inner_dim=3 * model_dim, # Reduced from 4 * model_dim to 3 * model_dim
|
| 49 |
+
num_attn_heads=num_attn_heads,
|
| 50 |
+
final_dropout_p=0.0,
|
| 51 |
+
attention_dropout_p=0.0,
|
| 52 |
+
dropout_p=0.1,
|
| 53 |
+
mha_output_proj_bias=True,
|
| 54 |
+
use_swiglu=True,
|
| 55 |
+
layer_normalization_style="rms",
|
| 56 |
+
pos_embedding_style="rope",
|
| 57 |
+
),
|
| 58 |
+
denoiser=DenoiserConfig(
|
| 59 |
+
num_layers=6, # Reduced from 13 to 6
|
| 60 |
+
timestep_embed_dim=model_dim,
|
| 61 |
+
ffn_inner_dim=3 * model_dim, # Reduced from 4 * model_dim to 3 * model_dim
|
| 62 |
+
pos_embedding_style="none",
|
| 63 |
+
num_attn_heads=num_attn_heads,
|
| 64 |
+
final_dropout_p=0.0,
|
| 65 |
+
attention_dropout_p=0.0,
|
| 66 |
+
dropout_p=0.1,
|
| 67 |
+
mha_output_proj_bias=True,
|
| 68 |
+
use_swiglu=True,
|
| 69 |
+
layer_normalization_style="rms",
|
| 70 |
+
pre_denoiser=ProjectionConfig(),
|
| 71 |
+
post_denoiser=ProjectionConfig(),
|
| 72 |
+
),
|
| 73 |
+
sonar_normalizer_name="sonar_normalizer_wikipedia_en_1m",
|
| 74 |
+
trained_with_cf_guidance=True,
|
| 75 |
+
noise_scheduler=DDIMSchedulerConfig(num_diffusion_train_steps=100),
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@lcm_arch("arch_lexa_lcm_pre0")
|
| 80 |
+
def lexa_lcm_pre0() -> TwoTowerDiffusionLCModelConfig:
|
| 81 |
+
"""4-layer encoder / 10-layer denoiser / model dim 1024
|
| 82 |
+
Parameter Size: 287,880,192"""
|
| 83 |
+
model_dim: int = 1024 # Reduced from 2048 to 1024
|
| 84 |
+
num_attn_heads: int = 16
|
| 85 |
+
return TwoTowerDiffusionLCModelConfig(
|
| 86 |
+
model_dim=model_dim,
|
| 87 |
+
max_seq_len=2048,
|
| 88 |
+
frontend=EncoderFrontendConfig(),
|
| 89 |
+
context_encoder=TransformerConfig(
|
| 90 |
+
num_layers=4, # Reduced from 5 to 4
|
| 91 |
+
ffn_inner_dim=3 * model_dim, # Reduced from 4 * model_dim to 3 * model_dim
|
| 92 |
+
num_attn_heads=num_attn_heads,
|
| 93 |
+
final_dropout_p=0.0,
|
| 94 |
+
attention_dropout_p=0.0,
|
| 95 |
+
dropout_p=0.1,
|
| 96 |
+
mha_output_proj_bias=True,
|
| 97 |
+
use_swiglu=True,
|
| 98 |
+
layer_normalization_style="rms",
|
| 99 |
+
pos_embedding_style="rope",
|
| 100 |
+
),
|
| 101 |
+
denoiser=DenoiserConfig(
|
| 102 |
+
num_layers=10, # Reduced from 13 to 10
|
| 103 |
+
timestep_embed_dim=model_dim,
|
| 104 |
+
ffn_inner_dim=3 * model_dim, # Reduced from 4 * model_dim to 3 * model_dim
|
| 105 |
+
pos_embedding_style="none",
|
| 106 |
+
num_attn_heads=num_attn_heads,
|
| 107 |
+
final_dropout_p=0.0,
|
| 108 |
+
attention_dropout_p=0.0,
|
| 109 |
+
dropout_p=0.1,
|
| 110 |
+
mha_output_proj_bias=True,
|
| 111 |
+
use_swiglu=True,
|
| 112 |
+
layer_normalization_style="rms",
|
| 113 |
+
pre_denoiser=ProjectionConfig(),
|
| 114 |
+
post_denoiser=ProjectionConfig(),
|
| 115 |
+
),
|
| 116 |
+
sonar_normalizer_name="sonar_normalizer_wikipedia_en_1m",
|
| 117 |
+
trained_with_cf_guidance=True,
|
| 118 |
+
noise_scheduler=DDIMSchedulerConfig(num_diffusion_train_steps=100),
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@lcm_arch("two_tower_diffusion_lcm_1_6B")
|
| 123 |
+
def two_tower_diffusion_lcm_1_6B() -> TwoTowerDiffusionLCModelConfig:
|
| 124 |
+
"""5-layer encodder / 13-layer denoiser / model dim 2048
|
| 125 |
+
Parameter Size: 1,635,101,696"""
|
| 126 |
+
model_dim: int = 2048
|
| 127 |
+
num_attn_heads: int = 16
|
| 128 |
+
return TwoTowerDiffusionLCModelConfig(
|
| 129 |
+
model_dim=model_dim,
|
| 130 |
+
max_seq_len=4096,
|
| 131 |
+
frontend=EncoderFrontendConfig(),
|
| 132 |
+
context_encoder=TransformerConfig(
|
| 133 |
+
num_layers=5,
|
| 134 |
+
ffn_inner_dim=4 * model_dim,
|
| 135 |
+
num_attn_heads=num_attn_heads,
|
| 136 |
+
final_dropout_p=0.0,
|
| 137 |
+
attention_dropout_p=0.0,
|
| 138 |
+
dropout_p=0.1,
|
| 139 |
+
mha_output_proj_bias=True,
|
| 140 |
+
use_swiglu=True,
|
| 141 |
+
layer_normalization_style="rms",
|
| 142 |
+
pos_embedding_style="rope",
|
| 143 |
+
),
|
| 144 |
+
denoiser=DenoiserConfig(
|
| 145 |
+
num_layers=13,
|
| 146 |
+
timestep_embed_dim=model_dim,
|
| 147 |
+
ffn_inner_dim=4 * model_dim,
|
| 148 |
+
pos_embedding_style="none",
|
| 149 |
+
num_attn_heads=num_attn_heads,
|
| 150 |
+
final_dropout_p=0.0,
|
| 151 |
+
attention_dropout_p=0.0,
|
| 152 |
+
dropout_p=0.1,
|
| 153 |
+
mha_output_proj_bias=True,
|
| 154 |
+
use_swiglu=True,
|
| 155 |
+
layer_normalization_style="rms",
|
| 156 |
+
pre_denoiser=ProjectionConfig(),
|
| 157 |
+
post_denoiser=ProjectionConfig(),
|
| 158 |
+
),
|
| 159 |
+
# TODO change normalizer name to align with the normalizer instructions
|
| 160 |
+
sonar_normalizer_name="dummy_sonar_normalizer_B",
|
| 161 |
+
trained_with_cf_guidance=True,
|
| 162 |
+
noise_scheduler=DDIMSchedulerConfig(num_diffusion_train_steps=100),
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
@lcm_arch("two_tower_diffusion_lcm_7B")
|
| 167 |
+
def two_tower_diffusion_lcm_7B() -> TwoTowerDiffusionLCModelConfig:
|
| 168 |
+
# 5-layer encodder / 14-layer denoiser / model dim 4096
|
| 169 |
+
# Parameter Size: 6,930,781,696
|
| 170 |
+
model_dim: int = 4096
|
| 171 |
+
num_attn_heads: int = 32
|
| 172 |
+
return TwoTowerDiffusionLCModelConfig(
|
| 173 |
+
model_dim=model_dim,
|
| 174 |
+
max_seq_len=4096,
|
| 175 |
+
frontend=EncoderFrontendConfig(),
|
| 176 |
+
context_encoder=TransformerConfig(
|
| 177 |
+
num_layers=5,
|
| 178 |
+
ffn_inner_dim=4 * model_dim,
|
| 179 |
+
num_attn_heads=num_attn_heads,
|
| 180 |
+
final_dropout_p=0.0,
|
| 181 |
+
attention_dropout_p=0.0,
|
| 182 |
+
dropout_p=0.1,
|
| 183 |
+
mha_output_proj_bias=True,
|
| 184 |
+
use_swiglu=True,
|
| 185 |
+
layer_normalization_style="rms",
|
| 186 |
+
pos_embedding_style="rope",
|
| 187 |
+
),
|
| 188 |
+
denoiser=DenoiserConfig(
|
| 189 |
+
num_layers=14,
|
| 190 |
+
timestep_embed_dim=model_dim,
|
| 191 |
+
ffn_inner_dim=4 * model_dim,
|
| 192 |
+
pos_embedding_style="none",
|
| 193 |
+
num_attn_heads=num_attn_heads,
|
| 194 |
+
final_dropout_p=0.0,
|
| 195 |
+
attention_dropout_p=0.0,
|
| 196 |
+
dropout_p=0.1,
|
| 197 |
+
mha_output_proj_bias=True,
|
| 198 |
+
use_swiglu=True,
|
| 199 |
+
layer_normalization_style="rms",
|
| 200 |
+
pre_denoiser=ProjectionConfig(),
|
| 201 |
+
post_denoiser=ProjectionConfig(),
|
| 202 |
+
),
|
| 203 |
+
# TODO change normalizer name to align with the normalizer instructions
|
| 204 |
+
sonar_normalizer_name="dummy_sonar_normalizer_C",
|
| 205 |
+
trained_with_cf_guidance=True,
|
| 206 |
+
noise_scheduler=DDIMSchedulerConfig(num_diffusion_train_steps=100),
|
| 207 |
+
)
|
lcm/models/two_tower_diffusion_lcm/builder.py
ADDED
|
@@ -0,0 +1,628 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
+
from typing import Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from fairseq2.config_registry import ConfigRegistry
|
| 11 |
+
from fairseq2.logging import get_log_writer
|
| 12 |
+
from fairseq2.nn.padding import PaddingMask, get_seq_lens
|
| 13 |
+
from fairseq2.nn.transformer import CausalAttentionMaskFactory
|
| 14 |
+
from fairseq2.typing import DataType, Device
|
| 15 |
+
from torch import Tensor
|
| 16 |
+
|
| 17 |
+
from lcm.datasets.batch import EmbeddingsBatch
|
| 18 |
+
from lcm.models.abstract_lcm import (
|
| 19 |
+
AbstractLCModel,
|
| 20 |
+
AbstractLCModelBuilder,
|
| 21 |
+
AbstractLCModelConfig,
|
| 22 |
+
)
|
| 23 |
+
from lcm.models.sonar_normalizer.builder import SonarNormalizer
|
| 24 |
+
from lcm.models.two_tower_diffusion_lcm.frontend import (
|
| 25 |
+
EncoderFrontend,
|
| 26 |
+
EncoderFrontendConfig,
|
| 27 |
+
)
|
| 28 |
+
from lcm.nn.denoisers import (
|
| 29 |
+
DenoiserConfig,
|
| 30 |
+
LCMDenoiser,
|
| 31 |
+
LCMDenoiserTransformerFactory,
|
| 32 |
+
)
|
| 33 |
+
from lcm.nn.incremental_state import LCMIncrementalStateBag
|
| 34 |
+
from lcm.nn.initialization import parse_norm_order
|
| 35 |
+
from lcm.nn.normalization import parse_layer_norm_factory
|
| 36 |
+
from lcm.nn.schedulers import DDIMScheduler, DDIMSchedulerConfig
|
| 37 |
+
from lcm.nn.transformer import (
|
| 38 |
+
LCMTransformerDecoder,
|
| 39 |
+
TransformerConfig,
|
| 40 |
+
TransformerFactory,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
logger = get_log_writer(__name__)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE = "two_tower_diffusion_lcm"
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@dataclass
|
| 50 |
+
class TwoTowerDiffusionLCModelConfig(AbstractLCModelConfig):
|
| 51 |
+
model_type: str = TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE
|
| 52 |
+
|
| 53 |
+
max_seq_len: int = 2048
|
| 54 |
+
|
| 55 |
+
model_dim: int = 1024
|
| 56 |
+
|
| 57 |
+
frontend: EncoderFrontendConfig = field(
|
| 58 |
+
default_factory=lambda: EncoderFrontendConfig()
|
| 59 |
+
)
|
| 60 |
+
""" The fronted config. This module maps from `sonar_embed_dim` to `model_dim`
|
| 61 |
+
and potentially adds positional embeddings"""
|
| 62 |
+
|
| 63 |
+
context_encoder: TransformerConfig = field(
|
| 64 |
+
default_factory=lambda: TransformerConfig()
|
| 65 |
+
)
|
| 66 |
+
"""The context encoder config. This is causal Transformer decoder"""
|
| 67 |
+
|
| 68 |
+
noise_scheduler: DDIMSchedulerConfig = field(
|
| 69 |
+
default_factory=lambda: DDIMSchedulerConfig()
|
| 70 |
+
)
|
| 71 |
+
"""The config of the noise scheduler.
|
| 72 |
+
See lcm/diffusion_schedulers/ddim for more"""
|
| 73 |
+
|
| 74 |
+
denoiser: DenoiserConfig = field(default_factory=lambda: DenoiserConfig())
|
| 75 |
+
"""the config of the denoiser"""
|
| 76 |
+
|
| 77 |
+
trained_with_cf_guidance: bool = False
|
| 78 |
+
"""If `True`, the model will be trained with classifier-free guidance i.e.,
|
| 79 |
+
unconditional embedding generation.
|
| 80 |
+
The CF-guidance probability is set in
|
| 81 |
+
DiffusionLCMCriterionConfig.cf_guidance_probability"""
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
lcm_archs = ConfigRegistry[TwoTowerDiffusionLCModelConfig]()
|
| 85 |
+
lcm_arch = lcm_archs.decorator
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class TwoTowerDiffusionLCModel(AbstractLCModel):
|
| 89 |
+
"""Class for a diffusion-based LCM model"""
|
| 90 |
+
|
| 91 |
+
config: TwoTowerDiffusionLCModelConfig
|
| 92 |
+
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
config: TwoTowerDiffusionLCModelConfig,
|
| 96 |
+
sonar_normalizer: SonarNormalizer,
|
| 97 |
+
encoder_frontend: EncoderFrontend,
|
| 98 |
+
context_encoder: LCMTransformerDecoder,
|
| 99 |
+
denoiser: LCMDenoiser,
|
| 100 |
+
noise_scheduler: DDIMScheduler,
|
| 101 |
+
) -> None:
|
| 102 |
+
super().__init__(config)
|
| 103 |
+
|
| 104 |
+
self.model_dim = context_encoder.model_dim
|
| 105 |
+
|
| 106 |
+
self.sonar_embed_dim = config.sonar_embed_dim
|
| 107 |
+
|
| 108 |
+
self.sonar_normalizer = sonar_normalizer
|
| 109 |
+
|
| 110 |
+
self.encoder_frontend = encoder_frontend
|
| 111 |
+
"""The frontend of the context encoder.
|
| 112 |
+
This frontend simply applies a pre-linear projection
|
| 113 |
+
(to increase dimensionality) then adds positional embeddings"""
|
| 114 |
+
|
| 115 |
+
self.context_encoder = context_encoder
|
| 116 |
+
"""A causal Transformer decoder"""
|
| 117 |
+
|
| 118 |
+
self.noise_scheduler = noise_scheduler
|
| 119 |
+
"""The diffusion noise scheduler"""
|
| 120 |
+
|
| 121 |
+
self.denoiser = denoiser
|
| 122 |
+
|
| 123 |
+
def extra_repr(self) -> str:
|
| 124 |
+
""":meta private:"""
|
| 125 |
+
s = super().extra_repr()
|
| 126 |
+
return f"{s}, dtype={self.dtype}"
|
| 127 |
+
|
| 128 |
+
def forward(
|
| 129 |
+
self,
|
| 130 |
+
batch: EmbeddingsBatch,
|
| 131 |
+
noisy_batch: EmbeddingsBatch,
|
| 132 |
+
cf_guidance_prob: float = 0.0,
|
| 133 |
+
) -> EmbeddingsBatch:
|
| 134 |
+
"""
|
| 135 |
+
Arguments:
|
| 136 |
+
- batch (`EmbeddingsBatch`): The clean batch of embeddings to encode the context.
|
| 137 |
+
If `unsupervised` this is the source embeddings.
|
| 138 |
+
If `supervised` this is the source+target embeddings.
|
| 139 |
+
|
| 140 |
+
- noisy_batch (`EmbeddingsBatch`): the embeddings noised by the noise scheduler
|
| 141 |
+
If `unsupervised` this is noised source embeddings.
|
| 142 |
+
If `supervised` this is noised target-only embeddings.
|
| 143 |
+
|
| 144 |
+
- cf_guidance_prob: probability of training without any guiding context
|
| 145 |
+
"""
|
| 146 |
+
# Get source lengths if any:
|
| 147 |
+
source_lengths = batch.source_lengths
|
| 148 |
+
|
| 149 |
+
# Encode as context:
|
| 150 |
+
context = self.encode(batch)
|
| 151 |
+
|
| 152 |
+
# Predict denoised output
|
| 153 |
+
output_batch = self.denoise(
|
| 154 |
+
noisy_batch=noisy_batch,
|
| 155 |
+
context=context,
|
| 156 |
+
source_lengths=source_lengths,
|
| 157 |
+
cf_guidance_prob=cf_guidance_prob,
|
| 158 |
+
)
|
| 159 |
+
return output_batch
|
| 160 |
+
|
| 161 |
+
def encode(
|
| 162 |
+
self,
|
| 163 |
+
batch: EmbeddingsBatch,
|
| 164 |
+
state_bag: Optional[LCMIncrementalStateBag] = None,
|
| 165 |
+
**kwargs,
|
| 166 |
+
) -> EmbeddingsBatch:
|
| 167 |
+
"""
|
| 168 |
+
The main context encoder that takes in a sequence of sonar embeddings in B, T, D
|
| 169 |
+
and returns a sequence of the same shape after causal contextualization.
|
| 170 |
+
|
| 171 |
+
Main modules:
|
| 172 |
+
`frontend`: linear projection to model_dim + optional positional embeddings,
|
| 173 |
+
`context_encoder`: Causal Transformer decoder to causally encode the context
|
| 174 |
+
"""
|
| 175 |
+
# Frontend
|
| 176 |
+
seqs, padding_mask = self.encoder_frontend(
|
| 177 |
+
batch.seqs,
|
| 178 |
+
batch.padding_mask,
|
| 179 |
+
state_bag=state_bag,
|
| 180 |
+
**kwargs,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# Main Transformer
|
| 184 |
+
seqs, padding_mask = self.context_encoder(
|
| 185 |
+
seqs,
|
| 186 |
+
padding_mask,
|
| 187 |
+
state_bag=state_bag,
|
| 188 |
+
**kwargs,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
return EmbeddingsBatch(seqs=seqs, padding_mask=padding_mask)
|
| 192 |
+
|
| 193 |
+
def denoise(
|
| 194 |
+
self,
|
| 195 |
+
noisy_batch: EmbeddingsBatch,
|
| 196 |
+
context: EmbeddingsBatch,
|
| 197 |
+
source_lengths: Optional[Tensor] = None,
|
| 198 |
+
cf_guidance_prob: float = 0.0,
|
| 199 |
+
state_bag: Optional[LCMIncrementalStateBag] = None,
|
| 200 |
+
inference: bool = False,
|
| 201 |
+
) -> EmbeddingsBatch:
|
| 202 |
+
"""Diffuse a noised sonar embedding conditioned on the encoded context"""
|
| 203 |
+
seqs, padding_mask = self.denoiser(
|
| 204 |
+
seqs=noisy_batch.seqs,
|
| 205 |
+
diffusion_timesteps=noisy_batch.diffusion_timesteps,
|
| 206 |
+
padding_mask=noisy_batch.padding_mask,
|
| 207 |
+
conditioning_variables=context.seqs,
|
| 208 |
+
conditioning_variables_padding_mask=context.padding_mask,
|
| 209 |
+
source_lengths=source_lengths,
|
| 210 |
+
cf_guidance_prob=cf_guidance_prob,
|
| 211 |
+
inference=inference,
|
| 212 |
+
)
|
| 213 |
+
return EmbeddingsBatch(seqs=seqs, padding_mask=padding_mask)
|
| 214 |
+
|
| 215 |
+
def prep_for_denoising(self, decoding_options):
|
| 216 |
+
"""This setup is done once when we initialize the generator"""
|
| 217 |
+
self.guidance_scale = decoding_options.guidance_scale
|
| 218 |
+
self.guidance_rescale = decoding_options.guidance_rescale
|
| 219 |
+
self.initial_noise_scale = decoding_options.initial_noise_scale
|
| 220 |
+
self.timesteps = decoding_options.inference_timesteps
|
| 221 |
+
self.clip_noise = decoding_options.clip_noise
|
| 222 |
+
self.ddim_eta = decoding_options.ddim_eta
|
| 223 |
+
self.epsilon_scaling = decoding_options.epsilon_scaling
|
| 224 |
+
|
| 225 |
+
# if guidance_scale > 1.0 we will duplicate batches
|
| 226 |
+
self.do_classifier_free_guidance = self.guidance_scale != 1.0
|
| 227 |
+
|
| 228 |
+
# Setup the diffusion training-like noise scheduler
|
| 229 |
+
# by updating the timesteps according to the decoding `inference_timesteps`
|
| 230 |
+
self.noise_scheduler.set_timesteps(self.timesteps, device=self.device)
|
| 231 |
+
|
| 232 |
+
# Override the initial noise scale
|
| 233 |
+
self.noise_scheduler.init_noise_sigma = self.initial_noise_scale
|
| 234 |
+
# Override thresholding options:
|
| 235 |
+
if decoding_options.thresholding:
|
| 236 |
+
self.noise_scheduler.config.thresholding = decoding_options.thresholding
|
| 237 |
+
self.noise_scheduler.config.dynamic_thresholding_ratio = (
|
| 238 |
+
decoding_options.dynamic_thresholding_ratio
|
| 239 |
+
)
|
| 240 |
+
self.noise_scheduler.config.sample_max_value = (
|
| 241 |
+
decoding_options.sample_max_value
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
def sample_initial_noise_vectors(self, batch_size: int):
|
| 245 |
+
# Check that we called `prep_for_denoising`:
|
| 246 |
+
assert hasattr(self, "clip_noise"), (
|
| 247 |
+
"The model is not properly set for decoding, make sure to call `model.prep_for_denoising()`"
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
# Sample a noise vector for next embedding prediction
|
| 251 |
+
latents = torch.randn(
|
| 252 |
+
batch_size, 1, self.config.sonar_embed_dim, device=self.device
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
# Scale the initial noise by the standard deviation required by the scheduler
|
| 256 |
+
latents = latents * self.noise_scheduler.init_noise_sigma
|
| 257 |
+
|
| 258 |
+
# clip?
|
| 259 |
+
latents = latents.clip(-self.clip_noise, self.clip_noise)
|
| 260 |
+
return latents
|
| 261 |
+
|
| 262 |
+
@torch.inference_mode()
|
| 263 |
+
def predict_next_sentence( # type: ignore
|
| 264 |
+
self,
|
| 265 |
+
batch: EmbeddingsBatch,
|
| 266 |
+
context: EmbeddingsBatch,
|
| 267 |
+
temperature: float = 1.0,
|
| 268 |
+
state_bag: Optional[LCMIncrementalStateBag] = None,
|
| 269 |
+
context_state_bag: Optional[LCMIncrementalStateBag] = None,
|
| 270 |
+
**kwargs,
|
| 271 |
+
) -> Tuple[EmbeddingsBatch, EmbeddingsBatch]:
|
| 272 |
+
assert context_state_bag is not None, (
|
| 273 |
+
"Expected a state_bag to incrementally encode the context"
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
if self.do_classifier_free_guidance:
|
| 277 |
+
logger.debug("Running inference with CF-guidance...")
|
| 278 |
+
return self.predict_next_sentence_with_cf_guidance(
|
| 279 |
+
batch=batch,
|
| 280 |
+
context=context,
|
| 281 |
+
temperature=temperature,
|
| 282 |
+
state_bag=state_bag,
|
| 283 |
+
context_state_bag=context_state_bag,
|
| 284 |
+
**kwargs,
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
# Normalize the input embeddings if we're expected to
|
| 288 |
+
# normalize outside of the model's forward pass
|
| 289 |
+
if self.sonar_normalizer is not None:
|
| 290 |
+
batch = batch.normalize_seqs(self.sonar_normalizer)
|
| 291 |
+
|
| 292 |
+
# Encode context:
|
| 293 |
+
new_context = self.encode(batch, context_state_bag)
|
| 294 |
+
context_state_bag.increment_step_nr(1)
|
| 295 |
+
|
| 296 |
+
# Append to context
|
| 297 |
+
context = EmbeddingsBatch(torch.cat((context.seqs, new_context.seqs), dim=1))
|
| 298 |
+
|
| 299 |
+
# Sample latents:
|
| 300 |
+
latents = self.sample_initial_noise_vectors(batch_size=batch.seqs.size(0))
|
| 301 |
+
|
| 302 |
+
# Denoise
|
| 303 |
+
diffusion_timesteps_schedule = self.noise_scheduler.timesteps
|
| 304 |
+
|
| 305 |
+
for diffusion_timestep in diffusion_timesteps_schedule:
|
| 306 |
+
input_batch = EmbeddingsBatch(
|
| 307 |
+
seqs=latents,
|
| 308 |
+
diffusion_timesteps=diffusion_timestep.long().repeat(
|
| 309 |
+
(latents.shape[0], 1)
|
| 310 |
+
),
|
| 311 |
+
)
|
| 312 |
+
# Get model output
|
| 313 |
+
model_prediction = self.denoise(
|
| 314 |
+
noisy_batch=input_batch,
|
| 315 |
+
context=context,
|
| 316 |
+
state_bag=None,
|
| 317 |
+
inference=True,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
scheduler_outputs = self.noise_scheduler.step(
|
| 321 |
+
model_output=model_prediction.seqs,
|
| 322 |
+
timestep=diffusion_timestep,
|
| 323 |
+
sample=latents,
|
| 324 |
+
eta=self.ddim_eta,
|
| 325 |
+
epsilon_scaling=self.epsilon_scaling,
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
# setup latents for the next diffusion step
|
| 329 |
+
latents = scheduler_outputs.prev_sample
|
| 330 |
+
# clip?
|
| 331 |
+
latents = latents.clip(-self.clip_noise, self.clip_noise)
|
| 332 |
+
|
| 333 |
+
# Take the final predicted denoised sample (x_0 in the ddim paper) and denormalize if needed:
|
| 334 |
+
final_seqs = scheduler_outputs.pred_original_sample
|
| 335 |
+
|
| 336 |
+
final_seqs = self.sonar_normalizer.denormalize(final_seqs)
|
| 337 |
+
|
| 338 |
+
return EmbeddingsBatch(final_seqs, None), context
|
| 339 |
+
|
| 340 |
+
@torch.inference_mode()
|
| 341 |
+
def predict_next_sentence_with_cf_guidance( # type: ignore
|
| 342 |
+
self,
|
| 343 |
+
batch: EmbeddingsBatch,
|
| 344 |
+
context: EmbeddingsBatch,
|
| 345 |
+
temperature: float = 1.0,
|
| 346 |
+
state_bag: Optional[LCMIncrementalStateBag] = None,
|
| 347 |
+
context_state_bag: Optional[LCMIncrementalStateBag] = None,
|
| 348 |
+
**kwargs,
|
| 349 |
+
) -> Tuple[EmbeddingsBatch, EmbeddingsBatch]:
|
| 350 |
+
assert context_state_bag is not None, (
|
| 351 |
+
"Expected a state_bag to incrementally encode the context"
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
# Normalize the input embeddings if we're expected to
|
| 355 |
+
# normalize outside of the model's forward pass
|
| 356 |
+
if self.sonar_normalizer is not None:
|
| 357 |
+
batch = batch.normalize_seqs(self.sonar_normalizer)
|
| 358 |
+
|
| 359 |
+
# Encode context:
|
| 360 |
+
new_context = self.encode(batch, context_state_bag)
|
| 361 |
+
context_state_bag.increment_step_nr(1)
|
| 362 |
+
|
| 363 |
+
# Append to context
|
| 364 |
+
context = EmbeddingsBatch(torch.cat((context.seqs, new_context.seqs), dim=1))
|
| 365 |
+
|
| 366 |
+
# Sample latents:
|
| 367 |
+
latents = self.sample_initial_noise_vectors(batch_size=batch.seqs.size(0))
|
| 368 |
+
|
| 369 |
+
# Denoise
|
| 370 |
+
diffusion_timesteps_schedule = self.noise_scheduler.timesteps
|
| 371 |
+
|
| 372 |
+
# Duplicate the context and its padding mask, the second half will be ignored
|
| 373 |
+
_seq_lens = get_seq_lens(context.seqs, context.padding_mask)
|
| 374 |
+
|
| 375 |
+
# add zeros:
|
| 376 |
+
_seq_lens = torch.concat((_seq_lens, torch.zeros_like(_seq_lens)), dim=0)
|
| 377 |
+
|
| 378 |
+
context = EmbeddingsBatch(
|
| 379 |
+
torch.concat((context.seqs, torch.zeros_like(context.seqs)), dim=0),
|
| 380 |
+
PaddingMask(_seq_lens, batch_seq_len=context.seqs.size(1)),
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
batch_multiplier = 2
|
| 384 |
+
for diffusion_timestep in diffusion_timesteps_schedule:
|
| 385 |
+
is_max_diffusion_step = (
|
| 386 |
+
diffusion_timestep == self.noise_scheduler.num_diffusion_train_steps - 1
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
input_batch = EmbeddingsBatch(
|
| 390 |
+
torch.concat(batch_multiplier * [latents], dim=0),
|
| 391 |
+
diffusion_timesteps=diffusion_timestep.long().repeat(
|
| 392 |
+
(latents.shape[0] * batch_multiplier, 1)
|
| 393 |
+
),
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
model_prediction = self.denoise(
|
| 397 |
+
noisy_batch=input_batch,
|
| 398 |
+
context=context,
|
| 399 |
+
state_bag=None,
|
| 400 |
+
inference=True,
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
# If at the max step, do not step in the epsilon_scheduler
|
| 404 |
+
if is_max_diffusion_step:
|
| 405 |
+
# if beta_prod_t (denominator) is null i.e.,
|
| 406 |
+
# the diffusion timestep is at its max value (num_training_stesp-1)
|
| 407 |
+
# no denoising will be performed.
|
| 408 |
+
|
| 409 |
+
# Note that since the batch might be doubled because
|
| 410 |
+
# we're doing classifier-free guidance, we chunk the model output
|
| 411 |
+
# by batch_multiplier. If not at max_diffusion_step
|
| 412 |
+
# this chunking is performed in apply_classifier_free_guidance
|
| 413 |
+
scheduler_outputs = self.noise_scheduler.step(
|
| 414 |
+
model_output=model_prediction.seqs.chunk(batch_multiplier)[0],
|
| 415 |
+
timestep=diffusion_timestep,
|
| 416 |
+
sample=latents,
|
| 417 |
+
eta=self.ddim_eta,
|
| 418 |
+
epsilon_scaling=self.epsilon_scaling,
|
| 419 |
+
)
|
| 420 |
+
else:
|
| 421 |
+
# Predict the noise residual according to the prediction type
|
| 422 |
+
predicted_noise = self.noise_scheduler.get_epsilon(
|
| 423 |
+
model_output=model_prediction.seqs,
|
| 424 |
+
sample=input_batch.seqs,
|
| 425 |
+
timestep=diffusion_timestep,
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
if self.do_classifier_free_guidance:
|
| 429 |
+
# Perform guidance if trained with cf-guidance:
|
| 430 |
+
# The returned predicted noise will combine the conditional and
|
| 431 |
+
# unconditional predictions i.e., from (2 x batch_size, 1, C)
|
| 432 |
+
# to: (batch_size, 1, C)
|
| 433 |
+
predicted_noise = self.apply_classifier_free_guidance(
|
| 434 |
+
predicted_noise
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
# The cf-guidance operates on predicted noises and although we
|
| 438 |
+
# can go back and forth between epsilon and predicted sample
|
| 439 |
+
# once we combine cond and uncond we cannot go back to predicted_x0
|
| 440 |
+
|
| 441 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 442 |
+
scheduler_outputs = self.noise_scheduler.step(
|
| 443 |
+
model_output=predicted_noise,
|
| 444 |
+
timestep=diffusion_timestep,
|
| 445 |
+
sample=latents,
|
| 446 |
+
eta=self.ddim_eta,
|
| 447 |
+
epsilon_scaling=self.epsilon_scaling,
|
| 448 |
+
prediction_type="epsilon",
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
# setup latents for the next diffusion step
|
| 452 |
+
latents = scheduler_outputs.prev_sample
|
| 453 |
+
# clip?
|
| 454 |
+
latents = latents.clip(-self.clip_noise, self.clip_noise)
|
| 455 |
+
|
| 456 |
+
# Take the final predicted denoised sample (x_0 in the ddim paper) and denormalize if needed:
|
| 457 |
+
final_seqs = scheduler_outputs.pred_original_sample
|
| 458 |
+
|
| 459 |
+
final_seqs = self.sonar_normalizer.denormalize(final_seqs)
|
| 460 |
+
|
| 461 |
+
return EmbeddingsBatch(final_seqs, None), context
|
| 462 |
+
|
| 463 |
+
def apply_classifier_free_guidance(self, predicted_noise: Tensor) -> Tensor:
|
| 464 |
+
""" "
|
| 465 |
+
Apply Classifier-Free Guidance with Rescale as introduced in Algorithm 2 of https://arxiv.org/pdf/2305.08891
|
| 466 |
+
`pos` would be the conditional prediction `cond_prediction`
|
| 467 |
+
and `neg` the unconditional prediction `uncond_prediction`:
|
| 468 |
+
The batch during prefilling is prepared with the conditioning prefix in
|
| 469 |
+
the first half
|
| 470 |
+
"""
|
| 471 |
+
# Chunk and follow algorithm 2
|
| 472 |
+
cond_prediction, uncond_prediction = predicted_noise.chunk(2)
|
| 473 |
+
|
| 474 |
+
# Regular classifier-free guidance:
|
| 475 |
+
guided_noise_prediction = uncond_prediction + self.guidance_scale * (
|
| 476 |
+
cond_prediction - uncond_prediction
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
# Rescale classifier-free guidance to prevent over-exposure
|
| 480 |
+
# Calculate standard deviations.
|
| 481 |
+
std_pos = cond_prediction.std(dim=-1, keepdim=True)
|
| 482 |
+
std_cfg = guided_noise_prediction.std(dim=-1, keepdim=True)
|
| 483 |
+
|
| 484 |
+
# Apply guidance rescale with fused operations.
|
| 485 |
+
factor = std_pos / std_cfg
|
| 486 |
+
factor = self.guidance_rescale * factor + (1 - self.guidance_rescale)
|
| 487 |
+
|
| 488 |
+
return factor * guided_noise_prediction
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
class TwoTowerDiffusionLCModelBuilder(AbstractLCModelBuilder):
|
| 492 |
+
"""Builds modules of a diffusion-based LCM"""
|
| 493 |
+
|
| 494 |
+
config: TwoTowerDiffusionLCModelConfig
|
| 495 |
+
denoiser_factory: LCMDenoiserTransformerFactory
|
| 496 |
+
|
| 497 |
+
def __init__(
|
| 498 |
+
self,
|
| 499 |
+
config: TwoTowerDiffusionLCModelConfig,
|
| 500 |
+
*,
|
| 501 |
+
device: Optional[Device] = None,
|
| 502 |
+
dtype: Optional[DataType] = None,
|
| 503 |
+
) -> None:
|
| 504 |
+
"""
|
| 505 |
+
:param config:
|
| 506 |
+
The configuration.
|
| 507 |
+
:param device:
|
| 508 |
+
The device on which to initialize modules.
|
| 509 |
+
:param dtype:
|
| 510 |
+
The data type of module parameters and buffers.
|
| 511 |
+
"""
|
| 512 |
+
super().__init__(config=config, device=device, dtype=dtype)
|
| 513 |
+
|
| 514 |
+
self.context_encoder_factory = TransformerFactory(
|
| 515 |
+
model_dim=self.config.model_dim,
|
| 516 |
+
max_seq_len=self.config.max_seq_len,
|
| 517 |
+
config=self.config.context_encoder,
|
| 518 |
+
device=device,
|
| 519 |
+
dtype=dtype,
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
self.denoiser_factory = LCMDenoiserTransformerFactory(
|
| 523 |
+
model_dim=self.config.model_dim,
|
| 524 |
+
num_diffusion_train_timesteps=self.config.noise_scheduler.num_diffusion_train_steps,
|
| 525 |
+
max_seq_len=self.config.max_seq_len,
|
| 526 |
+
config=self.config.denoiser,
|
| 527 |
+
input_dim=self.config.sonar_embed_dim,
|
| 528 |
+
device=device,
|
| 529 |
+
dtype=dtype,
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
def build_model(self) -> TwoTowerDiffusionLCModel:
|
| 533 |
+
"""Build a model."""
|
| 534 |
+
|
| 535 |
+
sonar_normalizer = self.build_sonar_normalizer()
|
| 536 |
+
assert sonar_normalizer is not None, (
|
| 537 |
+
"TwoTowerDiffusionLCModel expects a `sonar_normalizer`"
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
# the context encoder
|
| 541 |
+
encoder_frontend = self.build_frontend()
|
| 542 |
+
|
| 543 |
+
context_encoder = self.build_context_encoder()
|
| 544 |
+
|
| 545 |
+
# the denoiser
|
| 546 |
+
denoiser = self.build_denoiser()
|
| 547 |
+
|
| 548 |
+
noise_scheduler = self.build_noise_scheduler()
|
| 549 |
+
|
| 550 |
+
return TwoTowerDiffusionLCModel(
|
| 551 |
+
config=self.config,
|
| 552 |
+
sonar_normalizer=sonar_normalizer,
|
| 553 |
+
context_encoder=context_encoder,
|
| 554 |
+
encoder_frontend=encoder_frontend,
|
| 555 |
+
denoiser=denoiser,
|
| 556 |
+
noise_scheduler=noise_scheduler,
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
def build_frontend(self) -> EncoderFrontend:
|
| 560 |
+
"""Build the context encoder front-end."""
|
| 561 |
+
|
| 562 |
+
return EncoderFrontend(
|
| 563 |
+
sonar_embed_dim=self.config.sonar_embed_dim,
|
| 564 |
+
model_dim=self.config.model_dim,
|
| 565 |
+
config=self.config.frontend,
|
| 566 |
+
pos_encoder=self.context_encoder_factory.build_pos_encoder(),
|
| 567 |
+
device=self.device,
|
| 568 |
+
dtype=self.dtype,
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
def build_context_encoder(self) -> LCMTransformerDecoder:
|
| 572 |
+
"""Build the context encoder."""
|
| 573 |
+
|
| 574 |
+
config = self.config.context_encoder
|
| 575 |
+
|
| 576 |
+
num_layers = config.num_layers
|
| 577 |
+
assert num_layers > 0, "The context encoder needs a non-zero number of layers"
|
| 578 |
+
|
| 579 |
+
layers = [self.context_encoder_factory.build_layer() for _ in range(num_layers)]
|
| 580 |
+
|
| 581 |
+
self_attn_mask_factory = CausalAttentionMaskFactory()
|
| 582 |
+
|
| 583 |
+
if config.final_norm_order_style is None:
|
| 584 |
+
# The final norm order style will be that of
|
| 585 |
+
# the layer-level norm order
|
| 586 |
+
final_norm_order = parse_norm_order(config.norm_order_style)
|
| 587 |
+
else:
|
| 588 |
+
final_norm_order = parse_norm_order(config.final_norm_order_style)
|
| 589 |
+
|
| 590 |
+
layer_norm_factory = parse_layer_norm_factory(config.layer_normalization_style)
|
| 591 |
+
|
| 592 |
+
return LCMTransformerDecoder(
|
| 593 |
+
layers,
|
| 594 |
+
self_attn_mask_factory=self_attn_mask_factory,
|
| 595 |
+
norm_order=final_norm_order,
|
| 596 |
+
layer_norm_factory=layer_norm_factory,
|
| 597 |
+
dropout_p=config.final_dropout_p,
|
| 598 |
+
device=self.device,
|
| 599 |
+
dtype=self.dtype,
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
def build_noise_scheduler(self) -> DDIMScheduler:
|
| 603 |
+
return DDIMScheduler(self.config.noise_scheduler)
|
| 604 |
+
|
| 605 |
+
def build_denoiser(self) -> LCMDenoiser:
|
| 606 |
+
"""Build a Transformer for diffusing noised latents."""
|
| 607 |
+
return self.denoiser_factory.build_model()
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
def create_two_tower_diffusion_lcm_model(
|
| 611 |
+
config: TwoTowerDiffusionLCModelConfig,
|
| 612 |
+
*,
|
| 613 |
+
device: Optional[Device] = None,
|
| 614 |
+
dtype: Optional[DataType] = None,
|
| 615 |
+
) -> TwoTowerDiffusionLCModel:
|
| 616 |
+
"""Create a DiffusionLCM model.
|
| 617 |
+
:param config:
|
| 618 |
+
The configuration.
|
| 619 |
+
:param device:
|
| 620 |
+
The device on which to initialize modules.
|
| 621 |
+
:param dtype:
|
| 622 |
+
The data type of module parameters and buffers.
|
| 623 |
+
"""
|
| 624 |
+
return TwoTowerDiffusionLCModelBuilder(
|
| 625 |
+
config,
|
| 626 |
+
device=device,
|
| 627 |
+
dtype=dtype, # type: ignore
|
| 628 |
+
).build_model()
|
lcm/models/two_tower_diffusion_lcm/frontend.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from fairseq2.logging import get_log_writer
|
| 11 |
+
from fairseq2.nn import Embedding, LearnedPositionEncoder, PositionEncoder
|
| 12 |
+
from fairseq2.nn.incremental_state import IncrementalStateBag
|
| 13 |
+
from fairseq2.nn.padding import PaddingMask
|
| 14 |
+
from fairseq2.nn.projection import Linear
|
| 15 |
+
from fairseq2.typing import DataType, Device
|
| 16 |
+
from torch import Tensor
|
| 17 |
+
from torch.nn import Dropout, Module
|
| 18 |
+
|
| 19 |
+
from lcm.nn.initialization import SUPPORTED_INIT_TYPES, get_init_fn
|
| 20 |
+
|
| 21 |
+
logger = get_log_writer(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class EncoderFrontendConfig:
|
| 26 |
+
dropout_p: float = 0.0
|
| 27 |
+
""" The dropout probability applied to the module' output"""
|
| 28 |
+
|
| 29 |
+
pre_linear_bias: bool = True
|
| 30 |
+
""" Whether or not the pre-linear layer has a bias term"""
|
| 31 |
+
|
| 32 |
+
pre_linear_init_fn: SUPPORTED_INIT_TYPES = "kaiming_uniform"
|
| 33 |
+
|
| 34 |
+
weight_normalization: bool = False
|
| 35 |
+
|
| 36 |
+
embedding_std: float = 1.0
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class EncoderFrontend(Module):
|
| 40 |
+
"""
|
| 41 |
+
A fronted for the context encoder in encoder-decoder LCMs
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
embed: Embedding
|
| 45 |
+
pos_encoder: Optional[PositionEncoder]
|
| 46 |
+
dropout: Optional[Dropout]
|
| 47 |
+
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
sonar_embed_dim: int,
|
| 51 |
+
model_dim: int,
|
| 52 |
+
config: EncoderFrontendConfig,
|
| 53 |
+
pos_encoder: Optional[PositionEncoder],
|
| 54 |
+
*,
|
| 55 |
+
device: Optional[Device] = None,
|
| 56 |
+
dtype: Optional[DataType] = None,
|
| 57 |
+
) -> None:
|
| 58 |
+
"""
|
| 59 |
+
:param sonar_embed_dim
|
| 60 |
+
The embedding dimension of the sentence encoder, in this case SONAR
|
| 61 |
+
:param model_dim
|
| 62 |
+
The model embedding dimension
|
| 63 |
+
:param config:
|
| 64 |
+
A Frontend config. See `LCMFrontendConfig`
|
| 65 |
+
:param pos_encoder:
|
| 66 |
+
An optional position encoder.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
super().__init__()
|
| 70 |
+
|
| 71 |
+
self.sonar_embed_dim = sonar_embed_dim
|
| 72 |
+
|
| 73 |
+
self.model_dim = model_dim
|
| 74 |
+
|
| 75 |
+
self.device = device
|
| 76 |
+
|
| 77 |
+
# Pre-linear to map to model dimension
|
| 78 |
+
init_fn = get_init_fn(config.pre_linear_init_fn)
|
| 79 |
+
|
| 80 |
+
lin = Linear(
|
| 81 |
+
sonar_embed_dim,
|
| 82 |
+
model_dim,
|
| 83 |
+
bias=config.pre_linear_bias,
|
| 84 |
+
device=device,
|
| 85 |
+
dtype=dtype,
|
| 86 |
+
init_fn=init_fn,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
if config.weight_normalization:
|
| 90 |
+
self.pre_linear = torch.nn.utils.parametrizations.weight_norm(lin)
|
| 91 |
+
else:
|
| 92 |
+
self.pre_linear = lin
|
| 93 |
+
|
| 94 |
+
if pos_encoder is not None:
|
| 95 |
+
if pos_encoder.encoding_dim != self.model_dim:
|
| 96 |
+
raise ValueError(
|
| 97 |
+
f"`encoding_dim` of `pos_encoder` and `embedding_dim` of \
|
| 98 |
+
`embed` must be equal, but are {pos_encoder.encoding_dim} \
|
| 99 |
+
and {self.model_dim} instead."
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
self.pos_encoder = pos_encoder
|
| 103 |
+
else:
|
| 104 |
+
self.register_module("pos_encoder", None)
|
| 105 |
+
|
| 106 |
+
if config.dropout_p > 0.0:
|
| 107 |
+
self.dropout = Dropout(config.dropout_p)
|
| 108 |
+
else:
|
| 109 |
+
self.register_module("dropout", None)
|
| 110 |
+
|
| 111 |
+
self.reset_parameters(embedding_std=config.embedding_std)
|
| 112 |
+
|
| 113 |
+
def reset_parameters(self, embedding_std: float) -> None:
|
| 114 |
+
"""Initialize module parameters.
|
| 115 |
+
The positional embeddings should be initialized with the
|
| 116 |
+
same order of magnitude as the semantic embeddings, in order
|
| 117 |
+
to make the early training as stable as possible.
|
| 118 |
+
Otherwise, the positional and special token embeddings would
|
| 119 |
+
flood out the semantic information.
|
| 120 |
+
"""
|
| 121 |
+
logger.info(
|
| 122 |
+
f"Initializing frontend embeddings (special and positional) ~ N(0, {embedding_std})"
|
| 123 |
+
)
|
| 124 |
+
if isinstance(self.pos_encoder, LearnedPositionEncoder):
|
| 125 |
+
torch.nn.init.normal_(self.pos_encoder.weight, std=embedding_std)
|
| 126 |
+
|
| 127 |
+
def forward(
|
| 128 |
+
self,
|
| 129 |
+
seqs: Tensor,
|
| 130 |
+
padding_mask: Optional[PaddingMask],
|
| 131 |
+
state_bag: Optional[IncrementalStateBag] = None,
|
| 132 |
+
**kwargs,
|
| 133 |
+
) -> Tuple[Tensor, Optional[PaddingMask]]:
|
| 134 |
+
"""
|
| 135 |
+
Apply pre-linear (if relevant) and add positional embeddings
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
# pre-linear if any:
|
| 139 |
+
seqs = self.pre_linear(seqs)
|
| 140 |
+
|
| 141 |
+
if self.pos_encoder is not None:
|
| 142 |
+
seqs = self.pos_encoder(
|
| 143 |
+
seqs,
|
| 144 |
+
padding_mask,
|
| 145 |
+
state_bag=state_bag,
|
| 146 |
+
**kwargs,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
if self.dropout is not None:
|
| 150 |
+
seqs = self.dropout(seqs)
|
| 151 |
+
|
| 152 |
+
return seqs, padding_mask
|
lcm/models/two_tower_diffusion_lcm/loader.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from fairseq2.models.config_loader import StandardModelConfigLoader
|
| 8 |
+
from fairseq2.models.loader import StandardModelLoader, load_model
|
| 9 |
+
|
| 10 |
+
from lcm.models.base_lcm.loader import convert_lcm_checkpoint
|
| 11 |
+
from lcm.models.two_tower_diffusion_lcm.builder import (
|
| 12 |
+
TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE,
|
| 13 |
+
TwoTowerDiffusionLCModelConfig,
|
| 14 |
+
create_two_tower_diffusion_lcm_model,
|
| 15 |
+
lcm_archs,
|
| 16 |
+
)
|
| 17 |
+
from lcm.utils.model_type_registry import ModelTypeConfig, lcm_model_type_registry
|
| 18 |
+
|
| 19 |
+
load_two_tower_diffusion_lcm_config = StandardModelConfigLoader(
|
| 20 |
+
family=TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE,
|
| 21 |
+
config_kls=TwoTowerDiffusionLCModelConfig,
|
| 22 |
+
arch_configs=lcm_archs,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
load_two_tower_diffusion_lcm_model = StandardModelLoader( # type: ignore # FIXME
|
| 27 |
+
config_loader=load_two_tower_diffusion_lcm_config,
|
| 28 |
+
factory=create_two_tower_diffusion_lcm_model,
|
| 29 |
+
checkpoint_converter=convert_lcm_checkpoint,
|
| 30 |
+
restrict_checkpoints=False,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
load_model.register(
|
| 34 |
+
TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE, load_two_tower_diffusion_lcm_model
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
lcm_model_type_registry.register(
|
| 38 |
+
ModelTypeConfig(
|
| 39 |
+
model_type=TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE,
|
| 40 |
+
config_loader=load_two_tower_diffusion_lcm_config,
|
| 41 |
+
model_factory=create_two_tower_diffusion_lcm_model,
|
| 42 |
+
model_loader=load_two_tower_diffusion_lcm_model,
|
| 43 |
+
)
|
| 44 |
+
)
|
lcm/nn/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
lcm/nn/denoisers/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from lcm.nn.denoisers.factory import (
|
| 8 |
+
DenoiserConfig,
|
| 9 |
+
LCMDenoiser,
|
| 10 |
+
LCMDenoiserTransformerFactory,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
__all__ = [
|
| 14 |
+
"DenoiserConfig",
|
| 15 |
+
"LCMDenoiser",
|
| 16 |
+
"LCMDenoiserTransformerFactory",
|
| 17 |
+
]
|
lcm/nn/denoisers/attention_masks.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
from typing import Optional, final
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from fairseq2.nn.transformer import (
|
| 11 |
+
AbstractAttentionMask,
|
| 12 |
+
AttentionMask,
|
| 13 |
+
AttentionMaskFactory,
|
| 14 |
+
)
|
| 15 |
+
from fairseq2.typing import DataType, Device, override
|
| 16 |
+
from torch import Tensor
|
| 17 |
+
|
| 18 |
+
from lcm.nn.incremental_state import LCMIncrementalStateBag
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _get_shifted_causal_mask(
|
| 22 |
+
seq_len: int,
|
| 23 |
+
key_len: int,
|
| 24 |
+
shift: int = 0,
|
| 25 |
+
cf_guidance_prob: float = 0.0,
|
| 26 |
+
zero_vector: bool = False,
|
| 27 |
+
device: Optional[Device] = None,
|
| 28 |
+
dtype: Optional[DataType] = None,
|
| 29 |
+
) -> Tensor:
|
| 30 |
+
causal_mask = torch.ones(
|
| 31 |
+
(seq_len, key_len),
|
| 32 |
+
device=device,
|
| 33 |
+
dtype=dtype,
|
| 34 |
+
)
|
| 35 |
+
causal_mask.tril_(diagonal=shift)
|
| 36 |
+
|
| 37 |
+
if cf_guidance_prob > 0.0:
|
| 38 |
+
num_rows_to_drop = math.floor((seq_len - 1) * cf_guidance_prob)
|
| 39 |
+
if num_rows_to_drop > 0:
|
| 40 |
+
rows_to_drop = 1 + torch.randperm(seq_len - 1)[:num_rows_to_drop]
|
| 41 |
+
if zero_vector:
|
| 42 |
+
causal_mask[rows_to_drop, 1:] = 0
|
| 43 |
+
else:
|
| 44 |
+
causal_mask[rows_to_drop, :] = 0
|
| 45 |
+
|
| 46 |
+
return causal_mask
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class NoAttentionMaskFactory(AttentionMaskFactory):
|
| 50 |
+
"""Constructs instances of :class:`NoAttentionMask`."""
|
| 51 |
+
|
| 52 |
+
@override
|
| 53 |
+
def __call__( # type: ignore
|
| 54 |
+
self,
|
| 55 |
+
seqs: Tensor,
|
| 56 |
+
keys: Tensor,
|
| 57 |
+
*,
|
| 58 |
+
training: bool = True,
|
| 59 |
+
state_bag: Optional[LCMIncrementalStateBag] = None,
|
| 60 |
+
inference_without_caching: Optional[bool] = False,
|
| 61 |
+
**kwargs,
|
| 62 |
+
) -> Optional[AttentionMask]:
|
| 63 |
+
mask: NoAttentionMask
|
| 64 |
+
|
| 65 |
+
attn_len: Optional[int] = seqs.size(1)
|
| 66 |
+
seq_len = seqs.size(1)
|
| 67 |
+
key_len = keys.size(1)
|
| 68 |
+
|
| 69 |
+
mask = NoAttentionMask(
|
| 70 |
+
seq_len=seq_len,
|
| 71 |
+
key_len=key_len,
|
| 72 |
+
attn_len=attn_len,
|
| 73 |
+
device=seqs.device,
|
| 74 |
+
dtype=seqs.dtype,
|
| 75 |
+
)
|
| 76 |
+
return mask
|
| 77 |
+
|
| 78 |
+
def __repr__(self) -> str:
|
| 79 |
+
return "NoAttentionMaskFactory()"
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@final
|
| 83 |
+
class NoAttentionMask(AbstractAttentionMask):
|
| 84 |
+
"""
|
| 85 |
+
Represents a diagonal attention mask, i.e attention
|
| 86 |
+
on current position only.
|
| 87 |
+
This turns the self-attention layer into an FFN
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def __init__(
|
| 91 |
+
self,
|
| 92 |
+
seq_len: int,
|
| 93 |
+
key_len: int,
|
| 94 |
+
attn_len: Optional[int],
|
| 95 |
+
*,
|
| 96 |
+
device: Optional[Device] = None,
|
| 97 |
+
dtype: Optional[DataType] = None,
|
| 98 |
+
) -> None:
|
| 99 |
+
"""
|
| 100 |
+
:param seq_len:
|
| 101 |
+
The sequence length.
|
| 102 |
+
"""
|
| 103 |
+
super().__init__()
|
| 104 |
+
|
| 105 |
+
self.seq_len = seq_len
|
| 106 |
+
|
| 107 |
+
self._device, self._dtype = device, dtype
|
| 108 |
+
|
| 109 |
+
@override
|
| 110 |
+
def _do_materialize(self) -> Tensor:
|
| 111 |
+
mask = torch.eye((self.seq_len), device=self._device, dtype=self._dtype)
|
| 112 |
+
mask.log_()
|
| 113 |
+
return mask
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class ShiftedCausalAttentionMaskFactory(AttentionMaskFactory):
|
| 117 |
+
"""
|
| 118 |
+
Constructs instances of :class:`ShiftedCausalAttentionMask`
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
@override
|
| 122 |
+
def __call__( # type: ignore
|
| 123 |
+
self,
|
| 124 |
+
seqs: Tensor,
|
| 125 |
+
keys: Tensor,
|
| 126 |
+
*,
|
| 127 |
+
source_lengths: Optional[Tensor] = None,
|
| 128 |
+
cf_guidance_prob: float = 0.0,
|
| 129 |
+
training: bool = True,
|
| 130 |
+
state_bag: Optional[LCMIncrementalStateBag] = None,
|
| 131 |
+
inference: bool = False,
|
| 132 |
+
) -> Optional[AttentionMask]:
|
| 133 |
+
mask: Optional[ShiftedCausalAttentionMask]
|
| 134 |
+
|
| 135 |
+
attn_len: Optional[int] = seqs.size(1)
|
| 136 |
+
seq_len = seqs.size(1)
|
| 137 |
+
key_len = keys.size(1)
|
| 138 |
+
|
| 139 |
+
if inference:
|
| 140 |
+
mask = None
|
| 141 |
+
else:
|
| 142 |
+
mask = ShiftedCausalAttentionMask(
|
| 143 |
+
seq_len=seq_len,
|
| 144 |
+
key_len=key_len,
|
| 145 |
+
attn_len=attn_len,
|
| 146 |
+
source_lengths=source_lengths,
|
| 147 |
+
cf_guidance_prob=cf_guidance_prob,
|
| 148 |
+
device=seqs.device,
|
| 149 |
+
dtype=seqs.dtype,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
return mask
|
| 153 |
+
|
| 154 |
+
def __repr__(self) -> str:
|
| 155 |
+
return "ShiftedCausalAttentionMask()"
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
@final
|
| 159 |
+
class ShiftedCausalAttentionMask(AbstractAttentionMask):
|
| 160 |
+
"""
|
| 161 |
+
Represents a causal mask shifted by source_lengths
|
| 162 |
+
|
| 163 |
+
In training time, Without source_lengths, the mask look like (e.g. seq_len = 5):
|
| 164 |
+
|
| 165 |
+
[ 0., -inf, -inf, -inf, -inf, -inf],
|
| 166 |
+
[ 0., 0., -inf, -inf, -inf, -inf],
|
| 167 |
+
[ 0., 0., 0., -inf, -inf, -inf],
|
| 168 |
+
[ 0., 0., 0., 0., -inf, -inf],
|
| 169 |
+
[ 0., 0., 0., 0., 0., -inf]
|
| 170 |
+
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
def __init__(
|
| 174 |
+
self,
|
| 175 |
+
seq_len: int,
|
| 176 |
+
key_len: int,
|
| 177 |
+
attn_len: Optional[int],
|
| 178 |
+
*,
|
| 179 |
+
source_lengths: Optional[Tensor] = None,
|
| 180 |
+
cf_guidance_prob: float = 0.0,
|
| 181 |
+
device: Optional[Device] = None,
|
| 182 |
+
dtype: Optional[DataType] = None,
|
| 183 |
+
) -> None:
|
| 184 |
+
"""
|
| 185 |
+
:param seq_len:
|
| 186 |
+
The sequence length.
|
| 187 |
+
"""
|
| 188 |
+
super().__init__()
|
| 189 |
+
|
| 190 |
+
self.seq_len = seq_len
|
| 191 |
+
self.key_len = key_len
|
| 192 |
+
self._source_lengths = source_lengths
|
| 193 |
+
self._cf_guidance_prob = cf_guidance_prob
|
| 194 |
+
self._device, self._dtype = device, dtype
|
| 195 |
+
|
| 196 |
+
@override
|
| 197 |
+
def _do_materialize(self) -> Tensor:
|
| 198 |
+
if self._source_lengths is None:
|
| 199 |
+
causal_mask = _get_shifted_causal_mask(
|
| 200 |
+
seq_len=self.seq_len,
|
| 201 |
+
key_len=self.key_len,
|
| 202 |
+
shift=0,
|
| 203 |
+
cf_guidance_prob=self._cf_guidance_prob,
|
| 204 |
+
zero_vector=True,
|
| 205 |
+
device=self._device,
|
| 206 |
+
dtype=self._dtype,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
else:
|
| 210 |
+
causal_mask = torch.stack(
|
| 211 |
+
[
|
| 212 |
+
_get_shifted_causal_mask(
|
| 213 |
+
seq_len=self.seq_len,
|
| 214 |
+
key_len=self.key_len,
|
| 215 |
+
shift=src_len,
|
| 216 |
+
cf_guidance_prob=self._cf_guidance_prob,
|
| 217 |
+
zero_vector=True,
|
| 218 |
+
device=self._device,
|
| 219 |
+
dtype=self._dtype,
|
| 220 |
+
)
|
| 221 |
+
for src_len in self._source_lengths
|
| 222 |
+
]
|
| 223 |
+
).unsqueeze(1)
|
| 224 |
+
# bs x 1 (head) x seq_len x seq_len
|
| 225 |
+
|
| 226 |
+
causal_mask.log_()
|
| 227 |
+
|
| 228 |
+
return causal_mask
|
lcm/nn/denoisers/factory.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
+
from typing import Literal, Optional
|
| 8 |
+
|
| 9 |
+
from fairseq2.logging import get_log_writer
|
| 10 |
+
from fairseq2.typing import DataType, Device
|
| 11 |
+
|
| 12 |
+
from lcm.nn.denoisers.attention_masks import (
|
| 13 |
+
NoAttentionMaskFactory,
|
| 14 |
+
ShiftedCausalAttentionMaskFactory,
|
| 15 |
+
)
|
| 16 |
+
from lcm.nn.denoisers.lcm_denoiser import (
|
| 17 |
+
LCMDenoiser,
|
| 18 |
+
LCMDenoiserLayer,
|
| 19 |
+
)
|
| 20 |
+
from lcm.nn.initialization import parse_norm_order
|
| 21 |
+
from lcm.nn.normalization import parse_layer_norm_factory
|
| 22 |
+
from lcm.nn.projection import (
|
| 23 |
+
Projection,
|
| 24 |
+
ProjectionConfig,
|
| 25 |
+
)
|
| 26 |
+
from lcm.nn.timestep_encoder import DiTTimestepEncoder
|
| 27 |
+
from lcm.nn.transformer import TransformerConfig, TransformerFactory
|
| 28 |
+
|
| 29 |
+
logger = get_log_writer(__name__)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@dataclass
|
| 33 |
+
class DenoiserConfig(TransformerConfig):
|
| 34 |
+
"""Config for building the LCM's denoiser"""
|
| 35 |
+
|
| 36 |
+
pos_embedding_style: Literal["rope", "sine", "learned", "none"] = "none"
|
| 37 |
+
"""By default, a denoiser does not have a positional embedder"""
|
| 38 |
+
|
| 39 |
+
pre_denoiser: ProjectionConfig = field(default_factory=lambda: ProjectionConfig())
|
| 40 |
+
"""the initial projection at the top of the denoiser"""
|
| 41 |
+
|
| 42 |
+
post_denoiser: ProjectionConfig = field(default_factory=lambda: ProjectionConfig())
|
| 43 |
+
"""the final output projection at the end of the denoiser"""
|
| 44 |
+
|
| 45 |
+
timestep_embed_dim: int = 1024
|
| 46 |
+
"""Diffusion timestep embedding dimension"""
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class LCMDenoiserTransformerFactory(TransformerFactory):
|
| 50 |
+
"""Denoiser with hybrid AdaLN and cross-attention"""
|
| 51 |
+
|
| 52 |
+
config: DenoiserConfig
|
| 53 |
+
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
model_dim: int,
|
| 57 |
+
max_seq_len: int,
|
| 58 |
+
num_diffusion_train_timesteps: int,
|
| 59 |
+
config: DenoiserConfig,
|
| 60 |
+
input_dim: int = 1024,
|
| 61 |
+
device: Optional[Device] = None,
|
| 62 |
+
dtype: Optional[DataType] = None,
|
| 63 |
+
) -> None:
|
| 64 |
+
"""
|
| 65 |
+
:param model_dim:
|
| 66 |
+
The hidden model dimension of the Transformer
|
| 67 |
+
:params max_seqs_len:
|
| 68 |
+
Maximum supported sequence length by the model
|
| 69 |
+
:param config:
|
| 70 |
+
The configuration.
|
| 71 |
+
:param input_dim:
|
| 72 |
+
The input embedding dimension i.e `sonar_embed_dim``
|
| 73 |
+
:param device:
|
| 74 |
+
The device on which to initialize modules.
|
| 75 |
+
:param dtype:
|
| 76 |
+
The data type of module parameters and buffers.
|
| 77 |
+
"""
|
| 78 |
+
super().__init__(
|
| 79 |
+
model_dim=model_dim,
|
| 80 |
+
max_seq_len=max_seq_len,
|
| 81 |
+
config=config,
|
| 82 |
+
device=device,
|
| 83 |
+
dtype=dtype,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
self.input_dim = input_dim
|
| 87 |
+
|
| 88 |
+
self.num_diffusion_train_timesteps = num_diffusion_train_timesteps
|
| 89 |
+
|
| 90 |
+
def build_cross_attention_mask(self):
|
| 91 |
+
return ShiftedCausalAttentionMaskFactory()
|
| 92 |
+
|
| 93 |
+
def build_timestep_embedder(self):
|
| 94 |
+
return DiTTimestepEncoder(
|
| 95 |
+
embedding_dim=self.config.timestep_embed_dim,
|
| 96 |
+
dtype=self.dtype,
|
| 97 |
+
device=self.device,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
def build_initial_proj(self) -> Projection:
|
| 101 |
+
# We will be concatenating context and timesteps embeddings
|
| 102 |
+
assert self.config.timestep_embed_dim == self.model_dim, (
|
| 103 |
+
"Since the timestep embeddings will be added to the sequence of "
|
| 104 |
+
"conditioning variables, they need to be of the same dimension. "
|
| 105 |
+
f"Found timestep_embed_dim={self.config.timestep_embed_dim} "
|
| 106 |
+
f"and model_dim={self.model_dim}"
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
return Projection(
|
| 110 |
+
output_dim=self.model_dim,
|
| 111 |
+
input_dim=self.input_dim,
|
| 112 |
+
config=self.config.pre_denoiser,
|
| 113 |
+
device=self.device,
|
| 114 |
+
dtype=self.dtype,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def build_final_proj(self) -> Projection:
|
| 118 |
+
return Projection(
|
| 119 |
+
output_dim=self.input_dim,
|
| 120 |
+
input_dim=self.model_dim,
|
| 121 |
+
config=self.config.post_denoiser,
|
| 122 |
+
device=self.device,
|
| 123 |
+
dtype=self.dtype,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
def build_model(self) -> LCMDenoiser:
|
| 127 |
+
"""Build the denoiser with its layers and initial/final projections"""
|
| 128 |
+
embed_time = self.build_timestep_embedder()
|
| 129 |
+
|
| 130 |
+
layers = [self.build_layer() for _ in range(self.config.num_layers)]
|
| 131 |
+
|
| 132 |
+
norm_order = parse_norm_order(self.config.norm_order_style)
|
| 133 |
+
|
| 134 |
+
# Self-attention here does not contextualize
|
| 135 |
+
self_attn_mask_factory = NoAttentionMaskFactory()
|
| 136 |
+
|
| 137 |
+
cross_attention_mask_factory = self.build_cross_attention_mask()
|
| 138 |
+
|
| 139 |
+
layer_norm_factory = parse_layer_norm_factory(
|
| 140 |
+
self.config.layer_normalization_style
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
pos_encoder = self.build_pos_encoder()
|
| 144 |
+
|
| 145 |
+
return LCMDenoiser(
|
| 146 |
+
embed_time=embed_time,
|
| 147 |
+
layers=layers,
|
| 148 |
+
initial_proj=self.build_initial_proj(),
|
| 149 |
+
final_proj=self.build_final_proj(),
|
| 150 |
+
dropout_p=self.config.final_dropout_p,
|
| 151 |
+
norm_order=norm_order,
|
| 152 |
+
layer_norm_factory=layer_norm_factory,
|
| 153 |
+
self_attn_mask_factory=self_attn_mask_factory,
|
| 154 |
+
cross_attention_mask_factory=cross_attention_mask_factory,
|
| 155 |
+
pos_encoder=pos_encoder,
|
| 156 |
+
device=self.device,
|
| 157 |
+
dtype=self.dtype,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
def build_layer(self) -> LCMDenoiserLayer:
|
| 161 |
+
"""Build a Transformer decoder layer based on the provided config."""
|
| 162 |
+
|
| 163 |
+
assert isinstance(self.config, DenoiserConfig), (
|
| 164 |
+
"Expecting a DenoiserConfig in the DenoiserTransformerFactory"
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
self_attn = self.build_attention()
|
| 168 |
+
|
| 169 |
+
cross_attn = self.build_attention()
|
| 170 |
+
|
| 171 |
+
ffn = self.build_ffn()
|
| 172 |
+
|
| 173 |
+
norm_order = parse_norm_order(self.config.norm_order_style)
|
| 174 |
+
|
| 175 |
+
layer_norm_factory = parse_layer_norm_factory(
|
| 176 |
+
self.config.layer_normalization_style
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
modulator_input_dim = self_attn.model_dim
|
| 180 |
+
|
| 181 |
+
layer = LCMDenoiserLayer(
|
| 182 |
+
self_attn=self_attn,
|
| 183 |
+
cross_attention=cross_attn,
|
| 184 |
+
ffn=ffn,
|
| 185 |
+
modulator_input_dim=modulator_input_dim,
|
| 186 |
+
dropout_p=self.config.dropout_p,
|
| 187 |
+
norm_order=norm_order,
|
| 188 |
+
layer_norm_factory=layer_norm_factory,
|
| 189 |
+
device=self.device,
|
| 190 |
+
dtype=self.dtype,
|
| 191 |
+
)
|
| 192 |
+
return layer
|
lcm/nn/denoisers/lcm_denoiser.py
ADDED
|
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from typing import Iterable, Optional, Tuple, cast
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from fairseq2.nn import PositionEncoder
|
| 11 |
+
from fairseq2.nn.incremental_state import IncrementalStateBag
|
| 12 |
+
from fairseq2.nn.normalization import LayerNorm
|
| 13 |
+
from fairseq2.nn.padding import PaddingMask
|
| 14 |
+
from fairseq2.nn.transformer import (
|
| 15 |
+
AttentionMask,
|
| 16 |
+
AttentionMaskFactory,
|
| 17 |
+
FeedForwardNetwork,
|
| 18 |
+
LayerNormFactory,
|
| 19 |
+
MultiheadAttention,
|
| 20 |
+
TransformerDecoderLayer,
|
| 21 |
+
TransformerNormOrder,
|
| 22 |
+
create_standard_layer_norm,
|
| 23 |
+
)
|
| 24 |
+
from fairseq2.typing import DataType, Device, override
|
| 25 |
+
from torch import Tensor
|
| 26 |
+
from torch.nn import Dropout, Module, ModuleList
|
| 27 |
+
from torch.nn.parameter import Parameter
|
| 28 |
+
|
| 29 |
+
from lcm.nn.projection import Projection
|
| 30 |
+
from lcm.nn.timestep_encoder import DiTTimestepEncoder
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class AdaLNModulator(Module):
|
| 34 |
+
"""An adaptive LayerNorm modulator to estimate
|
| 35 |
+
shift, gate and scale for all 3 sub-modules."""
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
input_dim: int,
|
| 40 |
+
output_dim: int,
|
| 41 |
+
device: Optional[Device] = None,
|
| 42 |
+
dtype: Optional[DataType] = None,
|
| 43 |
+
):
|
| 44 |
+
super().__init__()
|
| 45 |
+
|
| 46 |
+
self.activate = nn.SiLU()
|
| 47 |
+
self.fc = nn.Linear(
|
| 48 |
+
input_dim,
|
| 49 |
+
9 * output_dim,
|
| 50 |
+
bias=True,
|
| 51 |
+
device=device,
|
| 52 |
+
dtype=dtype,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def reset_parameters(self):
|
| 56 |
+
# zero-init
|
| 57 |
+
nn.init.constant_(self.fc.weight, 0)
|
| 58 |
+
nn.init.constant_(self.fc.bias, 0)
|
| 59 |
+
|
| 60 |
+
def forward(self, context: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
| 61 |
+
(modulate_san, modulate_cross_attention, modulate_ffn) = self.fc(
|
| 62 |
+
self.activate(context)
|
| 63 |
+
).chunk(3, dim=-1)
|
| 64 |
+
return modulate_san, modulate_cross_attention, modulate_ffn
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class LCMDenoiser(Module):
|
| 68 |
+
"""
|
| 69 |
+
The main denoiser module of the two-tower diffusion LCM.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
model_dim: int
|
| 73 |
+
layers: ModuleList
|
| 74 |
+
self_attn_mask_factory: AttentionMaskFactory
|
| 75 |
+
layer_norm: Optional[LayerNorm]
|
| 76 |
+
dropout_p: float
|
| 77 |
+
norm_order: TransformerNormOrder
|
| 78 |
+
cross_attention_mask_factory: AttentionMaskFactory
|
| 79 |
+
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
embed_time: DiTTimestepEncoder,
|
| 83 |
+
layers: Iterable[TransformerDecoderLayer],
|
| 84 |
+
initial_proj: Projection,
|
| 85 |
+
final_proj: Projection,
|
| 86 |
+
*,
|
| 87 |
+
self_attn_mask_factory: AttentionMaskFactory,
|
| 88 |
+
cross_attention_mask_factory: AttentionMaskFactory,
|
| 89 |
+
dropout_p: float = 0.0,
|
| 90 |
+
norm_order: TransformerNormOrder = TransformerNormOrder.POST,
|
| 91 |
+
pos_encoder: Optional[PositionEncoder] = None,
|
| 92 |
+
layer_norm_factory: Optional[LayerNormFactory] = None,
|
| 93 |
+
device: Optional[Device] = None,
|
| 94 |
+
dtype: Optional[DataType] = None,
|
| 95 |
+
) -> None:
|
| 96 |
+
"""
|
| 97 |
+
:param layers:
|
| 98 |
+
The decoder layers.
|
| 99 |
+
:param self_attn_mask_factory:
|
| 100 |
+
The self attention mask factory.
|
| 101 |
+
:param cross_attention_mask_factory:
|
| 102 |
+
The cross attention mask factory.
|
| 103 |
+
:param dropout_p:
|
| 104 |
+
The dropout probability on decoder outputs.
|
| 105 |
+
:param norm_order:
|
| 106 |
+
The Layer Normalization order.
|
| 107 |
+
:param: pos_encoder:
|
| 108 |
+
An optional positional encoding module
|
| 109 |
+
:param layer_norm_factory:
|
| 110 |
+
The factory to construct the Layer Normalization module.
|
| 111 |
+
"""
|
| 112 |
+
layer_list = ModuleList(layers)
|
| 113 |
+
|
| 114 |
+
if not layer_list:
|
| 115 |
+
raise ValueError("`layers` must be non-empty.")
|
| 116 |
+
|
| 117 |
+
model_dim = layer_list[0].model_dim
|
| 118 |
+
|
| 119 |
+
super().__init__()
|
| 120 |
+
|
| 121 |
+
self.model_dim = model_dim
|
| 122 |
+
|
| 123 |
+
self.embed_time = embed_time
|
| 124 |
+
|
| 125 |
+
self.initial_proj = initial_proj
|
| 126 |
+
|
| 127 |
+
self.final_proj = final_proj
|
| 128 |
+
|
| 129 |
+
self.pos_encoder = pos_encoder
|
| 130 |
+
|
| 131 |
+
if layer_norm_factory is None:
|
| 132 |
+
layer_norm_factory = create_standard_layer_norm
|
| 133 |
+
|
| 134 |
+
self.self_attn_mask_factory = self_attn_mask_factory
|
| 135 |
+
|
| 136 |
+
self.layers = layer_list
|
| 137 |
+
|
| 138 |
+
if norm_order != TransformerNormOrder.POST:
|
| 139 |
+
self.layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype)
|
| 140 |
+
else:
|
| 141 |
+
self.register_module("layer_norm", None)
|
| 142 |
+
|
| 143 |
+
if dropout_p > 0.0:
|
| 144 |
+
self.dropout = Dropout(dropout_p)
|
| 145 |
+
else:
|
| 146 |
+
self.register_module("dropout", None)
|
| 147 |
+
|
| 148 |
+
self.norm_order = norm_order
|
| 149 |
+
|
| 150 |
+
self.cross_attention_mask_factory = cross_attention_mask_factory
|
| 151 |
+
|
| 152 |
+
def forward(
|
| 153 |
+
self,
|
| 154 |
+
seqs: Tensor,
|
| 155 |
+
diffusion_timesteps: Tensor,
|
| 156 |
+
padding_mask: Optional[PaddingMask],
|
| 157 |
+
conditioning_variables: Optional[Tensor] = None,
|
| 158 |
+
conditioning_variables_padding_mask: Optional[PaddingMask] = None,
|
| 159 |
+
source_lengths: Optional[Tensor] = None,
|
| 160 |
+
cf_guidance_prob: float = 0.0,
|
| 161 |
+
*,
|
| 162 |
+
state_bag: Optional[IncrementalStateBag] = None,
|
| 163 |
+
inference: Optional[bool] = False,
|
| 164 |
+
) -> Tuple[Tensor, Optional[PaddingMask]]:
|
| 165 |
+
"""
|
| 166 |
+
Arguments:
|
| 167 |
+
- seqs (`Tensor`): the sequence of latents to denoise
|
| 168 |
+
- diffusion_timesteps (`Tensor`) the indices of the diffusion timesteps
|
| 169 |
+
to be embedded and fed as a conditioning variable.
|
| 170 |
+
- padding_mask (`PaddingMask`) mask of padded positions in the latents (seqs)
|
| 171 |
+
|
| 172 |
+
- conditioning_variables (`Tensor`) the sequence of conditioning
|
| 173 |
+
variables that will be combined with the timestep embedding to
|
| 174 |
+
guide the diffusion process
|
| 175 |
+
- conditioning_variables_padding_mask (`PaddingMask`) the mask of padded
|
| 176 |
+
positions in `conditioning_variables`
|
| 177 |
+
- source_lengths (`Optional[Tensor]`) the lengths of the source embeddings
|
| 178 |
+
in `conditioning_variables` to properly shift the cross-attention mask
|
| 179 |
+
- cf_guidance_prob: probability rate with which to drop all conditioning variables when denoising
|
| 180 |
+
- state_bag (`IncrementalStateBag`) the incremental state bag of the denoiser to enable kv-caching
|
| 181 |
+
- inference (`bool`) if `True` the cross-attention mask will be adjusted accordingly
|
| 182 |
+
"""
|
| 183 |
+
|
| 184 |
+
emb_timesteps = self.embed_time(diffusion_timesteps)
|
| 185 |
+
assert conditioning_variables is not None, (
|
| 186 |
+
"Expected conditioning_variables, found None"
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
assert conditioning_variables is not None, (
|
| 190 |
+
"Mypy - Expecting non-None conditioning_variables"
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
conditioning_variables = torch.cat(
|
| 194 |
+
[
|
| 195 |
+
torch.zeros_like(conditioning_variables[:, 0:1]),
|
| 196 |
+
conditioning_variables,
|
| 197 |
+
],
|
| 198 |
+
dim=1,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
if conditioning_variables_padding_mask is not None:
|
| 202 |
+
# shift by the length of the prepended timesteps
|
| 203 |
+
conditioning_variables_padding_mask = PaddingMask(
|
| 204 |
+
conditioning_variables_padding_mask._seq_lens + 1,
|
| 205 |
+
conditioning_variables_padding_mask._batch_seq_len + 1,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
# project to model_dim and add optional position codes:
|
| 209 |
+
seqs = self.initial_proj(seqs)
|
| 210 |
+
|
| 211 |
+
if self.pos_encoder is not None:
|
| 212 |
+
seqs = self.pos_encoder(seqs, padding_mask)
|
| 213 |
+
|
| 214 |
+
self_attn_mask = self.self_attn_mask_factory(
|
| 215 |
+
seqs, keys=seqs, training=self.training, state_bag=state_bag
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
assert conditioning_variables is not None
|
| 219 |
+
cross_attention_mask = self.cross_attention_mask_factory(
|
| 220 |
+
seqs,
|
| 221 |
+
keys=conditioning_variables,
|
| 222 |
+
source_lengths=source_lengths,
|
| 223 |
+
cf_guidance_prob=cf_guidance_prob,
|
| 224 |
+
training=self.training,
|
| 225 |
+
state_bag=state_bag,
|
| 226 |
+
inference=inference, # type: ignore
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
for layer_idx, layer in enumerate(self.layers):
|
| 230 |
+
layer_output, layer_padding_mask = layer(
|
| 231 |
+
seqs=seqs,
|
| 232 |
+
padding_mask=padding_mask,
|
| 233 |
+
self_attn_mask=self_attn_mask,
|
| 234 |
+
emb_timesteps=emb_timesteps,
|
| 235 |
+
conditioning_variables=conditioning_variables,
|
| 236 |
+
conditioning_variables_padding_mask=conditioning_variables_padding_mask,
|
| 237 |
+
cross_attention_mask=cross_attention_mask,
|
| 238 |
+
state_bag=state_bag,
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
seqs, padding_mask = layer_output, layer_padding_mask
|
| 242 |
+
|
| 243 |
+
if self.layer_norm is not None:
|
| 244 |
+
seqs = self.layer_norm(seqs)
|
| 245 |
+
|
| 246 |
+
if self.dropout is not None:
|
| 247 |
+
seqs = self.dropout(seqs)
|
| 248 |
+
|
| 249 |
+
seqs = self.final_proj(seqs)
|
| 250 |
+
|
| 251 |
+
return seqs, padding_mask
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class LCMDenoiserLayer(TransformerDecoderLayer):
|
| 255 |
+
"""A single layer of the hybrid denoiser"""
|
| 256 |
+
|
| 257 |
+
self_attn: MultiheadAttention
|
| 258 |
+
self_attn_norm: Optional[LayerNorm]
|
| 259 |
+
self_attn_dropout: Optional[Dropout]
|
| 260 |
+
self_attn_layer_norm: LayerNorm
|
| 261 |
+
cross_attention: MultiheadAttention
|
| 262 |
+
cross_attention_dropout: Optional[Dropout]
|
| 263 |
+
cross_attention_layer_norm: Optional[LayerNorm]
|
| 264 |
+
ffn: FeedForwardNetwork
|
| 265 |
+
ffn_dropout: Optional[Dropout]
|
| 266 |
+
residual_scale: Optional[Parameter]
|
| 267 |
+
ffn_layer_norm: LayerNorm
|
| 268 |
+
norm_order: TransformerNormOrder
|
| 269 |
+
|
| 270 |
+
def __init__(
|
| 271 |
+
self,
|
| 272 |
+
self_attn: MultiheadAttention,
|
| 273 |
+
ffn: FeedForwardNetwork,
|
| 274 |
+
cross_attention: MultiheadAttention,
|
| 275 |
+
*,
|
| 276 |
+
scale_residual: bool = False,
|
| 277 |
+
dropout_p: float = 0.0,
|
| 278 |
+
norm_order: TransformerNormOrder = TransformerNormOrder.POST,
|
| 279 |
+
layer_norm_factory: Optional[LayerNormFactory] = None,
|
| 280 |
+
modulator_input_dim: Optional[int] = None,
|
| 281 |
+
device: Optional[Device] = None,
|
| 282 |
+
dtype: Optional[DataType] = None,
|
| 283 |
+
) -> None:
|
| 284 |
+
"""
|
| 285 |
+
:param self_attn:
|
| 286 |
+
The self attention layer.
|
| 287 |
+
:param cross_attention:
|
| 288 |
+
The cross attention layer if denoiser-type is `cross-attention`.
|
| 289 |
+
:param ffn:
|
| 290 |
+
The feed-forward network.
|
| 291 |
+
:param scale_residual:
|
| 292 |
+
If ``True``, scales residuals before adding them to the output of
|
| 293 |
+
the feed-forward network as described in
|
| 294 |
+
:cite:t:`https://doi.org/10.48550/arxiv.2110.09456`.
|
| 295 |
+
:param dropout_p:
|
| 296 |
+
The dropout probability on outputs of the attention layers and the
|
| 297 |
+
feed-forward network.
|
| 298 |
+
:param norm_order:
|
| 299 |
+
The Layer Normalization order.
|
| 300 |
+
:param layer_norm_factory:
|
| 301 |
+
The factory to construct the Layer Normalization modules.
|
| 302 |
+
"""
|
| 303 |
+
model_dim = self_attn.model_dim
|
| 304 |
+
|
| 305 |
+
super().__init__(model_dim)
|
| 306 |
+
|
| 307 |
+
if layer_norm_factory is None:
|
| 308 |
+
layer_norm_factory = create_standard_layer_norm
|
| 309 |
+
|
| 310 |
+
self_attn_layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype)
|
| 311 |
+
|
| 312 |
+
if norm_order != TransformerNormOrder.POST:
|
| 313 |
+
self.self_attn_layer_norm = self_attn_layer_norm
|
| 314 |
+
|
| 315 |
+
self.self_attn = self_attn
|
| 316 |
+
|
| 317 |
+
if norm_order == TransformerNormOrder.PRE_WITH_NORMFORMER:
|
| 318 |
+
self.self_attn_norm = layer_norm_factory(
|
| 319 |
+
model_dim, device=device, dtype=dtype
|
| 320 |
+
)
|
| 321 |
+
else:
|
| 322 |
+
self.register_module("self_attn_norm", None)
|
| 323 |
+
|
| 324 |
+
if dropout_p > 0.0:
|
| 325 |
+
self.self_attn_dropout = Dropout(dropout_p)
|
| 326 |
+
else:
|
| 327 |
+
self.register_module("self_attn_dropout", None)
|
| 328 |
+
|
| 329 |
+
if norm_order == TransformerNormOrder.POST:
|
| 330 |
+
self.self_attn_layer_norm = self_attn_layer_norm
|
| 331 |
+
|
| 332 |
+
# Deal with the cross-attention layers:
|
| 333 |
+
if cross_attention is None:
|
| 334 |
+
self.register_module("cross_attention", None)
|
| 335 |
+
self.register_module("cross_attention_layer_norm", None)
|
| 336 |
+
else:
|
| 337 |
+
cross_attention_layer_norm = layer_norm_factory(
|
| 338 |
+
model_dim, device=device, dtype=dtype
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
if norm_order != TransformerNormOrder.POST:
|
| 342 |
+
self.cross_attention_layer_norm = cross_attention_layer_norm
|
| 343 |
+
|
| 344 |
+
self.cross_attention = cross_attention
|
| 345 |
+
|
| 346 |
+
if dropout_p > 0.0:
|
| 347 |
+
self.cross_attention_dropout = Dropout(dropout_p)
|
| 348 |
+
else:
|
| 349 |
+
self.register_module("cross_attention_dropout", None)
|
| 350 |
+
|
| 351 |
+
if norm_order == TransformerNormOrder.POST:
|
| 352 |
+
self.cross_attention_layer_norm = cross_attention_layer_norm
|
| 353 |
+
# / deal with cross-attention
|
| 354 |
+
|
| 355 |
+
ffn_layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype)
|
| 356 |
+
|
| 357 |
+
if norm_order != TransformerNormOrder.POST:
|
| 358 |
+
self.ffn_layer_norm = ffn_layer_norm
|
| 359 |
+
|
| 360 |
+
self.ffn = ffn
|
| 361 |
+
|
| 362 |
+
if dropout_p > 0.0:
|
| 363 |
+
self.ffn_dropout = Dropout(dropout_p)
|
| 364 |
+
else:
|
| 365 |
+
self.register_module("ffn_dropout", None)
|
| 366 |
+
|
| 367 |
+
if norm_order == TransformerNormOrder.POST:
|
| 368 |
+
self.ffn_layer_norm = ffn_layer_norm
|
| 369 |
+
|
| 370 |
+
self.norm_order = norm_order
|
| 371 |
+
|
| 372 |
+
# Add a modulator:
|
| 373 |
+
modulator_input_dim = modulator_input_dim or model_dim
|
| 374 |
+
self.modulator = AdaLNModulator(
|
| 375 |
+
input_dim=modulator_input_dim,
|
| 376 |
+
output_dim=model_dim,
|
| 377 |
+
device=device,
|
| 378 |
+
dtype=dtype,
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
self.reset_parameters()
|
| 382 |
+
|
| 383 |
+
def reset_parameters(self) -> None:
|
| 384 |
+
"""Reset the parameters and buffers of the module."""
|
| 385 |
+
# Zero-out the modulators:
|
| 386 |
+
self.modulator.reset_parameters()
|
| 387 |
+
|
| 388 |
+
@override
|
| 389 |
+
def forward( # type: ignore
|
| 390 |
+
self,
|
| 391 |
+
seqs: Tensor,
|
| 392 |
+
padding_mask: Optional[PaddingMask],
|
| 393 |
+
conditioning_variables: Tensor,
|
| 394 |
+
emb_timesteps: Tensor,
|
| 395 |
+
self_attn_mask: Optional[AttentionMask] = None,
|
| 396 |
+
conditioning_variables_padding_mask: Optional[PaddingMask] = None,
|
| 397 |
+
cross_attention_mask: Optional[AttentionMask] = None,
|
| 398 |
+
*,
|
| 399 |
+
state_bag: Optional[IncrementalStateBag] = None,
|
| 400 |
+
) -> Tuple[Tensor, Optional[PaddingMask]]:
|
| 401 |
+
# Get modulator output:
|
| 402 |
+
(modulate_san, modulate_cross_attention, modulate_ffn) = self.modulator(
|
| 403 |
+
emb_timesteps
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
seqs = self._forward_self_attn(
|
| 407 |
+
seqs=seqs,
|
| 408 |
+
padding_mask=padding_mask,
|
| 409 |
+
modulators=modulate_san,
|
| 410 |
+
self_attn_mask=self_attn_mask,
|
| 411 |
+
state_bag=state_bag,
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
seqs = self._forward_cross_attention(
|
| 415 |
+
seqs=seqs,
|
| 416 |
+
padding_mask=padding_mask,
|
| 417 |
+
conditioning_variables=conditioning_variables,
|
| 418 |
+
modulators=modulate_cross_attention,
|
| 419 |
+
cross_attention_mask=cross_attention_mask,
|
| 420 |
+
key_padding_mask=conditioning_variables_padding_mask,
|
| 421 |
+
state_bag=state_bag,
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
seqs = self._forward_ffn(
|
| 425 |
+
seqs=seqs,
|
| 426 |
+
modulators=modulate_ffn,
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
return seqs, padding_mask
|
| 430 |
+
|
| 431 |
+
def _forward_self_attn(
|
| 432 |
+
self,
|
| 433 |
+
seqs: Tensor,
|
| 434 |
+
modulators: Tensor,
|
| 435 |
+
padding_mask: Optional[PaddingMask],
|
| 436 |
+
self_attn_mask: Optional[AttentionMask],
|
| 437 |
+
state_bag: Optional[IncrementalStateBag],
|
| 438 |
+
) -> Tensor:
|
| 439 |
+
residual = seqs
|
| 440 |
+
|
| 441 |
+
assert self.norm_order != TransformerNormOrder.POST, (
|
| 442 |
+
"DiT AdaLN expect pre-normalization"
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
if self.norm_order != TransformerNormOrder.POST:
|
| 446 |
+
seqs = self.self_attn_layer_norm(seqs)
|
| 447 |
+
|
| 448 |
+
# split modulators into shift, scale and gate:
|
| 449 |
+
shift, scale, gate = modulators.chunk(3, dim=-1)
|
| 450 |
+
|
| 451 |
+
# modulate the input:
|
| 452 |
+
seqs = seqs * (1 + scale) + shift
|
| 453 |
+
|
| 454 |
+
seqs = self.self_attn(
|
| 455 |
+
seqs,
|
| 456 |
+
padding_mask,
|
| 457 |
+
keys=seqs,
|
| 458 |
+
key_padding_mask=None,
|
| 459 |
+
values=seqs,
|
| 460 |
+
attn_mask=self_attn_mask,
|
| 461 |
+
state_bag=state_bag,
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
if self.self_attn_norm is not None:
|
| 465 |
+
seqs = self.self_attn_norm(seqs)
|
| 466 |
+
|
| 467 |
+
if self.self_attn_dropout is not None:
|
| 468 |
+
seqs = self.self_attn_dropout(seqs)
|
| 469 |
+
|
| 470 |
+
# Scale the residual with the gate weights
|
| 471 |
+
seqs = residual + gate * seqs
|
| 472 |
+
|
| 473 |
+
return seqs
|
| 474 |
+
|
| 475 |
+
def _forward_cross_attention(
|
| 476 |
+
self,
|
| 477 |
+
seqs: Tensor,
|
| 478 |
+
modulators: Tensor,
|
| 479 |
+
padding_mask: Optional[PaddingMask],
|
| 480 |
+
conditioning_variables: Optional[Tensor],
|
| 481 |
+
key_padding_mask: Optional[PaddingMask],
|
| 482 |
+
cross_attention_mask: Optional[AttentionMask],
|
| 483 |
+
state_bag: Optional[IncrementalStateBag],
|
| 484 |
+
) -> Tensor:
|
| 485 |
+
if conditioning_variables is None:
|
| 486 |
+
raise ValueError(
|
| 487 |
+
"`conditioning_variables` must not be `None` for cross attention."
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
residual = seqs
|
| 491 |
+
|
| 492 |
+
assert self.norm_order != TransformerNormOrder.POST, (
|
| 493 |
+
"DiT AdaLN expect pre-normalization"
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
if self.norm_order != TransformerNormOrder.POST:
|
| 497 |
+
seqs = cast(LayerNorm, self.cross_attention_layer_norm)(seqs)
|
| 498 |
+
|
| 499 |
+
# split modulators into shift, scale and gate:
|
| 500 |
+
shift, scale, gate = modulators.chunk(3, dim=-1)
|
| 501 |
+
|
| 502 |
+
# modulate the input:
|
| 503 |
+
seqs = seqs * (1 + scale) + shift
|
| 504 |
+
|
| 505 |
+
seqs = self.cross_attention(
|
| 506 |
+
seqs,
|
| 507 |
+
padding_mask,
|
| 508 |
+
keys=conditioning_variables,
|
| 509 |
+
key_padding_mask=key_padding_mask,
|
| 510 |
+
attn_mask=cross_attention_mask,
|
| 511 |
+
values=conditioning_variables,
|
| 512 |
+
state_bag=state_bag,
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
if self.cross_attention_dropout is not None:
|
| 516 |
+
seqs = self.cross_attention_dropout(seqs)
|
| 517 |
+
|
| 518 |
+
# Scale the residual with the gate weights
|
| 519 |
+
seqs = residual + gate * seqs
|
| 520 |
+
|
| 521 |
+
return seqs
|
| 522 |
+
|
| 523 |
+
def _forward_ffn(self, seqs: Tensor, modulators: Tensor) -> Tensor:
|
| 524 |
+
assert self.norm_order != TransformerNormOrder.POST, (
|
| 525 |
+
"DiT AdaLN expects pre-normalization"
|
| 526 |
+
)
|
| 527 |
+
residual = seqs
|
| 528 |
+
|
| 529 |
+
if self.norm_order != TransformerNormOrder.POST:
|
| 530 |
+
seqs = self.ffn_layer_norm(seqs)
|
| 531 |
+
|
| 532 |
+
# split modulators into shift, scale and gate:
|
| 533 |
+
shift, scale, gate = modulators.chunk(3, dim=-1)
|
| 534 |
+
|
| 535 |
+
# modulate the input:
|
| 536 |
+
seqs = seqs * (1 + scale) + shift
|
| 537 |
+
|
| 538 |
+
seqs = self.ffn(seqs)
|
| 539 |
+
|
| 540 |
+
if self.ffn_dropout is not None:
|
| 541 |
+
seqs = self.ffn_dropout(seqs)
|
| 542 |
+
|
| 543 |
+
# Scale the branch with the gate weights
|
| 544 |
+
seqs = residual + gate * seqs
|
| 545 |
+
|
| 546 |
+
return seqs
|
lcm/nn/incremental_state.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from typing import Dict, Optional, final
|
| 7 |
+
|
| 8 |
+
from fairseq2.nn.incremental_state import IncrementalState, IncrementalStateBag
|
| 9 |
+
from fairseq2.nn.transformer import FullAttentionState
|
| 10 |
+
from torch import Tensor
|
| 11 |
+
from torch.nn import Module
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@final
|
| 15 |
+
class LCMIncrementalStateBag(IncrementalStateBag): # type: ignore
|
| 16 |
+
"""Holds the module states during incremental decoding."""
|
| 17 |
+
|
| 18 |
+
_module_states: Dict[Module, FullAttentionState] # type: ignore
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self, max_num_steps: int, *, capacity_increment: Optional[int] = 16
|
| 22 |
+
) -> None:
|
| 23 |
+
super().__init__(
|
| 24 |
+
max_num_steps=max_num_steps, capacity_increment=capacity_increment
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
def reorder(self, new_order: Tensor) -> None:
|
| 28 |
+
"""Reorder the module states.
|
| 29 |
+
|
| 30 |
+
See :meth:`IncrementalState.reorder` for more information.
|
| 31 |
+
"""
|
| 32 |
+
# FIXME Deal with reordering diffusion state bags here
|
| 33 |
+
for state in self._module_states.values():
|
| 34 |
+
state.reorder(new_order)
|
| 35 |
+
|
| 36 |
+
def set_state(self, m: Module, state: IncrementalState) -> None:
|
| 37 |
+
"""Set the state of ``m``.
|
| 38 |
+
:param m: The module.
|
| 39 |
+
:param state: The state to store.
|
| 40 |
+
There is no current call to `set_state` when the bag
|
| 41 |
+
is frozen, but it's implemented here for completeness
|
| 42 |
+
"""
|
| 43 |
+
super().set_state(m, state)
|
lcm/nn/initialization.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
from functools import partial
|
| 8 |
+
from typing import Literal, Optional
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from fairseq2.nn.projection import Linear
|
| 12 |
+
from fairseq2.nn.transformer import TransformerNormOrder
|
| 13 |
+
from torch.nn import Module
|
| 14 |
+
|
| 15 |
+
SUPPORTED_INIT_TYPES = Literal[
|
| 16 |
+
"xavier",
|
| 17 |
+
"sonar",
|
| 18 |
+
"zero",
|
| 19 |
+
"trunc_normal",
|
| 20 |
+
"kaiming_uniform",
|
| 21 |
+
"none",
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
SONAR_STD = 0.006
|
| 26 |
+
# Most SONAR embeddings have a distribution with the mean close to 0 and std close to 0.006
|
| 27 |
+
# Initializing embedding-like parameters (e.g. end-of-text vector) from a similar distribution is recommended,
|
| 28 |
+
# to minimize their disruption of the model training
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_init_fn(style: str = "xavier", sonar_std: float = SONAR_STD):
|
| 32 |
+
if style == "xavier":
|
| 33 |
+
return init_linear_xavier
|
| 34 |
+
|
| 35 |
+
if style == "kaiming_uniform":
|
| 36 |
+
return init_linear_kaiming_uniform
|
| 37 |
+
|
| 38 |
+
if style == "sonar":
|
| 39 |
+
return partial(init_linear_to_sonar, sonar_std=sonar_std)
|
| 40 |
+
|
| 41 |
+
if style == "zero":
|
| 42 |
+
return init_linear_zero
|
| 43 |
+
|
| 44 |
+
if style == "trunc_normal":
|
| 45 |
+
return init_linear_trunc_normal
|
| 46 |
+
|
| 47 |
+
if style == "none":
|
| 48 |
+
return None
|
| 49 |
+
|
| 50 |
+
else:
|
| 51 |
+
raise ValueError(f"Could not recognize initialization function {style}")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def init_linear_to_sonar(layer: Linear, sonar_std: float) -> None:
|
| 55 |
+
"""
|
| 56 |
+
Initialize the post-lcm in such a way, that if it is fed layer-normed
|
| 57 |
+
lcm outputs (with zero mean and unit variance), its outputs have zero
|
| 58 |
+
mean and the variance of SONAR embeddings.
|
| 59 |
+
"""
|
| 60 |
+
if layer.bias is not None:
|
| 61 |
+
torch.nn.init.zeros_(layer.bias)
|
| 62 |
+
|
| 63 |
+
std = sonar_std * (3 / layer.input_dim) ** 0.5
|
| 64 |
+
|
| 65 |
+
torch.nn.init.uniform_(layer.weight, a=-std, b=std)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def init_linear_xavier(layer: Linear) -> None:
|
| 69 |
+
torch.nn.init.xavier_uniform_(layer.weight)
|
| 70 |
+
if layer.bias is not None:
|
| 71 |
+
torch.nn.init.zeros_(layer.bias)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def init_linear_zero(layer: Linear) -> None:
|
| 75 |
+
torch.nn.init.zeros_(layer.weight)
|
| 76 |
+
if layer.bias is not None:
|
| 77 |
+
torch.nn.init.zeros_(layer.bias)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def init_linear_trunc_normal(layer: Linear) -> None:
|
| 81 |
+
torch.nn.init.trunc_normal_(layer.weight, std=1e-3)
|
| 82 |
+
if layer.bias is not None:
|
| 83 |
+
torch.nn.init.zeros_(layer.bias)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def init_linear_kaiming_uniform(layer: Linear) -> None:
|
| 87 |
+
torch.nn.init.kaiming_uniform_(layer.weight, a=math.sqrt(5))
|
| 88 |
+
|
| 89 |
+
if layer.bias is not None:
|
| 90 |
+
fan_in = layer.weight.size(1)
|
| 91 |
+
|
| 92 |
+
m = 1
|
| 93 |
+
if layer.weight.ndim > 2:
|
| 94 |
+
for s in layer.weight.shape[2:]:
|
| 95 |
+
m *= s
|
| 96 |
+
|
| 97 |
+
fan_in *= m
|
| 98 |
+
|
| 99 |
+
# We do not calculate the true standard deviation of the uniform
|
| 100 |
+
# distribution (i.e. multiply with sqrt(3)). See
|
| 101 |
+
# https://github.com/pytorch/pytorch/issues/57109#issuecomment-828847575.
|
| 102 |
+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
| 103 |
+
|
| 104 |
+
torch.nn.init.uniform_(layer.bias, -bound, bound)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def parse_norm_order(var: str) -> TransformerNormOrder:
|
| 108 |
+
norm_order: TransformerNormOrder
|
| 109 |
+
if var == "pre":
|
| 110 |
+
norm_order = TransformerNormOrder.PRE
|
| 111 |
+
elif var == "post":
|
| 112 |
+
norm_order = TransformerNormOrder.POST
|
| 113 |
+
elif var == "normformer":
|
| 114 |
+
norm_order = TransformerNormOrder.PRE_WITH_NORMFORMER
|
| 115 |
+
else:
|
| 116 |
+
raise ValueError(f"Unknown normalization order {var}")
|
| 117 |
+
|
| 118 |
+
return norm_order
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def parse_activation_fn(var: str = None) -> Optional[Module]:
|
| 122 |
+
if var is None:
|
| 123 |
+
return None
|
| 124 |
+
|
| 125 |
+
activ_fn: Module
|
| 126 |
+
|
| 127 |
+
if var == "relu":
|
| 128 |
+
activ_fn = torch.nn.ReLU()
|
| 129 |
+
elif var == "tanh":
|
| 130 |
+
activ_fn = torch.nn.Tanh()
|
| 131 |
+
elif var == "elu":
|
| 132 |
+
activ_fn = torch.nn.ELU()
|
| 133 |
+
elif var == "leaky_relu":
|
| 134 |
+
activ_fn = torch.nn.LeakyReLU()
|
| 135 |
+
elif var == "prelu":
|
| 136 |
+
activ_fn = torch.nn.PReLU()
|
| 137 |
+
elif var == "selu":
|
| 138 |
+
activ_fn = torch.nn.SELU()
|
| 139 |
+
elif var == "gelu":
|
| 140 |
+
activ_fn = torch.nn.GELU()
|
| 141 |
+
elif var == "silu":
|
| 142 |
+
activ_fn = torch.nn.SiLU()
|
| 143 |
+
elif var == "softsign":
|
| 144 |
+
activ_fn = torch.nn.Softsign()
|
| 145 |
+
elif var == "sigmoid":
|
| 146 |
+
activ_fn = torch.nn.Sigmoid()
|
| 147 |
+
elif var == "hardsigmoid":
|
| 148 |
+
activ_fn = torch.nn.Hardsigmoid()
|
| 149 |
+
else:
|
| 150 |
+
raise ValueError(f"Unknown activation function {var}")
|
| 151 |
+
|
| 152 |
+
return activ_fn
|
lcm/nn/normalization.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from typing import Literal, Optional, final
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from fairseq2.nn import LayerNorm, RMSNorm, StandardLayerNorm
|
| 10 |
+
from fairseq2.nn.transformer import LayerNormFactory, create_standard_layer_norm
|
| 11 |
+
from fairseq2.typing import DataType, Device, override
|
| 12 |
+
|
| 13 |
+
SUPPORTED_LN_TYPES = Literal["standard", "fp32", "rms", "unit"]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@final
|
| 17 |
+
class FP32LayerNorm(LayerNorm):
|
| 18 |
+
"""Applies Layer Normalization in single-precision."""
|
| 19 |
+
|
| 20 |
+
@override
|
| 21 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 22 |
+
w, b = self.weight, self.bias
|
| 23 |
+
|
| 24 |
+
# cast input and params to float32
|
| 25 |
+
fp32_x = x.float()
|
| 26 |
+
fp32_w = w.float() if w is not None else None
|
| 27 |
+
fp32_b = b.float() if b is not None else None
|
| 28 |
+
|
| 29 |
+
y = torch.nn.functional.layer_norm(
|
| 30 |
+
fp32_x, self.normalized_shape, fp32_w, fp32_b, self.eps
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
return y.type_as(x)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def build_rms_layer_norm(
|
| 37 |
+
model_dim: int,
|
| 38 |
+
*,
|
| 39 |
+
device: Optional[Device] = None,
|
| 40 |
+
dtype: Optional[DataType] = None,
|
| 41 |
+
) -> LayerNorm:
|
| 42 |
+
"""Build an RMS Layer Normalization module."""
|
| 43 |
+
return RMSNorm(model_dim, bias=False, device=device, dtype=dtype)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def build_fp32_layer_norm(
|
| 47 |
+
model_dim: int,
|
| 48 |
+
*,
|
| 49 |
+
device: Optional[Device] = None,
|
| 50 |
+
dtype: Optional[DataType] = None,
|
| 51 |
+
) -> LayerNorm:
|
| 52 |
+
"""Build an Single-precision Layer Normalization module."""
|
| 53 |
+
return FP32LayerNorm(model_dim, bias=False, device=device, dtype=dtype)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def build_unit_layer_norm(
|
| 57 |
+
model_dim: int,
|
| 58 |
+
*,
|
| 59 |
+
device: Optional[Device] = None,
|
| 60 |
+
dtype: Optional[DataType] = None,
|
| 61 |
+
) -> LayerNorm:
|
| 62 |
+
"""Create an instance of :class:`StandardLayerNorm
|
| 63 |
+
without learnable mean and variance`."""
|
| 64 |
+
return StandardLayerNorm(
|
| 65 |
+
model_dim,
|
| 66 |
+
bias=False,
|
| 67 |
+
elementwise_affine=False,
|
| 68 |
+
device=device,
|
| 69 |
+
dtype=dtype,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def parse_layer_norm_factory(layer_normalization_style: str) -> LayerNormFactory:
|
| 74 |
+
if layer_normalization_style == "rms":
|
| 75 |
+
# Note that RMSNorm normalizes in single-precision by default
|
| 76 |
+
return build_rms_layer_norm
|
| 77 |
+
|
| 78 |
+
elif layer_normalization_style == "unit":
|
| 79 |
+
return build_unit_layer_norm
|
| 80 |
+
|
| 81 |
+
elif layer_normalization_style == "fp32":
|
| 82 |
+
return build_fp32_layer_norm
|
| 83 |
+
|
| 84 |
+
elif layer_normalization_style == "standard":
|
| 85 |
+
return create_standard_layer_norm
|
| 86 |
+
|
| 87 |
+
else:
|
| 88 |
+
raise ValueError(f"Unsupported LayerNorm style {layer_normalization_style}")
|
lcm/nn/projection.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from fairseq2.nn.projection import Linear
|
| 11 |
+
from fairseq2.typing import DataType, Device
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
from torch.nn import Module
|
| 14 |
+
|
| 15 |
+
from lcm.nn.initialization import (
|
| 16 |
+
SUPPORTED_INIT_TYPES,
|
| 17 |
+
get_init_fn,
|
| 18 |
+
parse_activation_fn,
|
| 19 |
+
)
|
| 20 |
+
from lcm.nn.normalization import SUPPORTED_LN_TYPES
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class ProjectionConfig:
|
| 25 |
+
dropout_p: float = 0.0
|
| 26 |
+
""" The dropout probability applied to the module' output"""
|
| 27 |
+
|
| 28 |
+
linear_bias: bool = True
|
| 29 |
+
""" Whether or not the pre-linear layer has a bias term"""
|
| 30 |
+
|
| 31 |
+
linear_init_fn: SUPPORTED_INIT_TYPES = "kaiming_uniform"
|
| 32 |
+
|
| 33 |
+
weight_normalization: bool = False
|
| 34 |
+
|
| 35 |
+
layer_normalization_style: SUPPORTED_LN_TYPES = "standard"
|
| 36 |
+
|
| 37 |
+
activation_name: Optional[str] = None
|
| 38 |
+
"""the activation function to apply after fi any"""
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class Projection(Module):
|
| 42 |
+
"""
|
| 43 |
+
An output projecton module.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
output_dim: int,
|
| 49 |
+
input_dim: int,
|
| 50 |
+
config: ProjectionConfig,
|
| 51 |
+
device: Optional[Device] = None,
|
| 52 |
+
dtype: Optional[DataType] = None,
|
| 53 |
+
) -> None:
|
| 54 |
+
super().__init__()
|
| 55 |
+
|
| 56 |
+
self.dtype = dtype
|
| 57 |
+
|
| 58 |
+
init_fn = get_init_fn(config.linear_init_fn)
|
| 59 |
+
|
| 60 |
+
lin = Linear(
|
| 61 |
+
input_dim,
|
| 62 |
+
output_dim,
|
| 63 |
+
bias=config.linear_bias,
|
| 64 |
+
device=device,
|
| 65 |
+
dtype=dtype,
|
| 66 |
+
init_fn=init_fn,
|
| 67 |
+
)
|
| 68 |
+
if config.weight_normalization:
|
| 69 |
+
self.fc = torch.nn.utils.parametrizations.weight_norm(lin)
|
| 70 |
+
else:
|
| 71 |
+
self.fc = lin
|
| 72 |
+
|
| 73 |
+
self.activation_fn = parse_activation_fn(config.activation_name)
|
| 74 |
+
|
| 75 |
+
if self.activation_fn is not None:
|
| 76 |
+
# some activation functions (e.g., PReLU) have parameters
|
| 77 |
+
# and so we need to move them to the right device
|
| 78 |
+
self.activation_fn.to(device)
|
| 79 |
+
|
| 80 |
+
def forward(self, seqs: Tensor):
|
| 81 |
+
seqs = self.fc(seqs)
|
| 82 |
+
|
| 83 |
+
if self.activation_fn is not None:
|
| 84 |
+
seqs = self.activation_fn(seqs)
|
| 85 |
+
|
| 86 |
+
return seqs
|
lcm/nn/schedulers/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from lcm.nn.schedulers.ddim import (
|
| 8 |
+
DDIMScheduler,
|
| 9 |
+
DDIMSchedulerConfig,
|
| 10 |
+
DDIMSchedulerOutput,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
__all__ = [
|
| 14 |
+
"DDIMScheduler",
|
| 15 |
+
"DDIMSchedulerConfig",
|
| 16 |
+
"DDIMSchedulerOutput",
|
| 17 |
+
]
|
lcm/nn/schedulers/ddim.py
ADDED
|
@@ -0,0 +1,741 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
# This code is based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddim.py, which is distributed under the Apache 2.0 License.
|
| 7 |
+
# HuggingFace's diffusers DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
|
| 8 |
+
# and https://github.com/hojonathanho/diffusion
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from typing import List, Literal, Optional, Tuple, Union
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from fairseq2.logging import get_log_writer
|
| 16 |
+
from fairseq2.typing import CPU
|
| 17 |
+
from torch import Tensor
|
| 18 |
+
|
| 19 |
+
logger = get_log_writer(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def sigmoid(x):
|
| 23 |
+
return 1 / (1 + math.exp(-x))
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def logit(x):
|
| 27 |
+
return math.log(x / (1 - x))
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class DDIMSchedulerOutput:
|
| 32 |
+
"""
|
| 33 |
+
Output class for the scheduler's `step` function output.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
prev_sample (`Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
| 37 |
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
| 38 |
+
denoising loop.
|
| 39 |
+
pred_original_sample (`Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
| 40 |
+
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
| 41 |
+
`pred_original_sample` can be used to preview progress or for guidance.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
prev_sample: Tensor
|
| 45 |
+
pred_original_sample: Tensor
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@dataclass
|
| 49 |
+
class DDIMSchedulerConfig:
|
| 50 |
+
num_diffusion_train_steps: int = 1000
|
| 51 |
+
"""The number of diffusion steps to train the model."""
|
| 52 |
+
|
| 53 |
+
beta_start: float = 0.0001
|
| 54 |
+
"""The starting `beta` value of inference."""
|
| 55 |
+
|
| 56 |
+
beta_end: float = 0.02
|
| 57 |
+
"""The final `beta` value."""
|
| 58 |
+
"""In DDPM (https://arxiv.org/pdf/2006.11239), $\beta_t$ is increasing
|
| 59 |
+
linearly from $\beta_1$ (`beta_start`)=1e−4 to $\beta_T$ (`beta_end`)=0.02.
|
| 60 |
+
These constants were chosen to be small relative to data scaled to [−1, 1],
|
| 61 |
+
ensuring that reverse and forward processes have approximately
|
| 62 |
+
the same functional form while keeping the signal-to-noise ratio at $x_T$ as small as possible.
|
| 63 |
+
Another common choice in HF:diffusers `beta_start=0.00085, beta_end=0.012,`
|
| 64 |
+
Note that `beta_start` and `beta_end` are irrelevant for `squaredcos_cap_v2`
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
beta_schedule: Literal[
|
| 68 |
+
"linear",
|
| 69 |
+
"scaled_linear",
|
| 70 |
+
"squaredcos_cap_v2",
|
| 71 |
+
"sigmoid",
|
| 72 |
+
] = "squaredcos_cap_v2"
|
| 73 |
+
"""The beta schedule, a mapping from a beta range to a sequence of betas
|
| 74 |
+
for stepping the model (length=`num_diffusion_train_steps`).
|
| 75 |
+
Choose from:
|
| 76 |
+
- `linear`: Linearly spaced betas between `beta_start` and `beta_end`.
|
| 77 |
+
Referred to as `sqrt_linear` in stable-diffusion.
|
| 78 |
+
- `scaled_linear`: Squared values after linearly spacing form sqrt(beta_start) to sqrt(beta_end).
|
| 79 |
+
Referred to as `linear` in stable-diffusion.
|
| 80 |
+
-`squaredcos_cap_v2`: Creates a beta schedule that discretizes
|
| 81 |
+
math:: $\bar alpha(t) = {cos((t/T + s) / (1+s) * \pi/2)}^2$, HF:diffusers sets `s` to 0.008.
|
| 82 |
+
For the intuition behind how a cosine schedule compares to a linear schedule
|
| 83 |
+
see Figure 3 of https://arxiv.org/pdf/2102.09672
|
| 84 |
+
- `sigmoid` our sigmoid schedule (see Equation 14 of the LCM paper).
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
scaled_linear_exponent: float = 2.0
|
| 88 |
+
"""Exponent for the scaled linear beta schedule. Default is quadratic (scaled_linear_exponent=2)"""
|
| 89 |
+
|
| 90 |
+
sigmoid_schedule_alpha: float = 1.5
|
| 91 |
+
sigmoid_schedule_beta: float = 0
|
| 92 |
+
"""alpha and beta hyper-parameters of the sigmoid beta-schedule"""
|
| 93 |
+
|
| 94 |
+
clip_sample: bool = False
|
| 95 |
+
"""Clip the predicted sample for numerical stability."""
|
| 96 |
+
|
| 97 |
+
clip_sample_range: float = 1.0
|
| 98 |
+
"""The maximum magnitude for sample clipping. Valid only when `clip_sample=True`."""
|
| 99 |
+
|
| 100 |
+
set_alpha_to_one: bool = True
|
| 101 |
+
"""Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
|
| 102 |
+
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
| 103 |
+
otherwise it uses the alpha value at step 0."""
|
| 104 |
+
|
| 105 |
+
prediction_type: Literal["sample", "epsilon", "v_prediction"] = "sample"
|
| 106 |
+
"""If `sample`, the model predicts the clean ground truth embeddings.
|
| 107 |
+
If `epsilon`, the model predicts the added noise of the diffusion process.
|
| 108 |
+
If `v_epsilon`, the model predicts an interpolation of the ground truth clean
|
| 109 |
+
embeddings and the added noise. As introduced in section 2.4 of the Imagen paper
|
| 110 |
+
(https://imagen.research.google/video/paper.pdf)
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
thresholding: bool = False
|
| 114 |
+
"""Whether to use the "dynamic thresholding" method.
|
| 115 |
+
This is unsuitable for latent-space diffusion models such as Stable Diffusion."""
|
| 116 |
+
|
| 117 |
+
dynamic_thresholding_ratio: float = 0.995
|
| 118 |
+
"""The ratio for the dynamic thresholding method. Valid only when `thresholding=True`."""
|
| 119 |
+
|
| 120 |
+
sample_max_value: float = 1.0
|
| 121 |
+
"""The threshold value for dynamic thresholding. Valid only when `thresholding=True`."""
|
| 122 |
+
|
| 123 |
+
rescale_betas_zero_snr: bool = True
|
| 124 |
+
"""Whether to rescale the betas to have zero terminal SNR. This enables the
|
| 125 |
+
model to generate very bright and dark samples instead of limiting it to samples
|
| 126 |
+
with medium brightness. Loosely related to
|
| 127 |
+
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506)."""
|
| 128 |
+
|
| 129 |
+
# Inference specific
|
| 130 |
+
timestep_spacing: Literal["linspace", "leading", "trailing"] = "trailing"
|
| 131 |
+
"""The way the timesteps should be scaled. Refer to Table 2 of
|
| 132 |
+
https://arxiv.org/abs/2305.08891 for more information."""
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class DDIMScheduler:
|
| 136 |
+
def __init__(self, config: DDIMSchedulerConfig):
|
| 137 |
+
self.config = config
|
| 138 |
+
|
| 139 |
+
# Make these 2 arguments easily accessible
|
| 140 |
+
self.num_diffusion_train_steps = self.config.num_diffusion_train_steps
|
| 141 |
+
|
| 142 |
+
self.prediction_type = self.config.prediction_type
|
| 143 |
+
|
| 144 |
+
beta_schedule = self.config.beta_schedule
|
| 145 |
+
|
| 146 |
+
if beta_schedule == "linear":
|
| 147 |
+
self.betas = torch.linspace(
|
| 148 |
+
self.config.beta_start,
|
| 149 |
+
self.config.beta_end,
|
| 150 |
+
self.num_diffusion_train_steps,
|
| 151 |
+
dtype=torch.float32,
|
| 152 |
+
)
|
| 153 |
+
elif beta_schedule == "scaled_linear":
|
| 154 |
+
# This schedule is very specific to the latent diffusion model.
|
| 155 |
+
exponent = self.config.scaled_linear_exponent
|
| 156 |
+
self.betas = (
|
| 157 |
+
torch.linspace(
|
| 158 |
+
self.config.beta_start ** (1 / exponent),
|
| 159 |
+
self.config.beta_end ** (1 / exponent),
|
| 160 |
+
self.num_diffusion_train_steps,
|
| 161 |
+
dtype=torch.float32,
|
| 162 |
+
)
|
| 163 |
+
** exponent
|
| 164 |
+
)
|
| 165 |
+
elif beta_schedule == "squaredcos_cap_v2":
|
| 166 |
+
# Cosine schedule as introduced in
|
| 167 |
+
# [Nichol and Dhariwal, 2021](https://proceedings.mlr.press/v139/nichol21a/nichol21a.pdf)
|
| 168 |
+
self.betas = betas_for_alpha_bar(
|
| 169 |
+
self.num_diffusion_train_steps,
|
| 170 |
+
alpha_transform_type="cosine",
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
elif beta_schedule == "sigmoid":
|
| 174 |
+
self.betas = betas_for_alpha_bar(
|
| 175 |
+
self.num_diffusion_train_steps,
|
| 176 |
+
alpha_transform_type="sigmoid",
|
| 177 |
+
sigmoid_alpha=self.config.sigmoid_schedule_alpha,
|
| 178 |
+
sigmoid_beta=self.config.sigmoid_schedule_beta,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
else:
|
| 182 |
+
raise NotImplementedError(
|
| 183 |
+
f"We do not recognize beta_schedule={beta_schedule}"
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# Rescale for zero SNR
|
| 187 |
+
if self.config.rescale_betas_zero_snr:
|
| 188 |
+
self.betas = rescale_zero_terminal_snr(self.betas)
|
| 189 |
+
|
| 190 |
+
self.alphas = 1.0 - self.betas
|
| 191 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
| 192 |
+
|
| 193 |
+
# At every step in ddim, we are looking into the previous alphas_cumprod
|
| 194 |
+
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
| 195 |
+
# `set_alpha_to_one` decides whether we set this parameter simply to one or
|
| 196 |
+
# whether we use the final alpha of the "non-previous" one.
|
| 197 |
+
self.final_alpha_cumprod = (
|
| 198 |
+
torch.tensor(1.0)
|
| 199 |
+
if self.config.set_alpha_to_one
|
| 200 |
+
else self.alphas_cumprod[0]
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# standard deviation of the initial noise distribution
|
| 204 |
+
self.init_noise_sigma = 1.0
|
| 205 |
+
|
| 206 |
+
# timesteps for inference
|
| 207 |
+
self.num_inference_steps: Optional[int] = None
|
| 208 |
+
|
| 209 |
+
def _get_variance(self, timestep, prev_timestep):
|
| 210 |
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
| 211 |
+
alpha_prod_t_prev = (
|
| 212 |
+
self.alphas_cumprod[prev_timestep]
|
| 213 |
+
if prev_timestep >= 0
|
| 214 |
+
else self.final_alpha_cumprod
|
| 215 |
+
)
|
| 216 |
+
beta_prod_t = 1 - alpha_prod_t
|
| 217 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
| 218 |
+
|
| 219 |
+
variance = (beta_prod_t_prev / beta_prod_t) * (
|
| 220 |
+
1 - alpha_prod_t / alpha_prod_t_prev
|
| 221 |
+
)
|
| 222 |
+
return variance
|
| 223 |
+
|
| 224 |
+
def get_variances(self) -> Tensor:
|
| 225 |
+
alpha_prod_t = self.alphas_cumprod
|
| 226 |
+
alpha_prod_t_prev = torch.cat(
|
| 227 |
+
(torch.tensor([self.final_alpha_cumprod]), alpha_prod_t[:-1])
|
| 228 |
+
)
|
| 229 |
+
beta_prod_t = 1 - alpha_prod_t
|
| 230 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
| 231 |
+
|
| 232 |
+
variance = (beta_prod_t_prev / beta_prod_t) * (
|
| 233 |
+
1 - alpha_prod_t / alpha_prod_t_prev
|
| 234 |
+
)
|
| 235 |
+
return variance
|
| 236 |
+
|
| 237 |
+
def get_snrs(self) -> Tensor:
|
| 238 |
+
alphas_cumprod = self.alphas_cumprod
|
| 239 |
+
snr = alphas_cumprod / (1 - alphas_cumprod)
|
| 240 |
+
return snr
|
| 241 |
+
|
| 242 |
+
def _threshold_sample(self, sample: Tensor) -> Tensor:
|
| 243 |
+
"""
|
| 244 |
+
"Dynamic thresholding: At each sampling step we set s to a certain
|
| 245 |
+
percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t),
|
| 246 |
+
and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
| 247 |
+
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1)
|
| 248 |
+
inwards, thereby actively preventing pixels from saturation at each step.
|
| 249 |
+
We find that dynamic thresholding results in significantly better
|
| 250 |
+
photorealism as well as better image-text alignment,
|
| 251 |
+
especially when using very large guidance weights."
|
| 252 |
+
|
| 253 |
+
https://arxiv.org/abs/2205.11487
|
| 254 |
+
"""
|
| 255 |
+
dtype = sample.dtype
|
| 256 |
+
batch_size, channels, *remaining_dims = sample.shape
|
| 257 |
+
|
| 258 |
+
if dtype not in (torch.float32, torch.float64):
|
| 259 |
+
sample = (
|
| 260 |
+
sample.float()
|
| 261 |
+
) # upcast for quantile calculation, and clamp not implemented for cpu half
|
| 262 |
+
|
| 263 |
+
# Flatten sample for doing quantile calculation along each image
|
| 264 |
+
sample = sample.reshape(batch_size, -1)
|
| 265 |
+
|
| 266 |
+
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
| 267 |
+
|
| 268 |
+
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
| 269 |
+
s = torch.clamp(
|
| 270 |
+
s, min=1, max=self.config.sample_max_value
|
| 271 |
+
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
| 272 |
+
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
| 273 |
+
sample = (
|
| 274 |
+
torch.clamp(sample, -s, s) / s
|
| 275 |
+
) # "we threshold xt0 to the range [-s, s] and then divide by s"
|
| 276 |
+
|
| 277 |
+
sample = sample.reshape(batch_size, channels, *remaining_dims)
|
| 278 |
+
sample = sample.to(dtype)
|
| 279 |
+
|
| 280 |
+
return sample
|
| 281 |
+
|
| 282 |
+
def set_timesteps(
|
| 283 |
+
self, num_inference_steps: int, device: Union[str, torch.device] = None
|
| 284 |
+
):
|
| 285 |
+
"""
|
| 286 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
num_inference_steps (`int`):
|
| 290 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
| 291 |
+
"""
|
| 292 |
+
|
| 293 |
+
if num_inference_steps > self.config.num_diffusion_train_steps:
|
| 294 |
+
raise ValueError(
|
| 295 |
+
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.num_diffusion_train_steps`:"
|
| 296 |
+
f" {self.num_diffusion_train_steps} as the unet model trained with this scheduler can only handle"
|
| 297 |
+
f" maximal {self.num_diffusion_train_steps} timesteps."
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
self.num_inference_steps = num_inference_steps
|
| 301 |
+
|
| 302 |
+
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
|
| 303 |
+
# With T the number of training steps and S the number of inference steps
|
| 304 |
+
|
| 305 |
+
if self.config.timestep_spacing == "linspace":
|
| 306 |
+
# Linspace: flip round(linspace(1, T, S))
|
| 307 |
+
# With T=1000 and S=10; [999, 888, 777, 666, 555, 444, 333, 222, 111, 0]
|
| 308 |
+
timesteps = torch.linspace(
|
| 309 |
+
0,
|
| 310 |
+
self.config.num_diffusion_train_steps - 1,
|
| 311 |
+
self.num_inference_steps,
|
| 312 |
+
device=device,
|
| 313 |
+
dtype=torch.long,
|
| 314 |
+
)
|
| 315 |
+
timesteps = torch.flip(timesteps, dims=(0,)).round()
|
| 316 |
+
|
| 317 |
+
elif self.config.timestep_spacing == "leading":
|
| 318 |
+
# Leading: flip arange(1, T + 1, floor(T /S))
|
| 319 |
+
# With T=1000 and S=10: [900, 800, 700, 600, 500, 400, 300, 200, 100, 0]
|
| 320 |
+
|
| 321 |
+
leading_step_ratio = (
|
| 322 |
+
self.num_diffusion_train_steps // self.num_inference_steps
|
| 323 |
+
)
|
| 324 |
+
timesteps = torch.arange(
|
| 325 |
+
start=0,
|
| 326 |
+
end=self.num_diffusion_train_steps,
|
| 327 |
+
step=leading_step_ratio,
|
| 328 |
+
device=device,
|
| 329 |
+
dtype=torch.long,
|
| 330 |
+
)
|
| 331 |
+
timesteps = torch.flip(timesteps, dims=(0,)).round()
|
| 332 |
+
|
| 333 |
+
elif self.config.timestep_spacing == "trailing":
|
| 334 |
+
# Trailing: round(flip(arange(T, 0, −T /S)))
|
| 335 |
+
# With T=1000 and S=10: [999, 899, 799, 699, 599, 499, 399, 299, 199, 99]
|
| 336 |
+
trailing_step_ratio: float = (
|
| 337 |
+
self.num_diffusion_train_steps / self.num_inference_steps
|
| 338 |
+
)
|
| 339 |
+
# creates integer timesteps by multiplying by ratio
|
| 340 |
+
timesteps = torch.arange(
|
| 341 |
+
self.config.num_diffusion_train_steps,
|
| 342 |
+
0,
|
| 343 |
+
-trailing_step_ratio,
|
| 344 |
+
device=device,
|
| 345 |
+
dtype=torch.long,
|
| 346 |
+
).round()
|
| 347 |
+
timesteps -= 1
|
| 348 |
+
else:
|
| 349 |
+
raise ValueError(
|
| 350 |
+
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
self.timesteps = timesteps
|
| 354 |
+
logger.debug(
|
| 355 |
+
f"With `{self.config.timestep_spacing}`, setting inference timesteps to {self.timesteps}"
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
def step(
|
| 359 |
+
self,
|
| 360 |
+
model_output: Tensor,
|
| 361 |
+
timestep: int,
|
| 362 |
+
sample: Tensor,
|
| 363 |
+
eta: float = 0.0,
|
| 364 |
+
use_clipped_model_output: bool = False,
|
| 365 |
+
generator=None,
|
| 366 |
+
variance_noise: Optional[Tensor] = None,
|
| 367 |
+
prediction_type: Optional[str] = None,
|
| 368 |
+
epsilon_scaling: Optional[float] = None,
|
| 369 |
+
) -> DDIMSchedulerOutput:
|
| 370 |
+
"""
|
| 371 |
+
INFERENCE ONLY.
|
| 372 |
+
Predict the sample from the previous timestep by reversing the SDE.
|
| 373 |
+
This function propagates the diffusion
|
| 374 |
+
process from the learned model outputs.
|
| 375 |
+
|
| 376 |
+
Args:
|
| 377 |
+
model_output (`Tensor`):
|
| 378 |
+
The direct output from learned diffusion model.
|
| 379 |
+
timestep (`float`):
|
| 380 |
+
The current discrete timestep in the diffusion chain.
|
| 381 |
+
sample (`Tensor`):
|
| 382 |
+
A current instance of a sample created by the diffusion process.
|
| 383 |
+
eta (`float`):
|
| 384 |
+
The weight of noise for added noise in diffusion step.
|
| 385 |
+
use_clipped_model_output (`bool`, defaults to `False`):
|
| 386 |
+
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
|
| 387 |
+
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
|
| 388 |
+
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
|
| 389 |
+
`use_clipped_model_output` has no effect.
|
| 390 |
+
generator (`torch.Generator`, *optional*):
|
| 391 |
+
A random number generator.
|
| 392 |
+
variance_noise (`Tensor`):
|
| 393 |
+
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
| 394 |
+
itself. Useful for methods such as [`CycleDiffusion`].
|
| 395 |
+
prediction_type: Optional[str] if provided we step with a different prediction_type
|
| 396 |
+
than the one in the config
|
| 397 |
+
epsilon_scaling: Optional[float] if not None, the predicted epsilon will be scaled down by
|
| 398 |
+
the provided factor as introduced in https://arxiv.org/pdf/2308.15321
|
| 399 |
+
|
| 400 |
+
Returns:
|
| 401 |
+
DDIMSchedulerOutput
|
| 402 |
+
|
| 403 |
+
"""
|
| 404 |
+
if self.num_inference_steps is None:
|
| 405 |
+
raise ValueError(
|
| 406 |
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
| 410 |
+
# Ideally, read DDIM paper in-detail understanding
|
| 411 |
+
|
| 412 |
+
# Notation (<variable name> -> <name in paper>
|
| 413 |
+
# - pred_noise_t -> e_theta(x_t, t)
|
| 414 |
+
# - pred_original_sample -> f_theta(x_t, t) or x_0
|
| 415 |
+
# - std_dev_t -> sigma_t
|
| 416 |
+
# - eta -> η
|
| 417 |
+
# - pred_sample_direction -> "direction pointing to x_t"
|
| 418 |
+
# - pred_prev_sample -> "x_t-1"
|
| 419 |
+
|
| 420 |
+
# 1. Get previous step value (=t-1)
|
| 421 |
+
prev_timestep = (
|
| 422 |
+
timestep - self.config.num_diffusion_train_steps // self.num_inference_steps
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
# 2. Compute alphas, betas
|
| 426 |
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
| 427 |
+
alpha_prod_t_prev = (
|
| 428 |
+
self.alphas_cumprod[prev_timestep]
|
| 429 |
+
if prev_timestep >= 0
|
| 430 |
+
else self.final_alpha_cumprod
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
beta_prod_t = 1 - alpha_prod_t
|
| 434 |
+
|
| 435 |
+
# 3. Compute predicted original sample from predicted noise also called
|
| 436 |
+
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
| 437 |
+
prediction_type = prediction_type or self.prediction_type
|
| 438 |
+
if prediction_type == "epsilon":
|
| 439 |
+
pred_original_sample = (
|
| 440 |
+
sample - beta_prod_t ** (0.5) * model_output
|
| 441 |
+
) / alpha_prod_t ** (0.5)
|
| 442 |
+
pred_epsilon = model_output
|
| 443 |
+
elif prediction_type == "sample":
|
| 444 |
+
pred_original_sample = model_output
|
| 445 |
+
pred_epsilon = (
|
| 446 |
+
sample - alpha_prod_t ** (0.5) * pred_original_sample
|
| 447 |
+
) / beta_prod_t ** (0.5)
|
| 448 |
+
elif prediction_type == "v_prediction":
|
| 449 |
+
pred_original_sample = (alpha_prod_t**0.5) * sample - (
|
| 450 |
+
beta_prod_t**0.5
|
| 451 |
+
) * model_output
|
| 452 |
+
pred_epsilon = (alpha_prod_t**0.5) * model_output + (
|
| 453 |
+
beta_prod_t**0.5
|
| 454 |
+
) * sample
|
| 455 |
+
else:
|
| 456 |
+
raise ValueError(
|
| 457 |
+
f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or"
|
| 458 |
+
" `v_prediction`"
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
# 3.a epsilon scaling:
|
| 462 |
+
if epsilon_scaling is not None:
|
| 463 |
+
pred_epsilon = pred_epsilon / epsilon_scaling
|
| 464 |
+
|
| 465 |
+
# 4. Clip or threshold "predicted x_0"
|
| 466 |
+
if self.config.thresholding:
|
| 467 |
+
pred_original_sample = self._threshold_sample(pred_original_sample)
|
| 468 |
+
elif self.config.clip_sample:
|
| 469 |
+
pred_original_sample = pred_original_sample.clamp(
|
| 470 |
+
-self.config.clip_sample_range, self.config.clip_sample_range
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
# 5. Compute variance: "sigma_t(η)" -> see formula (16)
|
| 474 |
+
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
| 475 |
+
variance = self._get_variance(timestep, prev_timestep)
|
| 476 |
+
std_dev_t = eta * variance ** (0.5)
|
| 477 |
+
if use_clipped_model_output:
|
| 478 |
+
# the pred_epsilon is always re-derived from the clipped x_0 in Glide
|
| 479 |
+
pred_epsilon = (
|
| 480 |
+
sample - alpha_prod_t ** (0.5) * pred_original_sample
|
| 481 |
+
) / beta_prod_t ** (0.5)
|
| 482 |
+
# 6. Compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
| 483 |
+
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (
|
| 484 |
+
0.5
|
| 485 |
+
) * pred_epsilon
|
| 486 |
+
# 7. Compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
| 487 |
+
prev_sample = (
|
| 488 |
+
alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
if eta > 0:
|
| 492 |
+
if variance_noise is not None and generator is not None:
|
| 493 |
+
raise ValueError(
|
| 494 |
+
"Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
|
| 495 |
+
" `variance_noise` stays `None`."
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
if variance_noise is None:
|
| 499 |
+
variance_noise = randn_tensor(
|
| 500 |
+
model_output.shape,
|
| 501 |
+
generator=generator,
|
| 502 |
+
device=model_output.device,
|
| 503 |
+
dtype=model_output.dtype,
|
| 504 |
+
)
|
| 505 |
+
variance = std_dev_t * variance_noise
|
| 506 |
+
prev_sample = prev_sample + variance
|
| 507 |
+
|
| 508 |
+
return DDIMSchedulerOutput(
|
| 509 |
+
prev_sample=prev_sample, pred_original_sample=pred_original_sample
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
def add_noise(
|
| 513 |
+
self,
|
| 514 |
+
original_samples: Tensor,
|
| 515 |
+
noise: Tensor,
|
| 516 |
+
timesteps: Tensor,
|
| 517 |
+
) -> Tensor:
|
| 518 |
+
"""TRAINING ONLY
|
| 519 |
+
Forward noising process during training"""
|
| 520 |
+
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
| 521 |
+
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
|
| 522 |
+
# for the subsequent add_noise calls
|
| 523 |
+
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
|
| 524 |
+
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
|
| 525 |
+
timesteps = timesteps.to(original_samples.device).to(torch.int32)
|
| 526 |
+
|
| 527 |
+
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
| 528 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
| 529 |
+
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
| 530 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
| 531 |
+
|
| 532 |
+
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
| 533 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
| 534 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
| 535 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
| 536 |
+
|
| 537 |
+
noisy_samples = (
|
| 538 |
+
sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
| 539 |
+
)
|
| 540 |
+
return noisy_samples
|
| 541 |
+
|
| 542 |
+
def get_velocity(self, sample: Tensor, noise: Tensor, timesteps: Tensor) -> Tensor:
|
| 543 |
+
# Make sure alphas_cumprod and timestep have same device and dtype as sample
|
| 544 |
+
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
|
| 545 |
+
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
|
| 546 |
+
timesteps = timesteps.to(sample.device).to(torch.int32)
|
| 547 |
+
|
| 548 |
+
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
| 549 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
| 550 |
+
while len(sqrt_alpha_prod.shape) < len(sample.shape):
|
| 551 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
| 552 |
+
|
| 553 |
+
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
| 554 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
| 555 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
|
| 556 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
| 557 |
+
|
| 558 |
+
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
| 559 |
+
return velocity
|
| 560 |
+
|
| 561 |
+
def get_epsilon(
|
| 562 |
+
self, model_output: Tensor, sample: Tensor, timestep: int
|
| 563 |
+
) -> Tensor:
|
| 564 |
+
"""Given model inputs (sample) and outputs (model_output)
|
| 565 |
+
Predict the noise residual according to the scheduler's
|
| 566 |
+
prediction type"""
|
| 567 |
+
|
| 568 |
+
pred_type = self.prediction_type
|
| 569 |
+
|
| 570 |
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
| 571 |
+
|
| 572 |
+
beta_prod_t = 1 - alpha_prod_t
|
| 573 |
+
|
| 574 |
+
if pred_type == "epsilon":
|
| 575 |
+
return model_output
|
| 576 |
+
|
| 577 |
+
elif pred_type == "sample":
|
| 578 |
+
return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (
|
| 579 |
+
0.5
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
elif pred_type == "v_prediction":
|
| 583 |
+
return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
| 584 |
+
else:
|
| 585 |
+
raise ValueError(
|
| 586 |
+
f"The scheduler's prediction type {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`"
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
def randn_tensor(
|
| 591 |
+
shape: Union[Tuple, List],
|
| 592 |
+
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
|
| 593 |
+
device: Optional["torch.device"] = None,
|
| 594 |
+
dtype: Optional["torch.dtype"] = None,
|
| 595 |
+
layout: Optional["torch.layout"] = None,
|
| 596 |
+
):
|
| 597 |
+
"""A helper function to create random tensors on the desired `device` with the desired `dtype`. When
|
| 598 |
+
passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor
|
| 599 |
+
is always created on the CPU.
|
| 600 |
+
"""
|
| 601 |
+
# device on which tensor is created defaults to device
|
| 602 |
+
rand_device = device
|
| 603 |
+
batch_size = shape[0]
|
| 604 |
+
|
| 605 |
+
layout = layout or torch.strided
|
| 606 |
+
device = device or torch.device("cpu")
|
| 607 |
+
|
| 608 |
+
if generator is not None:
|
| 609 |
+
gen_device_type = (
|
| 610 |
+
generator.device.type
|
| 611 |
+
if not isinstance(generator, list)
|
| 612 |
+
else generator[0].device.type
|
| 613 |
+
)
|
| 614 |
+
if gen_device_type != device.type and gen_device_type == "cpu":
|
| 615 |
+
rand_device = CPU
|
| 616 |
+
if device != "mps":
|
| 617 |
+
logger.info(
|
| 618 |
+
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
|
| 619 |
+
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
|
| 620 |
+
f" slighly speed up this function by passing a generator that was created on the {device} device."
|
| 621 |
+
)
|
| 622 |
+
elif gen_device_type != device.type and gen_device_type == "cuda":
|
| 623 |
+
raise ValueError(
|
| 624 |
+
f"Cannot generate a {device} tensor from a generator of type {gen_device_type}."
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
# make sure generator list of length 1 is treated like a non-list
|
| 628 |
+
if isinstance(generator, list) and len(generator) == 1:
|
| 629 |
+
generator = generator[0]
|
| 630 |
+
|
| 631 |
+
if isinstance(generator, list):
|
| 632 |
+
shape = (1,) + shape[1:] # type: ignore
|
| 633 |
+
latents_list = [
|
| 634 |
+
torch.randn(
|
| 635 |
+
shape,
|
| 636 |
+
generator=generator[i],
|
| 637 |
+
device=rand_device,
|
| 638 |
+
dtype=dtype,
|
| 639 |
+
layout=layout,
|
| 640 |
+
)
|
| 641 |
+
for i in range(batch_size)
|
| 642 |
+
]
|
| 643 |
+
latents = torch.cat(latents_list, dim=0).to(device)
|
| 644 |
+
else:
|
| 645 |
+
latents = torch.randn(
|
| 646 |
+
shape, generator=generator, device=rand_device, dtype=dtype, layout=layout
|
| 647 |
+
).to(device)
|
| 648 |
+
|
| 649 |
+
return latents
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
def betas_for_alpha_bar(
|
| 653 |
+
num_diffusion_timesteps: int,
|
| 654 |
+
max_beta: float = 0.999,
|
| 655 |
+
alpha_transform_type: Literal["cosine", "exp", "sigmoid"] = "cosine",
|
| 656 |
+
sigmoid_alpha: float = 1.5,
|
| 657 |
+
sigmoid_beta: float = 0,
|
| 658 |
+
) -> Tensor:
|
| 659 |
+
"""
|
| 660 |
+
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
| 661 |
+
(1-beta) over time from t = [0,1].
|
| 662 |
+
|
| 663 |
+
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
| 664 |
+
to that part of the diffusion process.
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
Args:
|
| 668 |
+
num_diffusion_timesteps (`int`): the number of betas to produce.
|
| 669 |
+
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
| 670 |
+
prevent singularities.
|
| 671 |
+
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
| 672 |
+
Choose from `cosine` or `exp`
|
| 673 |
+
sigmoid_alpha/sigmoid_beta: additional hyper-parameters for the sigmoid schedule
|
| 674 |
+
|
| 675 |
+
Returns:
|
| 676 |
+
betas (`Tensor`): the betas used by the scheduler to step the model outputs
|
| 677 |
+
"""
|
| 678 |
+
if alpha_transform_type == "cosine":
|
| 679 |
+
|
| 680 |
+
def alpha_bar_fn(t):
|
| 681 |
+
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
| 682 |
+
|
| 683 |
+
elif alpha_transform_type == "sigmoid":
|
| 684 |
+
|
| 685 |
+
def alpha_bar_fn(t):
|
| 686 |
+
epsilon = 1e-32
|
| 687 |
+
return sigmoid(
|
| 688 |
+
sigmoid_beta
|
| 689 |
+
- sigmoid_alpha
|
| 690 |
+
* logit(torch.clamp(torch.tensor(t), min=epsilon, max=1 - epsilon))
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
elif alpha_transform_type == "exp":
|
| 694 |
+
|
| 695 |
+
def alpha_bar_fn(t):
|
| 696 |
+
return math.exp(t * -12.0)
|
| 697 |
+
|
| 698 |
+
else:
|
| 699 |
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
| 700 |
+
|
| 701 |
+
betas = []
|
| 702 |
+
for i in range(num_diffusion_timesteps):
|
| 703 |
+
t1 = i / num_diffusion_timesteps
|
| 704 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
| 705 |
+
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
| 706 |
+
return torch.tensor(betas, dtype=torch.float32)
|
| 707 |
+
|
| 708 |
+
|
| 709 |
+
def rescale_zero_terminal_snr(betas: Tensor) -> Tensor:
|
| 710 |
+
"""
|
| 711 |
+
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
| 712 |
+
|
| 713 |
+
Args:
|
| 714 |
+
betas (`Tensor`):
|
| 715 |
+
the betas that the scheduler is being initialized with.
|
| 716 |
+
|
| 717 |
+
Returns:
|
| 718 |
+
`Tensor`: rescaled betas with zero terminal SNR
|
| 719 |
+
"""
|
| 720 |
+
# Convert betas to alphas_bar_sqrt
|
| 721 |
+
alphas = 1.0 - betas
|
| 722 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
| 723 |
+
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
| 724 |
+
|
| 725 |
+
# Store old values.
|
| 726 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
| 727 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
| 728 |
+
|
| 729 |
+
# Shift so the last timestep is zero.
|
| 730 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
| 731 |
+
|
| 732 |
+
# Scale so the first timestep is back to the old value.
|
| 733 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
| 734 |
+
|
| 735 |
+
# Convert alphas_bar_sqrt to betas
|
| 736 |
+
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
| 737 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
| 738 |
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
| 739 |
+
betas = 1 - alphas
|
| 740 |
+
|
| 741 |
+
return betas
|
lcm/nn/timestep_encoder.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from fairseq2.nn.projection import Linear
|
| 11 |
+
from fairseq2.typing import DataType, Device
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
from torch.nn import Module
|
| 14 |
+
|
| 15 |
+
from lcm.nn.initialization import parse_activation_fn
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class DiTTimestepEncoder(Module):
|
| 19 |
+
"""
|
| 20 |
+
Embeds scalar timesteps into vector representations.
|
| 21 |
+
Based on DiT's `TimestepEmbedder`
|
| 22 |
+
https://github.com/facebookresearch/DiT/blob/main/models.py
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
embedding_dim: int,
|
| 28 |
+
frequency_embedding_size: int = 256,
|
| 29 |
+
activation_fn_name: str = "silu",
|
| 30 |
+
device: Optional[Device] = None,
|
| 31 |
+
dtype: Optional[DataType] = None,
|
| 32 |
+
):
|
| 33 |
+
super().__init__()
|
| 34 |
+
|
| 35 |
+
self.dtype = dtype
|
| 36 |
+
|
| 37 |
+
self.device = device
|
| 38 |
+
|
| 39 |
+
self.embedding_dim = embedding_dim
|
| 40 |
+
|
| 41 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 42 |
+
|
| 43 |
+
self.fc1 = Linear(
|
| 44 |
+
frequency_embedding_size,
|
| 45 |
+
embedding_dim,
|
| 46 |
+
bias=True,
|
| 47 |
+
device=device,
|
| 48 |
+
dtype=dtype,
|
| 49 |
+
)
|
| 50 |
+
self.nonlin = parse_activation_fn(activation_fn_name)
|
| 51 |
+
self.fc2 = Linear(
|
| 52 |
+
embedding_dim,
|
| 53 |
+
embedding_dim,
|
| 54 |
+
bias=True,
|
| 55 |
+
device=device,
|
| 56 |
+
dtype=dtype,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
self.reset_parameters()
|
| 60 |
+
|
| 61 |
+
def reset_parameters(self) -> None:
|
| 62 |
+
"""Reset the parameters and buffers of the module."""
|
| 63 |
+
torch.nn.init.normal_(self.fc1.weight, std=0.02)
|
| 64 |
+
torch.nn.init.normal_(self.fc2.weight, std=0.02)
|
| 65 |
+
|
| 66 |
+
if self.fc1.bias is not None:
|
| 67 |
+
torch.nn.init.zeros_(self.fc1.bias)
|
| 68 |
+
|
| 69 |
+
if self.fc2.bias is not None:
|
| 70 |
+
torch.nn.init.zeros_(self.fc2.bias)
|
| 71 |
+
|
| 72 |
+
@staticmethod
|
| 73 |
+
def sinusoidal_timestep_embedding(
|
| 74 |
+
timestep, frequency_embedding_size, max_period=10000
|
| 75 |
+
):
|
| 76 |
+
"""
|
| 77 |
+
Create sinusoidal timestep embeddings.
|
| 78 |
+
:param timestep: a 1-D Tensor of N indices, one per batch element.
|
| 79 |
+
These may be fractional.
|
| 80 |
+
:param frequency_embedding_size: the dimension of the output.
|
| 81 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 82 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 83 |
+
|
| 84 |
+
Based on DiT's `TimestepEmbedder`
|
| 85 |
+
https://github.com/facebookresearch/DiT/blob/main/models.py
|
| 86 |
+
"""
|
| 87 |
+
half = frequency_embedding_size // 2
|
| 88 |
+
|
| 89 |
+
freqs = torch.exp(
|
| 90 |
+
-math.log(max_period)
|
| 91 |
+
* torch.arange(start=0, end=half, dtype=torch.float32)
|
| 92 |
+
/ half
|
| 93 |
+
).to(device=timestep.device)
|
| 94 |
+
|
| 95 |
+
args = timestep[:, None].float() * freqs[None]
|
| 96 |
+
|
| 97 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 98 |
+
|
| 99 |
+
if frequency_embedding_size % 2:
|
| 100 |
+
embedding = torch.cat(
|
| 101 |
+
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
return embedding
|
| 105 |
+
|
| 106 |
+
def forward(self, timesteps: Tensor) -> Tensor:
|
| 107 |
+
initial_size = timesteps.size()
|
| 108 |
+
|
| 109 |
+
flat_timesteps = timesteps.view(-1, 1)
|
| 110 |
+
|
| 111 |
+
t_freq = self.sinusoidal_timestep_embedding(
|
| 112 |
+
flat_timesteps, self.frequency_embedding_size
|
| 113 |
+
).to(self.dtype)
|
| 114 |
+
|
| 115 |
+
t_emb = self.fc1(t_freq)
|
| 116 |
+
|
| 117 |
+
if self.nonlin is not None:
|
| 118 |
+
t_emb = self.nonlin(t_emb)
|
| 119 |
+
|
| 120 |
+
t_emb = self.fc2(t_emb)
|
| 121 |
+
|
| 122 |
+
return t_emb.view(*initial_size, self.embedding_dim)
|
lcm/nn/transformer/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from lcm.nn.transformer.attention import (
|
| 7 |
+
QKNormMultiheadAttention,
|
| 8 |
+
)
|
| 9 |
+
from lcm.nn.transformer.decoder import (
|
| 10 |
+
LCMStandardTransformerDecoderLayer,
|
| 11 |
+
LCMTransformerDecoder,
|
| 12 |
+
)
|
| 13 |
+
from lcm.nn.transformer.factory import (
|
| 14 |
+
TransformerConfig,
|
| 15 |
+
TransformerFactory,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
"QKNormMultiheadAttention",
|
| 20 |
+
"LCMStandardTransformerDecoderLayer",
|
| 21 |
+
"LCMTransformerDecoder",
|
| 22 |
+
"TransformerConfig",
|
| 23 |
+
"TransformerFactory",
|
| 24 |
+
]
|
lcm/nn/transformer/attention.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from typing import Optional, Tuple, final
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from fairseq2.nn.ops import repeat_interleave
|
| 11 |
+
from fairseq2.nn.padding import PaddingMask
|
| 12 |
+
from fairseq2.nn.position_encoder import PositionEncoder
|
| 13 |
+
from fairseq2.nn.projection import Projection
|
| 14 |
+
from fairseq2.nn.transformer import (
|
| 15 |
+
AttentionMask,
|
| 16 |
+
AttentionMaskFactory,
|
| 17 |
+
AttentionState,
|
| 18 |
+
AttentionStateFactory,
|
| 19 |
+
FullAttentionState,
|
| 20 |
+
LayerNormFactory,
|
| 21 |
+
StandardMultiheadAttention,
|
| 22 |
+
create_standard_layer_norm,
|
| 23 |
+
)
|
| 24 |
+
from fairseq2.nn.transformer.attention import SDPA
|
| 25 |
+
from fairseq2.typing import DataType, Device, override
|
| 26 |
+
from torch import Tensor
|
| 27 |
+
from torch.nn.parameter import Parameter
|
| 28 |
+
|
| 29 |
+
# FIXME revert to fs2's standard state bag if possible
|
| 30 |
+
from lcm.nn.incremental_state import (
|
| 31 |
+
LCMIncrementalStateBag,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@final
|
| 36 |
+
class QKNormMultiheadAttention(StandardMultiheadAttention): # type: ignore
|
| 37 |
+
"""Represents a Transformer multi-head attention as described in
|
| 38 |
+
:cite:t:`https://doi.org/10.48550/arxiv.1706.03762`
|
| 39 |
+
with two additional layer-normalization for keys and queries
|
| 40 |
+
as described in https://arxiv.org/pdf/2302.05442
|
| 41 |
+
and other related work
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
kv_dim: int
|
| 45 |
+
num_key_value_heads: int
|
| 46 |
+
q_proj: Projection
|
| 47 |
+
k_proj: Projection
|
| 48 |
+
v_proj: Projection
|
| 49 |
+
attn_mask_factory: Optional[AttentionMaskFactory]
|
| 50 |
+
pos_encoder: Optional[PositionEncoder]
|
| 51 |
+
bias_k: Optional[Parameter]
|
| 52 |
+
bias_v: Optional[Parameter]
|
| 53 |
+
add_zero_attn: bool
|
| 54 |
+
sdpa: SDPA
|
| 55 |
+
head_scale_weight: Optional[Parameter]
|
| 56 |
+
output_proj: Projection
|
| 57 |
+
state_factory: Optional[AttentionStateFactory]
|
| 58 |
+
layer_norm_factory: Optional[LayerNormFactory]
|
| 59 |
+
|
| 60 |
+
"""
|
| 61 |
+
For full parameters description see fairseq2/src/fairseq2/nn/transformer/multihead_attention.py
|
| 62 |
+
Parameters of interest to us:
|
| 63 |
+
:param num_key_value_heads:
|
| 64 |
+
The number of key/value heads for Grouped Query Attention as
|
| 65 |
+
described in :cite:t:`https://doi.org/10.48550/arXiv.2305.13245`.
|
| 66 |
+
If ``None`` or set to ``num_heads``, it is equivalent to standard
|
| 67 |
+
Multi Head Attention (MHA); if set to 1, it is equivalent to Multi
|
| 68 |
+
Query Attention (MQA).
|
| 69 |
+
|
| 70 |
+
:param enable_qk_layernorm:
|
| 71 |
+
If True follow Q/K projections with LayerNorms
|
| 72 |
+
|
| 73 |
+
:param weight_normalization:
|
| 74 |
+
If True, wrap K/Q/V projections with weight normalization for regularization
|
| 75 |
+
|
| 76 |
+
:param pos_encoder:
|
| 77 |
+
For RoPE positional encoder that adds positional encoding to keys
|
| 78 |
+
and queries before computing the attention scores
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
model_dim: int,
|
| 84 |
+
num_heads: int,
|
| 85 |
+
*,
|
| 86 |
+
kv_dim: Optional[int] = None,
|
| 87 |
+
num_key_value_heads: Optional[int] = None,
|
| 88 |
+
q_proj: Optional[Projection] = None,
|
| 89 |
+
k_proj: Optional[Projection] = None,
|
| 90 |
+
v_proj: Optional[Projection] = None,
|
| 91 |
+
attn_mask_factory: Optional[AttentionMaskFactory] = None,
|
| 92 |
+
pos_encoder: Optional[PositionEncoder] = None,
|
| 93 |
+
sdpa: Optional[SDPA] = None,
|
| 94 |
+
scale_heads: bool = False,
|
| 95 |
+
output_proj: Optional[Projection] = None,
|
| 96 |
+
bias: bool = True,
|
| 97 |
+
state_factory: Optional[AttentionStateFactory] = None,
|
| 98 |
+
enable_qk_layernorm: bool = False,
|
| 99 |
+
weight_normalization: bool = False,
|
| 100 |
+
layer_norm_factory: Optional[LayerNormFactory] = None,
|
| 101 |
+
device: Optional[Device] = None,
|
| 102 |
+
dtype: Optional[DataType] = None,
|
| 103 |
+
) -> None:
|
| 104 |
+
super().__init__(
|
| 105 |
+
model_dim=model_dim,
|
| 106 |
+
num_heads=num_heads,
|
| 107 |
+
kv_dim=kv_dim,
|
| 108 |
+
num_key_value_heads=num_key_value_heads,
|
| 109 |
+
q_proj=q_proj,
|
| 110 |
+
k_proj=k_proj,
|
| 111 |
+
v_proj=v_proj,
|
| 112 |
+
attn_mask_factory=attn_mask_factory,
|
| 113 |
+
pos_encoder=pos_encoder,
|
| 114 |
+
sdpa=sdpa,
|
| 115 |
+
scale_heads=scale_heads,
|
| 116 |
+
output_proj=output_proj,
|
| 117 |
+
bias=bias,
|
| 118 |
+
state_factory=state_factory,
|
| 119 |
+
device=device,
|
| 120 |
+
dtype=dtype,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# wrap linear layers with weight norm
|
| 124 |
+
if weight_normalization:
|
| 125 |
+
self.k_proj = nn.utils.parametrizations.weight_norm(self.k_proj)
|
| 126 |
+
self.q_proj = nn.utils.parametrizations.weight_norm(self.q_proj)
|
| 127 |
+
self.v_proj = nn.utils.parametrizations.weight_norm(self.v_proj)
|
| 128 |
+
|
| 129 |
+
self.enable_qk_layernorm = enable_qk_layernorm
|
| 130 |
+
# initialize q-k LayerNorms if needed
|
| 131 |
+
if self.enable_qk_layernorm:
|
| 132 |
+
if layer_norm_factory is None:
|
| 133 |
+
# use default LayerNorm factory
|
| 134 |
+
layer_norm_factory = create_standard_layer_norm
|
| 135 |
+
|
| 136 |
+
self.q_layer_norm = layer_norm_factory(
|
| 137 |
+
model_dim, device=device, dtype=dtype
|
| 138 |
+
)
|
| 139 |
+
self.k_layer_norm = layer_norm_factory(
|
| 140 |
+
self.kv_dim, device=device, dtype=dtype
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
@override
|
| 144 |
+
def _project_q( # type: ignore
|
| 145 |
+
self,
|
| 146 |
+
seqs: Tensor,
|
| 147 |
+
padding_mask: Optional[PaddingMask],
|
| 148 |
+
state_bag: Optional[LCMIncrementalStateBag] = None,
|
| 149 |
+
) -> Tensor:
|
| 150 |
+
# (N, S, M) -> (N, S, K_proj)
|
| 151 |
+
q = self.q_proj(seqs)
|
| 152 |
+
|
| 153 |
+
# normalize queries
|
| 154 |
+
if self.enable_qk_layernorm:
|
| 155 |
+
q = self.q_layer_norm(q)
|
| 156 |
+
|
| 157 |
+
# (N, S, K_proj) -> (N, H, S, K_h)
|
| 158 |
+
q = q.unflatten(-1, (self.num_heads, -1)).transpose(1, 2)
|
| 159 |
+
|
| 160 |
+
if self.pos_encoder is not None:
|
| 161 |
+
q = self.pos_encoder(
|
| 162 |
+
q,
|
| 163 |
+
padding_mask,
|
| 164 |
+
state_bag=state_bag,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
return q # type: ignore[no-any-return]
|
| 168 |
+
|
| 169 |
+
@override
|
| 170 |
+
def _project_kv( # type: ignore
|
| 171 |
+
self,
|
| 172 |
+
keys: Tensor,
|
| 173 |
+
key_padding_mask: Optional[PaddingMask],
|
| 174 |
+
values: Tensor,
|
| 175 |
+
state_bag: Optional[LCMIncrementalStateBag] = None,
|
| 176 |
+
) -> Tuple[Tensor, Tensor]:
|
| 177 |
+
# (N, S, K) -> (N, S, K_proj)
|
| 178 |
+
k = self.k_proj(keys)
|
| 179 |
+
|
| 180 |
+
# normalize keys
|
| 181 |
+
if self.enable_qk_layernorm:
|
| 182 |
+
k = self.k_layer_norm(k)
|
| 183 |
+
|
| 184 |
+
# (N, S, V) -> (N, S, V_proj)
|
| 185 |
+
v = self.v_proj(values)
|
| 186 |
+
|
| 187 |
+
# (N, S, K_proj) -> (N, H, S, K_h)
|
| 188 |
+
k = k.unflatten(-1, (self.num_key_value_heads, -1)).transpose(1, 2)
|
| 189 |
+
# (N, S, V_proj) -> (N, H, S, V_h)
|
| 190 |
+
v = v.unflatten(-1, (self.num_key_value_heads, -1)).transpose(1, 2)
|
| 191 |
+
|
| 192 |
+
if self.pos_encoder is not None:
|
| 193 |
+
k = self.pos_encoder(
|
| 194 |
+
k,
|
| 195 |
+
key_padding_mask,
|
| 196 |
+
state_bag=state_bag,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
return k, v
|
| 200 |
+
|
| 201 |
+
@override
|
| 202 |
+
def forward( # type: ignore
|
| 203 |
+
self,
|
| 204 |
+
seqs: Tensor,
|
| 205 |
+
padding_mask: Optional[PaddingMask],
|
| 206 |
+
keys: Tensor,
|
| 207 |
+
key_padding_mask: Optional[PaddingMask],
|
| 208 |
+
values: Tensor,
|
| 209 |
+
*,
|
| 210 |
+
attn_mask: Optional[AttentionMask] = None,
|
| 211 |
+
state_bag: Optional[LCMIncrementalStateBag] = None,
|
| 212 |
+
) -> Tensor:
|
| 213 |
+
# (N, S, M) -> (N, H, S, K_h)
|
| 214 |
+
q = self._project_q(
|
| 215 |
+
seqs,
|
| 216 |
+
padding_mask,
|
| 217 |
+
state_bag,
|
| 218 |
+
)
|
| 219 |
+
if self.training or state_bag is None:
|
| 220 |
+
# k: (N, S_kv, M) -> (N, H_kv, S_kv, K_h)
|
| 221 |
+
# v: (N, S_kv, M) -> (N, H_kv, S_kv, V_h)
|
| 222 |
+
k, v = self._project_kv(
|
| 223 |
+
keys,
|
| 224 |
+
key_padding_mask,
|
| 225 |
+
values,
|
| 226 |
+
)
|
| 227 |
+
else:
|
| 228 |
+
if key_padding_mask is not None:
|
| 229 |
+
raise ValueError(
|
| 230 |
+
"`key_padding_mask` must be `None` during incremental decoding."
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
# k: (N, S_step, M) -> (N, H_kv, S_step, K_h)
|
| 234 |
+
# v: (N, S_step, M) -> (N, H_kv, S_step, V_h)
|
| 235 |
+
k, v = self._project_kv(keys, key_padding_mask, values, state_bag)
|
| 236 |
+
|
| 237 |
+
state = state_bag.get_state(self, AttentionState) # type: ignore
|
| 238 |
+
|
| 239 |
+
if state is None:
|
| 240 |
+
state_factory = self.state_factory or FullAttentionState
|
| 241 |
+
|
| 242 |
+
state = state_factory(
|
| 243 |
+
k, v, state_bag.max_num_steps, state_bag.capacity_increment
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
state_bag.set_state(self, state)
|
| 247 |
+
else:
|
| 248 |
+
state.append(k, v)
|
| 249 |
+
|
| 250 |
+
# k: (N, H_kv, S_kv, K_h)
|
| 251 |
+
# v: (N, H_kv, S_kv, V_h)
|
| 252 |
+
|
| 253 |
+
k, v = state.get()
|
| 254 |
+
|
| 255 |
+
# With Grouped Query Attention, each key/value head is repeated.
|
| 256 |
+
if (num_query_groups := self.num_heads // self.num_key_value_heads) > 1:
|
| 257 |
+
# (N, H_kv, S_kv, K_h) -> (N, H, S_kv, K_h)
|
| 258 |
+
k = repeat_interleave(k, dim=1, repeat=num_query_groups)
|
| 259 |
+
# (N, H_kv, S_kv, K_h) -> (N, H, S_kv, V_h)
|
| 260 |
+
v = repeat_interleave(v, dim=1, repeat=num_query_groups)
|
| 261 |
+
|
| 262 |
+
if self.attn_mask_factory is not None:
|
| 263 |
+
attn_mask = self.attn_mask_factory(
|
| 264 |
+
seqs, keys=keys, training=self.training, state_bag=state_bag
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
needs_weights = len(self._attn_weight_hooks) > 0
|
| 268 |
+
|
| 269 |
+
# attn: (N, H, S, V_h)
|
| 270 |
+
# attn_weights: (N, H, S, S_kv)
|
| 271 |
+
|
| 272 |
+
attn, attn_weights = self.sdpa(
|
| 273 |
+
q,
|
| 274 |
+
k,
|
| 275 |
+
key_padding_mask,
|
| 276 |
+
v,
|
| 277 |
+
attn_mask=attn_mask,
|
| 278 |
+
needs_weights=needs_weights,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
if attn_weights is not None:
|
| 282 |
+
for hook in self._attn_weight_hooks.values():
|
| 283 |
+
hook(self, attn, attn_weights)
|
| 284 |
+
|
| 285 |
+
# (N, H, S, V_h) -> (N, S, H, V_h)
|
| 286 |
+
attn = attn.transpose(1, 2)
|
| 287 |
+
|
| 288 |
+
if self.head_scale_weight is not None:
|
| 289 |
+
attn = torch.einsum("nshv,h->nshv", attn, self.head_scale_weight)
|
| 290 |
+
|
| 291 |
+
# (N, S, H, V_h) -> (N, S, V_proj)
|
| 292 |
+
attn = attn.flatten(2, 3)
|
| 293 |
+
|
| 294 |
+
# (N, S, V_proj) -> (N, S, M)
|
| 295 |
+
|
| 296 |
+
attn = self.output_proj(attn)
|
| 297 |
+
|
| 298 |
+
return attn # type: ignore[no-any-return]
|
| 299 |
+
|
| 300 |
+
@override
|
| 301 |
+
def extra_repr(self) -> str:
|
| 302 |
+
""":meta private:"""
|
| 303 |
+
s = super().extra_repr()
|
| 304 |
+
|
| 305 |
+
s = f"{s}, enable_qk_layernorm={self.enable_qk_layernorm}"
|
| 306 |
+
|
| 307 |
+
return s
|
lcm/nn/transformer/decoder.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from typing import List, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
from fairseq2.nn.padding import PaddingMask
|
| 9 |
+
from fairseq2.nn.transformer import (
|
| 10 |
+
AttentionMask,
|
| 11 |
+
AttentionMaskFactory,
|
| 12 |
+
LayerNormFactory,
|
| 13 |
+
StandardTransformerDecoderLayer,
|
| 14 |
+
TransformerDecoder,
|
| 15 |
+
TransformerDecoderLayer,
|
| 16 |
+
TransformerNormOrder,
|
| 17 |
+
)
|
| 18 |
+
from fairseq2.typing import DataType, Device, override
|
| 19 |
+
from torch import Generator, Tensor
|
| 20 |
+
from torch.nn import Dropout, ModuleList
|
| 21 |
+
|
| 22 |
+
from lcm.nn.incremental_state import LCMIncrementalStateBag
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class LCMStandardTransformerDecoderLayer(StandardTransformerDecoderLayer): # type: ignore
|
| 26 |
+
"""Pass on `source_lengths` to StandardTransformerDecoderLayer's forward_pass."""
|
| 27 |
+
|
| 28 |
+
@override
|
| 29 |
+
def forward( # type: ignore
|
| 30 |
+
self,
|
| 31 |
+
seqs: Tensor,
|
| 32 |
+
padding_mask: Optional[PaddingMask],
|
| 33 |
+
self_attn_mask: Optional[AttentionMask] = None,
|
| 34 |
+
encoder_output: Optional[Tensor] = None,
|
| 35 |
+
encoder_padding_mask: Optional[PaddingMask] = None,
|
| 36 |
+
*,
|
| 37 |
+
state_bag: Optional[LCMIncrementalStateBag] = None,
|
| 38 |
+
) -> Tuple[Tensor, Optional[PaddingMask]]:
|
| 39 |
+
seqs = self._forward_self_attn(
|
| 40 |
+
seqs,
|
| 41 |
+
padding_mask,
|
| 42 |
+
self_attn_mask,
|
| 43 |
+
state_bag,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
seqs = self._forward_encoder_decoder_attn(
|
| 47 |
+
seqs, padding_mask, encoder_output, encoder_padding_mask, state_bag
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
seqs = self._forward_ffn(seqs)
|
| 51 |
+
|
| 52 |
+
return seqs, padding_mask
|
| 53 |
+
|
| 54 |
+
@override
|
| 55 |
+
def _forward_self_attn( # type: ignore
|
| 56 |
+
self,
|
| 57 |
+
seqs: Tensor,
|
| 58 |
+
padding_mask: Optional[PaddingMask],
|
| 59 |
+
self_attn_mask: Optional[AttentionMask],
|
| 60 |
+
state_bag: Optional[LCMIncrementalStateBag],
|
| 61 |
+
) -> Tensor:
|
| 62 |
+
residual = seqs
|
| 63 |
+
|
| 64 |
+
if self.norm_order != TransformerNormOrder.POST:
|
| 65 |
+
seqs = self.self_attn_layer_norm(seqs)
|
| 66 |
+
|
| 67 |
+
seqs = self.self_attn(
|
| 68 |
+
seqs,
|
| 69 |
+
padding_mask,
|
| 70 |
+
keys=seqs,
|
| 71 |
+
key_padding_mask=padding_mask,
|
| 72 |
+
values=seqs,
|
| 73 |
+
attn_mask=self_attn_mask,
|
| 74 |
+
state_bag=state_bag,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
if self.self_attn_norm is not None:
|
| 78 |
+
seqs = self.self_attn_norm(seqs)
|
| 79 |
+
|
| 80 |
+
if self.self_attn_dropout is not None:
|
| 81 |
+
seqs = self.self_attn_dropout(seqs)
|
| 82 |
+
|
| 83 |
+
seqs = seqs + residual
|
| 84 |
+
|
| 85 |
+
if self.norm_order == TransformerNormOrder.POST:
|
| 86 |
+
seqs = self.self_attn_layer_norm(seqs)
|
| 87 |
+
|
| 88 |
+
return seqs
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class LCMTransformerDecoder(TransformerDecoder):
|
| 92 |
+
def __init__(
|
| 93 |
+
self,
|
| 94 |
+
layers: List[TransformerDecoderLayer],
|
| 95 |
+
layer_norm_factory: LayerNormFactory,
|
| 96 |
+
self_attn_mask_factory: AttentionMaskFactory,
|
| 97 |
+
use_causal_attn_mask: bool = True,
|
| 98 |
+
generator: Optional[Generator] = None,
|
| 99 |
+
dropout_p: float = 0.0,
|
| 100 |
+
norm_order: TransformerNormOrder = TransformerNormOrder.POST,
|
| 101 |
+
device: Optional[Device] = None,
|
| 102 |
+
dtype: Optional[DataType] = None,
|
| 103 |
+
) -> None:
|
| 104 |
+
layer_list = ModuleList(layers)
|
| 105 |
+
|
| 106 |
+
if not layer_list:
|
| 107 |
+
raise ValueError("`layers` must be non-empty.")
|
| 108 |
+
|
| 109 |
+
model_dim = layer_list[0].model_dim
|
| 110 |
+
|
| 111 |
+
super().__init__(model_dim)
|
| 112 |
+
|
| 113 |
+
self.self_attn_mask_factory = self_attn_mask_factory
|
| 114 |
+
|
| 115 |
+
self.layers = layer_list
|
| 116 |
+
|
| 117 |
+
self.generator = generator
|
| 118 |
+
|
| 119 |
+
if norm_order != TransformerNormOrder.POST:
|
| 120 |
+
self.layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype)
|
| 121 |
+
else:
|
| 122 |
+
self.register_module("layer_norm", None)
|
| 123 |
+
|
| 124 |
+
if dropout_p > 0.0:
|
| 125 |
+
self.dropout = Dropout(dropout_p)
|
| 126 |
+
else:
|
| 127 |
+
self.register_module("dropout", None)
|
| 128 |
+
|
| 129 |
+
self.norm_order = norm_order
|
| 130 |
+
|
| 131 |
+
@override
|
| 132 |
+
def forward( # type: ignore
|
| 133 |
+
self,
|
| 134 |
+
seqs: Tensor,
|
| 135 |
+
padding_mask: Optional[PaddingMask],
|
| 136 |
+
encoder_output: Optional[Tensor] = None,
|
| 137 |
+
encoder_padding_mask: Optional[PaddingMask] = None,
|
| 138 |
+
*,
|
| 139 |
+
state_bag: Optional[LCMIncrementalStateBag] = None,
|
| 140 |
+
**kwargs,
|
| 141 |
+
) -> Tuple[Tensor, Optional[PaddingMask]]:
|
| 142 |
+
"""Pass on two additional arguments to StandardTransformerDecoder's forward_pass:"""
|
| 143 |
+
num_layers = len(self.layers)
|
| 144 |
+
|
| 145 |
+
self_attn_mask: Optional[AttentionMask] = None
|
| 146 |
+
if self.self_attn_mask_factory is not None:
|
| 147 |
+
self_attn_mask = self.self_attn_mask_factory(
|
| 148 |
+
seqs,
|
| 149 |
+
keys=seqs,
|
| 150 |
+
training=self.training,
|
| 151 |
+
state_bag=state_bag,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
for layer_idx, layer in enumerate(self.layers):
|
| 155 |
+
layer_output, layer_padding_mask = layer(
|
| 156 |
+
seqs,
|
| 157 |
+
padding_mask,
|
| 158 |
+
self_attn_mask,
|
| 159 |
+
encoder_output,
|
| 160 |
+
encoder_padding_mask,
|
| 161 |
+
state_bag=state_bag,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
seqs, padding_mask = layer_output, layer_padding_mask
|
| 165 |
+
|
| 166 |
+
for hook in self._layer_output_hooks.values():
|
| 167 |
+
if not hook(layer_idx, seqs, padding_mask, num_layers):
|
| 168 |
+
break
|
| 169 |
+
|
| 170 |
+
if self.layer_norm is not None:
|
| 171 |
+
seqs = self.layer_norm(seqs)
|
| 172 |
+
|
| 173 |
+
if self.dropout is not None:
|
| 174 |
+
seqs = self.dropout(seqs)
|
| 175 |
+
|
| 176 |
+
return seqs, padding_mask
|
lcm/nn/transformer/factory.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Literal, Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from fairseq2.logging import get_log_writer
|
| 11 |
+
from fairseq2.nn import PositionEncoder
|
| 12 |
+
from fairseq2.nn.position_encoder import (
|
| 13 |
+
LearnedPositionEncoder,
|
| 14 |
+
RotaryEncoder,
|
| 15 |
+
SinusoidalPositionEncoder,
|
| 16 |
+
)
|
| 17 |
+
from fairseq2.nn.projection import Linear
|
| 18 |
+
from fairseq2.nn.transformer import (
|
| 19 |
+
FeedForwardNetwork,
|
| 20 |
+
GLUFeedForwardNetwork,
|
| 21 |
+
MultiheadAttention,
|
| 22 |
+
StandardFeedForwardNetwork,
|
| 23 |
+
TransformerDecoderLayer,
|
| 24 |
+
create_default_sdpa,
|
| 25 |
+
)
|
| 26 |
+
from fairseq2.typing import DataType, Device
|
| 27 |
+
|
| 28 |
+
from lcm.nn.initialization import (
|
| 29 |
+
SUPPORTED_INIT_TYPES,
|
| 30 |
+
get_init_fn,
|
| 31 |
+
parse_activation_fn,
|
| 32 |
+
parse_norm_order,
|
| 33 |
+
)
|
| 34 |
+
from lcm.nn.normalization import SUPPORTED_LN_TYPES, parse_layer_norm_factory
|
| 35 |
+
from lcm.nn.transformer import LCMStandardTransformerDecoderLayer
|
| 36 |
+
from lcm.nn.transformer.attention import (
|
| 37 |
+
FullAttentionState,
|
| 38 |
+
QKNormMultiheadAttention,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
SUPPORTED_NORM_ORDERS = Literal["pre", "post", "normformer"]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
logger = get_log_writer(__name__)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@dataclass
|
| 48 |
+
class TransformerConfig:
|
| 49 |
+
"""A config object to group all config
|
| 50 |
+
hyper-parameters of a LCMTransformerDecoder"""
|
| 51 |
+
|
| 52 |
+
num_layers: int = 2
|
| 53 |
+
|
| 54 |
+
num_attn_heads: int = 8
|
| 55 |
+
|
| 56 |
+
# Dropout rates
|
| 57 |
+
dropout_p: float = 0.1
|
| 58 |
+
""" The dropout probability outputs of the attention layers and the
|
| 59 |
+
feed-forward network (before joining the residual stream)"""
|
| 60 |
+
|
| 61 |
+
final_dropout_p: float = 0.1
|
| 62 |
+
""" The dropout probability on decoder outputs"""
|
| 63 |
+
|
| 64 |
+
attention_dropout_p: float = 0.0
|
| 65 |
+
"""the dropout rate on attention weights in SDPA"""
|
| 66 |
+
|
| 67 |
+
# FFN
|
| 68 |
+
ffn_inner_dim: int = 1024 * 4
|
| 69 |
+
|
| 70 |
+
use_swiglu: bool = False
|
| 71 |
+
"""Use GLUFeedForwardNetwork instead of regular FFN blocks"""
|
| 72 |
+
|
| 73 |
+
ffn_inner_activation_name: str = "relu"
|
| 74 |
+
|
| 75 |
+
"""The activation to apply to outputs of the FFN inner projection layer.
|
| 76 |
+
Default is `relu `i.e., `torch.nn.ReLU`. This is only relevant when `use_swiglu= False`"""
|
| 77 |
+
|
| 78 |
+
# positional embedding
|
| 79 |
+
pos_embedding_style: Literal["rope", "sine", "learned", "none"] = "learned"
|
| 80 |
+
|
| 81 |
+
"""If `rope`: a rotary positional encoder in used in the attention layers.
|
| 82 |
+
If `sine`: Sinusoidal positional embeddings will be added in
|
| 83 |
+
the frontend before heading into the decoder
|
| 84 |
+
If `learned`: Learned positional embeddings will be added in
|
| 85 |
+
the frontend before heading into the decoder.
|
| 86 |
+
If `None`: no positional embeddings will be used (e.g. in the case
|
| 87 |
+
of unconditional diffusion of a single vector)."""
|
| 88 |
+
|
| 89 |
+
rope_theta: float = 10_000.0
|
| 90 |
+
""" The coefficient of the long-term decay of RoPE embeddings."""
|
| 91 |
+
|
| 92 |
+
# Normalization
|
| 93 |
+
layer_normalization_style: SUPPORTED_LN_TYPES = "standard"
|
| 94 |
+
|
| 95 |
+
norm_order_style: SUPPORTED_NORM_ORDERS = "pre"
|
| 96 |
+
"""LayerNorm order in the transformer decoder,
|
| 97 |
+
default is pre-normalization (`pre`). Other options are post-normalization (`post`)
|
| 98 |
+
and normformer-style normalization (`normformer`)"""
|
| 99 |
+
|
| 100 |
+
final_norm_order_style: Optional[SUPPORTED_NORM_ORDERS] = None
|
| 101 |
+
"""Controls lcm-level norm-order, using ``post`` here with a ``pre`` layer-level norm-order
|
| 102 |
+
means that we will skip the last layernorm in the stack"""
|
| 103 |
+
|
| 104 |
+
enable_qk_layernorm: bool = False
|
| 105 |
+
"""If ``True``, LayerNorms will be applied to queries and keys in self-attention layers
|
| 106 |
+
QK-LayerNorm described in https://arxiv.org/pdf/2302.05442 and subsequent work
|
| 107 |
+
is recommended to alleviate Transformer training instabilities
|
| 108 |
+
"""
|
| 109 |
+
mha_qkv_weight_normalization: bool = False
|
| 110 |
+
"""if ``True`` wrap the K/Q/V linears of MHA in weight normalization"""
|
| 111 |
+
|
| 112 |
+
mha_output_weight_normalization: bool = False
|
| 113 |
+
"""if ``True`` wrap the output projection of MHA with weight normalization.
|
| 114 |
+
This is a temporary fix to resume training some models and will be removed"""
|
| 115 |
+
|
| 116 |
+
# Miscellaneous
|
| 117 |
+
mha_output_proj_bias: bool = False
|
| 118 |
+
"""If ``True`` add a bias term to the MHA output projection"""
|
| 119 |
+
|
| 120 |
+
scale_residual: Optional[float] = None
|
| 121 |
+
"""scale to multiply the residual in the Transformer decoder"""
|
| 122 |
+
|
| 123 |
+
attention_output_init_fn: SUPPORTED_INIT_TYPES = "xavier"
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class TransformerFactory:
|
| 127 |
+
def __init__(
|
| 128 |
+
self,
|
| 129 |
+
model_dim: int,
|
| 130 |
+
max_seq_len: int,
|
| 131 |
+
config: TransformerConfig,
|
| 132 |
+
device: Optional[Device] = None,
|
| 133 |
+
dtype: Optional[DataType] = None,
|
| 134 |
+
) -> None:
|
| 135 |
+
"""
|
| 136 |
+
:param model_dim:
|
| 137 |
+
The hidden model dimension of the Transformer
|
| 138 |
+
:params max_seq_len:
|
| 139 |
+
Maximum supported sequence length by the model
|
| 140 |
+
:param config:
|
| 141 |
+
The configuration.
|
| 142 |
+
:param device:
|
| 143 |
+
The device on which to initialize modules.
|
| 144 |
+
:param dtype:
|
| 145 |
+
The data type of module parameters and buffers.
|
| 146 |
+
"""
|
| 147 |
+
self.model_dim = model_dim
|
| 148 |
+
self.max_seq_len = max_seq_len
|
| 149 |
+
self.config = config
|
| 150 |
+
self.device, self.dtype = device, dtype
|
| 151 |
+
|
| 152 |
+
def build_layer(self) -> TransformerDecoderLayer:
|
| 153 |
+
"""Build a Transformer decoder layer based on the provided config."""
|
| 154 |
+
|
| 155 |
+
self_attn = self.build_attention()
|
| 156 |
+
|
| 157 |
+
ffn = self.build_ffn()
|
| 158 |
+
|
| 159 |
+
norm_order = parse_norm_order(self.config.norm_order_style)
|
| 160 |
+
|
| 161 |
+
layer_norm_factory = parse_layer_norm_factory(
|
| 162 |
+
self.config.layer_normalization_style
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
layer = LCMStandardTransformerDecoderLayer(
|
| 166 |
+
self_attn=self_attn,
|
| 167 |
+
encoder_decoder_attn=None,
|
| 168 |
+
ffn=ffn,
|
| 169 |
+
dropout_p=self.config.dropout_p,
|
| 170 |
+
norm_order=norm_order,
|
| 171 |
+
layer_norm_factory=layer_norm_factory,
|
| 172 |
+
scale_residual=self.config.scale_residual is not None,
|
| 173 |
+
device=self.device,
|
| 174 |
+
dtype=self.dtype,
|
| 175 |
+
)
|
| 176 |
+
# reset residual_scale
|
| 177 |
+
if layer.residual_scale is not None:
|
| 178 |
+
assert self.config.scale_residual is not None, (
|
| 179 |
+
f"Layer has a resiudal scale but scale={self.config.scale_residual}"
|
| 180 |
+
)
|
| 181 |
+
torch.nn.init.constant_(layer.residual_scale, self.config.scale_residual)
|
| 182 |
+
logger.info(
|
| 183 |
+
f"Initializing the residual scale at {self.config.scale_residual}"
|
| 184 |
+
)
|
| 185 |
+
return layer
|
| 186 |
+
|
| 187 |
+
def build_pos_encoder(self) -> Optional[PositionEncoder]:
|
| 188 |
+
"""Build the positional encoder (learned or sinusoidal, if any)
|
| 189 |
+
that will be used in the frontend"""
|
| 190 |
+
pos_encoder: Optional[PositionEncoder]
|
| 191 |
+
|
| 192 |
+
if self.config.pos_embedding_style == "learned":
|
| 193 |
+
pos_encoder = LearnedPositionEncoder(
|
| 194 |
+
self.model_dim,
|
| 195 |
+
self.max_seq_len,
|
| 196 |
+
device=self.device,
|
| 197 |
+
dtype=self.dtype,
|
| 198 |
+
)
|
| 199 |
+
elif self.config.pos_embedding_style == "sine":
|
| 200 |
+
pos_encoder = SinusoidalPositionEncoder(
|
| 201 |
+
self.model_dim,
|
| 202 |
+
self.max_seq_len,
|
| 203 |
+
device=self.device,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
else:
|
| 207 |
+
pos_encoder = None
|
| 208 |
+
|
| 209 |
+
return pos_encoder
|
| 210 |
+
|
| 211 |
+
def build_attention_pos_encoder(self) -> Optional[PositionEncoder]:
|
| 212 |
+
"""Build the position encoder that can
|
| 213 |
+
potentially be used in the MHA module"""
|
| 214 |
+
|
| 215 |
+
pos_encoder: Optional[PositionEncoder]
|
| 216 |
+
|
| 217 |
+
if self.config.pos_embedding_style == "rope":
|
| 218 |
+
pos_encoder = RotaryEncoder(
|
| 219 |
+
encoding_dim=self.model_dim // self.config.num_attn_heads,
|
| 220 |
+
max_seq_len=self.max_seq_len,
|
| 221 |
+
theta=self.config.rope_theta,
|
| 222 |
+
device=self.device,
|
| 223 |
+
)
|
| 224 |
+
else:
|
| 225 |
+
pos_encoder = None
|
| 226 |
+
return pos_encoder
|
| 227 |
+
|
| 228 |
+
def build_attention(self) -> MultiheadAttention:
|
| 229 |
+
"""Build a Transformer multi-head attention layer."""
|
| 230 |
+
|
| 231 |
+
# allow for a different kv_dim
|
| 232 |
+
kv_dim = self.model_dim
|
| 233 |
+
|
| 234 |
+
# fairseq2.nn.transformer.attention.TorchSDPA
|
| 235 |
+
sdpa = create_default_sdpa(attn_dropout_p=self.config.attention_dropout_p)
|
| 236 |
+
|
| 237 |
+
init_fn = get_init_fn(self.config.attention_output_init_fn)
|
| 238 |
+
|
| 239 |
+
# How does Rope play with encoder-decoder attention?
|
| 240 |
+
pos_encoder = self.build_attention_pos_encoder()
|
| 241 |
+
|
| 242 |
+
layer_norm_factory = parse_layer_norm_factory(
|
| 243 |
+
self.config.layer_normalization_style
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# build output_proj:
|
| 247 |
+
output_proj = Linear(
|
| 248 |
+
self.model_dim,
|
| 249 |
+
self.model_dim,
|
| 250 |
+
bias=self.config.mha_output_proj_bias,
|
| 251 |
+
init_fn=init_fn,
|
| 252 |
+
device=self.device,
|
| 253 |
+
dtype=self.dtype,
|
| 254 |
+
)
|
| 255 |
+
if self.config.mha_output_weight_normalization:
|
| 256 |
+
output_proj = torch.nn.utils.parametrizations.weight_norm(output_proj)
|
| 257 |
+
|
| 258 |
+
return QKNormMultiheadAttention(
|
| 259 |
+
self.model_dim,
|
| 260 |
+
self.config.num_attn_heads,
|
| 261 |
+
kv_dim=kv_dim,
|
| 262 |
+
pos_encoder=pos_encoder,
|
| 263 |
+
sdpa=sdpa,
|
| 264 |
+
output_proj=output_proj,
|
| 265 |
+
enable_qk_layernorm=self.config.enable_qk_layernorm,
|
| 266 |
+
weight_normalization=self.config.mha_qkv_weight_normalization,
|
| 267 |
+
layer_norm_factory=layer_norm_factory,
|
| 268 |
+
state_factory=FullAttentionState,
|
| 269 |
+
device=self.device,
|
| 270 |
+
dtype=self.dtype,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
def build_ffn(self) -> FeedForwardNetwork:
|
| 274 |
+
"""Build a Transformer feed-forward network."""
|
| 275 |
+
if self.config.use_swiglu:
|
| 276 |
+
# Default gate_activation is torch.nn.SiLU
|
| 277 |
+
return GLUFeedForwardNetwork(
|
| 278 |
+
self.model_dim,
|
| 279 |
+
self.config.ffn_inner_dim,
|
| 280 |
+
bias=True,
|
| 281 |
+
inner_dim_scale=2 / 3,
|
| 282 |
+
inner_dim_to_multiple=256,
|
| 283 |
+
device=self.device,
|
| 284 |
+
dtype=self.dtype,
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
ffn_inner_activation = parse_activation_fn(
|
| 288 |
+
self.config.ffn_inner_activation_name
|
| 289 |
+
)
|
| 290 |
+
norm_order = parse_norm_order(self.config.norm_order_style)
|
| 291 |
+
|
| 292 |
+
return StandardFeedForwardNetwork(
|
| 293 |
+
self.model_dim,
|
| 294 |
+
self.config.ffn_inner_dim,
|
| 295 |
+
inner_activation=ffn_inner_activation,
|
| 296 |
+
bias=True,
|
| 297 |
+
norm_order=norm_order,
|
| 298 |
+
device=self.device,
|
| 299 |
+
dtype=self.dtype,
|
| 300 |
+
)
|