Spaces:
No application file
No application file
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +15 -0
- .gitignore +228 -0
- .gitmodules +6 -0
- .gradio/certificate.pem +31 -0
- Discrete-Diffusion-Forcing/.gitignore +3 -0
- Discrete-Diffusion-Forcing/.python-version +1 -0
- Discrete-Diffusion-Forcing/.vscode/launch.json +37 -0
- Discrete-Diffusion-Forcing/D2F-eval/eval_dream.py +1155 -0
- Discrete-Diffusion-Forcing/D2F-eval/eval_dream.sh +158 -0
- Discrete-Diffusion-Forcing/D2F-eval/eval_dream_d2f_vllm.py +764 -0
- Discrete-Diffusion-Forcing/D2F-eval/eval_dream_d2f_vllm.sh +135 -0
- Discrete-Diffusion-Forcing/D2F-eval/eval_llada.py +1198 -0
- Discrete-Diffusion-Forcing/D2F-eval/eval_llada.sh +155 -0
- Discrete-Diffusion-Forcing/D2F-eval/generate_llada_demo_ar.py +660 -0
- Discrete-Diffusion-Forcing/D2F-eval/generate_llada_demo_block.py +630 -0
- Discrete-Diffusion-Forcing/D2F-eval/model_cache/dream/configuration_dream.py +88 -0
- Discrete-Diffusion-Forcing/D2F-eval/model_cache/dream/generation_utils.py +463 -0
- Discrete-Diffusion-Forcing/D2F-eval/model_cache/dream/model_dream.py +1029 -0
- Discrete-Diffusion-Forcing/D2F-eval/model_cache/llada/configuration_llada.py +463 -0
- Discrete-Diffusion-Forcing/D2F-eval/model_cache/llada/modeling_llada.py +1504 -0
- Discrete-Diffusion-Forcing/D2F-eval/postprocess_code.py +62 -0
- Discrete-Diffusion-Forcing/D2F-eval/sanitize.py +147 -0
- Discrete-Diffusion-Forcing/D2F-train/config/acc_config +23 -0
- Discrete-Diffusion-Forcing/D2F-train/config/dream_eagle.yaml +59 -0
- Discrete-Diffusion-Forcing/D2F-train/config/llada.yaml +59 -0
- Discrete-Diffusion-Forcing/D2F-train/model/configuration_llada.py +463 -0
- Discrete-Diffusion-Forcing/D2F-train/model/modeling_llada.py +1500 -0
- Discrete-Diffusion-Forcing/D2F-train/train.py +215 -0
- Discrete-Diffusion-Forcing/D2F-train/train.sh +4 -0
- Discrete-Diffusion-Forcing/D2F-train/utils/__init__.py +0 -0
- Discrete-Diffusion-Forcing/D2F-train/utils/data.py +313 -0
- Discrete-Diffusion-Forcing/D2F-train/utils/generation.py +144 -0
- Discrete-Diffusion-Forcing/D2F-train/utils/loss.py +193 -0
- Discrete-Diffusion-Forcing/D2F-train/utils/model.py +59 -0
- Discrete-Diffusion-Forcing/D2F-train/utils/util.py +159 -0
- Discrete-Diffusion-Forcing/Discrete Diffusion Forcing.pdf +3 -0
- Discrete-Diffusion-Forcing/LICENCE +21 -0
- Discrete-Diffusion-Forcing/README.md +376 -0
- Discrete-Diffusion-Forcing/docs/_config.yml +1 -0
- Discrete-Diffusion-Forcing/docs/assets/img/d2f/fig1_main_result.png +3 -0
- Discrete-Diffusion-Forcing/docs/assets/img/d2f/fig2_tradeoff.png +3 -0
- Discrete-Diffusion-Forcing/docs/assets/img/d2f/fig3_overview.png +3 -0
- Discrete-Diffusion-Forcing/docs/assets/img/d2f/fig4_pipeline.png +3 -0
- Discrete-Diffusion-Forcing/docs/assets/img/d2f/logo_lr.jpg +0 -0
- Discrete-Diffusion-Forcing/docs/assets/img/d2f/logo_lr.png +3 -0
- Discrete-Diffusion-Forcing/docs/assets/img/d2f/logo_ud.jpg +0 -0
- Discrete-Diffusion-Forcing/docs/assets/img/d2f/logo_ud.png +3 -0
- Discrete-Diffusion-Forcing/docs/assets/img/d2f/table1_llada_results.png +3 -0
- Discrete-Diffusion-Forcing/docs/assets/img/d2f/table2_dream_results.png +3 -0
- Discrete-Diffusion-Forcing/docs/assets/img/d2f/wechat.png +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,18 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
Discrete-Diffusion-Forcing/Discrete[[:space:]]Diffusion[[:space:]]Forcing.pdf filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
Discrete-Diffusion-Forcing/docs/assets/img/d2f/fig1_main_result.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
Discrete-Diffusion-Forcing/docs/assets/img/d2f/fig2_tradeoff.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
Discrete-Diffusion-Forcing/docs/assets/img/d2f/fig3_overview.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
Discrete-Diffusion-Forcing/docs/assets/img/d2f/fig4_pipeline.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
Discrete-Diffusion-Forcing/docs/assets/img/d2f/logo_lr.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
Discrete-Diffusion-Forcing/docs/assets/img/d2f/logo_ud.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
Discrete-Diffusion-Forcing/docs/assets/img/d2f/table1_llada_results.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
Discrete-Diffusion-Forcing/docs/assets/img/d2f/table2_dream_results.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
Discrete-Diffusion-Forcing/docs/assets/img/d2f/wechat.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
Discrete-Diffusion-Forcing/docs/assets/video/block_demo.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
Discrete-Diffusion-Forcing/docs/assets/video/block_demo_small.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
Discrete-Diffusion-Forcing/docs/assets/video/d2f_vs_ar_demo.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
Discrete-Diffusion-Forcing/docs/assets/video/d2f_vs_ar_demo_small.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
wandb/run-20251017_233420-gok04idh/run-gok04idh.wandb filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[codz]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py.cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
# Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# UV
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
# uv.lock
|
| 102 |
+
|
| 103 |
+
# poetry
|
| 104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 106 |
+
# commonly ignored for libraries.
|
| 107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 108 |
+
# poetry.lock
|
| 109 |
+
# poetry.toml
|
| 110 |
+
|
| 111 |
+
# pdm
|
| 112 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 113 |
+
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
| 114 |
+
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
| 115 |
+
# pdm.lock
|
| 116 |
+
# pdm.toml
|
| 117 |
+
.pdm-python
|
| 118 |
+
.pdm-build/
|
| 119 |
+
|
| 120 |
+
# pixi
|
| 121 |
+
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
| 122 |
+
# pixi.lock
|
| 123 |
+
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
| 124 |
+
# in the .venv directory. It is recommended not to include this directory in version control.
|
| 125 |
+
.pixi
|
| 126 |
+
|
| 127 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 128 |
+
__pypackages__/
|
| 129 |
+
|
| 130 |
+
# Celery stuff
|
| 131 |
+
celerybeat-schedule
|
| 132 |
+
celerybeat.pid
|
| 133 |
+
|
| 134 |
+
# Redis
|
| 135 |
+
*.rdb
|
| 136 |
+
*.aof
|
| 137 |
+
*.pid
|
| 138 |
+
|
| 139 |
+
# RabbitMQ
|
| 140 |
+
mnesia/
|
| 141 |
+
rabbitmq/
|
| 142 |
+
rabbitmq-data/
|
| 143 |
+
|
| 144 |
+
# ActiveMQ
|
| 145 |
+
activemq-data/
|
| 146 |
+
|
| 147 |
+
# SageMath parsed files
|
| 148 |
+
*.sage.py
|
| 149 |
+
|
| 150 |
+
# Environments
|
| 151 |
+
.env
|
| 152 |
+
.envrc
|
| 153 |
+
.venv
|
| 154 |
+
env/
|
| 155 |
+
venv/
|
| 156 |
+
ENV/
|
| 157 |
+
env.bak/
|
| 158 |
+
venv.bak/
|
| 159 |
+
|
| 160 |
+
# Spyder project settings
|
| 161 |
+
.spyderproject
|
| 162 |
+
.spyproject
|
| 163 |
+
|
| 164 |
+
# Rope project settings
|
| 165 |
+
.ropeproject
|
| 166 |
+
|
| 167 |
+
# mkdocs documentation
|
| 168 |
+
/site
|
| 169 |
+
|
| 170 |
+
# mypy
|
| 171 |
+
.mypy_cache/
|
| 172 |
+
.dmypy.json
|
| 173 |
+
dmypy.json
|
| 174 |
+
|
| 175 |
+
# Pyre type checker
|
| 176 |
+
.pyre/
|
| 177 |
+
|
| 178 |
+
# pytype static type analyzer
|
| 179 |
+
.pytype/
|
| 180 |
+
|
| 181 |
+
# Cython debug symbols
|
| 182 |
+
cython_debug/
|
| 183 |
+
|
| 184 |
+
# PyCharm
|
| 185 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 186 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 187 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 188 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 189 |
+
# .idea/
|
| 190 |
+
|
| 191 |
+
# Abstra
|
| 192 |
+
# Abstra is an AI-powered process automation framework.
|
| 193 |
+
# Ignore directories containing user credentials, local state, and settings.
|
| 194 |
+
# Learn more at https://abstra.io/docs
|
| 195 |
+
.abstra/
|
| 196 |
+
|
| 197 |
+
# Visual Studio Code
|
| 198 |
+
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
| 199 |
+
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
| 200 |
+
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
| 201 |
+
# you could uncomment the following to ignore the entire vscode folder
|
| 202 |
+
# .vscode/
|
| 203 |
+
|
| 204 |
+
# Ruff stuff:
|
| 205 |
+
.ruff_cache/
|
| 206 |
+
|
| 207 |
+
# PyPI configuration file
|
| 208 |
+
.pypirc
|
| 209 |
+
|
| 210 |
+
# Marimo
|
| 211 |
+
marimo/_static/
|
| 212 |
+
marimo/_lsp/
|
| 213 |
+
__marimo__/
|
| 214 |
+
|
| 215 |
+
# Streamlit
|
| 216 |
+
.streamlit/secrets.toml
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
.vscode/*
|
| 220 |
+
!.vscode/settings.json
|
| 221 |
+
!.vscode/tasks.json
|
| 222 |
+
!.vscode/launch.json
|
| 223 |
+
!.vscode/extensions.json
|
| 224 |
+
!.vscode/*.code-snippets
|
| 225 |
+
!*.code-workspace
|
| 226 |
+
|
| 227 |
+
# Built Visual Studio Code Extensions
|
| 228 |
+
*.vsix
|
.gitmodules
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[submodule "FlexMDM"]
|
| 2 |
+
path = FlexMDM
|
| 3 |
+
url = git@github.com:brianlck/FlexMDM.git
|
| 4 |
+
[submodule "Discrete-Diffusion-Forcing"]
|
| 5 |
+
path = Discrete-Diffusion-Forcing
|
| 6 |
+
url = git@github.com:zhijie-group/Discrete-Diffusion-Forcing.git
|
.gradio/certificate.pem
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-----BEGIN CERTIFICATE-----
|
| 2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
| 3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
| 4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
| 5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
| 6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
| 7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
| 8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
| 9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
| 10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
| 11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
| 12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
| 13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
| 14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
| 15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
| 16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
| 17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
| 18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
| 19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
| 20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
| 21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
| 22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
| 23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
| 24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
| 25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
| 26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
| 27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
| 28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
| 29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
| 30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
| 31 |
+
-----END CERTIFICATE-----
|
Discrete-Diffusion-Forcing/.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
d2f_vllm
|
| 2 |
+
evals_*
|
| 3 |
+
*.pyc
|
Discrete-Diffusion-Forcing/.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.12
|
Discrete-Diffusion-Forcing/.vscode/launch.json
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"configurations": [
|
| 3 |
+
|
| 4 |
+
{
|
| 5 |
+
"name": "PyDbg: `Dream` Accelerate Launch Debug",
|
| 6 |
+
"type": "debugpy",
|
| 7 |
+
"request": "launch",
|
| 8 |
+
"module": "accelerate.commands.launch",
|
| 9 |
+
"args": [
|
| 10 |
+
"--main_process_port",
|
| 11 |
+
"29520",
|
| 12 |
+
"--num_processes",
|
| 13 |
+
"1",
|
| 14 |
+
"D2F-eval/eval_dream_d2f_vllm.py",
|
| 15 |
+
"--model",
|
| 16 |
+
"dream_lora",
|
| 17 |
+
"--model_args",
|
| 18 |
+
"pretrained=/data1/ckpts/Dream-org/Dream-v0-Base-7B,lora_path=/data1/xck/ckpt/wx_dream_base/Decoder-ddt_test-20k,max_new_tokens=256,diffusion_steps=256,temperature=0,add_bos_token=true,escape_until=true,block_size=32,block_add_threshold=0.9,skip_threshold=0.95,decoded_token_threshold=0.9,dtype=bfloat16,sampling_strategy=default,save_dir=evals_dream_single/Decoder-ddt_test-20k/humaneval-ns0-len256-temp0-limit10000-diffsteps256-block32-thresh0.9-decodethresh0.7-skip0.7-toppnone-dtypebfloat16-samplingdefault",
|
| 19 |
+
"--tasks",
|
| 20 |
+
"humaneval",
|
| 21 |
+
"--num_fewshot",
|
| 22 |
+
"0",
|
| 23 |
+
"--batch_size",
|
| 24 |
+
"1",
|
| 25 |
+
"--output_path",
|
| 26 |
+
"evals_dream_single/Decoder-ddt_test-20k/humaneval-ns0-len256-temp0-limit10000-diffsteps256-block32-thresh0.9-decodethresh0.9-skip0.95-toppnone-dtypebfloat16-samplingdefault",
|
| 27 |
+
"--log_samples",
|
| 28 |
+
"--confirm_run_unsafe_code"
|
| 29 |
+
],
|
| 30 |
+
"env": {
|
| 31 |
+
"HF_ALLOW_CODE_EVAL": "1",
|
| 32 |
+
"CUDA_VISIBLE_DEVICES": "0"
|
| 33 |
+
},
|
| 34 |
+
"cwd": "${workspaceFolder}"
|
| 35 |
+
},
|
| 36 |
+
]
|
| 37 |
+
}
|
Discrete-Diffusion-Forcing/D2F-eval/eval_dream.py
ADDED
|
@@ -0,0 +1,1155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import gc
|
| 3 |
+
import time
|
| 4 |
+
import json
|
| 5 |
+
from datetime import timedelta
|
| 6 |
+
from typing import List, Optional, Tuple, Type, TypeVar, Union
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torch.distributions as dists
|
| 10 |
+
import transformers
|
| 11 |
+
from accelerate import (
|
| 12 |
+
Accelerator,
|
| 13 |
+
InitProcessGroupKwargs,
|
| 14 |
+
)
|
| 15 |
+
from datasets import Dataset
|
| 16 |
+
from packaging import version
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from peft import PeftConfig, PeftModel
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
from lm_eval import utils
|
| 22 |
+
from lm_eval.api.instance import Instance
|
| 23 |
+
from lm_eval.api.model import LM
|
| 24 |
+
from lm_eval.api.registry import register_model
|
| 25 |
+
from lm_eval.models.utils import get_dtype
|
| 26 |
+
from lm_eval.__main__ import cli_evaluate
|
| 27 |
+
|
| 28 |
+
eval_logger = logging.getLogger(__name__)
|
| 29 |
+
T = TypeVar("T", bound="LM")
|
| 30 |
+
import random
|
| 31 |
+
def set_seed(seed):
|
| 32 |
+
torch.manual_seed(seed)
|
| 33 |
+
random.seed(seed)
|
| 34 |
+
np.random.seed(seed)
|
| 35 |
+
|
| 36 |
+
torch.backends.cudnn.deterministic = True
|
| 37 |
+
torch.backends.cudnn.benchmark = False
|
| 38 |
+
|
| 39 |
+
def shift_logits(logits):
|
| 40 |
+
shifted_logits = torch.zeros_like(logits)
|
| 41 |
+
shifted_logits[:, 1:, :] = logits[:, :-1, :]
|
| 42 |
+
shifted_logits[:, 0, :] = 1.0
|
| 43 |
+
return shifted_logits
|
| 44 |
+
|
| 45 |
+
def create_full_block_attention_mask(prompt_length, max_length, block_size, device=None, dtype=None):
|
| 46 |
+
"""
|
| 47 |
+
Creates a complete attention mask for the entire sequence with block-based causal attention.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
prompt_length: Length of the prompt (first irregular block)
|
| 51 |
+
max_length: Maximum total sequence length
|
| 52 |
+
block_size: Size of each regular block
|
| 53 |
+
device: Device to create tensor on
|
| 54 |
+
dtype: Data type for the attention mask
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
attention_mask: Tensor of shape [1, 1, max_length, max_length]
|
| 58 |
+
"""
|
| 59 |
+
# Use the provided dtype or default to bfloat16
|
| 60 |
+
if dtype is None:
|
| 61 |
+
dtype = torch.bfloat16
|
| 62 |
+
|
| 63 |
+
# Initialize mask with -inf (no attention)
|
| 64 |
+
attention_mask = torch.full((1, 1, max_length, max_length), -torch.inf, device=device, dtype=dtype)
|
| 65 |
+
|
| 66 |
+
# Block 0: Prompt (can see itself)
|
| 67 |
+
attention_mask[:, :, :prompt_length, :prompt_length] = 0
|
| 68 |
+
|
| 69 |
+
# Calculate the number of regular blocks after prompt
|
| 70 |
+
remaining_length = max_length - prompt_length
|
| 71 |
+
num_blocks = (remaining_length + block_size - 1) // block_size
|
| 72 |
+
|
| 73 |
+
# Process each regular block
|
| 74 |
+
for b in range(num_blocks):
|
| 75 |
+
block_start = prompt_length + b * block_size
|
| 76 |
+
block_end = min(prompt_length + (b + 1) * block_size, max_length)
|
| 77 |
+
|
| 78 |
+
# Current block can see the prompt
|
| 79 |
+
attention_mask[:, :, block_start:block_end, :prompt_length] = 0
|
| 80 |
+
|
| 81 |
+
# Current block can see all previous regular blocks
|
| 82 |
+
for prev_b in range(b):
|
| 83 |
+
prev_start = prompt_length + prev_b * block_size
|
| 84 |
+
prev_end = min(prompt_length + (prev_b + 1) * block_size, max_length)
|
| 85 |
+
attention_mask[:, :, block_start:block_end, prev_start:prev_end] = 0
|
| 86 |
+
|
| 87 |
+
# Current block can see itself (full attention within block)
|
| 88 |
+
attention_mask[:, :, block_start:block_end, block_start:block_end] = 0
|
| 89 |
+
|
| 90 |
+
return attention_mask
|
| 91 |
+
|
| 92 |
+
def extract_attention_mask(full_mask, start_pos, input_length, cache_length):
|
| 93 |
+
"""
|
| 94 |
+
Extract the relevant portion of attention mask for current forward pass.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
full_mask: Complete attention mask [1, 1, max_length, max_length]
|
| 98 |
+
start_pos: Starting position in the full sequence
|
| 99 |
+
input_length: Length of current input sequence
|
| 100 |
+
cache_length: Length of cached sequence
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
attention_mask: Extracted mask [1, 1, input_length, cache_length + input_length]
|
| 104 |
+
"""
|
| 105 |
+
end_pos = start_pos + input_length
|
| 106 |
+
total_length = cache_length + input_length
|
| 107 |
+
|
| 108 |
+
# Extract the relevant rows (current input positions)
|
| 109 |
+
# and columns (cache + current input positions)
|
| 110 |
+
extracted_mask = torch.full((1, 1, input_length, total_length), -torch.inf,
|
| 111 |
+
device=full_mask.device, dtype=full_mask.dtype)
|
| 112 |
+
|
| 113 |
+
# Copy cache columns (0 to cache_length in the extracted mask corresponds to 0 to cache_length in full mask)
|
| 114 |
+
extracted_mask[:, :, :, :cache_length] = full_mask[:, :, start_pos:end_pos, :cache_length]
|
| 115 |
+
|
| 116 |
+
# Copy current input columns
|
| 117 |
+
extracted_mask[:, :, :, cache_length:] = full_mask[:, :, start_pos:end_pos, start_pos:end_pos]
|
| 118 |
+
|
| 119 |
+
return extracted_mask
|
| 120 |
+
|
| 121 |
+
def build_custom_float_attention_mask(input_ids, prompt_length, block_size, device=None, dtype=None):
|
| 122 |
+
B, seq_len = input_ids.shape
|
| 123 |
+
# Use the provided dtype or default to float32
|
| 124 |
+
if dtype is None:
|
| 125 |
+
dtype = torch.float32
|
| 126 |
+
# Initialize to all -inf
|
| 127 |
+
attn_mask = torch.full((B, 1, seq_len, seq_len), float('-inf'), dtype=dtype, device=device)
|
| 128 |
+
# 1. Prompt part: each token can attend to the entire prompt
|
| 129 |
+
for i in range(B):
|
| 130 |
+
attn_mask[i, :, :, :prompt_length[i]] = 0.0 # Allow all tokens to see the prompt
|
| 131 |
+
|
| 132 |
+
# 2. Block division: divide into blocks starting from prompt_length
|
| 133 |
+
num_blocks = (seq_len - prompt_length[i] + block_size - 1) // block_size
|
| 134 |
+
|
| 135 |
+
for b in range(num_blocks):
|
| 136 |
+
block_start = prompt_length[i] + b * block_size
|
| 137 |
+
block_end = min(block_start + block_size, seq_len)
|
| 138 |
+
|
| 139 |
+
# Full attention within the block
|
| 140 |
+
attn_mask[i, :, block_start:block_end, block_start:block_end] = 0.0
|
| 141 |
+
|
| 142 |
+
# Causal attention between blocks (can only see previous blocks)
|
| 143 |
+
for prev_b in range(b):
|
| 144 |
+
prev_start = prompt_length[i] + prev_b * block_size
|
| 145 |
+
prev_end = min(prev_start + block_size, seq_len)
|
| 146 |
+
|
| 147 |
+
# Current block can see previous blocks
|
| 148 |
+
attn_mask[i, :, block_start:block_end, prev_start:prev_end] = 0.0
|
| 149 |
+
|
| 150 |
+
return attn_mask
|
| 151 |
+
|
| 152 |
+
def top_p_logits(logits, top_p=None):
|
| 153 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 154 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 155 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 156 |
+
# Shift the indices to the right to keep the first token above the threshold
|
| 157 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 158 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 159 |
+
|
| 160 |
+
mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
|
| 161 |
+
mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
|
| 162 |
+
logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
|
| 163 |
+
return logits
|
| 164 |
+
|
| 165 |
+
def top_k_logits(logits, top_k=None):
|
| 166 |
+
top_k = min(top_k, logits.size(-1)) # Safety check
|
| 167 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
| 168 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
| 169 |
+
logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
|
| 170 |
+
return logits
|
| 171 |
+
|
| 172 |
+
def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
|
| 173 |
+
if temperature > 0:
|
| 174 |
+
logits = logits / temperature
|
| 175 |
+
if top_p is not None and top_p < 1:
|
| 176 |
+
logits = top_p_logits(logits, top_p)
|
| 177 |
+
if top_k is not None:
|
| 178 |
+
logits = top_k_logits(logits, top_k)
|
| 179 |
+
probs = torch.softmax(logits, dim=-1)
|
| 180 |
+
|
| 181 |
+
if temperature > 0:
|
| 182 |
+
try:
|
| 183 |
+
x0 = dists.Categorical(probs=probs).sample()
|
| 184 |
+
initial_confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
|
| 185 |
+
except:
|
| 186 |
+
initial_confidence, x0 = probs.max(dim=-1)
|
| 187 |
+
else:
|
| 188 |
+
initial_confidence, x0 = probs.max(dim=-1)
|
| 189 |
+
|
| 190 |
+
# Save initial confidence
|
| 191 |
+
confidence = initial_confidence.clone()
|
| 192 |
+
|
| 193 |
+
if margin_confidence:
|
| 194 |
+
sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
|
| 195 |
+
# Extract top1 and top2 probabilities
|
| 196 |
+
top1_probs = sorted_probs[:, 0]
|
| 197 |
+
top2_probs = sorted_probs[:, 1]
|
| 198 |
+
# Calculate confidence as top1 - top2
|
| 199 |
+
confidence = top1_probs - top2_probs
|
| 200 |
+
|
| 201 |
+
if neg_entropy:
|
| 202 |
+
epsilon = 1e-10
|
| 203 |
+
log_probs = torch.log(probs + epsilon)
|
| 204 |
+
confidence = torch.sum(probs * log_probs, dim=-1)
|
| 205 |
+
|
| 206 |
+
return confidence, x0, initial_confidence
|
| 207 |
+
|
| 208 |
+
@register_model("dream_lora")
|
| 209 |
+
class DreamLoRA(LM):
|
| 210 |
+
def __init__(
|
| 211 |
+
self,
|
| 212 |
+
pretrained: Union[str, transformers.PreTrainedModel],
|
| 213 |
+
lora_path: str,
|
| 214 |
+
batch_size: Optional[Union[int, str]] = 1,
|
| 215 |
+
device: Optional[str] = "cuda",
|
| 216 |
+
dtype: Optional[Union[str, torch.dtype]] = "auto",
|
| 217 |
+
max_new_tokens: Optional[int] = 128,
|
| 218 |
+
max_length: Optional[int] = 2048, # Updated to match example code
|
| 219 |
+
add_bos_token: Optional[bool] = False,
|
| 220 |
+
nll_type: Optional[str] = "mc",
|
| 221 |
+
log_type: Optional[str] = "ftb",
|
| 222 |
+
mc_num: Optional[int] = 128,
|
| 223 |
+
classifier_free_guidance: Optional[float] = 1.0,
|
| 224 |
+
sampling_eps: Optional[float] = 1e-3,
|
| 225 |
+
diffusion_steps: Optional[int] = 128,
|
| 226 |
+
trust_remote_code: Optional[bool] = True,
|
| 227 |
+
parallelize: Optional[bool] = False,
|
| 228 |
+
autogptq: Optional[Union[bool, str]] = False,
|
| 229 |
+
temperature: Optional[float] = 0.2, # Updated default
|
| 230 |
+
top_p: Optional[float] = None, # Updated default
|
| 231 |
+
top_k: Optional[float] = None,
|
| 232 |
+
alg: Optional[str] = "entropy",
|
| 233 |
+
alg_temp: Optional[float] = 0.0,
|
| 234 |
+
escape_until: Optional[bool] = False,
|
| 235 |
+
block_size: Optional[int] = 4, # Updated to match example code
|
| 236 |
+
mask_token_id: Optional[int] = 151666, # Added mask_token_id parameter
|
| 237 |
+
block_add_threshold: Optional[float] = 0.5, # Added block_add_threshold parameter
|
| 238 |
+
decoded_token_threshold: Optional[int] = 0.9, # Added decoded_token_threshold parameter
|
| 239 |
+
skip_threshold: Optional[float] = 1.0, # Added skip_threshold parameter
|
| 240 |
+
sampling_strategy: Optional[str] = "default", # Added sampling_strategy parameter
|
| 241 |
+
save_dir: Optional[str] = None,
|
| 242 |
+
**kwargs,
|
| 243 |
+
) -> None:
|
| 244 |
+
super().__init__()
|
| 245 |
+
|
| 246 |
+
# prepare for parallelism
|
| 247 |
+
assert isinstance(device, str)
|
| 248 |
+
assert isinstance(pretrained, str)
|
| 249 |
+
assert isinstance(batch_size, (int, str))
|
| 250 |
+
|
| 251 |
+
gpus = torch.cuda.device_count()
|
| 252 |
+
accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
|
| 253 |
+
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
|
| 254 |
+
if accelerator.num_processes > 1:
|
| 255 |
+
self.accelerator = accelerator
|
| 256 |
+
|
| 257 |
+
if "npu" in accelerator.device.type:
|
| 258 |
+
gpus = torch.npu.device_count()
|
| 259 |
+
|
| 260 |
+
# using one process with no model parallelism
|
| 261 |
+
if not (parallelize or accelerator.num_processes > 1):
|
| 262 |
+
# use user-passed device
|
| 263 |
+
device_list = set(
|
| 264 |
+
["cuda", "cpu"]
|
| 265 |
+
+ [f"cuda:{i}" for i in range(gpus)]
|
| 266 |
+
+ ["mps", "mps:0"]
|
| 267 |
+
+ [f"npu:{i}" for i in range(gpus)]
|
| 268 |
+
)
|
| 269 |
+
if device and device in device_list:
|
| 270 |
+
self._device = torch.device(device)
|
| 271 |
+
eval_logger.info(f"Using device '{device}'")
|
| 272 |
+
if device in ("mps", "mps:0") and version.parse(
|
| 273 |
+
torch.__version__
|
| 274 |
+
) < version.parse("2.1"):
|
| 275 |
+
raise RuntimeError(
|
| 276 |
+
f"mps requires torch >= 2.1. You have {torch.__version__}"
|
| 277 |
+
)
|
| 278 |
+
else:
|
| 279 |
+
eval_logger.info("Device not specified")
|
| 280 |
+
eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
|
| 281 |
+
self._device = (
|
| 282 |
+
torch.device("cuda")
|
| 283 |
+
if torch.cuda.is_available()
|
| 284 |
+
else torch.device("cpu")
|
| 285 |
+
)
|
| 286 |
+
else: # Parallelism managed by accelerate
|
| 287 |
+
if device != "cuda":
|
| 288 |
+
eval_logger.info(
|
| 289 |
+
f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
|
| 290 |
+
)
|
| 291 |
+
# TODO: include in warning that `load_in_8bit` etc. affect this too
|
| 292 |
+
self._device = (
|
| 293 |
+
self.accelerator.device
|
| 294 |
+
if hasattr(self, "accelerator")
|
| 295 |
+
else torch.device(device)
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
self.batch_size_per_gpu = batch_size
|
| 299 |
+
if isinstance(batch_size, str):
|
| 300 |
+
self.batch_size_per_gpu = int(batch_size)
|
| 301 |
+
|
| 302 |
+
# Save LoRA path and block_size
|
| 303 |
+
self.lora_path = lora_path
|
| 304 |
+
self.block_size = block_size
|
| 305 |
+
self.block_add_threshold = block_add_threshold # New block_add_threshold attribute
|
| 306 |
+
self.skip_threshold = skip_threshold # New skip_threshold attribute
|
| 307 |
+
self.sampling_strategy = sampling_strategy # Save sampling strategy parameter
|
| 308 |
+
self.decoded_token_threshold = decoded_token_threshold # New decoded_token_threshold attribute
|
| 309 |
+
self.save_dir = save_dir
|
| 310 |
+
|
| 311 |
+
# Add metric tracking
|
| 312 |
+
self.total_forward_passes = 0
|
| 313 |
+
self.total_generated_tokens = 0
|
| 314 |
+
self.total_prompts = 0
|
| 315 |
+
# Add time and token statistics
|
| 316 |
+
self.total_generation_time = 0.0
|
| 317 |
+
self.total_block_tokens = 0 # Number of blocks * block_size
|
| 318 |
+
self.total_actual_tokens = 0 # Actual generated tokens (excluding EOS)
|
| 319 |
+
self.total_non_eos_tokens = 0 # Total non-EOS tokens in the entire sequence
|
| 320 |
+
self.all_generation_times = []
|
| 321 |
+
self.all_block_tokens = []
|
| 322 |
+
self.all_actual_tokens = []
|
| 323 |
+
self.all_non_eos_tokens = []
|
| 324 |
+
|
| 325 |
+
# Save target_dtype for later use
|
| 326 |
+
self.target_dtype = get_dtype(dtype)
|
| 327 |
+
|
| 328 |
+
self._create_model_and_tokenizer(pretrained, dtype, trust_remote_code)
|
| 329 |
+
|
| 330 |
+
if isinstance(pretrained, str):
|
| 331 |
+
if gpus >= 1 or str(self.device) == "mps":
|
| 332 |
+
# TODO: can remove this whole snippet except in the mps case, perhaps?
|
| 333 |
+
if not (parallelize or autogptq or hasattr(self, "accelerator")):
|
| 334 |
+
# place model onto device requested manually,
|
| 335 |
+
# if not using HF Accelerate or device_map
|
| 336 |
+
# or any other option that preloads model onto device
|
| 337 |
+
try:
|
| 338 |
+
self.model.to(self.device)
|
| 339 |
+
except ValueError:
|
| 340 |
+
eval_logger.debug(
|
| 341 |
+
"Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore."
|
| 342 |
+
)
|
| 343 |
+
# multigpu data-parallel support when launched with accelerate
|
| 344 |
+
if gpus > 1:
|
| 345 |
+
if accelerator.num_processes > 1:
|
| 346 |
+
if parallelize:
|
| 347 |
+
eval_logger.warning(
|
| 348 |
+
"You are both using a HF Accelerate `device_map` (`--model_args parallelize=True`) and launching via `accelerate launch`. This will attempt to do model and data parallelism depending on the resources available."
|
| 349 |
+
)
|
| 350 |
+
elif gpus > accelerator.num_processes:
|
| 351 |
+
eval_logger.warning(
|
| 352 |
+
"WARNING: The number of total system GPUs does not match the number of spawned processes. "
|
| 353 |
+
"If you would like to use data parallelism, please launch the script "
|
| 354 |
+
"with 'accelerate launch *script*'. "
|
| 355 |
+
f"Current run will proceed with {accelerator.num_processes} devices."
|
| 356 |
+
)
|
| 357 |
+
if self.accelerator.is_local_main_process:
|
| 358 |
+
eval_logger.info(
|
| 359 |
+
f"Using {gpus} devices with data parallelism"
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
self._device = torch.device(f"{accelerator.device}")
|
| 363 |
+
self.accelerator = accelerator
|
| 364 |
+
|
| 365 |
+
self._rank = self.accelerator.local_process_index
|
| 366 |
+
self._world_size = self.accelerator.num_processes
|
| 367 |
+
else:
|
| 368 |
+
# if we aren't launching via accelerate, ditch
|
| 369 |
+
self._rank = 0
|
| 370 |
+
self._world_size = 1
|
| 371 |
+
else:
|
| 372 |
+
# if a PreTrainedModel was passed into HFLM, we forgo distributed setup.
|
| 373 |
+
eval_logger.warning(
|
| 374 |
+
"Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration"
|
| 375 |
+
)
|
| 376 |
+
self._rank = 0
|
| 377 |
+
self._world_size = 1
|
| 378 |
+
|
| 379 |
+
self.max_length = max_length
|
| 380 |
+
self.add_bos_token = add_bos_token
|
| 381 |
+
# generation params
|
| 382 |
+
self.max_new_tokens = max_new_tokens
|
| 383 |
+
self.diffusion_steps = diffusion_steps
|
| 384 |
+
self.temperature = temperature
|
| 385 |
+
self.top_p = top_p
|
| 386 |
+
self.top_k = top_k
|
| 387 |
+
self.alg = alg
|
| 388 |
+
self.alg_temp = alg_temp
|
| 389 |
+
self.escape_until = escape_until
|
| 390 |
+
self.block_size = block_size
|
| 391 |
+
self.mask_token_id = mask_token_id
|
| 392 |
+
|
| 393 |
+
# loglikelihood params
|
| 394 |
+
self.nll_type = nll_type
|
| 395 |
+
self.log_type = log_type
|
| 396 |
+
self.mc_num = mc_num
|
| 397 |
+
self.classifier_free_guidance = classifier_free_guidance
|
| 398 |
+
self.sampling_eps = sampling_eps
|
| 399 |
+
|
| 400 |
+
@property
|
| 401 |
+
def batch_size(self):
|
| 402 |
+
return self.batch_size_per_gpu
|
| 403 |
+
|
| 404 |
+
@property
|
| 405 |
+
def device(self):
|
| 406 |
+
return self._device
|
| 407 |
+
|
| 408 |
+
@property
|
| 409 |
+
def rank(self):
|
| 410 |
+
return self._rank
|
| 411 |
+
|
| 412 |
+
@property
|
| 413 |
+
def world_size(self):
|
| 414 |
+
return self._world_size
|
| 415 |
+
|
| 416 |
+
def _create_model_and_tokenizer(self, pretrained, dtype, trust_remote_code):
|
| 417 |
+
# Get correct data type
|
| 418 |
+
from model_cache.dream.model_dream import DreamModel
|
| 419 |
+
from model_cache.dream.configuration_dream import DreamConfig
|
| 420 |
+
target_dtype = get_dtype(dtype)
|
| 421 |
+
|
| 422 |
+
# Load base model, using DreamModel and DreamConfig
|
| 423 |
+
model_config = DreamConfig.from_pretrained(pretrained)
|
| 424 |
+
self.model = DreamModel.from_pretrained(
|
| 425 |
+
pretrained,
|
| 426 |
+
config=model_config,
|
| 427 |
+
torch_dtype=target_dtype,
|
| 428 |
+
trust_remote_code=False,
|
| 429 |
+
).eval()
|
| 430 |
+
|
| 431 |
+
# Load LoRA config and model
|
| 432 |
+
config = PeftConfig.from_pretrained(self.lora_path)
|
| 433 |
+
self.model = PeftModel.from_pretrained(self.model, self.lora_path)
|
| 434 |
+
|
| 435 |
+
# Only convert data type if target_dtype is not None and not "auto"
|
| 436 |
+
if target_dtype is not None and target_dtype != "auto":
|
| 437 |
+
self.model = self.model.to(target_dtype)
|
| 438 |
+
|
| 439 |
+
# Move to specified device
|
| 440 |
+
self.model = self.model.to(self.device)
|
| 441 |
+
|
| 442 |
+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 443 |
+
pretrained, trust_remote_code=trust_remote_code
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
def tok_decode(self, tokens, skip_special_tokens=True):
|
| 447 |
+
return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
| 448 |
+
|
| 449 |
+
def tok_encode(self, text, add_special_tokens=True):
|
| 450 |
+
return self.tokenizer(
|
| 451 |
+
text, return_tensors="pt", add_special_tokens=add_special_tokens
|
| 452 |
+
).input_ids
|
| 453 |
+
|
| 454 |
+
@classmethod
|
| 455 |
+
def create_from_arg_string(
|
| 456 |
+
cls: Type[T], arg_string: str, additional_config: Optional[dict] = None
|
| 457 |
+
) -> T:
|
| 458 |
+
"""
|
| 459 |
+
Creates an instance of the LM class using the given argument string and additional config.
|
| 460 |
+
|
| 461 |
+
Parameters:
|
| 462 |
+
- arg_string: A string containing arguments in the format key1=value1,key2=value2.
|
| 463 |
+
- additional_config: Optional dictionary containing additional configuration parameters.
|
| 464 |
+
|
| 465 |
+
Returns:
|
| 466 |
+
- Instance of the LM class.
|
| 467 |
+
"""
|
| 468 |
+
additional_config = {} if additional_config is None else additional_config
|
| 469 |
+
args = utils.simple_parse_args_string(arg_string)
|
| 470 |
+
args2 = {k: v for k, v in additional_config.items() if v is not None}
|
| 471 |
+
return cls(**args, **args2)
|
| 472 |
+
|
| 473 |
+
def apply_chat_template(
|
| 474 |
+
self, chat_history, add_generation_prompt: bool = True
|
| 475 |
+
) -> str:
|
| 476 |
+
"""
|
| 477 |
+
Method to apply a chat template to a list of chat history between user and model.
|
| 478 |
+
"""
|
| 479 |
+
chat_templated = self.tokenizer.apply_chat_template(
|
| 480 |
+
chat_history,
|
| 481 |
+
tokenize=False,
|
| 482 |
+
add_generation_prompt=add_generation_prompt,
|
| 483 |
+
continue_final_message=not add_generation_prompt,
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
return chat_templated
|
| 487 |
+
|
| 488 |
+
@property
|
| 489 |
+
def tokenizer_name(self) -> str:
|
| 490 |
+
return self.tokenizer.name_or_path.replace("/", "__")
|
| 491 |
+
|
| 492 |
+
def _count_non_eos_tokens_before_truncation(self, generated_sequence, prompt_length):
|
| 493 |
+
"""
|
| 494 |
+
Unified token counting function: counts non-EOS tokens in the generated sequence (before truncation).
|
| 495 |
+
"""
|
| 496 |
+
# Get the generated part (excluding the prompt)
|
| 497 |
+
generated_tokens = generated_sequence[prompt_length:]
|
| 498 |
+
# Count non-EOS tokens
|
| 499 |
+
eos_token_id = self.tokenizer.eos_token_id
|
| 500 |
+
if eos_token_id is not None:
|
| 501 |
+
# If it's a tensor, convert to list for counting
|
| 502 |
+
if hasattr(generated_tokens, 'tolist'):
|
| 503 |
+
generated_tokens_list = generated_tokens.tolist()
|
| 504 |
+
else:
|
| 505 |
+
generated_tokens_list = generated_tokens
|
| 506 |
+
non_eos_count = sum(1 for token in generated_tokens_list if token != eos_token_id)
|
| 507 |
+
else:
|
| 508 |
+
non_eos_count = len(generated_tokens)
|
| 509 |
+
return non_eos_count
|
| 510 |
+
|
| 511 |
+
def _generate_batch(self, prompts: List[str]) -> List[str]:
|
| 512 |
+
if self.add_bos_token:
|
| 513 |
+
prompts = [self.tokenizer.bos_token + p for p in prompts]
|
| 514 |
+
|
| 515 |
+
responses = []
|
| 516 |
+
|
| 517 |
+
# Generate for each prompt individually (block generation usually processes one by one)
|
| 518 |
+
for i, prompt in enumerate(prompts):
|
| 519 |
+
# tokenize
|
| 520 |
+
prompt_ids = self.tokenizer.encode(prompt)
|
| 521 |
+
prompt_tensor = torch.tensor([prompt_ids], device=self.device, dtype=torch.long)
|
| 522 |
+
|
| 523 |
+
if len(prompt_ids) > self.max_length - self.max_new_tokens:
|
| 524 |
+
eval_logger.warning(f"Prompt length {len(prompt_ids)} is larger than {self.max_length-self.max_new_tokens}, cutoff on the left side")
|
| 525 |
+
prompt_tensor = prompt_tensor[:, -(self.max_length-self.max_new_tokens):]
|
| 526 |
+
|
| 527 |
+
# Use generate_block_single method to generate, returns EOS-truncated response text
|
| 528 |
+
response = self._generate_block_single(prompt_tensor)
|
| 529 |
+
responses.append(response)
|
| 530 |
+
|
| 531 |
+
return responses
|
| 532 |
+
|
| 533 |
+
def _generate_block_single(self, prompt):
|
| 534 |
+
"""
|
| 535 |
+
Generates a response for a single prompt using parallel block generation, based on KV cache,
|
| 536 |
+
and using pre-generated attention masks.
|
| 537 |
+
Returns: EOS-truncated response text.
|
| 538 |
+
"""
|
| 539 |
+
self.model.eval()
|
| 540 |
+
|
| 541 |
+
mask_id = self.mask_token_id
|
| 542 |
+
block_size = self.block_size
|
| 543 |
+
block_add_threshold = self.block_add_threshold
|
| 544 |
+
skip_threshold = self.skip_threshold
|
| 545 |
+
decoded_token_threshold = self.decoded_token_threshold
|
| 546 |
+
|
| 547 |
+
# Pre-generate full attention mask, using model's data type
|
| 548 |
+
prompt_length = prompt.shape[1]
|
| 549 |
+
full_attention_mask = create_full_block_attention_mask(
|
| 550 |
+
prompt_length=prompt_length,
|
| 551 |
+
max_length=self.max_length,
|
| 552 |
+
block_size=block_size,
|
| 553 |
+
device=self.device,
|
| 554 |
+
dtype=self.target_dtype if self.target_dtype is not None and self.target_dtype != "auto" else torch.bfloat16
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
with torch.inference_mode():
|
| 558 |
+
# Initialization
|
| 559 |
+
x_t = prompt.to(self.device)
|
| 560 |
+
|
| 561 |
+
# Track block states - state can be: 'active', 'to_cache', 'in_cache'
|
| 562 |
+
# Added 'is_complete' field to indicate whether it's a complete state (True) or incomplete (False)
|
| 563 |
+
block_states = {
|
| 564 |
+
0: {
|
| 565 |
+
'start_pos': 0,
|
| 566 |
+
'end_pos': prompt.shape[1],
|
| 567 |
+
'mask_count': 0,
|
| 568 |
+
'total_masks': prompt.shape[1],
|
| 569 |
+
'state': 'to_cache', # prompt ready for caching immediately
|
| 570 |
+
'is_complete': True, # prompt is always in a complete state
|
| 571 |
+
},
|
| 572 |
+
}
|
| 573 |
+
|
| 574 |
+
# Initialize cache
|
| 575 |
+
past_key_values = None
|
| 576 |
+
last_logits = None
|
| 577 |
+
|
| 578 |
+
current_blocks = 0 # Number of active blocks
|
| 579 |
+
step = 0
|
| 580 |
+
eos_detected = False # EOS detection flag
|
| 581 |
+
|
| 582 |
+
while current_blocks >= 0:
|
| 583 |
+
step += 1
|
| 584 |
+
|
| 585 |
+
# Check if a new block needs to be added
|
| 586 |
+
if len(block_states)-1 < (self.max_new_tokens // block_size) and not eos_detected:
|
| 587 |
+
last_block_id = len(block_states) - 1
|
| 588 |
+
current_progress = (block_states[last_block_id]['total_masks'] -
|
| 589 |
+
block_states[last_block_id]['mask_count']) / block_states[last_block_id]['total_masks']
|
| 590 |
+
if current_progress >= block_add_threshold:
|
| 591 |
+
# Add new block - defaults to incomplete state
|
| 592 |
+
new_block_id = len(block_states)
|
| 593 |
+
new_start_pos = x_t.shape[1]
|
| 594 |
+
x_t = torch.cat([x_t, torch.tensor([[mask_id] * block_size]).to(self.device)], dim=1)
|
| 595 |
+
|
| 596 |
+
block_states[new_block_id] = {
|
| 597 |
+
'start_pos': new_start_pos,
|
| 598 |
+
'end_pos': new_start_pos + block_size,
|
| 599 |
+
'mask_count': block_size,
|
| 600 |
+
'total_masks': block_size,
|
| 601 |
+
'state': 'active',
|
| 602 |
+
'is_complete': False, # New block defaults to incomplete state
|
| 603 |
+
}
|
| 604 |
+
current_blocks += 1
|
| 605 |
+
|
| 606 |
+
# At the beginning of each loop, update block completion states
|
| 607 |
+
self._update_block_completion_states(block_states, decoded_token_threshold)
|
| 608 |
+
# Check if there are still mask tokens
|
| 609 |
+
mask_index = (x_t == mask_id)
|
| 610 |
+
if mask_index.sum() == 0 and current_blocks == 0:
|
| 611 |
+
break
|
| 612 |
+
|
| 613 |
+
# Determine which blocks need to be added to cache
|
| 614 |
+
blocks_to_cache = [bid for bid, state in block_states.items()
|
| 615 |
+
if state['state'] == 'to_cache']
|
| 616 |
+
|
| 617 |
+
# Determine the part to process
|
| 618 |
+
cache_length = 0 if past_key_values is None else past_key_values.get_seq_length()
|
| 619 |
+
|
| 620 |
+
# Determine content to add to cache
|
| 621 |
+
update_kvcache = 0
|
| 622 |
+
if blocks_to_cache:
|
| 623 |
+
# Find the earliest block that needs to be cached
|
| 624 |
+
earliest_block_id = min(blocks_to_cache)
|
| 625 |
+
earliest_pos = block_states[earliest_block_id]['start_pos']
|
| 626 |
+
|
| 627 |
+
# Find the latest block that needs to be cached
|
| 628 |
+
latest_block_id = max(blocks_to_cache)
|
| 629 |
+
latest_pos = block_states[latest_block_id]['end_pos']
|
| 630 |
+
|
| 631 |
+
# Update cache for all blocks within this range
|
| 632 |
+
update_kvcache = latest_pos - earliest_pos
|
| 633 |
+
|
| 634 |
+
# Create input sequence for forward pass
|
| 635 |
+
process_start_pos = cache_length
|
| 636 |
+
|
| 637 |
+
if update_kvcache > 0:
|
| 638 |
+
# Need to update cache - use completed blocks
|
| 639 |
+
earliest_block_to_cache = min(blocks_to_cache)
|
| 640 |
+
input_seq = x_t[:, block_states[earliest_block_to_cache]['start_pos']:]
|
| 641 |
+
process_start_pos = block_states[earliest_block_to_cache]['start_pos']
|
| 642 |
+
else:
|
| 643 |
+
# Only process active blocks
|
| 644 |
+
active_blocks = [bid for bid in block_states.keys() if block_states[bid]['state'] == 'active']
|
| 645 |
+
if active_blocks:
|
| 646 |
+
# Get all active blocks after the cache
|
| 647 |
+
earliest_active_after_cache = float('inf')
|
| 648 |
+
for bid in active_blocks:
|
| 649 |
+
if block_states[bid]['start_pos'] >= cache_length:
|
| 650 |
+
earliest_active_after_cache = min(earliest_active_after_cache, block_states[bid]['start_pos'])
|
| 651 |
+
|
| 652 |
+
if earliest_active_after_cache < float('inf'):
|
| 653 |
+
input_seq = x_t[:, earliest_active_after_cache:]
|
| 654 |
+
process_start_pos = earliest_active_after_cache
|
| 655 |
+
else:
|
| 656 |
+
# No active blocks after cache, this shouldn't happen
|
| 657 |
+
input_seq = x_t[:, cache_length:]
|
| 658 |
+
# If cache length is already equal to or exceeds sequence length, exit
|
| 659 |
+
if cache_length >= x_t.shape[1]:
|
| 660 |
+
print(f"Cache length ({cache_length}) >= sequence length ({x_t.shape[1]}) at step {step}. Exiting generation loop.")
|
| 661 |
+
raise Exception("Cache length >= sequence length")
|
| 662 |
+
else:
|
| 663 |
+
# No active blocks, but might have blocks to cache in next iteration
|
| 664 |
+
break
|
| 665 |
+
|
| 666 |
+
# Check if input_seq is empty
|
| 667 |
+
if input_seq.shape[1] == 0:
|
| 668 |
+
print(f"Warning: input_seq is empty at step {step}. Breaking generation loop.")
|
| 669 |
+
raise Exception("input_seq is empty")
|
| 670 |
+
|
| 671 |
+
# Extract attention mask for current input from the pre-generated full mask
|
| 672 |
+
input_length = input_seq.shape[1]
|
| 673 |
+
attention_mask = extract_attention_mask(
|
| 674 |
+
full_mask=full_attention_mask,
|
| 675 |
+
start_pos=process_start_pos,
|
| 676 |
+
input_length=input_length,
|
| 677 |
+
cache_length=cache_length
|
| 678 |
+
)
|
| 679 |
+
|
| 680 |
+
# Forward pass
|
| 681 |
+
outputs = self.model(
|
| 682 |
+
input_seq,
|
| 683 |
+
attention_mask=attention_mask,
|
| 684 |
+
past_key_values=past_key_values,
|
| 685 |
+
use_cache=True,
|
| 686 |
+
update_kvcache=update_kvcache,
|
| 687 |
+
)
|
| 688 |
+
|
| 689 |
+
# If needed, update cache
|
| 690 |
+
if update_kvcache > 0:
|
| 691 |
+
# Store logits of the last position for next token prediction
|
| 692 |
+
cache_end_idx = update_kvcache - 1
|
| 693 |
+
last_logits = outputs.logits[:, cache_end_idx, :].unsqueeze(1)
|
| 694 |
+
|
| 695 |
+
# Update cache
|
| 696 |
+
past_key_values = outputs.past_key_values
|
| 697 |
+
|
| 698 |
+
# Mark blocks as cached
|
| 699 |
+
for block_id in blocks_to_cache:
|
| 700 |
+
block_states[block_id]['state'] = 'in_cache'
|
| 701 |
+
|
| 702 |
+
# Get correctly shifted logits for prediction
|
| 703 |
+
logits = self._shift_logits(outputs.logits, last_logit=last_logits)
|
| 704 |
+
|
| 705 |
+
# Process mask tokens for each active block
|
| 706 |
+
blocks_to_deactivate = []
|
| 707 |
+
|
| 708 |
+
for block_id in sorted(block_states.keys()):
|
| 709 |
+
if block_states[block_id]['state'] != 'active':
|
| 710 |
+
continue
|
| 711 |
+
|
| 712 |
+
# Get mask positions for this block
|
| 713 |
+
block_start = block_states[block_id]['start_pos']
|
| 714 |
+
block_end = block_states[block_id]['end_pos']
|
| 715 |
+
block_mask_index = mask_index.clone()
|
| 716 |
+
block_mask_index[:, :block_start] = False
|
| 717 |
+
block_mask_index[:, block_end:] = False
|
| 718 |
+
|
| 719 |
+
# If the current block has no masks, skip it
|
| 720 |
+
if block_mask_index.sum() == 0:
|
| 721 |
+
blocks_to_deactivate.append(block_id)
|
| 722 |
+
continue
|
| 723 |
+
|
| 724 |
+
# Calculate relative position for logits
|
| 725 |
+
logit_offset = block_start - process_start_pos
|
| 726 |
+
block_rel_positions = torch.where(block_mask_index[0, block_start:block_end])[0]
|
| 727 |
+
|
| 728 |
+
if block_rel_positions.size(0) > 0:
|
| 729 |
+
# Get logits for masked positions
|
| 730 |
+
block_mask_logits = logits[:, logit_offset + block_rel_positions, :]
|
| 731 |
+
|
| 732 |
+
# Sample tokens
|
| 733 |
+
confidence, x0, initial_confidence = sample_tokens(
|
| 734 |
+
block_mask_logits.squeeze(0),
|
| 735 |
+
self.temperature,
|
| 736 |
+
top_p=self.top_p,
|
| 737 |
+
top_k=self.top_k,
|
| 738 |
+
neg_entropy=(self.sampling_strategy == "neg_entropy"),
|
| 739 |
+
margin_confidence=(self.sampling_strategy == "margin_confidence")
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
# Apply different sampling strategies based on the block's complete/incomplete state
|
| 743 |
+
is_complete = block_states[block_id]['is_complete']
|
| 744 |
+
|
| 745 |
+
if is_complete:
|
| 746 |
+
# Complete state: apply confidence threshold, if no high confidence, select highest
|
| 747 |
+
high_conf_indices = torch.where(initial_confidence > skip_threshold)[0]
|
| 748 |
+
|
| 749 |
+
if len(high_conf_indices) == 0:
|
| 750 |
+
number_transfer_tokens = 1
|
| 751 |
+
_, transfer_index = torch.topk(confidence, number_transfer_tokens)
|
| 752 |
+
else:
|
| 753 |
+
transfer_index = torch.tensor([], device=self.device, dtype=torch.long)
|
| 754 |
+
|
| 755 |
+
# Merge indices
|
| 756 |
+
all_indices = torch.unique(torch.cat([transfer_index, high_conf_indices]))
|
| 757 |
+
else:
|
| 758 |
+
# Incomplete state: only apply confidence threshold, if none exceed, select no tokens
|
| 759 |
+
high_conf_indices = torch.where(initial_confidence > skip_threshold)[0]
|
| 760 |
+
all_indices = high_conf_indices
|
| 761 |
+
|
| 762 |
+
# Update tokens
|
| 763 |
+
if len(all_indices) > 0:
|
| 764 |
+
x0_ = torch.zeros_like(x0, device=self.device, dtype=torch.long) + mask_id
|
| 765 |
+
x0_[all_indices] = x0[all_indices].clone()
|
| 766 |
+
|
| 767 |
+
# Map indices back to original positions
|
| 768 |
+
for i, idx in enumerate(all_indices):
|
| 769 |
+
abs_pos = block_start + block_rel_positions[idx]
|
| 770 |
+
x_t[0, abs_pos] = x0_[idx]
|
| 771 |
+
|
| 772 |
+
# Update block state
|
| 773 |
+
block_states[block_id]['mask_count'] -= len(all_indices)
|
| 774 |
+
|
| 775 |
+
# Check EOS token
|
| 776 |
+
eos_token_id = self.tokenizer.eos_token_id
|
| 777 |
+
if eos_token_id is not None:
|
| 778 |
+
for idx in all_indices:
|
| 779 |
+
if x0[idx].item() == eos_token_id:
|
| 780 |
+
eos_detected = True
|
| 781 |
+
break
|
| 782 |
+
|
| 783 |
+
# If no masks remain in this block, deactivate it
|
| 784 |
+
mask_index = (x_t == mask_id)
|
| 785 |
+
block_mask_index = mask_index.clone()
|
| 786 |
+
block_mask_index[:, :block_start] = False
|
| 787 |
+
block_mask_index[:, block_end:] = False
|
| 788 |
+
if block_mask_index.sum() == 0:
|
| 789 |
+
blocks_to_deactivate.append(block_id)
|
| 790 |
+
continue
|
| 791 |
+
|
| 792 |
+
# Deactivate completed blocks and mark them for caching in the next iteration
|
| 793 |
+
for block_id in blocks_to_deactivate:
|
| 794 |
+
if block_states[block_id]['state'] == 'active':
|
| 795 |
+
# Check if all preceding blocks are already non-active
|
| 796 |
+
can_deactivate = True
|
| 797 |
+
for prev_block_id in range(block_id):
|
| 798 |
+
if prev_block_id in block_states and block_states[prev_block_id]['state'] == 'active':
|
| 799 |
+
can_deactivate = False
|
| 800 |
+
break
|
| 801 |
+
|
| 802 |
+
# Only mark the current block as 'to_cache' if all preceding blocks are non-active
|
| 803 |
+
if can_deactivate:
|
| 804 |
+
block_states[block_id]['state'] = 'to_cache'
|
| 805 |
+
current_blocks -= 1
|
| 806 |
+
# If there are active blocks before, keep current block as active (do nothing)
|
| 807 |
+
|
| 808 |
+
# Safety check
|
| 809 |
+
if step > 10000:
|
| 810 |
+
print(f"WARNING: Hit safety check at step {step}. Exiting generation loop.")
|
| 811 |
+
break
|
| 812 |
+
|
| 813 |
+
# First, calculate non-EOS tokens for the full generated sequence
|
| 814 |
+
generated_sequence = x_t[0, prompt.shape[1]:].tolist()
|
| 815 |
+
non_eos_tokens = self._count_non_eos_tokens_before_truncation(
|
| 816 |
+
x_t[0].tolist(), prompt.shape[1]
|
| 817 |
+
)
|
| 818 |
+
|
| 819 |
+
# Accumulate to total tokens
|
| 820 |
+
if not hasattr(self, 'total_generated_tokens'):
|
| 821 |
+
self.total_generated_tokens = 0
|
| 822 |
+
self.total_generated_tokens += non_eos_tokens
|
| 823 |
+
|
| 824 |
+
# Generate EOS-truncated response text (consistent with other file logic)
|
| 825 |
+
response = self.tokenizer.decode(generated_sequence).split(self.tokenizer.eos_token)[0]
|
| 826 |
+
|
| 827 |
+
return response
|
| 828 |
+
|
| 829 |
+
def _update_block_completion_states(self, block_states, decoded_token_threshold):
|
| 830 |
+
"""
|
| 831 |
+
Updates the complete/incomplete state of blocks.
|
| 832 |
+
Iterates through blocks from front to back. If a block's decoded token count
|
| 833 |
+
is greater than the threshold, the next block to its right (if it exists)
|
| 834 |
+
is set to a complete state.
|
| 835 |
+
"""
|
| 836 |
+
for block_id in sorted(block_states.keys()):
|
| 837 |
+
# if block_id == 0: # Skip prompt block
|
| 838 |
+
# continue
|
| 839 |
+
|
| 840 |
+
# Calculate decoded tokens for the current block
|
| 841 |
+
decoded_tokens = block_states[block_id]['total_masks'] - block_states[block_id]['mask_count']
|
| 842 |
+
decode_ratio = decoded_tokens / block_states[block_id]['total_masks']
|
| 843 |
+
# If the current block's decoded token count is greater than the threshold,
|
| 844 |
+
# then the next block (if it exists) is set to a complete state.
|
| 845 |
+
# print("decode_ratio",decode_ratio)
|
| 846 |
+
# print("decoded_token_threshold",decoded_token_threshold)
|
| 847 |
+
if decode_ratio >= decoded_token_threshold:
|
| 848 |
+
next_block_id = block_id + 1
|
| 849 |
+
if next_block_id in block_states:
|
| 850 |
+
block_states[next_block_id]['is_complete'] = True
|
| 851 |
+
|
| 852 |
+
def _shift_logits(self, logits, last_logit=None, block_size=None):
|
| 853 |
+
"""Shifts logits to the right by one position, for autoregressive generation"""
|
| 854 |
+
# Check if logits are empty
|
| 855 |
+
if logits.shape[1] == 0:
|
| 856 |
+
print("Warning: logits sequence length is 0, returning empty logits")
|
| 857 |
+
raise Exception("logits sequence length is 0")
|
| 858 |
+
|
| 859 |
+
shifted_logits = torch.zeros_like(logits)
|
| 860 |
+
shifted_logits[:, 1:, :] = logits[:, :-1, :]
|
| 861 |
+
if last_logit is not None:
|
| 862 |
+
shifted_logits[:, 0, :] = last_logit
|
| 863 |
+
return shifted_logits
|
| 864 |
+
shifted_logits[:, 0, :] = 1.0
|
| 865 |
+
return shifted_logits
|
| 866 |
+
|
| 867 |
+
def generate_until(self, requests: List[Instance], disable_tqdm: bool = False):
|
| 868 |
+
res = []
|
| 869 |
+
|
| 870 |
+
# Initialize statistics counters
|
| 871 |
+
if not hasattr(self, 'total_generated_tokens'):
|
| 872 |
+
self.total_generated_tokens = 0
|
| 873 |
+
num_tokens = 0
|
| 874 |
+
num_nfe = 0 # Number of Forward Evaluations
|
| 875 |
+
|
| 876 |
+
pbar = tqdm(
|
| 877 |
+
total=len(requests),
|
| 878 |
+
disable=(disable_tqdm or (self.rank != 0)),
|
| 879 |
+
desc="Running generate_until requests",
|
| 880 |
+
)
|
| 881 |
+
|
| 882 |
+
start_time = time.time()
|
| 883 |
+
|
| 884 |
+
for batch_idx in range(0, len(requests), self.batch_size):
|
| 885 |
+
batch_requests = requests[batch_idx : batch_idx + self.batch_size]
|
| 886 |
+
contexts, gen_args = zip(*[req.arguments for req in batch_requests])
|
| 887 |
+
responses = self._generate_batch(contexts)
|
| 888 |
+
if not self.escape_until:
|
| 889 |
+
for i, r in enumerate(responses):
|
| 890 |
+
for s in gen_args[0]['until']:
|
| 891 |
+
r = r.split(s)[0]
|
| 892 |
+
responses[i] = r
|
| 893 |
+
|
| 894 |
+
res.extend(responses)
|
| 895 |
+
pbar.update(len(contexts))
|
| 896 |
+
|
| 897 |
+
end_time = time.time()
|
| 898 |
+
total_time = end_time - start_time
|
| 899 |
+
|
| 900 |
+
# Accumulate statistics
|
| 901 |
+
num_tokens = self.total_generated_tokens
|
| 902 |
+
num_nfe = self.diffusion_steps * len(requests) # Estimate NFE
|
| 903 |
+
|
| 904 |
+
# Save final statistics
|
| 905 |
+
final_stats = {
|
| 906 |
+
'processed_samples': len(requests),
|
| 907 |
+
'total_samples': len(requests),
|
| 908 |
+
'total_tokens': num_tokens,
|
| 909 |
+
'total_nfe': num_nfe,
|
| 910 |
+
'total_time': total_time,
|
| 911 |
+
'tokens_per_second': num_tokens / total_time if total_time > 0 else 0,
|
| 912 |
+
'nfe_per_token': num_nfe / num_tokens if num_tokens > 0 else 0,
|
| 913 |
+
'timestamp': time.strftime('%Y-%m-%d %H:%M:%S')
|
| 914 |
+
}
|
| 915 |
+
|
| 916 |
+
# Save statistics to file
|
| 917 |
+
if self.save_dir is not None:
|
| 918 |
+
import os
|
| 919 |
+
os.makedirs(self.save_dir, exist_ok=True)
|
| 920 |
+
|
| 921 |
+
# Save response results
|
| 922 |
+
save_path = os.path.join(self.save_dir, f'rank_{self.rank}_responses.jsonl')
|
| 923 |
+
with open(save_path, 'w', encoding='utf-8') as f:
|
| 924 |
+
for r in res:
|
| 925 |
+
f.write(json.dumps(r, ensure_ascii=False) + '\n')
|
| 926 |
+
|
| 927 |
+
# Save statistics results
|
| 928 |
+
stats_path = os.path.join(self.save_dir, f'rank_{self.rank}_final_stats.json')
|
| 929 |
+
with open(stats_path, 'w', encoding='utf-8') as f:
|
| 930 |
+
json.dump(final_stats, f, ensure_ascii=False, indent=2)
|
| 931 |
+
|
| 932 |
+
# Print final statistics
|
| 933 |
+
print("\n" + "="*60)
|
| 934 |
+
print("=== Final Statistics ===")
|
| 935 |
+
print("="*60)
|
| 936 |
+
print(f"Processed Samples: {final_stats['processed_samples']}")
|
| 937 |
+
print(f"Total Samples: {final_stats['total_samples']}")
|
| 938 |
+
print(f"Total Tokens: {final_stats['total_tokens']}")
|
| 939 |
+
print(f"Total NFE: {final_stats['total_nfe']}")
|
| 940 |
+
print(f"Total Time: {final_stats['total_time']:.4f}s")
|
| 941 |
+
print(f"Tokens/Second: {final_stats['tokens_per_second']:.2f}")
|
| 942 |
+
print(f"NFE/Token: {final_stats['nfe_per_token']:.4f}")
|
| 943 |
+
print(f"Completion Time: {final_stats['timestamp']}")
|
| 944 |
+
print("="*60)
|
| 945 |
+
|
| 946 |
+
return res
|
| 947 |
+
|
| 948 |
+
def _forward_process(self, batch):
|
| 949 |
+
b, l = batch.shape
|
| 950 |
+
# sample from U[0, 1] following https://arxiv.org/pdf/2107.00630 I.1
|
| 951 |
+
u0 = torch.rand(1, device=batch.device, dtype=torch.float32)
|
| 952 |
+
indices = torch.arange(b, device=batch.device).float()
|
| 953 |
+
t = (u0 + indices / b) % 1
|
| 954 |
+
|
| 955 |
+
p_mask = (1 - self.sampling_eps) * t + self.sampling_eps
|
| 956 |
+
|
| 957 |
+
p_mask = p_mask[:, None].repeat(1, l)
|
| 958 |
+
|
| 959 |
+
mask_indices = torch.rand((b, l), device=batch.device) < p_mask
|
| 960 |
+
# always unmask bos and eos
|
| 961 |
+
mask_indices[:, 0] = False
|
| 962 |
+
mask_indices[:, -1] = False
|
| 963 |
+
|
| 964 |
+
noisy_batch = torch.where(mask_indices, self.mask_token_id, batch)
|
| 965 |
+
return noisy_batch, p_mask
|
| 966 |
+
|
| 967 |
+
@torch.no_grad()
|
| 968 |
+
def get_logits(self, batch, prompt_index):
|
| 969 |
+
'''
|
| 970 |
+
prompt_index : 1D bool tensor, length=batch.shape[1]
|
| 971 |
+
'''
|
| 972 |
+
if self.classifier_free_guidance > 1.:
|
| 973 |
+
assert len(prompt_index) == batch.shape[1]
|
| 974 |
+
prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1)
|
| 975 |
+
un_batch = batch.clone()
|
| 976 |
+
un_batch[prompt_index] = self.mask_token_id
|
| 977 |
+
batch = torch.cat([batch, un_batch])
|
| 978 |
+
|
| 979 |
+
input = batch
|
| 980 |
+
|
| 981 |
+
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
| 982 |
+
logits = self.model(input).logits
|
| 983 |
+
# since bos always unmask, the first logits will not be used
|
| 984 |
+
logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
|
| 985 |
+
|
| 986 |
+
if self.classifier_free_guidance > 1.:
|
| 987 |
+
logits, un_logits = torch.chunk(logits, 2, dim=0)
|
| 988 |
+
logits = un_logits + self.cfg * (logits - un_logits)
|
| 989 |
+
return logits[:, :batch.shape[1]]
|
| 990 |
+
|
| 991 |
+
@torch.no_grad()
|
| 992 |
+
def _eval_target_nll_mc(self, prefix, target):
|
| 993 |
+
if prefix is None:
|
| 994 |
+
seq = target[None, :]
|
| 995 |
+
else:
|
| 996 |
+
seq = torch.concatenate([prefix, target])[None, :]
|
| 997 |
+
seq = seq.repeat((self.batch_size, 1)).to(self.device)
|
| 998 |
+
|
| 999 |
+
if self.log_type == 'ftb':
|
| 1000 |
+
prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
|
| 1001 |
+
else:
|
| 1002 |
+
prompt_index = torch.arange(seq.shape[1], device=self.device) >= len(prefix)
|
| 1003 |
+
|
| 1004 |
+
loss_acc = []
|
| 1005 |
+
for _ in range(max(self.mc_num // self.batch_size, 1)):
|
| 1006 |
+
perturbed_seq = seq.clone()
|
| 1007 |
+
# eval_logger.info("before noising")
|
| 1008 |
+
perturbed_seq_, p_mask = self._forward_process(seq)
|
| 1009 |
+
# eval_logger.info("end noising")
|
| 1010 |
+
if self.log_type == 'ftb':
|
| 1011 |
+
perturbed_seq[:, -len(target):] = perturbed_seq_[:, -len(target):]
|
| 1012 |
+
elif self.log_type == 'btf':
|
| 1013 |
+
perturbed_seq[:, :len(prefix)] = perturbed_seq_[:, :len(prefix)]
|
| 1014 |
+
elif self.log_type == 'union':
|
| 1015 |
+
perturbed_seq = perturbed_seq_
|
| 1016 |
+
else:
|
| 1017 |
+
raise NotImplementedError(self.log_type)
|
| 1018 |
+
|
| 1019 |
+
mask_indices = perturbed_seq == self.mask_token_id
|
| 1020 |
+
logits = self.get_logits(perturbed_seq, prompt_index)
|
| 1021 |
+
loss = F.cross_entropy(logits[mask_indices], seq[mask_indices], reduction='none') / p_mask[mask_indices]
|
| 1022 |
+
loss = loss.sum() / self.batch_size
|
| 1023 |
+
loss_acc.append(loss.item())
|
| 1024 |
+
|
| 1025 |
+
return sum(loss_acc) / len(loss_acc)
|
| 1026 |
+
|
| 1027 |
+
@torch.no_grad()
|
| 1028 |
+
def _eval_target_nll_ar(self, prefix, target):
|
| 1029 |
+
prefix, target = prefix.unsqueeze(0), target.unsqueeze(0) # 1*l1, 1*l2
|
| 1030 |
+
assert self.log_type in ['ftb', 'btf']
|
| 1031 |
+
assert self.nll_type in ['ar_ftb', 'ar_btf']
|
| 1032 |
+
|
| 1033 |
+
if self.log_type == 'ftb':
|
| 1034 |
+
prompt_index = torch.arange(prefix.shape[1] + target.shape[1], device=self.device) < prefix.shape[1]
|
| 1035 |
+
else:
|
| 1036 |
+
prompt_index = torch.arange(prefix.shape[1] + target.shape[1], device=self.device) >= prefix.shape[1]
|
| 1037 |
+
|
| 1038 |
+
if self.log_type == 'ftb':
|
| 1039 |
+
perturbed_ = target.repeat(target.shape[1], 1).clone().contiguous() # l2*l2
|
| 1040 |
+
else:
|
| 1041 |
+
perturbed_ = prefix.repeat(prefix.shape[1], 1).clone().contiguous() # l1*l1
|
| 1042 |
+
|
| 1043 |
+
mask_index = torch.ones((perturbed_.shape[1], perturbed_.shape[1]), dtype=torch.bool)
|
| 1044 |
+
if self.nll_type == 'ar_ftb':
|
| 1045 |
+
mask_index = torch.triu(mask_index)
|
| 1046 |
+
else:
|
| 1047 |
+
mask_index = torch.tril(mask_index)
|
| 1048 |
+
perturbed_[mask_index] = self.mask_token_id
|
| 1049 |
+
if self.log_type == 'ftb':
|
| 1050 |
+
perturbed_seq = torch.cat([prefix.repeat(perturbed_.shape[0], 1), perturbed_], dim=-1)
|
| 1051 |
+
else:
|
| 1052 |
+
perturbed_seq = torch.cat([perturbed_, target.repeat(perturbed_.shape[0], 1)], dim=-1)
|
| 1053 |
+
|
| 1054 |
+
logits_ = []
|
| 1055 |
+
num = len(perturbed_seq) // self.batch_size if len(perturbed_seq) % self.batch_size == 0 else len(perturbed_seq) // self.batch_size + 1
|
| 1056 |
+
for i in range(num):
|
| 1057 |
+
end = (i + 1) * self.batch_size if (i + 1) * self.batch_size < len(perturbed_seq) else len(perturbed_seq)
|
| 1058 |
+
perturbed_seq_ = perturbed_seq[i * self.batch_size: end]
|
| 1059 |
+
perturbed_seq_ = perturbed_seq_.to(self.device)
|
| 1060 |
+
if len(perturbed_seq_.shape) == 1:
|
| 1061 |
+
perturbed_seq_ = perturbed_seq_.unsqueeze(0)
|
| 1062 |
+
logits = self.get_logits(perturbed_seq_, prompt_index)
|
| 1063 |
+
logits_.append(logits.cpu())
|
| 1064 |
+
logits = torch.cat(logits_, dim=0)
|
| 1065 |
+
|
| 1066 |
+
temp_index = torch.ones((perturbed_.shape[1], perturbed_.shape[1]), dtype=torch.bool)
|
| 1067 |
+
if self.nll_type == 'ar_ftb':
|
| 1068 |
+
temp_index = torch.triu(temp_index, diagonal=1)
|
| 1069 |
+
else:
|
| 1070 |
+
temp_index = torch.tril(temp_index, diagonal=-1)
|
| 1071 |
+
mask_index[temp_index] = False
|
| 1072 |
+
if self.log_type == 'ftb':
|
| 1073 |
+
logits_index = torch.cat([torch.zeros((perturbed_.shape[1], prefix.shape[1]), dtype=torch.bool), mask_index], dim=-1)
|
| 1074 |
+
else:
|
| 1075 |
+
logits_index = torch.cat([mask_index, torch.zeros((perturbed_.shape[1], target.shape[1]), dtype=torch.bool)], dim=-1)
|
| 1076 |
+
|
| 1077 |
+
if self.log_type == 'ftb':
|
| 1078 |
+
loss = F.cross_entropy(logits[logits_index], target[0], reduction='sum').cpu().item()
|
| 1079 |
+
else:
|
| 1080 |
+
loss = F.cross_entropy(logits[logits_index], prefix[0], reduction='sum').cpu().item()
|
| 1081 |
+
return loss
|
| 1082 |
+
|
| 1083 |
+
def _encode_pair(self, context, continuation):
|
| 1084 |
+
if self.add_bos_token:
|
| 1085 |
+
context = self.tokenizer.bos_token + context
|
| 1086 |
+
|
| 1087 |
+
n_spaces = len(context) - len(context.rstrip())
|
| 1088 |
+
if n_spaces > 0:
|
| 1089 |
+
continuation = context[-n_spaces:] + continuation
|
| 1090 |
+
context = context[:-n_spaces]
|
| 1091 |
+
|
| 1092 |
+
whole_enc = self.tokenizer.encode(context + continuation) + [self.tokenizer.eos_token_id]
|
| 1093 |
+
context_enc = self.tokenizer.encode(context)
|
| 1094 |
+
|
| 1095 |
+
context_enc_len = len(context_enc)
|
| 1096 |
+
continuation_enc = whole_enc[context_enc_len:]
|
| 1097 |
+
|
| 1098 |
+
# by default truncate on the left
|
| 1099 |
+
cutoff_length = max(len(whole_enc) - self.max_length, 0)
|
| 1100 |
+
if cutoff_length > 0:
|
| 1101 |
+
eval_logger.warning(f"Text length {len(whole_enc)} is larger than {self.max_length}, cutoff on the left side")
|
| 1102 |
+
context_remain = context_enc_len-cutoff_length
|
| 1103 |
+
if context_remain > 0:
|
| 1104 |
+
context_enc = context_enc[-context_remain:]
|
| 1105 |
+
else:
|
| 1106 |
+
eval_logger.warning(f"All context (prompt) is truncated.")
|
| 1107 |
+
context_enc = ""
|
| 1108 |
+
continuation_enc = whole_enc[-self.max_length:]
|
| 1109 |
+
return context_enc, continuation_enc
|
| 1110 |
+
|
| 1111 |
+
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
|
| 1112 |
+
def _tokenize(e):
|
| 1113 |
+
prefix, target = self._encode_pair(e["prefix"], e["target"])
|
| 1114 |
+
return {
|
| 1115 |
+
"prefix_text": e["prefix"],
|
| 1116 |
+
"target_text": e["target"],
|
| 1117 |
+
"prefix": prefix,
|
| 1118 |
+
"target": target,
|
| 1119 |
+
}
|
| 1120 |
+
|
| 1121 |
+
ds = []
|
| 1122 |
+
ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests]
|
| 1123 |
+
ds = Dataset.from_list(ds)
|
| 1124 |
+
print(ds[0])
|
| 1125 |
+
ds = ds.map(_tokenize)
|
| 1126 |
+
ds = ds.with_format("torch")
|
| 1127 |
+
|
| 1128 |
+
out = []
|
| 1129 |
+
with torch.no_grad():
|
| 1130 |
+
for elem in tqdm(ds, desc="Computing likelihood..."):
|
| 1131 |
+
prefix = elem["prefix"]
|
| 1132 |
+
target = elem["target"]
|
| 1133 |
+
# likelihood calculations are modified from https://github.com/ML-GSAI/SMDM/blob/main/evaluate_diff.py
|
| 1134 |
+
if self.nll_type == 'mc':
|
| 1135 |
+
ll = -self._eval_target_nll_mc(prefix, target)
|
| 1136 |
+
if self.log_type == 'union':
|
| 1137 |
+
ll = ll / (len(target) + len(prefix))
|
| 1138 |
+
elif self.nll_type == 'ar_ftb' or self.nll_type == 'ar_btf':
|
| 1139 |
+
ll = -self._eval_target_nll_ar(prefix, target)
|
| 1140 |
+
else:
|
| 1141 |
+
raise NotImplementedError(self.nll_type)
|
| 1142 |
+
|
| 1143 |
+
# TODO: greedy decoding
|
| 1144 |
+
is_target_greedy_dec = False
|
| 1145 |
+
|
| 1146 |
+
out.append((ll, 1.0 if is_target_greedy_dec else 0.0))
|
| 1147 |
+
return out
|
| 1148 |
+
|
| 1149 |
+
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
|
| 1150 |
+
raise NotImplementedError
|
| 1151 |
+
|
| 1152 |
+
|
| 1153 |
+
if __name__ == "__main__":
|
| 1154 |
+
set_seed(1234)
|
| 1155 |
+
cli_evaluate()
|
Discrete-Diffusion-Forcing/D2F-eval/eval_dream.sh
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
tasks="gsm8k_cot mbpp minerva_math"
|
| 2 |
+
nshots="8 3 4"
|
| 3 |
+
lengths="256 256 256"
|
| 4 |
+
temperatures="0 0 0"
|
| 5 |
+
limits="10000 10000 10000"
|
| 6 |
+
block_sizes="32 48 64"
|
| 7 |
+
block_add_thresholds="0.1 0.1 0.1"
|
| 8 |
+
decoded_token_thresholds="0.95 0.95 0.95"
|
| 9 |
+
skip_thresholds="0.9 0.9 0.9"
|
| 10 |
+
top_ps="none none none"
|
| 11 |
+
dtypes="bfloat16 bfloat16 bfloat16"
|
| 12 |
+
sampling_strategies="default default default"
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
humaneval_nshots="0"
|
| 17 |
+
humaneval_lengths="256"
|
| 18 |
+
humaneval_temperatures="0"
|
| 19 |
+
humaneval_limits="10000"
|
| 20 |
+
humaneval_diffusion_steps="256"
|
| 21 |
+
humaneval_block_sizes="32"
|
| 22 |
+
humaneval_block_add_thresholds="0.9"
|
| 23 |
+
humaneval_decoded_token_thresholds="0.95"
|
| 24 |
+
humaneval_skip_thresholds="0.95"
|
| 25 |
+
humaneval_top_ps="none"
|
| 26 |
+
humaneval_dtypes="bfloat16"
|
| 27 |
+
humaneval_sampling_strategies="default"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
base_model=Dream-org/Dream-v0-Base-7B
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
lora_models=(
|
| 35 |
+
"SJTU-Deng-Lab/D2F_Dream_Base_7B_Lora"
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
read -ra TASKS_ARRAY <<< "$tasks"
|
| 40 |
+
read -ra NSHOTS_ARRAY <<< "$nshots"
|
| 41 |
+
read -ra LENGTH_ARRAY <<< "$lengths"
|
| 42 |
+
read -ra TEMP_ARRAY <<< "$temperatures"
|
| 43 |
+
read -ra LIMITS_ARRAY <<< "$limits"
|
| 44 |
+
read -ra BLOCK_SIZES_ARRAY <<< "$block_sizes"
|
| 45 |
+
read -ra BLOCK_ADD_THRESHOLDS_ARRAY <<< "$block_add_thresholds"
|
| 46 |
+
read -ra DECODED_TOKEN_THRESHOLDS_ARRAY <<< "$decoded_token_thresholds"
|
| 47 |
+
read -ra SKIP_THRESHOLDS_ARRAY <<< "$skip_thresholds"
|
| 48 |
+
read -ra TOP_PS_ARRAY <<< "$top_ps"
|
| 49 |
+
read -ra DTYPES_ARRAY <<< "$dtypes"
|
| 50 |
+
read -ra SAMPLING_STRATEGIES_ARRAY <<< "$sampling_strategies"
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
read -ra HUMANEVAL_NSHOTS_ARRAY <<< "$humaneval_nshots"
|
| 54 |
+
read -ra HUMANEVAL_LENGTHS_ARRAY <<< "$humaneval_lengths"
|
| 55 |
+
read -ra HUMANEVAL_TEMP_ARRAY <<< "$humaneval_temperatures"
|
| 56 |
+
read -ra HUMANEVAL_LIMITS_ARRAY <<< "$humaneval_limits"
|
| 57 |
+
read -ra HUMANEVAL_DIFFUSION_STEPS_ARRAY <<< "$humaneval_diffusion_steps"
|
| 58 |
+
read -ra HUMANEVAL_BLOCK_SIZES_ARRAY <<< "$humaneval_block_sizes"
|
| 59 |
+
read -ra HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY <<< "$humaneval_block_add_thresholds"
|
| 60 |
+
read -ra HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY <<< "$humaneval_decoded_token_thresholds"
|
| 61 |
+
read -ra HUMANEVAL_SKIP_THRESHOLDS_ARRAY <<< "$humaneval_skip_thresholds"
|
| 62 |
+
read -ra HUMANEVAL_TOP_PS_ARRAY <<< "$humaneval_top_ps"
|
| 63 |
+
read -ra HUMANEVAL_DTYPES_ARRAY <<< "$humaneval_dtypes"
|
| 64 |
+
read -ra HUMANEVAL_SAMPLING_STRATEGIES_ARRAY <<< "$humaneval_sampling_strategies"
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
array_length=${#TASKS_ARRAY[@]}
|
| 68 |
+
if [[ ${#NSHOTS_ARRAY[@]} -ne $array_length ]] || \
|
| 69 |
+
[[ ${#LENGTH_ARRAY[@]} -ne $array_length ]] || \
|
| 70 |
+
[[ ${#TEMP_ARRAY[@]} -ne $array_length ]] || \
|
| 71 |
+
[[ ${#LIMITS_ARRAY[@]} -ne $array_length ]] || \
|
| 72 |
+
[[ ${#BLOCK_SIZES_ARRAY[@]} -ne $array_length ]] || \
|
| 73 |
+
[[ ${#BLOCK_ADD_THRESHOLDS_ARRAY[@]} -ne $array_length ]] || \
|
| 74 |
+
[[ ${#DECODED_TOKEN_THRESHOLDS_ARRAY[@]} -ne $array_length ]] || \
|
| 75 |
+
[[ ${#SKIP_THRESHOLDS_ARRAY[@]} -ne $array_length ]] || \
|
| 76 |
+
[[ ${#TOP_PS_ARRAY[@]} -ne $array_length ]] || \
|
| 77 |
+
[[ ${#SAMPLING_STRATEGIES_ARRAY[@]} -ne $array_length ]] || \
|
| 78 |
+
[[ ${#DTYPES_ARRAY[@]} -ne $array_length ]]; then
|
| 79 |
+
echo "Error: All configuration arrays must have the same length!"
|
| 80 |
+
echo "Tasks: ${#TASKS_ARRAY[@]}, Nshots: ${#NSHOTS_ARRAY[@]}, Lengths: ${#LENGTH_ARRAY[@]}, Temperatures: ${#TEMP_ARRAY[@]}, Limits: ${#LIMITS_ARRAY[@]}, Block sizes: ${#BLOCK_SIZES_ARRAY[@]}, Block thresholds: ${#BLOCK_ADD_THRESHOLDS_ARRAY[@]}, Decoded token thresholds: ${#DECODED_TOKEN_THRESHOLDS_ARRAY[@]}, Skip thresholds: ${#SKIP_THRESHOLDS_ARRAY[@]}, Top_ps: ${#TOP_PS_ARRAY[@]}, Sampling strategies: ${#SAMPLING_STRATEGIES_ARRAY[@]}, Dtypes: ${#DTYPES_ARRAY[@]}"
|
| 81 |
+
exit 1
|
| 82 |
+
fi
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
humaneval_array_length=${#HUMANEVAL_NSHOTS_ARRAY[@]}
|
| 86 |
+
if [[ ${#HUMANEVAL_LENGTHS_ARRAY[@]} -ne $humaneval_array_length ]] || \
|
| 87 |
+
[[ ${#HUMANEVAL_TEMP_ARRAY[@]} -ne $humaneval_array_length ]] || \
|
| 88 |
+
[[ ${#HUMANEVAL_LIMITS_ARRAY[@]} -ne $humaneval_array_length ]] || \
|
| 89 |
+
[[ ${#HUMANEVAL_DIFFUSION_STEPS_ARRAY[@]} -ne $humaneval_array_length ]] || \
|
| 90 |
+
[[ ${#HUMANEVAL_BLOCK_SIZES_ARRAY[@]} -ne $humaneval_array_length ]] || \
|
| 91 |
+
[[ ${#HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY[@]} -ne $humaneval_array_length ]] || \
|
| 92 |
+
[[ ${#HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY[@]} -ne $humaneval_array_length ]] || \
|
| 93 |
+
[[ ${#HUMANEVAL_SKIP_THRESHOLDS_ARRAY[@]} -ne $humaneval_array_length ]] || \
|
| 94 |
+
[[ ${#HUMANEVAL_TOP_PS_ARRAY[@]} -ne $humaneval_array_length ]] || \
|
| 95 |
+
[[ ${#HUMANEVAL_DTYPES_ARRAY[@]} -ne $humaneval_array_length ]] || \
|
| 96 |
+
[[ ${#HUMANEVAL_SAMPLING_STRATEGIES_ARRAY[@]} -ne $humaneval_array_length ]]; then
|
| 97 |
+
echo "Error: All HumanEval configuration arrays must have the same length!"
|
| 98 |
+
echo "HumanEval Nshots: ${#HUMANEVAL_NSHOTS_ARRAY[@]}, Lengths: ${#HUMANEVAL_LENGTHS_ARRAY[@]}, Temperatures: ${#HUMANEVAL_TEMP_ARRAY[@]}, Limits: ${#HUMANEVAL_LIMITS_ARRAY[@]}, Diffusion steps: ${#HUMANEVAL_DIFFUSION_STEPS_ARRAY[@]}, Block sizes: ${#HUMANEVAL_BLOCK_SIZES_ARRAY[@]}, Block thresholds: ${#HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY[@]}, Decoded token thresholds: ${#HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY[@]}, Skip thresholds: ${#HUMANEVAL_SKIP_THRESHOLDS_ARRAY[@]}, Top_ps: ${#HUMANEVAL_TOP_PS_ARRAY[@]}, Dtypes: ${#HUMANEVAL_DTYPES_ARRAY[@]}, Sampling strategies: ${#HUMANEVAL_SAMPLING_STRATEGIES_ARRAY[@]}"
|
| 99 |
+
exit 1
|
| 100 |
+
fi
|
| 101 |
+
|
| 102 |
+
export HF_ALLOW_CODE_EVAL=1
|
| 103 |
+
for lora_model in "${lora_models[@]}"; do
|
| 104 |
+
lora_model_name="$lora_model"
|
| 105 |
+
echo "===================================================================="
|
| 106 |
+
echo "Evaluating LoRA model: $lora_model_name"
|
| 107 |
+
echo "===================================================================="
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
for i in "${!HUMANEVAL_NSHOTS_ARRAY[@]}"; do
|
| 112 |
+
output_path="evals_dream${lora_model_name}/humaneval-ns${HUMANEVAL_NSHOTS_ARRAY[$i]}-len${HUMANEVAL_LENGTHS_ARRAY[$i]}-temp${HUMANEVAL_TEMP_ARRAY[$i]}-limit${HUMANEVAL_LIMITS_ARRAY[$i]}-diffsteps${HUMANEVAL_DIFFUSION_STEPS_ARRAY[$i]}-block${HUMANEVAL_BLOCK_SIZES_ARRAY[$i]}-thresh${HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY[$i]}-decodethresh${HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY[$i]}-skip${HUMANEVAL_SKIP_THRESHOLDS_ARRAY[$i]}-topp${HUMANEVAL_TOP_PS_ARRAY[$i]}-dtype${HUMANEVAL_DTYPES_ARRAY[$i]}-sampling${HUMANEVAL_SAMPLING_STRATEGIES_ARRAY[$i]}"
|
| 113 |
+
echo "Running HumanEval evaluation $((i+1))/${humaneval_array_length} for $lora_model_name..."
|
| 114 |
+
echo "HumanEval Config: Shots: ${HUMANEVAL_NSHOTS_ARRAY[$i]}, Length: ${HUMANEVAL_LENGTHS_ARRAY[$i]}, Temperature: ${HUMANEVAL_TEMP_ARRAY[$i]}, Limit: ${HUMANEVAL_LIMITS_ARRAY[$i]}, Diffusion Steps: ${HUMANEVAL_DIFFUSION_STEPS_ARRAY[$i]}, Block Size: ${HUMANEVAL_BLOCK_SIZES_ARRAY[$i]}, Block Add Threshold: ${HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY[$i]}, Decoded Token Threshold: ${HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY[$i]}, Skip Threshold: ${HUMANEVAL_SKIP_THRESHOLDS_ARRAY[$i]}, Top_p: ${HUMANEVAL_TOP_PS_ARRAY[$i]}, Sampling Strategy: ${HUMANEVAL_SAMPLING_STRATEGIES_ARRAY[$i]}, Dtype: ${HUMANEVAL_DTYPES_ARRAY[$i]}; Output: $output_path"
|
| 115 |
+
|
| 116 |
+
if [[ "${HUMANEVAL_TOP_PS_ARRAY[$i]}" == "none" ]]; then
|
| 117 |
+
humaneval_model_args="pretrained=${base_model},lora_path=${lora_model},max_new_tokens=${HUMANEVAL_LENGTHS_ARRAY[$i]},diffusion_steps=${HUMANEVAL_DIFFUSION_STEPS_ARRAY[$i]},temperature=${HUMANEVAL_TEMP_ARRAY[$i]},add_bos_token=true,escape_until=true,block_size=${HUMANEVAL_BLOCK_SIZES_ARRAY[$i]},block_add_threshold=${HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY[$i]},skip_threshold=${HUMANEVAL_SKIP_THRESHOLDS_ARRAY[$i]},decoded_token_threshold=${HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY[$i]},dtype=${HUMANEVAL_DTYPES_ARRAY[$i]},sampling_strategy=${HUMANEVAL_SAMPLING_STRATEGIES_ARRAY[$i]},save_dir=${output_path}"
|
| 118 |
+
else
|
| 119 |
+
humaneval_model_args="pretrained=${base_model},lora_path=${lora_model},max_new_tokens=${HUMANEVAL_LENGTHS_ARRAY[$i]},diffusion_steps=${HUMANEVAL_DIFFUSION_STEPS_ARRAY[$i]},temperature=${HUMANEVAL_TEMP_ARRAY[$i]},top_p=${HUMANEVAL_TOP_PS_ARRAY[$i]},add_bos_token=true,escape_until=true,block_size=${HUMANEVAL_BLOCK_SIZES_ARRAY[$i]},block_add_threshold=${HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY[$i]},skip_threshold=${HUMANEVAL_SKIP_THRESHOLDS_ARRAY[$i]},decoded_token_threshold=${HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY[$i]},dtype=${HUMANEVAL_DTYPES_ARRAY[$i]},sampling_strategy=${HUMANEVAL_SAMPLING_STRATEGIES_ARRAY[$i]},save_dir=${output_path}"
|
| 120 |
+
fi
|
| 121 |
+
|
| 122 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch --main_process_port 29520 --num_processes 8 eval_dream.py --model dream_lora \
|
| 123 |
+
--model_args $humaneval_model_args \
|
| 124 |
+
--tasks humaneval \
|
| 125 |
+
--num_fewshot ${HUMANEVAL_NSHOTS_ARRAY[$i]} \
|
| 126 |
+
--batch_size 1 \
|
| 127 |
+
--output_path $output_path \
|
| 128 |
+
--log_samples \
|
| 129 |
+
--confirm_run_unsafe_code
|
| 130 |
+
done
|
| 131 |
+
|
| 132 |
+
### NOTICE: use postprocess for humaneval
|
| 133 |
+
# python postprocess_code.py {the samples_xxx.jsonl file under output_path}
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
for i in "${!TASKS_ARRAY[@]}"; do
|
| 137 |
+
output_path="evals_dream${lora_model_name}/${TASKS_ARRAY[$i]}-ns${NSHOTS_ARRAY[$i]}-len${LENGTH_ARRAY[$i]}-temp${TEMP_ARRAY[$i]}-limit${LIMITS_ARRAY[$i]}-diffsteps${LENGTH_ARRAY[$i]}-block${BLOCK_SIZES_ARRAY[$i]}-thresh${BLOCK_ADD_THRESHOLDS_ARRAY[$i]}-decodethresh${DECODED_TOKEN_THRESHOLDS_ARRAY[$i]}-skip${SKIP_THRESHOLDS_ARRAY[$i]}-topp${TOP_PS_ARRAY[$i]}-dtype${DTYPES_ARRAY[$i]}-sampling${SAMPLING_STRATEGIES_ARRAY[$i]}"
|
| 138 |
+
echo "Task: ${TASKS_ARRAY[$i]}, Shots: ${NSHOTS_ARRAY[$i]}, Length: ${LENGTH_ARRAY[$i]}, Temperature: ${TEMP_ARRAY[$i]}, Limit: ${LIMITS_ARRAY[$i]}, Block Size: ${BLOCK_SIZES_ARRAY[$i]}, Block Add Threshold: ${BLOCK_ADD_THRESHOLDS_ARRAY[$i]}, Decoded Token Threshold: ${DECODED_TOKEN_THRESHOLDS_ARRAY[$i]}, Skip Threshold: ${SKIP_THRESHOLDS_ARRAY[$i]}, Top_p: ${TOP_PS_ARRAY[$i]}, Sampling Strategy: ${SAMPLING_STRATEGIES_ARRAY[$i]}, Dtype: ${DTYPES_ARRAY[$i]}; Output: $output_path"
|
| 139 |
+
|
| 140 |
+
if [[ "${TOP_PS_ARRAY[$i]}" == "none" ]]; then
|
| 141 |
+
model_args="pretrained=${base_model},lora_path=${lora_model},max_new_tokens=${LENGTH_ARRAY[$i]},diffusion_steps=${LENGTH_ARRAY[$i]},add_bos_token=true,temperature=${TEMP_ARRAY[$i]},block_size=${BLOCK_SIZES_ARRAY[$i]},block_add_threshold=${BLOCK_ADD_THRESHOLDS_ARRAY[$i]},skip_threshold=${SKIP_THRESHOLDS_ARRAY[$i]},decoded_token_threshold=${DECODED_TOKEN_THRESHOLDS_ARRAY[$i]},dtype=${DTYPES_ARRAY[$i]},sampling_strategy=${SAMPLING_STRATEGIES_ARRAY[$i]},save_dir=${output_path}"
|
| 142 |
+
else
|
| 143 |
+
model_args="pretrained=${base_model},lora_path=${lora_model},max_new_tokens=${LENGTH_ARRAY[$i]},diffusion_steps=${LENGTH_ARRAY[$i]},add_bos_token=true,temperature=${TEMP_ARRAY[$i]},top_p=${TOP_PS_ARRAY[$i]},block_size=${BLOCK_SIZES_ARRAY[$i]},block_add_threshold=${BLOCK_ADD_THRESHOLDS_ARRAY[$i]},skip_threshold=${SKIP_THRESHOLDS_ARRAY[$i]},decoded_token_threshold=${DECODED_TOKEN_THRESHOLDS_ARRAY[$i]},dtype=${DTYPES_ARRAY[$i]},sampling_strategy=${SAMPLING_STRATEGIES_ARRAY[$i]},save_dir=${output_path}"
|
| 144 |
+
fi
|
| 145 |
+
|
| 146 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch --main_process_port 29520 --num_processes 8 eval_dream.py --model dream_lora \
|
| 147 |
+
--model_args $model_args \
|
| 148 |
+
--tasks ${TASKS_ARRAY[$i]} \
|
| 149 |
+
--limit ${LIMITS_ARRAY[$i]} \
|
| 150 |
+
--num_fewshot ${NSHOTS_ARRAY[$i]} \
|
| 151 |
+
--batch_size 1 \
|
| 152 |
+
--output_path $output_path \
|
| 153 |
+
--log_samples \
|
| 154 |
+
--confirm_run_unsafe_code
|
| 155 |
+
done
|
| 156 |
+
done
|
| 157 |
+
|
| 158 |
+
echo "All evaluations completed!"
|
Discrete-Diffusion-Forcing/D2F-eval/eval_dream_d2f_vllm.py
ADDED
|
@@ -0,0 +1,764 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import gc
|
| 3 |
+
import time
|
| 4 |
+
import json
|
| 5 |
+
from datetime import timedelta
|
| 6 |
+
from typing import List, Optional, Tuple, Type, TypeVar, Union
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torch.distributions as dists
|
| 10 |
+
import transformers
|
| 11 |
+
from accelerate import (
|
| 12 |
+
Accelerator,
|
| 13 |
+
InitProcessGroupKwargs,
|
| 14 |
+
)
|
| 15 |
+
from datasets import Dataset
|
| 16 |
+
from packaging import version
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from peft import PeftConfig, PeftModel
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
from lm_eval import utils
|
| 22 |
+
from lm_eval.api.instance import Instance
|
| 23 |
+
from lm_eval.api.model import LM
|
| 24 |
+
from lm_eval.api.registry import register_model
|
| 25 |
+
from lm_eval.models.utils import get_dtype
|
| 26 |
+
from lm_eval.__main__ import cli_evaluate
|
| 27 |
+
|
| 28 |
+
eval_logger = logging.getLogger(__name__)
|
| 29 |
+
T = TypeVar("T", bound="LM")
|
| 30 |
+
import random
|
| 31 |
+
def set_seed(seed):
|
| 32 |
+
torch.manual_seed(seed)
|
| 33 |
+
random.seed(seed)
|
| 34 |
+
np.random.seed(seed)
|
| 35 |
+
|
| 36 |
+
torch.backends.cudnn.deterministic = True
|
| 37 |
+
torch.backends.cudnn.benchmark = False
|
| 38 |
+
|
| 39 |
+
def shift_logits(logits):
|
| 40 |
+
shifted_logits = torch.zeros_like(logits)
|
| 41 |
+
shifted_logits[:, 1:, :] = logits[:, :-1, :]
|
| 42 |
+
shifted_logits[:, 0, :] = 1.0
|
| 43 |
+
return shifted_logits
|
| 44 |
+
|
| 45 |
+
def create_full_block_attention_mask(prompt_length, max_length, block_size, device=None, dtype=None):
|
| 46 |
+
"""
|
| 47 |
+
Creates a complete attention mask for the entire sequence with block-based causal attention.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
prompt_length: Length of the prompt (first irregular block)
|
| 51 |
+
max_length: Maximum total sequence length
|
| 52 |
+
block_size: Size of each regular block
|
| 53 |
+
device: Device to create tensor on
|
| 54 |
+
dtype: Data type for the attention mask
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
attention_mask: Tensor of shape [1, 1, max_length, max_length]
|
| 58 |
+
"""
|
| 59 |
+
# Use the provided dtype or default to bfloat16
|
| 60 |
+
if dtype is None:
|
| 61 |
+
dtype = torch.bfloat16
|
| 62 |
+
|
| 63 |
+
# Initialize mask with -inf (no attention)
|
| 64 |
+
attention_mask = torch.full((1, 1, max_length, max_length), -torch.inf, device=device, dtype=dtype)
|
| 65 |
+
|
| 66 |
+
# Block 0: Prompt (can see itself)
|
| 67 |
+
attention_mask[:, :, :prompt_length, :prompt_length] = 0
|
| 68 |
+
|
| 69 |
+
# Calculate the number of regular blocks after prompt
|
| 70 |
+
remaining_length = max_length - prompt_length
|
| 71 |
+
num_blocks = (remaining_length + block_size - 1) // block_size
|
| 72 |
+
|
| 73 |
+
# Process each regular block
|
| 74 |
+
for b in range(num_blocks):
|
| 75 |
+
block_start = prompt_length + b * block_size
|
| 76 |
+
block_end = min(prompt_length + (b + 1) * block_size, max_length)
|
| 77 |
+
|
| 78 |
+
# Current block can see the prompt
|
| 79 |
+
attention_mask[:, :, block_start:block_end, :prompt_length] = 0
|
| 80 |
+
|
| 81 |
+
# Current block can see all previous regular blocks
|
| 82 |
+
for prev_b in range(b):
|
| 83 |
+
prev_start = prompt_length + prev_b * block_size
|
| 84 |
+
prev_end = min(prompt_length + (prev_b + 1) * block_size, max_length)
|
| 85 |
+
attention_mask[:, :, block_start:block_end, prev_start:prev_end] = 0
|
| 86 |
+
|
| 87 |
+
# Current block can see itself (full attention within block)
|
| 88 |
+
attention_mask[:, :, block_start:block_end, block_start:block_end] = 0
|
| 89 |
+
|
| 90 |
+
return attention_mask
|
| 91 |
+
|
| 92 |
+
def extract_attention_mask(full_mask, start_pos, input_length, cache_length):
|
| 93 |
+
"""
|
| 94 |
+
Extract the relevant portion of attention mask for current forward pass.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
full_mask: Complete attention mask [1, 1, max_length, max_length]
|
| 98 |
+
start_pos: Starting position in the full sequence
|
| 99 |
+
input_length: Length of current input sequence
|
| 100 |
+
cache_length: Length of cached sequence
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
attention_mask: Extracted mask [1, 1, input_length, cache_length + input_length]
|
| 104 |
+
"""
|
| 105 |
+
end_pos = start_pos + input_length
|
| 106 |
+
total_length = cache_length + input_length
|
| 107 |
+
|
| 108 |
+
# Extract the relevant rows (current input positions)
|
| 109 |
+
# and columns (cache + current input positions)
|
| 110 |
+
extracted_mask = torch.full((1, 1, input_length, total_length), -torch.inf,
|
| 111 |
+
device=full_mask.device, dtype=full_mask.dtype)
|
| 112 |
+
|
| 113 |
+
# Copy cache columns (0 to cache_length in the extracted mask corresponds to 0 to cache_length in full mask)
|
| 114 |
+
extracted_mask[:, :, :, :cache_length] = full_mask[:, :, start_pos:end_pos, :cache_length]
|
| 115 |
+
|
| 116 |
+
# Copy current input columns
|
| 117 |
+
extracted_mask[:, :, :, cache_length:] = full_mask[:, :, start_pos:end_pos, start_pos:end_pos]
|
| 118 |
+
|
| 119 |
+
return extracted_mask
|
| 120 |
+
|
| 121 |
+
def build_custom_float_attention_mask(input_ids, prompt_length, block_size, device=None, dtype=None):
|
| 122 |
+
B, seq_len = input_ids.shape
|
| 123 |
+
# Use the provided dtype or default to float32
|
| 124 |
+
if dtype is None:
|
| 125 |
+
dtype = torch.float32
|
| 126 |
+
# Initialize to all -inf
|
| 127 |
+
attn_mask = torch.full((B, 1, seq_len, seq_len), float('-inf'), dtype=dtype, device=device)
|
| 128 |
+
# 1. Prompt part: each token can attend to the entire prompt
|
| 129 |
+
for i in range(B):
|
| 130 |
+
attn_mask[i, :, :, :prompt_length[i]] = 0.0 # Allow all tokens to see the prompt
|
| 131 |
+
|
| 132 |
+
# 2. Block division: divide into blocks starting from prompt_length
|
| 133 |
+
num_blocks = (seq_len - prompt_length[i] + block_size - 1) // block_size
|
| 134 |
+
|
| 135 |
+
for b in range(num_blocks):
|
| 136 |
+
block_start = prompt_length[i] + b * block_size
|
| 137 |
+
block_end = min(block_start + block_size, seq_len)
|
| 138 |
+
|
| 139 |
+
# Full attention within the block
|
| 140 |
+
attn_mask[i, :, block_start:block_end, block_start:block_end] = 0.0
|
| 141 |
+
|
| 142 |
+
# Causal attention between blocks (can only see previous blocks)
|
| 143 |
+
for prev_b in range(b):
|
| 144 |
+
prev_start = prompt_length[i] + prev_b * block_size
|
| 145 |
+
prev_end = min(prev_start + block_size, seq_len)
|
| 146 |
+
|
| 147 |
+
# Current block can see previous blocks
|
| 148 |
+
attn_mask[i, :, block_start:block_end, prev_start:prev_end] = 0.0
|
| 149 |
+
|
| 150 |
+
return attn_mask
|
| 151 |
+
|
| 152 |
+
def top_p_logits(logits, top_p=None):
|
| 153 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 154 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 155 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 156 |
+
# Shift the indices to the right to keep the first token above the threshold
|
| 157 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 158 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 159 |
+
|
| 160 |
+
mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
|
| 161 |
+
mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
|
| 162 |
+
logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
|
| 163 |
+
return logits
|
| 164 |
+
|
| 165 |
+
def top_k_logits(logits, top_k=None):
|
| 166 |
+
top_k = min(top_k, logits.size(-1)) # Safety check
|
| 167 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
| 168 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
| 169 |
+
logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
|
| 170 |
+
return logits
|
| 171 |
+
|
| 172 |
+
def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
|
| 173 |
+
if temperature > 0:
|
| 174 |
+
logits = logits / temperature
|
| 175 |
+
if top_p is not None and top_p < 1:
|
| 176 |
+
logits = top_p_logits(logits, top_p)
|
| 177 |
+
if top_k is not None:
|
| 178 |
+
logits = top_k_logits(logits, top_k)
|
| 179 |
+
probs = torch.softmax(logits, dim=-1)
|
| 180 |
+
|
| 181 |
+
if temperature > 0:
|
| 182 |
+
try:
|
| 183 |
+
x0 = dists.Categorical(probs=probs).sample()
|
| 184 |
+
initial_confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
|
| 185 |
+
except:
|
| 186 |
+
initial_confidence, x0 = probs.max(dim=-1)
|
| 187 |
+
else:
|
| 188 |
+
initial_confidence, x0 = probs.max(dim=-1)
|
| 189 |
+
|
| 190 |
+
# Save initial confidence
|
| 191 |
+
confidence = initial_confidence.clone()
|
| 192 |
+
|
| 193 |
+
if margin_confidence:
|
| 194 |
+
sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
|
| 195 |
+
# Extract top1 and top2 probabilities
|
| 196 |
+
top1_probs = sorted_probs[:, 0]
|
| 197 |
+
top2_probs = sorted_probs[:, 1]
|
| 198 |
+
# Calculate confidence as top1 - top2
|
| 199 |
+
confidence = top1_probs - top2_probs
|
| 200 |
+
|
| 201 |
+
if neg_entropy:
|
| 202 |
+
epsilon = 1e-10
|
| 203 |
+
log_probs = torch.log(probs + epsilon)
|
| 204 |
+
confidence = torch.sum(probs * log_probs, dim=-1)
|
| 205 |
+
|
| 206 |
+
return confidence, x0, initial_confidence
|
| 207 |
+
|
| 208 |
+
@register_model("dream_lora")
|
| 209 |
+
class DreamLoRA(LM):
|
| 210 |
+
def __init__(
|
| 211 |
+
self,
|
| 212 |
+
pretrained: Union[str, transformers.PreTrainedModel],
|
| 213 |
+
lora_path: str,
|
| 214 |
+
batch_size: Optional[Union[int, str]] = 1,
|
| 215 |
+
device: Optional[str] = "cuda",
|
| 216 |
+
dtype: Optional[Union[str, torch.dtype]] = "auto",
|
| 217 |
+
max_new_tokens: Optional[int] = 128,
|
| 218 |
+
max_length: Optional[int] = 2048, # Updated to match example code
|
| 219 |
+
add_bos_token: Optional[bool] = False,
|
| 220 |
+
nll_type: Optional[str] = "mc",
|
| 221 |
+
log_type: Optional[str] = "ftb",
|
| 222 |
+
mc_num: Optional[int] = 128,
|
| 223 |
+
classifier_free_guidance: Optional[float] = 1.0,
|
| 224 |
+
sampling_eps: Optional[float] = 1e-3,
|
| 225 |
+
diffusion_steps: Optional[int] = 128,
|
| 226 |
+
trust_remote_code: Optional[bool] = True,
|
| 227 |
+
parallelize: Optional[bool] = False,
|
| 228 |
+
autogptq: Optional[Union[bool, str]] = False,
|
| 229 |
+
temperature: Optional[float] = 0.2, # Updated default
|
| 230 |
+
top_p: Optional[float] = None, # Updated default
|
| 231 |
+
top_k: Optional[float] = None,
|
| 232 |
+
alg: Optional[str] = "entropy",
|
| 233 |
+
alg_temp: Optional[float] = 0.0,
|
| 234 |
+
escape_until: Optional[bool] = False,
|
| 235 |
+
block_size: Optional[int] = 4, # Updated to match example code
|
| 236 |
+
mask_token_id: Optional[int] = 151666, # Added mask_token_id parameter
|
| 237 |
+
block_add_threshold: Optional[float] = 0.5, # Added block_add_threshold parameter
|
| 238 |
+
decoded_token_threshold: Optional[int] = 0.9, # Added decoded_token_threshold parameter
|
| 239 |
+
skip_threshold: Optional[float] = 1.0, # Added skip_threshold parameter
|
| 240 |
+
sampling_strategy: Optional[str] = "default", # Added sampling_strategy parameter
|
| 241 |
+
save_dir: Optional[str] = None,
|
| 242 |
+
**kwargs,
|
| 243 |
+
) -> None:
|
| 244 |
+
super().__init__()
|
| 245 |
+
|
| 246 |
+
# prepare for parallelism
|
| 247 |
+
assert isinstance(device, str)
|
| 248 |
+
assert isinstance(pretrained, str)
|
| 249 |
+
assert isinstance(batch_size, (int, str))
|
| 250 |
+
|
| 251 |
+
gpus = torch.cuda.device_count()
|
| 252 |
+
accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
|
| 253 |
+
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
|
| 254 |
+
if accelerator.num_processes > 1:
|
| 255 |
+
self.accelerator = accelerator
|
| 256 |
+
|
| 257 |
+
if "npu" in accelerator.device.type:
|
| 258 |
+
gpus = torch.npu.device_count()
|
| 259 |
+
|
| 260 |
+
# using one process with no model parallelism
|
| 261 |
+
if not (parallelize or accelerator.num_processes > 1):
|
| 262 |
+
# use user-passed device
|
| 263 |
+
device_list = set(
|
| 264 |
+
["cuda", "cpu"]
|
| 265 |
+
+ [f"cuda:{i}" for i in range(gpus)]
|
| 266 |
+
+ ["mps", "mps:0"]
|
| 267 |
+
+ [f"npu:{i}" for i in range(gpus)]
|
| 268 |
+
)
|
| 269 |
+
if device and device in device_list:
|
| 270 |
+
self._device = torch.device(device)
|
| 271 |
+
eval_logger.info(f"Using device '{device}'")
|
| 272 |
+
if device in ("mps", "mps:0") and version.parse(
|
| 273 |
+
torch.__version__
|
| 274 |
+
) < version.parse("2.1"):
|
| 275 |
+
raise RuntimeError(
|
| 276 |
+
f"mps requires torch >= 2.1. You have {torch.__version__}"
|
| 277 |
+
)
|
| 278 |
+
else:
|
| 279 |
+
eval_logger.info("Device not specified")
|
| 280 |
+
eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
|
| 281 |
+
self._device = (
|
| 282 |
+
torch.device("cuda")
|
| 283 |
+
if torch.cuda.is_available()
|
| 284 |
+
else torch.device("cpu")
|
| 285 |
+
)
|
| 286 |
+
else: # Parallelism managed by accelerate
|
| 287 |
+
if device != "cuda":
|
| 288 |
+
eval_logger.info(
|
| 289 |
+
f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
|
| 290 |
+
)
|
| 291 |
+
# TODO: include in warning that `load_in_8bit` etc. affect this too
|
| 292 |
+
self._device = (
|
| 293 |
+
self.accelerator.device
|
| 294 |
+
if hasattr(self, "accelerator")
|
| 295 |
+
else torch.device(device)
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
self.batch_size_per_gpu = batch_size
|
| 299 |
+
if isinstance(batch_size, str):
|
| 300 |
+
self.batch_size_per_gpu = int(batch_size)
|
| 301 |
+
|
| 302 |
+
# Save LoRA path and block_size
|
| 303 |
+
self.lora_path = lora_path
|
| 304 |
+
self.block_size = block_size
|
| 305 |
+
self.block_add_threshold = block_add_threshold # New block_add_threshold attribute
|
| 306 |
+
self.skip_threshold = skip_threshold # New skip_threshold attribute
|
| 307 |
+
self.sampling_strategy = sampling_strategy # Save sampling strategy parameter
|
| 308 |
+
self.decoded_token_threshold = decoded_token_threshold # New decoded_token_threshold attribute
|
| 309 |
+
self.save_dir = save_dir
|
| 310 |
+
|
| 311 |
+
# Add metric tracking
|
| 312 |
+
self.total_forward_passes = 0
|
| 313 |
+
self.total_generated_tokens = 0
|
| 314 |
+
self.total_prompts = 0
|
| 315 |
+
# Add time and token statistics
|
| 316 |
+
self.total_generation_time = 0.0
|
| 317 |
+
self.total_block_tokens = 0 # Number of blocks * block_size
|
| 318 |
+
self.total_actual_tokens = 0 # Actual generated tokens (excluding EOS)
|
| 319 |
+
self.total_non_eos_tokens = 0 # Total non-EOS tokens in the entire sequence
|
| 320 |
+
self.all_generation_times = []
|
| 321 |
+
self.all_block_tokens = []
|
| 322 |
+
self.all_actual_tokens = []
|
| 323 |
+
self.all_non_eos_tokens = []
|
| 324 |
+
|
| 325 |
+
# Save target_dtype for later use
|
| 326 |
+
self.target_dtype = get_dtype(dtype)
|
| 327 |
+
|
| 328 |
+
# if isinstance(pretrained, str):
|
| 329 |
+
# if gpus >= 1 or str(self.device) == "mps":
|
| 330 |
+
# # TODO: can remove this whole snippet except in the mps case, perhaps?
|
| 331 |
+
# if not (parallelize or autogptq or hasattr(self, "accelerator")):
|
| 332 |
+
# # place model onto device requested manually,
|
| 333 |
+
# # if not using HF Accelerate or device_map
|
| 334 |
+
# # or any other option that preloads model onto device
|
| 335 |
+
# try:
|
| 336 |
+
# self.model.to(self.device)
|
| 337 |
+
# except ValueError:
|
| 338 |
+
# eval_logger.debug(
|
| 339 |
+
# "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore."
|
| 340 |
+
# )
|
| 341 |
+
# # multigpu data-parallel support when launched with accelerate
|
| 342 |
+
# if gpus > 1:
|
| 343 |
+
# if accelerator.num_processes > 1:
|
| 344 |
+
# if parallelize:
|
| 345 |
+
# eval_logger.warning(
|
| 346 |
+
# "You are both using a HF Accelerate `device_map` (`--model_args parallelize=True`) and launching via `accelerate launch`. This will attempt to do model and data parallelism depending on the resources available."
|
| 347 |
+
# )
|
| 348 |
+
# elif gpus > accelerator.num_processes:
|
| 349 |
+
# eval_logger.warning(
|
| 350 |
+
# "WARNING: The number of total system GPUs does not match the number of spawned processes. "
|
| 351 |
+
# "If you would like to use data parallelism, please launch the script "
|
| 352 |
+
# "with 'accelerate launch *script*'. "
|
| 353 |
+
# f"Current run will proceed with {accelerator.num_processes} devices."
|
| 354 |
+
# )
|
| 355 |
+
# if self.accelerator.is_local_main_process:
|
| 356 |
+
# eval_logger.info(
|
| 357 |
+
# f"Using {gpus} devices with data parallelism"
|
| 358 |
+
# )
|
| 359 |
+
|
| 360 |
+
# self._device = torch.device(f"{accelerator.device}")
|
| 361 |
+
# self.accelerator = accelerator
|
| 362 |
+
|
| 363 |
+
# self._rank = self.accelerator.local_process_index
|
| 364 |
+
# self._world_size = self.accelerator.num_processes
|
| 365 |
+
# else:
|
| 366 |
+
# # if we aren't launching via accelerate, ditch
|
| 367 |
+
# self._rank = 0
|
| 368 |
+
# self._world_size = 1
|
| 369 |
+
# else:
|
| 370 |
+
# # if a PreTrainedModel was passed into HFLM, we forgo distributed setup.
|
| 371 |
+
# eval_logger.warning(
|
| 372 |
+
# "Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration"
|
| 373 |
+
# )
|
| 374 |
+
# self._rank = 0
|
| 375 |
+
# self._world_size = 1
|
| 376 |
+
|
| 377 |
+
self.max_length = max_length
|
| 378 |
+
self.add_bos_token = add_bos_token
|
| 379 |
+
# generation params
|
| 380 |
+
self.max_new_tokens = max_new_tokens
|
| 381 |
+
self.diffusion_steps = diffusion_steps
|
| 382 |
+
self.temperature = temperature
|
| 383 |
+
self.top_p = top_p
|
| 384 |
+
self.top_k = top_k
|
| 385 |
+
self.alg = alg
|
| 386 |
+
self.alg_temp = alg_temp
|
| 387 |
+
self.escape_until = escape_until
|
| 388 |
+
self.block_size = block_size
|
| 389 |
+
self.mask_token_id = mask_token_id
|
| 390 |
+
|
| 391 |
+
# loglikelihood params
|
| 392 |
+
self.nll_type = nll_type
|
| 393 |
+
self.log_type = log_type
|
| 394 |
+
self.mc_num = mc_num
|
| 395 |
+
self.classifier_free_guidance = classifier_free_guidance
|
| 396 |
+
self.sampling_eps = sampling_eps
|
| 397 |
+
|
| 398 |
+
self._create_model_and_tokenizer(pretrained, dtype, trust_remote_code)
|
| 399 |
+
|
| 400 |
+
@property
|
| 401 |
+
def batch_size(self):
|
| 402 |
+
return self.batch_size_per_gpu
|
| 403 |
+
|
| 404 |
+
@property
|
| 405 |
+
def device(self):
|
| 406 |
+
return self._device
|
| 407 |
+
|
| 408 |
+
@property
|
| 409 |
+
def rank(self):
|
| 410 |
+
return self._rank
|
| 411 |
+
|
| 412 |
+
@property
|
| 413 |
+
def world_size(self):
|
| 414 |
+
return self._world_size
|
| 415 |
+
|
| 416 |
+
def _create_model_and_tokenizer(self, pretrained, dtype, trust_remote_code):
|
| 417 |
+
from d2f_vllm import LLM, SamplingParams
|
| 418 |
+
|
| 419 |
+
self.LLM = LLM(
|
| 420 |
+
pretrained,
|
| 421 |
+
lora_path=self.lora_path,
|
| 422 |
+
use_lora=True,
|
| 423 |
+
model_name="dream",
|
| 424 |
+
model_type="diffusion_lm",
|
| 425 |
+
enforce_eager=True,
|
| 426 |
+
tensor_parallel_size=1,
|
| 427 |
+
gpu_memory_utilization=0.60,
|
| 428 |
+
max_num_batched_tokens=2048,
|
| 429 |
+
max_num_seqs=20,
|
| 430 |
+
max_model_len=1024,
|
| 431 |
+
accept_threshold=self.skip_threshold,
|
| 432 |
+
complete_threshold=self.decoded_token_threshold,
|
| 433 |
+
add_new_block_threshold=1-self.block_add_threshold,
|
| 434 |
+
kv_cache_layout="unified"
|
| 435 |
+
)
|
| 436 |
+
self.tokenizer = self.LLM.tokenizer
|
| 437 |
+
self.sampling_params = SamplingParams(temperature=self.temperature, max_tokens=self.max_new_tokens)
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def tok_decode(self, tokens, skip_special_tokens=True):
|
| 441 |
+
return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
| 442 |
+
|
| 443 |
+
def tok_encode(self, text, add_special_tokens=True):
|
| 444 |
+
return self.tokenizer(
|
| 445 |
+
text, return_tensors="pt", add_special_tokens=add_special_tokens
|
| 446 |
+
).input_ids
|
| 447 |
+
|
| 448 |
+
@classmethod
|
| 449 |
+
def create_from_arg_string(
|
| 450 |
+
cls: Type[T], arg_string: str, additional_config: Optional[dict] = None
|
| 451 |
+
) -> T:
|
| 452 |
+
"""
|
| 453 |
+
Creates an instance of the LM class using the given argument string and additional config.
|
| 454 |
+
|
| 455 |
+
Parameters:
|
| 456 |
+
- arg_string: A string containing arguments in the format key1=value1,key2=value2.
|
| 457 |
+
- additional_config: Optional dictionary containing additional configuration parameters.
|
| 458 |
+
|
| 459 |
+
Returns:
|
| 460 |
+
- Instance of the LM class.
|
| 461 |
+
"""
|
| 462 |
+
additional_config = {} if additional_config is None else additional_config
|
| 463 |
+
args = utils.simple_parse_args_string(arg_string)
|
| 464 |
+
args2 = {k: v for k, v in additional_config.items() if v is not None}
|
| 465 |
+
return cls(**args, **args2)
|
| 466 |
+
|
| 467 |
+
def apply_chat_template(
|
| 468 |
+
self, chat_history, add_generation_prompt: bool = True
|
| 469 |
+
) -> str:
|
| 470 |
+
"""
|
| 471 |
+
Method to apply a chat template to a list of chat history between user and model.
|
| 472 |
+
"""
|
| 473 |
+
chat_templated = self.tokenizer.apply_chat_template(
|
| 474 |
+
chat_history,
|
| 475 |
+
tokenize=False,
|
| 476 |
+
add_generation_prompt=add_generation_prompt,
|
| 477 |
+
continue_final_message=not add_generation_prompt,
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
return chat_templated
|
| 481 |
+
|
| 482 |
+
@property
|
| 483 |
+
def tokenizer_name(self) -> str:
|
| 484 |
+
return self.tokenizer.name_or_path.replace("/", "__")
|
| 485 |
+
|
| 486 |
+
def generate_until(self, requests: List[Instance], disable_tqdm: bool = False):
|
| 487 |
+
res = []
|
| 488 |
+
|
| 489 |
+
# Initialize statistics counters
|
| 490 |
+
if not hasattr(self, 'total_generated_tokens'):
|
| 491 |
+
self.total_generated_tokens = 0
|
| 492 |
+
num_tokens = 0
|
| 493 |
+
num_nfe = 0 # Number of Forward Evaluations
|
| 494 |
+
|
| 495 |
+
prompts, gen_args = [], []
|
| 496 |
+
print("Preparing prompts...")
|
| 497 |
+
for req in tqdm(requests):
|
| 498 |
+
prompts.append(self.tokenizer.bos_token + req.arguments[0])
|
| 499 |
+
gen_args.append(req.arguments[1])
|
| 500 |
+
|
| 501 |
+
start_time = time.time()
|
| 502 |
+
|
| 503 |
+
outputs = self.LLM.generate(prompts, self.sampling_params)
|
| 504 |
+
|
| 505 |
+
end_time = time.time()
|
| 506 |
+
total_time = end_time - start_time
|
| 507 |
+
|
| 508 |
+
# Accumulate statistics
|
| 509 |
+
res = [output['text'] for output in outputs]
|
| 510 |
+
num_tokens = sum(len(output['token_ids']) for output in outputs)
|
| 511 |
+
num_nfe = sum(output['n_diff_steps'] for output in outputs)
|
| 512 |
+
|
| 513 |
+
# Save final statistics
|
| 514 |
+
final_stats = {
|
| 515 |
+
'processed_samples': len(requests),
|
| 516 |
+
'total_samples': len(requests),
|
| 517 |
+
'total_tokens': num_tokens,
|
| 518 |
+
'total_nfe': num_nfe,
|
| 519 |
+
'total_time': total_time,
|
| 520 |
+
'tokens_per_second': num_tokens / total_time if total_time > 0 else 0,
|
| 521 |
+
'nfe_per_token': num_nfe / num_tokens if num_tokens > 0 else 0,
|
| 522 |
+
'timestamp': time.strftime('%Y-%m-%d %H:%M:%S')
|
| 523 |
+
}
|
| 524 |
+
|
| 525 |
+
# Save statistics to file
|
| 526 |
+
if self.save_dir is not None:
|
| 527 |
+
import os
|
| 528 |
+
os.makedirs(self.save_dir, exist_ok=True)
|
| 529 |
+
|
| 530 |
+
# Save response results
|
| 531 |
+
save_path = os.path.join(self.save_dir, f'rank_{self.rank}_responses.jsonl')
|
| 532 |
+
with open(save_path, 'w', encoding='utf-8') as f:
|
| 533 |
+
for r in res:
|
| 534 |
+
f.write(json.dumps(r, ensure_ascii=False) + '\n')
|
| 535 |
+
|
| 536 |
+
# Save statistics results
|
| 537 |
+
stats_path = os.path.join(self.save_dir, f'rank_{self.rank}_final_stats.json')
|
| 538 |
+
with open(stats_path, 'w', encoding='utf-8') as f:
|
| 539 |
+
json.dump(final_stats, f, ensure_ascii=False, indent=2)
|
| 540 |
+
|
| 541 |
+
# Print final statistics
|
| 542 |
+
print("\n" + "="*60)
|
| 543 |
+
print("=== Final Statistics ===")
|
| 544 |
+
print("="*60)
|
| 545 |
+
print(f"Processed Samples: {final_stats['processed_samples']}")
|
| 546 |
+
print(f"Total Samples: {final_stats['total_samples']}")
|
| 547 |
+
print(f"Total Tokens: {final_stats['total_tokens']}")
|
| 548 |
+
print(f"Total NFE: {final_stats['total_nfe']}")
|
| 549 |
+
print(f"Total Time: {final_stats['total_time']:.4f}s")
|
| 550 |
+
print(f"Tokens/Second: {final_stats['tokens_per_second']:.2f}")
|
| 551 |
+
print(f"NFE/Token: {final_stats['nfe_per_token']:.4f}")
|
| 552 |
+
print(f"Completion Time: {final_stats['timestamp']}")
|
| 553 |
+
print("="*60)
|
| 554 |
+
|
| 555 |
+
return res
|
| 556 |
+
|
| 557 |
+
def _forward_process(self, batch):
|
| 558 |
+
b, l = batch.shape
|
| 559 |
+
# sample from U[0, 1] following https://arxiv.org/pdf/2107.00630 I.1
|
| 560 |
+
u0 = torch.rand(1, device=batch.device, dtype=torch.float32)
|
| 561 |
+
indices = torch.arange(b, device=batch.device).float()
|
| 562 |
+
t = (u0 + indices / b) % 1
|
| 563 |
+
|
| 564 |
+
p_mask = (1 - self.sampling_eps) * t + self.sampling_eps
|
| 565 |
+
|
| 566 |
+
p_mask = p_mask[:, None].repeat(1, l)
|
| 567 |
+
|
| 568 |
+
mask_indices = torch.rand((b, l), device=batch.device) < p_mask
|
| 569 |
+
# always unmask bos and eos
|
| 570 |
+
mask_indices[:, 0] = False
|
| 571 |
+
mask_indices[:, -1] = False
|
| 572 |
+
|
| 573 |
+
noisy_batch = torch.where(mask_indices, self.mask_token_id, batch)
|
| 574 |
+
return noisy_batch, p_mask
|
| 575 |
+
|
| 576 |
+
@torch.no_grad()
|
| 577 |
+
def get_logits(self, batch, prompt_index):
|
| 578 |
+
'''
|
| 579 |
+
prompt_index : 1D bool tensor, length=batch.shape[1]
|
| 580 |
+
'''
|
| 581 |
+
if self.classifier_free_guidance > 1.:
|
| 582 |
+
assert len(prompt_index) == batch.shape[1]
|
| 583 |
+
prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1)
|
| 584 |
+
un_batch = batch.clone()
|
| 585 |
+
un_batch[prompt_index] = self.mask_token_id
|
| 586 |
+
batch = torch.cat([batch, un_batch])
|
| 587 |
+
|
| 588 |
+
input = batch
|
| 589 |
+
|
| 590 |
+
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
| 591 |
+
logits = self.model(input).logits
|
| 592 |
+
# since bos always unmask, the first logits will not be used
|
| 593 |
+
logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
|
| 594 |
+
|
| 595 |
+
if self.classifier_free_guidance > 1.:
|
| 596 |
+
logits, un_logits = torch.chunk(logits, 2, dim=0)
|
| 597 |
+
logits = un_logits + self.cfg * (logits - un_logits)
|
| 598 |
+
return logits[:, :batch.shape[1]]
|
| 599 |
+
|
| 600 |
+
@torch.no_grad()
|
| 601 |
+
def _eval_target_nll_mc(self, prefix, target):
|
| 602 |
+
if prefix is None:
|
| 603 |
+
seq = target[None, :]
|
| 604 |
+
else:
|
| 605 |
+
seq = torch.concatenate([prefix, target])[None, :]
|
| 606 |
+
seq = seq.repeat((self.batch_size, 1)).to(self.device)
|
| 607 |
+
|
| 608 |
+
if self.log_type == 'ftb':
|
| 609 |
+
prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
|
| 610 |
+
else:
|
| 611 |
+
prompt_index = torch.arange(seq.shape[1], device=self.device) >= len(prefix)
|
| 612 |
+
|
| 613 |
+
loss_acc = []
|
| 614 |
+
for _ in range(max(self.mc_num // self.batch_size, 1)):
|
| 615 |
+
perturbed_seq = seq.clone()
|
| 616 |
+
# eval_logger.info("before noising")
|
| 617 |
+
perturbed_seq_, p_mask = self._forward_process(seq)
|
| 618 |
+
# eval_logger.info("end noising")
|
| 619 |
+
if self.log_type == 'ftb':
|
| 620 |
+
perturbed_seq[:, -len(target):] = perturbed_seq_[:, -len(target):]
|
| 621 |
+
elif self.log_type == 'btf':
|
| 622 |
+
perturbed_seq[:, :len(prefix)] = perturbed_seq_[:, :len(prefix)]
|
| 623 |
+
elif self.log_type == 'union':
|
| 624 |
+
perturbed_seq = perturbed_seq_
|
| 625 |
+
else:
|
| 626 |
+
raise NotImplementedError(self.log_type)
|
| 627 |
+
|
| 628 |
+
mask_indices = perturbed_seq == self.mask_token_id
|
| 629 |
+
logits = self.get_logits(perturbed_seq, prompt_index)
|
| 630 |
+
loss = F.cross_entropy(logits[mask_indices], seq[mask_indices], reduction='none') / p_mask[mask_indices]
|
| 631 |
+
loss = loss.sum() / self.batch_size
|
| 632 |
+
loss_acc.append(loss.item())
|
| 633 |
+
|
| 634 |
+
return sum(loss_acc) / len(loss_acc)
|
| 635 |
+
|
| 636 |
+
@torch.no_grad()
|
| 637 |
+
def _eval_target_nll_ar(self, prefix, target):
|
| 638 |
+
prefix, target = prefix.unsqueeze(0), target.unsqueeze(0) # 1*l1, 1*l2
|
| 639 |
+
assert self.log_type in ['ftb', 'btf']
|
| 640 |
+
assert self.nll_type in ['ar_ftb', 'ar_btf']
|
| 641 |
+
|
| 642 |
+
if self.log_type == 'ftb':
|
| 643 |
+
prompt_index = torch.arange(prefix.shape[1] + target.shape[1], device=self.device) < prefix.shape[1]
|
| 644 |
+
else:
|
| 645 |
+
prompt_index = torch.arange(prefix.shape[1] + target.shape[1], device=self.device) >= prefix.shape[1]
|
| 646 |
+
|
| 647 |
+
if self.log_type == 'ftb':
|
| 648 |
+
perturbed_ = target.repeat(target.shape[1], 1).clone().contiguous() # l2*l2
|
| 649 |
+
else:
|
| 650 |
+
perturbed_ = prefix.repeat(prefix.shape[1], 1).clone().contiguous() # l1*l1
|
| 651 |
+
|
| 652 |
+
mask_index = torch.ones((perturbed_.shape[1], perturbed_.shape[1]), dtype=torch.bool)
|
| 653 |
+
if self.nll_type == 'ar_ftb':
|
| 654 |
+
mask_index = torch.triu(mask_index)
|
| 655 |
+
else:
|
| 656 |
+
mask_index = torch.tril(mask_index)
|
| 657 |
+
perturbed_[mask_index] = self.mask_token_id
|
| 658 |
+
if self.log_type == 'ftb':
|
| 659 |
+
perturbed_seq = torch.cat([prefix.repeat(perturbed_.shape[0], 1), perturbed_], dim=-1)
|
| 660 |
+
else:
|
| 661 |
+
perturbed_seq = torch.cat([perturbed_, target.repeat(perturbed_.shape[0], 1)], dim=-1)
|
| 662 |
+
|
| 663 |
+
logits_ = []
|
| 664 |
+
num = len(perturbed_seq) // self.batch_size if len(perturbed_seq) % self.batch_size == 0 else len(perturbed_seq) // self.batch_size + 1
|
| 665 |
+
for i in range(num):
|
| 666 |
+
end = (i + 1) * self.batch_size if (i + 1) * self.batch_size < len(perturbed_seq) else len(perturbed_seq)
|
| 667 |
+
perturbed_seq_ = perturbed_seq[i * self.batch_size: end]
|
| 668 |
+
perturbed_seq_ = perturbed_seq_.to(self.device)
|
| 669 |
+
if len(perturbed_seq_.shape) == 1:
|
| 670 |
+
perturbed_seq_ = perturbed_seq_.unsqueeze(0)
|
| 671 |
+
logits = self.get_logits(perturbed_seq_, prompt_index)
|
| 672 |
+
logits_.append(logits.cpu())
|
| 673 |
+
logits = torch.cat(logits_, dim=0)
|
| 674 |
+
|
| 675 |
+
temp_index = torch.ones((perturbed_.shape[1], perturbed_.shape[1]), dtype=torch.bool)
|
| 676 |
+
if self.nll_type == 'ar_ftb':
|
| 677 |
+
temp_index = torch.triu(temp_index, diagonal=1)
|
| 678 |
+
else:
|
| 679 |
+
temp_index = torch.tril(temp_index, diagonal=-1)
|
| 680 |
+
mask_index[temp_index] = False
|
| 681 |
+
if self.log_type == 'ftb':
|
| 682 |
+
logits_index = torch.cat([torch.zeros((perturbed_.shape[1], prefix.shape[1]), dtype=torch.bool), mask_index], dim=-1)
|
| 683 |
+
else:
|
| 684 |
+
logits_index = torch.cat([mask_index, torch.zeros((perturbed_.shape[1], target.shape[1]), dtype=torch.bool)], dim=-1)
|
| 685 |
+
|
| 686 |
+
if self.log_type == 'ftb':
|
| 687 |
+
loss = F.cross_entropy(logits[logits_index], target[0], reduction='sum').cpu().item()
|
| 688 |
+
else:
|
| 689 |
+
loss = F.cross_entropy(logits[logits_index], prefix[0], reduction='sum').cpu().item()
|
| 690 |
+
return loss
|
| 691 |
+
|
| 692 |
+
def _encode_pair(self, context, continuation):
|
| 693 |
+
if self.add_bos_token:
|
| 694 |
+
context = self.tokenizer.bos_token + context
|
| 695 |
+
|
| 696 |
+
n_spaces = len(context) - len(context.rstrip())
|
| 697 |
+
if n_spaces > 0:
|
| 698 |
+
continuation = context[-n_spaces:] + continuation
|
| 699 |
+
context = context[:-n_spaces]
|
| 700 |
+
|
| 701 |
+
whole_enc = self.tokenizer.encode(context + continuation) + [self.tokenizer.eos_token_id]
|
| 702 |
+
context_enc = self.tokenizer.encode(context)
|
| 703 |
+
|
| 704 |
+
context_enc_len = len(context_enc)
|
| 705 |
+
continuation_enc = whole_enc[context_enc_len:]
|
| 706 |
+
|
| 707 |
+
# by default truncate on the left
|
| 708 |
+
cutoff_length = max(len(whole_enc) - self.max_length, 0)
|
| 709 |
+
if cutoff_length > 0:
|
| 710 |
+
eval_logger.warning(f"Text length {len(whole_enc)} is larger than {self.max_length}, cutoff on the left side")
|
| 711 |
+
context_remain = context_enc_len-cutoff_length
|
| 712 |
+
if context_remain > 0:
|
| 713 |
+
context_enc = context_enc[-context_remain:]
|
| 714 |
+
else:
|
| 715 |
+
eval_logger.warning(f"All context (prompt) is truncated.")
|
| 716 |
+
context_enc = ""
|
| 717 |
+
continuation_enc = whole_enc[-self.max_length:]
|
| 718 |
+
return context_enc, continuation_enc
|
| 719 |
+
|
| 720 |
+
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
|
| 721 |
+
def _tokenize(e):
|
| 722 |
+
prefix, target = self._encode_pair(e["prefix"], e["target"])
|
| 723 |
+
return {
|
| 724 |
+
"prefix_text": e["prefix"],
|
| 725 |
+
"target_text": e["target"],
|
| 726 |
+
"prefix": prefix,
|
| 727 |
+
"target": target,
|
| 728 |
+
}
|
| 729 |
+
|
| 730 |
+
ds = []
|
| 731 |
+
ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests]
|
| 732 |
+
ds = Dataset.from_list(ds)
|
| 733 |
+
print(ds[0])
|
| 734 |
+
ds = ds.map(_tokenize)
|
| 735 |
+
ds = ds.with_format("torch")
|
| 736 |
+
|
| 737 |
+
out = []
|
| 738 |
+
with torch.no_grad():
|
| 739 |
+
for elem in tqdm(ds, desc="Computing likelihood..."):
|
| 740 |
+
prefix = elem["prefix"]
|
| 741 |
+
target = elem["target"]
|
| 742 |
+
# likelihood calculations are modified from https://github.com/ML-GSAI/SMDM/blob/main/evaluate_diff.py
|
| 743 |
+
if self.nll_type == 'mc':
|
| 744 |
+
ll = -self._eval_target_nll_mc(prefix, target)
|
| 745 |
+
if self.log_type == 'union':
|
| 746 |
+
ll = ll / (len(target) + len(prefix))
|
| 747 |
+
elif self.nll_type == 'ar_ftb' or self.nll_type == 'ar_btf':
|
| 748 |
+
ll = -self._eval_target_nll_ar(prefix, target)
|
| 749 |
+
else:
|
| 750 |
+
raise NotImplementedError(self.nll_type)
|
| 751 |
+
|
| 752 |
+
# TODO: greedy decoding
|
| 753 |
+
is_target_greedy_dec = False
|
| 754 |
+
|
| 755 |
+
out.append((ll, 1.0 if is_target_greedy_dec else 0.0))
|
| 756 |
+
return out
|
| 757 |
+
|
| 758 |
+
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
|
| 759 |
+
raise NotImplementedError
|
| 760 |
+
|
| 761 |
+
|
| 762 |
+
if __name__ == "__main__":
|
| 763 |
+
set_seed(1234)
|
| 764 |
+
cli_evaluate()
|
Discrete-Diffusion-Forcing/D2F-eval/eval_dream_d2f_vllm.sh
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
tasks="gsm8k_cot mbpp minerva_math"
|
| 2 |
+
nshots="8 3 4"
|
| 3 |
+
lengths="256 256 256"
|
| 4 |
+
temperatures="0 0 0"
|
| 5 |
+
limits="10000 10000 10000"
|
| 6 |
+
block_sizes="32 48 64"
|
| 7 |
+
block_add_thresholds="0.1 0.1 0.1"
|
| 8 |
+
decoded_token_thresholds="0.95 0.95 0.95"
|
| 9 |
+
skip_thresholds="0.9 0.9 0.9"
|
| 10 |
+
top_ps="none none none"
|
| 11 |
+
dtypes="bfloat16 bfloat16 bfloat16"
|
| 12 |
+
sampling_strategies="default default default"
|
| 13 |
+
|
| 14 |
+
humaneval_nshots="0"
|
| 15 |
+
humaneval_lengths="256"
|
| 16 |
+
humaneval_temperatures="0"
|
| 17 |
+
humaneval_limits="10000"
|
| 18 |
+
humaneval_diffusion_steps="256"
|
| 19 |
+
humaneval_block_sizes="32"
|
| 20 |
+
humaneval_block_add_thresholds="0.9"
|
| 21 |
+
humaneval_decoded_token_thresholds="0.95"
|
| 22 |
+
humaneval_skip_thresholds="0.95"
|
| 23 |
+
humaneval_top_ps="none"
|
| 24 |
+
humaneval_dtypes="bfloat16"
|
| 25 |
+
humaneval_sampling_strategies="default"
|
| 26 |
+
|
| 27 |
+
base_model=Dream-org/Dream-v0-Base-7B
|
| 28 |
+
|
| 29 |
+
lora_models=(
|
| 30 |
+
"SJTU-Deng-Lab/D2F_Dream_Base_7B_Lora"
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
read -ra TASKS_ARRAY <<< "$tasks"
|
| 34 |
+
read -ra NSHOTS_ARRAY <<< "$nshots"
|
| 35 |
+
read -ra LENGTH_ARRAY <<< "$lengths"
|
| 36 |
+
read -ra TEMP_ARRAY <<< "$temperatures"
|
| 37 |
+
read -ra LIMITS_ARRAY <<< "$limits"
|
| 38 |
+
read -ra BLOCK_SIZES_ARRAY <<< "$block_sizes"
|
| 39 |
+
read -ra BLOCK_ADD_THRESHOLDS_ARRAY <<< "$block_add_thresholds"
|
| 40 |
+
read -ra DECODED_TOKEN_THRESHOLDS_ARRAY <<< "$decoded_token_thresholds"
|
| 41 |
+
read -ra SKIP_THRESHOLDS_ARRAY <<< "$skip_thresholds"
|
| 42 |
+
read -ra TOP_PS_ARRAY <<< "$top_ps"
|
| 43 |
+
read -ra DTYPES_ARRAY <<< "$dtypes"
|
| 44 |
+
read -ra SAMPLING_STRATEGIES_ARRAY <<< "$sampling_strategies"
|
| 45 |
+
|
| 46 |
+
read -ra HUMANEVAL_NSHOTS_ARRAY <<< "$humaneval_nshots"
|
| 47 |
+
read -ra HUMANEVAL_LENGTHS_ARRAY <<< "$humaneval_lengths"
|
| 48 |
+
read -ra HUMANEVAL_TEMP_ARRAY <<< "$humaneval_temperatures"
|
| 49 |
+
read -ra HUMANEVAL_LIMITS_ARRAY <<< "$humaneval_limits"
|
| 50 |
+
read -ra HUMANEVAL_DIFFUSION_STEPS_ARRAY <<< "$humaneval_diffusion_steps"
|
| 51 |
+
read -ra HUMANEVAL_BLOCK_SIZES_ARRAY <<< "$humaneval_block_sizes"
|
| 52 |
+
read -ra HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY <<< "$humaneval_block_add_thresholds"
|
| 53 |
+
read -ra HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY <<< "$humaneval_decoded_token_thresholds"
|
| 54 |
+
read -ra HUMANEVAL_SKIP_THRESHOLDS_ARRAY <<< "$humaneval_skip_thresholds"
|
| 55 |
+
read -ra HUMANEVAL_TOP_PS_ARRAY <<< "$humaneval_top_ps"
|
| 56 |
+
read -ra HUMANEVAL_DTYPES_ARRAY <<< "$humaneval_dtypes"
|
| 57 |
+
read -ra HUMANEVAL_SAMPLING_STRATEGIES_ARRAY <<< "$humaneval_sampling_strategies"
|
| 58 |
+
|
| 59 |
+
array_length=${#TASKS_ARRAY[@]}
|
| 60 |
+
if [[ ${#NSHOTS_ARRAY[@]} -ne $array_length ]] || \
|
| 61 |
+
[[ ${#LENGTH_ARRAY[@]} -ne $array_length ]] || \
|
| 62 |
+
[[ ${#TEMP_ARRAY[@]} -ne $array_length ]] || \
|
| 63 |
+
[[ ${#LIMITS_ARRAY[@]} -ne $array_length ]] || \
|
| 64 |
+
[[ ${#BLOCK_SIZES_ARRAY[@]} -ne $array_length ]] || \
|
| 65 |
+
[[ ${#BLOCK_ADD_THRESHOLDS_ARRAY[@]} -ne $array_length ]] || \
|
| 66 |
+
[[ ${#DECODED_TOKEN_THRESHOLDS_ARRAY[@]} -ne $array_length ]] || \
|
| 67 |
+
[[ ${#SKIP_THRESHOLDS_ARRAY[@]} -ne $array_length ]] || \
|
| 68 |
+
[[ ${#TOP_PS_ARRAY[@]} -ne $array_length ]] || \
|
| 69 |
+
[[ ${#SAMPLING_STRATEGIES_ARRAY[@]} -ne $array_length ]] || \
|
| 70 |
+
[[ ${#DTYPES_ARRAY[@]} -ne $array_length ]]; then
|
| 71 |
+
echo "Error: All configuration arrays must have the same length!"
|
| 72 |
+
exit 1
|
| 73 |
+
fi
|
| 74 |
+
|
| 75 |
+
humaneval_array_length=${#HUMANEVAL_NSHOTS_ARRAY[@]}
|
| 76 |
+
if [[ ${#HUMANEVAL_LENGTHS_ARRAY[@]} -ne $humaneval_array_length ]] || \
|
| 77 |
+
[[ ${#HUMANEVAL_TEMP_ARRAY[@]} -ne $humaneval_array_length ]] || \
|
| 78 |
+
[[ ${#HUMANEVAL_LIMITS_ARRAY[@]} -ne $humaneval_array_length ]] || \
|
| 79 |
+
[[ ${#HUMANEVAL_DIFFUSION_STEPS_ARRAY[@]} -ne $humaneval_array_length ]] || \
|
| 80 |
+
[[ ${#HUMANEVAL_BLOCK_SIZES_ARRAY[@]} -ne $humaneval_array_length ]] || \
|
| 81 |
+
[[ ${#HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY[@]} -ne $humaneval_array_length ]] || \
|
| 82 |
+
[[ ${#HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY[@]} -ne $humaneval_array_length ]] || \
|
| 83 |
+
[[ ${#HUMANEVAL_SKIP_THRESHOLDS_ARRAY[@]} -ne $humaneval_array_length ]] || \
|
| 84 |
+
[[ ${#HUMANEVAL_TOP_PS_ARRAY[@]} -ne $humaneval_array_length ]] || \
|
| 85 |
+
[[ ${#HUMANEVAL_DTYPES_ARRAY[@]} -ne $humaneval_array_length ]] || \
|
| 86 |
+
[[ ${#HUMANEVAL_SAMPLING_STRATEGIES_ARRAY[@]} -ne $humaneval_array_length ]]; then
|
| 87 |
+
echo "Error: All HumanEval configuration arrays must have the same length!"
|
| 88 |
+
exit 1
|
| 89 |
+
fi
|
| 90 |
+
|
| 91 |
+
export HF_ALLOW_CODE_EVAL=1
|
| 92 |
+
for lora_model in "${lora_models[@]}"; do
|
| 93 |
+
lora_model_name="$lora_model"
|
| 94 |
+
echo "===================================================================="
|
| 95 |
+
echo "Evaluating LoRA model: $lora_model_name"
|
| 96 |
+
echo "===================================================================="
|
| 97 |
+
|
| 98 |
+
for i in "${!HUMANEVAL_NSHOTS_ARRAY[@]}"; do
|
| 99 |
+
output_path="evals_dream${lora_model_name}/humaneval-ns${HUMANEVAL_NSHOTS_ARRAY[$i]}-len${HUMANEVAL_LENGTHS_ARRAY[$i]}-temp${HUMANEVAL_TEMP_ARRAY[$i]}-limit${HUMANEVAL_LIMITS_ARRAY[$i]}-diffsteps${HUMANEVAL_DIFFUSION_STEPS_ARRAY[$i]}-block${HUMANEVAL_BLOCK_SIZES_ARRAY[$i]}-thresh${HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY[$i]}-decodethresh${HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY[$i]}-skip${HUMANEVAL_SKIP_THRESHOLDS_ARRAY[$i]}-topp${HUMANEVAL_TOP_PS_ARRAY[$i]}-dtype${HUMANEVAL_DTYPES_ARRAY[$i]}-sampling${HUMANEVAL_SAMPLING_STRATEGIES_ARRAY[$i]}"
|
| 100 |
+
echo "Running HumanEval evaluation $((i+1))/${humaneval_array_length} for $lora_model_name..."
|
| 101 |
+
if [[ "${HUMANEVAL_TOP_PS_ARRAY[$i]}" == "none" ]]; then
|
| 102 |
+
humaneval_model_args="pretrained=${base_model},lora_path=${lora_model},max_new_tokens=${HUMANEVAL_LENGTHS_ARRAY[$i]},diffusion_steps=${HUMANEVAL_DIFFUSION_STEPS_ARRAY[$i]},temperature=${HUMANEVAL_TEMP_ARRAY[$i]},add_bos_token=true,escape_until=true,block_size=${HUMANEVAL_BLOCK_SIZES_ARRAY[$i]},block_add_threshold=${HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY[$i]},skip_threshold=${HUMANEVAL_SKIP_THRESHOLDS_ARRAY[$i]},decoded_token_threshold=${HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY[$i]},dtype=${HUMANEVAL_DTYPES_ARRAY[$i]},sampling_strategy=${HUMANEVAL_SAMPLING_STRATEGIES_ARRAY[$i]},save_dir=${output_path}"
|
| 103 |
+
else
|
| 104 |
+
humaneval_model_args="pretrained=${base_model},lora_path=${lora_model},max_new_tokens=${HUMANEVAL_LENGTHS_ARRAY[$i]},diffusion_steps=${HUMANEVAL_DIFFUSION_STEPS_ARRAY[$i]},temperature=${HUMANEVAL_TEMP_ARRAY[$i]},top_p=${HUMANEVAL_TOP_PS_ARRAY[$i]},add_bos_token=true,escape_until=true,block_size=${HUMANEVAL_BLOCK_SIZES_ARRAY[$i]},block_add_threshold=${HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY[$i]},skip_threshold=${HUMANEVAL_SKIP_THRESHOLDS_ARRAY[$i]},decoded_token_threshold=${HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY[$i]},dtype=${HUMANEVAL_DTYPES_ARRAY[$i]},sampling_strategy=${HUMANEVAL_SAMPLING_STRATEGIES_ARRAY[$i]},save_dir=${output_path}"
|
| 105 |
+
fi
|
| 106 |
+
CUDA_VISIBLE_DEVICES=5 accelerate launch --main_process_port 29520 --num_processes 1 eval_dream_d2f_vllm.py --model dream_lora \
|
| 107 |
+
--model_args $humaneval_model_args \
|
| 108 |
+
--tasks humaneval \
|
| 109 |
+
--num_fewshot ${HUMANEVAL_NSHOTS_ARRAY[$i]} \
|
| 110 |
+
--batch_size 1 \
|
| 111 |
+
--output_path $output_path \
|
| 112 |
+
--log_samples \
|
| 113 |
+
--confirm_run_unsafe_code
|
| 114 |
+
done
|
| 115 |
+
|
| 116 |
+
for i in "${!TASKS_ARRAY[@]}"; do
|
| 117 |
+
output_path="evals_dream${lora_model_name}/${TASKS_ARRAY[$i]}-ns${NSHOTS_ARRAY[$i]}-len${LENGTH_ARRAY[$i]}-temp${TEMP_ARRAY[$i]}-limit${LIMITS_ARRAY[$i]}-diffsteps${LENGTH_ARRAY[$i]}-block${BLOCK_SIZES_ARRAY[$i]}-thresh${BLOCK_ADD_THRESHOLDS_ARRAY[$i]}-decodethresh${DECODED_TOKEN_THRESHOLDS_ARRAY[$i]}-skip${SKIP_THRESHOLDS_ARRAY[$i]}-topp${TOP_PS_ARRAY[$i]}-dtype${DTYPES_ARRAY[$i]}-sampling${SAMPLING_STRATEGIES_ARRAY[$i]}"
|
| 118 |
+
if [[ "${TOP_PS_ARRAY[$i]}" == "none" ]]; then
|
| 119 |
+
model_args="pretrained=${base_model},lora_path=${lora_model},max_new_tokens=${LENGTH_ARRAY[$i]},diffusion_steps=${LENGTH_ARRAY[$i]},add_bos_token=true,temperature=${TEMP_ARRAY[$i]},block_size=${BLOCK_SIZES_ARRAY[$i]},block_add_threshold=${BLOCK_ADD_THRESHOLDS_ARRAY[$i]},skip_threshold=${SKIP_THRESHOLDS_ARRAY[$i]},decoded_token_threshold=${DECODED_TOKEN_THRESHOLDS_ARRAY[$i]},dtype=${DTYPES_ARRAY[$i]},sampling_strategy=${SAMPLING_STRATEGIES_ARRAY[$i]},save_dir=${output_path}"
|
| 120 |
+
else
|
| 121 |
+
model_args="pretrained=${base_model},lora_path=${lora_model},max_new_tokens=${LENGTH_ARRAY[$i]},diffusion_steps=${LENGTH_ARRAY[$i]},add_bos_token=true,temperature=${TEMP_ARRAY[$i]},top_p=${TOP_PS_ARRAY[$i]},block_size=${BLOCK_SIZES_ARRAY[$i]},block_add_threshold=${BLOCK_ADD_THRESHOLDS_ARRAY[$i]},skip_threshold=${SKIP_THRESHOLDS_ARRAY[$i]},decoded_token_threshold=${DECODED_TOKEN_THRESHOLDS_ARRAY[$i]},dtype=${DTYPES_ARRAY[$i]},sampling_strategy=${SAMPLING_STRATEGIES_ARRAY[$i]},save_dir=${output_path}"
|
| 122 |
+
fi
|
| 123 |
+
CUDA_VISIBLE_DEVICES=5 accelerate launch --main_process_port 29520 --num_processes 1 eval_dream_d2f_vllm.py --model dream_lora \
|
| 124 |
+
--model_args $model_args \
|
| 125 |
+
--tasks ${TASKS_ARRAY[$i]} \
|
| 126 |
+
--limit ${LIMITS_ARRAY[$i]} \
|
| 127 |
+
--num_fewshot ${NSHOTS_ARRAY[$i]} \
|
| 128 |
+
--batch_size 1 \
|
| 129 |
+
--output_path $output_path \
|
| 130 |
+
--log_samples \
|
| 131 |
+
--confirm_run_unsafe_code
|
| 132 |
+
done
|
| 133 |
+
done
|
| 134 |
+
|
| 135 |
+
echo "All evaluations completed!"
|
Discrete-Diffusion-Forcing/D2F-eval/eval_llada.py
ADDED
|
@@ -0,0 +1,1198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import gc
|
| 3 |
+
import json
|
| 4 |
+
import time # Add time module
|
| 5 |
+
from datetime import timedelta
|
| 6 |
+
from typing import List, Optional, Tuple, Type, TypeVar, Union, Dict
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torch.distributions as dists
|
| 10 |
+
import transformers
|
| 11 |
+
from transformers import AutoTokenizer
|
| 12 |
+
from peft import LoraConfig, get_peft_model
|
| 13 |
+
from accelerate import (
|
| 14 |
+
Accelerator,
|
| 15 |
+
InitProcessGroupKwargs,
|
| 16 |
+
)
|
| 17 |
+
from datasets import Dataset
|
| 18 |
+
from packaging import version
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
from peft import PeftConfig, PeftModel
|
| 21 |
+
import numpy as np # Add numpy import
|
| 22 |
+
import os
|
| 23 |
+
import jinja2
|
| 24 |
+
|
| 25 |
+
# Import LLaDA model related modules
|
| 26 |
+
from model_cache.llada.modeling_llada import LLaDAModelLM
|
| 27 |
+
from model_cache.llada.configuration_llada import LLaDAConfig
|
| 28 |
+
|
| 29 |
+
from lm_eval import utils
|
| 30 |
+
from lm_eval.api.instance import Instance
|
| 31 |
+
from lm_eval.api.model import TemplateLM
|
| 32 |
+
from lm_eval.api.registry import register_model
|
| 33 |
+
from lm_eval.models.utils import get_dtype
|
| 34 |
+
from lm_eval.__main__ import cli_evaluate
|
| 35 |
+
|
| 36 |
+
eval_logger = logging.getLogger(__name__)
|
| 37 |
+
T = TypeVar("T", bound="TemplateLM")
|
| 38 |
+
|
| 39 |
+
import random
|
| 40 |
+
def set_seed(seed):
|
| 41 |
+
torch.manual_seed(seed)
|
| 42 |
+
random.seed(seed)
|
| 43 |
+
np.random.seed(seed)
|
| 44 |
+
|
| 45 |
+
torch.backends.cudnn.deterministic = True
|
| 46 |
+
torch.backends.cudnn.benchmark = False
|
| 47 |
+
|
| 48 |
+
def create_full_block_attention_mask(prompt_length, max_length, block_size, device=None, dtype=None):
|
| 49 |
+
"""
|
| 50 |
+
Creates a complete attention mask for the entire sequence with block-based causal attention.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
prompt_length: Length of the prompt (first irregular block)
|
| 54 |
+
max_length: Maximum total sequence length
|
| 55 |
+
block_size: Size of each regular block
|
| 56 |
+
device: Device to create tensor on
|
| 57 |
+
dtype: Data type for the attention mask
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
attention_mask: Tensor of shape [1, 1, max_length, max_length]
|
| 61 |
+
"""
|
| 62 |
+
# Use the provided dtype or default to bfloat16
|
| 63 |
+
if dtype is None:
|
| 64 |
+
dtype = torch.bfloat16
|
| 65 |
+
|
| 66 |
+
# Initialize mask with -inf (no attention)
|
| 67 |
+
attention_mask = torch.full((1, 1, max_length, max_length), -torch.inf, device=device, dtype=dtype)
|
| 68 |
+
|
| 69 |
+
# Block 0: Prompt (can see itself)
|
| 70 |
+
attention_mask[:, :, :prompt_length, :prompt_length] = 0
|
| 71 |
+
|
| 72 |
+
# Calculate the number of regular blocks after prompt
|
| 73 |
+
remaining_length = max_length - prompt_length
|
| 74 |
+
num_blocks = (remaining_length + block_size - 1) // block_size
|
| 75 |
+
|
| 76 |
+
# Process each regular block
|
| 77 |
+
for b in range(num_blocks):
|
| 78 |
+
block_start = prompt_length + b * block_size
|
| 79 |
+
block_end = min(prompt_length + (b + 1) * block_size, max_length)
|
| 80 |
+
|
| 81 |
+
# Current block can see the prompt
|
| 82 |
+
attention_mask[:, :, block_start:block_end, :prompt_length] = 0
|
| 83 |
+
|
| 84 |
+
# Current block can see all previous regular blocks
|
| 85 |
+
for prev_b in range(b):
|
| 86 |
+
prev_start = prompt_length + prev_b * block_size
|
| 87 |
+
prev_end = min(prompt_length + (prev_b + 1) * block_size, max_length)
|
| 88 |
+
attention_mask[:, :, block_start:block_end, prev_start:prev_end] = 0
|
| 89 |
+
|
| 90 |
+
# Current block can see itself (full attention within block)
|
| 91 |
+
attention_mask[:, :, block_start:block_end, block_start:block_end] = 0
|
| 92 |
+
|
| 93 |
+
return attention_mask
|
| 94 |
+
|
| 95 |
+
def extract_attention_mask(full_mask, start_pos, input_length, cache_length):
|
| 96 |
+
"""
|
| 97 |
+
Extract the relevant portion of attention mask for current forward pass.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
full_mask: Complete attention mask [1, 1, max_length, max_length]
|
| 101 |
+
start_pos: Starting position in the full sequence
|
| 102 |
+
input_length: Length of current input sequence
|
| 103 |
+
cache_length: Length of cached sequence
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
attention_mask: Extracted mask [1, 1, input_length, cache_length + input_length]
|
| 107 |
+
"""
|
| 108 |
+
end_pos = start_pos + input_length
|
| 109 |
+
total_length = cache_length + input_length
|
| 110 |
+
|
| 111 |
+
# Extract the relevant rows (current input positions)
|
| 112 |
+
# and columns (cache + current input positions)
|
| 113 |
+
extracted_mask = torch.full((1, 1, input_length, total_length), -torch.inf,
|
| 114 |
+
device=full_mask.device, dtype=full_mask.dtype)
|
| 115 |
+
|
| 116 |
+
# Copy cache columns (0 to cache_length in the extracted mask corresponds to 0 to cache_length in full mask)
|
| 117 |
+
extracted_mask[:, :, :, :cache_length] = full_mask[:, :, start_pos:end_pos, :cache_length]
|
| 118 |
+
|
| 119 |
+
# Copy current input columns
|
| 120 |
+
extracted_mask[:, :, :, cache_length:] = full_mask[:, :, start_pos:end_pos, start_pos:end_pos]
|
| 121 |
+
|
| 122 |
+
return extracted_mask
|
| 123 |
+
|
| 124 |
+
def build_custom_float_attention_mask(input_ids, prompt_length, block_size, device=None, dtype=None):
|
| 125 |
+
"""
|
| 126 |
+
Builds a custom float attention mask with block-based causal attention.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
input_ids: Input token IDs.
|
| 130 |
+
prompt_length: Length of the prompt for each sequence in the batch.
|
| 131 |
+
block_size: Size of each regular block.
|
| 132 |
+
device: Device to create tensor on.
|
| 133 |
+
dtype: Data type for the attention mask.
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
attn_mask: Tensor of shape [B, 1, seq_len, seq_len].
|
| 137 |
+
"""
|
| 138 |
+
B, seq_len = input_ids.shape
|
| 139 |
+
# Use the provided dtype or default to float32
|
| 140 |
+
if dtype is None:
|
| 141 |
+
dtype = torch.float32
|
| 142 |
+
# Initialize to all -inf
|
| 143 |
+
attn_mask = torch.full((B, 1, seq_len, seq_len), float('-inf'), dtype=dtype, device=device)
|
| 144 |
+
# 1. Prompt section: each token can attend to the entire prompt
|
| 145 |
+
for i in range(B):
|
| 146 |
+
attn_mask[i, :, :, :prompt_length[i]] = 0.0 # Allow all tokens to see the prompt
|
| 147 |
+
|
| 148 |
+
# 2. Block division: divide blocks starting from prompt_length
|
| 149 |
+
num_blocks = (seq_len - prompt_length[i] + block_size - 1) // block_size
|
| 150 |
+
|
| 151 |
+
for b in range(num_blocks):
|
| 152 |
+
block_start = prompt_length[i] + b * block_size
|
| 153 |
+
block_end = min(block_start + block_size, seq_len)
|
| 154 |
+
|
| 155 |
+
# Full attention within the block
|
| 156 |
+
attn_mask[i, :, block_start:block_end, block_start:block_end] = 0.0
|
| 157 |
+
|
| 158 |
+
# Causal attention between blocks (can only see previous blocks)
|
| 159 |
+
for prev_b in range(b):
|
| 160 |
+
prev_start = prompt_length[i] + prev_b * block_size
|
| 161 |
+
prev_end = min(prev_start + block_size, seq_len)
|
| 162 |
+
|
| 163 |
+
# Current block can see previous blocks
|
| 164 |
+
attn_mask[i, :, block_start:block_end, prev_start:prev_end] = 0.0
|
| 165 |
+
|
| 166 |
+
return attn_mask
|
| 167 |
+
|
| 168 |
+
def top_p_logits(logits, top_p=None):
|
| 169 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 170 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 171 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 172 |
+
# Shift the indices to the right to keep the first token above the threshold
|
| 173 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 174 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 175 |
+
|
| 176 |
+
mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
|
| 177 |
+
mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
|
| 178 |
+
logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
|
| 179 |
+
return logits
|
| 180 |
+
|
| 181 |
+
def top_k_logits(logits, top_k=None):
|
| 182 |
+
top_k = min(top_k, logits.size(-1)) # Safety check
|
| 183 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
| 184 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
| 185 |
+
logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
|
| 186 |
+
return logits
|
| 187 |
+
|
| 188 |
+
def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
|
| 189 |
+
if temperature > 0:
|
| 190 |
+
logits = logits / temperature
|
| 191 |
+
if top_p is not None and top_p < 1:
|
| 192 |
+
logits = top_p_logits(logits, top_p)
|
| 193 |
+
if top_k is not None:
|
| 194 |
+
logits = top_k_logits(logits, top_k)
|
| 195 |
+
probs = torch.softmax(logits, dim=-1)
|
| 196 |
+
|
| 197 |
+
if temperature > 0:
|
| 198 |
+
try:
|
| 199 |
+
x0 = dists.Categorical(probs=probs).sample()
|
| 200 |
+
initial_confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
|
| 201 |
+
except:
|
| 202 |
+
initial_confidence, x0 = probs.max(dim=-1)
|
| 203 |
+
else:
|
| 204 |
+
initial_confidence, x0 = probs.max(dim=-1)
|
| 205 |
+
|
| 206 |
+
# Save initial confidence
|
| 207 |
+
confidence = initial_confidence.clone()
|
| 208 |
+
|
| 209 |
+
if margin_confidence:
|
| 210 |
+
sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
|
| 211 |
+
# Extract top1 and top2 probabilities
|
| 212 |
+
top1_probs = sorted_probs[:, 0]
|
| 213 |
+
top2_probs = sorted_probs[:, 1]
|
| 214 |
+
# Calculate confidence as top1 - top2
|
| 215 |
+
confidence = top1_probs - top2_probs
|
| 216 |
+
|
| 217 |
+
if neg_entropy:
|
| 218 |
+
epsilon = 1e-10
|
| 219 |
+
log_probs = torch.log(probs + epsilon)
|
| 220 |
+
confidence = torch.sum(probs * log_probs, dim=-1)
|
| 221 |
+
|
| 222 |
+
return confidence, x0, initial_confidence
|
| 223 |
+
|
| 224 |
+
@register_model("dream_lora")
|
| 225 |
+
class DreamLoRA(TemplateLM):
|
| 226 |
+
def __init__(
|
| 227 |
+
self,
|
| 228 |
+
pretrained: Union[str, transformers.PreTrainedModel],
|
| 229 |
+
lora_path: str,
|
| 230 |
+
batch_size: Optional[Union[int, str]] = 1,
|
| 231 |
+
device: Optional[str] = "cuda",
|
| 232 |
+
dtype: Optional[Union[str, torch.dtype]] = "auto",
|
| 233 |
+
max_new_tokens: Optional[int] = 128,
|
| 234 |
+
max_length: Optional[int] = 4096, # Updated to match example code
|
| 235 |
+
add_bos_token: Optional[bool] = False,
|
| 236 |
+
nll_type: Optional[str] = "mc",
|
| 237 |
+
log_type: Optional[str] = "ftb",
|
| 238 |
+
mc_num: Optional[int] = 128,
|
| 239 |
+
classifier_free_guidance: Optional[float] = 1.0,
|
| 240 |
+
sampling_eps: Optional[float] = 1e-3,
|
| 241 |
+
diffusion_steps: Optional[int] = 128,
|
| 242 |
+
trust_remote_code: Optional[bool] = True,
|
| 243 |
+
parallelize: Optional[bool] = False,
|
| 244 |
+
autogptq: Optional[Union[bool, str]] = False,
|
| 245 |
+
temperature: Optional[float] = 0.2, # Updated default value
|
| 246 |
+
top_p: Optional[float] = None, # Updated default value
|
| 247 |
+
top_k: Optional[float] = None,
|
| 248 |
+
alg: Optional[str] = "entropy",
|
| 249 |
+
alg_temp: Optional[float] = 0.0,
|
| 250 |
+
escape_until: Optional[bool] = False,
|
| 251 |
+
block_size: Optional[int] = 4, # Updated to match example code
|
| 252 |
+
mask_token_id: Optional[int] = 126336, # Added mask_token_id parameter
|
| 253 |
+
block_add_threshold: Optional[float] = 0.5, # Added block_add_threshold parameter
|
| 254 |
+
decoded_token_threshold: Optional[float] = 0.9, # Added decoded token threshold parameter
|
| 255 |
+
skip_threshold: Optional[float] = 1.0, # Added skip_threshold parameter
|
| 256 |
+
sampling_strategy: Optional[str] = "default", # Added sampling strategy parameter
|
| 257 |
+
save_dir: Optional[str] = None, # Added save directory parameter
|
| 258 |
+
show_speed: Optional[bool] = True, # Added speed statistics parameter
|
| 259 |
+
**kwargs,
|
| 260 |
+
) -> None:
|
| 261 |
+
super().__init__()
|
| 262 |
+
|
| 263 |
+
# prepare for parallelism
|
| 264 |
+
assert isinstance(device, str)
|
| 265 |
+
assert isinstance(pretrained, str)
|
| 266 |
+
assert isinstance(batch_size, (int, str))
|
| 267 |
+
|
| 268 |
+
gpus = torch.cuda.device_count()
|
| 269 |
+
accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
|
| 270 |
+
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
|
| 271 |
+
if accelerator.num_processes > 1:
|
| 272 |
+
self.accelerator = accelerator
|
| 273 |
+
|
| 274 |
+
if "npu" in accelerator.device.type:
|
| 275 |
+
gpus = torch.npu.device_count()
|
| 276 |
+
|
| 277 |
+
# using one process with no model parallelism
|
| 278 |
+
if not (parallelize or accelerator.num_processes > 1):
|
| 279 |
+
# use user-passed device
|
| 280 |
+
device_list = set(
|
| 281 |
+
["cuda", "cpu"]
|
| 282 |
+
+ [f"cuda:{i}" for i in range(gpus)]
|
| 283 |
+
+ ["mps", "mps:0"]
|
| 284 |
+
+ [f"npu:{i}" for i in range(gpus)]
|
| 285 |
+
)
|
| 286 |
+
if device and device in device_list:
|
| 287 |
+
self._device = torch.device(device)
|
| 288 |
+
eval_logger.info(f"Using device '{device}'")
|
| 289 |
+
if device in ("mps", "mps:0") and version.parse(
|
| 290 |
+
torch.__version__
|
| 291 |
+
) < version.parse("2.1"):
|
| 292 |
+
raise RuntimeError(
|
| 293 |
+
f"mps requires torch >= 2.1. You have {torch.__version__}"
|
| 294 |
+
)
|
| 295 |
+
else:
|
| 296 |
+
eval_logger.info("Device not specified")
|
| 297 |
+
eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
|
| 298 |
+
self._device = (
|
| 299 |
+
torch.device("cuda")
|
| 300 |
+
if torch.cuda.is_available()
|
| 301 |
+
else torch.device("cpu")
|
| 302 |
+
)
|
| 303 |
+
else: # Parallelism managed by accelerate
|
| 304 |
+
if device != "cuda":
|
| 305 |
+
eval_logger.info(
|
| 306 |
+
f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
|
| 307 |
+
)
|
| 308 |
+
# TODO: include in warning that `load_in_8bit` etc. affect this too
|
| 309 |
+
self._device = (
|
| 310 |
+
self.accelerator.device
|
| 311 |
+
if hasattr(self, "accelerator")
|
| 312 |
+
else torch.device(device)
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
self.batch_size_per_gpu = batch_size
|
| 316 |
+
if isinstance(batch_size, str):
|
| 317 |
+
self.batch_size_per_gpu = int(batch_size)
|
| 318 |
+
|
| 319 |
+
# Save LoRA path and block_size
|
| 320 |
+
self.lora_path = lora_path
|
| 321 |
+
self.block_size = block_size
|
| 322 |
+
self.block_add_threshold = block_add_threshold # Added block_add_threshold attribute
|
| 323 |
+
self.skip_threshold = skip_threshold # Added skip_threshold attribute
|
| 324 |
+
self.sampling_strategy = sampling_strategy # Save sampling strategy parameter
|
| 325 |
+
self.decoded_token_threshold = decoded_token_threshold # Added decoded token threshold attribute
|
| 326 |
+
|
| 327 |
+
# Save target_dtype for later use
|
| 328 |
+
self.target_dtype = get_dtype(dtype)
|
| 329 |
+
|
| 330 |
+
self._create_model_and_tokenizer(pretrained, dtype, trust_remote_code)
|
| 331 |
+
|
| 332 |
+
if isinstance(pretrained, str):
|
| 333 |
+
if gpus >= 1 or str(self.device) == "mps":
|
| 334 |
+
# TODO: can remove this whole snippet except in the mps case, perhaps?
|
| 335 |
+
if not (parallelize or autogptq or hasattr(self, "accelerator")):
|
| 336 |
+
# place model onto device requested manually,
|
| 337 |
+
# if not using HF Accelerate or device_map
|
| 338 |
+
# or any other option that preloads model onto device
|
| 339 |
+
try:
|
| 340 |
+
self.model.to(self.device)
|
| 341 |
+
except ValueError:
|
| 342 |
+
eval_logger.debug(
|
| 343 |
+
"Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore."
|
| 344 |
+
)
|
| 345 |
+
# multigpu data-parallel support when launched with accelerate
|
| 346 |
+
if gpus > 1:
|
| 347 |
+
if accelerator.num_processes > 1:
|
| 348 |
+
if parallelize:
|
| 349 |
+
eval_logger.warning(
|
| 350 |
+
"You are both using a HF Accelerate `device_map` (`--model_args parallelize=True`) and launching via `accelerate launch`. This will attempt to do model and data parallelism depending on the resources available."
|
| 351 |
+
)
|
| 352 |
+
elif gpus > accelerator.num_processes:
|
| 353 |
+
eval_logger.warning(
|
| 354 |
+
"WARNING: The number of total system GPUs does not match the number of spawned processes. "
|
| 355 |
+
"If you would like to use data parallelism, please launch the script "
|
| 356 |
+
"with 'accelerate launch *script*'. "
|
| 357 |
+
f"Current run will proceed with {accelerator.num_processes} devices."
|
| 358 |
+
)
|
| 359 |
+
if self.accelerator.is_local_main_process:
|
| 360 |
+
eval_logger.info(
|
| 361 |
+
f"Using {gpus} devices with data parallelism"
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
self._device = torch.device(f"{accelerator.device}")
|
| 365 |
+
self.accelerator = accelerator
|
| 366 |
+
|
| 367 |
+
self._rank = self.accelerator.local_process_index
|
| 368 |
+
self._world_size = self.accelerator.num_processes
|
| 369 |
+
else:
|
| 370 |
+
# if we aren't launching via accelerate, ditch
|
| 371 |
+
self._rank = 0
|
| 372 |
+
self._world_size = 1
|
| 373 |
+
else:
|
| 374 |
+
# if a PreTrainedModel was passed into HFLM, we forgo distributed setup.
|
| 375 |
+
eval_logger.warning(
|
| 376 |
+
"Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration"
|
| 377 |
+
)
|
| 378 |
+
self._rank = 0
|
| 379 |
+
self._world_size = 1
|
| 380 |
+
|
| 381 |
+
self.max_length = max_length
|
| 382 |
+
self.add_bos_token = add_bos_token
|
| 383 |
+
# generation params
|
| 384 |
+
self.max_new_tokens = max_new_tokens
|
| 385 |
+
self.diffusion_steps = diffusion_steps
|
| 386 |
+
self.temperature = temperature
|
| 387 |
+
self.top_p = top_p
|
| 388 |
+
self.top_k = top_k
|
| 389 |
+
self.alg = alg
|
| 390 |
+
self.alg_temp = alg_temp
|
| 391 |
+
self.escape_until = escape_until
|
| 392 |
+
self.block_size = block_size
|
| 393 |
+
self.mask_token_id = mask_token_id
|
| 394 |
+
|
| 395 |
+
# loglikelihood params
|
| 396 |
+
self.nll_type = nll_type
|
| 397 |
+
self.log_type = log_type
|
| 398 |
+
self.mc_num = mc_num
|
| 399 |
+
self.classifier_free_guidance = classifier_free_guidance
|
| 400 |
+
self.sampling_eps = sampling_eps
|
| 401 |
+
|
| 402 |
+
# Add backend attribute, consistent with LLaDA.py
|
| 403 |
+
self.backend = "causal"
|
| 404 |
+
|
| 405 |
+
# Add truncation attribute, consistent with LLaDA.py
|
| 406 |
+
self.truncation = False
|
| 407 |
+
|
| 408 |
+
self.save_dir = save_dir
|
| 409 |
+
self.show_speed = show_speed
|
| 410 |
+
|
| 411 |
+
@property
|
| 412 |
+
def batch_size(self):
|
| 413 |
+
return self.batch_size_per_gpu
|
| 414 |
+
|
| 415 |
+
@property
|
| 416 |
+
def eot_token_id(self):
|
| 417 |
+
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
|
| 418 |
+
return self.tokenizer.eos_token_id
|
| 419 |
+
|
| 420 |
+
@property
|
| 421 |
+
def device(self):
|
| 422 |
+
return self._device
|
| 423 |
+
|
| 424 |
+
@property
|
| 425 |
+
def rank(self):
|
| 426 |
+
return self._rank
|
| 427 |
+
|
| 428 |
+
@property
|
| 429 |
+
def world_size(self):
|
| 430 |
+
return self._world_size
|
| 431 |
+
|
| 432 |
+
def _create_model_and_tokenizer(self, pretrained, dtype, trust_remote_code):
|
| 433 |
+
# Get correct data type
|
| 434 |
+
target_dtype = get_dtype(dtype)
|
| 435 |
+
|
| 436 |
+
# Load LLaDA model and configuration
|
| 437 |
+
config = LLaDAConfig.from_pretrained(pretrained)
|
| 438 |
+
self.model = LLaDAModelLM.from_pretrained(
|
| 439 |
+
pretrained,
|
| 440 |
+
config=config,
|
| 441 |
+
torch_dtype=target_dtype,
|
| 442 |
+
trust_remote_code=False,
|
| 443 |
+
).eval()
|
| 444 |
+
|
| 445 |
+
# Load LoRA configuration and model
|
| 446 |
+
peft_config = PeftConfig.from_pretrained(self.lora_path)
|
| 447 |
+
self.model = PeftModel.from_pretrained(self.model, self.lora_path)
|
| 448 |
+
|
| 449 |
+
# Convert data type only when target_dtype is not None and not "auto"
|
| 450 |
+
if target_dtype is not None and target_dtype != "auto":
|
| 451 |
+
self.model = self.model.to(target_dtype)
|
| 452 |
+
|
| 453 |
+
# Move to specified device
|
| 454 |
+
self.model = self.model.to(self.device)
|
| 455 |
+
|
| 456 |
+
# Load tokenizer
|
| 457 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 458 |
+
pretrained, trust_remote_code=trust_remote_code
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
def tok_encode(
|
| 462 |
+
self, string: str, left_truncate_len=None, add_special_tokens=None
|
| 463 |
+
) -> List[int]:
|
| 464 |
+
""" """
|
| 465 |
+
# default for None - empty dict, use predefined tokenizer param
|
| 466 |
+
# used for all models except for CausalLM or predefined value
|
| 467 |
+
special_tokens_kwargs = {}
|
| 468 |
+
|
| 469 |
+
# by default for CausalLM - false or self.add_bos_token is set
|
| 470 |
+
if add_special_tokens is None:
|
| 471 |
+
if self.backend == "causal":
|
| 472 |
+
special_tokens_kwargs = {
|
| 473 |
+
"add_special_tokens": False or self.add_bos_token
|
| 474 |
+
}
|
| 475 |
+
# otherwise the method explicitly defines the value
|
| 476 |
+
else:
|
| 477 |
+
special_tokens_kwargs = {"add_special_tokens": add_special_tokens}
|
| 478 |
+
|
| 479 |
+
encoding = self.tokenizer.encode(string, **special_tokens_kwargs)
|
| 480 |
+
|
| 481 |
+
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
|
| 482 |
+
if left_truncate_len:
|
| 483 |
+
encoding = encoding[-left_truncate_len:]
|
| 484 |
+
return encoding
|
| 485 |
+
|
| 486 |
+
def tok_batch_encode(
|
| 487 |
+
self,
|
| 488 |
+
strings: List[str],
|
| 489 |
+
padding_side: str = "left",
|
| 490 |
+
left_truncate_len: int = None,
|
| 491 |
+
truncation: bool = False,
|
| 492 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 493 |
+
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
|
| 494 |
+
old_padding_side = self.tokenizer.padding_side
|
| 495 |
+
self.tokenizer.padding_side = padding_side
|
| 496 |
+
|
| 497 |
+
add_special_tokens = {}
|
| 498 |
+
if self.backend == "causal":
|
| 499 |
+
add_special_tokens = {"add_special_tokens": False or self.add_bos_token}
|
| 500 |
+
|
| 501 |
+
encoding = self.tokenizer(
|
| 502 |
+
strings,
|
| 503 |
+
truncation=truncation,
|
| 504 |
+
padding="longest",
|
| 505 |
+
return_tensors="pt",
|
| 506 |
+
**add_special_tokens,
|
| 507 |
+
)
|
| 508 |
+
if left_truncate_len:
|
| 509 |
+
original_lengths = encoding["input_ids"].size(1)
|
| 510 |
+
if original_lengths > left_truncate_len:
|
| 511 |
+
eval_logger.warn(
|
| 512 |
+
f"Left truncation applied. Original sequence length was {original_lengths}, "
|
| 513 |
+
f"truncating to last {left_truncate_len} tokens. Some content will be lost.",
|
| 514 |
+
)
|
| 515 |
+
encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
|
| 516 |
+
encoding["attention_mask"] = encoding["attention_mask"][
|
| 517 |
+
:, -left_truncate_len:
|
| 518 |
+
]
|
| 519 |
+
self.tokenizer.padding_side = old_padding_side
|
| 520 |
+
|
| 521 |
+
return encoding["input_ids"].to(self.device), encoding["attention_mask"].to(self.device)
|
| 522 |
+
|
| 523 |
+
def tok_decode(self, tokens, skip_special_tokens=True):
|
| 524 |
+
return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def _count_tokens_after_truncation(self, response_text: str, until_terms: List[str] = None) -> int:
|
| 529 |
+
"""
|
| 530 |
+
Unified token counting function: calculates the number of non-126081 tokens after truncating the response.
|
| 531 |
+
"""
|
| 532 |
+
# Apply truncation based on until parameters
|
| 533 |
+
truncated_text = response_text
|
| 534 |
+
if until_terms and not self.escape_until:
|
| 535 |
+
for term in until_terms:
|
| 536 |
+
if len(term) > 0:
|
| 537 |
+
truncated_text = truncated_text.split(term)[0]
|
| 538 |
+
|
| 539 |
+
# Re-tokenize processed answer and count non-126081 tokens
|
| 540 |
+
generated_answer_ids = torch.tensor(self.tokenizer(truncated_text)["input_ids"])
|
| 541 |
+
return int((generated_answer_ids != 126081).sum())
|
| 542 |
+
|
| 543 |
+
@classmethod
|
| 544 |
+
def create_from_arg_string(
|
| 545 |
+
cls: Type[T], arg_string: str, additional_config: Optional[dict] = None
|
| 546 |
+
) -> T:
|
| 547 |
+
"""
|
| 548 |
+
Creates an instance of the LM class using the given argument string and additional config.
|
| 549 |
+
|
| 550 |
+
Parameters:
|
| 551 |
+
- arg_string: A string containing arguments in the format key1=value1,key2=value2.
|
| 552 |
+
- additional_config: Optional dictionary containing additional configuration parameters.
|
| 553 |
+
|
| 554 |
+
Returns:
|
| 555 |
+
- Instance of the LM class.
|
| 556 |
+
"""
|
| 557 |
+
additional_config = {} if additional_config is None else additional_config
|
| 558 |
+
args = utils.simple_parse_args_string(arg_string)
|
| 559 |
+
args2 = {k: v for k, v in additional_config.items() if v is not None}
|
| 560 |
+
return cls(**args, **args2)
|
| 561 |
+
|
| 562 |
+
def apply_chat_template(
|
| 563 |
+
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
|
| 564 |
+
) -> str:
|
| 565 |
+
"""
|
| 566 |
+
Method to apply a chat template to a list of chat history between user and model.
|
| 567 |
+
"""
|
| 568 |
+
try:
|
| 569 |
+
chat_templated = self.tokenizer.apply_chat_template(
|
| 570 |
+
chat_history,
|
| 571 |
+
tokenize=False,
|
| 572 |
+
add_generation_prompt=add_generation_prompt,
|
| 573 |
+
continue_final_message=not add_generation_prompt,
|
| 574 |
+
)
|
| 575 |
+
except jinja2.exceptions.TemplateError:
|
| 576 |
+
eval_logger.warning(
|
| 577 |
+
"Failed to apply chat template. removing the system role in chat history."
|
| 578 |
+
)
|
| 579 |
+
chat_history = [msg for msg in chat_history if msg["role"] != "system"]
|
| 580 |
+
chat_templated = self.tokenizer.apply_chat_template(
|
| 581 |
+
chat_history,
|
| 582 |
+
tokenize=False,
|
| 583 |
+
add_generation_prompt=add_generation_prompt,
|
| 584 |
+
continue_final_message=not add_generation_prompt,
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
return chat_templated
|
| 588 |
+
|
| 589 |
+
@property
|
| 590 |
+
def tokenizer_name(self) -> str:
|
| 591 |
+
return self.tokenizer.name_or_path.replace("/", "__")
|
| 592 |
+
|
| 593 |
+
def _generate_block_single(self, prompt):
|
| 594 |
+
"""
|
| 595 |
+
Generates a response for a single prompt using parallel block generation, based on KV cache, and uses pre-generated attention masks.
|
| 596 |
+
Returns: generated_sequence (List[int]) - List of generated token IDs
|
| 597 |
+
"""
|
| 598 |
+
self.model.eval()
|
| 599 |
+
|
| 600 |
+
mask_id = self.mask_token_id
|
| 601 |
+
block_size = self.block_size
|
| 602 |
+
block_add_threshold = self.block_add_threshold
|
| 603 |
+
skip_threshold = self.skip_threshold
|
| 604 |
+
|
| 605 |
+
# Pre-generate the full attention mask, using the model's data type
|
| 606 |
+
prompt_length = prompt.shape[1]
|
| 607 |
+
full_attention_mask = create_full_block_attention_mask(
|
| 608 |
+
prompt_length=prompt_length,
|
| 609 |
+
max_length=self.max_length,
|
| 610 |
+
block_size=block_size,
|
| 611 |
+
device=self.device,
|
| 612 |
+
dtype=self.target_dtype if self.target_dtype is not None and self.target_dtype != "auto" else torch.bfloat16
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
with torch.inference_mode():
|
| 616 |
+
# Initialization
|
| 617 |
+
x_t = prompt.to(self.device)
|
| 618 |
+
|
| 619 |
+
# Track block states - states can be: 'active', 'to_cache', 'in_cache'
|
| 620 |
+
# Added 'is_complete' field to indicate whether it's a complete state (True) or incomplete state (False)
|
| 621 |
+
block_states = {
|
| 622 |
+
0: {
|
| 623 |
+
'start_pos': 0,
|
| 624 |
+
'end_pos': prompt.shape[1],
|
| 625 |
+
'mask_count': 0,
|
| 626 |
+
'total_masks': prompt.shape[1],
|
| 627 |
+
'state': 'to_cache', # Prompt is immediately ready for caching
|
| 628 |
+
'is_complete': True, # Prompt is always in a complete state
|
| 629 |
+
},
|
| 630 |
+
}
|
| 631 |
+
|
| 632 |
+
# Initialize cache
|
| 633 |
+
past_key_values = None
|
| 634 |
+
|
| 635 |
+
current_blocks = 0 # Number of active blocks
|
| 636 |
+
step = 0
|
| 637 |
+
eos_detected = False # EOS detection flag
|
| 638 |
+
cache_length = 0
|
| 639 |
+
while current_blocks >= 0:
|
| 640 |
+
step += 1
|
| 641 |
+
|
| 642 |
+
# Check if a new block needs to be added
|
| 643 |
+
if len(block_states)-1 < (self.max_new_tokens // block_size) and not eos_detected:
|
| 644 |
+
last_block_id = len(block_states) - 1
|
| 645 |
+
current_progress = (block_states[last_block_id]['total_masks'] -
|
| 646 |
+
block_states[last_block_id]['mask_count']) / block_states[last_block_id]['total_masks']
|
| 647 |
+
if current_progress >= block_add_threshold:
|
| 648 |
+
# Add new block
|
| 649 |
+
new_block_id = len(block_states)
|
| 650 |
+
new_start_pos = x_t.shape[1]
|
| 651 |
+
x_t = torch.cat([x_t, torch.tensor([[mask_id] * block_size]).to(self.device)], dim=1)
|
| 652 |
+
|
| 653 |
+
block_states[new_block_id] = {
|
| 654 |
+
'start_pos': new_start_pos,
|
| 655 |
+
'end_pos': new_start_pos + block_size,
|
| 656 |
+
'mask_count': block_size,
|
| 657 |
+
'total_masks': block_size,
|
| 658 |
+
'state': 'active',
|
| 659 |
+
'is_complete': False, # New block defaults to an incomplete state
|
| 660 |
+
}
|
| 661 |
+
current_blocks += 1
|
| 662 |
+
|
| 663 |
+
# At the beginning of each loop, update the block's complete/incomplete states
|
| 664 |
+
self._update_block_completion_states(block_states, self.decoded_token_threshold)
|
| 665 |
+
# Check if there are still mask tokens
|
| 666 |
+
mask_index = (x_t == mask_id)
|
| 667 |
+
if mask_index.sum() == 0 and current_blocks == 0:
|
| 668 |
+
break
|
| 669 |
+
|
| 670 |
+
# Determine which blocks need to be added to the cache
|
| 671 |
+
blocks_to_cache = [bid for bid, state in block_states.items()
|
| 672 |
+
if state['state'] == 'to_cache']
|
| 673 |
+
|
| 674 |
+
# Determine the part to be processed
|
| 675 |
+
update_kvcache = 0
|
| 676 |
+
if blocks_to_cache:
|
| 677 |
+
# Find the earliest block to be cached
|
| 678 |
+
earliest_block_id = min(blocks_to_cache)
|
| 679 |
+
earliest_pos = block_states[earliest_block_id]['start_pos']
|
| 680 |
+
|
| 681 |
+
# Find the latest block to be cached
|
| 682 |
+
latest_block_id = max(blocks_to_cache)
|
| 683 |
+
latest_pos = block_states[latest_block_id]['end_pos']
|
| 684 |
+
|
| 685 |
+
# Update the cache for all blocks within this range
|
| 686 |
+
update_kvcache = latest_pos - earliest_pos
|
| 687 |
+
|
| 688 |
+
# Create input sequence for forward pass
|
| 689 |
+
process_start_pos = cache_length
|
| 690 |
+
|
| 691 |
+
if update_kvcache > 0:
|
| 692 |
+
# Need to update cache - use completed blocks
|
| 693 |
+
earliest_block_to_cache = min(blocks_to_cache)
|
| 694 |
+
input_seq = x_t[:, block_states[earliest_block_to_cache]['start_pos']:]
|
| 695 |
+
process_start_pos = block_states[earliest_block_to_cache]['start_pos']
|
| 696 |
+
else:
|
| 697 |
+
# Only process active blocks
|
| 698 |
+
active_blocks = [bid for bid, state in block_states.items() if state['state'] == 'active']
|
| 699 |
+
if active_blocks:
|
| 700 |
+
# Get all active blocks after caching
|
| 701 |
+
earliest_active_after_cache = float('inf')
|
| 702 |
+
for bid in active_blocks:
|
| 703 |
+
if block_states[bid]['start_pos'] >= cache_length:
|
| 704 |
+
earliest_active_after_cache = min(earliest_active_after_cache, block_states[bid]['start_pos'])
|
| 705 |
+
|
| 706 |
+
if earliest_active_after_cache < float('inf'):
|
| 707 |
+
input_seq = x_t[:, earliest_active_after_cache:]
|
| 708 |
+
process_start_pos = earliest_active_after_cache
|
| 709 |
+
else:
|
| 710 |
+
# No active blocks after caching, this should not happen
|
| 711 |
+
input_seq = x_t[:, cache_length:]
|
| 712 |
+
# If cache length is already equal to or exceeds sequence length, exit
|
| 713 |
+
if cache_length >= x_t.shape[1]:
|
| 714 |
+
print(f"Cache length ({cache_length}) >= sequence length ({x_t.shape[1]}) at step {step}. Exiting generation loop.")
|
| 715 |
+
raise Exception("Cache length >= sequence length")
|
| 716 |
+
else:
|
| 717 |
+
# No active blocks, but blocks might need to be cached in the next iteration
|
| 718 |
+
break
|
| 719 |
+
|
| 720 |
+
# Check if input_seq is empty
|
| 721 |
+
if input_seq.shape[1] == 0:
|
| 722 |
+
print(f"Warning: input_seq is empty at step {step}. Breaking generation loop.")
|
| 723 |
+
raise Exception("input_seq is empty")
|
| 724 |
+
|
| 725 |
+
# Extract the attention mask for the current input from the pre-generated full mask
|
| 726 |
+
input_length = input_seq.shape[1]
|
| 727 |
+
attention_mask = extract_attention_mask(
|
| 728 |
+
full_mask=full_attention_mask,
|
| 729 |
+
start_pos=process_start_pos,
|
| 730 |
+
input_length=input_length,
|
| 731 |
+
cache_length=cache_length
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
outputs = self.model(
|
| 735 |
+
input_seq,
|
| 736 |
+
attention_bias=attention_mask,
|
| 737 |
+
past_key_values=past_key_values,
|
| 738 |
+
use_cache=True,
|
| 739 |
+
update_kvcache=update_kvcache+cache_length,
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
# Get current logits - LLaDA model directly uses logits, no shifting needed
|
| 743 |
+
logits = outputs.logits
|
| 744 |
+
|
| 745 |
+
# Update cache if needed
|
| 746 |
+
if update_kvcache > 0:
|
| 747 |
+
# Update cache
|
| 748 |
+
past_key_values = outputs.past_key_values
|
| 749 |
+
|
| 750 |
+
# Mark blocks as cached
|
| 751 |
+
for block_id in blocks_to_cache:
|
| 752 |
+
block_states[block_id]['state'] = 'in_cache'
|
| 753 |
+
|
| 754 |
+
# Process mask tokens for each active block
|
| 755 |
+
blocks_to_deactivate = []
|
| 756 |
+
|
| 757 |
+
for block_id in sorted(block_states.keys()):
|
| 758 |
+
if block_states[block_id]['state'] != 'active':
|
| 759 |
+
continue
|
| 760 |
+
|
| 761 |
+
# Get mask positions for this block
|
| 762 |
+
block_start = block_states[block_id]['start_pos']
|
| 763 |
+
block_end = block_states[block_id]['end_pos']
|
| 764 |
+
block_mask_index = mask_index.clone()
|
| 765 |
+
block_mask_index[:, :block_start] = False
|
| 766 |
+
block_mask_index[:, block_end:] = False
|
| 767 |
+
|
| 768 |
+
# Skip if the current block has no masks
|
| 769 |
+
if block_mask_index.sum() == 0:
|
| 770 |
+
blocks_to_deactivate.append(block_id)
|
| 771 |
+
continue
|
| 772 |
+
|
| 773 |
+
|
| 774 |
+
# Calculate relative position of logits
|
| 775 |
+
logit_offset = block_start - process_start_pos
|
| 776 |
+
block_rel_positions = torch.where(block_mask_index[0, block_start:block_end])[0]
|
| 777 |
+
|
| 778 |
+
|
| 779 |
+
if block_rel_positions.size(0) > 0:
|
| 780 |
+
# Get logits for masked positions
|
| 781 |
+
block_mask_logits = logits[:, logit_offset + block_rel_positions, :]
|
| 782 |
+
|
| 783 |
+
# Sample tokens
|
| 784 |
+
confidence, x0, initial_confidence = sample_tokens(
|
| 785 |
+
block_mask_logits.squeeze(0),
|
| 786 |
+
self.temperature,
|
| 787 |
+
top_p=self.top_p,
|
| 788 |
+
top_k=self.top_k,
|
| 789 |
+
neg_entropy=(self.sampling_strategy == "neg_entropy"),
|
| 790 |
+
margin_confidence=(self.sampling_strategy == "margin_confidence")
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
+
# Use different sampling strategies based on the block's complete/incomplete state
|
| 794 |
+
is_complete = block_states[block_id]['is_complete']
|
| 795 |
+
|
| 796 |
+
if is_complete:
|
| 797 |
+
# Complete state: apply confidence threshold, if no high confidence, select the highest
|
| 798 |
+
high_conf_indices = torch.where(initial_confidence > skip_threshold)[0]
|
| 799 |
+
|
| 800 |
+
if len(high_conf_indices) == 0:
|
| 801 |
+
number_transfer_tokens = 1
|
| 802 |
+
_, transfer_index = torch.topk(confidence, number_transfer_tokens)
|
| 803 |
+
else:
|
| 804 |
+
transfer_index = torch.tensor([], device=self.device, dtype=torch.long)
|
| 805 |
+
|
| 806 |
+
# Merge indices
|
| 807 |
+
all_indices = torch.unique(torch.cat([transfer_index, high_conf_indices]))
|
| 808 |
+
else:
|
| 809 |
+
# Incomplete state: only apply confidence threshold, if no tokens exceed the threshold, select none
|
| 810 |
+
high_conf_indices = torch.where(initial_confidence > skip_threshold)[0]
|
| 811 |
+
all_indices = high_conf_indices
|
| 812 |
+
|
| 813 |
+
# Update tokens
|
| 814 |
+
if len(all_indices) > 0:
|
| 815 |
+
x0_ = torch.zeros_like(x0, device=self.device, dtype=torch.long) + mask_id
|
| 816 |
+
x0_[all_indices] = x0[all_indices].clone()
|
| 817 |
+
|
| 818 |
+
# Map indices back to original positions
|
| 819 |
+
for i, idx in enumerate(all_indices):
|
| 820 |
+
abs_pos = block_start + block_rel_positions[idx]
|
| 821 |
+
x_t[0, abs_pos] = x0_[idx]
|
| 822 |
+
|
| 823 |
+
# Update block state
|
| 824 |
+
block_states[block_id]['mask_count'] -= len(all_indices)
|
| 825 |
+
|
| 826 |
+
# Check for EOS token
|
| 827 |
+
eos_token_id = 126081
|
| 828 |
+
if eos_token_id is not None:
|
| 829 |
+
for idx in all_indices:
|
| 830 |
+
if x0[idx].item() == eos_token_id:
|
| 831 |
+
eos_detected = True
|
| 832 |
+
break
|
| 833 |
+
|
| 834 |
+
# Deactivate this block if no masks remain
|
| 835 |
+
mask_index = (x_t == mask_id)
|
| 836 |
+
block_mask_index = mask_index.clone()
|
| 837 |
+
block_mask_index[:, :block_start] = False
|
| 838 |
+
block_mask_index[:, block_end:] = False
|
| 839 |
+
if block_mask_index.sum() == 0:
|
| 840 |
+
blocks_to_deactivate.append(block_id)
|
| 841 |
+
continue
|
| 842 |
+
|
| 843 |
+
# Deactivate completed blocks and mark them for caching in the next iteration
|
| 844 |
+
for block_id in blocks_to_deactivate:
|
| 845 |
+
if block_states[block_id]['state'] == 'active':
|
| 846 |
+
# Check if all preceding blocks are already in a non-active state
|
| 847 |
+
can_deactivate = True
|
| 848 |
+
for prev_block_id in range(block_id):
|
| 849 |
+
if prev_block_id in block_states and block_states[prev_block_id]['state'] == 'active':
|
| 850 |
+
can_deactivate = False
|
| 851 |
+
break
|
| 852 |
+
|
| 853 |
+
# Only mark the current block as 'to_cache' if all preceding blocks are not active
|
| 854 |
+
if can_deactivate:
|
| 855 |
+
block_states[block_id]['state'] = 'to_cache'
|
| 856 |
+
current_blocks -= 1
|
| 857 |
+
# If there are active preceding blocks, keep the current block in active state (do nothing)
|
| 858 |
+
|
| 859 |
+
if update_kvcache > 0:
|
| 860 |
+
cache_length += update_kvcache
|
| 861 |
+
# Safety check
|
| 862 |
+
if step > 10000:
|
| 863 |
+
print(f"WARNING: Hit safety check at step {step}. Exiting generation loop.")
|
| 864 |
+
break
|
| 865 |
+
|
| 866 |
+
current_text = self.tokenizer.decode(x_t[0, prompt.shape[1]:].tolist(),skip_special_tokens=False)
|
| 867 |
+
|
| 868 |
+
# Generate final answer
|
| 869 |
+
generated_sequence = x_t[0, prompt.shape[1]:].tolist()
|
| 870 |
+
|
| 871 |
+
return generated_sequence
|
| 872 |
+
|
| 873 |
+
|
| 874 |
+
|
| 875 |
+
def generate_until(self, requests: List[Instance], disable_tqdm: bool = False):
|
| 876 |
+
res = []
|
| 877 |
+
start_time = time.time()
|
| 878 |
+
|
| 879 |
+
# Statistics variables
|
| 880 |
+
num_tokens = 0
|
| 881 |
+
num_nfe = 0
|
| 882 |
+
|
| 883 |
+
bar = tqdm(total=len(requests), disable=(disable_tqdm or (self.rank != 0)), desc="Running generate_until requests")
|
| 884 |
+
|
| 885 |
+
for i, req in enumerate(requests):
|
| 886 |
+
question = req.args[0]
|
| 887 |
+
# print("question:",question)
|
| 888 |
+
# exit()
|
| 889 |
+
gen_kwargs = req.args[1]
|
| 890 |
+
|
| 891 |
+
# Process input in LLaDA.py style
|
| 892 |
+
# print("Self.add_bos_token:", self.add_bos_token)
|
| 893 |
+
contexts = [question]
|
| 894 |
+
if self.add_bos_token:
|
| 895 |
+
contexts = [self.tokenizer.bos_token + p for p in contexts]
|
| 896 |
+
|
| 897 |
+
# Use the same tokenization method as LLaDA.py
|
| 898 |
+
context_enc, attn_masks = self.tok_batch_encode(
|
| 899 |
+
contexts,
|
| 900 |
+
truncation=self.truncation,
|
| 901 |
+
)
|
| 902 |
+
|
| 903 |
+
|
| 904 |
+
|
| 905 |
+
input_ids = context_enc[0].unsqueeze(0) # Take the first one and add batch dimension
|
| 906 |
+
|
| 907 |
+
# Add length check
|
| 908 |
+
if input_ids.shape[1] > self.max_length - self.max_new_tokens:
|
| 909 |
+
eval_logger.warning(f"Prompt length {input_ids.shape[1]} is larger than {self.max_length-self.max_new_tokens}, cutoff on the left side")
|
| 910 |
+
input_ids = input_ids[:, -(self.max_length-self.max_new_tokens):]
|
| 911 |
+
|
| 912 |
+
# Generate token IDs
|
| 913 |
+
generated_answer = self._generate_block_single(input_ids)
|
| 914 |
+
|
| 915 |
+
# Use tokenizer.batch_decode for decoding, consistent with LLaDA.py
|
| 916 |
+
cont_toks_list = self.tokenizer.batch_decode([generated_answer], skip_special_tokens=True)
|
| 917 |
+
s = cont_toks_list[0] # Take the first (and only) result
|
| 918 |
+
|
| 919 |
+
# Use unified token counting function
|
| 920 |
+
if self.show_speed:
|
| 921 |
+
num_tokens += self._count_tokens_after_truncation(s, gen_kwargs.get("until", []))
|
| 922 |
+
num_nfe += 1 # NFE uses simplified statistics (fixed to 1)
|
| 923 |
+
|
| 924 |
+
# Handle until truncation in LLaDA.py style
|
| 925 |
+
if not self.escape_until:
|
| 926 |
+
for term in gen_kwargs.get("until", []):
|
| 927 |
+
if len(term) > 0:
|
| 928 |
+
s = s.split(term)[0]
|
| 929 |
+
|
| 930 |
+
res.append(s)
|
| 931 |
+
bar.update(1)
|
| 932 |
+
|
| 933 |
+
bar.close()
|
| 934 |
+
|
| 935 |
+
# Save statistics only at the end
|
| 936 |
+
if self.save_dir is not None:
|
| 937 |
+
os.makedirs(self.save_dir, exist_ok=True)
|
| 938 |
+
final_time = time.time()
|
| 939 |
+
total_time = final_time - start_time
|
| 940 |
+
|
| 941 |
+
final_stats = {
|
| 942 |
+
"processed_samples": len(res),
|
| 943 |
+
"total_samples": len(requests),
|
| 944 |
+
"total_tokens": int(num_tokens),
|
| 945 |
+
"total_nfe": int(num_nfe),
|
| 946 |
+
"total_time": total_time,
|
| 947 |
+
"tokens_per_second": float(num_tokens) / total_time if total_time > 0 else 0.0,
|
| 948 |
+
"nfe_per_token": float(num_nfe) / float(num_tokens) if num_tokens > 0 else 0.0,
|
| 949 |
+
"timestamp": final_time
|
| 950 |
+
}
|
| 951 |
+
final_stats_path = os.path.join(self.save_dir, f'rank_{self.rank}_final_stats.json')
|
| 952 |
+
with open(final_stats_path, 'w', encoding='utf-8') as f:
|
| 953 |
+
json.dump(final_stats, f, ensure_ascii=False, indent=2)
|
| 954 |
+
|
| 955 |
+
if self.show_speed:
|
| 956 |
+
final_time = time.time()
|
| 957 |
+
total_time = final_time - start_time
|
| 958 |
+
print(f"\n=== Final Statistics ===")
|
| 959 |
+
print(f"Processed samples: {len(res)}")
|
| 960 |
+
print(f"Total tokens: {num_tokens}")
|
| 961 |
+
print(f"Total time: {total_time:.2f} seconds")
|
| 962 |
+
print(f"Throughput: {num_tokens / total_time:.2f} tokens/s")
|
| 963 |
+
print(f"Total NFE: {num_nfe}")
|
| 964 |
+
|
| 965 |
+
return res
|
| 966 |
+
|
| 967 |
+
def _forward_process(self, batch):
|
| 968 |
+
b, l = batch.shape
|
| 969 |
+
# sample from U[0, 1] following https://arxiv.org/pdf/2107.00630 I.1
|
| 970 |
+
u0 = torch.rand(1, device=batch.device, dtype=torch.float32)
|
| 971 |
+
indices = torch.arange(b, device=batch.device).float()
|
| 972 |
+
t = (u0 + indices / b) % 1
|
| 973 |
+
|
| 974 |
+
p_mask = (1 - self.sampling_eps) * t + self.sampling_eps
|
| 975 |
+
|
| 976 |
+
p_mask = p_mask[:, None].repeat(1, l)
|
| 977 |
+
|
| 978 |
+
mask_indices = torch.rand((b, l), device=batch.device) < p_mask
|
| 979 |
+
# always unmask bos and eos
|
| 980 |
+
mask_indices[:, 0] = False
|
| 981 |
+
mask_indices[:, -1] = False
|
| 982 |
+
|
| 983 |
+
noisy_batch = torch.where(mask_indices, self.mask_token_id, batch)
|
| 984 |
+
return noisy_batch, p_mask
|
| 985 |
+
|
| 986 |
+
@torch.no_grad()
|
| 987 |
+
def get_logits(self, batch, prompt_index):
|
| 988 |
+
'''
|
| 989 |
+
prompt_index : 1D bool tensor, length=batch.shape[1]
|
| 990 |
+
'''
|
| 991 |
+
if self.classifier_free_guidance > 1.:
|
| 992 |
+
assert len(prompt_index) == batch.shape[1]
|
| 993 |
+
prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1)
|
| 994 |
+
un_batch = batch.clone()
|
| 995 |
+
un_batch[prompt_index] = self.mask_token_id
|
| 996 |
+
batch = torch.cat([batch, un_batch])
|
| 997 |
+
|
| 998 |
+
input = batch
|
| 999 |
+
|
| 1000 |
+
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
| 1001 |
+
logits = self.model(input).logits
|
| 1002 |
+
# since bos always unmask, the first logits will not be used
|
| 1003 |
+
logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
|
| 1004 |
+
|
| 1005 |
+
if self.classifier_free_guidance > 1.:
|
| 1006 |
+
logits, un_logits = torch.chunk(logits, 2, dim=0)
|
| 1007 |
+
logits = un_logits + self.cfg * (logits - un_logits)
|
| 1008 |
+
return logits[:, :batch.shape[1]]
|
| 1009 |
+
|
| 1010 |
+
@torch.no_grad()
|
| 1011 |
+
def _eval_target_nll_mc(self, prefix, target):
|
| 1012 |
+
if prefix is None:
|
| 1013 |
+
seq = target[None, :]
|
| 1014 |
+
else:
|
| 1015 |
+
seq = torch.concatenate([prefix, target])[None, :]
|
| 1016 |
+
seq = seq.repeat((self.batch_size, 1)).to(self.device)
|
| 1017 |
+
|
| 1018 |
+
if self.log_type == 'ftb':
|
| 1019 |
+
prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
|
| 1020 |
+
else:
|
| 1021 |
+
prompt_index = torch.arange(seq.shape[1], device=self.device) >= len(prefix)
|
| 1022 |
+
|
| 1023 |
+
loss_acc = []
|
| 1024 |
+
for _ in range(max(self.mc_num // self.batch_size, 1)):
|
| 1025 |
+
perturbed_seq = seq.clone()
|
| 1026 |
+
# eval_logger.info("before noising")
|
| 1027 |
+
perturbed_seq_, p_mask = self._forward_process(seq)
|
| 1028 |
+
# eval_logger.info("end noising")
|
| 1029 |
+
if self.log_type == 'ftb':
|
| 1030 |
+
perturbed_seq[:, -len(target):] = perturbed_seq_[:, -len(target):]
|
| 1031 |
+
elif self.log_type == 'btf':
|
| 1032 |
+
perturbed_seq[:, :len(prefix)] = perturbed_seq_[:, :len(prefix)]
|
| 1033 |
+
elif self.log_type == 'union':
|
| 1034 |
+
perturbed_seq = perturbed_seq_
|
| 1035 |
+
else:
|
| 1036 |
+
raise NotImplementedError(self.log_type)
|
| 1037 |
+
|
| 1038 |
+
mask_indices = perturbed_seq == self.mask_token_id
|
| 1039 |
+
logits = self.get_logits(perturbed_seq, prompt_index)
|
| 1040 |
+
loss = F.cross_entropy(logits[mask_indices], seq[mask_indices], reduction='none') / p_mask[mask_indices]
|
| 1041 |
+
loss = loss.sum() / self.batch_size
|
| 1042 |
+
loss_acc.append(loss.item())
|
| 1043 |
+
|
| 1044 |
+
return sum(loss_acc) / len(loss_acc)
|
| 1045 |
+
|
| 1046 |
+
@torch.no_grad()
|
| 1047 |
+
def _eval_target_nll_ar(self, prefix, target):
|
| 1048 |
+
prefix, target = prefix.unsqueeze(0), target.unsqueeze(0) # 1*l1, 1*l2
|
| 1049 |
+
assert self.log_type in ['ftb', 'btf']
|
| 1050 |
+
assert self.nll_type in ['ar_ftb', 'ar_btf']
|
| 1051 |
+
|
| 1052 |
+
if self.log_type == 'ftb':
|
| 1053 |
+
prompt_index = torch.arange(prefix.shape[1] + target.shape[1], device=self.device) < prefix.shape[1]
|
| 1054 |
+
else:
|
| 1055 |
+
prompt_index = torch.arange(prefix.shape[1] + target.shape[1], device=self.device) >= prefix.shape[1]
|
| 1056 |
+
|
| 1057 |
+
if self.log_type == 'ftb':
|
| 1058 |
+
perturbed_ = target.repeat(target.shape[1], 1).clone().contiguous() # l2*l2
|
| 1059 |
+
else:
|
| 1060 |
+
perturbed_ = prefix.repeat(prefix.shape[1], 1).clone().contiguous() # l1*l1
|
| 1061 |
+
|
| 1062 |
+
mask_index = torch.ones((perturbed_.shape[1], perturbed_.shape[1]), dtype=torch.bool)
|
| 1063 |
+
if self.nll_type == 'ar_ftb':
|
| 1064 |
+
mask_index = torch.triu(mask_index)
|
| 1065 |
+
else:
|
| 1066 |
+
mask_index = torch.tril(mask_index)
|
| 1067 |
+
perturbed_[mask_index] = self.mask_token_id
|
| 1068 |
+
if self.log_type == 'ftb':
|
| 1069 |
+
perturbed_seq = torch.cat([prefix.repeat(perturbed_.shape[0], 1), perturbed_], dim=-1)
|
| 1070 |
+
else:
|
| 1071 |
+
perturbed_seq = torch.cat([perturbed_, target.repeat(perturbed_.shape[0], 1)], dim=-1)
|
| 1072 |
+
|
| 1073 |
+
logits_ = []
|
| 1074 |
+
num = len(perturbed_seq) // self.batch_size if len(perturbed_seq) % self.batch_size == 0 else len(perturbed_seq) // self.batch_size + 1
|
| 1075 |
+
for i in range(num):
|
| 1076 |
+
end = (i + 1) * self.batch_size if (i + 1) * self.batch_size < len(perturbed_seq) else len(perturbed_seq)
|
| 1077 |
+
perturbed_seq_ = perturbed_seq[i * self.batch_size: end]
|
| 1078 |
+
perturbed_seq_ = perturbed_seq_.to(self.device)
|
| 1079 |
+
if len(perturbed_seq_.shape) == 1:
|
| 1080 |
+
perturbed_seq_ = perturbed_seq_.unsqueeze(0)
|
| 1081 |
+
logits = self.get_logits(perturbed_seq_, prompt_index)
|
| 1082 |
+
logits_.append(logits.cpu())
|
| 1083 |
+
logits = torch.cat(logits_, dim=0)
|
| 1084 |
+
|
| 1085 |
+
temp_index = torch.ones((perturbed_.shape[1], perturbed_.shape[1]), dtype=torch.bool)
|
| 1086 |
+
if self.nll_type == 'ar_ftb':
|
| 1087 |
+
temp_index = torch.triu(temp_index, diagonal=1)
|
| 1088 |
+
else:
|
| 1089 |
+
temp_index = torch.tril(temp_index, diagonal=-1)
|
| 1090 |
+
mask_index[temp_index] = False
|
| 1091 |
+
if self.log_type == 'ftb':
|
| 1092 |
+
logits_index = torch.cat([torch.zeros((perturbed_.shape[1], prefix.shape[1]), dtype=torch.bool), mask_index], dim=-1)
|
| 1093 |
+
else:
|
| 1094 |
+
logits_index = torch.cat([mask_index, torch.zeros((perturbed_.shape[1], target.shape[1]), dtype=torch.bool)], dim=-1)
|
| 1095 |
+
|
| 1096 |
+
if self.log_type == 'ftb':
|
| 1097 |
+
loss = F.cross_entropy(logits[logits_index], target[0], reduction='sum').cpu().item()
|
| 1098 |
+
else:
|
| 1099 |
+
loss = F.cross_entropy(logits[logits_index], prefix[0], reduction='sum').cpu().item()
|
| 1100 |
+
return loss
|
| 1101 |
+
|
| 1102 |
+
def _encode_pair(self, context, continuation):
|
| 1103 |
+
if self.add_bos_token:
|
| 1104 |
+
context = self.tokenizer.bos_token + context
|
| 1105 |
+
|
| 1106 |
+
n_spaces = len(context) - len(context.rstrip())
|
| 1107 |
+
if n_spaces > 0:
|
| 1108 |
+
continuation = context[-n_spaces:] + continuation
|
| 1109 |
+
context = context[:-n_spaces]
|
| 1110 |
+
|
| 1111 |
+
whole_enc = self.tokenizer.encode(context + continuation) + [self.tokenizer.eos_token_id]
|
| 1112 |
+
context_enc = self.tokenizer.encode(context)
|
| 1113 |
+
|
| 1114 |
+
context_enc_len = len(context_enc)
|
| 1115 |
+
continuation_enc = whole_enc[context_enc_len:]
|
| 1116 |
+
|
| 1117 |
+
# by default truncate on the left
|
| 1118 |
+
cutoff_length = max(len(whole_enc) - self.max_length, 0)
|
| 1119 |
+
if cutoff_length > 0:
|
| 1120 |
+
eval_logger.warning(f"Text length {len(whole_enc)} is larger than {self.max_length}, cutoff on the left side")
|
| 1121 |
+
context_remain = context_enc_len-cutoff_length
|
| 1122 |
+
if context_remain > 0:
|
| 1123 |
+
context_enc = context_enc[-context_remain:]
|
| 1124 |
+
else:
|
| 1125 |
+
eval_logger.warning(f"All context (prompt) is truncated.")
|
| 1126 |
+
context_enc = ""
|
| 1127 |
+
continuation_enc = whole_enc[-self.max_length:]
|
| 1128 |
+
return context_enc, continuation_enc
|
| 1129 |
+
|
| 1130 |
+
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
|
| 1131 |
+
def _tokenize(e):
|
| 1132 |
+
prefix, target = self._encode_pair(e["prefix"], e["target"])
|
| 1133 |
+
return {
|
| 1134 |
+
"prefix_text": e["prefix"],
|
| 1135 |
+
"target_text": e["target"],
|
| 1136 |
+
"prefix": prefix,
|
| 1137 |
+
"target": target,
|
| 1138 |
+
}
|
| 1139 |
+
|
| 1140 |
+
ds = []
|
| 1141 |
+
ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests]
|
| 1142 |
+
ds = Dataset.from_list(ds)
|
| 1143 |
+
print(ds[0])
|
| 1144 |
+
ds = ds.map(_tokenize)
|
| 1145 |
+
ds = ds.with_format("torch")
|
| 1146 |
+
|
| 1147 |
+
out = []
|
| 1148 |
+
with torch.no_grad():
|
| 1149 |
+
for elem in tqdm(ds, desc="Computing likelihood..."):
|
| 1150 |
+
prefix = elem["prefix"]
|
| 1151 |
+
target = elem["target"]
|
| 1152 |
+
# likelihood calculations are modified from https://github.com/ML-GSAI/SMDM/blob/main/evaluate_diff.py
|
| 1153 |
+
if self.nll_type == 'mc':
|
| 1154 |
+
ll = -self._eval_target_nll_mc(prefix, target)
|
| 1155 |
+
if self.log_type == 'union':
|
| 1156 |
+
ll = ll / (len(target) + len(prefix))
|
| 1157 |
+
elif self.nll_type == 'ar_ftb' or self.nll_type == 'ar_btf':
|
| 1158 |
+
ll = -self._eval_target_nll_ar(prefix, target)
|
| 1159 |
+
else:
|
| 1160 |
+
raise NotImplementedError(self.nll_type)
|
| 1161 |
+
|
| 1162 |
+
# TODO: greedy decoding
|
| 1163 |
+
is_target_greedy_dec = False
|
| 1164 |
+
|
| 1165 |
+
out.append((ll, 1.0 if is_target_greedy_dec else 0.0))
|
| 1166 |
+
return out
|
| 1167 |
+
|
| 1168 |
+
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
|
| 1169 |
+
raise NotImplementedError
|
| 1170 |
+
|
| 1171 |
+
def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]:
|
| 1172 |
+
raise NotImplementedError
|
| 1173 |
+
|
| 1174 |
+
|
| 1175 |
+
def _update_block_completion_states(self, block_states, decoded_token_threshold):
|
| 1176 |
+
"""
|
| 1177 |
+
Updates the complete/incomplete state of blocks.
|
| 1178 |
+
Iterates through blocks from front to back. If a block's decoded token count exceeds the threshold, the next block to its right (if it exists) is set to a complete state.
|
| 1179 |
+
"""
|
| 1180 |
+
for block_id in sorted(block_states.keys()):
|
| 1181 |
+
# if block_id == 0: # Skip prompt block
|
| 1182 |
+
# continue
|
| 1183 |
+
|
| 1184 |
+
# Calculate decoded tokens for the current block
|
| 1185 |
+
decoded_tokens = block_states[block_id]['total_masks'] - block_states[block_id]['mask_count']
|
| 1186 |
+
decode_ratio = decoded_tokens / block_states[block_id]['total_masks']
|
| 1187 |
+
# If current block's decoded token count exceeds the threshold, the next block (if exists) is set to a complete state
|
| 1188 |
+
# print("decode_ratio",decode_ratio)
|
| 1189 |
+
# print("decoded_token_threshold",decoded_token_threshold)
|
| 1190 |
+
if decode_ratio >= decoded_token_threshold:
|
| 1191 |
+
next_block_id = block_id + 1
|
| 1192 |
+
if next_block_id in block_states:
|
| 1193 |
+
block_states[next_block_id]['is_complete'] = True
|
| 1194 |
+
|
| 1195 |
+
|
| 1196 |
+
if __name__ == "__main__":
|
| 1197 |
+
set_seed(1234)
|
| 1198 |
+
cli_evaluate()
|
Discrete-Diffusion-Forcing/D2F-eval/eval_llada.sh
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
tasks="gsm8k mbpp minerva_math"
|
| 5 |
+
nshots="4 3 0"
|
| 6 |
+
lengths="512 512 512"
|
| 7 |
+
temperatures="0 0 0"
|
| 8 |
+
limits="10000 10000 10000"
|
| 9 |
+
block_sizes="64 32 32"
|
| 10 |
+
block_add_thresholds="0.7 0.9 0.1"
|
| 11 |
+
decoded_token_thresholds="0.95 0.95 0.95"
|
| 12 |
+
skip_thresholds="0.9 0.9 0.9"
|
| 13 |
+
top_ps="none none none"
|
| 14 |
+
dtypes="bfloat16 bfloat16 bfloat16"
|
| 15 |
+
sampling_strategies="default default default"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
humaneval_nshots="0"
|
| 19 |
+
humaneval_lengths="512"
|
| 20 |
+
humaneval_temperatures="0"
|
| 21 |
+
humaneval_limits="10000"
|
| 22 |
+
humaneval_diffusion_steps="512"
|
| 23 |
+
humaneval_block_sizes="32"
|
| 24 |
+
humaneval_block_add_thresholds="0.1"
|
| 25 |
+
humaneval_decoded_token_thresholds="0.95"
|
| 26 |
+
humaneval_skip_thresholds="0.9"
|
| 27 |
+
humaneval_top_ps="none"
|
| 28 |
+
humaneval_dtypes="bfloat16"
|
| 29 |
+
humaneval_sampling_strategies="default"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
base_model=GSAI-ML/LLaDA-8B-Instruct
|
| 33 |
+
|
| 34 |
+
lora_models=(
|
| 35 |
+
"SJTU-Deng-Lab/D2F_LLaDA_Instruct_8B_Lora"
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
read -ra TASKS_ARRAY <<< "$tasks"
|
| 39 |
+
read -ra NSHOTS_ARRAY <<< "$nshots"
|
| 40 |
+
read -ra LENGTH_ARRAY <<< "$lengths"
|
| 41 |
+
read -ra TEMP_ARRAY <<< "$temperatures"
|
| 42 |
+
read -ra LIMITS_ARRAY <<< "$limits"
|
| 43 |
+
read -ra BLOCK_SIZES_ARRAY <<< "$block_sizes"
|
| 44 |
+
read -ra BLOCK_ADD_THRESHOLDS_ARRAY <<< "$block_add_thresholds"
|
| 45 |
+
read -ra DECODED_TOKEN_THRESHOLDS_ARRAY <<< "$decoded_token_thresholds"
|
| 46 |
+
read -ra SKIP_THRESHOLDS_ARRAY <<< "$skip_thresholds"
|
| 47 |
+
read -ra TOP_PS_ARRAY <<< "$top_ps"
|
| 48 |
+
read -ra DTYPES_ARRAY <<< "$dtypes"
|
| 49 |
+
read -ra SAMPLING_STRATEGIES_ARRAY <<< "$sampling_strategies"
|
| 50 |
+
|
| 51 |
+
read -ra HUMANEVAL_NSHOTS_ARRAY <<< "$humaneval_nshots"
|
| 52 |
+
read -ra HUMANEVAL_LENGTHS_ARRAY <<< "$humaneval_lengths"
|
| 53 |
+
read -ra HUMANEVAL_TEMP_ARRAY <<< "$humaneval_temperatures"
|
| 54 |
+
read -ra HUMANEVAL_LIMITS_ARRAY <<< "$humaneval_limits"
|
| 55 |
+
read -ra HUMANEVAL_DIFFUSION_STEPS_ARRAY <<< "$humaneval_diffusion_steps"
|
| 56 |
+
read -ra HUMANEVAL_BLOCK_SIZES_ARRAY <<< "$humaneval_block_sizes"
|
| 57 |
+
read -ra HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY <<< "$humaneval_block_add_thresholds"
|
| 58 |
+
read -ra HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY <<< "$humaneval_decoded_token_thresholds"
|
| 59 |
+
read -ra HUMANEVAL_SKIP_THRESHOLDS_ARRAY <<< "$humaneval_skip_thresholds"
|
| 60 |
+
read -ra HUMANEVAL_TOP_PS_ARRAY <<< "$humaneval_top_ps"
|
| 61 |
+
read -ra HUMANEVAL_DTYPES_ARRAY <<< "$humaneval_dtypes"
|
| 62 |
+
read -ra HUMANEVAL_SAMPLING_STRATEGIES_ARRAY <<< "$humaneval_sampling_strategies"
|
| 63 |
+
|
| 64 |
+
array_length=${#TASKS_ARRAY[@]}
|
| 65 |
+
if [[ ${#NSHOTS_ARRAY[@]} -ne $array_length ]] || \
|
| 66 |
+
[[ ${#LENGTH_ARRAY[@]} -ne $array_length ]] || \
|
| 67 |
+
[[ ${#TEMP_ARRAY[@]} -ne $array_length ]] || \
|
| 68 |
+
[[ ${#LIMITS_ARRAY[@]} -ne $array_length ]] || \
|
| 69 |
+
[[ ${#BLOCK_SIZES_ARRAY[@]} -ne $array_length ]] || \
|
| 70 |
+
[[ ${#BLOCK_ADD_THRESHOLDS_ARRAY[@]} -ne $array_length ]] || \
|
| 71 |
+
[[ ${#DECODED_TOKEN_THRESHOLDS_ARRAY[@]} -ne $array_length ]] || \
|
| 72 |
+
[[ ${#SKIP_THRESHOLDS_ARRAY[@]} -ne $array_length ]] || \
|
| 73 |
+
[[ ${#TOP_PS_ARRAY[@]} -ne $array_length ]] || \
|
| 74 |
+
[[ ${#SAMPLING_STRATEGIES_ARRAY[@]} -ne $array_length ]] || \
|
| 75 |
+
[[ ${#DTYPES_ARRAY[@]} -ne $array_length ]]; then
|
| 76 |
+
echo "Error: All configuration arrays must have the same length!"
|
| 77 |
+
echo "Tasks: ${#TASKS_ARRAY[@]}, Nshots: ${#NSHOTS_ARRAY[@]}, Lengths: ${#LENGTH_ARRAY[@]}, Temperatures: ${#TEMP_ARRAY[@]}, Limits: ${#LIMITS_ARRAY[@]}, Block sizes: ${#BLOCK_SIZES_ARRAY[@]}, Block thresholds: ${#BLOCK_ADD_THRESHOLDS_ARRAY[@]}, Decoded token thresholds: ${#DECODED_TOKEN_THRESHOLDS_ARRAY[@]}, Skip thresholds: ${#SKIP_THRESHOLDS_ARRAY[@]}, Top_ps: ${#TOP_PS_ARRAY[@]}, Sampling strategies: ${#SAMPLING_STRATEGIES_ARRAY[@]}, Dtypes: ${#DTYPES_ARRAY[@]}"
|
| 78 |
+
exit 1
|
| 79 |
+
fi
|
| 80 |
+
|
| 81 |
+
humaneval_array_length=${#HUMANEVAL_NSHOTS_ARRAY[@]}
|
| 82 |
+
if [[ ${#HUMANEVAL_LENGTHS_ARRAY[@]} -ne $humaneval_array_length ]] || \
|
| 83 |
+
[[ ${#HUMANEVAL_TEMP_ARRAY[@]} -ne $humaneval_array_length ]] || \
|
| 84 |
+
[[ ${#HUMANEVAL_LIMITS_ARRAY[@]} -ne $humaneval_array_length ]] || \
|
| 85 |
+
[[ ${#HUMANEVAL_DIFFUSION_STEPS_ARRAY[@]} -ne $humaneval_array_length ]] || \
|
| 86 |
+
[[ ${#HUMANEVAL_BLOCK_SIZES_ARRAY[@]} -ne $humaneval_array_length ]] || \
|
| 87 |
+
[[ ${#HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY[@]} -ne $humaneval_array_length ]] || \
|
| 88 |
+
[[ ${#HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY[@]} -ne $humaneval_array_length ]] || \
|
| 89 |
+
[[ ${#HUMANEVAL_SKIP_THRESHOLDS_ARRAY[@]} -ne $humaneval_array_length ]] || \
|
| 90 |
+
[[ ${#HUMANEVAL_TOP_PS_ARRAY[@]} -ne $humaneval_array_length ]] || \
|
| 91 |
+
[[ ${#HUMANEVAL_DTYPES_ARRAY[@]} -ne $humaneval_array_length ]] || \
|
| 92 |
+
[[ ${#HUMANEVAL_SAMPLING_STRATEGIES_ARRAY[@]} -ne $humaneval_array_length ]]; then
|
| 93 |
+
echo "Error: All HumanEval configuration arrays must have the same length!"
|
| 94 |
+
echo "HumanEval Nshots: ${#HUMANEVAL_NSHOTS_ARRAY[@]}, Lengths: ${#HUMANEVAL_LENGTHS_ARRAY[@]}, Temperatures: ${#HUMANEVAL_TEMP_ARRAY[@]}, Limits: ${#HUMANEVAL_LIMITS_ARRAY[@]}, Diffusion steps: ${#HUMANEVAL_DIFFUSION_STEPS_ARRAY[@]}, Block sizes: ${#HUMANEVAL_BLOCK_SIZES_ARRAY[@]}, Block thresholds: ${#HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY[@]}, Decoded token thresholds: ${#HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY[@]}, Skip thresholds: ${#HUMANEVAL_SKIP_THRESHOLDS_ARRAY[@]}, Top_ps: ${#HUMANEVAL_TOP_PS_ARRAY[@]}, Dtypes: ${#HUMANEVAL_DTYPES_ARRAY[@]}, Sampling strategies: ${#HUMANEVAL_SAMPLING_STRATEGIES_ARRAY[@]}"
|
| 95 |
+
exit 1
|
| 96 |
+
fi
|
| 97 |
+
|
| 98 |
+
export HF_ALLOW_CODE_EVAL=1
|
| 99 |
+
for lora_model in "${lora_models[@]}"; do
|
| 100 |
+
lora_model_name="$lora_model"
|
| 101 |
+
echo "===================================================================="
|
| 102 |
+
echo "Evaluating LoRA model: $lora_model_name"
|
| 103 |
+
echo "===================================================================="
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
for i in "${!HUMANEVAL_NSHOTS_ARRAY[@]}"; do
|
| 108 |
+
output_path="eval_llada${lora_model_name}/humaneval-ns${HUMANEVAL_NSHOTS_ARRAY[$i]}-len${HUMANEVAL_LENGTHS_ARRAY[$i]}-temp${HUMANEVAL_TEMP_ARRAY[$i]}-limit${HUMANEVAL_LIMITS_ARRAY[$i]}-diffsteps${HUMANEVAL_DIFFUSION_STEPS_ARRAY[$i]}-block${HUMANEVAL_BLOCK_SIZES_ARRAY[$i]}-thresh${HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY[$i]}-decodethresh${HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY[$i]}-skip${HUMANEVAL_SKIP_THRESHOLDS_ARRAY[$i]}-topp${HUMANEVAL_TOP_PS_ARRAY[$i]}-dtype${HUMANEVAL_DTYPES_ARRAY[$i]}-sampling${HUMANEVAL_SAMPLING_STRATEGIES_ARRAY[$i]}"
|
| 109 |
+
echo "Running HumanEval evaluation $((i+1))/${humaneval_array_length} for $lora_model_name..."
|
| 110 |
+
echo "HumanEval Config: Shots: ${HUMANEVAL_NSHOTS_ARRAY[$i]}, Length: ${HUMANEVAL_LENGTHS_ARRAY[$i]}, Temperature: ${HUMANEVAL_TEMP_ARRAY[$i]}, Limit: ${HUMANEVAL_LIMITS_ARRAY[$i]}, Diffusion Steps: ${HUMANEVAL_DIFFUSION_STEPS_ARRAY[$i]}, Block Size: ${HUMANEVAL_BLOCK_SIZES_ARRAY[$i]}, Block Add Threshold: ${HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY[$i]}, Decoded Token Threshold: ${HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY[$i]}, Skip Threshold: ${HUMANEVAL_SKIP_THRESHOLDS_ARRAY[$i]}, Top_p: ${HUMANEVAL_TOP_PS_ARRAY[$i]}, Sampling Strategy: ${HUMANEVAL_SAMPLING_STRATEGIES_ARRAY[$i]}, Dtype: ${HUMANEVAL_DTYPES_ARRAY[$i]}; Output: $output_path"
|
| 111 |
+
|
| 112 |
+
if [[ "${HUMANEVAL_TOP_PS_ARRAY[$i]}" == "none" ]]; then
|
| 113 |
+
humaneval_model_args="pretrained=${base_model},lora_path=${lora_model},max_new_tokens=${HUMANEVAL_LENGTHS_ARRAY[$i]},diffusion_steps=${HUMANEVAL_DIFFUSION_STEPS_ARRAY[$i]},temperature=${HUMANEVAL_TEMP_ARRAY[$i]},add_bos_token=true,escape_until=true,block_size=${HUMANEVAL_BLOCK_SIZES_ARRAY[$i]},block_add_threshold=${HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY[$i]},skip_threshold=${HUMANEVAL_SKIP_THRESHOLDS_ARRAY[$i]},decoded_token_threshold=${HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY[$i]},dtype=${HUMANEVAL_DTYPES_ARRAY[$i]},sampling_strategy=${HUMANEVAL_SAMPLING_STRATEGIES_ARRAY[$i]},save_dir=${output_path}"
|
| 114 |
+
else
|
| 115 |
+
humaneval_model_args="pretrained=${base_model},lora_path=${lora_model},max_new_tokens=${HUMANEVAL_LENGTHS_ARRAY[$i]},diffusion_steps=${HUMANEVAL_DIFFUSION_STEPS_ARRAY[$i]},temperature=${HUMANEVAL_TEMP_ARRAY[$i]},top_p=${HUMANEVAL_TOP_PS_ARRAY[$i]},add_bos_token=true,escape_until=true,block_size=${HUMANEVAL_BLOCK_SIZES_ARRAY[$i]},block_add_threshold=${HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY[$i]},skip_threshold=${HUMANEVAL_SKIP_THRESHOLDS_ARRAY[$i]},decoded_token_threshold=${HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY[$i]},dtype=${HUMANEVAL_DTYPES_ARRAY[$i]},sampling_strategy=${HUMANEVAL_SAMPLING_STRATEGIES_ARRAY[$i]},save_dir=${output_path}"
|
| 116 |
+
fi
|
| 117 |
+
|
| 118 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch --main_process_port 29520 --num_processes 8 eval_llada.py --model dream_lora \
|
| 119 |
+
--model_args $humaneval_model_args \
|
| 120 |
+
--tasks humaneval \
|
| 121 |
+
--num_fewshot ${HUMANEVAL_NSHOTS_ARRAY[$i]} \
|
| 122 |
+
--batch_size 1 \
|
| 123 |
+
--output_path $output_path \
|
| 124 |
+
--log_samples \
|
| 125 |
+
--confirm_run_unsafe_code
|
| 126 |
+
done
|
| 127 |
+
|
| 128 |
+
### NOTICE: use postprocess for humaneval
|
| 129 |
+
# python postprocess_code.py {the samples_xxx.jsonl file under output_path}
|
| 130 |
+
|
| 131 |
+
for i in "${!TASKS_ARRAY[@]}"; do
|
| 132 |
+
output_path="eval_llada${lora_model_name}/${TASKS_ARRAY[$i]}-ns${NSHOTS_ARRAY[$i]}-len${LENGTH_ARRAY[$i]}-temp${TEMP_ARRAY[$i]}-limit${LIMITS_ARRAY[$i]}-diffsteps${LENGTH_ARRAY[$i]}-block${BLOCK_SIZES_ARRAY[$i]}-thresh${BLOCK_ADD_THRESHOLDS_ARRAY[$i]}-decodethresh${DECODED_TOKEN_THRESHOLDS_ARRAY[$i]}-skip${SKIP_THRESHOLDS_ARRAY[$i]}-topp${TOP_PS_ARRAY[$i]}-dtype${DTYPES_ARRAY[$i]}-sampling${SAMPLING_STRATEGIES_ARRAY[$i]}"
|
| 133 |
+
echo "Task: ${TASKS_ARRAY[$i]}, Shots: ${NSHOTS_ARRAY[$i]}, Length: ${LENGTH_ARRAY[$i]}, Temperature: ${TEMP_ARRAY[$i]}, Limit: ${LIMITS_ARRAY[$i]}, Block Size: ${BLOCK_SIZES_ARRAY[$i]}, Block Add Threshold: ${BLOCK_ADD_THRESHOLDS_ARRAY[$i]}, Decoded Token Threshold: ${DECODED_TOKEN_THRESHOLDS_ARRAY[$i]}, Skip Threshold: ${SKIP_THRESHOLDS_ARRAY[$i]}, Top_p: ${TOP_PS_ARRAY[$i]}, Sampling Strategy: ${SAMPLING_STRATEGIES_ARRAY[$i]}, Dtype: ${DTYPES_ARRAY[$i]}; Output: $output_path"
|
| 134 |
+
|
| 135 |
+
if [[ "${TOP_PS_ARRAY[$i]}" == "none" ]]; then
|
| 136 |
+
model_args="pretrained=${base_model},lora_path=${lora_model},max_new_tokens=${LENGTH_ARRAY[$i]},diffusion_steps=${LENGTH_ARRAY[$i]},add_bos_token=true,temperature=${TEMP_ARRAY[$i]},block_size=${BLOCK_SIZES_ARRAY[$i]},block_add_threshold=${BLOCK_ADD_THRESHOLDS_ARRAY[$i]},skip_threshold=${SKIP_THRESHOLDS_ARRAY[$i]},decoded_token_threshold=${DECODED_TOKEN_THRESHOLDS_ARRAY[$i]},dtype=${DTYPES_ARRAY[$i]},sampling_strategy=${SAMPLING_STRATEGIES_ARRAY[$i]},save_dir=${output_path}"
|
| 137 |
+
else
|
| 138 |
+
model_args="pretrained=${base_model},lora_path=${lora_model},max_new_tokens=${LENGTH_ARRAY[$i]},diffusion_steps=${LENGTH_ARRAY[$i]},add_bos_token=true,temperature=${TEMP_ARRAY[$i]},top_p=${TOP_PS_ARRAY[$i]},block_size=${BLOCK_SIZES_ARRAY[$i]},block_add_threshold=${BLOCK_ADD_THRESHOLDS_ARRAY[$i]},skip_threshold=${SKIP_THRESHOLDS_ARRAY[$i]},decoded_token_threshold=${DECODED_TOKEN_THRESHOLDS_ARRAY[$i]},dtype=${DTYPES_ARRAY[$i]},sampling_strategy=${SAMPLING_STRATEGIES_ARRAY[$i]},save_dir=${output_path}"
|
| 139 |
+
fi
|
| 140 |
+
|
| 141 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch --main_process_port 29520 --num_processes 8 eval_llada.py --model dream_lora \
|
| 142 |
+
--model_args $model_args \
|
| 143 |
+
--tasks ${TASKS_ARRAY[$i]} \
|
| 144 |
+
--limit ${LIMITS_ARRAY[$i]} \
|
| 145 |
+
--num_fewshot ${NSHOTS_ARRAY[$i]} \
|
| 146 |
+
--batch_size 1 \
|
| 147 |
+
--output_path $output_path \
|
| 148 |
+
--log_samples \
|
| 149 |
+
--confirm_run_unsafe_code \
|
| 150 |
+
--apply_chat_template \
|
| 151 |
+
--fewshot_as_multiturn
|
| 152 |
+
done
|
| 153 |
+
done
|
| 154 |
+
|
| 155 |
+
echo "All evaluations completed!"
|
Discrete-Diffusion-Forcing/D2F-eval/generate_llada_demo_ar.py
ADDED
|
@@ -0,0 +1,660 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import torch.distributions as dists
|
| 4 |
+
import transformers
|
| 5 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 6 |
+
from peft import PeftModel, PeftConfig
|
| 7 |
+
import numpy as np
|
| 8 |
+
import random
|
| 9 |
+
import time
|
| 10 |
+
import os
|
| 11 |
+
from typing import List, Dict, Optional, Tuple, Iterator, Set
|
| 12 |
+
import gradio as gr
|
| 13 |
+
import gc
|
| 14 |
+
|
| 15 |
+
# Suppress some Hugging Face warnings
|
| 16 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 17 |
+
|
| 18 |
+
# Import necessary model classes
|
| 19 |
+
# Assuming these custom classes are in the correct path
|
| 20 |
+
from model_cache.llada.modeling_llada import LLaDAModelLM
|
| 21 |
+
from model_cache.llada.configuration_llada import LLaDAConfig
|
| 22 |
+
|
| 23 |
+
# --- Helper Functions (Unchanged) ---
|
| 24 |
+
def set_seed(seed):
|
| 25 |
+
torch.manual_seed(seed); random.seed(seed); np.random.seed(seed);
|
| 26 |
+
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed); torch.backends.cudnn.deterministic = True; torch.backends.cudnn.benchmark = False
|
| 27 |
+
|
| 28 |
+
def create_full_block_attention_mask(prompt_length, max_length, block_size, device=None, dtype=None):
|
| 29 |
+
if dtype is None: dtype = torch.bfloat16
|
| 30 |
+
attention_mask = torch.full((1, 1, max_length, max_length), -torch.inf, device=device, dtype=dtype)
|
| 31 |
+
attention_mask[:, :, :prompt_length, :prompt_length] = 0
|
| 32 |
+
remaining_length = max_length - prompt_length
|
| 33 |
+
num_blocks = (remaining_length + block_size - 1) // block_size
|
| 34 |
+
for b in range(num_blocks):
|
| 35 |
+
block_start = prompt_length + b * block_size; block_end = min(prompt_length + (b + 1) * block_size, max_length)
|
| 36 |
+
attention_mask[:, :, block_start:block_end, :prompt_length] = 0
|
| 37 |
+
for prev_b in range(b):
|
| 38 |
+
prev_start = prompt_length + prev_b * block_size; prev_end = min(prompt_length + (prev_b + 1) * block_size, max_length)
|
| 39 |
+
attention_mask[:, :, block_start:block_end, prev_start:prev_end] = 0
|
| 40 |
+
attention_mask[:, :, block_start:block_end, block_start:block_end] = 0
|
| 41 |
+
return attention_mask
|
| 42 |
+
|
| 43 |
+
def extract_attention_mask(full_mask, start_pos, input_length, cache_length):
|
| 44 |
+
end_pos = start_pos + input_length; total_length = cache_length + input_length
|
| 45 |
+
extracted_mask = torch.full((1, 1, input_length, total_length), -torch.inf, device=full_mask.device, dtype=full_mask.dtype)
|
| 46 |
+
extracted_mask[:, :, :, :cache_length] = full_mask[:, :, start_pos:end_pos, :cache_length]
|
| 47 |
+
extracted_mask[:, :, :, cache_length:] = full_mask[:, :, start_pos:end_pos, start_pos:end_pos]
|
| 48 |
+
return extracted_mask
|
| 49 |
+
|
| 50 |
+
def top_p_logits(logits, top_p=None):
|
| 51 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 52 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 53 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 54 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 55 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 56 |
+
mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
|
| 57 |
+
mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
|
| 58 |
+
logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
|
| 59 |
+
return logits
|
| 60 |
+
|
| 61 |
+
def top_k_logits(logits, top_k=None):
|
| 62 |
+
top_k = min(top_k, logits.size(-1))
|
| 63 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
| 64 |
+
logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
|
| 65 |
+
return logits
|
| 66 |
+
|
| 67 |
+
def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
|
| 68 |
+
if temperature > 0: logits = logits / temperature
|
| 69 |
+
if top_p is not None and top_p < 1: logits = top_p_logits(logits, top_p)
|
| 70 |
+
if top_k is not None: logits = top_k_logits(logits, top_k)
|
| 71 |
+
probs = torch.softmax(logits, dim=-1)
|
| 72 |
+
if temperature > 0:
|
| 73 |
+
try:
|
| 74 |
+
x0 = dists.Categorical(probs=probs).sample()
|
| 75 |
+
initial_confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
|
| 76 |
+
except: initial_confidence, x0 = probs.max(dim=-1)
|
| 77 |
+
else: initial_confidence, x0 = probs.max(dim=-1)
|
| 78 |
+
confidence = initial_confidence.clone()
|
| 79 |
+
if margin_confidence:
|
| 80 |
+
sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
|
| 81 |
+
confidence = sorted_probs[:, 0] - sorted_probs[:, 1]
|
| 82 |
+
if neg_entropy:
|
| 83 |
+
epsilon = 1e-10
|
| 84 |
+
confidence = torch.sum(probs * torch.log(probs + epsilon), dim=-1)
|
| 85 |
+
return confidence, x0, initial_confidence
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class D2FInference:
|
| 89 |
+
CSS = """
|
| 90 |
+
.gradio-container {
|
| 91 |
+
font-family: -apple-system, BlinkMacSystemFont, sans-serif;
|
| 92 |
+
}
|
| 93 |
+
.model-header {
|
| 94 |
+
font-size: 1.2em;
|
| 95 |
+
font-weight: bold;
|
| 96 |
+
margin-bottom: 10px;
|
| 97 |
+
padding: 8px;
|
| 98 |
+
border-radius: 5px;
|
| 99 |
+
text-align: center;
|
| 100 |
+
}
|
| 101 |
+
.d2f-header {
|
| 102 |
+
background-color: #DBEAFE;
|
| 103 |
+
color: #1E40AF;
|
| 104 |
+
}
|
| 105 |
+
.llama-header {
|
| 106 |
+
background-color: #FEF3C7;
|
| 107 |
+
color: #92400E;
|
| 108 |
+
}
|
| 109 |
+
.stats-container {
|
| 110 |
+
padding: 15px;
|
| 111 |
+
border: 1px solid #10B981;
|
| 112 |
+
border-radius: 8px;
|
| 113 |
+
background-color: #F0FDF4;
|
| 114 |
+
margin-top: 10px;
|
| 115 |
+
margin-bottom: 20px;
|
| 116 |
+
}
|
| 117 |
+
.output-textbox textarea {
|
| 118 |
+
font-size: 1.5em !important;
|
| 119 |
+
line-height: 1.6 !important;
|
| 120 |
+
height: 70vh !important;
|
| 121 |
+
overflow-y: auto !important;
|
| 122 |
+
}
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
def __init__(self, **kwargs):
|
| 126 |
+
print("Initializing D2F-LLaDA model...")
|
| 127 |
+
self.device = torch.device(kwargs.get("device", "cuda:3") if torch.cuda.is_available() else "cpu")
|
| 128 |
+
self.__dict__.update(kwargs)
|
| 129 |
+
if self.dtype == "bfloat16" and torch.cuda.is_bf16_supported(): self.target_dtype = torch.bfloat16
|
| 130 |
+
elif self.dtype == "float16": self.target_dtype = torch.float16
|
| 131 |
+
else: self.target_dtype = torch.float32
|
| 132 |
+
self._setup_model(self.pretrained_path, self.lora_path)
|
| 133 |
+
print("D2F-LLaDA model and tokenizer setup complete.")
|
| 134 |
+
|
| 135 |
+
def _setup_model(self, pretrained_path, lora_path):
|
| 136 |
+
config = LLaDAConfig.from_pretrained(pretrained_path)
|
| 137 |
+
self.model = LLaDAModelLM.from_pretrained(pretrained_path, config=config, torch_dtype=self.target_dtype).eval()
|
| 138 |
+
self.model = PeftModel.from_pretrained(self.model, lora_path)
|
| 139 |
+
self.model = self.model.to(self.device)
|
| 140 |
+
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
|
| 141 |
+
if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 142 |
+
|
| 143 |
+
def _apply_chat_template(self, prompt):
|
| 144 |
+
chat_history = [{"role": "user", "content": prompt}]
|
| 145 |
+
return self.tokenizer.apply_chat_template(chat_history, tokenize=False, add_generation_prompt=True)
|
| 146 |
+
|
| 147 |
+
def _update_block_completion_states(self, block_states, decoded_token_threshold):
|
| 148 |
+
for block_id in sorted(block_states.keys()):
|
| 149 |
+
decoded_tokens = block_states[block_id]['total_masks'] - block_states[block_id]['mask_count']
|
| 150 |
+
if block_states[block_id]['total_masks'] > 0:
|
| 151 |
+
decode_ratio = decoded_tokens / block_states[block_id]['total_masks']
|
| 152 |
+
if decode_ratio >= decoded_token_threshold:
|
| 153 |
+
if (next_block_id := block_id + 1) in block_states:
|
| 154 |
+
block_states[next_block_id]['is_complete'] = True
|
| 155 |
+
|
| 156 |
+
@torch.inference_mode()
|
| 157 |
+
def stream(
|
| 158 |
+
self,
|
| 159 |
+
prompt_text: str,
|
| 160 |
+
max_new_tokens: int,
|
| 161 |
+
block_size: int,
|
| 162 |
+
block_add_threshold: float,
|
| 163 |
+
decoded_token_threshold: float,
|
| 164 |
+
skip_threshold: float
|
| 165 |
+
) -> Iterator[Tuple[str, str]]:
|
| 166 |
+
|
| 167 |
+
start_time = time.time()
|
| 168 |
+
|
| 169 |
+
input_ids = self.tokenizer(self._apply_chat_template(prompt_text), return_tensors="pt").input_ids.to(self.device)
|
| 170 |
+
prompt_length = input_ids.shape[1]
|
| 171 |
+
|
| 172 |
+
full_attention_mask = create_full_block_attention_mask(prompt_length, self.max_length, block_size, self.device, self.target_dtype)
|
| 173 |
+
x_t = input_ids
|
| 174 |
+
block_states = {0: {'start_pos': 0, 'end_pos': prompt_length, 'mask_count': 0, 'total_masks': prompt_length, 'state': 'to_cache', 'is_complete': True}}
|
| 175 |
+
past_key_values, current_blocks, step, eos_detected, cache_length = None, 0, 0, False, 0
|
| 176 |
+
|
| 177 |
+
yield "", None
|
| 178 |
+
|
| 179 |
+
tokens_generated = 0
|
| 180 |
+
|
| 181 |
+
while True:
|
| 182 |
+
step += 1
|
| 183 |
+
updated_block_ids = set()
|
| 184 |
+
|
| 185 |
+
if len(block_states) - 1 < (max_new_tokens // block_size) and not eos_detected:
|
| 186 |
+
last_block_id = max(block_states.keys())
|
| 187 |
+
progress_ratio = (block_states[last_block_id]['total_masks'] - block_states[last_block_id]['mask_count']) / block_states[last_block_id]['total_masks'] if block_states[last_block_id]['total_masks'] > 0 else 1.0
|
| 188 |
+
if progress_ratio >= block_add_threshold:
|
| 189 |
+
new_block_id = last_block_id + 1; new_start_pos = x_t.shape[1]
|
| 190 |
+
if new_start_pos + block_size <= self.max_length:
|
| 191 |
+
x_t = torch.cat([x_t, torch.full((1, block_size), self.mask_token_id, device=self.device, dtype=torch.long)], dim=1)
|
| 192 |
+
block_states[new_block_id] = {'start_pos': new_start_pos, 'end_pos': new_start_pos + block_size, 'mask_count': block_size, 'total_masks': block_size, 'state': 'active', 'is_complete': False}
|
| 193 |
+
current_blocks += 1
|
| 194 |
+
|
| 195 |
+
self._update_block_completion_states(block_states, decoded_token_threshold)
|
| 196 |
+
if (x_t == self.mask_token_id).sum() == 0 and current_blocks == 0: break
|
| 197 |
+
|
| 198 |
+
blocks_to_cache = [bid for bid, state in block_states.items() if state['state'] == 'to_cache']
|
| 199 |
+
update_kvcache = 0
|
| 200 |
+
if blocks_to_cache:
|
| 201 |
+
start_pos, end_pos = block_states[min(blocks_to_cache)]['start_pos'], block_states[max(blocks_to_cache)]['end_pos']
|
| 202 |
+
update_kvcache = end_pos - start_pos; input_seq, process_start_pos = x_t[:, start_pos:], start_pos
|
| 203 |
+
else:
|
| 204 |
+
active_blocks = [bid for bid, state in block_states.items() if state['state'] == 'active' and state['start_pos'] >= cache_length]
|
| 205 |
+
if not active_blocks: break
|
| 206 |
+
start_pos = min(block_states[bid]['start_pos'] for bid in active_blocks); input_seq, process_start_pos = x_t[:, start_pos:], start_pos
|
| 207 |
+
|
| 208 |
+
if input_seq.shape[1] == 0: break
|
| 209 |
+
|
| 210 |
+
attention_mask = extract_attention_mask(full_mask=full_attention_mask,
|
| 211 |
+
start_pos=process_start_pos,
|
| 212 |
+
input_length=input_seq.shape[1],
|
| 213 |
+
cache_length=cache_length)
|
| 214 |
+
|
| 215 |
+
outputs = self.model(input_seq,
|
| 216 |
+
attention_bias=attention_mask,
|
| 217 |
+
past_key_values=past_key_values,
|
| 218 |
+
use_cache=True,
|
| 219 |
+
update_kvcache=update_kvcache + cache_length)
|
| 220 |
+
|
| 221 |
+
if update_kvcache > 0:
|
| 222 |
+
past_key_values = outputs.past_key_values
|
| 223 |
+
for bid in blocks_to_cache:
|
| 224 |
+
block_states[bid]['state'] = 'in_cache'
|
| 225 |
+
|
| 226 |
+
blocks_to_deactivate = []
|
| 227 |
+
for block_id, state in block_states.items():
|
| 228 |
+
if state['state'] != 'active':
|
| 229 |
+
continue
|
| 230 |
+
|
| 231 |
+
block_mask_locs = (x_t[0, state['start_pos']:state['end_pos']] == self.mask_token_id).nonzero().squeeze(-1)
|
| 232 |
+
|
| 233 |
+
if block_mask_locs.numel() == 0:
|
| 234 |
+
blocks_to_deactivate.append(block_id)
|
| 235 |
+
continue
|
| 236 |
+
|
| 237 |
+
logit_offset = state['start_pos'] - process_start_pos
|
| 238 |
+
block_mask_logits = outputs.logits[:, logit_offset + block_mask_locs, :]
|
| 239 |
+
_, x0, initial_confidence = sample_tokens(block_mask_logits.squeeze(0), self.temperature, self.top_p, self.top_k)
|
| 240 |
+
all_indices = (initial_confidence > skip_threshold).nonzero().squeeze(-1)
|
| 241 |
+
|
| 242 |
+
if state['is_complete'] and all_indices.numel() == 0 and block_mask_logits.numel() > 0:
|
| 243 |
+
all_indices = torch.tensor([torch.argmax(initial_confidence)], device=self.device)
|
| 244 |
+
|
| 245 |
+
if all_indices.numel() > 0:
|
| 246 |
+
updated_block_ids.add(block_id)
|
| 247 |
+
positions_to_update = state['start_pos'] + block_mask_locs[all_indices]
|
| 248 |
+
x_t[0, positions_to_update] = x0[all_indices]
|
| 249 |
+
state['mask_count'] -= all_indices.numel()
|
| 250 |
+
tokens_generated += all_indices.numel()
|
| 251 |
+
|
| 252 |
+
if self.tokenizer.eos_token_id in x0[all_indices]:
|
| 253 |
+
eos_detected = True
|
| 254 |
+
|
| 255 |
+
if state['mask_count'] == 0:
|
| 256 |
+
blocks_to_deactivate.append(block_id)
|
| 257 |
+
|
| 258 |
+
for bid in blocks_to_deactivate:
|
| 259 |
+
if block_states[bid]['state'] == 'active' and all(block_states.get(i, {}).get('state') != 'active' for i in range(bid)):
|
| 260 |
+
block_states[bid]['state'] = 'to_cache'
|
| 261 |
+
current_blocks -= 1
|
| 262 |
+
|
| 263 |
+
if update_kvcache > 0:
|
| 264 |
+
cache_length += update_kvcache
|
| 265 |
+
|
| 266 |
+
generated_ids = x_t[0, prompt_length:]
|
| 267 |
+
valid_ids = generated_ids[generated_ids != self.mask_token_id]
|
| 268 |
+
live_text = self.tokenizer.decode(valid_ids, skip_special_tokens=True)
|
| 269 |
+
|
| 270 |
+
yield live_text, None
|
| 271 |
+
|
| 272 |
+
total_time = time.time() - start_time
|
| 273 |
+
final_generated_ids = x_t[0, prompt_length:]
|
| 274 |
+
eos_positions = (final_generated_ids == self.tokenizer.eos_token_id).nonzero()
|
| 275 |
+
|
| 276 |
+
if eos_positions.numel() > 0:
|
| 277 |
+
final_generated_ids = final_generated_ids[:eos_positions[0, 0] + 1]
|
| 278 |
+
|
| 279 |
+
final_text = self.tokenizer.decode(final_generated_ids, skip_special_tokens=True)
|
| 280 |
+
|
| 281 |
+
tokens_incl_eos = len(final_generated_ids)
|
| 282 |
+
tokens_per_second = tokens_incl_eos / total_time if total_time > 0 else 0
|
| 283 |
+
|
| 284 |
+
stats = {
|
| 285 |
+
"total_time": total_time,
|
| 286 |
+
"tokens_generated": tokens_incl_eos,
|
| 287 |
+
"tokens_per_second": tokens_per_second
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
if past_key_values is not None:
|
| 291 |
+
del past_key_values
|
| 292 |
+
del full_attention_mask
|
| 293 |
+
torch.cuda.empty_cache()
|
| 294 |
+
|
| 295 |
+
yield final_text, stats
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
class LlamaInference:
|
| 299 |
+
def __init__(self, **kwargs):
|
| 300 |
+
print("Initializing LLaMA model...")
|
| 301 |
+
self.device = torch.device(kwargs.get("device", "cuda:4") if torch.cuda.is_available() else "cpu")
|
| 302 |
+
self.__dict__.update(kwargs)
|
| 303 |
+
self._setup_model(self.model_id)
|
| 304 |
+
print("LLaMA model and tokenizer setup complete.")
|
| 305 |
+
|
| 306 |
+
def _setup_model(self, model_id):
|
| 307 |
+
print(f"Loading LLaMA model {model_id} on {self.device}...")
|
| 308 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 309 |
+
|
| 310 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 311 |
+
model_id,
|
| 312 |
+
torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
|
| 313 |
+
device_map=self.device
|
| 314 |
+
).eval()
|
| 315 |
+
|
| 316 |
+
if self.tokenizer.eos_token is None:
|
| 317 |
+
self.tokenizer.eos_token = "</s>"
|
| 318 |
+
|
| 319 |
+
if self.tokenizer.pad_token is None:
|
| 320 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 321 |
+
|
| 322 |
+
def _apply_chat_template(self, prompt):
|
| 323 |
+
chat_history = [{"role": "user", "content": prompt}]
|
| 324 |
+
return self.tokenizer.apply_chat_template(chat_history, tokenize=False, add_generation_prompt=True)
|
| 325 |
+
|
| 326 |
+
@torch.inference_mode()
|
| 327 |
+
def stream(
|
| 328 |
+
self,
|
| 329 |
+
prompt_text: str,
|
| 330 |
+
max_new_tokens: int,
|
| 331 |
+
temperature: float = 0.0,
|
| 332 |
+
top_p: float = 0.9,
|
| 333 |
+
top_k: int = None
|
| 334 |
+
) -> Iterator[Tuple[str, str]]:
|
| 335 |
+
|
| 336 |
+
start_time = time.time()
|
| 337 |
+
|
| 338 |
+
formatted_prompt = self._apply_chat_template(prompt_text)
|
| 339 |
+
input_ids = self.tokenizer(formatted_prompt, return_tensors="pt").input_ids.to(self.device)
|
| 340 |
+
prompt_length = input_ids.shape[1]
|
| 341 |
+
|
| 342 |
+
yield "", None
|
| 343 |
+
|
| 344 |
+
tokens_generated = 0
|
| 345 |
+
current_input_ids = input_ids.clone()
|
| 346 |
+
|
| 347 |
+
for i in range(max_new_tokens):
|
| 348 |
+
with torch.no_grad():
|
| 349 |
+
outputs = self.model(current_input_ids, use_cache=True)
|
| 350 |
+
|
| 351 |
+
next_token_logits = outputs.logits[:, -1, :]
|
| 352 |
+
|
| 353 |
+
if temperature > 0:
|
| 354 |
+
next_token_logits = next_token_logits / temperature
|
| 355 |
+
if top_p is not None and top_p < 1:
|
| 356 |
+
next_token_logits = top_p_logits(next_token_logits, top_p)
|
| 357 |
+
if top_k is not None:
|
| 358 |
+
next_token_logits = top_k_logits(next_token_logits, top_k)
|
| 359 |
+
probs = torch.softmax(next_token_logits, dim=-1)
|
| 360 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 361 |
+
else:
|
| 362 |
+
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
|
| 363 |
+
|
| 364 |
+
current_input_ids = torch.cat([current_input_ids, next_token], dim=-1)
|
| 365 |
+
tokens_generated += 1
|
| 366 |
+
|
| 367 |
+
if next_token[0, 0].item() == self.tokenizer.eos_token_id:
|
| 368 |
+
break
|
| 369 |
+
|
| 370 |
+
generated_text = self.tokenizer.decode(
|
| 371 |
+
current_input_ids[0, prompt_length:],
|
| 372 |
+
skip_special_tokens=True
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
yield generated_text, None
|
| 376 |
+
|
| 377 |
+
del outputs
|
| 378 |
+
|
| 379 |
+
total_time = time.time() - start_time
|
| 380 |
+
tokens_per_second = tokens_generated / total_time if total_time > 0 else 0
|
| 381 |
+
|
| 382 |
+
final_text = self.tokenizer.decode(current_input_ids[0, prompt_length:], skip_special_tokens=True)
|
| 383 |
+
|
| 384 |
+
stats = {
|
| 385 |
+
"total_time": total_time,
|
| 386 |
+
"tokens_generated": tokens_generated,
|
| 387 |
+
"tokens_per_second": tokens_per_second
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
del current_input_ids
|
| 391 |
+
torch.cuda.empty_cache()
|
| 392 |
+
|
| 393 |
+
yield final_text, stats
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
# --- Comparison Helper Functions ---
|
| 397 |
+
def create_comparison_html(d2f_results, llama_results):
|
| 398 |
+
d_tokens = d2f_results["tokens_generated"]
|
| 399 |
+
d_time = d2f_results["total_time"]
|
| 400 |
+
d_tokens_per_sec = d2f_results["tokens_per_second"]
|
| 401 |
+
|
| 402 |
+
a_tokens = llama_results["tokens_generated"]
|
| 403 |
+
a_time = llama_results["total_time"]
|
| 404 |
+
a_tokens_per_sec = llama_results["tokens_per_second"]
|
| 405 |
+
|
| 406 |
+
if a_tokens_per_sec > 0:
|
| 407 |
+
speedup = d_tokens_per_sec / a_tokens_per_sec
|
| 408 |
+
else:
|
| 409 |
+
speedup = 0
|
| 410 |
+
|
| 411 |
+
comparison_html = f"""
|
| 412 |
+
<div class="stats-container" style="background-color: #F9FAFB; border-color: #6366F1;">
|
| 413 |
+
<h3>⚡ Performance Comparison</h3>
|
| 414 |
+
<table style="width:100%; text-align: left; border-collapse: collapse;">
|
| 415 |
+
<tr style="background-color: #EEF2FF;">
|
| 416 |
+
<th style="padding: 8px; border: 1px solid #ddd;">Metric</th>
|
| 417 |
+
<th style="padding: 8px; border: 1px solid #ddd;">D2F-LLaDA-Instruct-8B</th>
|
| 418 |
+
<th style="padding: 8px; border: 1px solid #ddd;">LLaMA3-Instruct-8B</th>
|
| 419 |
+
<th style="padding: 8px; border: 1px solid #ddd;">Difference</th>
|
| 420 |
+
</tr>
|
| 421 |
+
<tr>
|
| 422 |
+
<td style="padding: 8px; border: 1px solid #ddd;">Total tokens</td>
|
| 423 |
+
<td style="padding: 8px; border: 1px solid #ddd;">{d_tokens}</td>
|
| 424 |
+
<td style="padding: 8px; border: 1px solid #ddd;">{a_tokens}</td>
|
| 425 |
+
<td style="padding: 8px; border: 1px solid #ddd;">-</td>
|
| 426 |
+
</tr>
|
| 427 |
+
<tr>
|
| 428 |
+
<td style="padding: 8px; border: 1px solid #ddd;">Generation time</td>
|
| 429 |
+
<td style="padding: 8px; border: 1px solid #ddd;">{d_time:.2f}s</td>
|
| 430 |
+
<td style="padding: 8px; border: 1px solid #ddd;">{a_time:.2f}s</td>
|
| 431 |
+
<td style="padding: 8px; border: 1px solid #ddd;">
|
| 432 |
+
{"D2F-LLaDA is " + f"{(a_time/d_time):.1f}x faster" if d_time > 0 and d_time < a_time else "LLaMA3 is " + f"{(d_time/a_time):.1f}x faster"}
|
| 433 |
+
</td>
|
| 434 |
+
</tr>
|
| 435 |
+
<tr>
|
| 436 |
+
<td style="padding: 8px; border: 1px solid #ddd;">Tokens per second</td>
|
| 437 |
+
<td style="padding: 8px; border: 1px solid #ddd;">{d_tokens_per_sec:.2f}</td>
|
| 438 |
+
<td style="padding: 8px; border: 1px solid #ddd;">{a_tokens_per_sec:.2f}</td>
|
| 439 |
+
<td style="padding: 8px; border: 1px solid #ddd;">
|
| 440 |
+
{"D2F-LLaDA is " + f"{speedup:.1f}x faster" if speedup > 1 else "LLaMA3 is " + f"{(1/speedup if speedup > 0 else 0):.1f}x faster"}
|
| 441 |
+
</td>
|
| 442 |
+
</tr>
|
| 443 |
+
</table>
|
| 444 |
+
</div>
|
| 445 |
+
"""
|
| 446 |
+
|
| 447 |
+
return comparison_html
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
def create_stats_html(model_name, results):
|
| 451 |
+
stats_html = f"""
|
| 452 |
+
<div class="stats-container">
|
| 453 |
+
<h3>✓ {model_name} Generation Complete</h3>
|
| 454 |
+
<ul>
|
| 455 |
+
<li><b>Total time:</b> {results["total_time"]:.2f} seconds</li>
|
| 456 |
+
<li><b>Tokens generated:</b> {results["tokens_generated"]}</li>
|
| 457 |
+
<li><b>Tokens per second:</b> {results["tokens_per_second"]:.2f}</li>
|
| 458 |
+
</ul>
|
| 459 |
+
</div>
|
| 460 |
+
"""
|
| 461 |
+
|
| 462 |
+
return stats_html
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
# --- Main Interface ---
|
| 466 |
+
if __name__ == "__main__":
|
| 467 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "3,4"
|
| 468 |
+
|
| 469 |
+
torch.cuda.empty_cache()
|
| 470 |
+
|
| 471 |
+
d2f_config = {
|
| 472 |
+
"pretrained_path": "GSAI-ML/LLaDA-8B-Instruct",
|
| 473 |
+
"lora_path": "SJTU-Deng-Lab/D2F_LLaDA_Instruct_8B_Lora",
|
| 474 |
+
"device": "cuda:0",
|
| 475 |
+
"dtype": "bfloat16",
|
| 476 |
+
"max_length": 4096,
|
| 477 |
+
"temperature": 0.0,
|
| 478 |
+
"top_p": None,
|
| 479 |
+
"top_k": None,
|
| 480 |
+
"mask_token_id": 126336,
|
| 481 |
+
"sampling_strategy": "default",
|
| 482 |
+
}
|
| 483 |
+
|
| 484 |
+
llama_config = {
|
| 485 |
+
"model_id": "meta-llama/Llama-3.1-8B-Instruct",
|
| 486 |
+
"device": "cuda:1",
|
| 487 |
+
}
|
| 488 |
+
|
| 489 |
+
set_seed(42)
|
| 490 |
+
|
| 491 |
+
d2f_engine = D2FInference(**d2f_config)
|
| 492 |
+
llama_engine = LlamaInference(**llama_config)
|
| 493 |
+
|
| 494 |
+
with gr.Blocks(css=D2FInference.CSS, theme=gr.themes.Soft()) as demo:
|
| 495 |
+
gr.Markdown("# 🚀 D2F-LLaDA vs LLaMA3: Speed Comparison")
|
| 496 |
+
|
| 497 |
+
with gr.Row():
|
| 498 |
+
with gr.Column(scale=1):
|
| 499 |
+
prompt_input = gr.Textbox(
|
| 500 |
+
label="Enter your question",
|
| 501 |
+
placeholder="Example: Natalia sold clips to...",
|
| 502 |
+
lines=5
|
| 503 |
+
)
|
| 504 |
+
generate_button = gr.Button("🚀 Run Speed Comparison", variant="primary")
|
| 505 |
+
|
| 506 |
+
with gr.Accordion("⚙️ D2F-LLaDA Parameter Settings", open=True):
|
| 507 |
+
with gr.Row():
|
| 508 |
+
max_new_tokens_slider = gr.Slider(
|
| 509 |
+
minimum=64, maximum=2048, value=1024, step=64,
|
| 510 |
+
label="Max Tokens to Generate"
|
| 511 |
+
)
|
| 512 |
+
block_size_slider = gr.Slider(
|
| 513 |
+
minimum=16, maximum=128, value=32, step=16,
|
| 514 |
+
label="Block Size"
|
| 515 |
+
)
|
| 516 |
+
with gr.Row():
|
| 517 |
+
block_add_thresh_slider = gr.Slider(
|
| 518 |
+
minimum=0.0, maximum=1.0, value=0.1, step=0.05,
|
| 519 |
+
label="Block Add Threshold"
|
| 520 |
+
)
|
| 521 |
+
decoded_token_thresh_slider = gr.Slider(
|
| 522 |
+
minimum=0.0, maximum=1.0, value=0.5, step=0.05,
|
| 523 |
+
label="Decoding Completion Threshold"
|
| 524 |
+
)
|
| 525 |
+
skip_thresh_slider = gr.Slider(
|
| 526 |
+
minimum=0.0, maximum=1.0, value=0.9, step=0.01,
|
| 527 |
+
label="Skip Threshold"
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
comparison_output = gr.HTML(label="Performance Comparison", elem_id="comparison-container")
|
| 531 |
+
|
| 532 |
+
with gr.Row():
|
| 533 |
+
with gr.Column(scale=1):
|
| 534 |
+
gr.HTML("<div class='model-header d2f-header'>✨ D2F-LLaDA-Instruct-8B (Parallel Decoding)</div>")
|
| 535 |
+
d2f_output = gr.Textbox(
|
| 536 |
+
label="D2F-LLaDA Output",
|
| 537 |
+
interactive=False,
|
| 538 |
+
elem_classes=["output-textbox"]
|
| 539 |
+
)
|
| 540 |
+
d2f_status = gr.HTML(label="D2F-LLaDA Stats")
|
| 541 |
+
|
| 542 |
+
with gr.Column(scale=1):
|
| 543 |
+
gr.HTML("<div class='model-header llama-header'>🔄 LLaMA3-Instruct-8B (Standard)</div>")
|
| 544 |
+
llama_output = gr.Textbox(
|
| 545 |
+
label="LLaMA3 Output",
|
| 546 |
+
interactive=False,
|
| 547 |
+
elem_classes=["output-textbox"]
|
| 548 |
+
)
|
| 549 |
+
llama_status = gr.HTML(label="LLaMA3 Stats")
|
| 550 |
+
|
| 551 |
+
gr.Examples(
|
| 552 |
+
examples=[
|
| 553 |
+
["Solve the equation x² - 6x + 8 = 0. First, explain what a quadratic equation is and why it can have up to two solutions. Then solve this equation using three different methods: factoring, completing the square, and the quadratic formula. For each method, explain the mathematical reasoning behind it, show all steps in detail, and discuss when this particular method is most useful. Finally, verify your solutions by substituting them back into the original equation.", 1024, 32, 0.1, 0.55, 0.9],
|
| 554 |
+
["A circular swimming pool has a diameter of 8 meters. Calculate the pool's circumference and area. First, explain the relationship between diameter, radius, circumference, and area of a circle, including the role of π in these formulas. Then perform the calculations using π ≈ 3.14159. Next, estimate how much water (in cubic meters) would be needed to fill this pool if it has a uniform depth of 1.5 meters. Finally, calculate how much it would cost to fill this pool if water costs $2.50 per cubic meter. Show all steps and include appropriate units in your answer.", 1024, 32, 0.1, 0.5, 0.9],
|
| 555 |
+
["A movie theater offers a loyalty card that costs $15 and gives a 15% discount on all tickets. If a regular movie ticket costs $10, how many tickets would you need to buy to make the loyalty card worthwhile? First, explain the concept of a break-even point. Then set up an equation to find when the total cost with the card equals the total cost without the card. Solve this equation step by step, showing all your work. Finally, interpret your answer in the context of the problem.", 1024, 32, 0.1, 0.5, 0.9],
|
| 556 |
+
],
|
| 557 |
+
inputs=[
|
| 558 |
+
prompt_input, max_new_tokens_slider, block_size_slider,
|
| 559 |
+
block_add_thresh_slider, decoded_token_thresh_slider, skip_thresh_slider
|
| 560 |
+
],
|
| 561 |
+
label="Examples (Math Problems)"
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
def run_models_streaming(
|
| 565 |
+
prompt_text,
|
| 566 |
+
max_new_tokens,
|
| 567 |
+
block_size,
|
| 568 |
+
block_add_threshold,
|
| 569 |
+
decoded_token_threshold,
|
| 570 |
+
skip_threshold
|
| 571 |
+
):
|
| 572 |
+
torch.cuda.empty_cache()
|
| 573 |
+
|
| 574 |
+
d2f_generator = d2f_engine.stream(
|
| 575 |
+
prompt_text=prompt_text,
|
| 576 |
+
max_new_tokens=max_new_tokens,
|
| 577 |
+
block_size=block_size,
|
| 578 |
+
block_add_threshold=block_add_threshold,
|
| 579 |
+
decoded_token_threshold=decoded_token_threshold,
|
| 580 |
+
skip_threshold=skip_threshold
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
llama_generator = llama_engine.stream(
|
| 584 |
+
prompt_text=prompt_text,
|
| 585 |
+
max_new_tokens=max_new_tokens
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
d2f_text = ""
|
| 589 |
+
llama_text = ""
|
| 590 |
+
d2f_stats = None
|
| 591 |
+
llama_stats = None
|
| 592 |
+
|
| 593 |
+
yield d2f_text, llama_text, "", "", ""
|
| 594 |
+
|
| 595 |
+
d2f_done = False
|
| 596 |
+
llama_done = False
|
| 597 |
+
|
| 598 |
+
while not (d2f_done and llama_done):
|
| 599 |
+
if not d2f_done:
|
| 600 |
+
try:
|
| 601 |
+
new_d2f_text, new_d2f_stats = next(d2f_generator)
|
| 602 |
+
d2f_text = new_d2f_text
|
| 603 |
+
if new_d2f_stats is not None:
|
| 604 |
+
d2f_stats = new_d2f_stats
|
| 605 |
+
d2f_done = True
|
| 606 |
+
except StopIteration:
|
| 607 |
+
d2f_done = True
|
| 608 |
+
|
| 609 |
+
if not llama_done:
|
| 610 |
+
try:
|
| 611 |
+
new_llama_text, new_llama_stats = next(llama_generator)
|
| 612 |
+
llama_text = new_llama_text
|
| 613 |
+
if new_llama_stats is not None:
|
| 614 |
+
llama_stats = new_llama_stats
|
| 615 |
+
llama_done = True
|
| 616 |
+
except StopIteration:
|
| 617 |
+
llama_done = True
|
| 618 |
+
|
| 619 |
+
d2f_status_html = create_stats_html("D2F-LLaDA", d2f_stats) if d2f_stats else ""
|
| 620 |
+
llama_status_html = create_stats_html("LLaMA3", llama_stats) if llama_stats else ""
|
| 621 |
+
|
| 622 |
+
comparison = ""
|
| 623 |
+
if d2f_done and llama_done and d2f_stats and llama_stats:
|
| 624 |
+
comparison = create_comparison_html(d2f_stats, llama_stats)
|
| 625 |
+
|
| 626 |
+
yield d2f_text, llama_text, d2f_status_html, llama_status_html, comparison
|
| 627 |
+
|
| 628 |
+
# MODIFICATION: Removed the _js parameter from here
|
| 629 |
+
generate_button.click(
|
| 630 |
+
fn=run_models_streaming,
|
| 631 |
+
inputs=[
|
| 632 |
+
prompt_input, max_new_tokens_slider, block_size_slider,
|
| 633 |
+
block_add_thresh_slider, decoded_token_thresh_slider, skip_thresh_slider
|
| 634 |
+
],
|
| 635 |
+
outputs=[
|
| 636 |
+
d2f_output, llama_output,
|
| 637 |
+
d2f_status, llama_status,
|
| 638 |
+
comparison_output
|
| 639 |
+
]
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
# MODIFICATION: Added a hidden HTML component with a script for auto-scrolling
|
| 643 |
+
# This method is compatible with older Gradio versions.
|
| 644 |
+
gr.HTML(
|
| 645 |
+
"""
|
| 646 |
+
<script>
|
| 647 |
+
function_to_run = () => {
|
| 648 |
+
const textboxes = document.querySelectorAll('.output-textbox textarea');
|
| 649 |
+
textboxes.forEach(textbox => {
|
| 650 |
+
textbox.scrollTop = textbox.scrollHeight;
|
| 651 |
+
});
|
| 652 |
+
}
|
| 653 |
+
// Run the function every 250ms to ensure autoscrolling
|
| 654 |
+
setInterval(function_to_run, 250);
|
| 655 |
+
</script>
|
| 656 |
+
""",
|
| 657 |
+
visible=False
|
| 658 |
+
)
|
| 659 |
+
|
| 660 |
+
demo.queue().launch(share=True)
|
Discrete-Diffusion-Forcing/D2F-eval/generate_llada_demo_block.py
ADDED
|
@@ -0,0 +1,630 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import torch.distributions as dists
|
| 4 |
+
import transformers
|
| 5 |
+
from transformers import AutoTokenizer
|
| 6 |
+
from peft import PeftModel, PeftConfig
|
| 7 |
+
import numpy as np
|
| 8 |
+
import random
|
| 9 |
+
import time
|
| 10 |
+
import os
|
| 11 |
+
from typing import List, Dict, Optional, Tuple, Iterator, Set
|
| 12 |
+
import gradio as gr
|
| 13 |
+
import ipdb
|
| 14 |
+
# Suppress some Hugging Face warnings
|
| 15 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 16 |
+
|
| 17 |
+
# Import necessary model classes
|
| 18 |
+
from model_cache.llada.modeling_llada import LLaDAModelLM
|
| 19 |
+
from model_cache.llada.configuration_llada import LLaDAConfig
|
| 20 |
+
|
| 21 |
+
# --- Helper Functions (Unchanged) ---
|
| 22 |
+
def set_seed(seed):
|
| 23 |
+
torch.manual_seed(seed); random.seed(seed); np.random.seed(seed);
|
| 24 |
+
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed); torch.backends.cudnn.deterministic = True; torch.backends.cudnn.benchmark = False
|
| 25 |
+
def create_full_block_attention_mask(prompt_length, max_length, block_size, device=None, dtype=None):
|
| 26 |
+
if dtype is None: dtype = torch.bfloat16
|
| 27 |
+
attention_mask = torch.full((1, 1, max_length, max_length), -torch.inf, device=device, dtype=dtype)
|
| 28 |
+
attention_mask[:, :, :prompt_length, :prompt_length] = 0
|
| 29 |
+
remaining_length = max_length - prompt_length
|
| 30 |
+
num_blocks = (remaining_length + block_size - 1) // block_size
|
| 31 |
+
for b in range(num_blocks):
|
| 32 |
+
block_start = prompt_length + b * block_size; block_end = min(prompt_length + (b + 1) * block_size, max_length)
|
| 33 |
+
attention_mask[:, :, block_start:block_end, :prompt_length] = 0
|
| 34 |
+
for prev_b in range(b):
|
| 35 |
+
prev_start = prompt_length + prev_b * block_size; prev_end = min(prompt_length + (prev_b + 1) * block_size, max_length)
|
| 36 |
+
attention_mask[:, :, block_start:block_end, prev_start:prev_end] = 0
|
| 37 |
+
attention_mask[:, :, block_start:block_end, block_start:block_end] = 0
|
| 38 |
+
return attention_mask
|
| 39 |
+
def extract_attention_mask(full_mask, start_pos, input_length, cache_length):
|
| 40 |
+
end_pos = start_pos + input_length; total_length = cache_length + input_length
|
| 41 |
+
extracted_mask = torch.full((1, 1, input_length, total_length), -torch.inf, device=full_mask.device, dtype=full_mask.dtype)
|
| 42 |
+
extracted_mask[:, :, :, :cache_length] = full_mask[:, :, start_pos:end_pos, :cache_length]
|
| 43 |
+
extracted_mask[:, :, :, cache_length:] = full_mask[:, :, start_pos:end_pos, start_pos:end_pos]
|
| 44 |
+
return extracted_mask
|
| 45 |
+
def top_p_logits(logits, top_p=None):
|
| 46 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 47 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 48 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 49 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 50 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 51 |
+
mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
|
| 52 |
+
mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
|
| 53 |
+
logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
|
| 54 |
+
return logits
|
| 55 |
+
def top_k_logits(logits, top_k=None):
|
| 56 |
+
top_k = min(top_k, logits.size(-1))
|
| 57 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
| 58 |
+
logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
|
| 59 |
+
return logits
|
| 60 |
+
def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
|
| 61 |
+
if temperature > 0: logits = logits / temperature
|
| 62 |
+
if top_p is not None and top_p < 1: logits = top_p_logits(logits, top_p)
|
| 63 |
+
if top_k is not None: logits = top_k_logits(logits, top_k)
|
| 64 |
+
probs = torch.softmax(logits, dim=-1)
|
| 65 |
+
if temperature > 0:
|
| 66 |
+
try:
|
| 67 |
+
x0 = dists.Categorical(probs=probs).sample()
|
| 68 |
+
initial_confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
|
| 69 |
+
except: initial_confidence, x0 = probs.max(dim=-1)
|
| 70 |
+
else: initial_confidence, x0 = probs.max(dim=-1)
|
| 71 |
+
confidence = initial_confidence.clone()
|
| 72 |
+
if margin_confidence:
|
| 73 |
+
sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
|
| 74 |
+
confidence = sorted_probs[:, 0] - sorted_probs[:, 1]
|
| 75 |
+
if neg_entropy:
|
| 76 |
+
epsilon = 1e-10
|
| 77 |
+
confidence = torch.sum(probs * torch.log(probs + epsilon), dim=-1)
|
| 78 |
+
return confidence, x0, initial_confidence
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class DreamLoRAInference:
|
| 82 |
+
CSS = """
|
| 83 |
+
/* Fixed height, scrollable visualization container */
|
| 84 |
+
#viz-container {
|
| 85 |
+
height: 500px;
|
| 86 |
+
overflow-y: auto !important;
|
| 87 |
+
border: 1px solid #E5E7EB;
|
| 88 |
+
border-radius: 8px;
|
| 89 |
+
padding: 10px;
|
| 90 |
+
position: relative;
|
| 91 |
+
}
|
| 92 |
+
.block-container {
|
| 93 |
+
display: inline-block; border: 2px solid transparent; border-radius: 8px;
|
| 94 |
+
padding: 5px; margin: 4px 0; transition: border-color 0.3s, box-shadow 0.3s;
|
| 95 |
+
}
|
| 96 |
+
.block-updating {
|
| 97 |
+
border-color: #FF4500 !important;
|
| 98 |
+
box-shadow: 0 0 8px rgba(255, 69, 0, 0.7);
|
| 99 |
+
}
|
| 100 |
+
.token { padding: 2px 4px; margin: 2px; border-radius: 4px; display: inline-block; line-height: 1.4; font-family: monospace; }
|
| 101 |
+
.token.prompt { background-color: #E5E7EB; color: #4B5563; }
|
| 102 |
+
.token.gen-0 { background-color: #DBEAFE; color: #1E40AF; } /* Blue */
|
| 103 |
+
.token.gen-1 { background-color: #D1FAE5; color: #065F46; } /* Green */
|
| 104 |
+
.token.gen-2 { background-color: #FEF3C7; color: #92400E; } /* Yellow */
|
| 105 |
+
.token.gen-3 { background-color: #FEE2E2; color: #991B1B; } /* Red */
|
| 106 |
+
.token.gen-4 { background-color: #E0E7FF; color: #3730A3; } /* Indigo */
|
| 107 |
+
.token.gen-5 { background-color: #F3E8FF; color: #6B21A8; } /* Purple */
|
| 108 |
+
.token.mask { background-color: #F3F4F6; color: #9CA3AF; border: 1px dashed #D1D5DB; }
|
| 109 |
+
|
| 110 |
+
/* Independent status box styles */
|
| 111 |
+
#status-container {
|
| 112 |
+
height: 300px;
|
| 113 |
+
overflow-y: auto !important;
|
| 114 |
+
margin-top: 10px; padding: 15px; border: 1px solid #E5E7EB; border-radius: 8px; background-color: #F9FAFB;
|
| 115 |
+
position: relative;
|
| 116 |
+
}
|
| 117 |
+
#status-container h4 { margin-top: 0; }
|
| 118 |
+
.status-line { font-family: monospace; font-size: 13px; margin-bottom: 5px; margin-top: 5px; padding: 2px 4px; border-radius: 3px;}
|
| 119 |
+
#stats-output { padding: 15px; border: 1px solid #10B981; border-radius: 8px; background-color: #F0FDF4; margin-top: 10px; }
|
| 120 |
+
|
| 121 |
+
/* Scroll anchor */
|
| 122 |
+
.scroll-anchor {
|
| 123 |
+
height: 1px;
|
| 124 |
+
width: 100%;
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
/* Force scrollbar styles */
|
| 128 |
+
#viz-container::-webkit-scrollbar, #status-container::-webkit-scrollbar {
|
| 129 |
+
width: 10px !important;
|
| 130 |
+
background-color: #f5f5f5 !important;
|
| 131 |
+
}
|
| 132 |
+
#viz-container::-webkit-scrollbar-thumb, #status-container::-webkit-scrollbar-thumb {
|
| 133 |
+
background-color: #888 !important;
|
| 134 |
+
border-radius: 5px !important;
|
| 135 |
+
}
|
| 136 |
+
#viz-container::-webkit-scrollbar-track, #status-container::-webkit-scrollbar-track {
|
| 137 |
+
background-color: #f5f5f5 !important;
|
| 138 |
+
border-radius: 5px !important;
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
/* Column height alignment */
|
| 142 |
+
.left-column, .right-column {
|
| 143 |
+
display: flex;
|
| 144 |
+
flex-direction: column;
|
| 145 |
+
height: auto !important;
|
| 146 |
+
min-height: 800px;
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
.live-text-container, .viz-status-container {
|
| 150 |
+
display: flex;
|
| 151 |
+
flex-direction: column;
|
| 152 |
+
flex: 1;
|
| 153 |
+
overflow: visible;
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
#live-text-output, #stats-output {
|
| 157 |
+
margin-bottom: 20px;
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
/* Fix for bottom content being cut off */
|
| 161 |
+
.container {
|
| 162 |
+
padding-bottom: 40px;
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
/* Make sure content is fully visible */
|
| 166 |
+
.gradio-container {
|
| 167 |
+
overflow-y: visible !important;
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
/* Add padding to bottom of page */
|
| 171 |
+
.footer {
|
| 172 |
+
margin-top: 30px;
|
| 173 |
+
padding-bottom: 30px;
|
| 174 |
+
}
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
def __init__(self, **kwargs):
|
| 178 |
+
print("Initializing DreamLoRAInference...")
|
| 179 |
+
self.device = torch.device(kwargs.get("device", "cuda") if torch.cuda.is_available() else "cpu")
|
| 180 |
+
self.__dict__.update(kwargs)
|
| 181 |
+
if self.dtype == "bfloat16" and torch.cuda.is_bf16_supported(): self.target_dtype = torch.bfloat16
|
| 182 |
+
elif self.dtype == "float16": self.target_dtype = torch.float16
|
| 183 |
+
else: self.target_dtype = torch.float32
|
| 184 |
+
self._setup_model(self.pretrained_path, self.lora_path)
|
| 185 |
+
print("Model and tokenizer setup complete.")
|
| 186 |
+
|
| 187 |
+
def _setup_model(self, pretrained_path, lora_path):
|
| 188 |
+
config = LLaDAConfig.from_pretrained(pretrained_path)
|
| 189 |
+
self.model = LLaDAModelLM.from_pretrained(pretrained_path, config=config, torch_dtype=self.target_dtype).eval()
|
| 190 |
+
self.model = PeftModel.from_pretrained(self.model, lora_path)
|
| 191 |
+
self.model = self.model.to(self.device)
|
| 192 |
+
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
|
| 193 |
+
if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 194 |
+
|
| 195 |
+
def _apply_chat_template(self, prompt):
|
| 196 |
+
chat_history = [{"role": "user", "content": prompt}]
|
| 197 |
+
return self.tokenizer.apply_chat_template(chat_history, tokenize=False, add_generation_prompt=True)
|
| 198 |
+
|
| 199 |
+
def _update_block_completion_states(self, block_states, decoded_token_threshold):
|
| 200 |
+
for block_id in sorted(block_states.keys()):
|
| 201 |
+
decoded_tokens = block_states[block_id]['total_masks'] - block_states[block_id]['mask_count']
|
| 202 |
+
if block_states[block_id]['total_masks'] > 0:
|
| 203 |
+
decode_ratio = decoded_tokens / block_states[block_id]['total_masks']
|
| 204 |
+
if decode_ratio >= decoded_token_threshold:
|
| 205 |
+
if (next_block_id := block_id + 1) in block_states:
|
| 206 |
+
block_states[next_block_id]['is_complete'] = True
|
| 207 |
+
|
| 208 |
+
# Render visualization part (excluding prompt status info)
|
| 209 |
+
def _render_visualization_html(self, step: int, x_t: torch.Tensor, block_states: Dict, cache_length: int, updated_block_ids: Set[int]) -> str:
|
| 210 |
+
timestamp = int(time.time() * 1000)
|
| 211 |
+
|
| 212 |
+
html_parts = []
|
| 213 |
+
for block_id in sorted(k for k in block_states.keys() if k > 0): # Only render generated part (block_id > 0)
|
| 214 |
+
state = block_states[block_id]
|
| 215 |
+
container_classes = ["block-container"]
|
| 216 |
+
if block_id in updated_block_ids: container_classes.append("block-updating")
|
| 217 |
+
html_parts.append(f'<div class="{" ".join(container_classes)}" id="block-{block_id}-{timestamp}">')
|
| 218 |
+
block_tokens = x_t[0, state['start_pos']:state['end_pos']]
|
| 219 |
+
for token_id in block_tokens:
|
| 220 |
+
token_id_int = token_id.item()
|
| 221 |
+
token_classes = ["token"]
|
| 222 |
+
if token_id_int == self.mask_token_id:
|
| 223 |
+
token_str = '░'; token_classes.append("mask")
|
| 224 |
+
else:
|
| 225 |
+
token_str = self.tokenizer.decode([token_id_int], skip_special_tokens=False)
|
| 226 |
+
token_str = token_str.replace('&', '&').replace('<', '<').replace('>', '>')
|
| 227 |
+
token_classes.append(f"gen-{(block_id - 1) % 6}")
|
| 228 |
+
html_parts.append(f'<span class="{" ".join(token_classes)}">{token_str}</span>')
|
| 229 |
+
html_parts.append('</div>')
|
| 230 |
+
|
| 231 |
+
html_parts.append(f'<div class="scroll-anchor" id="viz-anchor-{timestamp}"></div>')
|
| 232 |
+
|
| 233 |
+
complete_html = f"""
|
| 234 |
+
<div class="viz-content" id="viz-content-{timestamp}">
|
| 235 |
+
{''.join(html_parts)}
|
| 236 |
+
</div>
|
| 237 |
+
|
| 238 |
+
<script>
|
| 239 |
+
function executeVizScroll() {{
|
| 240 |
+
const container = document.getElementById('viz-container');
|
| 241 |
+
const anchor = document.getElementById('viz-anchor-{timestamp}');
|
| 242 |
+
if (container && anchor) {{
|
| 243 |
+
try {{
|
| 244 |
+
container.scrollTo(0, container.scrollHeight);
|
| 245 |
+
container.scrollTop = container.scrollHeight;
|
| 246 |
+
anchor.scrollIntoView({{behavior: 'auto', block: 'end'}});
|
| 247 |
+
}} catch (e) {{
|
| 248 |
+
console.error('Scroll error:', e);
|
| 249 |
+
}}
|
| 250 |
+
}}
|
| 251 |
+
}}
|
| 252 |
+
|
| 253 |
+
setTimeout(executeVizScroll, 10);
|
| 254 |
+
setTimeout(executeVizScroll, 50);
|
| 255 |
+
setTimeout(executeVizScroll, 150);
|
| 256 |
+
setTimeout(executeVizScroll, 300);
|
| 257 |
+
|
| 258 |
+
try {{
|
| 259 |
+
const vizContent = document.getElementById('viz-content-{timestamp}');
|
| 260 |
+
const vizContainer = document.getElementById('viz-container');
|
| 261 |
+
|
| 262 |
+
if (vizContent && vizContainer) {{
|
| 263 |
+
const resizeObserver = new ResizeObserver(() => {{
|
| 264 |
+
executeVizScroll();
|
| 265 |
+
}});
|
| 266 |
+
resizeObserver.observe(vizContent);
|
| 267 |
+
|
| 268 |
+
const mutationObserver = new MutationObserver(() => {{
|
| 269 |
+
executeVizScroll();
|
| 270 |
+
}});
|
| 271 |
+
mutationObserver.observe(vizContainer, {{
|
| 272 |
+
childList: true,
|
| 273 |
+
subtree: true,
|
| 274 |
+
characterData: true
|
| 275 |
+
}});
|
| 276 |
+
}}
|
| 277 |
+
}} catch (e) {{
|
| 278 |
+
console.error('Observer error:', e);
|
| 279 |
+
}}
|
| 280 |
+
</script>
|
| 281 |
+
"""
|
| 282 |
+
|
| 283 |
+
return complete_html
|
| 284 |
+
|
| 285 |
+
# Render status box part (only shows generation block information)
|
| 286 |
+
def _render_status_html(self, step: int, block_states: Dict, cache_length: int) -> str:
|
| 287 |
+
timestamp = int(time.time() * 1000)
|
| 288 |
+
|
| 289 |
+
html_parts = []
|
| 290 |
+
html_parts.append(f'<h4>Generation Block Status (Step: {step}, Cache Length: {cache_length})</h4>')
|
| 291 |
+
for block_id in [k for k in sorted(block_states.keys()) if k > 0]:
|
| 292 |
+
state = block_states[block_id]
|
| 293 |
+
block_type = f"Block {block_id}"
|
| 294 |
+
masks_filled = state['total_masks'] - state['mask_count']
|
| 295 |
+
color_class = f"gen-{(block_id - 1) % 6}"
|
| 296 |
+
status_line = f'<b>{block_type.ljust(8)}</b>: Pos=[{str(state["start_pos"]).rjust(4)}:{str(state["end_pos"]).ljust(4)}] | State=\'{state["state"].ljust(8)}\' | Filled={str(masks_filled).rjust(2)}/{state["total_masks"]}'
|
| 297 |
+
html_parts.append(f'<p class="status-line token {color_class}" id="status-line-{block_id}-{timestamp}">{status_line}</p>')
|
| 298 |
+
|
| 299 |
+
html_parts.append(f'<div class="scroll-anchor" id="status-anchor-{timestamp}"></div>')
|
| 300 |
+
|
| 301 |
+
complete_html = f"""
|
| 302 |
+
<div class="status-content" id="status-content-{timestamp}">
|
| 303 |
+
{''.join(html_parts)}
|
| 304 |
+
</div>
|
| 305 |
+
|
| 306 |
+
<script>
|
| 307 |
+
function executeStatusScroll() {{
|
| 308 |
+
const container = document.getElementById('status-container');
|
| 309 |
+
const anchor = document.getElementById('status-anchor-{timestamp}');
|
| 310 |
+
if (container && anchor) {{
|
| 311 |
+
try {{
|
| 312 |
+
container.scrollTo(0, container.scrollHeight);
|
| 313 |
+
container.scrollTop = container.scrollHeight;
|
| 314 |
+
anchor.scrollIntoView({{behavior: 'auto', block: 'end'}});
|
| 315 |
+
}} catch (e) {{
|
| 316 |
+
console.error('Status scroll error:', e);
|
| 317 |
+
}}
|
| 318 |
+
}}
|
| 319 |
+
}}
|
| 320 |
+
|
| 321 |
+
setTimeout(executeStatusScroll, 10);
|
| 322 |
+
setTimeout(executeStatusScroll, 50);
|
| 323 |
+
setTimeout(executeStatusScroll, 150);
|
| 324 |
+
setTimeout(executeStatusScroll, 300);
|
| 325 |
+
|
| 326 |
+
try {{
|
| 327 |
+
const statusContent = document.getElementById('status-content-{timestamp}');
|
| 328 |
+
const statusContainer = document.getElementById('status-container');
|
| 329 |
+
|
| 330 |
+
if (statusContent && statusContainer) {{
|
| 331 |
+
const resizeObserver = new ResizeObserver(() => {{
|
| 332 |
+
executeStatusScroll();
|
| 333 |
+
}});
|
| 334 |
+
resizeObserver.observe(statusContent);
|
| 335 |
+
|
| 336 |
+
const mutationObserver = new MutationObserver(() => {{
|
| 337 |
+
executeStatusScroll();
|
| 338 |
+
}});
|
| 339 |
+
mutationObserver.observe(statusContainer, {{
|
| 340 |
+
childList: true,
|
| 341 |
+
subtree: true,
|
| 342 |
+
characterData: true
|
| 343 |
+
}});
|
| 344 |
+
}}
|
| 345 |
+
}} catch (e) {{
|
| 346 |
+
console.error('Status observer error:', e);
|
| 347 |
+
}}
|
| 348 |
+
</script>
|
| 349 |
+
"""
|
| 350 |
+
|
| 351 |
+
return complete_html
|
| 352 |
+
|
| 353 |
+
@torch.inference_mode()
|
| 354 |
+
def stream_and_capture_for_gradio(
|
| 355 |
+
self,
|
| 356 |
+
prompt_text: str,
|
| 357 |
+
max_new_tokens: int,
|
| 358 |
+
block_size: int,
|
| 359 |
+
block_add_threshold: float,
|
| 360 |
+
decoded_token_threshold: float,
|
| 361 |
+
skip_threshold: float
|
| 362 |
+
) -> Iterator[Tuple[str, List[Tuple[str, str]], str, str, str]]:
|
| 363 |
+
|
| 364 |
+
start_time = time.time()
|
| 365 |
+
captured_frames: List[Tuple[str, str]] = []
|
| 366 |
+
|
| 367 |
+
# Initialization
|
| 368 |
+
ipdb.set_trace()
|
| 369 |
+
input_ids = self.tokenizer(self._apply_chat_template(prompt_text), return_tensors="pt").input_ids.to(self.device)
|
| 370 |
+
prompt_length = input_ids.shape[1]
|
| 371 |
+
|
| 372 |
+
full_attention_mask = create_full_block_attention_mask(prompt_length, self.max_length, block_size, self.device, self.target_dtype)
|
| 373 |
+
x_t = input_ids
|
| 374 |
+
block_states = {0: {'start_pos': 0, 'end_pos': prompt_length, 'mask_count': 0, 'total_masks': prompt_length, 'state': 'to_cache', 'is_complete': True}}
|
| 375 |
+
past_key_values, current_blocks, step, eos_detected, cache_length = None, 0, 0, False, 0
|
| 376 |
+
|
| 377 |
+
# Capture initial state
|
| 378 |
+
initial_viz_html = self._render_visualization_html(0, x_t, block_states, 0, set())
|
| 379 |
+
initial_status_html = self._render_status_html(0, block_states, 0)
|
| 380 |
+
captured_frames.append((initial_viz_html, initial_status_html))
|
| 381 |
+
|
| 382 |
+
yield "", captured_frames, "Initializing generation process...", "Initializing visualization...", "Initializing block status..."
|
| 383 |
+
|
| 384 |
+
# Main generation loop
|
| 385 |
+
while True:
|
| 386 |
+
step += 1
|
| 387 |
+
updated_block_ids: Set[int] = set()
|
| 388 |
+
|
| 389 |
+
if len(block_states) - 1 < (max_new_tokens // block_size) and not eos_detected:
|
| 390 |
+
last_block_id = max(block_states.keys())
|
| 391 |
+
progress = (block_states[last_block_id]['total_masks'] - block_states[last_block_id]['mask_count']) / block_states[last_block_id]['total_masks'] if block_states[last_block_id]['total_masks'] > 0 else 1.0
|
| 392 |
+
if progress >= block_add_threshold:
|
| 393 |
+
new_block_id = last_block_id + 1; new_start_pos = x_t.shape[1]
|
| 394 |
+
if new_start_pos + block_size <= self.max_length:
|
| 395 |
+
x_t = torch.cat([x_t, torch.full((1, block_size), self.mask_token_id, device=self.device, dtype=torch.long)], dim=1)
|
| 396 |
+
block_states[new_block_id] = {'start_pos': new_start_pos, 'end_pos': new_start_pos + block_size, 'mask_count': block_size, 'total_masks': block_size, 'state': 'active', 'is_complete': False}
|
| 397 |
+
current_blocks += 1
|
| 398 |
+
|
| 399 |
+
self._update_block_completion_states(block_states, decoded_token_threshold)
|
| 400 |
+
if (x_t == self.mask_token_id).sum() == 0 and current_blocks == 0: break
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
#### D2F-BLOCK ####
|
| 405 |
+
blocks_to_cache = [bid for bid, state in block_states.items() if state['state'] == 'to_cache']
|
| 406 |
+
update_kvcache = 0
|
| 407 |
+
if blocks_to_cache:
|
| 408 |
+
start_pos, end_pos = block_states[min(blocks_to_cache)]['start_pos'], block_states[max(blocks_to_cache)]['end_pos']
|
| 409 |
+
update_kvcache = end_pos - start_pos; input_seq, process_start_pos = x_t[:, start_pos:], start_pos
|
| 410 |
+
else:
|
| 411 |
+
active_blocks = [bid for bid, state in block_states.items() if state['state'] == 'active' and state['start_pos'] >= cache_length]
|
| 412 |
+
if not active_blocks: break
|
| 413 |
+
start_pos = min(block_states[bid]['start_pos'] for bid in active_blocks); input_seq, process_start_pos = x_t[:, start_pos:], start_pos
|
| 414 |
+
|
| 415 |
+
if input_seq.shape[1] == 0: break
|
| 416 |
+
|
| 417 |
+
attention_mask = extract_attention_mask(full_attention_mask, process_start_pos, input_seq.shape[1], cache_length)
|
| 418 |
+
outputs = self.model(input_seq, attention_bias=attention_mask, past_key_values=past_key_values, use_cache=True, update_kvcache=update_kvcache + cache_length)
|
| 419 |
+
if update_kvcache > 0:
|
| 420 |
+
past_key_values = outputs.past_key_values
|
| 421 |
+
for bid in blocks_to_cache: block_states[bid]['state'] = 'in_cache'
|
| 422 |
+
|
| 423 |
+
blocks_to_deactivate = []
|
| 424 |
+
for block_id, state in block_states.items():
|
| 425 |
+
if state['state'] != 'active': continue
|
| 426 |
+
block_mask_locs = (x_t[0, state['start_pos']:state['end_pos']] == self.mask_token_id).nonzero().squeeze(-1)
|
| 427 |
+
if block_mask_locs.numel() == 0:
|
| 428 |
+
blocks_to_deactivate.append(block_id); continue
|
| 429 |
+
logit_offset = state['start_pos'] - process_start_pos
|
| 430 |
+
block_mask_logits = outputs.logits[:, logit_offset + block_mask_locs, :]
|
| 431 |
+
_, x0, initial_confidence = sample_tokens(block_mask_logits.squeeze(0), self.temperature, self.top_p, self.top_k)
|
| 432 |
+
all_indices = (initial_confidence > skip_threshold).nonzero().squeeze(-1)
|
| 433 |
+
if state['is_complete'] and all_indices.numel() == 0 and block_mask_logits.numel() > 0:
|
| 434 |
+
all_indices = torch.tensor([torch.argmax(initial_confidence)], device=self.device)
|
| 435 |
+
|
| 436 |
+
if all_indices.numel() > 0:
|
| 437 |
+
updated_block_ids.add(block_id)
|
| 438 |
+
positions_to_update = state['start_pos'] + block_mask_locs[all_indices]
|
| 439 |
+
x_t[0, positions_to_update] = x0[all_indices]; state['mask_count'] -= all_indices.numel()
|
| 440 |
+
if self.tokenizer.eos_token_id in x0[all_indices]: eos_detected = True
|
| 441 |
+
if state['mask_count'] == 0: blocks_to_deactivate.append(block_id)
|
| 442 |
+
|
| 443 |
+
for bid in blocks_to_deactivate:
|
| 444 |
+
if block_states[bid]['state'] == 'active' and all(block_states.get(i, {}).get('state') != 'active' for i in range(bid)):
|
| 445 |
+
block_states[bid]['state'] = 'to_cache'; current_blocks -= 1
|
| 446 |
+
if update_kvcache > 0: cache_length += update_kvcache
|
| 447 |
+
|
| 448 |
+
#### FlexMDM Cache Update ####
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
# Capture current step's visualization and status frames
|
| 455 |
+
generated_ids = x_t[0, prompt_length:]
|
| 456 |
+
valid_ids = generated_ids[generated_ids != self.mask_token_id]
|
| 457 |
+
live_text = self.tokenizer.decode(valid_ids, skip_special_tokens=True)
|
| 458 |
+
|
| 459 |
+
current_viz_html = self._render_visualization_html(step, x_t, block_states, cache_length, updated_block_ids)
|
| 460 |
+
current_status_html = self._render_status_html(step, block_states, cache_length)
|
| 461 |
+
captured_frames.append((current_viz_html, current_status_html))
|
| 462 |
+
|
| 463 |
+
yield live_text, captured_frames, "Generating...", "Generating...", "Generating..."
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
# Final output
|
| 468 |
+
total_time = time.time() - start_time
|
| 469 |
+
final_generated_ids = x_t[0, prompt_length:]
|
| 470 |
+
eos_positions = (final_generated_ids == self.tokenizer.eos_token_id).nonzero()
|
| 471 |
+
if eos_positions.numel() > 0:
|
| 472 |
+
final_generated_ids = final_generated_ids[:eos_positions[0, 0] + 1]
|
| 473 |
+
|
| 474 |
+
final_text = self.tokenizer.decode(final_generated_ids, skip_special_tokens=True)
|
| 475 |
+
final_viz_html = self._render_visualization_html(step, x_t, block_states, cache_length, set())
|
| 476 |
+
final_status_html = self._render_status_html(step, block_states, cache_length)
|
| 477 |
+
captured_frames.append((final_viz_html, final_status_html))
|
| 478 |
+
|
| 479 |
+
tokens_incl_eos = len(final_generated_ids)
|
| 480 |
+
tokens_excl_eos = len(final_generated_ids[final_generated_ids != self.tokenizer.eos_token_id])
|
| 481 |
+
stats_text = f"""
|
| 482 |
+
### ✅ Generation Complete!
|
| 483 |
+
---
|
| 484 |
+
- **Total time:** `{total_time:.2f} seconds`
|
| 485 |
+
- **Tokens generated (incl. EOS):** `{tokens_incl_eos}`
|
| 486 |
+
- **Tokens generated (excl. EOS):** `{tokens_excl_eos}`
|
| 487 |
+
- **Tokens per second:** `{(tokens_incl_eos / total_time):.2f}`
|
| 488 |
+
"""
|
| 489 |
+
|
| 490 |
+
yield final_text, captured_frames, stats_text, "Generation complete, playback starting soon", "Generation complete, playback starting soon"
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
# --- Gradio UI and Event Handlers ---
|
| 494 |
+
if __name__ == "__main__":
|
| 495 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
|
| 496 |
+
config = {
|
| 497 |
+
"pretrained_path": "GSAI-ML/LLaDA-8B-Instruct",
|
| 498 |
+
"lora_path": "SJTU-Deng-Lab/D2F_LLaDA_Instruct_8B_Lora",
|
| 499 |
+
"device": "cuda", "dtype": "bfloat16", "max_length": 4096,
|
| 500 |
+
"temperature": 0.0, "top_p": None, "top_k": None, "mask_token_id": 126336,
|
| 501 |
+
"sampling_strategy": "default",
|
| 502 |
+
}
|
| 503 |
+
set_seed(42)
|
| 504 |
+
inference_engine = DreamLoRAInference(**config)
|
| 505 |
+
|
| 506 |
+
# Gradio helper for animation
|
| 507 |
+
def animate_visualization(html_frames_list: List[Tuple[str, str]], delay: float) -> Iterator[Tuple[str, str]]:
|
| 508 |
+
if not html_frames_list:
|
| 509 |
+
yield "No visualization data captured", "No status data captured"
|
| 510 |
+
return
|
| 511 |
+
for viz_frame, status_frame in html_frames_list:
|
| 512 |
+
yield viz_frame, status_frame
|
| 513 |
+
time.sleep(delay)
|
| 514 |
+
|
| 515 |
+
# Global auto-scroll JS
|
| 516 |
+
auto_scroll_js = """
|
| 517 |
+
<script>
|
| 518 |
+
function globalForceScroll() {
|
| 519 |
+
// Scroll visualization container
|
| 520 |
+
var vizContainer = document.getElementById('viz-container');
|
| 521 |
+
if (vizContainer) {
|
| 522 |
+
vizContainer.scrollTop = vizContainer.scrollHeight;
|
| 523 |
+
}
|
| 524 |
+
|
| 525 |
+
// Scroll status container
|
| 526 |
+
var statusContainer = document.getElementById('status-container');
|
| 527 |
+
if (statusContainer) {
|
| 528 |
+
statusContainer.scrollTop = statusContainer.scrollHeight;
|
| 529 |
+
}
|
| 530 |
+
|
| 531 |
+
// Scroll all anchors
|
| 532 |
+
var anchors = document.querySelectorAll('.scroll-anchor');
|
| 533 |
+
anchors.forEach(function(anchor) {
|
| 534 |
+
try {
|
| 535 |
+
anchor.scrollIntoView({behavior: 'auto', block: 'end'});
|
| 536 |
+
} catch(e) {}
|
| 537 |
+
});
|
| 538 |
+
}
|
| 539 |
+
|
| 540 |
+
// Periodic scrolling
|
| 541 |
+
setInterval(globalForceScroll, 200);
|
| 542 |
+
|
| 543 |
+
document.addEventListener('DOMContentLoaded', function() {
|
| 544 |
+
// Monitor content changes
|
| 545 |
+
var observer = new MutationObserver(function(mutations) {
|
| 546 |
+
globalForceScroll();
|
| 547 |
+
});
|
| 548 |
+
|
| 549 |
+
observer.observe(document.body, {
|
| 550 |
+
childList: true,
|
| 551 |
+
subtree: true,
|
| 552 |
+
characterData: true
|
| 553 |
+
});
|
| 554 |
+
|
| 555 |
+
// Initial scrolling
|
| 556 |
+
setTimeout(globalForceScroll, 100);
|
| 557 |
+
setTimeout(globalForceScroll, 500);
|
| 558 |
+
setTimeout(globalForceScroll, 1000);
|
| 559 |
+
});
|
| 560 |
+
</script>
|
| 561 |
+
"""
|
| 562 |
+
|
| 563 |
+
with gr.Blocks(css=DreamLoRAInference.CSS, theme=gr.themes.Soft()) as demo:
|
| 564 |
+
html_frames_state = gr.State([])
|
| 565 |
+
|
| 566 |
+
gr.Markdown("# ✨ D2F-LLaDA: Real-time Text vs. Slow-motion Visualization")
|
| 567 |
+
gr.Markdown("Left side shows real-time streaming output. Right side plays back the decoding process visualization after generation completes.")
|
| 568 |
+
|
| 569 |
+
# Inject global auto-scroll JS
|
| 570 |
+
gr.HTML(auto_scroll_js)
|
| 571 |
+
|
| 572 |
+
with gr.Row():
|
| 573 |
+
# --- Left Column ---
|
| 574 |
+
with gr.Column(scale=2, elem_classes=["left-column"]):
|
| 575 |
+
prompt_input = gr.Textbox(label="Enter your question", placeholder="Example: Natalia sold clips to...", lines=5)
|
| 576 |
+
generate_button = gr.Button("🚀 Generate & Visualize", variant="primary")
|
| 577 |
+
with gr.Group(elem_classes=["live-text-container"]):
|
| 578 |
+
live_text_output = gr.Textbox(label="Real-time Generation Output", interactive=False, lines=25, elem_id="live-text-output")
|
| 579 |
+
stats_output = gr.Markdown(label="Generation Statistics", elem_id="stats-output")
|
| 580 |
+
|
| 581 |
+
# --- Right Column ---
|
| 582 |
+
with gr.Column(scale=3, elem_classes=["right-column"]):
|
| 583 |
+
with gr.Accordion("⚙️ Parameter Settings", open=True):
|
| 584 |
+
with gr.Row():
|
| 585 |
+
max_new_tokens_slider = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, label="Max Tokens to Generate")
|
| 586 |
+
block_size_slider = gr.Slider(minimum=16, maximum=128, value=32, step=16, label="Block Size")
|
| 587 |
+
with gr.Row():
|
| 588 |
+
block_add_thresh_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.05, label="Block Add Threshold")
|
| 589 |
+
decoded_token_thresh_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.05, label="Decoding Completion Threshold")
|
| 590 |
+
skip_thresh_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.01, label="Skip Threshold")
|
| 591 |
+
delay_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.05, label="Playback Delay (seconds)", info="Adjust visualization playback speed.")
|
| 592 |
+
|
| 593 |
+
with gr.Group(elem_classes=["viz-status-container"]):
|
| 594 |
+
visualization_output = gr.HTML(label="Generation Process Visualization", elem_id="viz-container")
|
| 595 |
+
status_output_html = gr.HTML(label="Generation Block Status", elem_id="status-container")
|
| 596 |
+
|
| 597 |
+
gr.Examples(
|
| 598 |
+
examples=[
|
| 599 |
+
["Solve the equation x² - 6x + 8 = 0. First, explain what a quadratic equation is and why it can have up to two solutions. Then solve this equation using three different methods: factoring, completing the square, and the quadratic formula. For each method, explain the mathematical reasoning behind it, show all steps in detail, and discuss when this particular method is most useful. Finally, verify your solutions by substituting them back into the original equation.", 1024, 32, 0.1, 0.55, 0.9, 0.1],
|
| 600 |
+
|
| 601 |
+
["A circular swimming pool has a diameter of 8 meters. Calculate the pool's circumference and area. First, explain the relationship between diameter, radius, circumference, and area of a circle, including the role of π in these formulas. Then perform the calculations using π ≈ 3.14159. Next, estimate how much water (in cubic meters) would be needed to fill this pool if it has a uniform depth of 1.5 meters. Finally, calculate how much it would cost to fill this pool if water costs $2.50 per cubic meter. Show all steps and include appropriate units in your answer.", 1024, 32, 0.1, 0.5, 0.9, 0.1],
|
| 602 |
+
|
| 603 |
+
["A movie theater offers a loyalty card that costs $15 and gives a 15% discount on all tickets. If a regular movie ticket costs $10, how many tickets would you need to buy to make the loyalty card worthwhile? First, explain the concept of a break-even point. Then set up an equation to find when the total cost with the card equals the total cost without the card. Solve this equation step by step, showing all your work. Finally, interpret your answer in the context of the problem.", 1024, 32, 0.1, 0.5, 0.9, 0.1],
|
| 604 |
+
],
|
| 605 |
+
inputs=[
|
| 606 |
+
prompt_input, max_new_tokens_slider, block_size_slider, block_add_thresh_slider,
|
| 607 |
+
decoded_token_thresh_slider, skip_thresh_slider, delay_slider
|
| 608 |
+
],
|
| 609 |
+
label="Examples (Math Problems)"
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
# --- Event Handling Chain ---
|
| 613 |
+
inputs_list = [
|
| 614 |
+
prompt_input, max_new_tokens_slider, block_size_slider,
|
| 615 |
+
block_add_thresh_slider, decoded_token_thresh_slider, skip_thresh_slider
|
| 616 |
+
]
|
| 617 |
+
ipdb.set_trace()
|
| 618 |
+
generation_event = generate_button.click(
|
| 619 |
+
fn=inference_engine.stream_and_capture_for_gradio,
|
| 620 |
+
inputs=inputs_list,
|
| 621 |
+
outputs=[live_text_output, html_frames_state, stats_output, visualization_output, status_output_html]
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
generation_event.then(
|
| 625 |
+
fn=animate_visualization,
|
| 626 |
+
inputs=[html_frames_state, delay_slider],
|
| 627 |
+
outputs=[visualization_output, status_output_html]
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
demo.queue().launch(share=True)
|
Discrete-Diffusion-Forcing/D2F-eval/model_cache/dream/configuration_dream.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
# Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""Dream model configuration"""
|
| 17 |
+
|
| 18 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 19 |
+
from transformers.modeling_rope_utils import rope_config_validation
|
| 20 |
+
from transformers.utils import logging
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = logging.get_logger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class DreamConfig(PretrainedConfig):
|
| 27 |
+
model_type = "Dream"
|
| 28 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
vocab_size=151936,
|
| 33 |
+
hidden_size=4096,
|
| 34 |
+
intermediate_size=22016,
|
| 35 |
+
num_hidden_layers=32,
|
| 36 |
+
num_attention_heads=32,
|
| 37 |
+
num_key_value_heads=32,
|
| 38 |
+
hidden_act="silu",
|
| 39 |
+
max_position_embeddings=32768,
|
| 40 |
+
initializer_range=0.02,
|
| 41 |
+
rms_norm_eps=1e-6,
|
| 42 |
+
use_cache=False, # cache not used in diffusion
|
| 43 |
+
tie_word_embeddings=False,
|
| 44 |
+
rope_theta=10000.0,
|
| 45 |
+
rope_scaling=None,
|
| 46 |
+
use_sliding_window=False,
|
| 47 |
+
sliding_window=4096,
|
| 48 |
+
max_window_layers=28,
|
| 49 |
+
attention_dropout=0.0,
|
| 50 |
+
mask_token_id=151666,
|
| 51 |
+
pad_token_id=151643,
|
| 52 |
+
**kwargs,
|
| 53 |
+
):
|
| 54 |
+
self.vocab_size = vocab_size
|
| 55 |
+
self.max_position_embeddings = max_position_embeddings
|
| 56 |
+
self.hidden_size = hidden_size
|
| 57 |
+
self.intermediate_size = intermediate_size
|
| 58 |
+
self.num_hidden_layers = num_hidden_layers
|
| 59 |
+
self.num_attention_heads = num_attention_heads
|
| 60 |
+
self.use_sliding_window = use_sliding_window
|
| 61 |
+
self.sliding_window = sliding_window if use_sliding_window else None
|
| 62 |
+
self.max_window_layers = max_window_layers
|
| 63 |
+
|
| 64 |
+
# for backward compatibility
|
| 65 |
+
if num_key_value_heads is None:
|
| 66 |
+
num_key_value_heads = num_attention_heads
|
| 67 |
+
|
| 68 |
+
self.num_key_value_heads = num_key_value_heads
|
| 69 |
+
self.hidden_act = hidden_act
|
| 70 |
+
self.initializer_range = initializer_range
|
| 71 |
+
self.rms_norm_eps = rms_norm_eps
|
| 72 |
+
self.use_cache = use_cache
|
| 73 |
+
self.rope_theta = rope_theta
|
| 74 |
+
self.rope_scaling = rope_scaling
|
| 75 |
+
self.attention_dropout = attention_dropout
|
| 76 |
+
# Validate the correctness of rotary position embeddings parameters
|
| 77 |
+
# BC: if there is a 'type' field, move it to 'rope_type'.
|
| 78 |
+
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
| 79 |
+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
| 80 |
+
rope_config_validation(self)
|
| 81 |
+
|
| 82 |
+
super().__init__(
|
| 83 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 84 |
+
**kwargs,
|
| 85 |
+
)
|
| 86 |
+
self.mask_token_id = mask_token_id
|
| 87 |
+
self.pad_token_id = pad_token_id
|
| 88 |
+
|
Discrete-Diffusion-Forcing/D2F-eval/model_cache/dream/generation_utils.py
ADDED
|
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import warnings
|
| 17 |
+
import copy
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.distributions as dists
|
| 23 |
+
from torch.nn import functional as F
|
| 24 |
+
from transformers import __version__
|
| 25 |
+
from transformers.generation.configuration_utils import (
|
| 26 |
+
GenerationConfig
|
| 27 |
+
)
|
| 28 |
+
from transformers.utils import (
|
| 29 |
+
ModelOutput,
|
| 30 |
+
is_torchdynamo_compiling,
|
| 31 |
+
logging,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
logger = logging.get_logger(__name__)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def top_p_logits(logits, top_p=None):
|
| 38 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 39 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 40 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 41 |
+
# Shift the indices to the right to keep the first token above the threshold
|
| 42 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 43 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 44 |
+
|
| 45 |
+
mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
|
| 46 |
+
mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
|
| 47 |
+
logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
|
| 48 |
+
return logits
|
| 49 |
+
|
| 50 |
+
def top_k_logits(logits, top_k=None):
|
| 51 |
+
top_k = min(top_k, logits.size(-1)) # Safety check
|
| 52 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
| 53 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
| 54 |
+
logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
|
| 55 |
+
return logits
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
|
| 59 |
+
|
| 60 |
+
if temperature > 0:
|
| 61 |
+
logits = logits / temperature
|
| 62 |
+
if top_p is not None and top_p < 1:
|
| 63 |
+
logits = top_p_logits(logits, top_p)
|
| 64 |
+
if top_k is not None:
|
| 65 |
+
logits = top_k_logits(logits, top_k)
|
| 66 |
+
probs = torch.softmax(logits, dim=-1)
|
| 67 |
+
|
| 68 |
+
if temperature > 0:
|
| 69 |
+
try:
|
| 70 |
+
x0 = dists.Categorical(probs=probs).sample()
|
| 71 |
+
confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
|
| 72 |
+
except:
|
| 73 |
+
confidence, x0 = probs.max(dim=-1)
|
| 74 |
+
else:
|
| 75 |
+
confidence, x0 = probs.max(dim=-1)
|
| 76 |
+
|
| 77 |
+
if margin_confidence:
|
| 78 |
+
sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
|
| 79 |
+
# Extract top1 and top2 probabilities
|
| 80 |
+
top1_probs = sorted_probs[:, 0]
|
| 81 |
+
top2_probs = sorted_probs[:, 1]
|
| 82 |
+
# Calculate confidence as top1 - top2
|
| 83 |
+
confidence = top1_probs - top2_probs
|
| 84 |
+
|
| 85 |
+
if neg_entropy:
|
| 86 |
+
epsilon = 1e-10
|
| 87 |
+
log_probs = torch.log(probs + epsilon)
|
| 88 |
+
confidence = torch.sum(probs * log_probs, dim=-1)
|
| 89 |
+
|
| 90 |
+
return confidence, x0
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@dataclass
|
| 94 |
+
class DreamModelOutput(ModelOutput):
|
| 95 |
+
sequences: torch.LongTensor = None
|
| 96 |
+
history: Optional[Tuple[torch.FloatTensor]] = None
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class DreamGenerationConfig(GenerationConfig):
|
| 100 |
+
def __init__(self, **kwargs):
|
| 101 |
+
self.temperature: float = kwargs.pop("temperature", 0.0)
|
| 102 |
+
self.top_p: Optional[float] = kwargs.pop("top_p", None)
|
| 103 |
+
self.top_k: Optional[int] = kwargs.pop("top_k", None)
|
| 104 |
+
self.max_length = kwargs.pop("max_length", 20)
|
| 105 |
+
self.max_new_tokens = kwargs.pop("max_new_tokens", None)
|
| 106 |
+
# diffusion specific params
|
| 107 |
+
self.eps: float = kwargs.pop("eps", 1e-3)
|
| 108 |
+
self.steps: int = kwargs.pop("steps", 512)
|
| 109 |
+
self.alg: str = kwargs.pop("alg", 'origin')
|
| 110 |
+
self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
|
| 111 |
+
|
| 112 |
+
# Parameters that define the output variables of `generate`
|
| 113 |
+
self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
|
| 114 |
+
self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False)
|
| 115 |
+
self.output_history: bool = kwargs.pop("output_history", False)
|
| 116 |
+
|
| 117 |
+
# Special tokens that can be used at generation time
|
| 118 |
+
self.mask_token_id = kwargs.pop("mask_token_id", None)
|
| 119 |
+
self.pad_token_id = kwargs.pop("pad_token_id", None)
|
| 120 |
+
self.bos_token_id = kwargs.pop("bos_token_id", None)
|
| 121 |
+
self.eos_token_id = kwargs.pop("eos_token_id", None)
|
| 122 |
+
|
| 123 |
+
# Wild card
|
| 124 |
+
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
|
| 125 |
+
|
| 126 |
+
# The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub
|
| 127 |
+
# interface.
|
| 128 |
+
self._from_model_config = kwargs.pop("_from_model_config", False)
|
| 129 |
+
self._commit_hash = kwargs.pop("_commit_hash", None)
|
| 130 |
+
self.transformers_version = kwargs.pop("transformers_version", __version__)
|
| 131 |
+
|
| 132 |
+
# Additional attributes without default values
|
| 133 |
+
if not self._from_model_config:
|
| 134 |
+
# we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a
|
| 135 |
+
# model's default configuration file
|
| 136 |
+
for key, value in kwargs.items():
|
| 137 |
+
try:
|
| 138 |
+
setattr(self, key, value)
|
| 139 |
+
except AttributeError as err:
|
| 140 |
+
logger.error(f"Can't set {key} with value {value} for {self}")
|
| 141 |
+
raise err
|
| 142 |
+
|
| 143 |
+
# Validate the values of the attributes
|
| 144 |
+
self.validate(is_init=True)
|
| 145 |
+
|
| 146 |
+
def validate(self, is_init=False):
|
| 147 |
+
pass
|
| 148 |
+
|
| 149 |
+
class DreamGenerationMixin:
|
| 150 |
+
@staticmethod
|
| 151 |
+
def _expand_inputs_for_generation(
|
| 152 |
+
expand_size: int = 1,
|
| 153 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 154 |
+
attention_mask: Optional[torch.LongTensor] = None
|
| 155 |
+
) -> Tuple[torch.LongTensor, Dict[str, Any]]:
|
| 156 |
+
"""Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
|
| 157 |
+
# Do not call torch.repeat_interleave if expand_size is 1 because it clones
|
| 158 |
+
# the input tensor and thus requires more memory although no change is applied
|
| 159 |
+
if expand_size == 1:
|
| 160 |
+
return input_ids, attention_mask
|
| 161 |
+
if input_ids is not None:
|
| 162 |
+
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
|
| 163 |
+
if attention_mask is not None:
|
| 164 |
+
attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
|
| 165 |
+
return input_ids, attention_mask
|
| 166 |
+
|
| 167 |
+
def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
|
| 168 |
+
"""Performs validation related to the resulting generated length"""
|
| 169 |
+
|
| 170 |
+
# Can't throw warnings/exceptions during compilation
|
| 171 |
+
if is_torchdynamo_compiling():
|
| 172 |
+
return
|
| 173 |
+
|
| 174 |
+
# 1. Max length warnings related to poor parameterization
|
| 175 |
+
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
|
| 176 |
+
# 20 is the default max_length of the generation config
|
| 177 |
+
warnings.warn(
|
| 178 |
+
f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the "
|
| 179 |
+
"generation length. We recommend setting `max_new_tokens` to control the maximum length of the "
|
| 180 |
+
"generation.",
|
| 181 |
+
UserWarning,
|
| 182 |
+
)
|
| 183 |
+
if input_ids_length >= generation_config.max_length:
|
| 184 |
+
input_ids_string = "input_ids"
|
| 185 |
+
raise ValueError(
|
| 186 |
+
f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to"
|
| 187 |
+
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
| 188 |
+
" increasing `max_length` or, better yet, setting `max_new_tokens`."
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
def _prepare_generated_length(
|
| 192 |
+
self,
|
| 193 |
+
generation_config,
|
| 194 |
+
has_default_max_length,
|
| 195 |
+
input_ids_length,
|
| 196 |
+
):
|
| 197 |
+
"""Prepared max and min length in generation configs to avoid clashes between similar attributes"""
|
| 198 |
+
|
| 199 |
+
if generation_config.max_new_tokens is not None:
|
| 200 |
+
if not has_default_max_length and generation_config.max_length is not None:
|
| 201 |
+
logger.warning(
|
| 202 |
+
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
| 203 |
+
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
| 204 |
+
"Please refer to the documentation for more information. "
|
| 205 |
+
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
| 206 |
+
)
|
| 207 |
+
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
|
| 208 |
+
|
| 209 |
+
elif has_default_max_length:
|
| 210 |
+
if generation_config.max_length == DreamGenerationConfig().max_length:
|
| 211 |
+
generation_config.max_length = generation_config.max_length + input_ids_length
|
| 212 |
+
max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
|
| 213 |
+
if max_position_embeddings is not None:
|
| 214 |
+
generation_config.max_length = min(generation_config.max_length, max_position_embeddings)
|
| 215 |
+
|
| 216 |
+
return generation_config
|
| 217 |
+
|
| 218 |
+
def _prepare_generation_config(
|
| 219 |
+
self, generation_config: Optional[DreamGenerationConfig], **kwargs: Dict
|
| 220 |
+
) -> DreamGenerationConfig:
|
| 221 |
+
"""
|
| 222 |
+
Prepares the base generation config, then applies any generation configuration options from kwargs. This
|
| 223 |
+
function handles retrocompatibility with respect to configuration files.
|
| 224 |
+
"""
|
| 225 |
+
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
|
| 226 |
+
using_model_generation_config = False
|
| 227 |
+
if generation_config is None:
|
| 228 |
+
generation_config = DreamGenerationConfig.from_model_config(self.config)
|
| 229 |
+
using_model_generation_config = True
|
| 230 |
+
|
| 231 |
+
# `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config`
|
| 232 |
+
# will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled -- an
|
| 233 |
+
# exception will be raised in `_validate_model_kwargs`
|
| 234 |
+
if not is_torchdynamo_compiling():
|
| 235 |
+
generation_config = copy.deepcopy(generation_config)
|
| 236 |
+
_kwargs = generation_config.update(**kwargs)
|
| 237 |
+
# If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model
|
| 238 |
+
if not using_model_generation_config:
|
| 239 |
+
if generation_config.bos_token_id is None:
|
| 240 |
+
generation_config.bos_token_id = self.generation_config.bos_token_id
|
| 241 |
+
if generation_config.eos_token_id is None:
|
| 242 |
+
generation_config.eos_token_id = self.generation_config.eos_token_id
|
| 243 |
+
if generation_config.pad_token_id is None:
|
| 244 |
+
generation_config.pad_token_id = self.generation_config.pad_token_id
|
| 245 |
+
if generation_config.mask_token_id is None:
|
| 246 |
+
generation_config.mask_token_id = self.generation_config.mask_token_id
|
| 247 |
+
|
| 248 |
+
return generation_config
|
| 249 |
+
|
| 250 |
+
def _prepare_special_tokens(
|
| 251 |
+
self,
|
| 252 |
+
generation_config: DreamGenerationConfig,
|
| 253 |
+
device: Optional[Union[torch.device, str]] = None,
|
| 254 |
+
):
|
| 255 |
+
"""
|
| 256 |
+
Prepares the special tokens for generation, overwriting the generation config with their processed versions
|
| 257 |
+
converted to tensor.
|
| 258 |
+
Note that `generation_config` is changed in place and stops being serializable after this method is called.
|
| 259 |
+
That is no problem if called within `generate` (`generation_config` is a local copy that doesn't leave the
|
| 260 |
+
function). However, if called outside `generate`, consider creating a copy of `generation_config` first.
|
| 261 |
+
"""
|
| 262 |
+
|
| 263 |
+
# Convert special tokens to tensors
|
| 264 |
+
def _tensor_or_none(token, device=None):
|
| 265 |
+
if token is None:
|
| 266 |
+
return token
|
| 267 |
+
|
| 268 |
+
device = device if device is not None else self.device
|
| 269 |
+
if isinstance(token, torch.Tensor):
|
| 270 |
+
return token.to(device)
|
| 271 |
+
return torch.tensor(token, device=device, dtype=torch.long)
|
| 272 |
+
|
| 273 |
+
bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device)
|
| 274 |
+
eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device)
|
| 275 |
+
pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device)
|
| 276 |
+
mask_token_tensor = _tensor_or_none(generation_config.mask_token_id, device=device)
|
| 277 |
+
|
| 278 |
+
# We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
|
| 279 |
+
if eos_token_tensor is not None and eos_token_tensor.ndim == 0:
|
| 280 |
+
eos_token_tensor = eos_token_tensor.unsqueeze(0)
|
| 281 |
+
|
| 282 |
+
# Set pad token if unset (and there are conditions to do so)
|
| 283 |
+
if pad_token_tensor is None and eos_token_tensor is not None:
|
| 284 |
+
pad_token_tensor = eos_token_tensor[0]
|
| 285 |
+
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
|
| 286 |
+
|
| 287 |
+
# Update generation config with the updated special tokens tensors
|
| 288 |
+
# NOTE: this must be written into a different attribute name than the one holding the original special tokens
|
| 289 |
+
# (in their non-tensor form), in order to enable end-to-end compilation. See
|
| 290 |
+
# https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations
|
| 291 |
+
generation_config._bos_token_tensor = bos_token_tensor
|
| 292 |
+
generation_config._eos_token_tensor = eos_token_tensor
|
| 293 |
+
generation_config._pad_token_tensor = pad_token_tensor
|
| 294 |
+
generation_config._mask_token_tensor = mask_token_tensor
|
| 295 |
+
|
| 296 |
+
@torch.no_grad()
|
| 297 |
+
def diffusion_generate(
|
| 298 |
+
self,
|
| 299 |
+
inputs: Optional[torch.Tensor] = None,
|
| 300 |
+
generation_config: Optional[DreamGenerationConfig] = None,
|
| 301 |
+
**kwargs,
|
| 302 |
+
) -> Union[DreamModelOutput, torch.LongTensor]:
|
| 303 |
+
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
| 304 |
+
generation_config = self._prepare_generation_config(generation_config, **kwargs)
|
| 305 |
+
generation_tokens_hook_func = kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x)
|
| 306 |
+
generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits)
|
| 307 |
+
|
| 308 |
+
# 2. Define model inputs
|
| 309 |
+
assert inputs is not None
|
| 310 |
+
input_ids = inputs
|
| 311 |
+
device = input_ids.device
|
| 312 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
| 313 |
+
self._prepare_special_tokens(generation_config, device=device)
|
| 314 |
+
|
| 315 |
+
# 3. Prepare `max_length`.
|
| 316 |
+
input_ids_length = input_ids.shape[-1]
|
| 317 |
+
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
| 318 |
+
generation_config = self._prepare_generated_length(
|
| 319 |
+
generation_config=generation_config,
|
| 320 |
+
has_default_max_length=has_default_max_length,
|
| 321 |
+
input_ids_length=input_ids_length,
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
| 325 |
+
|
| 326 |
+
# 4. Check input_ids
|
| 327 |
+
if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
|
| 328 |
+
warnings.warn(
|
| 329 |
+
"You are calling .generate() with the `input_ids` being on a device type different"
|
| 330 |
+
f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
|
| 331 |
+
f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
|
| 332 |
+
" Please make sure that you have put `input_ids` to the"
|
| 333 |
+
f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
|
| 334 |
+
" running `.generate()`.",
|
| 335 |
+
UserWarning,
|
| 336 |
+
)
|
| 337 |
+
if (
|
| 338 |
+
hasattr(generation_config, "pad_token_id") and
|
| 339 |
+
torch.any(input_ids == generation_config.pad_token_id) and
|
| 340 |
+
attention_mask is None
|
| 341 |
+
):
|
| 342 |
+
warnings.warn(
|
| 343 |
+
"Padding was detected but no attention mask is passed here. For correct "
|
| 344 |
+
"generation results, please set `attention_mask` when batch-padding inputs.",
|
| 345 |
+
UserWarning,
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
input_ids, attention_mask = self._expand_inputs_for_generation(
|
| 349 |
+
expand_size=generation_config.num_return_sequences,
|
| 350 |
+
input_ids=input_ids,
|
| 351 |
+
attention_mask=attention_mask
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
result = self._sample(
|
| 355 |
+
input_ids,
|
| 356 |
+
attention_mask=attention_mask,
|
| 357 |
+
generation_config=generation_config,
|
| 358 |
+
generation_tokens_hook_func=generation_tokens_hook_func,
|
| 359 |
+
generation_logits_hook_func=generation_logits_hook_func
|
| 360 |
+
)
|
| 361 |
+
return result
|
| 362 |
+
|
| 363 |
+
def _sample(
|
| 364 |
+
self,
|
| 365 |
+
input_ids: torch.LongTensor,
|
| 366 |
+
attention_mask: Optional[torch.LongTensor],
|
| 367 |
+
generation_config: DreamGenerationConfig,
|
| 368 |
+
generation_tokens_hook_func,
|
| 369 |
+
generation_logits_hook_func
|
| 370 |
+
) -> Union[DreamModelOutput, torch.LongTensor]:
|
| 371 |
+
# init values
|
| 372 |
+
output_history = generation_config.output_history
|
| 373 |
+
return_dict_in_generate = generation_config.return_dict_in_generate
|
| 374 |
+
max_length = generation_config.max_length
|
| 375 |
+
mask_token_id = generation_config.mask_token_id
|
| 376 |
+
steps = generation_config.steps
|
| 377 |
+
eps = generation_config.eps
|
| 378 |
+
alg = generation_config.alg
|
| 379 |
+
alg_temp = generation_config.alg_temp
|
| 380 |
+
temperature = generation_config.temperature
|
| 381 |
+
top_p = generation_config.top_p
|
| 382 |
+
top_k = generation_config.top_k
|
| 383 |
+
|
| 384 |
+
histories = [] if (return_dict_in_generate and output_history) else None
|
| 385 |
+
|
| 386 |
+
# pad input_ids to max_length
|
| 387 |
+
x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id)
|
| 388 |
+
|
| 389 |
+
if attention_mask is not None and torch.any(attention_mask == 0.0):
|
| 390 |
+
# we do not mask the [MASK] tokens so value = 1.0
|
| 391 |
+
attention_mask = F.pad(attention_mask, (0, max_length - attention_mask.shape[1]), value=1.0)
|
| 392 |
+
tok_idx = attention_mask.long().cumsum(-1) - 1
|
| 393 |
+
tok_idx.masked_fill_(attention_mask == 0, 1)
|
| 394 |
+
# attention_mask is of shape [B, N]
|
| 395 |
+
# broadcast to [B, 1, N, N]
|
| 396 |
+
attention_mask = torch.logical_and(
|
| 397 |
+
attention_mask.unsqueeze(1).unsqueeze(-2),
|
| 398 |
+
attention_mask.unsqueeze(1).unsqueeze(-1),
|
| 399 |
+
)
|
| 400 |
+
else:
|
| 401 |
+
tok_idx = None
|
| 402 |
+
attention_mask = "full"
|
| 403 |
+
|
| 404 |
+
timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
|
| 405 |
+
|
| 406 |
+
# this allows user-defined token control of the intermediate steps
|
| 407 |
+
x = generation_tokens_hook_func(None, x, None)
|
| 408 |
+
for i in range(steps):
|
| 409 |
+
mask_index = (x == mask_token_id)
|
| 410 |
+
logits = self(x, attention_mask, tok_idx).logits
|
| 411 |
+
logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
|
| 412 |
+
|
| 413 |
+
# this allows user-defined logits control of the intermediate steps
|
| 414 |
+
logits = generation_logits_hook_func(i, x, logits)
|
| 415 |
+
|
| 416 |
+
mask_logits = logits[mask_index]
|
| 417 |
+
t = timesteps[i]
|
| 418 |
+
s = timesteps[i + 1]
|
| 419 |
+
|
| 420 |
+
if alg == 'origin':
|
| 421 |
+
p_transfer = 1 - s / t if i < steps - 1 else 1
|
| 422 |
+
x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
|
| 423 |
+
transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
|
| 424 |
+
_, x0[transfer_index_t_s]= sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k)
|
| 425 |
+
x[mask_index] = x0.clone()
|
| 426 |
+
else:
|
| 427 |
+
if alg == 'maskgit_plus':
|
| 428 |
+
confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
|
| 429 |
+
elif alg == 'topk_margin':
|
| 430 |
+
confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True)
|
| 431 |
+
elif alg == 'entropy':
|
| 432 |
+
confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
|
| 433 |
+
else:
|
| 434 |
+
raise RuntimeError(f"Unknown alg: {alg}")
|
| 435 |
+
num_mask_token = mask_index.sum() / mask_index.shape[0]
|
| 436 |
+
number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
|
| 437 |
+
full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype)
|
| 438 |
+
full_confidence[mask_index] = confidence
|
| 439 |
+
if number_transfer_tokens > 0:
|
| 440 |
+
if alg_temp is None or alg_temp == 0:
|
| 441 |
+
_, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
|
| 442 |
+
else:
|
| 443 |
+
full_confidence = full_confidence / alg_temp
|
| 444 |
+
full_confidence = F.softmax(full_confidence, dim=-1)
|
| 445 |
+
transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens)
|
| 446 |
+
x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id
|
| 447 |
+
x_[mask_index] = x0.clone()
|
| 448 |
+
row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
|
| 449 |
+
x[row_indices,transfer_index] = x_[row_indices,transfer_index]
|
| 450 |
+
|
| 451 |
+
# this allows user-defined token control of the intermediate steps
|
| 452 |
+
x = generation_tokens_hook_func(i, x, logits)
|
| 453 |
+
|
| 454 |
+
if histories is not None:
|
| 455 |
+
histories.append(x.clone())
|
| 456 |
+
|
| 457 |
+
if return_dict_in_generate:
|
| 458 |
+
return DreamModelOutput(
|
| 459 |
+
sequences=x,
|
| 460 |
+
history=histories,
|
| 461 |
+
)
|
| 462 |
+
else:
|
| 463 |
+
return x
|
Discrete-Diffusion-Forcing/D2F-eval/model_cache/dream/model_dream.py
ADDED
|
@@ -0,0 +1,1029 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hugging Face's logo
|
| 2 |
+
# Hugging Face
|
| 3 |
+
# Models
|
| 4 |
+
# Datasets
|
| 5 |
+
# Spaces
|
| 6 |
+
# Community
|
| 7 |
+
# Docs
|
| 8 |
+
# Enterprise
|
| 9 |
+
# Pricing
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# Dream-org
|
| 15 |
+
# /
|
| 16 |
+
# Dream-v0-Instruct-7B
|
| 17 |
+
|
| 18 |
+
# like
|
| 19 |
+
# 94
|
| 20 |
+
|
| 21 |
+
# Follow
|
| 22 |
+
|
| 23 |
+
# Dream Org
|
| 24 |
+
# 81
|
| 25 |
+
# Feature Extraction
|
| 26 |
+
# Transformers
|
| 27 |
+
# Safetensors
|
| 28 |
+
# Dream
|
| 29 |
+
# custom_code
|
| 30 |
+
|
| 31 |
+
# License:
|
| 32 |
+
# apache-2.0
|
| 33 |
+
# Model card
|
| 34 |
+
# Files and versions
|
| 35 |
+
# Community
|
| 36 |
+
# 2
|
| 37 |
+
# Dream-v0-Instruct-7B
|
| 38 |
+
# /
|
| 39 |
+
# modeling_dream.py
|
| 40 |
+
|
| 41 |
+
# jiacheng-ye's picture
|
| 42 |
+
# jiacheng-ye
|
| 43 |
+
# Upload model
|
| 44 |
+
# 373705a
|
| 45 |
+
# verified
|
| 46 |
+
# about 2 months ago
|
| 47 |
+
# raw
|
| 48 |
+
|
| 49 |
+
# Copy download link
|
| 50 |
+
# history
|
| 51 |
+
# blame
|
| 52 |
+
# contribute
|
| 53 |
+
# delete
|
| 54 |
+
|
| 55 |
+
# 36.8 kB
|
| 56 |
+
# # coding=utf-8
|
| 57 |
+
# Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved.
|
| 58 |
+
#
|
| 59 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 60 |
+
# and OPT and Qwen implementations in this library. It has been modified from its
|
| 61 |
+
# original forms to accommodate minor architectural differences compared
|
| 62 |
+
# to GPT-NeoX and OPT and Qwen used by the Meta AI and Qwen team that trained the model.
|
| 63 |
+
#
|
| 64 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 65 |
+
# you may not use this file except in compliance with the License.
|
| 66 |
+
# You may obtain a copy of the License at
|
| 67 |
+
#
|
| 68 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 69 |
+
#
|
| 70 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 71 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 72 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 73 |
+
# See the License for the specific language governing permissions and
|
| 74 |
+
# limitations under the License.
|
| 75 |
+
"""PyTorch Dream model."""
|
| 76 |
+
from transformers import Qwen2Model
|
| 77 |
+
from torch.nn.attention.flex_attention import flex_attention
|
| 78 |
+
import math
|
| 79 |
+
from typing import List, Optional, Tuple, Union
|
| 80 |
+
import os
|
| 81 |
+
import torch
|
| 82 |
+
import torch.utils.checkpoint
|
| 83 |
+
from torch import nn
|
| 84 |
+
|
| 85 |
+
from transformers.activations import ACT2FN
|
| 86 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 87 |
+
from transformers.modeling_outputs import (
|
| 88 |
+
BaseModelOutput,
|
| 89 |
+
MaskedLMOutput,
|
| 90 |
+
BaseModelOutputWithPast,
|
| 91 |
+
CausalLMOutputWithPast
|
| 92 |
+
)
|
| 93 |
+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
| 94 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 95 |
+
from transformers.utils import (
|
| 96 |
+
add_start_docstrings,
|
| 97 |
+
add_start_docstrings_to_model_forward,
|
| 98 |
+
is_flash_attn_2_available,
|
| 99 |
+
is_flash_attn_greater_or_equal_2_10,
|
| 100 |
+
logging,
|
| 101 |
+
)
|
| 102 |
+
from transformers import PretrainedConfig
|
| 103 |
+
from model_cache.dream.configuration_dream import DreamConfig
|
| 104 |
+
from model_cache.dream.generation_utils import DreamGenerationMixin, DreamGenerationConfig
|
| 105 |
+
if is_flash_attn_2_available():
|
| 106 |
+
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
logger = logging.get_logger(__name__)
|
| 110 |
+
|
| 111 |
+
from transformers import Qwen2ForCausalLM
|
| 112 |
+
_CHECKPOINT_FOR_DOC = "Dream-7B"
|
| 113 |
+
_CONFIG_FOR_DOC = "DreamConfig"
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Dream
|
| 117 |
+
class DreamRMSNorm(nn.Module):
|
| 118 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 119 |
+
"""
|
| 120 |
+
DreamRMSNorm is equivalent to T5LayerNorm
|
| 121 |
+
"""
|
| 122 |
+
super().__init__()
|
| 123 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 124 |
+
self.variance_epsilon = eps
|
| 125 |
+
|
| 126 |
+
def forward(self, hidden_states):
|
| 127 |
+
input_dtype = hidden_states.dtype
|
| 128 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 129 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 130 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 131 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 132 |
+
|
| 133 |
+
def extra_repr(self):
|
| 134 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Dream
|
| 138 |
+
class DreamRotaryEmbedding(nn.Module):
|
| 139 |
+
def __init__(
|
| 140 |
+
self,
|
| 141 |
+
dim=None,
|
| 142 |
+
max_position_embeddings=2048,
|
| 143 |
+
base=10000,
|
| 144 |
+
device=None,
|
| 145 |
+
scaling_factor=1.0,
|
| 146 |
+
rope_type="default",
|
| 147 |
+
config: Optional[DreamConfig] = None,
|
| 148 |
+
):
|
| 149 |
+
super().__init__()
|
| 150 |
+
# TODO (joao): remove the `if` below, only used for BC
|
| 151 |
+
self.rope_kwargs = {}
|
| 152 |
+
if config is None:
|
| 153 |
+
logger.warning_once(
|
| 154 |
+
"`DreamRotaryEmbedding` can now be fully parameterized by passing the model config through the "
|
| 155 |
+
"`config` argument. All other arguments will be removed in v4.46"
|
| 156 |
+
)
|
| 157 |
+
self.rope_kwargs = {
|
| 158 |
+
"rope_type": rope_type,
|
| 159 |
+
"factor": scaling_factor,
|
| 160 |
+
"dim": dim,
|
| 161 |
+
"base": base,
|
| 162 |
+
"max_position_embeddings": max_position_embeddings,
|
| 163 |
+
}
|
| 164 |
+
self.rope_type = rope_type
|
| 165 |
+
self.max_seq_len_cached = max_position_embeddings
|
| 166 |
+
self.original_max_seq_len = max_position_embeddings
|
| 167 |
+
else:
|
| 168 |
+
# BC: "rope_type" was originally "type"
|
| 169 |
+
if config.rope_scaling is not None:
|
| 170 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 171 |
+
else:
|
| 172 |
+
self.rope_type = "default"
|
| 173 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 174 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 175 |
+
|
| 176 |
+
self.config = config
|
| 177 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 178 |
+
|
| 179 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
| 180 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 181 |
+
self.original_inv_freq = self.inv_freq
|
| 182 |
+
|
| 183 |
+
def reset_parameters(self):
|
| 184 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, self.inv_freq.device, **self.rope_kwargs)
|
| 185 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 186 |
+
self.original_inv_freq = self.inv_freq
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def _dynamic_frequency_update(self, position_ids, device):
|
| 190 |
+
"""
|
| 191 |
+
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
| 192 |
+
1 - growing beyond the cached sequence length (allow scaling)
|
| 193 |
+
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
| 194 |
+
"""
|
| 195 |
+
seq_len = torch.max(position_ids) + 1
|
| 196 |
+
if seq_len > self.max_seq_len_cached: # growth
|
| 197 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(
|
| 198 |
+
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
| 199 |
+
)
|
| 200 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
| 201 |
+
self.max_seq_len_cached = seq_len
|
| 202 |
+
|
| 203 |
+
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
| 204 |
+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
| 205 |
+
self.max_seq_len_cached = self.original_max_seq_len
|
| 206 |
+
|
| 207 |
+
@torch.no_grad()
|
| 208 |
+
def forward(self, x, position_ids):
|
| 209 |
+
if "dynamic" in self.rope_type:
|
| 210 |
+
self._dynamic_frequency_update(position_ids, device=x.device)
|
| 211 |
+
|
| 212 |
+
# Core RoPE block
|
| 213 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
| 214 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 215 |
+
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
| 216 |
+
device_type = x.device.type
|
| 217 |
+
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
| 218 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
| 219 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 220 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 221 |
+
cos = emb.cos()
|
| 222 |
+
sin = emb.sin()
|
| 223 |
+
|
| 224 |
+
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
| 225 |
+
cos = cos * self.attention_scaling
|
| 226 |
+
sin = sin * self.attention_scaling
|
| 227 |
+
|
| 228 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
| 232 |
+
def rotate_half(x):
|
| 233 |
+
"""Rotates half the hidden dims of the input."""
|
| 234 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 235 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 236 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
| 240 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 241 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 242 |
+
Args:
|
| 243 |
+
q (`torch.Tensor`): The query tensor.
|
| 244 |
+
k (`torch.Tensor`): The key tensor.
|
| 245 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 246 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 247 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 248 |
+
Deprecated and unused.
|
| 249 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 250 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 251 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 252 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 253 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 254 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 255 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 256 |
+
Returns:
|
| 257 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 258 |
+
"""
|
| 259 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 260 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 261 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 262 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 263 |
+
return q_embed, k_embed
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Dream
|
| 267 |
+
class DreamMLP(nn.Module):
|
| 268 |
+
def __init__(self, config):
|
| 269 |
+
super().__init__()
|
| 270 |
+
self.hidden_size = config.hidden_size
|
| 271 |
+
self.intermediate_size = config.intermediate_size
|
| 272 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 273 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 274 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 275 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 276 |
+
|
| 277 |
+
def forward(self, hidden_state):
|
| 278 |
+
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
| 282 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 283 |
+
"""
|
| 284 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 285 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 286 |
+
"""
|
| 287 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 288 |
+
if n_rep == 1:
|
| 289 |
+
return hidden_states
|
| 290 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 291 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class DreamAttention(nn.Module):
|
| 295 |
+
"""
|
| 296 |
+
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
| 297 |
+
and "Generating Long Sequences with Sparse Transformers".
|
| 298 |
+
"""
|
| 299 |
+
|
| 300 |
+
def __init__(self, config: DreamConfig, layer_idx: Optional[int] = None):
|
| 301 |
+
super().__init__()
|
| 302 |
+
self.config = config
|
| 303 |
+
self.layer_idx = layer_idx
|
| 304 |
+
if layer_idx is None:
|
| 305 |
+
logger.warning_once(
|
| 306 |
+
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
|
| 307 |
+
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
| 308 |
+
"when creating this class."
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
self.hidden_size = config.hidden_size
|
| 312 |
+
self.num_heads = config.num_attention_heads
|
| 313 |
+
self.head_dim = self.hidden_size // self.num_heads
|
| 314 |
+
self.num_key_value_heads = config.num_key_value_heads
|
| 315 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| 316 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 317 |
+
self.rope_theta = config.rope_theta
|
| 318 |
+
self.is_causal = False
|
| 319 |
+
self.attention_dropout = config.attention_dropout
|
| 320 |
+
|
| 321 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
| 322 |
+
raise ValueError(
|
| 323 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
| 324 |
+
f" and `num_heads`: {self.num_heads})."
|
| 325 |
+
)
|
| 326 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
|
| 327 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
| 328 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
| 329 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
| 330 |
+
|
| 331 |
+
self.rotary_emb = DreamRotaryEmbedding(config=self.config)
|
| 332 |
+
|
| 333 |
+
def forward(
|
| 334 |
+
self,
|
| 335 |
+
hidden_states: torch.Tensor,
|
| 336 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 337 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 338 |
+
past_key_value: Optional[Cache] = None,
|
| 339 |
+
output_attentions: bool = False,
|
| 340 |
+
use_cache: bool = False,
|
| 341 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 342 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
| 343 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 344 |
+
bsz, q_len, _ = hidden_states.size()
|
| 345 |
+
|
| 346 |
+
query_states = self.q_proj(hidden_states)
|
| 347 |
+
key_states = self.k_proj(hidden_states)
|
| 348 |
+
value_states = self.v_proj(hidden_states)
|
| 349 |
+
|
| 350 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 351 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 352 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 353 |
+
|
| 354 |
+
if position_embeddings is None:
|
| 355 |
+
logger.warning_once(
|
| 356 |
+
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
| 357 |
+
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
| 358 |
+
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
| 359 |
+
"removed and `position_embeddings` will be mandatory."
|
| 360 |
+
)
|
| 361 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
| 362 |
+
else:
|
| 363 |
+
cos, sin = position_embeddings
|
| 364 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 365 |
+
|
| 366 |
+
if past_key_value is not None:
|
| 367 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
| 368 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 369 |
+
|
| 370 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
| 371 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 372 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 373 |
+
|
| 374 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
| 375 |
+
if attention_mask is not None: # no matter the length, we just slice it
|
| 376 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 377 |
+
attn_weights = attn_weights + causal_mask
|
| 378 |
+
|
| 379 |
+
# upcast attention to fp32
|
| 380 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 381 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
| 382 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 383 |
+
|
| 384 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
| 385 |
+
raise ValueError(
|
| 386 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
| 387 |
+
f" {attn_output.size()}"
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 391 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 392 |
+
|
| 393 |
+
attn_output = self.o_proj(attn_output)
|
| 394 |
+
|
| 395 |
+
if not output_attentions:
|
| 396 |
+
attn_weights = None
|
| 397 |
+
|
| 398 |
+
return attn_output, attn_weights, past_key_value
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
class DreamSdpaAttention(DreamAttention):
|
| 402 |
+
"""
|
| 403 |
+
Dream attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
| 404 |
+
`DreamAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
| 405 |
+
SDPA API.
|
| 406 |
+
"""
|
| 407 |
+
|
| 408 |
+
# Adapted from DreamAttention.forward
|
| 409 |
+
def forward(
|
| 410 |
+
self,
|
| 411 |
+
hidden_states: torch.Tensor,
|
| 412 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 413 |
+
update_kvcache: torch.int32 = None,
|
| 414 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 415 |
+
past_key_value: Optional[Cache] = None,
|
| 416 |
+
output_attentions: bool = False,
|
| 417 |
+
use_cache: bool = False,
|
| 418 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 419 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
| 420 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 421 |
+
if output_attentions:
|
| 422 |
+
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
| 423 |
+
logger.warning_once(
|
| 424 |
+
"DreamModel is using DreamSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
| 425 |
+
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 426 |
+
)
|
| 427 |
+
return super().forward(
|
| 428 |
+
hidden_states=hidden_states,
|
| 429 |
+
attention_mask=attention_mask,
|
| 430 |
+
position_ids=position_ids,
|
| 431 |
+
past_key_value=past_key_value,
|
| 432 |
+
output_attentions=output_attentions,
|
| 433 |
+
use_cache=use_cache,
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
bsz, q_len, _ = hidden_states.size()
|
| 437 |
+
|
| 438 |
+
query_states = self.q_proj(hidden_states)
|
| 439 |
+
key_states = self.k_proj(hidden_states)
|
| 440 |
+
value_states = self.v_proj(hidden_states)
|
| 441 |
+
|
| 442 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 443 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 444 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 445 |
+
|
| 446 |
+
if position_embeddings is None:
|
| 447 |
+
logger.warning_once(
|
| 448 |
+
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
| 449 |
+
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
| 450 |
+
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
| 451 |
+
"removed and `position_embeddings` will be mandatory."
|
| 452 |
+
)
|
| 453 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
| 454 |
+
else:
|
| 455 |
+
cos, sin = position_embeddings
|
| 456 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 457 |
+
|
| 458 |
+
if past_key_value is not None:
|
| 459 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
| 460 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 461 |
+
|
| 462 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 463 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 464 |
+
|
| 465 |
+
# causal_mask = attention_mask
|
| 466 |
+
# if attention_mask is not None: # no matter the length, we just slice it
|
| 467 |
+
# causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 468 |
+
|
| 469 |
+
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
| 470 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
| 471 |
+
if query_states.device.type == "cuda" and attention_mask is not None:
|
| 472 |
+
query_states = query_states.contiguous()
|
| 473 |
+
key_states = key_states.contiguous()
|
| 474 |
+
value_states = value_states.contiguous()
|
| 475 |
+
|
| 476 |
+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
| 477 |
+
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
| 478 |
+
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
| 479 |
+
# is_causal = True if causal_mask is None and q_len > 1 else False
|
| 480 |
+
|
| 481 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 482 |
+
query_states,
|
| 483 |
+
key_states,
|
| 484 |
+
value_states,
|
| 485 |
+
attn_mask=attention_mask if attention_mask is not None else None,
|
| 486 |
+
dropout_p=self.attention_dropout if self.training else 0.0,
|
| 487 |
+
is_causal=False, # hard coded
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 491 |
+
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
| 492 |
+
|
| 493 |
+
attn_output = self.o_proj(attn_output)
|
| 494 |
+
|
| 495 |
+
return attn_output, None, past_key_value
|
| 496 |
+
class DreamFlexAttention(DreamAttention):
|
| 497 |
+
"""
|
| 498 |
+
Dream attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
| 499 |
+
`DreamAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
| 500 |
+
SDPA API.
|
| 501 |
+
"""
|
| 502 |
+
|
| 503 |
+
# Adapted from DreamAttention.forward
|
| 504 |
+
def forward(
|
| 505 |
+
self,
|
| 506 |
+
hidden_states: torch.Tensor,
|
| 507 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 508 |
+
update_kvcache: torch.int32 = None,
|
| 509 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 510 |
+
past_key_value: Optional[Cache] = None,
|
| 511 |
+
output_attentions: bool = False,
|
| 512 |
+
use_cache: bool = False,
|
| 513 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 514 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
| 515 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 516 |
+
if output_attentions:
|
| 517 |
+
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
| 518 |
+
logger.warning_once(
|
| 519 |
+
"DreamModel is using DreamSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
| 520 |
+
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 521 |
+
)
|
| 522 |
+
return super().forward(
|
| 523 |
+
hidden_states=hidden_states,
|
| 524 |
+
attention_mask=attention_mask,
|
| 525 |
+
position_ids=position_ids,
|
| 526 |
+
past_key_value=past_key_value,
|
| 527 |
+
output_attentions=output_attentions,
|
| 528 |
+
use_cache=use_cache,
|
| 529 |
+
)
|
| 530 |
+
# print("hidden_states",hidden_states)
|
| 531 |
+
bsz, q_len, _ = hidden_states.size()
|
| 532 |
+
|
| 533 |
+
query_states = self.q_proj(hidden_states)
|
| 534 |
+
key_states = self.k_proj(hidden_states)
|
| 535 |
+
value_states = self.v_proj(hidden_states)
|
| 536 |
+
|
| 537 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 538 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 539 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 540 |
+
|
| 541 |
+
if position_embeddings is None:
|
| 542 |
+
logger.warning_once(
|
| 543 |
+
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
| 544 |
+
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
| 545 |
+
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
| 546 |
+
"removed and `position_embeddings` will be mandatory."
|
| 547 |
+
)
|
| 548 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
| 549 |
+
else:
|
| 550 |
+
cos, sin = position_embeddings
|
| 551 |
+
# print(query_states.shape,key_states.shape,cos.shape,sin.shape)
|
| 552 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 553 |
+
# print("k,v",key_states.shape,value_states.shape,past_key_value)
|
| 554 |
+
# print(cos.shape,sin.shape,cache_position.shape)
|
| 555 |
+
if past_key_value is not None:
|
| 556 |
+
if update_kvcache == 0:
|
| 557 |
+
past_key_states, past_value_states = past_key_value[self.layer_idx]
|
| 558 |
+
key_states=torch.cat([past_key_states, key_states], dim=2)
|
| 559 |
+
value_states=torch.cat([past_value_states, value_states], dim=2)
|
| 560 |
+
# Specific to RoPE models
|
| 561 |
+
else:
|
| 562 |
+
cache_kwargs = {"sin": sin[:,:update_kvcache,:], "cos": cos[:,:update_kvcache,:], "cache_position": cache_position[:update_kvcache]}
|
| 563 |
+
# print("update_kvcache",update_kvcache)
|
| 564 |
+
new_key_states, new_value_states = past_key_value.update(key_states[:,:,:update_kvcache, :], value_states[:,:,:update_kvcache, : ], self.layer_idx, cache_kwargs)
|
| 565 |
+
# print("new_kv",new_key_states.shape,new_value_states.shape)
|
| 566 |
+
# print("k,v",new_key_states.shape,new_value_states.shape)
|
| 567 |
+
key_states = torch.cat([new_key_states,key_states[:,:,update_kvcache:,:]], dim=2)
|
| 568 |
+
value_states = torch.cat([new_value_states,value_states[:,:,update_kvcache:,:]], dim=2)
|
| 569 |
+
# print("k,v",key_states.shape,value_states.shape)
|
| 570 |
+
# print(key_states.shape,value_states.shape)
|
| 571 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 572 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 573 |
+
|
| 574 |
+
# causal_mask = attention_mask
|
| 575 |
+
if attention_mask is not None: # no matter the length, we just slice it
|
| 576 |
+
atte_mask = attention_mask[:,:, :, : key_states.shape[-2]].clone()
|
| 577 |
+
# print(update_kvcache,attention_mask.shape)
|
| 578 |
+
# if attention_mask.shape[3]>86+32:
|
| 579 |
+
# if attention_mask.shape[-1]!=attention_mask.shape[-2]:
|
| 580 |
+
# atte_mask[:,:,:update_kvcache,-update_kvcache:]=-torch.inf
|
| 581 |
+
|
| 582 |
+
# if update_kvcache > 0:
|
| 583 |
+
# print("attention_mask中出现过的值",atte_mask.unique())
|
| 584 |
+
# print('tTTTTTTTTT')
|
| 585 |
+
# print("-"*20)
|
| 586 |
+
# print("attention_mask",attention_mask,update_kvcache)
|
| 587 |
+
# print(attention_mask)
|
| 588 |
+
# exit()
|
| 589 |
+
# print(attention_mask[0,0,:,:],attention_mask[0,0,:,:].shape)
|
| 590 |
+
# exit(0)
|
| 591 |
+
|
| 592 |
+
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
| 593 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
| 594 |
+
if query_states.device.type == "cuda" and attention_mask is not None:
|
| 595 |
+
query_states = query_states.contiguous()
|
| 596 |
+
key_states = key_states.contiguous()
|
| 597 |
+
value_states = value_states.contiguous()
|
| 598 |
+
|
| 599 |
+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
| 600 |
+
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
| 601 |
+
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
| 602 |
+
# is_causal = True if causal_mask is None and q_len > 1 else False
|
| 603 |
+
# print(query_states.shape[2], key_states.shape[2])
|
| 604 |
+
# attention_mask=attention_mask[:,:, :key_states.shape[2], :key_states.shape[2]] if attention_mask is not None else None
|
| 605 |
+
# attn_output = flex_attention(query_states, key_states, value_states, block_mask= attention_mask ),
|
| 606 |
+
# print(query_states.shape, key_states.shape, value_states.shape, attention_mask.shape if attention_mask is not None else None)
|
| 607 |
+
# print(query_states.dtype,attention_mask.dtype if attention_mask is not None else None)
|
| 608 |
+
# print(self.training)
|
| 609 |
+
# print("key_states",key_states[:,:,:84,:])
|
| 610 |
+
# torch.save(key_states,"key_states1.pt")
|
| 611 |
+
# torch.save(value_states,"value_states1.pt")
|
| 612 |
+
# torch.save(value_states,"query_state1.pt")
|
| 613 |
+
# torch.save(attention_mask,"attention_mask1.pt")
|
| 614 |
+
# print(atte_mask.shape)
|
| 615 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 616 |
+
query_states,
|
| 617 |
+
key_states,
|
| 618 |
+
value_states,
|
| 619 |
+
attn_mask=atte_mask if attention_mask is not None else None,
|
| 620 |
+
dropout_p=self.attention_dropout if self.training else 0.0,
|
| 621 |
+
is_causal=False, # hard coded
|
| 622 |
+
)
|
| 623 |
+
# print("attn_output",attn_output[:,:,:84,:],attn_output.shape)
|
| 624 |
+
# print(atte_mask[:,:,:84,:84],attenti_mask.shape)
|
| 625 |
+
# exit()
|
| 626 |
+
# if self.layer_idx==2:
|
| 627 |
+
# torch.save(attn_output,"attn_output2.pt")
|
| 628 |
+
# exit()
|
| 629 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 630 |
+
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
| 631 |
+
|
| 632 |
+
attn_output = self.o_proj(attn_output)
|
| 633 |
+
|
| 634 |
+
return attn_output, None, past_key_value
|
| 635 |
+
|
| 636 |
+
class DreamDecoderLayer(nn.Module):
|
| 637 |
+
def __init__(self, config: DreamConfig, layer_idx: int):
|
| 638 |
+
super().__init__()
|
| 639 |
+
self.hidden_size = config.hidden_size
|
| 640 |
+
|
| 641 |
+
if config.sliding_window and config._attn_implementation != "flash_attention_2":
|
| 642 |
+
logger.warning_once(
|
| 643 |
+
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
|
| 644 |
+
"unexpected results may be encountered."
|
| 645 |
+
)
|
| 646 |
+
|
| 647 |
+
# self.self_attn = Dream_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
| 648 |
+
self.self_attn = DreamFlexAttention(config, layer_idx)
|
| 649 |
+
|
| 650 |
+
self.mlp = DreamMLP(config)
|
| 651 |
+
self.input_layernorm = DreamRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 652 |
+
self.post_attention_layernorm = DreamRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 653 |
+
|
| 654 |
+
def forward(
|
| 655 |
+
self,
|
| 656 |
+
hidden_states: torch.Tensor,
|
| 657 |
+
update_kvcache: torch.int32 = None,
|
| 658 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 659 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 660 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 661 |
+
output_attentions: Optional[bool] = False,
|
| 662 |
+
use_cache: Optional[bool] = False,
|
| 663 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 664 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
| 665 |
+
**kwargs,
|
| 666 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 667 |
+
"""
|
| 668 |
+
Args:
|
| 669 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 670 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
| 671 |
+
`(batch, sequence_length)` where padding elements are indicated by 0.
|
| 672 |
+
output_attentions (`bool`, *optional*):
|
| 673 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 674 |
+
returned tensors for more detail.
|
| 675 |
+
use_cache (`bool`, *optional*):
|
| 676 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 677 |
+
(see `past_key_values`).
|
| 678 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
| 679 |
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
| 680 |
+
Indices depicting the position of the input sequence tokens in the sequence.
|
| 681 |
+
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
|
| 682 |
+
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
| 683 |
+
with `head_dim` being the embedding dimension of each attention head.
|
| 684 |
+
kwargs (`dict`, *optional*):
|
| 685 |
+
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
| 686 |
+
into the model
|
| 687 |
+
"""
|
| 688 |
+
|
| 689 |
+
residual = hidden_states
|
| 690 |
+
|
| 691 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 692 |
+
|
| 693 |
+
# Self Attention
|
| 694 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
| 695 |
+
hidden_states=hidden_states,
|
| 696 |
+
attention_mask=attention_mask,
|
| 697 |
+
update_kvcache=update_kvcache,
|
| 698 |
+
position_ids=position_ids,
|
| 699 |
+
past_key_value=past_key_value,
|
| 700 |
+
output_attentions=output_attentions,
|
| 701 |
+
use_cache=use_cache,
|
| 702 |
+
cache_position=cache_position,
|
| 703 |
+
position_embeddings=position_embeddings,
|
| 704 |
+
)
|
| 705 |
+
hidden_states = residual + hidden_states
|
| 706 |
+
|
| 707 |
+
# Fully Connected
|
| 708 |
+
residual = hidden_states
|
| 709 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 710 |
+
hidden_states = self.mlp(hidden_states)
|
| 711 |
+
hidden_states = residual + hidden_states
|
| 712 |
+
|
| 713 |
+
outputs = (hidden_states,)
|
| 714 |
+
|
| 715 |
+
if output_attentions:
|
| 716 |
+
outputs += (self_attn_weights,)
|
| 717 |
+
|
| 718 |
+
if use_cache:
|
| 719 |
+
outputs += (present_key_value,)
|
| 720 |
+
|
| 721 |
+
return outputs
|
| 722 |
+
|
| 723 |
+
class DreamPreTrainedModel(PreTrainedModel):
|
| 724 |
+
config_class = DreamConfig
|
| 725 |
+
base_model_prefix = "model"
|
| 726 |
+
supports_gradient_checkpointing = True
|
| 727 |
+
_no_split_modules = ["DreamDecoderLayer"]
|
| 728 |
+
_skip_keys_device_placement = "past_key_values"
|
| 729 |
+
_supports_flash_attn_2 = True
|
| 730 |
+
_supports_sdpa = True
|
| 731 |
+
_supports_cache_class = True
|
| 732 |
+
_supports_quantized_cache = True
|
| 733 |
+
_supports_static_cache = True
|
| 734 |
+
|
| 735 |
+
def _init_weights(self, module):
|
| 736 |
+
std = self.config.initializer_range
|
| 737 |
+
if isinstance(module, nn.Linear):
|
| 738 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 739 |
+
if module.bias is not None:
|
| 740 |
+
module.bias.data.zero_()
|
| 741 |
+
elif isinstance(module, nn.Embedding):
|
| 742 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 743 |
+
if module.padding_idx is not None:
|
| 744 |
+
module.weight.data[module.padding_idx].zero_()
|
| 745 |
+
|
| 746 |
+
@classmethod
|
| 747 |
+
def from_pretrained(
|
| 748 |
+
cls,
|
| 749 |
+
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
| 750 |
+
*model_args,
|
| 751 |
+
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
|
| 752 |
+
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
| 753 |
+
ignore_mismatched_sizes: bool = False,
|
| 754 |
+
force_download: bool = False,
|
| 755 |
+
local_files_only: bool = False,
|
| 756 |
+
token: Optional[Union[str, bool]] = None,
|
| 757 |
+
revision: str = "main",
|
| 758 |
+
use_safetensors: Optional[bool] = None,
|
| 759 |
+
weights_only: bool = True,
|
| 760 |
+
**kwargs,
|
| 761 |
+
):
|
| 762 |
+
_model = super().from_pretrained(
|
| 763 |
+
pretrained_model_name_or_path,
|
| 764 |
+
*model_args,
|
| 765 |
+
config=config,
|
| 766 |
+
cache_dir=cache_dir,
|
| 767 |
+
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
| 768 |
+
force_download=force_download,
|
| 769 |
+
local_files_only=local_files_only,
|
| 770 |
+
token=token,
|
| 771 |
+
revision=revision,
|
| 772 |
+
use_safetensors=use_safetensors,
|
| 773 |
+
weights_only=weights_only,
|
| 774 |
+
**kwargs,
|
| 775 |
+
)
|
| 776 |
+
# NOTE(Lin): we need to override the generation config
|
| 777 |
+
# because the generation config loaded in `from_pretrained`
|
| 778 |
+
# does not include all the attributes of DreamGenerationConfig
|
| 779 |
+
resume_download = kwargs.get("resume_download", None)
|
| 780 |
+
proxies = kwargs.get("proxies", None)
|
| 781 |
+
subfolder = kwargs.get("subfolder", "")
|
| 782 |
+
from_auto_class = kwargs.get("_from_auto", False)
|
| 783 |
+
from_pipeline = kwargs.get("_from_pipeline", None)
|
| 784 |
+
_model.generation_config = DreamGenerationConfig.from_pretrained(
|
| 785 |
+
pretrained_model_name_or_path,
|
| 786 |
+
cache_dir=cache_dir,
|
| 787 |
+
force_download=force_download,
|
| 788 |
+
resume_download=resume_download,
|
| 789 |
+
proxies=proxies,
|
| 790 |
+
local_files_only=local_files_only,
|
| 791 |
+
token=token,
|
| 792 |
+
revision=revision,
|
| 793 |
+
subfolder=subfolder,
|
| 794 |
+
_from_auto=from_auto_class,
|
| 795 |
+
_from_pipeline=from_pipeline,
|
| 796 |
+
)
|
| 797 |
+
return _model
|
| 798 |
+
|
| 799 |
+
class DreamBaseModel(DreamPreTrainedModel):
|
| 800 |
+
"""
|
| 801 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DreamDecoderLayer`]
|
| 802 |
+
Args:
|
| 803 |
+
config: DreamConfig
|
| 804 |
+
"""
|
| 805 |
+
|
| 806 |
+
def __init__(self, config: DreamConfig):
|
| 807 |
+
super().__init__(config)
|
| 808 |
+
self.padding_idx = config.pad_token_id
|
| 809 |
+
self.vocab_size = config.vocab_size
|
| 810 |
+
|
| 811 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 812 |
+
self.layers = nn.ModuleList(
|
| 813 |
+
[DreamDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 814 |
+
)
|
| 815 |
+
self._attn_implementation = config._attn_implementation
|
| 816 |
+
self.norm = DreamRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 817 |
+
self.rotary_emb = DreamRotaryEmbedding(config=config)
|
| 818 |
+
|
| 819 |
+
self.gradient_checkpointing = False
|
| 820 |
+
# Initialize weights and apply final processing
|
| 821 |
+
self.post_init()
|
| 822 |
+
|
| 823 |
+
def get_input_embeddings(self):
|
| 824 |
+
return self.embed_tokens
|
| 825 |
+
|
| 826 |
+
def set_input_embeddings(self, value):
|
| 827 |
+
self.embed_tokens = value
|
| 828 |
+
|
| 829 |
+
def forward(
|
| 830 |
+
self,
|
| 831 |
+
input_ids: torch.LongTensor = None,
|
| 832 |
+
update_kvcache: torch.int32 = None,
|
| 833 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 834 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 835 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 836 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 837 |
+
use_cache: Optional[bool] = None,
|
| 838 |
+
output_attentions: Optional[bool] = None,
|
| 839 |
+
output_hidden_states: Optional[bool] = None,
|
| 840 |
+
return_dict: Optional[bool] = None,
|
| 841 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 842 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 843 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 844 |
+
output_hidden_states = (
|
| 845 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 846 |
+
)
|
| 847 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 848 |
+
|
| 849 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 850 |
+
|
| 851 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 852 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 853 |
+
|
| 854 |
+
if self.gradient_checkpointing and self.training:
|
| 855 |
+
if use_cache:
|
| 856 |
+
logger.warning_once(
|
| 857 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 858 |
+
)
|
| 859 |
+
use_cache = False
|
| 860 |
+
|
| 861 |
+
if inputs_embeds is None:
|
| 862 |
+
# past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 863 |
+
# input_ids = input_ids[:, past_seen_tokens:]
|
| 864 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 865 |
+
# print("inputs_embeds",inputs_embeds.shape)
|
| 866 |
+
|
| 867 |
+
if use_cache and past_key_values is None:
|
| 868 |
+
past_key_values = DynamicCache()
|
| 869 |
+
|
| 870 |
+
if cache_position is None:
|
| 871 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 872 |
+
cache_position = torch.arange(
|
| 873 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 874 |
+
)
|
| 875 |
+
|
| 876 |
+
if position_ids is None:
|
| 877 |
+
position_ids = cache_position.unsqueeze(0)
|
| 878 |
+
|
| 879 |
+
hidden_states = inputs_embeds
|
| 880 |
+
|
| 881 |
+
# create position embeddings to be shared across the decoder layers
|
| 882 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 883 |
+
|
| 884 |
+
# decoder layers
|
| 885 |
+
all_hidden_states = () if output_hidden_states else None
|
| 886 |
+
all_self_attns = () if output_attentions else None
|
| 887 |
+
|
| 888 |
+
for decoder_layer in self.layers:
|
| 889 |
+
if output_hidden_states:
|
| 890 |
+
all_hidden_states += (hidden_states,)
|
| 891 |
+
|
| 892 |
+
if self.gradient_checkpointing and self.training:
|
| 893 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 894 |
+
decoder_layer.__call__,
|
| 895 |
+
hidden_states,
|
| 896 |
+
attention_mask,
|
| 897 |
+
position_ids,
|
| 898 |
+
past_key_values,
|
| 899 |
+
output_attentions,
|
| 900 |
+
use_cache,
|
| 901 |
+
cache_position,
|
| 902 |
+
position_embeddings,
|
| 903 |
+
)
|
| 904 |
+
else:
|
| 905 |
+
layer_outputs = decoder_layer(
|
| 906 |
+
hidden_states,
|
| 907 |
+
attention_mask=attention_mask,
|
| 908 |
+
update_kvcache=update_kvcache,
|
| 909 |
+
position_ids=position_ids,
|
| 910 |
+
past_key_value=past_key_values,
|
| 911 |
+
output_attentions=output_attentions,
|
| 912 |
+
use_cache=use_cache,
|
| 913 |
+
cache_position=cache_position,
|
| 914 |
+
position_embeddings=position_embeddings,
|
| 915 |
+
)
|
| 916 |
+
|
| 917 |
+
hidden_states = layer_outputs[0]
|
| 918 |
+
|
| 919 |
+
if output_attentions:
|
| 920 |
+
all_self_attns += (layer_outputs[1],)
|
| 921 |
+
|
| 922 |
+
hidden_states = self.norm(hidden_states)
|
| 923 |
+
|
| 924 |
+
# add hidden states from the last decoder layer
|
| 925 |
+
if output_hidden_states:
|
| 926 |
+
all_hidden_states += (hidden_states,)
|
| 927 |
+
|
| 928 |
+
if not return_dict:
|
| 929 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attns] if v is not None)
|
| 930 |
+
return BaseModelOutputWithPast(
|
| 931 |
+
last_hidden_state=hidden_states,
|
| 932 |
+
past_key_values=past_key_values if use_cache else None,
|
| 933 |
+
hidden_states=all_hidden_states,
|
| 934 |
+
attentions=all_self_attns,
|
| 935 |
+
)
|
| 936 |
+
|
| 937 |
+
|
| 938 |
+
class DreamModel(DreamGenerationMixin, DreamPreTrainedModel):
|
| 939 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 940 |
+
|
| 941 |
+
def __init__(self, config):
|
| 942 |
+
super().__init__(config)
|
| 943 |
+
self.model = DreamBaseModel(config)
|
| 944 |
+
self.vocab_size = config.vocab_size
|
| 945 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 946 |
+
|
| 947 |
+
# Initialize weights and apply final processing
|
| 948 |
+
self.post_init()
|
| 949 |
+
|
| 950 |
+
def reset_rope_parameters(self):
|
| 951 |
+
self.model.rotary_emb.reset_parameters()
|
| 952 |
+
for layer in self.model.layers:
|
| 953 |
+
layer.self_attn.rotary_emb.reset_parameters()
|
| 954 |
+
|
| 955 |
+
def get_input_embeddings(self):
|
| 956 |
+
return self.model.embed_tokens
|
| 957 |
+
|
| 958 |
+
def set_input_embeddings(self, value):
|
| 959 |
+
self.model.embed_tokens = value
|
| 960 |
+
|
| 961 |
+
def get_output_embeddings(self):
|
| 962 |
+
return self.lm_head
|
| 963 |
+
|
| 964 |
+
def set_output_embeddings(self, new_embeddings):
|
| 965 |
+
self.lm_head = new_embeddings
|
| 966 |
+
|
| 967 |
+
def set_decoder(self, decoder):
|
| 968 |
+
self.model = decoder
|
| 969 |
+
|
| 970 |
+
def get_decoder(self):
|
| 971 |
+
return self.model
|
| 972 |
+
|
| 973 |
+
def forward(
|
| 974 |
+
self,
|
| 975 |
+
input_ids: torch.LongTensor = None,
|
| 976 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 977 |
+
update_kvcache: torch.int32 = None,
|
| 978 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 979 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 980 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 981 |
+
labels: Optional[torch.LongTensor] = None,
|
| 982 |
+
use_cache: Optional[bool] = None,
|
| 983 |
+
output_attentions: Optional[bool] = None,
|
| 984 |
+
output_hidden_states: Optional[bool] = None,
|
| 985 |
+
return_dict: Optional[bool] = None,
|
| 986 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 987 |
+
num_logits_to_keep: int = 0,
|
| 988 |
+
**loss_kwargs,
|
| 989 |
+
) -> Union[Tuple, MaskedLMOutput]:
|
| 990 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 991 |
+
output_hidden_states = (
|
| 992 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 993 |
+
)
|
| 994 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 995 |
+
|
| 996 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 997 |
+
outputs = self.model(
|
| 998 |
+
input_ids=input_ids,
|
| 999 |
+
attention_mask=attention_mask,
|
| 1000 |
+
update_kvcache=update_kvcache,
|
| 1001 |
+
position_ids=position_ids,
|
| 1002 |
+
past_key_values=past_key_values,
|
| 1003 |
+
inputs_embeds=inputs_embeds,
|
| 1004 |
+
use_cache=use_cache,
|
| 1005 |
+
output_attentions=output_attentions,
|
| 1006 |
+
output_hidden_states=output_hidden_states,
|
| 1007 |
+
return_dict=return_dict,
|
| 1008 |
+
cache_position=cache_position,
|
| 1009 |
+
)
|
| 1010 |
+
|
| 1011 |
+
hidden_states = outputs[0]
|
| 1012 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 1013 |
+
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
| 1014 |
+
|
| 1015 |
+
loss = None
|
| 1016 |
+
if labels is not None:
|
| 1017 |
+
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
| 1018 |
+
|
| 1019 |
+
if not return_dict:
|
| 1020 |
+
output = (logits,) + outputs[1:]
|
| 1021 |
+
return (loss,) + output if loss is not None else output
|
| 1022 |
+
|
| 1023 |
+
return CausalLMOutputWithPast(
|
| 1024 |
+
loss=loss,
|
| 1025 |
+
logits=logits,
|
| 1026 |
+
past_key_values=outputs.past_key_values,
|
| 1027 |
+
hidden_states=outputs.hidden_states,
|
| 1028 |
+
attentions=outputs.attentions,
|
| 1029 |
+
)
|
Discrete-Diffusion-Forcing/D2F-eval/model_cache/llada/configuration_llada.py
ADDED
|
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLaDA configuration
|
| 3 |
+
"""
|
| 4 |
+
from transformers import AutoConfig, PretrainedConfig
|
| 5 |
+
|
| 6 |
+
from enum import Enum
|
| 7 |
+
from os import PathLike
|
| 8 |
+
from typing import Union
|
| 9 |
+
from dataclasses import asdict, dataclass, field
|
| 10 |
+
from glob import glob
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import (
|
| 13 |
+
Any,
|
| 14 |
+
Dict,
|
| 15 |
+
Iterable,
|
| 16 |
+
List,
|
| 17 |
+
Optional,
|
| 18 |
+
Tuple,
|
| 19 |
+
Type,
|
| 20 |
+
TypeVar,
|
| 21 |
+
Union,
|
| 22 |
+
cast,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
__all__ = [
|
| 27 |
+
"ActivationType",
|
| 28 |
+
"ActivationCheckpointingStrategy",
|
| 29 |
+
"BlockType",
|
| 30 |
+
"LayerNormType",
|
| 31 |
+
"InitFnType",
|
| 32 |
+
"ModelConfig",
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
PathOrStr = Union[str, PathLike]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class StrEnum(str, Enum):
|
| 39 |
+
"""
|
| 40 |
+
This is equivalent to Python's :class:`enum.StrEnum` since version 3.11.
|
| 41 |
+
We include this here for compatibility with older version of Python.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __str__(self) -> str:
|
| 45 |
+
return self.value
|
| 46 |
+
|
| 47 |
+
def __repr__(self) -> str:
|
| 48 |
+
return f"'{str(self)}'"
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class LayerNormType(StrEnum):
|
| 52 |
+
default = "default"
|
| 53 |
+
"""
|
| 54 |
+
The default LayerNorm implementation, equivalent to PyTorch's built-in version.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
low_precision = "low_precision"
|
| 58 |
+
"""
|
| 59 |
+
A low-precision version of the default LayerNorm.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
rms = "rms"
|
| 63 |
+
"""
|
| 64 |
+
An RMSNorm implementation. When using ``torch.compile`` this is
|
| 65 |
+
probably the fastest implementation.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
gemma_rms = "gemma_rms"
|
| 69 |
+
"""
|
| 70 |
+
An RMSNorm implementation by gemmma. When using ``torch.compile`` this is
|
| 71 |
+
probably the fastest implementation.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
amd_compatible = "amd_compatible"
|
| 75 |
+
"""
|
| 76 |
+
LayerNorm implemented manually to work around an issue with ROCm.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class ActivationType(StrEnum):
|
| 81 |
+
gelu = "gelu"
|
| 82 |
+
relu = "relu"
|
| 83 |
+
silu = "silu"
|
| 84 |
+
swiglu = "swiglu"
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class BlockType(StrEnum):
|
| 88 |
+
sequential = "sequential"
|
| 89 |
+
parallel = "parallel"
|
| 90 |
+
|
| 91 |
+
llama = "llama"
|
| 92 |
+
"""
|
| 93 |
+
A block similar to the sequential block with slightly different
|
| 94 |
+
implementations of operations like attention to imitate the behavior of Llama.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class InitFnType(StrEnum):
|
| 99 |
+
mitchell = "mitchell"
|
| 100 |
+
"""
|
| 101 |
+
The strategy suggested to us by Mitchell Wortsman from UW.
|
| 102 |
+
This uses a truncated normal distribution with an adaptive standard deviation that depends
|
| 103 |
+
on the size of the weights as well as the depth of the layer.
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
normal = "normal"
|
| 107 |
+
"""
|
| 108 |
+
All weights are initialized from the same normal distribution.
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
kaiming_normal = "kaiming_normal"
|
| 112 |
+
"""
|
| 113 |
+
All weights are initialized with the Kaiming method from a normal distribution.
|
| 114 |
+
Note this currently won't work with FSDP.
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
fan_in = "fan_in"
|
| 118 |
+
"""
|
| 119 |
+
"Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in``
|
| 120 |
+
is the input dimensionality of the kernel.
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
full_megatron = "full_megatron"
|
| 124 |
+
"""
|
| 125 |
+
This is what metaseq calls "full megatron init". It is the init used for Llama 2.
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
@dataclass
|
| 130 |
+
class ModelConfig():
|
| 131 |
+
"""
|
| 132 |
+
LLaDA (model) configuration.
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
# Note that the defaults for these attributes are equivalent to the base GPT2 model.
|
| 136 |
+
|
| 137 |
+
d_model: int = 768
|
| 138 |
+
"""
|
| 139 |
+
The hidden size of the model.
|
| 140 |
+
"""
|
| 141 |
+
|
| 142 |
+
n_heads: int = 12
|
| 143 |
+
"""
|
| 144 |
+
The number of self-attention heads.
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
n_kv_heads: Optional[int] = None
|
| 148 |
+
"""
|
| 149 |
+
The number of heads to use for keys and values. Defaults to `n_heads`.
|
| 150 |
+
Set this to ``None`` or ``n_heads`` for normal multi-head attention.
|
| 151 |
+
Set this to 1 for multi-query attention.
|
| 152 |
+
Set it to some in-between value for Llama2-style grouped query attention.
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
n_layers: int = 12
|
| 156 |
+
"""
|
| 157 |
+
The number of layers/blocks.
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
mlp_ratio: int = 4
|
| 161 |
+
"""
|
| 162 |
+
The ratio of the inner MLP dimensionality to ``d_model``.
|
| 163 |
+
This is only used when ``mlp_hidden_size`` is not set.
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
mlp_hidden_size: Optional[int] = None
|
| 167 |
+
"""
|
| 168 |
+
Set the exact hidden size for the MLP. Otherwise the inner MLP hidden size will be set to `mlp_ratio * d_model`.
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
activation_type: ActivationType = ActivationType.swiglu
|
| 172 |
+
"""
|
| 173 |
+
The activation function to use within the MLP layers.
|
| 174 |
+
"""
|
| 175 |
+
|
| 176 |
+
block_type: BlockType = BlockType.sequential
|
| 177 |
+
"""
|
| 178 |
+
The transformer block implementation.
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
block_group_size: int = 1
|
| 182 |
+
"""
|
| 183 |
+
The number of blocks to group together into a single parent block.
|
| 184 |
+
This has no affect on the number of parameters in the model and is only used to wrap groups
|
| 185 |
+
of blocks together with a single FSDP wrapper during training.
|
| 186 |
+
"""
|
| 187 |
+
|
| 188 |
+
alibi: bool = False
|
| 189 |
+
"""
|
| 190 |
+
If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``.
|
| 191 |
+
"""
|
| 192 |
+
|
| 193 |
+
alibi_bias_max: float = 8.0
|
| 194 |
+
"""
|
| 195 |
+
Maximum absolute value of ALiBi bias.
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
rope: bool = False
|
| 199 |
+
"""
|
| 200 |
+
Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``.
|
| 201 |
+
"""
|
| 202 |
+
|
| 203 |
+
rope_full_precision: bool = True
|
| 204 |
+
"""
|
| 205 |
+
If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise,
|
| 206 |
+
apply RoPE at the precision of the input.
|
| 207 |
+
"""
|
| 208 |
+
|
| 209 |
+
flash_attention: bool = False
|
| 210 |
+
"""
|
| 211 |
+
If ``True``, use ``FlashAttention``.
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
attention_dropout: float = 0.1
|
| 215 |
+
"""
|
| 216 |
+
The dropout probability within the attention modules.
|
| 217 |
+
"""
|
| 218 |
+
|
| 219 |
+
multi_query_attention: Optional[bool] = None
|
| 220 |
+
"""
|
| 221 |
+
Use the Multi-Query formulation of attention used in PaLM. This reduces the number of parameters
|
| 222 |
+
and is more efficient during inference.
|
| 223 |
+
"""
|
| 224 |
+
|
| 225 |
+
attention_layer_norm: bool = False
|
| 226 |
+
"""
|
| 227 |
+
Apply layer norm to the keys and queries within the attention mechanism.
|
| 228 |
+
This can help stabilize training.
|
| 229 |
+
"""
|
| 230 |
+
|
| 231 |
+
residual_dropout: float = 0.1
|
| 232 |
+
"""
|
| 233 |
+
The dropout probability for the MLP and attention output within each block.
|
| 234 |
+
"""
|
| 235 |
+
|
| 236 |
+
embedding_dropout: float = 0.1
|
| 237 |
+
"""
|
| 238 |
+
The dropout probability for embeddings.
|
| 239 |
+
"""
|
| 240 |
+
|
| 241 |
+
input_emb_norm: bool = False
|
| 242 |
+
"""
|
| 243 |
+
An input hidden_states norm implementation by gemmma.
|
| 244 |
+
"""
|
| 245 |
+
|
| 246 |
+
layer_norm_type: LayerNormType = LayerNormType.default
|
| 247 |
+
"""
|
| 248 |
+
The layernorm implementation to use.
|
| 249 |
+
"""
|
| 250 |
+
|
| 251 |
+
layer_norm_with_affine: bool = True
|
| 252 |
+
"""
|
| 253 |
+
Whether to include bias and weight parameters for the layer norms.
|
| 254 |
+
This only affects layer norms that are immediately followed by a linear layer in the forward pass,
|
| 255 |
+
so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine`
|
| 256 |
+
to ``False``.
|
| 257 |
+
"""
|
| 258 |
+
|
| 259 |
+
rms_norm_eps: float = 1e-05
|
| 260 |
+
"""
|
| 261 |
+
The rms layernorm eps param.
|
| 262 |
+
"""
|
| 263 |
+
|
| 264 |
+
attention_layer_norm_with_affine: bool = True
|
| 265 |
+
"""
|
| 266 |
+
Toggle affine transform for the QK norms.
|
| 267 |
+
"""
|
| 268 |
+
|
| 269 |
+
max_sequence_length: int = 1024
|
| 270 |
+
"""
|
| 271 |
+
The maximum input sequence length supported by the model.
|
| 272 |
+
"""
|
| 273 |
+
|
| 274 |
+
rope_theta: float = 10000.0
|
| 275 |
+
"""
|
| 276 |
+
The rope base param.
|
| 277 |
+
"""
|
| 278 |
+
|
| 279 |
+
include_qkv_bias: Optional[bool] = False
|
| 280 |
+
"""
|
| 281 |
+
Whether or not to include bias parameters in qkv linear layers.
|
| 282 |
+
"""
|
| 283 |
+
|
| 284 |
+
include_bias: bool = False
|
| 285 |
+
"""
|
| 286 |
+
Whether or not to include bias parameters in linear layers.
|
| 287 |
+
In PaLM, they got rid of all bias terms because they found that large
|
| 288 |
+
models tend to have near 0 bias terms anyway.
|
| 289 |
+
"""
|
| 290 |
+
|
| 291 |
+
bias_for_layer_norm: Optional[bool] = None
|
| 292 |
+
"""
|
| 293 |
+
Whether or not to include bias parameters in layer norm.
|
| 294 |
+
This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in
|
| 295 |
+
layer norm.
|
| 296 |
+
When this is None (the default), it inherits the setting from include_bias.
|
| 297 |
+
"""
|
| 298 |
+
|
| 299 |
+
scale_logits: bool = False
|
| 300 |
+
"""
|
| 301 |
+
If ``True``, scale the output logits by ``1 / sqrt(d_model)``.
|
| 302 |
+
"""
|
| 303 |
+
|
| 304 |
+
vocab_size: int = 50257
|
| 305 |
+
"""
|
| 306 |
+
Vocabulary size of the model.
|
| 307 |
+
"""
|
| 308 |
+
|
| 309 |
+
embedding_size: Optional[int] = 50304
|
| 310 |
+
"""
|
| 311 |
+
The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default
|
| 312 |
+
to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the
|
| 313 |
+
next multiple of 128 that's greater than ``vocab_size`` can improve throughput
|
| 314 |
+
substantially.
|
| 315 |
+
"""
|
| 316 |
+
|
| 317 |
+
weight_tying: bool = True
|
| 318 |
+
"""
|
| 319 |
+
Whether to tie output linear weights to the input embedding.
|
| 320 |
+
"""
|
| 321 |
+
|
| 322 |
+
eos_token_id: int = 50256
|
| 323 |
+
"""
|
| 324 |
+
The ID of the end-of-sentence special token.
|
| 325 |
+
"""
|
| 326 |
+
|
| 327 |
+
pad_token_id: int = 50256
|
| 328 |
+
"""
|
| 329 |
+
The ID of the token to use for padding. Defaults to the ID of the EOS token.
|
| 330 |
+
"""
|
| 331 |
+
|
| 332 |
+
mask_token_id: Optional[int] = 50256
|
| 333 |
+
"""
|
| 334 |
+
The ID of the token to use for mask token. Defaults to the ID of the EOS token.
|
| 335 |
+
"""
|
| 336 |
+
|
| 337 |
+
init_device: Optional[str] = None
|
| 338 |
+
"""
|
| 339 |
+
The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta".
|
| 340 |
+
"""
|
| 341 |
+
|
| 342 |
+
init_fn: InitFnType = InitFnType.normal
|
| 343 |
+
"""
|
| 344 |
+
The weight initialization strategy.
|
| 345 |
+
"""
|
| 346 |
+
|
| 347 |
+
init_std: float = 0.02
|
| 348 |
+
"""
|
| 349 |
+
The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such
|
| 350 |
+
as "normal".
|
| 351 |
+
"""
|
| 352 |
+
|
| 353 |
+
init_cutoff_factor: Optional[float] = None
|
| 354 |
+
"""
|
| 355 |
+
A positive factor used to scale the cutoff values when initializing weights with a "fixed distribution" ``init_fn``, such
|
| 356 |
+
as "normal". Setting this to None means values are not cutoff.
|
| 357 |
+
"""
|
| 358 |
+
|
| 359 |
+
precision: Optional[str] = None
|
| 360 |
+
"""
|
| 361 |
+
Precision used to train/evaluate with. You shouldn't set this directly.
|
| 362 |
+
See :data:`TrainConfig.precision` instead.
|
| 363 |
+
"""
|
| 364 |
+
|
| 365 |
+
@property
|
| 366 |
+
def effective_n_kv_heads(self) -> int:
|
| 367 |
+
if self.n_kv_heads is None:
|
| 368 |
+
if self.multi_query_attention is True:
|
| 369 |
+
return 1
|
| 370 |
+
else:
|
| 371 |
+
return self.n_heads
|
| 372 |
+
else:
|
| 373 |
+
if self.multi_query_attention is None:
|
| 374 |
+
return self.n_kv_heads
|
| 375 |
+
if self.multi_query_attention:
|
| 376 |
+
n_kv_heads_should_be = 1
|
| 377 |
+
else:
|
| 378 |
+
n_kv_heads_should_be = self.n_heads
|
| 379 |
+
if self.n_kv_heads == n_kv_heads_should_be:
|
| 380 |
+
return n_kv_heads_should_be
|
| 381 |
+
else:
|
| 382 |
+
raise Exception(
|
| 383 |
+
"You can't set `multi_query_attention` and `n_kv_heads` at the same time."
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
class ActivationCheckpointingStrategy(StrEnum):
|
| 387 |
+
whole_layer = "whole_layer"
|
| 388 |
+
"""
|
| 389 |
+
Checkpoint every transformer layer.
|
| 390 |
+
"""
|
| 391 |
+
|
| 392 |
+
one_in_two = "one_in_two"
|
| 393 |
+
"""
|
| 394 |
+
Checkpoint one in two transformer layers.
|
| 395 |
+
"""
|
| 396 |
+
|
| 397 |
+
one_in_three = "one_in_three"
|
| 398 |
+
"""
|
| 399 |
+
Checkpoint one in three transformer layers.
|
| 400 |
+
"""
|
| 401 |
+
|
| 402 |
+
one_in_four = "one_in_four"
|
| 403 |
+
"""
|
| 404 |
+
Checkpoint one in four transformer layers.
|
| 405 |
+
"""
|
| 406 |
+
|
| 407 |
+
two_in_three = "two_in_three"
|
| 408 |
+
"""
|
| 409 |
+
Checkpoint two out of every three transformer layers.
|
| 410 |
+
"""
|
| 411 |
+
|
| 412 |
+
three_in_four = "three_in_four"
|
| 413 |
+
"""
|
| 414 |
+
Checkpoint three out of four of every transformer layers.
|
| 415 |
+
"""
|
| 416 |
+
|
| 417 |
+
four_in_five = "four_in_five"
|
| 418 |
+
"""
|
| 419 |
+
Checkpoint four out of five of every transformer layers.
|
| 420 |
+
"""
|
| 421 |
+
|
| 422 |
+
nine_in_ten = "nine_in_ten"
|
| 423 |
+
"""
|
| 424 |
+
Checkpoint nine out of ten of every transformer layers.
|
| 425 |
+
"""
|
| 426 |
+
|
| 427 |
+
fine_grained = "fine_grained"
|
| 428 |
+
"""
|
| 429 |
+
Focus checkpointing on where it is cheap to recompute and saves most memory.
|
| 430 |
+
"""
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
class LLaDAConfig(PretrainedConfig):
|
| 434 |
+
model_type = "llada"
|
| 435 |
+
keys_to_ignore_at_inference = ["past_key_values"] # TODO: confirm
|
| 436 |
+
|
| 437 |
+
def __init__(self, use_cache: bool = False, **kwargs):
|
| 438 |
+
model_config = ModelConfig()
|
| 439 |
+
all_kwargs = model_config.__dict__
|
| 440 |
+
all_kwargs.update(kwargs)
|
| 441 |
+
all_kwargs.update({"use_cache": use_cache})
|
| 442 |
+
all_kwargs.update(
|
| 443 |
+
{
|
| 444 |
+
"architectures": all_kwargs.get("architectures", ["LLaDAModelLM"])
|
| 445 |
+
}
|
| 446 |
+
)
|
| 447 |
+
super().__init__(**all_kwargs)
|
| 448 |
+
|
| 449 |
+
@property
|
| 450 |
+
def num_attention_heads(self):
|
| 451 |
+
return self.n_heads
|
| 452 |
+
|
| 453 |
+
@property
|
| 454 |
+
def num_hidden_layers(self):
|
| 455 |
+
return self.n_layers
|
| 456 |
+
|
| 457 |
+
@property
|
| 458 |
+
def hidden_size(self):
|
| 459 |
+
return self.d_model
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
# Register the config class so that it is available for transformer pipelines, auto-loading etc.
|
| 463 |
+
AutoConfig.register("llada", LLaDAConfig)
|
Discrete-Diffusion-Forcing/D2F-eval/model_cache/llada/modeling_llada.py
ADDED
|
@@ -0,0 +1,1504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
import sys
|
| 6 |
+
from abc import abstractmethod
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
from functools import partial
|
| 9 |
+
from typing import (
|
| 10 |
+
Callable,
|
| 11 |
+
Dict,
|
| 12 |
+
Iterable,
|
| 13 |
+
List,
|
| 14 |
+
NamedTuple,
|
| 15 |
+
Optional,
|
| 16 |
+
Sequence,
|
| 17 |
+
Set,
|
| 18 |
+
Tuple,
|
| 19 |
+
cast,
|
| 20 |
+
)
|
| 21 |
+
from dataclasses import fields
|
| 22 |
+
from typing import List, Optional, Tuple, Union
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
import torch.backends.cuda
|
| 26 |
+
import torch.nn as nn
|
| 27 |
+
import torch.nn.functional as F
|
| 28 |
+
from torch import einsum
|
| 29 |
+
from transformers import PreTrainedModel
|
| 30 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 31 |
+
from transformers.models.auto import AutoModel
|
| 32 |
+
from transformers.cache_utils import Cache
|
| 33 |
+
|
| 34 |
+
from .configuration_llada import (
|
| 35 |
+
LLaDAConfig,
|
| 36 |
+
StrEnum,
|
| 37 |
+
InitFnType,
|
| 38 |
+
ActivationType,
|
| 39 |
+
BlockType,
|
| 40 |
+
LayerNormType,
|
| 41 |
+
ModelConfig,
|
| 42 |
+
ActivationCheckpointingStrategy,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
if sys.version_info.minor > 8:
|
| 46 |
+
from collections.abc import MutableMapping
|
| 47 |
+
elif sys.version_info.minor == 8:
|
| 48 |
+
from typing import MutableMapping
|
| 49 |
+
else:
|
| 50 |
+
raise SystemExit("This script supports Python 3.8 or higher")
|
| 51 |
+
|
| 52 |
+
__all__ = [
|
| 53 |
+
"LayerNormBase",
|
| 54 |
+
"LayerNorm",
|
| 55 |
+
"RMSLayerNorm",
|
| 56 |
+
"GemmaRMSLayerNorm",
|
| 57 |
+
"RotaryEmbedding",
|
| 58 |
+
"Activation",
|
| 59 |
+
"GELU",
|
| 60 |
+
"ReLU",
|
| 61 |
+
"SwiGLU",
|
| 62 |
+
"LLaDABlock",
|
| 63 |
+
"LLaDASequentialBlock",
|
| 64 |
+
"LLaDAModel",
|
| 65 |
+
"LLaDAOutput",
|
| 66 |
+
"LLaDAGenerateOutput",
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
log = logging.getLogger(__name__)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class ModuleType(StrEnum):
|
| 74 |
+
in_module = "in"
|
| 75 |
+
out_module = "out"
|
| 76 |
+
emb = "emb"
|
| 77 |
+
final_out = "final_out"
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def init_weights(
|
| 81 |
+
config: ModelConfig,
|
| 82 |
+
module: Union[nn.Linear, nn.Embedding],
|
| 83 |
+
d: Optional[int] = None,
|
| 84 |
+
layer_id: Optional[int] = None,
|
| 85 |
+
std_factor: float = 1.0,
|
| 86 |
+
type_of_module: Optional[ModuleType] = None,
|
| 87 |
+
) -> None:
|
| 88 |
+
"""
|
| 89 |
+
Initialize weights of a linear or embedding module.
|
| 90 |
+
|
| 91 |
+
:param config: The model config.
|
| 92 |
+
:param module: The linear or embedding submodule to initialize.
|
| 93 |
+
:param d: The effective input dimensionality of the weights. This could be smaller than the actual dimensions
|
| 94 |
+
for fused layers.
|
| 95 |
+
:param layer_id: When set, the standard deviation for the "mitchell" method will be adjusted by
|
| 96 |
+
``1 / sqrt(2 * (layer_id + 1))``.
|
| 97 |
+
"""
|
| 98 |
+
d = d if d is not None else config.d_model
|
| 99 |
+
if config.init_fn == InitFnType.normal:
|
| 100 |
+
std = config.init_std * std_factor
|
| 101 |
+
if config.init_cutoff_factor is not None:
|
| 102 |
+
cutoff_value = config.init_cutoff_factor * std
|
| 103 |
+
nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value)
|
| 104 |
+
else:
|
| 105 |
+
nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 106 |
+
elif config.init_fn == InitFnType.mitchell:
|
| 107 |
+
std = std_factor / math.sqrt(d)
|
| 108 |
+
if layer_id is not None:
|
| 109 |
+
std = std / math.sqrt(2 * (layer_id + 1))
|
| 110 |
+
nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-3 * std, b=3 * std)
|
| 111 |
+
elif config.init_fn == InitFnType.kaiming_normal:
|
| 112 |
+
nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
|
| 113 |
+
elif config.init_fn == InitFnType.fan_in:
|
| 114 |
+
std = std_factor / math.sqrt(d)
|
| 115 |
+
nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 116 |
+
elif config.init_fn == InitFnType.full_megatron:
|
| 117 |
+
if type_of_module is None:
|
| 118 |
+
raise RuntimeError(f"When using the {InitFnType.full_megatron} init, every module must have a type.")
|
| 119 |
+
|
| 120 |
+
cutoff_factor = config.init_cutoff_factor
|
| 121 |
+
if cutoff_factor is None:
|
| 122 |
+
cutoff_factor = 3
|
| 123 |
+
|
| 124 |
+
if type_of_module == ModuleType.in_module:
|
| 125 |
+
# for att_proj (same as QKV), ff_proj
|
| 126 |
+
std = config.init_std
|
| 127 |
+
elif type_of_module == ModuleType.out_module:
|
| 128 |
+
# for attn_out, ff_out
|
| 129 |
+
std = config.init_std / math.sqrt(2.0 * config.n_layers)
|
| 130 |
+
elif type_of_module == ModuleType.emb:
|
| 131 |
+
# positional embeddings (wpe)
|
| 132 |
+
# token embeddings (wte)
|
| 133 |
+
std = config.init_std
|
| 134 |
+
elif type_of_module == ModuleType.final_out:
|
| 135 |
+
# final output (ff_out)
|
| 136 |
+
std = config.d_model**-0.5
|
| 137 |
+
else:
|
| 138 |
+
raise RuntimeError(f"Unknown module type '{type_of_module}'")
|
| 139 |
+
nn.init.trunc_normal_(
|
| 140 |
+
module.weight,
|
| 141 |
+
mean=0.0,
|
| 142 |
+
std=std,
|
| 143 |
+
a=-cutoff_factor * std,
|
| 144 |
+
b=cutoff_factor * std,
|
| 145 |
+
)
|
| 146 |
+
else:
|
| 147 |
+
raise NotImplementedError(config.init_fn)
|
| 148 |
+
|
| 149 |
+
if isinstance(module, nn.Linear):
|
| 150 |
+
if module.bias is not None:
|
| 151 |
+
nn.init.zeros_(module.bias)
|
| 152 |
+
|
| 153 |
+
if config.init_fn == InitFnType.normal and getattr(module, "_is_residual", False):
|
| 154 |
+
with torch.no_grad():
|
| 155 |
+
module.weight.div_(math.sqrt(2 * config.n_layers))
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False):
|
| 159 |
+
"""
|
| 160 |
+
Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf``
|
| 161 |
+
is ``True`` and to replace ``float("inf")`` with the maximum value of the dtype when ``check_pos_inf`` is ``True``.
|
| 162 |
+
"""
|
| 163 |
+
if check_neg_inf:
|
| 164 |
+
x.masked_fill_(x == float("-inf"), torch.finfo(x.dtype).min)
|
| 165 |
+
if check_pos_inf:
|
| 166 |
+
x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def activation_checkpoint_function(cfg: ModelConfig):
|
| 170 |
+
preserve_rng_state = (
|
| 171 |
+
(cfg.attention_dropout == 0.0) and (cfg.embedding_dropout == 0.0) and (cfg.residual_dropout == 0.0)
|
| 172 |
+
)
|
| 173 |
+
from torch.utils.checkpoint import checkpoint
|
| 174 |
+
|
| 175 |
+
return partial(
|
| 176 |
+
checkpoint,
|
| 177 |
+
preserve_rng_state=preserve_rng_state,
|
| 178 |
+
use_reentrant=False,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class BufferCache(dict, MutableMapping[str, torch.Tensor]):
|
| 183 |
+
"""
|
| 184 |
+
Cache for attention biases and other things that would normally be stored as buffers.
|
| 185 |
+
We avoid using buffers because we've run into various issues doing so with FSDP.
|
| 186 |
+
In general it appears the way FSDP handles buffers is not well-defined.
|
| 187 |
+
It doesn't shard them but apparently it does synchronize them across processes, which we want to avoid
|
| 188 |
+
since (A) it isn't necessary, and (B) we sometimes have `-inf` in these biases which might get turned into
|
| 189 |
+
NaNs when they're synchronized due to casting or some other issue.
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def _non_meta_init_device(config: ModelConfig) -> torch.device:
|
| 194 |
+
if config.init_device is not None and config.init_device != "meta":
|
| 195 |
+
return torch.device(config.init_device)
|
| 196 |
+
else:
|
| 197 |
+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class Dropout(nn.Dropout):
|
| 201 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 202 |
+
if self.p == 0.0:
|
| 203 |
+
return input
|
| 204 |
+
else:
|
| 205 |
+
return F.dropout(input, self.p, self.training, self.inplace)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class LayerNormBase(nn.Module):
|
| 209 |
+
def __init__(
|
| 210 |
+
self,
|
| 211 |
+
config: ModelConfig,
|
| 212 |
+
*,
|
| 213 |
+
size: Optional[int] = None,
|
| 214 |
+
elementwise_affine: Optional[bool] = True,
|
| 215 |
+
eps: float = 1e-05,
|
| 216 |
+
):
|
| 217 |
+
super().__init__()
|
| 218 |
+
self.config = config
|
| 219 |
+
self.eps = eps
|
| 220 |
+
self.normalized_shape = (size or config.d_model,)
|
| 221 |
+
if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine):
|
| 222 |
+
self.weight = nn.Parameter(torch.ones(self.normalized_shape, device=config.init_device))
|
| 223 |
+
use_bias = self.config.bias_for_layer_norm
|
| 224 |
+
if use_bias is None:
|
| 225 |
+
use_bias = self.config.include_bias
|
| 226 |
+
if use_bias:
|
| 227 |
+
self.bias = nn.Parameter(torch.zeros(self.normalized_shape, device=config.init_device))
|
| 228 |
+
else:
|
| 229 |
+
self.register_parameter("bias", None)
|
| 230 |
+
else:
|
| 231 |
+
self.register_parameter("bias", None)
|
| 232 |
+
self.register_parameter("weight", None)
|
| 233 |
+
|
| 234 |
+
@abstractmethod
|
| 235 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 236 |
+
raise NotImplementedError
|
| 237 |
+
|
| 238 |
+
@classmethod
|
| 239 |
+
def build(cls, config: ModelConfig, size: Optional[int] = None, **kwargs) -> LayerNormBase:
|
| 240 |
+
if config.layer_norm_type == LayerNormType.default:
|
| 241 |
+
return LayerNorm(config, size=size, low_precision=False, **kwargs)
|
| 242 |
+
elif config.layer_norm_type == LayerNormType.low_precision:
|
| 243 |
+
return LayerNorm(config, size=size, low_precision=True, **kwargs)
|
| 244 |
+
elif config.layer_norm_type == LayerNormType.rms:
|
| 245 |
+
return RMSLayerNorm(config, size=size, **kwargs)
|
| 246 |
+
elif config.layer_norm_type == LayerNormType.gemma_rms:
|
| 247 |
+
return GemmaRMSLayerNorm(config, size=size, **kwargs)
|
| 248 |
+
else:
|
| 249 |
+
raise NotImplementedError(f"Unknown LayerNorm type: '{config.layer_norm_type}'")
|
| 250 |
+
|
| 251 |
+
def _cast_if_autocast_enabled(self, tensor: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
|
| 252 |
+
# NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
|
| 253 |
+
# `is_autocast_cpu_enabled()` for CPU autocast.
|
| 254 |
+
# See https://github.com/pytorch/pytorch/issues/110966.
|
| 255 |
+
if tensor.device.type == "cuda" and torch.is_autocast_enabled():
|
| 256 |
+
return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_gpu_dtype())
|
| 257 |
+
elif tensor.device.type == "cpu" and torch.is_autocast_cpu_enabled():
|
| 258 |
+
return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_cpu_dtype())
|
| 259 |
+
else:
|
| 260 |
+
return tensor
|
| 261 |
+
|
| 262 |
+
def reset_parameters(self):
|
| 263 |
+
if self.weight is not None:
|
| 264 |
+
torch.nn.init.ones_(self.weight) # type: ignore
|
| 265 |
+
if self.bias is not None:
|
| 266 |
+
torch.nn.init.zeros_(self.bias) # type: ignore
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class LayerNorm(LayerNormBase):
|
| 270 |
+
"""
|
| 271 |
+
The default :class:`LayerNorm` implementation which can optionally run in low precision.
|
| 272 |
+
"""
|
| 273 |
+
|
| 274 |
+
def __init__(
|
| 275 |
+
self,
|
| 276 |
+
config: ModelConfig,
|
| 277 |
+
size: Optional[int] = None,
|
| 278 |
+
low_precision: bool = False,
|
| 279 |
+
elementwise_affine: Optional[bool] = None,
|
| 280 |
+
eps: float = 1e-05,
|
| 281 |
+
):
|
| 282 |
+
super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps)
|
| 283 |
+
self.low_precision = low_precision
|
| 284 |
+
|
| 285 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 286 |
+
if self.low_precision:
|
| 287 |
+
module_device = x.device
|
| 288 |
+
downcast_x = self._cast_if_autocast_enabled(x)
|
| 289 |
+
downcast_weight = (
|
| 290 |
+
self._cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
|
| 291 |
+
)
|
| 292 |
+
downcast_bias = self._cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
|
| 293 |
+
with torch.autocast(enabled=False, device_type=module_device.type):
|
| 294 |
+
return F.layer_norm(
|
| 295 |
+
downcast_x, self.normalized_shape, weight=downcast_weight, bias=downcast_bias, eps=self.eps
|
| 296 |
+
)
|
| 297 |
+
else:
|
| 298 |
+
return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class RMSLayerNorm(LayerNormBase):
|
| 302 |
+
"""
|
| 303 |
+
RMS layer norm, a simplified :class:`LayerNorm` implementation
|
| 304 |
+
"""
|
| 305 |
+
|
| 306 |
+
def __init__(
|
| 307 |
+
self,
|
| 308 |
+
config: ModelConfig,
|
| 309 |
+
size: Optional[int] = None,
|
| 310 |
+
elementwise_affine: Optional[bool] = None,
|
| 311 |
+
eps: float = 1e-5,
|
| 312 |
+
):
|
| 313 |
+
super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=config.rms_norm_eps)
|
| 314 |
+
|
| 315 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 316 |
+
# with torch.autocast(enabled=False, device_type=x.device.type):
|
| 317 |
+
og_dtype = x.dtype
|
| 318 |
+
x = x.to(torch.float32)
|
| 319 |
+
# print(x.dtype,x.shape)
|
| 320 |
+
variance = x*x
|
| 321 |
+
# print(variance)
|
| 322 |
+
variance = variance.mean(dim=-1,keepdim=True)
|
| 323 |
+
x = x * torch.rsqrt(variance + self.eps)
|
| 324 |
+
x = x.to(og_dtype)
|
| 325 |
+
|
| 326 |
+
if self.weight is not None:
|
| 327 |
+
if self.bias is not None:
|
| 328 |
+
return self.weight * x + self.bias
|
| 329 |
+
else:
|
| 330 |
+
return self.weight * x
|
| 331 |
+
else:
|
| 332 |
+
return x
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
class GemmaRMSLayerNorm(LayerNormBase):
|
| 336 |
+
"""
|
| 337 |
+
Gemma RMS layer norm, a simplified :class:`LayerNorm` implementation
|
| 338 |
+
"""
|
| 339 |
+
|
| 340 |
+
def __init__(
|
| 341 |
+
self,
|
| 342 |
+
config: ModelConfig,
|
| 343 |
+
size: Optional[int] = None,
|
| 344 |
+
elementwise_affine: Optional[bool] = None,
|
| 345 |
+
eps: float = 1e-5,
|
| 346 |
+
):
|
| 347 |
+
super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=config.rms_norm_eps)
|
| 348 |
+
|
| 349 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 350 |
+
with torch.autocast(enabled=False, device_type=x.device.type):
|
| 351 |
+
og_dtype = x.dtype
|
| 352 |
+
x = x.to(torch.float32)
|
| 353 |
+
variance = x.pow(2).mean(-1, keepdim=True)
|
| 354 |
+
x = x * torch.rsqrt(variance + self.eps)
|
| 355 |
+
x = x.to(og_dtype)
|
| 356 |
+
|
| 357 |
+
if self.weight is not None:
|
| 358 |
+
if self.bias is not None:
|
| 359 |
+
return x * (1 + self.weight) + self.bias
|
| 360 |
+
else:
|
| 361 |
+
return x * (1 + self.weight)
|
| 362 |
+
else:
|
| 363 |
+
return x
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
class RotaryEmbedding(nn.Module):
|
| 367 |
+
"""
|
| 368 |
+
[Rotary positional embeddings (RoPE)](https://arxiv.org/abs/2104.09864).
|
| 369 |
+
"""
|
| 370 |
+
|
| 371 |
+
def __init__(self, config: ModelConfig, cache: BufferCache):
|
| 372 |
+
super().__init__()
|
| 373 |
+
self.config = config
|
| 374 |
+
self.__cache = cache
|
| 375 |
+
# Warm up cache.
|
| 376 |
+
self.rope_theta = config.rope_theta
|
| 377 |
+
self.get_rotary_embedding(config.max_sequence_length, _non_meta_init_device(config))
|
| 378 |
+
|
| 379 |
+
def get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 380 |
+
if (
|
| 381 |
+
(pos_sin := self.__cache.get("rope_pos_sin")) is not None
|
| 382 |
+
and (pos_cos := self.__cache.get("rope_pos_cos")) is not None
|
| 383 |
+
and pos_sin.shape[-2] >= seq_len
|
| 384 |
+
and pos_cos.shape[-2] >= seq_len
|
| 385 |
+
):
|
| 386 |
+
if pos_sin.device != device:
|
| 387 |
+
pos_sin = pos_sin.to(device)
|
| 388 |
+
self.__cache["rope_pos_sin"] = pos_sin
|
| 389 |
+
if pos_cos.device != device:
|
| 390 |
+
pos_cos = pos_cos.to(device)
|
| 391 |
+
self.__cache["rope_pos_cos"] = pos_cos
|
| 392 |
+
return pos_sin[:, :, :seq_len, :], pos_cos[:, :, :seq_len, :]
|
| 393 |
+
|
| 394 |
+
with torch.autocast(device.type, enabled=False):
|
| 395 |
+
dim = self.config.d_model // self.config.n_heads
|
| 396 |
+
inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim))
|
| 397 |
+
seq = torch.arange(seq_len, device=device, dtype=torch.float)
|
| 398 |
+
freqs = einsum("i , j -> i j", seq, inv_freq)
|
| 399 |
+
positions = torch.cat((freqs, freqs), dim=-1)
|
| 400 |
+
pos_sin, pos_cos = positions.sin()[None, None, :, :], positions.cos()[None, None, :, :]
|
| 401 |
+
self.__cache["rope_pos_sin"] = pos_sin
|
| 402 |
+
self.__cache["rope_pos_cos"] = pos_cos
|
| 403 |
+
return pos_sin, pos_cos
|
| 404 |
+
|
| 405 |
+
def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
|
| 406 |
+
B, nh, T, hs = x.size()
|
| 407 |
+
x = x.view(B, nh, T, 2, hs // 2)
|
| 408 |
+
x1, x2 = x.unbind(dim=-2)
|
| 409 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 410 |
+
|
| 411 |
+
def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
| 412 |
+
return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype)
|
| 413 |
+
|
| 414 |
+
def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 415 |
+
if self.config.rope_full_precision:
|
| 416 |
+
q_, k_ = q.float(), k.float()
|
| 417 |
+
else:
|
| 418 |
+
q_, k_ = q, k
|
| 419 |
+
|
| 420 |
+
with torch.autocast(q.device.type, enabled=False):
|
| 421 |
+
query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None
|
| 422 |
+
pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device)
|
| 423 |
+
pos_sin = pos_sin.type_as(q_)
|
| 424 |
+
pos_cos = pos_cos.type_as(q_)
|
| 425 |
+
q_ = self.apply_rotary_pos_emb(
|
| 426 |
+
pos_sin[:, :, key_len - query_len : key_len, :],
|
| 427 |
+
pos_cos[:, :, key_len - query_len : key_len, :],
|
| 428 |
+
q_,
|
| 429 |
+
)
|
| 430 |
+
k_ = self.apply_rotary_pos_emb(pos_sin, pos_cos, k_)
|
| 431 |
+
return q_.type_as(q), k_.type_as(k)
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
class Activation(nn.Module):
|
| 435 |
+
def __init__(self, config: ModelConfig):
|
| 436 |
+
super().__init__()
|
| 437 |
+
self.config = config
|
| 438 |
+
|
| 439 |
+
@abstractmethod
|
| 440 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 441 |
+
raise NotImplementedError
|
| 442 |
+
|
| 443 |
+
@property
|
| 444 |
+
@abstractmethod
|
| 445 |
+
def output_multiplier(self) -> float:
|
| 446 |
+
raise NotImplementedError
|
| 447 |
+
|
| 448 |
+
@classmethod
|
| 449 |
+
def build(cls, config: ModelConfig) -> Activation:
|
| 450 |
+
if config.activation_type == ActivationType.gelu:
|
| 451 |
+
return cast(Activation, GELU(approximate="none"))
|
| 452 |
+
elif config.activation_type == ActivationType.relu:
|
| 453 |
+
return cast(Activation, ReLU(inplace=False))
|
| 454 |
+
elif config.activation_type == ActivationType.silu:
|
| 455 |
+
return cast(Activation, SiLU(inplace=False))
|
| 456 |
+
elif config.activation_type == ActivationType.swiglu:
|
| 457 |
+
return SwiGLU(config)
|
| 458 |
+
else:
|
| 459 |
+
raise NotImplementedError(f"Unknown activation: '{config.activation_type}'")
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
class GELU(nn.GELU):
|
| 463 |
+
@property
|
| 464 |
+
def output_multiplier(self) -> float:
|
| 465 |
+
return 1.0
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
class ReLU(nn.ReLU):
|
| 469 |
+
@property
|
| 470 |
+
def output_multiplier(self) -> float:
|
| 471 |
+
return 1.0
|
| 472 |
+
|
| 473 |
+
class SiLU(nn.SiLU):
|
| 474 |
+
@property
|
| 475 |
+
def output_multiplier(self) -> float:
|
| 476 |
+
return 1.0
|
| 477 |
+
|
| 478 |
+
class SwiGLU(Activation):
|
| 479 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 480 |
+
x, gate = x.chunk(2, dim=-1)
|
| 481 |
+
return F.silu(gate) * x
|
| 482 |
+
|
| 483 |
+
@property
|
| 484 |
+
def output_multiplier(self) -> float:
|
| 485 |
+
return 0.5
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor:
|
| 489 |
+
att_bias = torch.triu(
|
| 490 |
+
torch.ones(seq_len, seq_len, device=device, dtype=torch.float),
|
| 491 |
+
diagonal=1,
|
| 492 |
+
)
|
| 493 |
+
att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min)
|
| 494 |
+
return att_bias.view(1, 1, seq_len, seq_len) # type: ignore
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
def get_causal_attention_bias(cache: BufferCache, seq_len: int, device: torch.device) -> torch.Tensor:
|
| 498 |
+
if (causal_bias := cache.get("causal_attention_bias")) is not None and causal_bias.shape[-1] >= seq_len:
|
| 499 |
+
if causal_bias.device != device:
|
| 500 |
+
causal_bias = causal_bias.to(device)
|
| 501 |
+
cache["causal_attention_bias"] = causal_bias
|
| 502 |
+
return causal_bias
|
| 503 |
+
with torch.autocast(device.type, enabled=False):
|
| 504 |
+
causal_bias = causal_attention_bias(seq_len, device)
|
| 505 |
+
cache["causal_attention_bias"] = causal_bias
|
| 506 |
+
return causal_bias
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
def alibi_attention_bias(seq_len: int, config: ModelConfig, device: torch.device) -> torch.FloatTensor:
|
| 510 |
+
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len)
|
| 511 |
+
|
| 512 |
+
# shape: (1, 1, seq_len, seq_len)
|
| 513 |
+
alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1)
|
| 514 |
+
alibi_bias.abs_().mul_(-1)
|
| 515 |
+
|
| 516 |
+
# shape: (n_heads,)
|
| 517 |
+
m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device)
|
| 518 |
+
m.mul_(config.alibi_bias_max / config.n_heads)
|
| 519 |
+
|
| 520 |
+
# shape: (1, n_heads, seq_len, seq_len)
|
| 521 |
+
return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1))) # type: ignore
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
class LLaDABlock(nn.Module):
|
| 525 |
+
"""
|
| 526 |
+
A base class for transformer block implementations.
|
| 527 |
+
"""
|
| 528 |
+
|
| 529 |
+
def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
|
| 530 |
+
super().__init__()
|
| 531 |
+
self.layer_id = layer_id
|
| 532 |
+
self.config = config
|
| 533 |
+
self.hidden_size = (
|
| 534 |
+
config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
|
| 535 |
+
)
|
| 536 |
+
self.__cache = cache
|
| 537 |
+
assert config.d_model % config.n_heads == 0
|
| 538 |
+
|
| 539 |
+
self._activation_checkpoint_fn = None
|
| 540 |
+
|
| 541 |
+
# Dropout.
|
| 542 |
+
self.dropout = Dropout(config.residual_dropout)
|
| 543 |
+
|
| 544 |
+
# Layer norms.
|
| 545 |
+
self.k_norm: Optional[LayerNormBase] = None
|
| 546 |
+
self.q_norm: Optional[LayerNormBase] = None
|
| 547 |
+
if config.attention_layer_norm:
|
| 548 |
+
self.k_norm = LayerNormBase.build(
|
| 549 |
+
config,
|
| 550 |
+
size=(config.d_model // config.n_heads) * config.effective_n_kv_heads,
|
| 551 |
+
elementwise_affine=config.attention_layer_norm_with_affine,
|
| 552 |
+
)
|
| 553 |
+
self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine)
|
| 554 |
+
|
| 555 |
+
# Activation function.
|
| 556 |
+
self.act = Activation.build(config)
|
| 557 |
+
assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
|
| 558 |
+
|
| 559 |
+
# Attention output projection.
|
| 560 |
+
self.attn_out = nn.Linear(
|
| 561 |
+
config.d_model, config.d_model, bias=config.include_bias, device=config.init_device
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
# Feed-forward output projection.
|
| 565 |
+
self.ff_out = nn.Linear(
|
| 566 |
+
int(self.act.output_multiplier * self.hidden_size),
|
| 567 |
+
config.d_model,
|
| 568 |
+
bias=config.include_bias,
|
| 569 |
+
device=config.init_device,
|
| 570 |
+
)
|
| 571 |
+
self.ff_out._is_residual = True # type: ignore
|
| 572 |
+
|
| 573 |
+
# Rotary embeddings.
|
| 574 |
+
if self.config.rope:
|
| 575 |
+
self.rotary_emb = RotaryEmbedding(config, self.__cache)
|
| 576 |
+
|
| 577 |
+
self.flash_attn_func = None
|
| 578 |
+
if config.flash_attention:
|
| 579 |
+
try:
|
| 580 |
+
from flash_attn import flash_attn_func # type: ignore
|
| 581 |
+
|
| 582 |
+
self.flash_attn_func = flash_attn_func
|
| 583 |
+
except ModuleNotFoundError:
|
| 584 |
+
pass
|
| 585 |
+
|
| 586 |
+
def reset_parameters(self):
|
| 587 |
+
if self.k_norm is not None:
|
| 588 |
+
self.k_norm.reset_parameters()
|
| 589 |
+
if self.q_norm is not None:
|
| 590 |
+
self.q_norm.reset_parameters()
|
| 591 |
+
init_weights(
|
| 592 |
+
self.config,
|
| 593 |
+
self.attn_out,
|
| 594 |
+
d=self.config.d_model,
|
| 595 |
+
layer_id=self.layer_id,
|
| 596 |
+
type_of_module=ModuleType.out_module,
|
| 597 |
+
)
|
| 598 |
+
init_weights(
|
| 599 |
+
self.config,
|
| 600 |
+
self.ff_out,
|
| 601 |
+
d=self.ff_out.in_features,
|
| 602 |
+
layer_id=self.layer_id,
|
| 603 |
+
type_of_module=ModuleType.out_module,
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
|
| 607 |
+
if strategy == ActivationCheckpointingStrategy.fine_grained:
|
| 608 |
+
self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
|
| 609 |
+
else:
|
| 610 |
+
self._activation_checkpoint_fn = None
|
| 611 |
+
|
| 612 |
+
@classmethod
|
| 613 |
+
def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor:
|
| 614 |
+
target_dtype = input_dtype
|
| 615 |
+
# NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
|
| 616 |
+
# `is_autocast_cpu_enabled()` for CPU autocast.
|
| 617 |
+
# See https://github.com/pytorch/pytorch/issues/110966.
|
| 618 |
+
if bias.device.type == "cuda" and torch.is_autocast_enabled():
|
| 619 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
| 620 |
+
elif bias.device.type == "cpu" and torch.is_autocast_cpu_enabled():
|
| 621 |
+
target_dtype = torch.get_autocast_cpu_dtype()
|
| 622 |
+
if bias.dtype != target_dtype:
|
| 623 |
+
bias = bias.to(target_dtype)
|
| 624 |
+
ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False)
|
| 625 |
+
return bias
|
| 626 |
+
|
| 627 |
+
def _scaled_dot_product_attention(
|
| 628 |
+
self,
|
| 629 |
+
q: torch.Tensor,
|
| 630 |
+
k: torch.Tensor,
|
| 631 |
+
v: torch.Tensor,
|
| 632 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 633 |
+
dropout_p: float = 0.0,
|
| 634 |
+
is_causal: bool = False,
|
| 635 |
+
) -> torch.Tensor:
|
| 636 |
+
"""
|
| 637 |
+
Computes scaled dot product attention on query, key and value tensors, using an optional
|
| 638 |
+
attention mask if passed, and applying dropout if a probability greater than 0.0 is specified.
|
| 639 |
+
"""
|
| 640 |
+
if self.flash_attn_func is not None and attn_mask is None:
|
| 641 |
+
r = self.flash_attn_func(
|
| 642 |
+
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p, causal=False
|
| 643 |
+
)
|
| 644 |
+
return r.transpose(1, 2)
|
| 645 |
+
else:
|
| 646 |
+
# torch's sdpa doesn't support GQA, so we're doing this
|
| 647 |
+
assert k.size(1) == v.size(1)
|
| 648 |
+
num_kv_heads = k.size(1)
|
| 649 |
+
num_q_heads = q.size(1)
|
| 650 |
+
if num_q_heads != num_kv_heads:
|
| 651 |
+
assert num_q_heads % num_kv_heads == 0
|
| 652 |
+
k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
|
| 653 |
+
v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
|
| 654 |
+
# Modify: MDM set causal to False, and with no attn_mask.
|
| 655 |
+
return F.scaled_dot_product_attention(
|
| 656 |
+
q,
|
| 657 |
+
k,
|
| 658 |
+
v,
|
| 659 |
+
attn_mask=attn_mask,
|
| 660 |
+
dropout_p=dropout_p,
|
| 661 |
+
is_causal=False,
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
def attention(
|
| 665 |
+
self,
|
| 666 |
+
q: torch.Tensor,
|
| 667 |
+
k: torch.Tensor,
|
| 668 |
+
v: torch.Tensor,
|
| 669 |
+
attention_bias: Optional[torch.Tensor] = None,
|
| 670 |
+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 671 |
+
use_cache: bool = False,
|
| 672 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 673 |
+
B, T, C = q.size() # batch size, sequence length, d_model
|
| 674 |
+
dtype = k.dtype
|
| 675 |
+
|
| 676 |
+
# Optionally apply layer norm to keys and queries.
|
| 677 |
+
if self.q_norm is not None and self.k_norm is not None:
|
| 678 |
+
q = self.q_norm(q).to(dtype=dtype)
|
| 679 |
+
k = self.k_norm(k).to(dtype=dtype)
|
| 680 |
+
|
| 681 |
+
# Move head forward to be next to the batch dim.
|
| 682 |
+
# shape: (B, nh, T, hs)
|
| 683 |
+
q = q.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2)
|
| 684 |
+
# shape: (B, n_kv_h, T, hs)
|
| 685 |
+
k = k.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
|
| 686 |
+
# shape: (B, n_kv_h, T, hs)
|
| 687 |
+
v = v.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
|
| 688 |
+
|
| 689 |
+
if layer_past is not None:
|
| 690 |
+
past_key, past_value = layer_past
|
| 691 |
+
k = torch.cat((past_key, k), dim=-2)
|
| 692 |
+
v = torch.cat((past_value, v), dim=-2)
|
| 693 |
+
|
| 694 |
+
present = (k, v) if use_cache else None
|
| 695 |
+
query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
|
| 696 |
+
|
| 697 |
+
if self.config.rope:
|
| 698 |
+
# Apply rotary embeddings.
|
| 699 |
+
q, k = self.rotary_emb(q, k)
|
| 700 |
+
|
| 701 |
+
# if attention_bias is not None:
|
| 702 |
+
# # Resize and cast attention bias.
|
| 703 |
+
# # The current dtype of the attention bias might not match the dtype that the SDP attn function will
|
| 704 |
+
# # run in if AMP is enabled, and this can be a problem if some tokens are masked out due to padding
|
| 705 |
+
# # as down-casting the attention bias to the autocast precision will result in -infs, which will
|
| 706 |
+
# # cause the SDP attn function to produce NaNs.
|
| 707 |
+
# attention_bias = self._cast_attn_bias(
|
| 708 |
+
# attention_bias[:, :, key_len - query_len : key_len, :key_len], dtype
|
| 709 |
+
# )
|
| 710 |
+
|
| 711 |
+
# Get the attention scores.
|
| 712 |
+
# shape: (B, nh, T, hs)
|
| 713 |
+
att = self._scaled_dot_product_attention(
|
| 714 |
+
q,
|
| 715 |
+
k,
|
| 716 |
+
v,
|
| 717 |
+
attn_mask=attention_bias,
|
| 718 |
+
dropout_p=0.0 if not self.training else self.config.attention_dropout,
|
| 719 |
+
is_causal=False,
|
| 720 |
+
)
|
| 721 |
+
|
| 722 |
+
# Re-assemble all head outputs side-by-side.
|
| 723 |
+
att = att.transpose(1, 2).contiguous().view(B, T, C)
|
| 724 |
+
|
| 725 |
+
# Apply output projection.
|
| 726 |
+
return self.attn_out(att), present
|
| 727 |
+
|
| 728 |
+
@abstractmethod
|
| 729 |
+
def forward(
|
| 730 |
+
self,
|
| 731 |
+
x: torch.Tensor,
|
| 732 |
+
attention_bias: Optional[torch.FloatTensor] = None,
|
| 733 |
+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 734 |
+
use_cache: bool = False,
|
| 735 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 736 |
+
raise NotImplementedError
|
| 737 |
+
|
| 738 |
+
@classmethod
|
| 739 |
+
def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> LLaDABlock:
|
| 740 |
+
if config.block_type == BlockType.sequential:
|
| 741 |
+
return LLaDASequentialBlock(layer_id, config, cache)
|
| 742 |
+
elif config.block_type == BlockType.llama:
|
| 743 |
+
return LLaDALlamaBlock(layer_id, config, cache)
|
| 744 |
+
else:
|
| 745 |
+
raise NotImplementedError(f"Unknown block type: '{config.block_type}'")
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
class LLaDASequentialBlock(LLaDABlock):
|
| 749 |
+
"""
|
| 750 |
+
This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
|
| 751 |
+
(plus another skip connection).
|
| 752 |
+
"""
|
| 753 |
+
|
| 754 |
+
def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
|
| 755 |
+
super().__init__(layer_id, config, cache)
|
| 756 |
+
# Layer norms.
|
| 757 |
+
self.attn_norm = LayerNorm.build(config)
|
| 758 |
+
self.ff_norm = LayerNorm.build(config)
|
| 759 |
+
# Attention input projection. Projects x -> (q, k, v)
|
| 760 |
+
head_dim = config.d_model // config.n_heads
|
| 761 |
+
self.fused_dims = (
|
| 762 |
+
config.d_model,
|
| 763 |
+
config.effective_n_kv_heads * head_dim,
|
| 764 |
+
config.effective_n_kv_heads * head_dim,
|
| 765 |
+
)
|
| 766 |
+
self.att_proj = nn.Linear(
|
| 767 |
+
config.d_model, sum(self.fused_dims), bias=config.include_bias | config.include_qkv_bias, device=config.init_device
|
| 768 |
+
)
|
| 769 |
+
# Feed-forward input projection.
|
| 770 |
+
self.ff_proj = nn.Linear(
|
| 771 |
+
config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
def reset_parameters(self):
|
| 775 |
+
super().reset_parameters()
|
| 776 |
+
self.attn_norm.reset_parameters()
|
| 777 |
+
self.ff_norm.reset_parameters()
|
| 778 |
+
# NOTE: the standard deviation for these weights does not depend on the layer.
|
| 779 |
+
init_weights(
|
| 780 |
+
self.config, self.att_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
|
| 781 |
+
)
|
| 782 |
+
init_weights(
|
| 783 |
+
self.config, self.ff_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
def forward(
|
| 787 |
+
self,
|
| 788 |
+
x: torch.Tensor,
|
| 789 |
+
attention_bias: Optional[torch.Tensor] = None,
|
| 790 |
+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 791 |
+
use_cache: bool = False,
|
| 792 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 793 |
+
# Get query, key, value projections.
|
| 794 |
+
# shape:
|
| 795 |
+
# - for regular attn q, k, v: (batch_size, seq_len, d_model)
|
| 796 |
+
# - for multi-query attn q: (batch_size, seq_len, d_model)
|
| 797 |
+
# k, v: (batch_size, seq_len, d_model // n_heads)
|
| 798 |
+
# - for group query attn q: (batch_size, seq_len, d_model)
|
| 799 |
+
# k, v: (batch_size, seq_len, d_model // n_kv_heads)
|
| 800 |
+
if self._activation_checkpoint_fn is not None:
|
| 801 |
+
q, k, v = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x)).split(
|
| 802 |
+
self.fused_dims, dim=-1
|
| 803 |
+
)
|
| 804 |
+
else:
|
| 805 |
+
q, k, v = self.att_proj(self.attn_norm(x)).split(self.fused_dims, dim=-1)
|
| 806 |
+
|
| 807 |
+
# Get attention scores.
|
| 808 |
+
if self._activation_checkpoint_fn is not None:
|
| 809 |
+
att, cache = self._activation_checkpoint_fn( # type: ignore
|
| 810 |
+
self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
|
| 811 |
+
)
|
| 812 |
+
else:
|
| 813 |
+
att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
|
| 814 |
+
|
| 815 |
+
# Add attention scores.
|
| 816 |
+
# shape: (B, T, C)
|
| 817 |
+
x = x + self.dropout(att)
|
| 818 |
+
|
| 819 |
+
# Add feed-forward projection.
|
| 820 |
+
# shape: (batch_size, seq_len, d_model)
|
| 821 |
+
og_x = x
|
| 822 |
+
if self._activation_checkpoint_fn is not None:
|
| 823 |
+
x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
|
| 824 |
+
else:
|
| 825 |
+
x = self.ff_norm(x)
|
| 826 |
+
x = self.ff_proj(x)
|
| 827 |
+
if self._activation_checkpoint_fn is not None:
|
| 828 |
+
x = self._activation_checkpoint_fn(self.act, x) # type: ignore
|
| 829 |
+
else:
|
| 830 |
+
x = self.act(x)
|
| 831 |
+
x = self.ff_out(x)
|
| 832 |
+
x = self.dropout(x)
|
| 833 |
+
x = og_x + x
|
| 834 |
+
|
| 835 |
+
return x, cache
|
| 836 |
+
|
| 837 |
+
|
| 838 |
+
class LLaDALlamaBlock(LLaDABlock):
|
| 839 |
+
"""
|
| 840 |
+
This is a transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
|
| 841 |
+
(plus another skip connection). This block is similar to `LLaDASequentialBlock`
|
| 842 |
+
but some operations have slightly different implementations to imitate the
|
| 843 |
+
behavior of Llama.
|
| 844 |
+
"""
|
| 845 |
+
|
| 846 |
+
def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
|
| 847 |
+
super().__init__(layer_id, config, cache)
|
| 848 |
+
# Layer norms.
|
| 849 |
+
self.attn_norm = LayerNorm.build(config)
|
| 850 |
+
self.ff_norm = LayerNorm.build(config)
|
| 851 |
+
self.__cache = cache
|
| 852 |
+
|
| 853 |
+
# Attention input projection. Projects x -> (q, k, v)
|
| 854 |
+
head_dim = config.d_model // config.n_heads
|
| 855 |
+
q_proj_out_dim = config.d_model
|
| 856 |
+
k_proj_out_dim = config.effective_n_kv_heads * head_dim
|
| 857 |
+
v_proj_out_dim = config.effective_n_kv_heads * head_dim
|
| 858 |
+
self.q_proj = nn.Linear(
|
| 859 |
+
config.d_model, q_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device
|
| 860 |
+
)
|
| 861 |
+
self.k_proj = nn.Linear(
|
| 862 |
+
config.d_model, k_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device
|
| 863 |
+
)
|
| 864 |
+
self.v_proj = nn.Linear(
|
| 865 |
+
config.d_model, v_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device
|
| 866 |
+
)
|
| 867 |
+
|
| 868 |
+
# Feed-forward input projection.
|
| 869 |
+
self.ff_proj = nn.Linear(
|
| 870 |
+
config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
|
| 871 |
+
)
|
| 872 |
+
# new add
|
| 873 |
+
self.up_proj = nn.Linear(
|
| 874 |
+
config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
|
| 875 |
+
)
|
| 876 |
+
|
| 877 |
+
def reset_parameters(self):
|
| 878 |
+
super().reset_parameters()
|
| 879 |
+
self.attn_norm.reset_parameters()
|
| 880 |
+
self.ff_norm.reset_parameters()
|
| 881 |
+
# NOTE: the standard deviation for these weights does not depend on the layer.
|
| 882 |
+
init_weights(self.config, self.q_proj, d=self.config.d_model, layer_id=None)
|
| 883 |
+
init_weights(self.config, self.k_proj, d=self.config.d_model, layer_id=None)
|
| 884 |
+
init_weights(self.config, self.v_proj, d=self.config.d_model, layer_id=None)
|
| 885 |
+
init_weights(self.config, self.ff_proj, d=self.config.d_model, layer_id=None)
|
| 886 |
+
init_weights(self.config, self.up_proj, d=self.config.d_model, layer_id=None) # new add
|
| 887 |
+
|
| 888 |
+
def forward(
|
| 889 |
+
self,
|
| 890 |
+
x: torch.Tensor,
|
| 891 |
+
attention_bias: Optional[torch.Tensor] = None,
|
| 892 |
+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 893 |
+
use_cache: bool = False,
|
| 894 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 895 |
+
# Get query, key, value projections.
|
| 896 |
+
# shape:
|
| 897 |
+
# - for regular attn q, k, v: (batch_size, seq_len, d_model)
|
| 898 |
+
# - for multi-query attn q: (batch_size, seq_len, d_model)
|
| 899 |
+
# k, v: (batch_size, seq_len, d_model // n_heads)
|
| 900 |
+
# - for group query attn q: (batch_size, seq_len, d_model)
|
| 901 |
+
# k, v: (batch_size, seq_len, d_model // n_kv_heads)
|
| 902 |
+
# print(x)
|
| 903 |
+
x_normed = self.attn_norm(x)
|
| 904 |
+
q = self.q_proj(x_normed)
|
| 905 |
+
k = self.k_proj(x_normed)
|
| 906 |
+
v = self.v_proj(x_normed)
|
| 907 |
+
|
| 908 |
+
# Get attention scores.
|
| 909 |
+
if self._activation_checkpoint_fn is not None:
|
| 910 |
+
att, cache = self._activation_checkpoint_fn( # type: ignore
|
| 911 |
+
self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
|
| 912 |
+
)
|
| 913 |
+
else:
|
| 914 |
+
att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
|
| 915 |
+
|
| 916 |
+
# Add attention scores.
|
| 917 |
+
# shape: (B, T, C)
|
| 918 |
+
x = x + self.dropout(att)
|
| 919 |
+
|
| 920 |
+
# Add feed-forward projection.
|
| 921 |
+
# shape: (batch_size, seq_len, d_model)
|
| 922 |
+
og_x = x
|
| 923 |
+
if self._activation_checkpoint_fn is not None:
|
| 924 |
+
x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
|
| 925 |
+
else:
|
| 926 |
+
x = self.ff_norm(x)
|
| 927 |
+
x, x_up = self.ff_proj(x), self.up_proj(x) # new add
|
| 928 |
+
if self._activation_checkpoint_fn is not None:
|
| 929 |
+
x = self._activation_checkpoint_fn(self.act, x) # type: ignore
|
| 930 |
+
else:
|
| 931 |
+
x = self.act(x)
|
| 932 |
+
x = x * x_up # new add
|
| 933 |
+
x = self.ff_out(x)
|
| 934 |
+
x = self.dropout(x)
|
| 935 |
+
x = og_x + x
|
| 936 |
+
|
| 937 |
+
return x, cache
|
| 938 |
+
|
| 939 |
+
|
| 940 |
+
class LLaDAOutput(NamedTuple):
|
| 941 |
+
logits: torch.FloatTensor
|
| 942 |
+
"""
|
| 943 |
+
A tensor of shape `(batch_size, seq_len, vocab_size)` representing the log probabilities
|
| 944 |
+
for the next token *before* normalization via (log) softmax.
|
| 945 |
+
"""
|
| 946 |
+
|
| 947 |
+
attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]]
|
| 948 |
+
"""
|
| 949 |
+
Attention keys and values from each block.
|
| 950 |
+
"""
|
| 951 |
+
|
| 952 |
+
hidden_states: Optional[Tuple[torch.Tensor]]
|
| 953 |
+
"""
|
| 954 |
+
Hidden states from each block.
|
| 955 |
+
"""
|
| 956 |
+
|
| 957 |
+
|
| 958 |
+
class LLaDAGenerateOutput(NamedTuple):
|
| 959 |
+
token_ids: torch.LongTensor
|
| 960 |
+
"""
|
| 961 |
+
The generated token IDs, a tensor of shape `(batch_size, beam_size, max_steps)`.
|
| 962 |
+
These do *not* include the original input IDs.
|
| 963 |
+
"""
|
| 964 |
+
|
| 965 |
+
scores: torch.FloatTensor
|
| 966 |
+
"""
|
| 967 |
+
The scores of the generated sequences, a tensor of shape `(batch_size, beam_size)`.
|
| 968 |
+
"""
|
| 969 |
+
|
| 970 |
+
|
| 971 |
+
class LLaDABlockGroup(nn.ModuleList):
|
| 972 |
+
def __init__(self, config: ModelConfig, layer_offset: int, modules: Optional[Iterable[nn.Module]] = None):
|
| 973 |
+
super().__init__(modules)
|
| 974 |
+
self.config = config
|
| 975 |
+
self.layer_offset = layer_offset
|
| 976 |
+
self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
|
| 977 |
+
self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
|
| 978 |
+
|
| 979 |
+
def forward(
|
| 980 |
+
self,
|
| 981 |
+
x: torch.Tensor,
|
| 982 |
+
attention_bias: Optional[torch.FloatTensor] = None,
|
| 983 |
+
layers_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
| 984 |
+
use_cache: bool = False,
|
| 985 |
+
) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
|
| 986 |
+
attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
|
| 987 |
+
for block_idx, block in enumerate(self):
|
| 988 |
+
layer_past = None if layers_past is None else layers_past[block_idx]
|
| 989 |
+
block_idx += self.layer_offset
|
| 990 |
+
if (
|
| 991 |
+
(self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.whole_layer)
|
| 992 |
+
or (
|
| 993 |
+
self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_two
|
| 994 |
+
and block_idx % 2 == 0
|
| 995 |
+
)
|
| 996 |
+
or (
|
| 997 |
+
self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_three
|
| 998 |
+
and block_idx % 3 == 0
|
| 999 |
+
)
|
| 1000 |
+
or (
|
| 1001 |
+
self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_four
|
| 1002 |
+
and block_idx % 4 == 0
|
| 1003 |
+
)
|
| 1004 |
+
):
|
| 1005 |
+
# shape: (batch_size, seq_len, d_model)
|
| 1006 |
+
x, cache = self._activation_checkpoint_fn( # type: ignore
|
| 1007 |
+
block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
|
| 1008 |
+
)
|
| 1009 |
+
else:
|
| 1010 |
+
# shape: (batch_size, seq_len, d_model)
|
| 1011 |
+
x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
|
| 1012 |
+
if attn_key_values is not None:
|
| 1013 |
+
assert cache is not None
|
| 1014 |
+
attn_key_values.append(cache)
|
| 1015 |
+
return x, attn_key_values
|
| 1016 |
+
|
| 1017 |
+
def reset_parameters(self):
|
| 1018 |
+
for block in self:
|
| 1019 |
+
block.reset_parameters()
|
| 1020 |
+
|
| 1021 |
+
def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
|
| 1022 |
+
self.activation_checkpointing_strategy = strategy
|
| 1023 |
+
for block in self:
|
| 1024 |
+
block.set_activation_checkpointing(strategy)
|
| 1025 |
+
|
| 1026 |
+
|
| 1027 |
+
class LLaDAModel(nn.Module):
|
| 1028 |
+
def __init__(self, config: ModelConfig, init_params: bool = True):
|
| 1029 |
+
super().__init__()
|
| 1030 |
+
self.config = config
|
| 1031 |
+
self.__cache = BufferCache()
|
| 1032 |
+
|
| 1033 |
+
# Validate config.
|
| 1034 |
+
if self.config.alibi and self.config.flash_attention:
|
| 1035 |
+
raise Exception("ALiBi is currently not supported with FlashAttention")
|
| 1036 |
+
|
| 1037 |
+
if self.config.alibi and self.config.rope:
|
| 1038 |
+
raise Exception("ALiBi and RoPE are mutually exclusive")
|
| 1039 |
+
|
| 1040 |
+
if self.config.embedding_size is not None and self.config.embedding_size != self.config.vocab_size:
|
| 1041 |
+
if self.config.embedding_size < self.config.vocab_size:
|
| 1042 |
+
raise Exception("embedding size should be at least as big as vocab size")
|
| 1043 |
+
elif self.config.embedding_size % 128 != 0:
|
| 1044 |
+
import warnings
|
| 1045 |
+
|
| 1046 |
+
warnings.warn(
|
| 1047 |
+
"Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning
|
| 1048 |
+
)
|
| 1049 |
+
|
| 1050 |
+
self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
|
| 1051 |
+
self._activation_checkpoint_fn: Callable = activation_checkpoint_function(self.config)
|
| 1052 |
+
|
| 1053 |
+
if not (
|
| 1054 |
+
0 < self.config.block_group_size <= self.config.n_layers
|
| 1055 |
+
and self.config.n_layers % self.config.block_group_size == 0
|
| 1056 |
+
):
|
| 1057 |
+
raise Exception("n layers must be divisible by block group size")
|
| 1058 |
+
|
| 1059 |
+
torch.backends.cuda.enable_flash_sdp(True)
|
| 1060 |
+
torch.backends.cuda.enable_mem_efficient_sdp(False) # this is super slow so make sure torch won't use it
|
| 1061 |
+
|
| 1062 |
+
self.transformer = nn.ModuleDict(
|
| 1063 |
+
dict(
|
| 1064 |
+
wte=nn.Embedding(
|
| 1065 |
+
config.embedding_size or config.vocab_size, config.d_model, device=config.init_device
|
| 1066 |
+
),
|
| 1067 |
+
emb_drop=Dropout(config.embedding_dropout),
|
| 1068 |
+
ln_f=LayerNorm.build(config),
|
| 1069 |
+
)
|
| 1070 |
+
)
|
| 1071 |
+
|
| 1072 |
+
blocks = [LLaDABlock.build(i, config, self.__cache) for i in range(config.n_layers)]
|
| 1073 |
+
if self.config.block_group_size > 1:
|
| 1074 |
+
block_groups = [
|
| 1075 |
+
LLaDABlockGroup(config, i, blocks[i : i + config.block_group_size])
|
| 1076 |
+
for i in range(0, config.n_layers, config.block_group_size)
|
| 1077 |
+
]
|
| 1078 |
+
self.transformer.update({"block_groups": nn.ModuleList(block_groups)})
|
| 1079 |
+
else:
|
| 1080 |
+
self.transformer.update({"blocks": nn.ModuleList(blocks)})
|
| 1081 |
+
|
| 1082 |
+
if not (self.config.alibi or self.config.rope):
|
| 1083 |
+
self.transformer.update(
|
| 1084 |
+
{"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)}
|
| 1085 |
+
)
|
| 1086 |
+
if not config.weight_tying:
|
| 1087 |
+
self.transformer.update(
|
| 1088 |
+
{
|
| 1089 |
+
"ff_out": nn.Linear(
|
| 1090 |
+
config.d_model,
|
| 1091 |
+
config.embedding_size or config.vocab_size,
|
| 1092 |
+
bias=config.include_bias,
|
| 1093 |
+
device=config.init_device,
|
| 1094 |
+
)
|
| 1095 |
+
}
|
| 1096 |
+
)
|
| 1097 |
+
# When `init_device="meta"` FSDP will call `reset_parameters()` to initialize weights.
|
| 1098 |
+
if init_params and self.config.init_device != "meta":
|
| 1099 |
+
self.reset_parameters()
|
| 1100 |
+
self.__num_fwd_flops: Optional[int] = None
|
| 1101 |
+
|
| 1102 |
+
# Warm up cache.
|
| 1103 |
+
if self.config.alibi:
|
| 1104 |
+
get_causal_attention_bias(self.__cache, config.max_sequence_length, _non_meta_init_device(config))
|
| 1105 |
+
self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config))
|
| 1106 |
+
|
| 1107 |
+
def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
|
| 1108 |
+
self.activation_checkpointing_strategy = strategy
|
| 1109 |
+
if self.config.block_group_size != 1:
|
| 1110 |
+
for block_group in self.transformer.block_groups:
|
| 1111 |
+
block_group.set_activation_checkpointing(strategy)
|
| 1112 |
+
else:
|
| 1113 |
+
for block in self.transformer.blocks:
|
| 1114 |
+
block.set_activation_checkpointing(strategy)
|
| 1115 |
+
|
| 1116 |
+
@property
|
| 1117 |
+
def device(self) -> torch.device:
|
| 1118 |
+
device: torch.device = self.transformer.wte.weight.device # type: ignore
|
| 1119 |
+
if device.type == "meta":
|
| 1120 |
+
return _non_meta_init_device(self.config)
|
| 1121 |
+
else:
|
| 1122 |
+
return device
|
| 1123 |
+
|
| 1124 |
+
def reset_parameters(self):
|
| 1125 |
+
log.info("Initializing model parameters...")
|
| 1126 |
+
# Top-level embeddings / linear layers.
|
| 1127 |
+
init_weights(
|
| 1128 |
+
self.config,
|
| 1129 |
+
self.transformer.wte, # type: ignore
|
| 1130 |
+
std_factor=(0.5 * math.sqrt(self.config.d_model)) if self.config.scale_logits else 1.0,
|
| 1131 |
+
type_of_module=ModuleType.emb,
|
| 1132 |
+
)
|
| 1133 |
+
if hasattr(self.transformer, "wpe"):
|
| 1134 |
+
init_weights(self.config, self.transformer.wpe, type_of_module=ModuleType.emb) # type: ignore
|
| 1135 |
+
|
| 1136 |
+
# Top-level layer norm.
|
| 1137 |
+
self.transformer.ln_f.reset_parameters() # type: ignore
|
| 1138 |
+
|
| 1139 |
+
# Output weights.
|
| 1140 |
+
if hasattr(self.transformer, "ff_out"):
|
| 1141 |
+
init_weights(self.config, self.transformer.ff_out, type_of_module=ModuleType.final_out) # type: ignore
|
| 1142 |
+
|
| 1143 |
+
# Let the blocks handle themselves.
|
| 1144 |
+
if self.config.block_group_size == 1:
|
| 1145 |
+
for block in self.transformer.blocks:
|
| 1146 |
+
block.reset_parameters()
|
| 1147 |
+
else:
|
| 1148 |
+
for block_group in self.transformer.block_groups:
|
| 1149 |
+
block_group.reset_parameters()
|
| 1150 |
+
|
| 1151 |
+
def get_alibi_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor:
|
| 1152 |
+
if (alibi_bias := self.__cache.get("alibi_attention_bias")) is not None and alibi_bias.shape[
|
| 1153 |
+
-1
|
| 1154 |
+
] >= seq_len:
|
| 1155 |
+
if alibi_bias.device != device:
|
| 1156 |
+
alibi_bias = alibi_bias.to(device)
|
| 1157 |
+
self.__cache["alibi_attention_bias"] = alibi_bias
|
| 1158 |
+
return alibi_bias
|
| 1159 |
+
with torch.autocast(device.type, enabled=False):
|
| 1160 |
+
alibi_bias = alibi_attention_bias(seq_len, self.config, device)
|
| 1161 |
+
self.__cache["alibi_attention_bias"] = alibi_bias
|
| 1162 |
+
return alibi_bias
|
| 1163 |
+
|
| 1164 |
+
def forward(
|
| 1165 |
+
self,
|
| 1166 |
+
input_ids: torch.LongTensor,
|
| 1167 |
+
input_embeddings: Optional[torch.FloatTensor] = None,
|
| 1168 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1169 |
+
attention_bias: Optional[torch.Tensor] = None,
|
| 1170 |
+
past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
| 1171 |
+
use_cache: bool = False,
|
| 1172 |
+
update_kvcache: bool = False,
|
| 1173 |
+
last_logits_only: bool = False,
|
| 1174 |
+
output_hidden_states: Optional[bool] = None,
|
| 1175 |
+
) -> LLaDAOutput:
|
| 1176 |
+
"""
|
| 1177 |
+
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
|
| 1178 |
+
:param input_embeddings: A tensor of shape `(batch_size, seq_len, d_model)` with input
|
| 1179 |
+
embeddings. When provided, it is treated as the output of the input embedding layer.
|
| 1180 |
+
:param attention_mask: A tensor of shape `(batch_size, seq_len)` that indicates
|
| 1181 |
+
which input IDs are masked. A `1` value in the mask means that
|
| 1182 |
+
the corresponding input ID should *not* be ignored. A `0` means
|
| 1183 |
+
that the corresponding input ID is masked.
|
| 1184 |
+
|
| 1185 |
+
This has the same meaning as the `attention_mask` in HuggingFace's `transformers`
|
| 1186 |
+
library.
|
| 1187 |
+
:param attention_bias: A tensor of shape `(batch_size, 1, seq_len, seq_len)`,
|
| 1188 |
+
`(1, 1, seq_len, seq_len)`, or `(seq_len, seq_len)`. This is used
|
| 1189 |
+
to introduce causal or other biases.
|
| 1190 |
+
|
| 1191 |
+
If the tensor is a bool or byte tensor, a `True` or `1` at `attention_bias[:, :, i, j]`
|
| 1192 |
+
indicates that the i-th element in the sequence is allowed to attend to the j-th
|
| 1193 |
+
element in the sequence.
|
| 1194 |
+
|
| 1195 |
+
If the tensor is a float tensor, it will just be added to the attention
|
| 1196 |
+
scores before the softmax.
|
| 1197 |
+
|
| 1198 |
+
The default is causal, which corresponds to a lower-diagonal byte matrix of ones.
|
| 1199 |
+
:param past_key_values: Pre-computed keys and values for each attention block.
|
| 1200 |
+
Can be used to speed up sequential decoding. The `input_ids` which have
|
| 1201 |
+
their past given to this model should not be passed as `input_ids` as they have already been computed.
|
| 1202 |
+
:param use_cache: If `True`, return key and value tensors for each block.
|
| 1203 |
+
:param last_logits_only: If `True`, only compute the logits for the last token of each sequence.
|
| 1204 |
+
This can speed up decoding when you only care about the next token.
|
| 1205 |
+
"""
|
| 1206 |
+
# Add Basic MDM Model config check
|
| 1207 |
+
# print(input_ids.dtype)
|
| 1208 |
+
assert not self.config.alibi, "Alibi length extrapolation is not supported for MDM."
|
| 1209 |
+
assert self.config.rope, "Rope must be used in Llama-Encoder for MDM."
|
| 1210 |
+
# assert (past_key_values is None and not use_cache), "The kvcache is not suppotred for MDM."
|
| 1211 |
+
|
| 1212 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else False
|
| 1213 |
+
|
| 1214 |
+
if past_key_values:
|
| 1215 |
+
assert len(past_key_values) == self.config.n_layers
|
| 1216 |
+
|
| 1217 |
+
batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2]
|
| 1218 |
+
if past_key_values is None:
|
| 1219 |
+
past_length = 0
|
| 1220 |
+
else:
|
| 1221 |
+
past_length = past_key_values[0][0].size(-2)
|
| 1222 |
+
|
| 1223 |
+
# Get embeddings of input.
|
| 1224 |
+
# shape: (batch_size, seq_len, d_model)
|
| 1225 |
+
# print(input_ids.dtype,"wte")
|
| 1226 |
+
x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore
|
| 1227 |
+
|
| 1228 |
+
if self.config.input_emb_norm:
|
| 1229 |
+
x = x * (self.config.d_model**0.5)
|
| 1230 |
+
|
| 1231 |
+
if not (self.config.alibi or self.config.rope):
|
| 1232 |
+
# Get positional embeddings.
|
| 1233 |
+
# shape: (1, seq_len)
|
| 1234 |
+
pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0)
|
| 1235 |
+
# shape: (1, seq_len, d_model)
|
| 1236 |
+
pos_emb = self.transformer.wpe(pos) # type: ignore
|
| 1237 |
+
x = pos_emb + x
|
| 1238 |
+
|
| 1239 |
+
# Add input + positional embeddings and apply dropout.
|
| 1240 |
+
# shape: (batch_size, seq_len, d_model)
|
| 1241 |
+
x = self.transformer.emb_drop(x) # type: ignore
|
| 1242 |
+
|
| 1243 |
+
# Transform the attention mask into what the blocks expect.
|
| 1244 |
+
if attention_mask is not None and 0.0 in attention_mask:
|
| 1245 |
+
# shape: (batch_size, 1, 1, seq_len)
|
| 1246 |
+
attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :]
|
| 1247 |
+
attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
|
| 1248 |
+
else:
|
| 1249 |
+
attention_mask = None
|
| 1250 |
+
|
| 1251 |
+
# Merge attention mask with attention bias.
|
| 1252 |
+
if (
|
| 1253 |
+
attention_bias is not None
|
| 1254 |
+
or attention_mask is not None
|
| 1255 |
+
or self.config.alibi
|
| 1256 |
+
# NOTE (epwalsh): we need to initialize the attn bias in order for attn to work properly
|
| 1257 |
+
# with key+value cache. Otherwise `F.scaled_dot_product_attention()` doesn't seem to compute
|
| 1258 |
+
# scores correctly.
|
| 1259 |
+
or past_key_values is not None
|
| 1260 |
+
):
|
| 1261 |
+
if attention_bias is None and self.config.alibi:
|
| 1262 |
+
attention_bias = get_causal_attention_bias(
|
| 1263 |
+
self.__cache, past_length + seq_len, x.device
|
| 1264 |
+
) + self.get_alibi_attention_bias(past_length + seq_len, x.device)
|
| 1265 |
+
elif attention_bias is None:
|
| 1266 |
+
attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device)
|
| 1267 |
+
elif attention_bias.dtype in (torch.int8, torch.bool):
|
| 1268 |
+
attention_bias = attention_bias.to(dtype=torch.float)
|
| 1269 |
+
attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)
|
| 1270 |
+
|
| 1271 |
+
# Transform to the right shape and data type.
|
| 1272 |
+
mask_len = seq_len
|
| 1273 |
+
if attention_mask is not None:
|
| 1274 |
+
mask_len = attention_mask.shape[-1]
|
| 1275 |
+
elif past_key_values is not None:
|
| 1276 |
+
mask_len = past_key_values[0][0].shape[-2] + seq_len
|
| 1277 |
+
attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float)
|
| 1278 |
+
|
| 1279 |
+
# Add in the masking bias.
|
| 1280 |
+
if attention_mask is not None:
|
| 1281 |
+
attention_bias = attention_bias + attention_mask
|
| 1282 |
+
# Might get -infs after adding attention mask, since dtype.min + dtype.min = -inf.
|
| 1283 |
+
# `F.scaled_dot_product_attention()` doesn't handle -inf like you'd expect, instead
|
| 1284 |
+
# it can produce NaNs.
|
| 1285 |
+
ensure_finite_(attention_bias, check_neg_inf=True, check_pos_inf=False)
|
| 1286 |
+
|
| 1287 |
+
attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
|
| 1288 |
+
|
| 1289 |
+
# decoder layers
|
| 1290 |
+
all_hidden_states = []
|
| 1291 |
+
|
| 1292 |
+
# Apply blocks one-by-one.
|
| 1293 |
+
if self.config.block_group_size == 1:
|
| 1294 |
+
for block_idx, block in enumerate(self.transformer.blocks):
|
| 1295 |
+
if output_hidden_states:
|
| 1296 |
+
# add hidden states
|
| 1297 |
+
all_hidden_states.append(x)
|
| 1298 |
+
|
| 1299 |
+
layer_past = None if past_key_values is None else past_key_values[block_idx]
|
| 1300 |
+
if (
|
| 1301 |
+
(self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.whole_layer)
|
| 1302 |
+
or (
|
| 1303 |
+
self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_two
|
| 1304 |
+
and block_idx % 2 == 0
|
| 1305 |
+
)
|
| 1306 |
+
or (
|
| 1307 |
+
self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_three
|
| 1308 |
+
and block_idx % 3 == 0
|
| 1309 |
+
)
|
| 1310 |
+
or (
|
| 1311 |
+
self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_four
|
| 1312 |
+
and block_idx % 4 == 0
|
| 1313 |
+
)
|
| 1314 |
+
):
|
| 1315 |
+
# shape: (batch_size, seq_len, d_model)
|
| 1316 |
+
x, cache = self._activation_checkpoint_fn(
|
| 1317 |
+
block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
|
| 1318 |
+
)
|
| 1319 |
+
else:
|
| 1320 |
+
# shape: (batch_size, seq_len, d_model)
|
| 1321 |
+
x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
|
| 1322 |
+
if attn_key_values is not None:
|
| 1323 |
+
if update_kvcache:
|
| 1324 |
+
cache = (cache[0][:,:,:update_kvcache],cache[1][:,:,:update_kvcache,:])
|
| 1325 |
+
# print("True")
|
| 1326 |
+
attn_key_values.append(cache)
|
| 1327 |
+
else:
|
| 1328 |
+
for group_idx, block_group in enumerate(self.transformer.block_groups):
|
| 1329 |
+
if output_hidden_states:
|
| 1330 |
+
# add hidden states
|
| 1331 |
+
all_hidden_states.append(x)
|
| 1332 |
+
|
| 1333 |
+
layers_past = (
|
| 1334 |
+
None
|
| 1335 |
+
if past_key_values is None
|
| 1336 |
+
else past_key_values[
|
| 1337 |
+
group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size
|
| 1338 |
+
]
|
| 1339 |
+
)
|
| 1340 |
+
x, cache = block_group(
|
| 1341 |
+
x, attention_bias=attention_bias, layers_past=layers_past, use_cache=use_cache
|
| 1342 |
+
)
|
| 1343 |
+
if attn_key_values is not None:
|
| 1344 |
+
assert cache is not None
|
| 1345 |
+
attn_key_values.extend(cache)
|
| 1346 |
+
|
| 1347 |
+
if last_logits_only:
|
| 1348 |
+
# shape: (batch_size, 1, d_model)
|
| 1349 |
+
x = x[:, -1, :].unsqueeze(1)
|
| 1350 |
+
|
| 1351 |
+
# Apply final layer norm.
|
| 1352 |
+
# shape: (batch_size, seq_len or 1, d_model)
|
| 1353 |
+
x = self.transformer.ln_f(x) # type: ignore
|
| 1354 |
+
if output_hidden_states:
|
| 1355 |
+
# add final hidden state post-final-layernorm, following HuggingFace's convention
|
| 1356 |
+
all_hidden_states.append(x)
|
| 1357 |
+
|
| 1358 |
+
# Get logits.
|
| 1359 |
+
# shape: (batch_size, seq_len or 1, vocab_size)
|
| 1360 |
+
if self.config.weight_tying:
|
| 1361 |
+
logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore
|
| 1362 |
+
else:
|
| 1363 |
+
logits = self.transformer.ff_out(x) # type: ignore
|
| 1364 |
+
if self.config.scale_logits:
|
| 1365 |
+
logits.mul_(1 / math.sqrt(self.config.d_model))
|
| 1366 |
+
if use_cache == True and update_kvcache == False:
|
| 1367 |
+
attn_key_values=past_key_values
|
| 1368 |
+
return LLaDAOutput(logits=logits, attn_key_values=attn_key_values, hidden_states=tuple(all_hidden_states) if output_hidden_states else None) # type: ignore[arg-type]
|
| 1369 |
+
|
| 1370 |
+
|
| 1371 |
+
def create_model_config_from_pretrained_config(config: LLaDAConfig):
|
| 1372 |
+
"""
|
| 1373 |
+
Utility function
|
| 1374 |
+
"""
|
| 1375 |
+
|
| 1376 |
+
kwargs = {}
|
| 1377 |
+
for field in fields(ModelConfig):
|
| 1378 |
+
kwargs[field.name] = getattr(config, field.name)
|
| 1379 |
+
|
| 1380 |
+
model_config = ModelConfig(**kwargs)
|
| 1381 |
+
return model_config
|
| 1382 |
+
|
| 1383 |
+
|
| 1384 |
+
class LLaDAModelLM(PreTrainedModel):
|
| 1385 |
+
"""
|
| 1386 |
+
Extremely barebones HF model wrapper.
|
| 1387 |
+
"""
|
| 1388 |
+
|
| 1389 |
+
config_class = LLaDAConfig
|
| 1390 |
+
base_model_prefix = "model"
|
| 1391 |
+
_no_split_modules = ["LLaDABlock", "LLaDASequentialBlock", "LLaDALlamaBlock"]
|
| 1392 |
+
|
| 1393 |
+
def __init__(self, config: LLaDAConfig, model: Optional[LLaDAModel] = None, init_params: bool = False):
|
| 1394 |
+
super().__init__(config)
|
| 1395 |
+
|
| 1396 |
+
if not model:
|
| 1397 |
+
model_config = create_model_config_from_pretrained_config(config)
|
| 1398 |
+
# Initialize model (always on CPU to start with so we don't run out of GPU memory).
|
| 1399 |
+
model_config.init_device = "cpu"
|
| 1400 |
+
self.model = LLaDAModel(model_config, init_params=init_params)
|
| 1401 |
+
else:
|
| 1402 |
+
self.model = model
|
| 1403 |
+
|
| 1404 |
+
def forward(
|
| 1405 |
+
self,
|
| 1406 |
+
input_ids: torch.LongTensor = None,
|
| 1407 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1408 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1409 |
+
attention_bias: Optional[torch.Tensor] = None,
|
| 1410 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 1411 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1412 |
+
use_cache: Optional[bool] = None,
|
| 1413 |
+
update_kvcache: Optional[bool] = False,
|
| 1414 |
+
output_attentions: Optional[bool] = None,
|
| 1415 |
+
output_hidden_states: Optional[bool] = None,
|
| 1416 |
+
return_dict: Optional[bool] = None,
|
| 1417 |
+
cache_position: Optional[Cache] = None, # This is a hack mitigation of an issue in transformers `4.39.x`
|
| 1418 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 1419 |
+
if use_cache is None:
|
| 1420 |
+
use_cache = self.config.use_cache
|
| 1421 |
+
|
| 1422 |
+
if output_attentions:
|
| 1423 |
+
raise ValueError("output_attentions is not yet supported in LLaDA")
|
| 1424 |
+
|
| 1425 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1426 |
+
|
| 1427 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 1428 |
+
outputs = self.model.forward(
|
| 1429 |
+
input_ids=input_ids,
|
| 1430 |
+
input_embeddings=inputs_embeds,
|
| 1431 |
+
attention_mask=attention_mask,
|
| 1432 |
+
attention_bias=attention_bias,
|
| 1433 |
+
past_key_values=past_key_values,
|
| 1434 |
+
use_cache=use_cache,
|
| 1435 |
+
update_kvcache=update_kvcache,
|
| 1436 |
+
output_hidden_states=output_hidden_states,
|
| 1437 |
+
)
|
| 1438 |
+
|
| 1439 |
+
logits = outputs.logits
|
| 1440 |
+
hidden_states = outputs.hidden_states
|
| 1441 |
+
|
| 1442 |
+
loss = None
|
| 1443 |
+
if labels is not None:
|
| 1444 |
+
import warnings
|
| 1445 |
+
warnings.warn("Note that for LLaDA, you cannot calculate the loss here.", UserWarning)
|
| 1446 |
+
if not return_dict:
|
| 1447 |
+
output = (logits,) + outputs[1:]
|
| 1448 |
+
return (loss,) + output if loss is not None else output
|
| 1449 |
+
|
| 1450 |
+
return CausalLMOutputWithPast(
|
| 1451 |
+
logits=logits,
|
| 1452 |
+
past_key_values=outputs.attn_key_values,
|
| 1453 |
+
hidden_states=hidden_states,
|
| 1454 |
+
)
|
| 1455 |
+
|
| 1456 |
+
def can_generate(self) -> bool:
|
| 1457 |
+
return True
|
| 1458 |
+
|
| 1459 |
+
def prepare_inputs_for_generation(
|
| 1460 |
+
self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
|
| 1461 |
+
):
|
| 1462 |
+
if past_key_values:
|
| 1463 |
+
# This is because we want the model to only process the last generated token.
|
| 1464 |
+
input_ids = input_ids[:, -1:]
|
| 1465 |
+
model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
|
| 1466 |
+
|
| 1467 |
+
model_inputs.update(kwargs)
|
| 1468 |
+
model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache)
|
| 1469 |
+
return model_inputs
|
| 1470 |
+
|
| 1471 |
+
# TODO: these are required to make the implementation complete.
|
| 1472 |
+
# def resize_position_embeddings(self, new_num_position_embeddings: int):
|
| 1473 |
+
# pass
|
| 1474 |
+
#
|
| 1475 |
+
# def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
|
| 1476 |
+
# pass
|
| 1477 |
+
#
|
| 1478 |
+
# def _reorder_cache(self, past_key_values, beam_idx):
|
| 1479 |
+
# pass
|
| 1480 |
+
|
| 1481 |
+
def get_input_embeddings(self) -> torch.nn.Module:
|
| 1482 |
+
return self.model.transformer.wte
|
| 1483 |
+
|
| 1484 |
+
def set_input_embeddings(self, value: torch.nn.Module):
|
| 1485 |
+
self.model.transformer.wte = value
|
| 1486 |
+
|
| 1487 |
+
def get_output_embeddings(self):
|
| 1488 |
+
if self.config.weight_tying:
|
| 1489 |
+
return self.model.transformer.wte
|
| 1490 |
+
else:
|
| 1491 |
+
return self.model.transformer.ff_out
|
| 1492 |
+
|
| 1493 |
+
def set_output_embeddings(self, value: torch.nn.Module):
|
| 1494 |
+
if self.config.weight_tying:
|
| 1495 |
+
self.model.transformer.wte = value
|
| 1496 |
+
else:
|
| 1497 |
+
self.model.transformer.ff_out = value
|
| 1498 |
+
|
| 1499 |
+
def tie_weights(self):
|
| 1500 |
+
if self.config.weight_tying:
|
| 1501 |
+
self.model.transformer.ff_out = self.model.transformer.wte
|
| 1502 |
+
|
| 1503 |
+
# Register the model so that it is available for transformer pipelines, auto-loading, etc.
|
| 1504 |
+
AutoModel.register(LLaDAConfig, LLaDAModelLM)
|
Discrete-Diffusion-Forcing/D2F-eval/postprocess_code.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
# Modified from Dream repos: https://github.com/HKUNLP/Dream
|
| 17 |
+
|
| 18 |
+
import evaluate as hf_evaluate
|
| 19 |
+
import os
|
| 20 |
+
import sys
|
| 21 |
+
from sanitize import sanitize
|
| 22 |
+
|
| 23 |
+
os.environ["HF_ALLOW_CODE_EVAL"] = "1"
|
| 24 |
+
pass_at_k = hf_evaluate.load("code_eval")
|
| 25 |
+
|
| 26 |
+
def pass_at_1(references, predictions):
|
| 27 |
+
return pass_at_k.compute(
|
| 28 |
+
references=references,
|
| 29 |
+
predictions=predictions,
|
| 30 |
+
k=[1],
|
| 31 |
+
)[0]["pass@1"]
|
| 32 |
+
|
| 33 |
+
import json
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def read_jsonl(file_path):
|
| 37 |
+
data = []
|
| 38 |
+
with open(file_path, 'r') as file:
|
| 39 |
+
for line in file:
|
| 40 |
+
data.append(json.loads(line))
|
| 41 |
+
return data
|
| 42 |
+
|
| 43 |
+
file_path = sys.argv[1]
|
| 44 |
+
data = read_jsonl(file_path)
|
| 45 |
+
|
| 46 |
+
references = [sample['target'] for sample in data]
|
| 47 |
+
|
| 48 |
+
predictions = [[sanitize(sample['doc']['prompt'] + "\n" + sample['resps'][0][0].split('```python\n', 1)[-1].split('```')[0],
|
| 49 |
+
sample['doc']["entry_point"])]
|
| 50 |
+
for sample in data]
|
| 51 |
+
|
| 52 |
+
pass_at_1s = [pass_at_1([reference], [prediction]) for reference, prediction in zip(references, predictions)]
|
| 53 |
+
print(sum(pass_at_1s)/len(pass_at_1s))
|
| 54 |
+
|
| 55 |
+
def write_jsonl(data, file_path):
|
| 56 |
+
with open(file_path, 'w') as file:
|
| 57 |
+
for item in data:
|
| 58 |
+
file.write(json.dumps(item) + '\n')
|
| 59 |
+
|
| 60 |
+
res = [{"task_id": sample['doc']['task_id'], "completion": pred, "pass_at_1": res}
|
| 61 |
+
for sample, pred, res in zip(data, predictions, pass_at_1s)]
|
| 62 |
+
write_jsonl(res, file_path+'.cleaned')
|
Discrete-Diffusion-Forcing/D2F-eval/sanitize.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
# Modified from Dream repos: https://github.com/HKUNLP/Dream
|
| 17 |
+
|
| 18 |
+
"""Post-processing LLM-generated Python code implemented using tree-sitter."""
|
| 19 |
+
|
| 20 |
+
import os
|
| 21 |
+
import sys
|
| 22 |
+
import pathlib
|
| 23 |
+
|
| 24 |
+
ROOT = os.path.dirname(os.path.abspath(__file__))
|
| 25 |
+
sys.path.extend([os.path.dirname(ROOT), os.path.dirname(os.path.dirname(ROOT))])
|
| 26 |
+
|
| 27 |
+
import ast
|
| 28 |
+
import traceback
|
| 29 |
+
|
| 30 |
+
from typing import Dict, List, Optional, Set, Tuple
|
| 31 |
+
|
| 32 |
+
def refine_text(text: str) -> str:
|
| 33 |
+
text = text.replace("\t", " ")
|
| 34 |
+
text = text.replace("\r\n", "\n").replace("\r", "\n")
|
| 35 |
+
return text.strip() + "\n"
|
| 36 |
+
|
| 37 |
+
def syntax_check(code, verbose = False):
|
| 38 |
+
try:
|
| 39 |
+
ast.parse(code)
|
| 40 |
+
return True
|
| 41 |
+
except (SyntaxError, MemoryError):
|
| 42 |
+
if verbose:
|
| 43 |
+
traceback.print_exc()
|
| 44 |
+
return False
|
| 45 |
+
|
| 46 |
+
def extract_longest_valid_code(text: str) -> str:
|
| 47 |
+
lines = text.splitlines()
|
| 48 |
+
|
| 49 |
+
if len(lines) > 100:
|
| 50 |
+
lines = lines[:100]
|
| 51 |
+
max_valid_lines = 0
|
| 52 |
+
max_valid_snippet = ""
|
| 53 |
+
|
| 54 |
+
for i in range(len(lines)):
|
| 55 |
+
for j in range(i, len(lines)):
|
| 56 |
+
current_snippet = "\n".join(lines[i:j+1])
|
| 57 |
+
if syntax_check(current_snippet):
|
| 58 |
+
valid_line_count = sum(1 for line in lines[i:j+1] if line.strip())
|
| 59 |
+
if valid_line_count > max_valid_lines:
|
| 60 |
+
max_valid_lines = valid_line_count
|
| 61 |
+
max_valid_snippet = current_snippet
|
| 62 |
+
|
| 63 |
+
return max_valid_snippet
|
| 64 |
+
|
| 65 |
+
def get_deps(nodes: List[Tuple[str, ast.AST]]) -> Dict[str, Set[str]]:
|
| 66 |
+
name2deps = {}
|
| 67 |
+
for name, node in nodes:
|
| 68 |
+
deps = set()
|
| 69 |
+
stack = [node]
|
| 70 |
+
while stack:
|
| 71 |
+
current = stack.pop()
|
| 72 |
+
for child in ast.iter_child_nodes(current):
|
| 73 |
+
if isinstance(child, ast.Name):
|
| 74 |
+
deps.add(child.id)
|
| 75 |
+
elif isinstance(child, ast.Attribute):
|
| 76 |
+
deps.add(child.attr)
|
| 77 |
+
else:
|
| 78 |
+
stack.append(child)
|
| 79 |
+
name2deps[name] = deps
|
| 80 |
+
return name2deps
|
| 81 |
+
|
| 82 |
+
def get_function_dependency(entrypoint: str, call_graph: Dict[str, Set[str]]) -> Set[str]:
|
| 83 |
+
visited = set()
|
| 84 |
+
to_visit = [entrypoint]
|
| 85 |
+
|
| 86 |
+
while to_visit:
|
| 87 |
+
current = to_visit.pop(0)
|
| 88 |
+
if current not in visited:
|
| 89 |
+
visited.add(current)
|
| 90 |
+
to_visit.extend(call_graph.get(current, set()) - visited)
|
| 91 |
+
|
| 92 |
+
return visited
|
| 93 |
+
|
| 94 |
+
def get_definition_name(node: ast.AST) -> Optional[str]:
|
| 95 |
+
if isinstance(node, (ast.FunctionDef, ast.ClassDef)):
|
| 96 |
+
return node.name
|
| 97 |
+
elif isinstance(node, ast.Assign):
|
| 98 |
+
targets = node.targets
|
| 99 |
+
if targets and isinstance(targets[0], ast.Name):
|
| 100 |
+
return targets[0].id
|
| 101 |
+
return None
|
| 102 |
+
|
| 103 |
+
def has_return_statement(node: ast.AST) -> bool:
|
| 104 |
+
return any(isinstance(n, ast.Return) for n in ast.walk(node))
|
| 105 |
+
|
| 106 |
+
def sanitize(text: str, entrypoint: Optional[str] = None) -> str:
|
| 107 |
+
|
| 108 |
+
text = refine_text(text)
|
| 109 |
+
|
| 110 |
+
# text = python_extract(text)
|
| 111 |
+
|
| 112 |
+
code = extract_longest_valid_code(text)
|
| 113 |
+
tree = ast.parse(code)
|
| 114 |
+
|
| 115 |
+
definitions = {}
|
| 116 |
+
|
| 117 |
+
imports = []
|
| 118 |
+
|
| 119 |
+
for node in tree.body:
|
| 120 |
+
if isinstance(node, (ast.Import, ast.ImportFrom)):
|
| 121 |
+
imports.append(node)
|
| 122 |
+
elif isinstance(node, ast.ClassDef):
|
| 123 |
+
name = node.name
|
| 124 |
+
definitions[name] = ('class', node)
|
| 125 |
+
elif isinstance(node, ast.FunctionDef):
|
| 126 |
+
name = node.name
|
| 127 |
+
if has_return_statement(node):
|
| 128 |
+
definitions[name] = ('function', node)
|
| 129 |
+
elif isinstance(node, ast.Assign):
|
| 130 |
+
name = get_definition_name(node)
|
| 131 |
+
if name:
|
| 132 |
+
definitions[name] = ('variable', node)
|
| 133 |
+
|
| 134 |
+
if entrypoint:
|
| 135 |
+
name2deps = get_deps([(name, node) for name, (_, node) in definitions.items()])
|
| 136 |
+
reachable = get_function_dependency(entrypoint, name2deps)
|
| 137 |
+
|
| 138 |
+
sanitized_output = []
|
| 139 |
+
|
| 140 |
+
for node in imports:
|
| 141 |
+
sanitized_output.append(ast.unparse(node))
|
| 142 |
+
|
| 143 |
+
for name, (_, node) in definitions.items():
|
| 144 |
+
if not entrypoint or name in reachable:
|
| 145 |
+
sanitized_output.append(ast.unparse(node))
|
| 146 |
+
|
| 147 |
+
return "\n".join(sanitized_output)
|
Discrete-Diffusion-Forcing/D2F-train/config/acc_config
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
deepspeed_config:
|
| 4 |
+
gradient_accumulation_steps: 3
|
| 5 |
+
gradient_clipping: 1.0
|
| 6 |
+
offload_optimizer_device: none
|
| 7 |
+
offload_param_device: none
|
| 8 |
+
zero3_init_flag: False
|
| 9 |
+
zero_stage: 2
|
| 10 |
+
distributed_type: DEEPSPEED
|
| 11 |
+
downcast_bf16: 'no'
|
| 12 |
+
enable_cpu_affinity: True
|
| 13 |
+
machine_rank: 0
|
| 14 |
+
main_training_function: main
|
| 15 |
+
mixed_precision: fp16
|
| 16 |
+
num_machines: 1
|
| 17 |
+
num_processes: 4
|
| 18 |
+
rdzv_backend: static
|
| 19 |
+
same_network: false
|
| 20 |
+
tpu_env: []
|
| 21 |
+
tpu_use_cluster: false
|
| 22 |
+
tpu_use_sudo: false
|
| 23 |
+
use_cpu: True
|
Discrete-Diffusion-Forcing/D2F-train/config/dream_eagle.yaml
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Training mode configuration
|
| 2 |
+
training_mode: 'dream' # 'llada' or 'dream'
|
| 3 |
+
|
| 4 |
+
# Model and data path configuration
|
| 5 |
+
paths:
|
| 6 |
+
model: 'Dream-org/Dream-v0-Base-7B'
|
| 7 |
+
experiment: 'ckpt_dream_base'
|
| 8 |
+
data:
|
| 9 |
+
bs: 'Lansechen/bs17k_collection_filtered_hard_maxlength600'
|
| 10 |
+
bs_easy: 'Lansechen/bs17k_collection_filtered_easy_maxlength600'
|
| 11 |
+
|
| 12 |
+
denoiser:
|
| 13 |
+
encoder:
|
| 14 |
+
name: 'dream'
|
| 15 |
+
mask_id: 151666
|
| 16 |
+
|
| 17 |
+
decoder:
|
| 18 |
+
wiinit: true
|
| 19 |
+
name: 'eagle_rope'
|
| 20 |
+
num_blocks: 1
|
| 21 |
+
seq_len: &seq_len 1024
|
| 22 |
+
input_dim: 3584
|
| 23 |
+
hidden_dim: &dim 3584
|
| 24 |
+
vocab_size: 152064
|
| 25 |
+
block:
|
| 26 |
+
seq_len: *seq_len
|
| 27 |
+
hidden_dim: *dim
|
| 28 |
+
num_heads: 32
|
| 29 |
+
|
| 30 |
+
train:
|
| 31 |
+
# Will use paths.experiment path
|
| 32 |
+
decoder_resume_path:
|
| 33 |
+
head_resume_path:
|
| 34 |
+
skipped_keys:
|
| 35 |
+
global_step:
|
| 36 |
+
exp_name: &exp_name 'ddt_test'
|
| 37 |
+
wandb_proj: *exp_name
|
| 38 |
+
output_dir: 'ddt_test'
|
| 39 |
+
logging_dir: 'logs'
|
| 40 |
+
mixed_precision: 'fp16'
|
| 41 |
+
gradient_accumulation_steps: 5
|
| 42 |
+
report_to: 'wandb'
|
| 43 |
+
block_size: 16
|
| 44 |
+
|
| 45 |
+
lr: 5e-6
|
| 46 |
+
num_iters: 50000
|
| 47 |
+
eval_every: 100000
|
| 48 |
+
save_every: 1000
|
| 49 |
+
|
| 50 |
+
enable_shift: true
|
| 51 |
+
share_steps: 2
|
| 52 |
+
self_align: true
|
| 53 |
+
feature_align: false
|
| 54 |
+
self_step: true
|
| 55 |
+
|
| 56 |
+
data:
|
| 57 |
+
name: 'bs17k' #['numinamath', 'bs17k']
|
| 58 |
+
batch_size: 1
|
| 59 |
+
max_length: *seq_len
|
Discrete-Diffusion-Forcing/D2F-train/config/llada.yaml
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Training mode configuration
|
| 2 |
+
training_mode: 'llada' # 'llada' or 'dream'
|
| 3 |
+
|
| 4 |
+
# Model and data path configuration
|
| 5 |
+
paths:
|
| 6 |
+
model: 'GSAI-ML/LLaDA-8B-Instruct'
|
| 7 |
+
experiment: 'ckpt_llada_instruct'
|
| 8 |
+
data:
|
| 9 |
+
bs: 'Lansechen/bs17k_collection_filtered_hard_maxlength600'
|
| 10 |
+
bs_easy: 'Lansechen/bs17k_collection_filtered_easy_maxlength600'
|
| 11 |
+
|
| 12 |
+
denoiser:
|
| 13 |
+
encoder:
|
| 14 |
+
name: 'dream'
|
| 15 |
+
mask_id: 151666
|
| 16 |
+
|
| 17 |
+
decoder:
|
| 18 |
+
wiinit: true
|
| 19 |
+
name: 'eagle_rope'
|
| 20 |
+
num_blocks: 1
|
| 21 |
+
seq_len: &seq_len 1024
|
| 22 |
+
input_dim: 3584
|
| 23 |
+
hidden_dim: &dim 3584
|
| 24 |
+
vocab_size: 152064
|
| 25 |
+
block:
|
| 26 |
+
seq_len: *seq_len
|
| 27 |
+
hidden_dim: *dim
|
| 28 |
+
num_heads: 32
|
| 29 |
+
|
| 30 |
+
train:
|
| 31 |
+
# Will use paths.experiment path
|
| 32 |
+
decoder_resume_path:
|
| 33 |
+
head_resume_path:
|
| 34 |
+
skipped_keys:
|
| 35 |
+
global_step:
|
| 36 |
+
exp_name: &exp_name 'llada_ddt_maskteacher'
|
| 37 |
+
wandb_proj: *exp_name
|
| 38 |
+
output_dir: 'ddt_test'
|
| 39 |
+
logging_dir: 'logs'
|
| 40 |
+
mixed_precision: 'fp16'
|
| 41 |
+
gradient_accumulation_steps: 5
|
| 42 |
+
report_to: 'wandb'
|
| 43 |
+
block_size: 16
|
| 44 |
+
|
| 45 |
+
lr: 1e-5
|
| 46 |
+
num_iters: 50000
|
| 47 |
+
eval_every: 100000
|
| 48 |
+
save_every: 1000
|
| 49 |
+
|
| 50 |
+
enable_shift: true
|
| 51 |
+
share_steps: 2
|
| 52 |
+
self_align: true
|
| 53 |
+
feature_align: false
|
| 54 |
+
self_step: true
|
| 55 |
+
|
| 56 |
+
data:
|
| 57 |
+
name: 'bs17k' #['numinamath', 'bs17k']
|
| 58 |
+
batch_size: 1
|
| 59 |
+
max_length: *seq_len
|
Discrete-Diffusion-Forcing/D2F-train/model/configuration_llada.py
ADDED
|
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLaDA configuration
|
| 3 |
+
"""
|
| 4 |
+
from transformers import AutoConfig, PretrainedConfig
|
| 5 |
+
|
| 6 |
+
from enum import Enum
|
| 7 |
+
from os import PathLike
|
| 8 |
+
from typing import Union
|
| 9 |
+
from dataclasses import asdict, dataclass, field
|
| 10 |
+
from glob import glob
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import (
|
| 13 |
+
Any,
|
| 14 |
+
Dict,
|
| 15 |
+
Iterable,
|
| 16 |
+
List,
|
| 17 |
+
Optional,
|
| 18 |
+
Tuple,
|
| 19 |
+
Type,
|
| 20 |
+
TypeVar,
|
| 21 |
+
Union,
|
| 22 |
+
cast,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
__all__ = [
|
| 27 |
+
"ActivationType",
|
| 28 |
+
"ActivationCheckpointingStrategy",
|
| 29 |
+
"BlockType",
|
| 30 |
+
"LayerNormType",
|
| 31 |
+
"InitFnType",
|
| 32 |
+
"ModelConfig",
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
PathOrStr = Union[str, PathLike]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class StrEnum(str, Enum):
|
| 39 |
+
"""
|
| 40 |
+
This is equivalent to Python's :class:`enum.StrEnum` since version 3.11.
|
| 41 |
+
We include this here for compatibility with older version of Python.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __str__(self) -> str:
|
| 45 |
+
return self.value
|
| 46 |
+
|
| 47 |
+
def __repr__(self) -> str:
|
| 48 |
+
return f"'{str(self)}'"
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class LayerNormType(StrEnum):
|
| 52 |
+
default = "default"
|
| 53 |
+
"""
|
| 54 |
+
The default LayerNorm implementation, equivalent to PyTorch's built-in version.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
low_precision = "low_precision"
|
| 58 |
+
"""
|
| 59 |
+
A low-precision version of the default LayerNorm.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
rms = "rms"
|
| 63 |
+
"""
|
| 64 |
+
An RMSNorm implementation. When using ``torch.compile`` this is
|
| 65 |
+
probably the fastest implementation.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
gemma_rms = "gemma_rms"
|
| 69 |
+
"""
|
| 70 |
+
An RMSNorm implementation by gemmma. When using ``torch.compile`` this is
|
| 71 |
+
probably the fastest implementation.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
amd_compatible = "amd_compatible"
|
| 75 |
+
"""
|
| 76 |
+
LayerNorm implemented manually to work around an issue with ROCm.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class ActivationType(StrEnum):
|
| 81 |
+
gelu = "gelu"
|
| 82 |
+
relu = "relu"
|
| 83 |
+
silu = "silu"
|
| 84 |
+
swiglu = "swiglu"
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class BlockType(StrEnum):
|
| 88 |
+
sequential = "sequential"
|
| 89 |
+
parallel = "parallel"
|
| 90 |
+
|
| 91 |
+
llama = "llama"
|
| 92 |
+
"""
|
| 93 |
+
A block similar to the sequential block with slightly different
|
| 94 |
+
implementations of operations like attention to imitate the behavior of Llama.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class InitFnType(StrEnum):
|
| 99 |
+
mitchell = "mitchell"
|
| 100 |
+
"""
|
| 101 |
+
The strategy suggested to us by Mitchell Wortsman from UW.
|
| 102 |
+
This uses a truncated normal distribution with an adaptive standard deviation that depends
|
| 103 |
+
on the size of the weights as well as the depth of the layer.
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
normal = "normal"
|
| 107 |
+
"""
|
| 108 |
+
All weights are initialized from the same normal distribution.
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
kaiming_normal = "kaiming_normal"
|
| 112 |
+
"""
|
| 113 |
+
All weights are initialized with the Kaiming method from a normal distribution.
|
| 114 |
+
Note this currently won't work with FSDP.
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
fan_in = "fan_in"
|
| 118 |
+
"""
|
| 119 |
+
"Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in``
|
| 120 |
+
is the input dimensionality of the kernel.
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
full_megatron = "full_megatron"
|
| 124 |
+
"""
|
| 125 |
+
This is what metaseq calls "full megatron init". It is the init used for Llama 2.
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
@dataclass
|
| 130 |
+
class ModelConfig():
|
| 131 |
+
"""
|
| 132 |
+
LLaDA (model) configuration.
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
# Note that the defaults for these attributes are equivalent to the base GPT2 model.
|
| 136 |
+
|
| 137 |
+
d_model: int = 768
|
| 138 |
+
"""
|
| 139 |
+
The hidden size of the model.
|
| 140 |
+
"""
|
| 141 |
+
|
| 142 |
+
n_heads: int = 12
|
| 143 |
+
"""
|
| 144 |
+
The number of self-attention heads.
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
n_kv_heads: Optional[int] = None
|
| 148 |
+
"""
|
| 149 |
+
The number of heads to use for keys and values. Defaults to `n_heads`.
|
| 150 |
+
Set this to ``None`` or ``n_heads`` for normal multi-head attention.
|
| 151 |
+
Set this to 1 for multi-query attention.
|
| 152 |
+
Set it to some in-between value for Llama2-style grouped query attention.
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
n_layers: int = 12
|
| 156 |
+
"""
|
| 157 |
+
The number of layers/blocks.
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
mlp_ratio: int = 4
|
| 161 |
+
"""
|
| 162 |
+
The ratio of the inner MLP dimensionality to ``d_model``.
|
| 163 |
+
This is only used when ``mlp_hidden_size`` is not set.
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
mlp_hidden_size: Optional[int] = None
|
| 167 |
+
"""
|
| 168 |
+
Set the exact hidden size for the MLP. Otherwise the inner MLP hidden size will be set to `mlp_ratio * d_model`.
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
activation_type: ActivationType = ActivationType.swiglu
|
| 172 |
+
"""
|
| 173 |
+
The activation function to use within the MLP layers.
|
| 174 |
+
"""
|
| 175 |
+
|
| 176 |
+
block_type: BlockType = BlockType.sequential
|
| 177 |
+
"""
|
| 178 |
+
The transformer block implementation.
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
block_group_size: int = 1
|
| 182 |
+
"""
|
| 183 |
+
The number of blocks to group together into a single parent block.
|
| 184 |
+
This has no affect on the number of parameters in the model and is only used to wrap groups
|
| 185 |
+
of blocks together with a single FSDP wrapper during training.
|
| 186 |
+
"""
|
| 187 |
+
|
| 188 |
+
alibi: bool = False
|
| 189 |
+
"""
|
| 190 |
+
If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``.
|
| 191 |
+
"""
|
| 192 |
+
|
| 193 |
+
alibi_bias_max: float = 8.0
|
| 194 |
+
"""
|
| 195 |
+
Maximum absolute value of ALiBi bias.
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
rope: bool = False
|
| 199 |
+
"""
|
| 200 |
+
Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``.
|
| 201 |
+
"""
|
| 202 |
+
|
| 203 |
+
rope_full_precision: bool = True
|
| 204 |
+
"""
|
| 205 |
+
If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise,
|
| 206 |
+
apply RoPE at the precision of the input.
|
| 207 |
+
"""
|
| 208 |
+
|
| 209 |
+
flash_attention: bool = False
|
| 210 |
+
"""
|
| 211 |
+
If ``True``, use ``FlashAttention``.
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
attention_dropout: float = 0.1
|
| 215 |
+
"""
|
| 216 |
+
The dropout probability within the attention modules.
|
| 217 |
+
"""
|
| 218 |
+
|
| 219 |
+
multi_query_attention: Optional[bool] = None
|
| 220 |
+
"""
|
| 221 |
+
Use the Multi-Query formulation of attention used in PaLM. This reduces the number of parameters
|
| 222 |
+
and is more efficient during inference.
|
| 223 |
+
"""
|
| 224 |
+
|
| 225 |
+
attention_layer_norm: bool = False
|
| 226 |
+
"""
|
| 227 |
+
Apply layer norm to the keys and queries within the attention mechanism.
|
| 228 |
+
This can help stabilize training.
|
| 229 |
+
"""
|
| 230 |
+
|
| 231 |
+
residual_dropout: float = 0.1
|
| 232 |
+
"""
|
| 233 |
+
The dropout probability for the MLP and attention output within each block.
|
| 234 |
+
"""
|
| 235 |
+
|
| 236 |
+
embedding_dropout: float = 0.1
|
| 237 |
+
"""
|
| 238 |
+
The dropout probability for embeddings.
|
| 239 |
+
"""
|
| 240 |
+
|
| 241 |
+
input_emb_norm: bool = False
|
| 242 |
+
"""
|
| 243 |
+
An input hidden_states norm implementation by gemmma.
|
| 244 |
+
"""
|
| 245 |
+
|
| 246 |
+
layer_norm_type: LayerNormType = LayerNormType.default
|
| 247 |
+
"""
|
| 248 |
+
The layernorm implementation to use.
|
| 249 |
+
"""
|
| 250 |
+
|
| 251 |
+
layer_norm_with_affine: bool = True
|
| 252 |
+
"""
|
| 253 |
+
Whether to include bias and weight parameters for the layer norms.
|
| 254 |
+
This only affects layer norms that are immediately followed by a linear layer in the forward pass,
|
| 255 |
+
so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine`
|
| 256 |
+
to ``False``.
|
| 257 |
+
"""
|
| 258 |
+
|
| 259 |
+
rms_norm_eps: float = 1e-05
|
| 260 |
+
"""
|
| 261 |
+
The rms layernorm eps param.
|
| 262 |
+
"""
|
| 263 |
+
|
| 264 |
+
attention_layer_norm_with_affine: bool = True
|
| 265 |
+
"""
|
| 266 |
+
Toggle affine transform for the QK norms.
|
| 267 |
+
"""
|
| 268 |
+
|
| 269 |
+
max_sequence_length: int = 1024
|
| 270 |
+
"""
|
| 271 |
+
The maximum input sequence length supported by the model.
|
| 272 |
+
"""
|
| 273 |
+
|
| 274 |
+
rope_theta: float = 10000.0
|
| 275 |
+
"""
|
| 276 |
+
The rope base param.
|
| 277 |
+
"""
|
| 278 |
+
|
| 279 |
+
include_qkv_bias: Optional[bool] = False
|
| 280 |
+
"""
|
| 281 |
+
Whether or not to include bias parameters in qkv linear layers.
|
| 282 |
+
"""
|
| 283 |
+
|
| 284 |
+
include_bias: bool = False
|
| 285 |
+
"""
|
| 286 |
+
Whether or not to include bias parameters in linear layers.
|
| 287 |
+
In PaLM, they got rid of all bias terms because they found that large
|
| 288 |
+
models tend to have near 0 bias terms anyway.
|
| 289 |
+
"""
|
| 290 |
+
|
| 291 |
+
bias_for_layer_norm: Optional[bool] = None
|
| 292 |
+
"""
|
| 293 |
+
Whether or not to include bias parameters in layer norm.
|
| 294 |
+
This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in
|
| 295 |
+
layer norm.
|
| 296 |
+
When this is None (the default), it inherits the setting from include_bias.
|
| 297 |
+
"""
|
| 298 |
+
|
| 299 |
+
scale_logits: bool = False
|
| 300 |
+
"""
|
| 301 |
+
If ``True``, scale the output logits by ``1 / sqrt(d_model)``.
|
| 302 |
+
"""
|
| 303 |
+
|
| 304 |
+
vocab_size: int = 50257
|
| 305 |
+
"""
|
| 306 |
+
Vocabulary size of the model.
|
| 307 |
+
"""
|
| 308 |
+
|
| 309 |
+
embedding_size: Optional[int] = 50304
|
| 310 |
+
"""
|
| 311 |
+
The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default
|
| 312 |
+
to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the
|
| 313 |
+
next multiple of 128 that's greater than ``vocab_size`` can improve throughput
|
| 314 |
+
substantially.
|
| 315 |
+
"""
|
| 316 |
+
|
| 317 |
+
weight_tying: bool = True
|
| 318 |
+
"""
|
| 319 |
+
Whether to tie output linear weights to the input embedding.
|
| 320 |
+
"""
|
| 321 |
+
|
| 322 |
+
eos_token_id: int = 50256
|
| 323 |
+
"""
|
| 324 |
+
The ID of the end-of-sentence special token.
|
| 325 |
+
"""
|
| 326 |
+
|
| 327 |
+
pad_token_id: int = 50256
|
| 328 |
+
"""
|
| 329 |
+
The ID of the token to use for padding. Defaults to the ID of the EOS token.
|
| 330 |
+
"""
|
| 331 |
+
|
| 332 |
+
mask_token_id: Optional[int] = 50256
|
| 333 |
+
"""
|
| 334 |
+
The ID of the token to use for mask token. Defaults to the ID of the EOS token.
|
| 335 |
+
"""
|
| 336 |
+
|
| 337 |
+
init_device: Optional[str] = None
|
| 338 |
+
"""
|
| 339 |
+
The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta".
|
| 340 |
+
"""
|
| 341 |
+
|
| 342 |
+
init_fn: InitFnType = InitFnType.normal
|
| 343 |
+
"""
|
| 344 |
+
The weight initialization strategy.
|
| 345 |
+
"""
|
| 346 |
+
|
| 347 |
+
init_std: float = 0.02
|
| 348 |
+
"""
|
| 349 |
+
The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such
|
| 350 |
+
as "normal".
|
| 351 |
+
"""
|
| 352 |
+
|
| 353 |
+
init_cutoff_factor: Optional[float] = None
|
| 354 |
+
"""
|
| 355 |
+
A positive factor used to scale the cutoff values when initializing weights with a "fixed distribution" ``init_fn``, such
|
| 356 |
+
as "normal". Setting this to None means values are not cutoff.
|
| 357 |
+
"""
|
| 358 |
+
|
| 359 |
+
precision: Optional[str] = None
|
| 360 |
+
"""
|
| 361 |
+
Precision used to train/evaluate with. You shouldn't set this directly.
|
| 362 |
+
See :data:`TrainConfig.precision` instead.
|
| 363 |
+
"""
|
| 364 |
+
|
| 365 |
+
@property
|
| 366 |
+
def effective_n_kv_heads(self) -> int:
|
| 367 |
+
if self.n_kv_heads is None:
|
| 368 |
+
if self.multi_query_attention is True:
|
| 369 |
+
return 1
|
| 370 |
+
else:
|
| 371 |
+
return self.n_heads
|
| 372 |
+
else:
|
| 373 |
+
if self.multi_query_attention is None:
|
| 374 |
+
return self.n_kv_heads
|
| 375 |
+
if self.multi_query_attention:
|
| 376 |
+
n_kv_heads_should_be = 1
|
| 377 |
+
else:
|
| 378 |
+
n_kv_heads_should_be = self.n_heads
|
| 379 |
+
if self.n_kv_heads == n_kv_heads_should_be:
|
| 380 |
+
return n_kv_heads_should_be
|
| 381 |
+
else:
|
| 382 |
+
raise Exception(
|
| 383 |
+
"You can't set `multi_query_attention` and `n_kv_heads` at the same time."
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
class ActivationCheckpointingStrategy(StrEnum):
|
| 387 |
+
whole_layer = "whole_layer"
|
| 388 |
+
"""
|
| 389 |
+
Checkpoint every transformer layer.
|
| 390 |
+
"""
|
| 391 |
+
|
| 392 |
+
one_in_two = "one_in_two"
|
| 393 |
+
"""
|
| 394 |
+
Checkpoint one in two transformer layers.
|
| 395 |
+
"""
|
| 396 |
+
|
| 397 |
+
one_in_three = "one_in_three"
|
| 398 |
+
"""
|
| 399 |
+
Checkpoint one in three transformer layers.
|
| 400 |
+
"""
|
| 401 |
+
|
| 402 |
+
one_in_four = "one_in_four"
|
| 403 |
+
"""
|
| 404 |
+
Checkpoint one in four transformer layers.
|
| 405 |
+
"""
|
| 406 |
+
|
| 407 |
+
two_in_three = "two_in_three"
|
| 408 |
+
"""
|
| 409 |
+
Checkpoint two out of every three transformer layers.
|
| 410 |
+
"""
|
| 411 |
+
|
| 412 |
+
three_in_four = "three_in_four"
|
| 413 |
+
"""
|
| 414 |
+
Checkpoint three out of four of every transformer layers.
|
| 415 |
+
"""
|
| 416 |
+
|
| 417 |
+
four_in_five = "four_in_five"
|
| 418 |
+
"""
|
| 419 |
+
Checkpoint four out of five of every transformer layers.
|
| 420 |
+
"""
|
| 421 |
+
|
| 422 |
+
nine_in_ten = "nine_in_ten"
|
| 423 |
+
"""
|
| 424 |
+
Checkpoint nine out of ten of every transformer layers.
|
| 425 |
+
"""
|
| 426 |
+
|
| 427 |
+
fine_grained = "fine_grained"
|
| 428 |
+
"""
|
| 429 |
+
Focus checkpointing on where it is cheap to recompute and saves most memory.
|
| 430 |
+
"""
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
class LLaDAConfig(PretrainedConfig):
|
| 434 |
+
model_type = "llada"
|
| 435 |
+
keys_to_ignore_at_inference = ["past_key_values"] # TODO: confirm
|
| 436 |
+
|
| 437 |
+
def __init__(self, use_cache: bool = False, **kwargs):
|
| 438 |
+
model_config = ModelConfig()
|
| 439 |
+
all_kwargs = model_config.__dict__
|
| 440 |
+
all_kwargs.update(kwargs)
|
| 441 |
+
all_kwargs.update({"use_cache": use_cache})
|
| 442 |
+
all_kwargs.update(
|
| 443 |
+
{
|
| 444 |
+
"architectures": all_kwargs.get("architectures", ["LLaDAModelLM"])
|
| 445 |
+
}
|
| 446 |
+
)
|
| 447 |
+
super().__init__(**all_kwargs)
|
| 448 |
+
|
| 449 |
+
@property
|
| 450 |
+
def num_attention_heads(self):
|
| 451 |
+
return self.n_heads
|
| 452 |
+
|
| 453 |
+
@property
|
| 454 |
+
def num_hidden_layers(self):
|
| 455 |
+
return self.n_layers
|
| 456 |
+
|
| 457 |
+
@property
|
| 458 |
+
def hidden_size(self):
|
| 459 |
+
return self.d_model
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
# Register the config class so that it is available for transformer pipelines, auto-loading etc.
|
| 463 |
+
AutoConfig.register("llada", LLaDAConfig)
|
Discrete-Diffusion-Forcing/D2F-train/model/modeling_llada.py
ADDED
|
@@ -0,0 +1,1500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
import sys
|
| 6 |
+
from abc import abstractmethod
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
from functools import partial
|
| 9 |
+
from typing import (
|
| 10 |
+
Callable,
|
| 11 |
+
Dict,
|
| 12 |
+
Iterable,
|
| 13 |
+
List,
|
| 14 |
+
NamedTuple,
|
| 15 |
+
Optional,
|
| 16 |
+
Sequence,
|
| 17 |
+
Set,
|
| 18 |
+
Tuple,
|
| 19 |
+
cast,
|
| 20 |
+
)
|
| 21 |
+
from dataclasses import fields
|
| 22 |
+
from typing import List, Optional, Tuple, Union
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
import torch.backends.cuda
|
| 26 |
+
import torch.nn as nn
|
| 27 |
+
import torch.nn.functional as F
|
| 28 |
+
from torch import einsum
|
| 29 |
+
from transformers import PreTrainedModel
|
| 30 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 31 |
+
from transformers.models.auto import AutoModel
|
| 32 |
+
from transformers.cache_utils import Cache
|
| 33 |
+
|
| 34 |
+
from .configuration_llada import (
|
| 35 |
+
LLaDAConfig,
|
| 36 |
+
StrEnum,
|
| 37 |
+
InitFnType,
|
| 38 |
+
ActivationType,
|
| 39 |
+
BlockType,
|
| 40 |
+
LayerNormType,
|
| 41 |
+
ModelConfig,
|
| 42 |
+
ActivationCheckpointingStrategy,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
if sys.version_info.minor > 8:
|
| 46 |
+
from collections.abc import MutableMapping
|
| 47 |
+
elif sys.version_info.minor == 8:
|
| 48 |
+
from typing import MutableMapping
|
| 49 |
+
else:
|
| 50 |
+
raise SystemExit("This script supports Python 3.8 or higher")
|
| 51 |
+
|
| 52 |
+
__all__ = [
|
| 53 |
+
"LayerNormBase",
|
| 54 |
+
"LayerNorm",
|
| 55 |
+
"RMSLayerNorm",
|
| 56 |
+
"GemmaRMSLayerNorm",
|
| 57 |
+
"RotaryEmbedding",
|
| 58 |
+
"Activation",
|
| 59 |
+
"GELU",
|
| 60 |
+
"ReLU",
|
| 61 |
+
"SwiGLU",
|
| 62 |
+
"LLaDABlock",
|
| 63 |
+
"LLaDASequentialBlock",
|
| 64 |
+
"LLaDAModel",
|
| 65 |
+
"LLaDAOutput",
|
| 66 |
+
"LLaDAGenerateOutput",
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
log = logging.getLogger(__name__)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class ModuleType(StrEnum):
|
| 74 |
+
in_module = "in"
|
| 75 |
+
out_module = "out"
|
| 76 |
+
emb = "emb"
|
| 77 |
+
final_out = "final_out"
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def init_weights(
|
| 81 |
+
config: ModelConfig,
|
| 82 |
+
module: Union[nn.Linear, nn.Embedding],
|
| 83 |
+
d: Optional[int] = None,
|
| 84 |
+
layer_id: Optional[int] = None,
|
| 85 |
+
std_factor: float = 1.0,
|
| 86 |
+
type_of_module: Optional[ModuleType] = None,
|
| 87 |
+
) -> None:
|
| 88 |
+
"""
|
| 89 |
+
Initialize weights of a linear or embedding module.
|
| 90 |
+
|
| 91 |
+
:param config: The model config.
|
| 92 |
+
:param module: The linear or embedding submodule to initialize.
|
| 93 |
+
:param d: The effective input dimensionality of the weights. This could be smaller than the actual dimensions
|
| 94 |
+
for fused layers.
|
| 95 |
+
:param layer_id: When set, the standard deviation for the "mitchell" method will be adjusted by
|
| 96 |
+
``1 / sqrt(2 * (layer_id + 1))``.
|
| 97 |
+
"""
|
| 98 |
+
d = d if d is not None else config.d_model
|
| 99 |
+
if config.init_fn == InitFnType.normal:
|
| 100 |
+
std = config.init_std * std_factor
|
| 101 |
+
if config.init_cutoff_factor is not None:
|
| 102 |
+
cutoff_value = config.init_cutoff_factor * std
|
| 103 |
+
nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value)
|
| 104 |
+
else:
|
| 105 |
+
nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 106 |
+
elif config.init_fn == InitFnType.mitchell:
|
| 107 |
+
std = std_factor / math.sqrt(d)
|
| 108 |
+
if layer_id is not None:
|
| 109 |
+
std = std / math.sqrt(2 * (layer_id + 1))
|
| 110 |
+
nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-3 * std, b=3 * std)
|
| 111 |
+
elif config.init_fn == InitFnType.kaiming_normal:
|
| 112 |
+
nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
|
| 113 |
+
elif config.init_fn == InitFnType.fan_in:
|
| 114 |
+
std = std_factor / math.sqrt(d)
|
| 115 |
+
nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 116 |
+
elif config.init_fn == InitFnType.full_megatron:
|
| 117 |
+
if type_of_module is None:
|
| 118 |
+
raise RuntimeError(f"When using the {InitFnType.full_megatron} init, every module must have a type.")
|
| 119 |
+
|
| 120 |
+
cutoff_factor = config.init_cutoff_factor
|
| 121 |
+
if cutoff_factor is None:
|
| 122 |
+
cutoff_factor = 3
|
| 123 |
+
|
| 124 |
+
if type_of_module == ModuleType.in_module:
|
| 125 |
+
# for att_proj (same as QKV), ff_proj
|
| 126 |
+
std = config.init_std
|
| 127 |
+
elif type_of_module == ModuleType.out_module:
|
| 128 |
+
# for attn_out, ff_out
|
| 129 |
+
std = config.init_std / math.sqrt(2.0 * config.n_layers)
|
| 130 |
+
elif type_of_module == ModuleType.emb:
|
| 131 |
+
# positional embeddings (wpe)
|
| 132 |
+
# token embeddings (wte)
|
| 133 |
+
std = config.init_std
|
| 134 |
+
elif type_of_module == ModuleType.final_out:
|
| 135 |
+
# final output (ff_out)
|
| 136 |
+
std = config.d_model**-0.5
|
| 137 |
+
else:
|
| 138 |
+
raise RuntimeError(f"Unknown module type '{type_of_module}'")
|
| 139 |
+
nn.init.trunc_normal_(
|
| 140 |
+
module.weight,
|
| 141 |
+
mean=0.0,
|
| 142 |
+
std=std,
|
| 143 |
+
a=-cutoff_factor * std,
|
| 144 |
+
b=cutoff_factor * std,
|
| 145 |
+
)
|
| 146 |
+
else:
|
| 147 |
+
raise NotImplementedError(config.init_fn)
|
| 148 |
+
|
| 149 |
+
if isinstance(module, nn.Linear):
|
| 150 |
+
if module.bias is not None:
|
| 151 |
+
nn.init.zeros_(module.bias)
|
| 152 |
+
|
| 153 |
+
if config.init_fn == InitFnType.normal and getattr(module, "_is_residual", False):
|
| 154 |
+
with torch.no_grad():
|
| 155 |
+
module.weight.div_(math.sqrt(2 * config.n_layers))
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False):
|
| 159 |
+
"""
|
| 160 |
+
Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf``
|
| 161 |
+
is ``True`` and to replace ``float("inf")`` with the maximum value of the dtype when ``check_pos_inf`` is ``True``.
|
| 162 |
+
"""
|
| 163 |
+
if check_neg_inf:
|
| 164 |
+
x.masked_fill_(x == float("-inf"), torch.finfo(x.dtype).min)
|
| 165 |
+
if check_pos_inf:
|
| 166 |
+
x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def activation_checkpoint_function(cfg: ModelConfig):
|
| 170 |
+
preserve_rng_state = (
|
| 171 |
+
(cfg.attention_dropout == 0.0) and (cfg.embedding_dropout == 0.0) and (cfg.residual_dropout == 0.0)
|
| 172 |
+
)
|
| 173 |
+
from torch.utils.checkpoint import checkpoint
|
| 174 |
+
|
| 175 |
+
return partial(
|
| 176 |
+
checkpoint,
|
| 177 |
+
preserve_rng_state=preserve_rng_state,
|
| 178 |
+
use_reentrant=False,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class BufferCache(dict, MutableMapping[str, torch.Tensor]):
|
| 183 |
+
"""
|
| 184 |
+
Cache for attention biases and other things that would normally be stored as buffers.
|
| 185 |
+
We avoid using buffers because we've run into various issues doing so with FSDP.
|
| 186 |
+
In general it appears the way FSDP handles buffers is not well-defined.
|
| 187 |
+
It doesn't shard them but apparently it does synchronize them across processes, which we want to avoid
|
| 188 |
+
since (A) it isn't necessary, and (B) we sometimes have `-inf` in these biases which might get turned into
|
| 189 |
+
NaNs when they're synchronized due to casting or some other issue.
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def _non_meta_init_device(config: ModelConfig) -> torch.device:
|
| 194 |
+
if config.init_device is not None and config.init_device != "meta":
|
| 195 |
+
return torch.device(config.init_device)
|
| 196 |
+
else:
|
| 197 |
+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class Dropout(nn.Dropout):
|
| 201 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 202 |
+
if self.p == 0.0:
|
| 203 |
+
return input
|
| 204 |
+
else:
|
| 205 |
+
return F.dropout(input, self.p, self.training, self.inplace)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class LayerNormBase(nn.Module):
|
| 209 |
+
def __init__(
|
| 210 |
+
self,
|
| 211 |
+
config: ModelConfig,
|
| 212 |
+
*,
|
| 213 |
+
size: Optional[int] = None,
|
| 214 |
+
elementwise_affine: Optional[bool] = True,
|
| 215 |
+
eps: float = 1e-05,
|
| 216 |
+
):
|
| 217 |
+
super().__init__()
|
| 218 |
+
self.config = config
|
| 219 |
+
self.eps = eps
|
| 220 |
+
self.normalized_shape = (size or config.d_model,)
|
| 221 |
+
if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine):
|
| 222 |
+
self.weight = nn.Parameter(torch.ones(self.normalized_shape, device=config.init_device))
|
| 223 |
+
use_bias = self.config.bias_for_layer_norm
|
| 224 |
+
if use_bias is None:
|
| 225 |
+
use_bias = self.config.include_bias
|
| 226 |
+
if use_bias:
|
| 227 |
+
self.bias = nn.Parameter(torch.zeros(self.normalized_shape, device=config.init_device))
|
| 228 |
+
else:
|
| 229 |
+
self.register_parameter("bias", None)
|
| 230 |
+
else:
|
| 231 |
+
self.register_parameter("bias", None)
|
| 232 |
+
self.register_parameter("weight", None)
|
| 233 |
+
|
| 234 |
+
@abstractmethod
|
| 235 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 236 |
+
raise NotImplementedError
|
| 237 |
+
|
| 238 |
+
@classmethod
|
| 239 |
+
def build(cls, config: ModelConfig, size: Optional[int] = None, **kwargs) -> LayerNormBase:
|
| 240 |
+
if config.layer_norm_type == LayerNormType.default:
|
| 241 |
+
return LayerNorm(config, size=size, low_precision=False, **kwargs)
|
| 242 |
+
elif config.layer_norm_type == LayerNormType.low_precision:
|
| 243 |
+
return LayerNorm(config, size=size, low_precision=True, **kwargs)
|
| 244 |
+
elif config.layer_norm_type == LayerNormType.rms:
|
| 245 |
+
return RMSLayerNorm(config, size=size, **kwargs)
|
| 246 |
+
elif config.layer_norm_type == LayerNormType.gemma_rms:
|
| 247 |
+
return GemmaRMSLayerNorm(config, size=size, **kwargs)
|
| 248 |
+
else:
|
| 249 |
+
raise NotImplementedError(f"Unknown LayerNorm type: '{config.layer_norm_type}'")
|
| 250 |
+
|
| 251 |
+
def _cast_if_autocast_enabled(self, tensor: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
|
| 252 |
+
# NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
|
| 253 |
+
# `is_autocast_cpu_enabled()` for CPU autocast.
|
| 254 |
+
# See https://github.com/pytorch/pytorch/issues/110966.
|
| 255 |
+
if tensor.device.type == "cuda" and torch.is_autocast_enabled():
|
| 256 |
+
return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_gpu_dtype())
|
| 257 |
+
elif tensor.device.type == "cpu" and torch.is_autocast_cpu_enabled():
|
| 258 |
+
return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_cpu_dtype())
|
| 259 |
+
else:
|
| 260 |
+
return tensor
|
| 261 |
+
|
| 262 |
+
def reset_parameters(self):
|
| 263 |
+
if self.weight is not None:
|
| 264 |
+
torch.nn.init.ones_(self.weight) # type: ignore
|
| 265 |
+
if self.bias is not None:
|
| 266 |
+
torch.nn.init.zeros_(self.bias) # type: ignore
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class LayerNorm(LayerNormBase):
|
| 270 |
+
"""
|
| 271 |
+
The default :class:`LayerNorm` implementation which can optionally run in low precision.
|
| 272 |
+
"""
|
| 273 |
+
|
| 274 |
+
def __init__(
|
| 275 |
+
self,
|
| 276 |
+
config: ModelConfig,
|
| 277 |
+
size: Optional[int] = None,
|
| 278 |
+
low_precision: bool = False,
|
| 279 |
+
elementwise_affine: Optional[bool] = None,
|
| 280 |
+
eps: float = 1e-05,
|
| 281 |
+
):
|
| 282 |
+
super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps)
|
| 283 |
+
self.low_precision = low_precision
|
| 284 |
+
|
| 285 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 286 |
+
if self.low_precision:
|
| 287 |
+
module_device = x.device
|
| 288 |
+
downcast_x = self._cast_if_autocast_enabled(x)
|
| 289 |
+
downcast_weight = (
|
| 290 |
+
self._cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
|
| 291 |
+
)
|
| 292 |
+
downcast_bias = self._cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
|
| 293 |
+
with torch.autocast(enabled=False, device_type=module_device.type):
|
| 294 |
+
return F.layer_norm(
|
| 295 |
+
downcast_x, self.normalized_shape, weight=downcast_weight, bias=downcast_bias, eps=self.eps
|
| 296 |
+
)
|
| 297 |
+
else:
|
| 298 |
+
return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class RMSLayerNorm(LayerNormBase):
|
| 302 |
+
"""
|
| 303 |
+
RMS layer norm, a simplified :class:`LayerNorm` implementation
|
| 304 |
+
"""
|
| 305 |
+
|
| 306 |
+
def __init__(
|
| 307 |
+
self,
|
| 308 |
+
config: ModelConfig,
|
| 309 |
+
size: Optional[int] = None,
|
| 310 |
+
elementwise_affine: Optional[bool] = None,
|
| 311 |
+
eps: float = 1e-5,
|
| 312 |
+
):
|
| 313 |
+
super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=config.rms_norm_eps)
|
| 314 |
+
|
| 315 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 316 |
+
# with torch.autocast(enabled=False, device_type=x.device.type):
|
| 317 |
+
og_dtype = x.dtype
|
| 318 |
+
x = x.to(torch.float32)
|
| 319 |
+
# print(x.dtype,x.shape)
|
| 320 |
+
variance = x*x
|
| 321 |
+
# print(variance)
|
| 322 |
+
variance = variance.mean(dim=-1,keepdim=True)
|
| 323 |
+
x = x * torch.rsqrt(variance + self.eps)
|
| 324 |
+
x = x.to(og_dtype)
|
| 325 |
+
|
| 326 |
+
if self.weight is not None:
|
| 327 |
+
if self.bias is not None:
|
| 328 |
+
return self.weight * x + self.bias
|
| 329 |
+
else:
|
| 330 |
+
return self.weight * x
|
| 331 |
+
else:
|
| 332 |
+
return x
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
class GemmaRMSLayerNorm(LayerNormBase):
|
| 336 |
+
"""
|
| 337 |
+
Gemma RMS layer norm, a simplified :class:`LayerNorm` implementation
|
| 338 |
+
"""
|
| 339 |
+
|
| 340 |
+
def __init__(
|
| 341 |
+
self,
|
| 342 |
+
config: ModelConfig,
|
| 343 |
+
size: Optional[int] = None,
|
| 344 |
+
elementwise_affine: Optional[bool] = None,
|
| 345 |
+
eps: float = 1e-5,
|
| 346 |
+
):
|
| 347 |
+
super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=config.rms_norm_eps)
|
| 348 |
+
|
| 349 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 350 |
+
with torch.autocast(enabled=False, device_type=x.device.type):
|
| 351 |
+
og_dtype = x.dtype
|
| 352 |
+
x = x.to(torch.float32)
|
| 353 |
+
variance = x.pow(2).mean(-1, keepdim=True)
|
| 354 |
+
x = x * torch.rsqrt(variance + self.eps)
|
| 355 |
+
x = x.to(og_dtype)
|
| 356 |
+
|
| 357 |
+
if self.weight is not None:
|
| 358 |
+
if self.bias is not None:
|
| 359 |
+
return x * (1 + self.weight) + self.bias
|
| 360 |
+
else:
|
| 361 |
+
return x * (1 + self.weight)
|
| 362 |
+
else:
|
| 363 |
+
return x
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
class RotaryEmbedding(nn.Module):
|
| 367 |
+
"""
|
| 368 |
+
[Rotary positional embeddings (RoPE)](https://arxiv.org/abs/2104.09864).
|
| 369 |
+
"""
|
| 370 |
+
|
| 371 |
+
def __init__(self, config: ModelConfig, cache: BufferCache):
|
| 372 |
+
super().__init__()
|
| 373 |
+
self.config = config
|
| 374 |
+
self.__cache = cache
|
| 375 |
+
# Warm up cache.
|
| 376 |
+
self.rope_theta = config.rope_theta
|
| 377 |
+
self.get_rotary_embedding(config.max_sequence_length, _non_meta_init_device(config))
|
| 378 |
+
|
| 379 |
+
def get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 380 |
+
if (
|
| 381 |
+
(pos_sin := self.__cache.get("rope_pos_sin")) is not None
|
| 382 |
+
and (pos_cos := self.__cache.get("rope_pos_cos")) is not None
|
| 383 |
+
and pos_sin.shape[-2] >= seq_len
|
| 384 |
+
and pos_cos.shape[-2] >= seq_len
|
| 385 |
+
):
|
| 386 |
+
if pos_sin.device != device:
|
| 387 |
+
pos_sin = pos_sin.to(device)
|
| 388 |
+
self.__cache["rope_pos_sin"] = pos_sin
|
| 389 |
+
if pos_cos.device != device:
|
| 390 |
+
pos_cos = pos_cos.to(device)
|
| 391 |
+
self.__cache["rope_pos_cos"] = pos_cos
|
| 392 |
+
return pos_sin[:, :, :seq_len, :], pos_cos[:, :, :seq_len, :]
|
| 393 |
+
|
| 394 |
+
with torch.autocast(device.type, enabled=False):
|
| 395 |
+
dim = self.config.d_model // self.config.n_heads
|
| 396 |
+
inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim))
|
| 397 |
+
seq = torch.arange(seq_len, device=device, dtype=torch.float)
|
| 398 |
+
freqs = einsum("i , j -> i j", seq, inv_freq)
|
| 399 |
+
positions = torch.cat((freqs, freqs), dim=-1)
|
| 400 |
+
pos_sin, pos_cos = positions.sin()[None, None, :, :], positions.cos()[None, None, :, :]
|
| 401 |
+
self.__cache["rope_pos_sin"] = pos_sin
|
| 402 |
+
self.__cache["rope_pos_cos"] = pos_cos
|
| 403 |
+
return pos_sin, pos_cos
|
| 404 |
+
|
| 405 |
+
def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
|
| 406 |
+
B, nh, T, hs = x.size()
|
| 407 |
+
x = x.view(B, nh, T, 2, hs // 2)
|
| 408 |
+
x1, x2 = x.unbind(dim=-2)
|
| 409 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 410 |
+
|
| 411 |
+
def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
| 412 |
+
return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype)
|
| 413 |
+
|
| 414 |
+
def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 415 |
+
if self.config.rope_full_precision:
|
| 416 |
+
q_, k_ = q.float(), k.float()
|
| 417 |
+
else:
|
| 418 |
+
q_, k_ = q, k
|
| 419 |
+
|
| 420 |
+
with torch.autocast(q.device.type, enabled=False):
|
| 421 |
+
query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None
|
| 422 |
+
pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device)
|
| 423 |
+
pos_sin = pos_sin.type_as(q_)
|
| 424 |
+
pos_cos = pos_cos.type_as(q_)
|
| 425 |
+
q_ = self.apply_rotary_pos_emb(
|
| 426 |
+
pos_sin[:, :, key_len - query_len : key_len, :],
|
| 427 |
+
pos_cos[:, :, key_len - query_len : key_len, :],
|
| 428 |
+
q_,
|
| 429 |
+
)
|
| 430 |
+
k_ = self.apply_rotary_pos_emb(pos_sin, pos_cos, k_)
|
| 431 |
+
return q_.type_as(q), k_.type_as(k)
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
class Activation(nn.Module):
|
| 435 |
+
def __init__(self, config: ModelConfig):
|
| 436 |
+
super().__init__()
|
| 437 |
+
self.config = config
|
| 438 |
+
|
| 439 |
+
@abstractmethod
|
| 440 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 441 |
+
raise NotImplementedError
|
| 442 |
+
|
| 443 |
+
@property
|
| 444 |
+
@abstractmethod
|
| 445 |
+
def output_multiplier(self) -> float:
|
| 446 |
+
raise NotImplementedError
|
| 447 |
+
|
| 448 |
+
@classmethod
|
| 449 |
+
def build(cls, config: ModelConfig) -> Activation:
|
| 450 |
+
if config.activation_type == ActivationType.gelu:
|
| 451 |
+
return cast(Activation, GELU(approximate="none"))
|
| 452 |
+
elif config.activation_type == ActivationType.relu:
|
| 453 |
+
return cast(Activation, ReLU(inplace=False))
|
| 454 |
+
elif config.activation_type == ActivationType.silu:
|
| 455 |
+
return cast(Activation, SiLU(inplace=False))
|
| 456 |
+
elif config.activation_type == ActivationType.swiglu:
|
| 457 |
+
return SwiGLU(config)
|
| 458 |
+
else:
|
| 459 |
+
raise NotImplementedError(f"Unknown activation: '{config.activation_type}'")
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
class GELU(nn.GELU):
|
| 463 |
+
@property
|
| 464 |
+
def output_multiplier(self) -> float:
|
| 465 |
+
return 1.0
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
class ReLU(nn.ReLU):
|
| 469 |
+
@property
|
| 470 |
+
def output_multiplier(self) -> float:
|
| 471 |
+
return 1.0
|
| 472 |
+
|
| 473 |
+
class SiLU(nn.SiLU):
|
| 474 |
+
@property
|
| 475 |
+
def output_multiplier(self) -> float:
|
| 476 |
+
return 1.0
|
| 477 |
+
|
| 478 |
+
class SwiGLU(Activation):
|
| 479 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 480 |
+
x, gate = x.chunk(2, dim=-1)
|
| 481 |
+
return F.silu(gate) * x
|
| 482 |
+
|
| 483 |
+
@property
|
| 484 |
+
def output_multiplier(self) -> float:
|
| 485 |
+
return 0.5
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor:
|
| 489 |
+
att_bias = torch.triu(
|
| 490 |
+
torch.ones(seq_len, seq_len, device=device, dtype=torch.float),
|
| 491 |
+
diagonal=1,
|
| 492 |
+
)
|
| 493 |
+
att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min)
|
| 494 |
+
return att_bias.view(1, 1, seq_len, seq_len) # type: ignore
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
def get_causal_attention_bias(cache: BufferCache, seq_len: int, device: torch.device) -> torch.Tensor:
|
| 498 |
+
if (causal_bias := cache.get("causal_attention_bias")) is not None and causal_bias.shape[-1] >= seq_len:
|
| 499 |
+
if causal_bias.device != device:
|
| 500 |
+
causal_bias = causal_bias.to(device)
|
| 501 |
+
cache["causal_attention_bias"] = causal_bias
|
| 502 |
+
return causal_bias
|
| 503 |
+
with torch.autocast(device.type, enabled=False):
|
| 504 |
+
causal_bias = causal_attention_bias(seq_len, device)
|
| 505 |
+
cache["causal_attention_bias"] = causal_bias
|
| 506 |
+
return causal_bias
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
def alibi_attention_bias(seq_len: int, config: ModelConfig, device: torch.device) -> torch.FloatTensor:
|
| 510 |
+
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len)
|
| 511 |
+
|
| 512 |
+
# shape: (1, 1, seq_len, seq_len)
|
| 513 |
+
alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1)
|
| 514 |
+
alibi_bias.abs_().mul_(-1)
|
| 515 |
+
|
| 516 |
+
# shape: (n_heads,)
|
| 517 |
+
m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device)
|
| 518 |
+
m.mul_(config.alibi_bias_max / config.n_heads)
|
| 519 |
+
|
| 520 |
+
# shape: (1, n_heads, seq_len, seq_len)
|
| 521 |
+
return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1))) # type: ignore
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
class LLaDABlock(nn.Module):
|
| 525 |
+
"""
|
| 526 |
+
A base class for transformer block implementations.
|
| 527 |
+
"""
|
| 528 |
+
|
| 529 |
+
def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
|
| 530 |
+
super().__init__()
|
| 531 |
+
self.layer_id = layer_id
|
| 532 |
+
self.config = config
|
| 533 |
+
self.hidden_size = (
|
| 534 |
+
config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
|
| 535 |
+
)
|
| 536 |
+
self.__cache = cache
|
| 537 |
+
assert config.d_model % config.n_heads == 0
|
| 538 |
+
|
| 539 |
+
self._activation_checkpoint_fn = None
|
| 540 |
+
|
| 541 |
+
# Dropout.
|
| 542 |
+
self.dropout = Dropout(config.residual_dropout)
|
| 543 |
+
|
| 544 |
+
# Layer norms.
|
| 545 |
+
self.k_norm: Optional[LayerNormBase] = None
|
| 546 |
+
self.q_norm: Optional[LayerNormBase] = None
|
| 547 |
+
if config.attention_layer_norm:
|
| 548 |
+
self.k_norm = LayerNormBase.build(
|
| 549 |
+
config,
|
| 550 |
+
size=(config.d_model // config.n_heads) * config.effective_n_kv_heads,
|
| 551 |
+
elementwise_affine=config.attention_layer_norm_with_affine,
|
| 552 |
+
)
|
| 553 |
+
self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine)
|
| 554 |
+
|
| 555 |
+
# Activation function.
|
| 556 |
+
self.act = Activation.build(config)
|
| 557 |
+
assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
|
| 558 |
+
|
| 559 |
+
# Attention output projection.
|
| 560 |
+
self.attn_out = nn.Linear(
|
| 561 |
+
config.d_model, config.d_model, bias=config.include_bias, device=config.init_device
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
# Feed-forward output projection.
|
| 565 |
+
self.ff_out = nn.Linear(
|
| 566 |
+
int(self.act.output_multiplier * self.hidden_size),
|
| 567 |
+
config.d_model,
|
| 568 |
+
bias=config.include_bias,
|
| 569 |
+
device=config.init_device,
|
| 570 |
+
)
|
| 571 |
+
self.ff_out._is_residual = True # type: ignore
|
| 572 |
+
|
| 573 |
+
# Rotary embeddings.
|
| 574 |
+
if self.config.rope:
|
| 575 |
+
self.rotary_emb = RotaryEmbedding(config, self.__cache)
|
| 576 |
+
|
| 577 |
+
self.flash_attn_func = None
|
| 578 |
+
if config.flash_attention:
|
| 579 |
+
try:
|
| 580 |
+
from flash_attn import flash_attn_func # type: ignore
|
| 581 |
+
|
| 582 |
+
self.flash_attn_func = flash_attn_func
|
| 583 |
+
except ModuleNotFoundError:
|
| 584 |
+
pass
|
| 585 |
+
|
| 586 |
+
def reset_parameters(self):
|
| 587 |
+
if self.k_norm is not None:
|
| 588 |
+
self.k_norm.reset_parameters()
|
| 589 |
+
if self.q_norm is not None:
|
| 590 |
+
self.q_norm.reset_parameters()
|
| 591 |
+
init_weights(
|
| 592 |
+
self.config,
|
| 593 |
+
self.attn_out,
|
| 594 |
+
d=self.config.d_model,
|
| 595 |
+
layer_id=self.layer_id,
|
| 596 |
+
type_of_module=ModuleType.out_module,
|
| 597 |
+
)
|
| 598 |
+
init_weights(
|
| 599 |
+
self.config,
|
| 600 |
+
self.ff_out,
|
| 601 |
+
d=self.ff_out.in_features,
|
| 602 |
+
layer_id=self.layer_id,
|
| 603 |
+
type_of_module=ModuleType.out_module,
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
|
| 607 |
+
if strategy == ActivationCheckpointingStrategy.fine_grained:
|
| 608 |
+
self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
|
| 609 |
+
else:
|
| 610 |
+
self._activation_checkpoint_fn = None
|
| 611 |
+
|
| 612 |
+
@classmethod
|
| 613 |
+
def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor:
|
| 614 |
+
target_dtype = input_dtype
|
| 615 |
+
# NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
|
| 616 |
+
# `is_autocast_cpu_enabled()` for CPU autocast.
|
| 617 |
+
# See https://github.com/pytorch/pytorch/issues/110966.
|
| 618 |
+
if bias.device.type == "cuda" and torch.is_autocast_enabled():
|
| 619 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
| 620 |
+
elif bias.device.type == "cpu" and torch.is_autocast_cpu_enabled():
|
| 621 |
+
target_dtype = torch.get_autocast_cpu_dtype()
|
| 622 |
+
if bias.dtype != target_dtype:
|
| 623 |
+
bias = bias.to(target_dtype)
|
| 624 |
+
ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False)
|
| 625 |
+
return bias
|
| 626 |
+
|
| 627 |
+
def _scaled_dot_product_attention(
|
| 628 |
+
self,
|
| 629 |
+
q: torch.Tensor,
|
| 630 |
+
k: torch.Tensor,
|
| 631 |
+
v: torch.Tensor,
|
| 632 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 633 |
+
dropout_p: float = 0.0,
|
| 634 |
+
is_causal: bool = False,
|
| 635 |
+
) -> torch.Tensor:
|
| 636 |
+
"""
|
| 637 |
+
Computes scaled dot product attention on query, key and value tensors, using an optional
|
| 638 |
+
attention mask if passed, and applying dropout if a probability greater than 0.0 is specified.
|
| 639 |
+
"""
|
| 640 |
+
if self.flash_attn_func is not None and attn_mask is None:
|
| 641 |
+
r = self.flash_attn_func(
|
| 642 |
+
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p, causal=False
|
| 643 |
+
)
|
| 644 |
+
return r.transpose(1, 2)
|
| 645 |
+
else:
|
| 646 |
+
# torch's sdpa doesn't support GQA, so we're doing this
|
| 647 |
+
assert k.size(1) == v.size(1)
|
| 648 |
+
num_kv_heads = k.size(1)
|
| 649 |
+
num_q_heads = q.size(1)
|
| 650 |
+
if num_q_heads != num_kv_heads:
|
| 651 |
+
assert num_q_heads % num_kv_heads == 0
|
| 652 |
+
k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
|
| 653 |
+
v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
|
| 654 |
+
# Modify: MDM set causal to False, and with no attn_mask.
|
| 655 |
+
return F.scaled_dot_product_attention(
|
| 656 |
+
q,
|
| 657 |
+
k,
|
| 658 |
+
v,
|
| 659 |
+
attn_mask=attn_mask,
|
| 660 |
+
dropout_p=dropout_p,
|
| 661 |
+
is_causal=False,
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
def attention(
|
| 665 |
+
self,
|
| 666 |
+
q: torch.Tensor,
|
| 667 |
+
k: torch.Tensor,
|
| 668 |
+
v: torch.Tensor,
|
| 669 |
+
attention_bias: Optional[torch.Tensor] = None,
|
| 670 |
+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 671 |
+
use_cache: bool = False,
|
| 672 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 673 |
+
B, T, C = q.size() # batch size, sequence length, d_model
|
| 674 |
+
dtype = k.dtype
|
| 675 |
+
|
| 676 |
+
# Optionally apply layer norm to keys and queries.
|
| 677 |
+
if self.q_norm is not None and self.k_norm is not None:
|
| 678 |
+
q = self.q_norm(q).to(dtype=dtype)
|
| 679 |
+
k = self.k_norm(k).to(dtype=dtype)
|
| 680 |
+
|
| 681 |
+
# Move head forward to be next to the batch dim.
|
| 682 |
+
# shape: (B, nh, T, hs)
|
| 683 |
+
q = q.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2)
|
| 684 |
+
# shape: (B, n_kv_h, T, hs)
|
| 685 |
+
k = k.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
|
| 686 |
+
# shape: (B, n_kv_h, T, hs)
|
| 687 |
+
v = v.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
|
| 688 |
+
|
| 689 |
+
if layer_past is not None:
|
| 690 |
+
past_key, past_value = layer_past
|
| 691 |
+
k = torch.cat((past_key, k), dim=-2)
|
| 692 |
+
v = torch.cat((past_value, v), dim=-2)
|
| 693 |
+
|
| 694 |
+
present = (k, v) if use_cache else None
|
| 695 |
+
query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
|
| 696 |
+
|
| 697 |
+
if self.config.rope:
|
| 698 |
+
# Apply rotary embeddings.
|
| 699 |
+
q, k = self.rotary_emb(q, k)
|
| 700 |
+
|
| 701 |
+
# if attention_bias is not None:
|
| 702 |
+
# # Resize and cast attention bias.
|
| 703 |
+
# # The current dtype of the attention bias might not match the dtype that the SDP attn function will
|
| 704 |
+
# # run in if AMP is enabled, and this can be a problem if some tokens are masked out due to padding
|
| 705 |
+
# # as down-casting the attention bias to the autocast precision will result in -infs, which will
|
| 706 |
+
# # cause the SDP attn function to produce NaNs.
|
| 707 |
+
# attention_bias = self._cast_attn_bias(
|
| 708 |
+
# attention_bias[:, :, key_len - query_len : key_len, :key_len], dtype
|
| 709 |
+
# )
|
| 710 |
+
|
| 711 |
+
# Get the attention scores.
|
| 712 |
+
# shape: (B, nh, T, hs)
|
| 713 |
+
att = self._scaled_dot_product_attention(
|
| 714 |
+
q,
|
| 715 |
+
k,
|
| 716 |
+
v,
|
| 717 |
+
attn_mask=attention_bias,
|
| 718 |
+
dropout_p=0.0 if not self.training else self.config.attention_dropout,
|
| 719 |
+
is_causal=False,
|
| 720 |
+
)
|
| 721 |
+
|
| 722 |
+
# Re-assemble all head outputs side-by-side.
|
| 723 |
+
att = att.transpose(1, 2).contiguous().view(B, T, C)
|
| 724 |
+
|
| 725 |
+
# Apply output projection.
|
| 726 |
+
return self.attn_out(att), present
|
| 727 |
+
|
| 728 |
+
@abstractmethod
|
| 729 |
+
def forward(
|
| 730 |
+
self,
|
| 731 |
+
x: torch.Tensor,
|
| 732 |
+
attention_bias: Optional[torch.FloatTensor] = None,
|
| 733 |
+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 734 |
+
use_cache: bool = False,
|
| 735 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 736 |
+
raise NotImplementedError
|
| 737 |
+
|
| 738 |
+
@classmethod
|
| 739 |
+
def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> LLaDABlock:
|
| 740 |
+
if config.block_type == BlockType.sequential:
|
| 741 |
+
return LLaDASequentialBlock(layer_id, config, cache)
|
| 742 |
+
elif config.block_type == BlockType.llama:
|
| 743 |
+
return LLaDALlamaBlock(layer_id, config, cache)
|
| 744 |
+
else:
|
| 745 |
+
raise NotImplementedError(f"Unknown block type: '{config.block_type}'")
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
class LLaDASequentialBlock(LLaDABlock):
|
| 749 |
+
"""
|
| 750 |
+
This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
|
| 751 |
+
(plus another skip connection).
|
| 752 |
+
"""
|
| 753 |
+
|
| 754 |
+
def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
|
| 755 |
+
super().__init__(layer_id, config, cache)
|
| 756 |
+
# Layer norms.
|
| 757 |
+
self.attn_norm = LayerNorm.build(config)
|
| 758 |
+
self.ff_norm = LayerNorm.build(config)
|
| 759 |
+
# Attention input projection. Projects x -> (q, k, v)
|
| 760 |
+
head_dim = config.d_model // config.n_heads
|
| 761 |
+
self.fused_dims = (
|
| 762 |
+
config.d_model,
|
| 763 |
+
config.effective_n_kv_heads * head_dim,
|
| 764 |
+
config.effective_n_kv_heads * head_dim,
|
| 765 |
+
)
|
| 766 |
+
self.att_proj = nn.Linear(
|
| 767 |
+
config.d_model, sum(self.fused_dims), bias=config.include_bias | config.include_qkv_bias, device=config.init_device
|
| 768 |
+
)
|
| 769 |
+
# Feed-forward input projection.
|
| 770 |
+
self.ff_proj = nn.Linear(
|
| 771 |
+
config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
def reset_parameters(self):
|
| 775 |
+
super().reset_parameters()
|
| 776 |
+
self.attn_norm.reset_parameters()
|
| 777 |
+
self.ff_norm.reset_parameters()
|
| 778 |
+
# NOTE: the standard deviation for these weights does not depend on the layer.
|
| 779 |
+
init_weights(
|
| 780 |
+
self.config, self.att_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
|
| 781 |
+
)
|
| 782 |
+
init_weights(
|
| 783 |
+
self.config, self.ff_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
def forward(
|
| 787 |
+
self,
|
| 788 |
+
x: torch.Tensor,
|
| 789 |
+
attention_bias: Optional[torch.Tensor] = None,
|
| 790 |
+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 791 |
+
use_cache: bool = False,
|
| 792 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 793 |
+
# Get query, key, value projections.
|
| 794 |
+
# shape:
|
| 795 |
+
# - for regular attn q, k, v: (batch_size, seq_len, d_model)
|
| 796 |
+
# - for multi-query attn q: (batch_size, seq_len, d_model)
|
| 797 |
+
# k, v: (batch_size, seq_len, d_model // n_heads)
|
| 798 |
+
# - for group query attn q: (batch_size, seq_len, d_model)
|
| 799 |
+
# k, v: (batch_size, seq_len, d_model // n_kv_heads)
|
| 800 |
+
if self._activation_checkpoint_fn is not None:
|
| 801 |
+
q, k, v = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x)).split(
|
| 802 |
+
self.fused_dims, dim=-1
|
| 803 |
+
)
|
| 804 |
+
else:
|
| 805 |
+
q, k, v = self.att_proj(self.attn_norm(x)).split(self.fused_dims, dim=-1)
|
| 806 |
+
|
| 807 |
+
# Get attention scores.
|
| 808 |
+
if self._activation_checkpoint_fn is not None:
|
| 809 |
+
att, cache = self._activation_checkpoint_fn( # type: ignore
|
| 810 |
+
self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
|
| 811 |
+
)
|
| 812 |
+
else:
|
| 813 |
+
att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
|
| 814 |
+
|
| 815 |
+
# Add attention scores.
|
| 816 |
+
# shape: (B, T, C)
|
| 817 |
+
x = x + self.dropout(att)
|
| 818 |
+
|
| 819 |
+
# Add feed-forward projection.
|
| 820 |
+
# shape: (batch_size, seq_len, d_model)
|
| 821 |
+
og_x = x
|
| 822 |
+
if self._activation_checkpoint_fn is not None:
|
| 823 |
+
x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
|
| 824 |
+
else:
|
| 825 |
+
x = self.ff_norm(x)
|
| 826 |
+
x = self.ff_proj(x)
|
| 827 |
+
if self._activation_checkpoint_fn is not None:
|
| 828 |
+
x = self._activation_checkpoint_fn(self.act, x) # type: ignore
|
| 829 |
+
else:
|
| 830 |
+
x = self.act(x)
|
| 831 |
+
x = self.ff_out(x)
|
| 832 |
+
x = self.dropout(x)
|
| 833 |
+
x = og_x + x
|
| 834 |
+
|
| 835 |
+
return x, cache
|
| 836 |
+
|
| 837 |
+
|
| 838 |
+
class LLaDALlamaBlock(LLaDABlock):
|
| 839 |
+
"""
|
| 840 |
+
This is a transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
|
| 841 |
+
(plus another skip connection). This block is similar to `LLaDASequentialBlock`
|
| 842 |
+
but some operations have slightly different implementations to imitate the
|
| 843 |
+
behavior of Llama.
|
| 844 |
+
"""
|
| 845 |
+
|
| 846 |
+
def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
|
| 847 |
+
super().__init__(layer_id, config, cache)
|
| 848 |
+
# Layer norms.
|
| 849 |
+
self.attn_norm = LayerNorm.build(config)
|
| 850 |
+
self.ff_norm = LayerNorm.build(config)
|
| 851 |
+
self.__cache = cache
|
| 852 |
+
|
| 853 |
+
# Attention input projection. Projects x -> (q, k, v)
|
| 854 |
+
head_dim = config.d_model // config.n_heads
|
| 855 |
+
q_proj_out_dim = config.d_model
|
| 856 |
+
k_proj_out_dim = config.effective_n_kv_heads * head_dim
|
| 857 |
+
v_proj_out_dim = config.effective_n_kv_heads * head_dim
|
| 858 |
+
self.q_proj = nn.Linear(
|
| 859 |
+
config.d_model, q_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device
|
| 860 |
+
)
|
| 861 |
+
self.k_proj = nn.Linear(
|
| 862 |
+
config.d_model, k_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device
|
| 863 |
+
)
|
| 864 |
+
self.v_proj = nn.Linear(
|
| 865 |
+
config.d_model, v_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device
|
| 866 |
+
)
|
| 867 |
+
|
| 868 |
+
# Feed-forward input projection.
|
| 869 |
+
self.ff_proj = nn.Linear(
|
| 870 |
+
config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
|
| 871 |
+
)
|
| 872 |
+
# new add
|
| 873 |
+
self.up_proj = nn.Linear(
|
| 874 |
+
config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
|
| 875 |
+
)
|
| 876 |
+
|
| 877 |
+
def reset_parameters(self):
|
| 878 |
+
super().reset_parameters()
|
| 879 |
+
self.attn_norm.reset_parameters()
|
| 880 |
+
self.ff_norm.reset_parameters()
|
| 881 |
+
# NOTE: the standard deviation for these weights does not depend on the layer.
|
| 882 |
+
init_weights(self.config, self.q_proj, d=self.config.d_model, layer_id=None)
|
| 883 |
+
init_weights(self.config, self.k_proj, d=self.config.d_model, layer_id=None)
|
| 884 |
+
init_weights(self.config, self.v_proj, d=self.config.d_model, layer_id=None)
|
| 885 |
+
init_weights(self.config, self.ff_proj, d=self.config.d_model, layer_id=None)
|
| 886 |
+
init_weights(self.config, self.up_proj, d=self.config.d_model, layer_id=None) # new add
|
| 887 |
+
|
| 888 |
+
def forward(
|
| 889 |
+
self,
|
| 890 |
+
x: torch.Tensor,
|
| 891 |
+
attention_bias: Optional[torch.Tensor] = None,
|
| 892 |
+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 893 |
+
use_cache: bool = False,
|
| 894 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 895 |
+
# Get query, key, value projections.
|
| 896 |
+
# shape:
|
| 897 |
+
# - for regular attn q, k, v: (batch_size, seq_len, d_model)
|
| 898 |
+
# - for multi-query attn q: (batch_size, seq_len, d_model)
|
| 899 |
+
# k, v: (batch_size, seq_len, d_model // n_heads)
|
| 900 |
+
# - for group query attn q: (batch_size, seq_len, d_model)
|
| 901 |
+
# k, v: (batch_size, seq_len, d_model // n_kv_heads)
|
| 902 |
+
# print(x)
|
| 903 |
+
x_normed = self.attn_norm(x)
|
| 904 |
+
q = self.q_proj(x_normed)
|
| 905 |
+
k = self.k_proj(x_normed)
|
| 906 |
+
v = self.v_proj(x_normed)
|
| 907 |
+
|
| 908 |
+
# Get attention scores.
|
| 909 |
+
if self._activation_checkpoint_fn is not None:
|
| 910 |
+
att, cache = self._activation_checkpoint_fn( # type: ignore
|
| 911 |
+
self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
|
| 912 |
+
)
|
| 913 |
+
else:
|
| 914 |
+
att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
|
| 915 |
+
|
| 916 |
+
# Add attention scores.
|
| 917 |
+
# shape: (B, T, C)
|
| 918 |
+
x = x + self.dropout(att)
|
| 919 |
+
|
| 920 |
+
# Add feed-forward projection.
|
| 921 |
+
# shape: (batch_size, seq_len, d_model)
|
| 922 |
+
og_x = x
|
| 923 |
+
if self._activation_checkpoint_fn is not None:
|
| 924 |
+
x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
|
| 925 |
+
else:
|
| 926 |
+
x = self.ff_norm(x)
|
| 927 |
+
x, x_up = self.ff_proj(x), self.up_proj(x) # new add
|
| 928 |
+
if self._activation_checkpoint_fn is not None:
|
| 929 |
+
x = self._activation_checkpoint_fn(self.act, x) # type: ignore
|
| 930 |
+
else:
|
| 931 |
+
x = self.act(x)
|
| 932 |
+
x = x * x_up # new add
|
| 933 |
+
x = self.ff_out(x)
|
| 934 |
+
x = self.dropout(x)
|
| 935 |
+
x = og_x + x
|
| 936 |
+
|
| 937 |
+
return x, cache
|
| 938 |
+
|
| 939 |
+
|
| 940 |
+
class LLaDAOutput(NamedTuple):
|
| 941 |
+
logits: torch.FloatTensor
|
| 942 |
+
"""
|
| 943 |
+
A tensor of shape `(batch_size, seq_len, vocab_size)` representing the log probabilities
|
| 944 |
+
for the next token *before* normalization via (log) softmax.
|
| 945 |
+
"""
|
| 946 |
+
|
| 947 |
+
attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]]
|
| 948 |
+
"""
|
| 949 |
+
Attention keys and values from each block.
|
| 950 |
+
"""
|
| 951 |
+
|
| 952 |
+
hidden_states: Optional[Tuple[torch.Tensor]]
|
| 953 |
+
"""
|
| 954 |
+
Hidden states from each block.
|
| 955 |
+
"""
|
| 956 |
+
|
| 957 |
+
|
| 958 |
+
class LLaDAGenerateOutput(NamedTuple):
|
| 959 |
+
token_ids: torch.LongTensor
|
| 960 |
+
"""
|
| 961 |
+
The generated token IDs, a tensor of shape `(batch_size, beam_size, max_steps)`.
|
| 962 |
+
These do *not* include the original input IDs.
|
| 963 |
+
"""
|
| 964 |
+
|
| 965 |
+
scores: torch.FloatTensor
|
| 966 |
+
"""
|
| 967 |
+
The scores of the generated sequences, a tensor of shape `(batch_size, beam_size)`.
|
| 968 |
+
"""
|
| 969 |
+
|
| 970 |
+
|
| 971 |
+
class LLaDABlockGroup(nn.ModuleList):
|
| 972 |
+
def __init__(self, config: ModelConfig, layer_offset: int, modules: Optional[Iterable[nn.Module]] = None):
|
| 973 |
+
super().__init__(modules)
|
| 974 |
+
self.config = config
|
| 975 |
+
self.layer_offset = layer_offset
|
| 976 |
+
self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
|
| 977 |
+
self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
|
| 978 |
+
|
| 979 |
+
def forward(
|
| 980 |
+
self,
|
| 981 |
+
x: torch.Tensor,
|
| 982 |
+
attention_bias: Optional[torch.FloatTensor] = None,
|
| 983 |
+
layers_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
| 984 |
+
use_cache: bool = False,
|
| 985 |
+
) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
|
| 986 |
+
attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
|
| 987 |
+
for block_idx, block in enumerate(self):
|
| 988 |
+
layer_past = None if layers_past is None else layers_past[block_idx]
|
| 989 |
+
block_idx += self.layer_offset
|
| 990 |
+
if (
|
| 991 |
+
(self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.whole_layer)
|
| 992 |
+
or (
|
| 993 |
+
self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_two
|
| 994 |
+
and block_idx % 2 == 0
|
| 995 |
+
)
|
| 996 |
+
or (
|
| 997 |
+
self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_three
|
| 998 |
+
and block_idx % 3 == 0
|
| 999 |
+
)
|
| 1000 |
+
or (
|
| 1001 |
+
self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_four
|
| 1002 |
+
and block_idx % 4 == 0
|
| 1003 |
+
)
|
| 1004 |
+
):
|
| 1005 |
+
# shape: (batch_size, seq_len, d_model)
|
| 1006 |
+
x, cache = self._activation_checkpoint_fn( # type: ignore
|
| 1007 |
+
block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
|
| 1008 |
+
)
|
| 1009 |
+
else:
|
| 1010 |
+
# shape: (batch_size, seq_len, d_model)
|
| 1011 |
+
x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
|
| 1012 |
+
if attn_key_values is not None:
|
| 1013 |
+
assert cache is not None
|
| 1014 |
+
attn_key_values.append(cache)
|
| 1015 |
+
return x, attn_key_values
|
| 1016 |
+
|
| 1017 |
+
def reset_parameters(self):
|
| 1018 |
+
for block in self:
|
| 1019 |
+
block.reset_parameters()
|
| 1020 |
+
|
| 1021 |
+
def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
|
| 1022 |
+
self.activation_checkpointing_strategy = strategy
|
| 1023 |
+
for block in self:
|
| 1024 |
+
block.set_activation_checkpointing(strategy)
|
| 1025 |
+
|
| 1026 |
+
|
| 1027 |
+
class LLaDAModel(nn.Module):
|
| 1028 |
+
def __init__(self, config: ModelConfig, init_params: bool = True):
|
| 1029 |
+
super().__init__()
|
| 1030 |
+
self.config = config
|
| 1031 |
+
self.__cache = BufferCache()
|
| 1032 |
+
|
| 1033 |
+
# Validate config.
|
| 1034 |
+
if self.config.alibi and self.config.flash_attention:
|
| 1035 |
+
raise Exception("ALiBi is currently not supported with FlashAttention")
|
| 1036 |
+
|
| 1037 |
+
if self.config.alibi and self.config.rope:
|
| 1038 |
+
raise Exception("ALiBi and RoPE are mutually exclusive")
|
| 1039 |
+
|
| 1040 |
+
if self.config.embedding_size is not None and self.config.embedding_size != self.config.vocab_size:
|
| 1041 |
+
if self.config.embedding_size < self.config.vocab_size:
|
| 1042 |
+
raise Exception("embedding size should be at least as big as vocab size")
|
| 1043 |
+
elif self.config.embedding_size % 128 != 0:
|
| 1044 |
+
import warnings
|
| 1045 |
+
|
| 1046 |
+
warnings.warn(
|
| 1047 |
+
"Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning
|
| 1048 |
+
)
|
| 1049 |
+
|
| 1050 |
+
self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
|
| 1051 |
+
self._activation_checkpoint_fn: Callable = activation_checkpoint_function(self.config)
|
| 1052 |
+
|
| 1053 |
+
if not (
|
| 1054 |
+
0 < self.config.block_group_size <= self.config.n_layers
|
| 1055 |
+
and self.config.n_layers % self.config.block_group_size == 0
|
| 1056 |
+
):
|
| 1057 |
+
raise Exception("n layers must be divisible by block group size")
|
| 1058 |
+
|
| 1059 |
+
torch.backends.cuda.enable_flash_sdp(True)
|
| 1060 |
+
torch.backends.cuda.enable_mem_efficient_sdp(False) # this is super slow so make sure torch won't use it
|
| 1061 |
+
|
| 1062 |
+
self.transformer = nn.ModuleDict(
|
| 1063 |
+
dict(
|
| 1064 |
+
wte=nn.Embedding(
|
| 1065 |
+
config.embedding_size or config.vocab_size, config.d_model, device=config.init_device
|
| 1066 |
+
),
|
| 1067 |
+
emb_drop=Dropout(config.embedding_dropout),
|
| 1068 |
+
ln_f=LayerNorm.build(config),
|
| 1069 |
+
)
|
| 1070 |
+
)
|
| 1071 |
+
|
| 1072 |
+
blocks = [LLaDABlock.build(i, config, self.__cache) for i in range(config.n_layers)]
|
| 1073 |
+
if self.config.block_group_size > 1:
|
| 1074 |
+
block_groups = [
|
| 1075 |
+
LLaDABlockGroup(config, i, blocks[i : i + config.block_group_size])
|
| 1076 |
+
for i in range(0, config.n_layers, config.block_group_size)
|
| 1077 |
+
]
|
| 1078 |
+
self.transformer.update({"block_groups": nn.ModuleList(block_groups)})
|
| 1079 |
+
else:
|
| 1080 |
+
self.transformer.update({"blocks": nn.ModuleList(blocks)})
|
| 1081 |
+
|
| 1082 |
+
if not (self.config.alibi or self.config.rope):
|
| 1083 |
+
self.transformer.update(
|
| 1084 |
+
{"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)}
|
| 1085 |
+
)
|
| 1086 |
+
if not config.weight_tying:
|
| 1087 |
+
self.transformer.update(
|
| 1088 |
+
{
|
| 1089 |
+
"ff_out": nn.Linear(
|
| 1090 |
+
config.d_model,
|
| 1091 |
+
config.embedding_size or config.vocab_size,
|
| 1092 |
+
bias=config.include_bias,
|
| 1093 |
+
device=config.init_device,
|
| 1094 |
+
)
|
| 1095 |
+
}
|
| 1096 |
+
)
|
| 1097 |
+
# When `init_device="meta"` FSDP will call `reset_parameters()` to initialize weights.
|
| 1098 |
+
if init_params and self.config.init_device != "meta":
|
| 1099 |
+
self.reset_parameters()
|
| 1100 |
+
self.__num_fwd_flops: Optional[int] = None
|
| 1101 |
+
|
| 1102 |
+
# Warm up cache.
|
| 1103 |
+
if self.config.alibi:
|
| 1104 |
+
get_causal_attention_bias(self.__cache, config.max_sequence_length, _non_meta_init_device(config))
|
| 1105 |
+
self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config))
|
| 1106 |
+
|
| 1107 |
+
def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
|
| 1108 |
+
self.activation_checkpointing_strategy = strategy
|
| 1109 |
+
if self.config.block_group_size != 1:
|
| 1110 |
+
for block_group in self.transformer.block_groups:
|
| 1111 |
+
block_group.set_activation_checkpointing(strategy)
|
| 1112 |
+
else:
|
| 1113 |
+
for block in self.transformer.blocks:
|
| 1114 |
+
block.set_activation_checkpointing(strategy)
|
| 1115 |
+
|
| 1116 |
+
@property
|
| 1117 |
+
def device(self) -> torch.device:
|
| 1118 |
+
device: torch.device = self.transformer.wte.weight.device # type: ignore
|
| 1119 |
+
if device.type == "meta":
|
| 1120 |
+
return _non_meta_init_device(self.config)
|
| 1121 |
+
else:
|
| 1122 |
+
return device
|
| 1123 |
+
|
| 1124 |
+
def reset_parameters(self):
|
| 1125 |
+
log.info("Initializing model parameters...")
|
| 1126 |
+
# Top-level embeddings / linear layers.
|
| 1127 |
+
init_weights(
|
| 1128 |
+
self.config,
|
| 1129 |
+
self.transformer.wte, # type: ignore
|
| 1130 |
+
std_factor=(0.5 * math.sqrt(self.config.d_model)) if self.config.scale_logits else 1.0,
|
| 1131 |
+
type_of_module=ModuleType.emb,
|
| 1132 |
+
)
|
| 1133 |
+
if hasattr(self.transformer, "wpe"):
|
| 1134 |
+
init_weights(self.config, self.transformer.wpe, type_of_module=ModuleType.emb) # type: ignore
|
| 1135 |
+
|
| 1136 |
+
# Top-level layer norm.
|
| 1137 |
+
self.transformer.ln_f.reset_parameters() # type: ignore
|
| 1138 |
+
|
| 1139 |
+
# Output weights.
|
| 1140 |
+
if hasattr(self.transformer, "ff_out"):
|
| 1141 |
+
init_weights(self.config, self.transformer.ff_out, type_of_module=ModuleType.final_out) # type: ignore
|
| 1142 |
+
|
| 1143 |
+
# Let the blocks handle themselves.
|
| 1144 |
+
if self.config.block_group_size == 1:
|
| 1145 |
+
for block in self.transformer.blocks:
|
| 1146 |
+
block.reset_parameters()
|
| 1147 |
+
else:
|
| 1148 |
+
for block_group in self.transformer.block_groups:
|
| 1149 |
+
block_group.reset_parameters()
|
| 1150 |
+
|
| 1151 |
+
def get_alibi_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor:
|
| 1152 |
+
if (alibi_bias := self.__cache.get("alibi_attention_bias")) is not None and alibi_bias.shape[
|
| 1153 |
+
-1
|
| 1154 |
+
] >= seq_len:
|
| 1155 |
+
if alibi_bias.device != device:
|
| 1156 |
+
alibi_bias = alibi_bias.to(device)
|
| 1157 |
+
self.__cache["alibi_attention_bias"] = alibi_bias
|
| 1158 |
+
return alibi_bias
|
| 1159 |
+
with torch.autocast(device.type, enabled=False):
|
| 1160 |
+
alibi_bias = alibi_attention_bias(seq_len, self.config, device)
|
| 1161 |
+
self.__cache["alibi_attention_bias"] = alibi_bias
|
| 1162 |
+
return alibi_bias
|
| 1163 |
+
|
| 1164 |
+
def forward(
|
| 1165 |
+
self,
|
| 1166 |
+
input_ids: torch.LongTensor,
|
| 1167 |
+
input_embeddings: Optional[torch.FloatTensor] = None,
|
| 1168 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1169 |
+
attention_bias: Optional[torch.Tensor] = None,
|
| 1170 |
+
past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
| 1171 |
+
use_cache: bool = False,
|
| 1172 |
+
update_kvcache: bool = False,
|
| 1173 |
+
last_logits_only: bool = False,
|
| 1174 |
+
output_hidden_states: Optional[bool] = None,
|
| 1175 |
+
) -> LLaDAOutput:
|
| 1176 |
+
"""
|
| 1177 |
+
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
|
| 1178 |
+
:param input_embeddings: A tensor of shape `(batch_size, seq_len, d_model)` with input
|
| 1179 |
+
embeddings. When provided, it is treated as the output of the input embedding layer.
|
| 1180 |
+
:param attention_mask: A tensor of shape `(batch_size, seq_len)` that indicates
|
| 1181 |
+
which input IDs are masked. A `1` value in the mask means that
|
| 1182 |
+
the corresponding input ID should *not* be ignored. A `0` means
|
| 1183 |
+
that the corresponding input ID is masked.
|
| 1184 |
+
|
| 1185 |
+
This has the same meaning as the `attention_mask` in HuggingFace's `transformers`
|
| 1186 |
+
library.
|
| 1187 |
+
:param attention_bias: A tensor of shape `(batch_size, 1, seq_len, seq_len)`,
|
| 1188 |
+
`(1, 1, seq_len, seq_len)`, or `(seq_len, seq_len)`. This is used
|
| 1189 |
+
to introduce causal or other biases.
|
| 1190 |
+
|
| 1191 |
+
If the tensor is a bool or byte tensor, a `True` or `1` at `attention_bias[:, :, i, j]`
|
| 1192 |
+
indicates that the i-th element in the sequence is allowed to attend to the j-th
|
| 1193 |
+
element in the sequence.
|
| 1194 |
+
|
| 1195 |
+
If the tensor is a float tensor, it will just be added to the attention
|
| 1196 |
+
scores before the softmax.
|
| 1197 |
+
|
| 1198 |
+
The default is causal, which corresponds to a lower-diagonal byte matrix of ones.
|
| 1199 |
+
:param past_key_values: Pre-computed keys and values for each attention block.
|
| 1200 |
+
Can be used to speed up sequential decoding. The `input_ids` which have
|
| 1201 |
+
their past given to this model should not be passed as `input_ids` as they have already been computed.
|
| 1202 |
+
:param use_cache: If `True`, return key and value tensors for each block.
|
| 1203 |
+
:param last_logits_only: If `True`, only compute the logits for the last token of each sequence.
|
| 1204 |
+
This can speed up decoding when you only care about the next token.
|
| 1205 |
+
"""
|
| 1206 |
+
# Add Basic MDM Model config check
|
| 1207 |
+
# print(input_ids.dtype)
|
| 1208 |
+
assert not self.config.alibi, "Alibi length extrapolation is not supported for MDM."
|
| 1209 |
+
assert self.config.rope, "Rope must be used in Llama-Encoder for MDM."
|
| 1210 |
+
# assert (past_key_values is None and not use_cache), "The kvcache is not suppotred for MDM."
|
| 1211 |
+
|
| 1212 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else False
|
| 1213 |
+
|
| 1214 |
+
if past_key_values:
|
| 1215 |
+
assert len(past_key_values) == self.config.n_layers
|
| 1216 |
+
|
| 1217 |
+
batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2]
|
| 1218 |
+
if past_key_values is None:
|
| 1219 |
+
past_length = 0
|
| 1220 |
+
else:
|
| 1221 |
+
past_length = past_key_values[0][0].size(-2)
|
| 1222 |
+
|
| 1223 |
+
# Get embeddings of input.
|
| 1224 |
+
# shape: (batch_size, seq_len, d_model)
|
| 1225 |
+
# print(input_ids.dtype,"wte")
|
| 1226 |
+
x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore
|
| 1227 |
+
|
| 1228 |
+
if self.config.input_emb_norm:
|
| 1229 |
+
x = x * (self.config.d_model**0.5)
|
| 1230 |
+
|
| 1231 |
+
if not (self.config.alibi or self.config.rope):
|
| 1232 |
+
# Get positional embeddings.
|
| 1233 |
+
# shape: (1, seq_len)
|
| 1234 |
+
pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0)
|
| 1235 |
+
# shape: (1, seq_len, d_model)
|
| 1236 |
+
pos_emb = self.transformer.wpe(pos) # type: ignore
|
| 1237 |
+
x = pos_emb + x
|
| 1238 |
+
|
| 1239 |
+
# Add input + positional embeddings and apply dropout.
|
| 1240 |
+
# shape: (batch_size, seq_len, d_model)
|
| 1241 |
+
x = self.transformer.emb_drop(x) # type: ignore
|
| 1242 |
+
|
| 1243 |
+
# Transform the attention mask into what the blocks expect.
|
| 1244 |
+
if attention_mask is not None and 0.0 in attention_mask:
|
| 1245 |
+
# shape: (batch_size, 1, 1, seq_len)
|
| 1246 |
+
attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :]
|
| 1247 |
+
attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
|
| 1248 |
+
else:
|
| 1249 |
+
attention_mask = None
|
| 1250 |
+
|
| 1251 |
+
# Merge attention mask with attention bias.
|
| 1252 |
+
if (
|
| 1253 |
+
attention_bias is not None
|
| 1254 |
+
or attention_mask is not None
|
| 1255 |
+
or self.config.alibi
|
| 1256 |
+
# NOTE (epwalsh): we need to initialize the attn bias in order for attn to work properly
|
| 1257 |
+
# with key+value cache. Otherwise `F.scaled_dot_product_attention()` doesn't seem to compute
|
| 1258 |
+
# scores correctly.
|
| 1259 |
+
or past_key_values is not None
|
| 1260 |
+
):
|
| 1261 |
+
if attention_bias is None and self.config.alibi:
|
| 1262 |
+
attention_bias = get_causal_attention_bias(
|
| 1263 |
+
self.__cache, past_length + seq_len, x.device
|
| 1264 |
+
) + self.get_alibi_attention_bias(past_length + seq_len, x.device)
|
| 1265 |
+
elif attention_bias is None:
|
| 1266 |
+
attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device)
|
| 1267 |
+
elif attention_bias.dtype in (torch.int8, torch.bool):
|
| 1268 |
+
attention_bias = attention_bias.to(dtype=torch.float)
|
| 1269 |
+
attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)
|
| 1270 |
+
|
| 1271 |
+
# Transform to the right shape and data type.
|
| 1272 |
+
mask_len = seq_len
|
| 1273 |
+
if attention_mask is not None:
|
| 1274 |
+
mask_len = attention_mask.shape[-1]
|
| 1275 |
+
elif past_key_values is not None:
|
| 1276 |
+
mask_len = past_key_values[0][0].shape[-2] + seq_len
|
| 1277 |
+
attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float)
|
| 1278 |
+
|
| 1279 |
+
# Add in the masking bias.
|
| 1280 |
+
if attention_mask is not None:
|
| 1281 |
+
attention_bias = attention_bias + attention_mask
|
| 1282 |
+
# Might get -infs after adding attention mask, since dtype.min + dtype.min = -inf.
|
| 1283 |
+
# `F.scaled_dot_product_attention()` doesn't handle -inf like you'd expect, instead
|
| 1284 |
+
# it can produce NaNs.
|
| 1285 |
+
ensure_finite_(attention_bias, check_neg_inf=True, check_pos_inf=False)
|
| 1286 |
+
|
| 1287 |
+
attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
|
| 1288 |
+
|
| 1289 |
+
# decoder layers
|
| 1290 |
+
all_hidden_states = []
|
| 1291 |
+
|
| 1292 |
+
# Apply blocks one-by-one.
|
| 1293 |
+
if self.config.block_group_size == 1:
|
| 1294 |
+
for block_idx, block in enumerate(self.transformer.blocks):
|
| 1295 |
+
if output_hidden_states:
|
| 1296 |
+
# add hidden states
|
| 1297 |
+
all_hidden_states.append(x)
|
| 1298 |
+
|
| 1299 |
+
layer_past = None if past_key_values is None else past_key_values[block_idx]
|
| 1300 |
+
if (
|
| 1301 |
+
(self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.whole_layer)
|
| 1302 |
+
or (
|
| 1303 |
+
self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_two
|
| 1304 |
+
and block_idx % 2 == 0
|
| 1305 |
+
)
|
| 1306 |
+
or (
|
| 1307 |
+
self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_three
|
| 1308 |
+
and block_idx % 3 == 0
|
| 1309 |
+
)
|
| 1310 |
+
or (
|
| 1311 |
+
self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_four
|
| 1312 |
+
and block_idx % 4 == 0
|
| 1313 |
+
)
|
| 1314 |
+
):
|
| 1315 |
+
# shape: (batch_size, seq_len, d_model)
|
| 1316 |
+
x, cache = self._activation_checkpoint_fn(
|
| 1317 |
+
block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
|
| 1318 |
+
)
|
| 1319 |
+
else:
|
| 1320 |
+
# shape: (batch_size, seq_len, d_model)
|
| 1321 |
+
x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
|
| 1322 |
+
if attn_key_values is not None:
|
| 1323 |
+
if update_kvcache == True:
|
| 1324 |
+
attn_key_values.append(cache)
|
| 1325 |
+
else:
|
| 1326 |
+
for group_idx, block_group in enumerate(self.transformer.block_groups):
|
| 1327 |
+
if output_hidden_states:
|
| 1328 |
+
# add hidden states
|
| 1329 |
+
all_hidden_states.append(x)
|
| 1330 |
+
|
| 1331 |
+
layers_past = (
|
| 1332 |
+
None
|
| 1333 |
+
if past_key_values is None
|
| 1334 |
+
else past_key_values[
|
| 1335 |
+
group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size
|
| 1336 |
+
]
|
| 1337 |
+
)
|
| 1338 |
+
x, cache = block_group(
|
| 1339 |
+
x, attention_bias=attention_bias, layers_past=layers_past, use_cache=use_cache
|
| 1340 |
+
)
|
| 1341 |
+
if attn_key_values is not None:
|
| 1342 |
+
assert cache is not None
|
| 1343 |
+
attn_key_values.extend(cache)
|
| 1344 |
+
|
| 1345 |
+
if last_logits_only:
|
| 1346 |
+
# shape: (batch_size, 1, d_model)
|
| 1347 |
+
x = x[:, -1, :].unsqueeze(1)
|
| 1348 |
+
|
| 1349 |
+
# Apply final layer norm.
|
| 1350 |
+
# shape: (batch_size, seq_len or 1, d_model)
|
| 1351 |
+
x = self.transformer.ln_f(x) # type: ignore
|
| 1352 |
+
if output_hidden_states:
|
| 1353 |
+
# add final hidden state post-final-layernorm, following HuggingFace's convention
|
| 1354 |
+
all_hidden_states.append(x)
|
| 1355 |
+
|
| 1356 |
+
# Get logits.
|
| 1357 |
+
# shape: (batch_size, seq_len or 1, vocab_size)
|
| 1358 |
+
if self.config.weight_tying:
|
| 1359 |
+
logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore
|
| 1360 |
+
else:
|
| 1361 |
+
logits = self.transformer.ff_out(x) # type: ignore
|
| 1362 |
+
if self.config.scale_logits:
|
| 1363 |
+
logits.mul_(1 / math.sqrt(self.config.d_model))
|
| 1364 |
+
if use_cache == True and update_kvcache == False:
|
| 1365 |
+
attn_key_values=past_key_values
|
| 1366 |
+
return LLaDAOutput(logits=logits, attn_key_values=attn_key_values, hidden_states=tuple(all_hidden_states) if output_hidden_states else None) # type: ignore[arg-type]
|
| 1367 |
+
|
| 1368 |
+
|
| 1369 |
+
def create_model_config_from_pretrained_config(config: LLaDAConfig):
|
| 1370 |
+
"""
|
| 1371 |
+
Utility function
|
| 1372 |
+
"""
|
| 1373 |
+
|
| 1374 |
+
kwargs = {}
|
| 1375 |
+
for field in fields(ModelConfig):
|
| 1376 |
+
kwargs[field.name] = getattr(config, field.name)
|
| 1377 |
+
|
| 1378 |
+
model_config = ModelConfig(**kwargs)
|
| 1379 |
+
return model_config
|
| 1380 |
+
|
| 1381 |
+
|
| 1382 |
+
class LLaDAModelLM(PreTrainedModel):
|
| 1383 |
+
"""
|
| 1384 |
+
Extremely barebones HF model wrapper.
|
| 1385 |
+
"""
|
| 1386 |
+
|
| 1387 |
+
config_class = LLaDAConfig
|
| 1388 |
+
base_model_prefix = "model"
|
| 1389 |
+
_no_split_modules = ["LLaDABlock", "LLaDASequentialBlock", "LLaDALlamaBlock"]
|
| 1390 |
+
|
| 1391 |
+
def __init__(self, config: LLaDAConfig, model: Optional[LLaDAModel] = None, init_params: bool = False):
|
| 1392 |
+
super().__init__(config)
|
| 1393 |
+
|
| 1394 |
+
if not model:
|
| 1395 |
+
model_config = create_model_config_from_pretrained_config(config)
|
| 1396 |
+
# Initialize model (always on CPU to start with so we don't run out of GPU memory).
|
| 1397 |
+
model_config.init_device = "cpu"
|
| 1398 |
+
self.model = LLaDAModel(model_config, init_params=init_params)
|
| 1399 |
+
else:
|
| 1400 |
+
self.model = model
|
| 1401 |
+
|
| 1402 |
+
def forward(
|
| 1403 |
+
self,
|
| 1404 |
+
input_ids: torch.LongTensor = None,
|
| 1405 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1406 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1407 |
+
attention_bias: Optional[torch.Tensor] = None,
|
| 1408 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 1409 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1410 |
+
use_cache: Optional[bool] = None,
|
| 1411 |
+
output_attentions: Optional[bool] = None,
|
| 1412 |
+
output_hidden_states: Optional[bool] = None,
|
| 1413 |
+
return_dict: Optional[bool] = None,
|
| 1414 |
+
cache_position: Optional[Cache] = None, # This is a hack mitigation of an issue in transformers `4.39.x`
|
| 1415 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 1416 |
+
if use_cache is None:
|
| 1417 |
+
use_cache = self.config.use_cache
|
| 1418 |
+
|
| 1419 |
+
if output_attentions:
|
| 1420 |
+
raise ValueError("output_attentions is not yet supported in LLaDA")
|
| 1421 |
+
|
| 1422 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1423 |
+
|
| 1424 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 1425 |
+
outputs = self.model.forward(
|
| 1426 |
+
input_ids=input_ids,
|
| 1427 |
+
input_embeddings=inputs_embeds,
|
| 1428 |
+
attention_mask=attention_mask,
|
| 1429 |
+
attention_bias=attention_bias,
|
| 1430 |
+
past_key_values=past_key_values,
|
| 1431 |
+
use_cache=use_cache,
|
| 1432 |
+
output_hidden_states=output_hidden_states,
|
| 1433 |
+
)
|
| 1434 |
+
|
| 1435 |
+
logits = outputs.logits
|
| 1436 |
+
hidden_states = outputs.hidden_states
|
| 1437 |
+
|
| 1438 |
+
loss = None
|
| 1439 |
+
if labels is not None:
|
| 1440 |
+
import warnings
|
| 1441 |
+
warnings.warn("Note that for LLaDA, you cannot calculate the loss here.", UserWarning)
|
| 1442 |
+
if not return_dict:
|
| 1443 |
+
output = (logits,) + outputs[1:]
|
| 1444 |
+
return (loss,) + output if loss is not None else output
|
| 1445 |
+
|
| 1446 |
+
return CausalLMOutputWithPast(
|
| 1447 |
+
logits=logits,
|
| 1448 |
+
past_key_values=outputs.attn_key_values,
|
| 1449 |
+
hidden_states=hidden_states,
|
| 1450 |
+
)
|
| 1451 |
+
|
| 1452 |
+
def can_generate(self) -> bool:
|
| 1453 |
+
return True
|
| 1454 |
+
|
| 1455 |
+
def prepare_inputs_for_generation(
|
| 1456 |
+
self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
|
| 1457 |
+
):
|
| 1458 |
+
if past_key_values:
|
| 1459 |
+
# This is because we want the model to only process the last generated token.
|
| 1460 |
+
input_ids = input_ids[:, -1:]
|
| 1461 |
+
model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
|
| 1462 |
+
|
| 1463 |
+
model_inputs.update(kwargs)
|
| 1464 |
+
model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache)
|
| 1465 |
+
return model_inputs
|
| 1466 |
+
|
| 1467 |
+
# TODO: these are required to make the implementation complete.
|
| 1468 |
+
# def resize_position_embeddings(self, new_num_position_embeddings: int):
|
| 1469 |
+
# pass
|
| 1470 |
+
#
|
| 1471 |
+
# def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
|
| 1472 |
+
# pass
|
| 1473 |
+
#
|
| 1474 |
+
# def _reorder_cache(self, past_key_values, beam_idx):
|
| 1475 |
+
# pass
|
| 1476 |
+
|
| 1477 |
+
def get_input_embeddings(self) -> torch.nn.Module:
|
| 1478 |
+
return self.model.transformer.wte
|
| 1479 |
+
|
| 1480 |
+
def set_input_embeddings(self, value: torch.nn.Module):
|
| 1481 |
+
self.model.transformer.wte = value
|
| 1482 |
+
|
| 1483 |
+
def get_output_embeddings(self):
|
| 1484 |
+
if self.config.weight_tying:
|
| 1485 |
+
return self.model.transformer.wte
|
| 1486 |
+
else:
|
| 1487 |
+
return self.model.transformer.ff_out
|
| 1488 |
+
|
| 1489 |
+
def set_output_embeddings(self, value: torch.nn.Module):
|
| 1490 |
+
if self.config.weight_tying:
|
| 1491 |
+
self.model.transformer.wte = value
|
| 1492 |
+
else:
|
| 1493 |
+
self.model.transformer.ff_out = value
|
| 1494 |
+
|
| 1495 |
+
def tie_weights(self):
|
| 1496 |
+
if self.config.weight_tying:
|
| 1497 |
+
self.model.transformer.ff_out = self.model.transformer.wte
|
| 1498 |
+
|
| 1499 |
+
# Register the model so that it is available for transformer pipelines, auto-loading, etc.
|
| 1500 |
+
AutoModel.register(LLaDAConfig, LLaDAModelLM)
|
Discrete-Diffusion-Forcing/D2F-train/train.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from peft import PeftModel, PeftConfig, get_peft_model
|
| 2 |
+
from datasets import load_dataset
|
| 3 |
+
from torch.utils.data import DataLoader
|
| 4 |
+
from peft import PeftModel, PeftConfig, get_peft_model
|
| 5 |
+
from utils.util import flatten_dict,shift_logits
|
| 6 |
+
from utils.data import get_bs17k_dataloader,get_llada_bs17k_dataloader,get_dataloader_by_config
|
| 7 |
+
from utils.model import get_model,get_llada,get_model_by_config
|
| 8 |
+
from utils.loss import compute_loss,compute_llada_loss,compute_normal_loss,compute_loss_by_config
|
| 9 |
+
from utils.generation import sample_tokens
|
| 10 |
+
# import dataloader
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import torch
|
| 14 |
+
import argparse
|
| 15 |
+
import torch.distributed as dist
|
| 16 |
+
from omegaconf import OmegaConf
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from accelerate import Accelerator
|
| 19 |
+
from accelerate.utils import ProjectConfiguration
|
| 20 |
+
|
| 21 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
| 22 |
+
|
| 23 |
+
def get_accelerator(config, global_config):
|
| 24 |
+
# Select experiment path based on config
|
| 25 |
+
if hasattr(global_config, 'paths') and hasattr(global_config.paths, 'experiment'):
|
| 26 |
+
root_path = global_config.paths.experiment
|
| 27 |
+
else:
|
| 28 |
+
root_path = config.root if hasattr(config, 'root') else '/tmp/experiment'
|
| 29 |
+
|
| 30 |
+
output_dir = os.path.join(root_path, config.exp_name, config.output_dir)
|
| 31 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 32 |
+
logging_dir = os.path.join(output_dir, config.logging_dir)
|
| 33 |
+
project_config = ProjectConfiguration(project_dir=config.output_dir, logging_dir=logging_dir)
|
| 34 |
+
accelerator = Accelerator(
|
| 35 |
+
log_with=None if config.report_to == 'no' else config.report_to,
|
| 36 |
+
mixed_precision=config.mixed_precision,
|
| 37 |
+
project_config=project_config,
|
| 38 |
+
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
return accelerator, output_dir
|
| 42 |
+
|
| 43 |
+
def main(args):
|
| 44 |
+
config = OmegaConf.load(args.config)
|
| 45 |
+
accelerator, output_dir = get_accelerator(config.train, config)
|
| 46 |
+
|
| 47 |
+
# Use unified model and data loading functions
|
| 48 |
+
denoiser, tokenizer = get_model_by_config(config)
|
| 49 |
+
dataloader = get_dataloader_by_config(tokenizer, config.data, config)
|
| 50 |
+
|
| 51 |
+
if config.train.decoder_resume_path is not None:
|
| 52 |
+
ckpt = torch.load(config.train.decoder_resume_path, map_location='cpu', weights_only=True)
|
| 53 |
+
if config.train.skipped_keys:
|
| 54 |
+
ckpt = {k: v for k, v in ckpt.items() if k not in config.train.skipped_keys}
|
| 55 |
+
m, u = denoiser.load_state_dict(ckpt, strict=False)
|
| 56 |
+
if accelerator.is_main_process:
|
| 57 |
+
print(f'model ckpt loaded from {config.train.decoder_resume_path}')
|
| 58 |
+
|
| 59 |
+
# ckpt = torch.load(config.train.head_resume_path, map_location='cpu', weights_only=True)
|
| 60 |
+
# if config.train.skipped_keys:
|
| 61 |
+
# ckpt = {k: v for k, v in ckpt.items() if k not in config.train.skipped_keys}
|
| 62 |
+
# m, u = denoiser.lm_head.load_state_dict(ckpt, strict=False)
|
| 63 |
+
# if accelerator.is_main_process:
|
| 64 |
+
# print(f'model ckpt loaded from {config.train.head_resume_path}')
|
| 65 |
+
|
| 66 |
+
global_step = config.train.global_step if config.train.global_step is not None else 0
|
| 67 |
+
params_to_learn = list(param for param in denoiser.parameters() if param.requires_grad)
|
| 68 |
+
optimizer = torch.optim.AdamW(
|
| 69 |
+
params_to_learn,
|
| 70 |
+
lr = config.train.lr,
|
| 71 |
+
betas = (0.9, 0.95),
|
| 72 |
+
weight_decay = 5e-2,
|
| 73 |
+
eps = 1e-8,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
denoiser, dataloader, optimizer = accelerator.prepare(denoiser, dataloader, optimizer)
|
| 77 |
+
|
| 78 |
+
config.device_count = accelerator.num_processes
|
| 79 |
+
if accelerator.is_main_process:
|
| 80 |
+
accelerator.init_trackers(config.train.wandb_proj, config=flatten_dict(config))
|
| 81 |
+
|
| 82 |
+
training_done = False
|
| 83 |
+
epoch = 0
|
| 84 |
+
progress_bar = tqdm(
|
| 85 |
+
total = config.train.num_iters,
|
| 86 |
+
initial = global_step,
|
| 87 |
+
desc = 'Steps',
|
| 88 |
+
disable = not accelerator.is_local_main_process,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
if accelerator.is_main_process:
|
| 92 |
+
print(f'Learnable parameters: {sum(p.numel() for p in params_to_learn if p.requires_grad) / 1e9} B')
|
| 93 |
+
|
| 94 |
+
while not training_done:
|
| 95 |
+
if accelerator.is_main_process:
|
| 96 |
+
print(f'Epoch: {epoch}')
|
| 97 |
+
for batch in dataloader:
|
| 98 |
+
with accelerator.accumulate([denoiser]):
|
| 99 |
+
denoiser.train()
|
| 100 |
+
input_ids = batch['data']
|
| 101 |
+
# print("input_ids",input_ids.dtype)
|
| 102 |
+
question_length = batch['question_length']
|
| 103 |
+
|
| 104 |
+
# Use unified loss function selection
|
| 105 |
+
losses = compute_loss_by_config(
|
| 106 |
+
input_ids,
|
| 107 |
+
denoiser,
|
| 108 |
+
question_length,
|
| 109 |
+
block_size = config.train.block_size,
|
| 110 |
+
mask_id = config.denoiser.encoder.mask_id,
|
| 111 |
+
enable_shift = config.train.enable_shift,
|
| 112 |
+
share_steps = config.train.share_steps,
|
| 113 |
+
self_align = config.train.self_align,
|
| 114 |
+
feature_align = config.train.feature_align,
|
| 115 |
+
self_step = config.train.self_step,
|
| 116 |
+
eos_id = tokenizer.eos_token_id,
|
| 117 |
+
config = config
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
if config.train.share_steps > 1:
|
| 121 |
+
loss_tgt = losses['loss']
|
| 122 |
+
# loss_1 = losses['loss_1']
|
| 123 |
+
# loss_2 = losses['loss_2']
|
| 124 |
+
else:
|
| 125 |
+
raise NotImplementedError
|
| 126 |
+
torch.cuda.empty_cache()
|
| 127 |
+
accelerator.backward(loss_tgt)
|
| 128 |
+
if accelerator.sync_gradients:
|
| 129 |
+
accelerator.clip_grad_norm_(params_to_learn, 1.0)
|
| 130 |
+
|
| 131 |
+
optimizer.step()
|
| 132 |
+
optimizer.zero_grad()
|
| 133 |
+
|
| 134 |
+
if accelerator.sync_gradients:
|
| 135 |
+
global_step += 1
|
| 136 |
+
progress_bar.update(1)
|
| 137 |
+
logs = dict()
|
| 138 |
+
loss_tgt = accelerator.gather(loss_tgt.detach()).mean().item()
|
| 139 |
+
logs['loss'] = loss_tgt
|
| 140 |
+
# if config.train.share_steps > 1:
|
| 141 |
+
# loss_1 = accelerator.gather(loss_1.detach()).mean().item()
|
| 142 |
+
# loss_2 = accelerator.gather(loss_2.detach()).mean().item()
|
| 143 |
+
# logs['loss_1'] = loss_1
|
| 144 |
+
# logs['loss_2'] = loss_2
|
| 145 |
+
|
| 146 |
+
accelerator.log(logs, step=global_step)
|
| 147 |
+
progress_bar.set_postfix(**logs)
|
| 148 |
+
|
| 149 |
+
if global_step > 0 and global_step % config.train.eval_every == 0 and accelerator.is_main_process:
|
| 150 |
+
denoiser.eval();
|
| 151 |
+
question = 'Henry made two stops during his 60-mile bike trip. He first stopped after 20 miles. His second stop was 15 miles before the end of the trip. How many miles did he travel between his first and second stops?'
|
| 152 |
+
# prompt = tokenizer(question)['input_ids']
|
| 153 |
+
# prompt = torch.tensor(prompt).to(accelerator.device).unsqueeze(0)
|
| 154 |
+
messages = [
|
| 155 |
+
{"role": "user", "content": question}
|
| 156 |
+
]
|
| 157 |
+
prompt = tokenizer.apply_chat_template(
|
| 158 |
+
messages, return_tensors="pt", return_dict=True, add_generation_prompt=True
|
| 159 |
+
).input_ids
|
| 160 |
+
prompt = prompt.to(accelerator.device)
|
| 161 |
+
|
| 162 |
+
mask_id = 151666
|
| 163 |
+
gen_len = 512 - prompt.shape[1]
|
| 164 |
+
temperature = 0.2
|
| 165 |
+
top_p = 0.95
|
| 166 |
+
|
| 167 |
+
x_t = torch.cat([prompt, torch.tensor([[mask_id]*gen_len]).to(accelerator.device)], dim=1)
|
| 168 |
+
with torch.inference_mode():
|
| 169 |
+
for i in range(gen_len):
|
| 170 |
+
mask_index = (x_t == mask_id)
|
| 171 |
+
if i % 2 == 0:
|
| 172 |
+
z_t = denoiser.module.encoder(x_t, output_hidden_states=True).hidden_states[-1]
|
| 173 |
+
hidden_state = denoiser.module.decoder(x_t, z_t)
|
| 174 |
+
logits = denoiser.module.encoder.lm_head(hidden_state)
|
| 175 |
+
else:
|
| 176 |
+
hidden_state = denoiser.module.decoder(x_t, z_t)
|
| 177 |
+
logits = denoiser.module.lm_head(hidden_state)
|
| 178 |
+
|
| 179 |
+
if config.train.enable_shift:
|
| 180 |
+
logits = shift_logits(logits)
|
| 181 |
+
|
| 182 |
+
mask_logits = logits[mask_index]
|
| 183 |
+
confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=None, neg_entropy=True)
|
| 184 |
+
|
| 185 |
+
number_transfer_tokens = 1
|
| 186 |
+
_, transfer_index = torch.topk(confidence, number_transfer_tokens)
|
| 187 |
+
x0_ = torch.zeros_like(x0, device=accelerator.device, dtype=torch.long) + mask_id
|
| 188 |
+
x0_[transfer_index] = x0[transfer_index].clone()
|
| 189 |
+
x_t[mask_index] = x0_
|
| 190 |
+
|
| 191 |
+
answer = tokenizer.batch_decode(x_t[:, prompt.shape[1]:], skip_special_tokens=True)[0]
|
| 192 |
+
print(answer)
|
| 193 |
+
|
| 194 |
+
accelerator.wait_for_everyone()
|
| 195 |
+
|
| 196 |
+
if global_step > 0 and global_step % config.train.save_every == 0 and accelerator.is_main_process:
|
| 197 |
+
denoiser.eval()
|
| 198 |
+
decoder_state_dict = accelerator.unwrap_model(denoiser).save_pretrained(os.path.join(output_dir, f"Decoder-{config.train.exp_name}-{global_step // 1000}k"))
|
| 199 |
+
# lmhead_state_dict = accelerator.unwrap_model(denoiser).lm_head.state_dict()
|
| 200 |
+
# torch.save(lmhead_state_dict, os.path.join(output_dir, f"LMhead-{config.train.exp_name}-{global_step // 1000}k"))
|
| 201 |
+
accelerator.wait_for_everyone()
|
| 202 |
+
if global_step >= config.train.num_iters:
|
| 203 |
+
training_done = True
|
| 204 |
+
break
|
| 205 |
+
epoch += 1
|
| 206 |
+
accelerator.end_training()
|
| 207 |
+
if dist.is_initialized():
|
| 208 |
+
dist.destroy_process_group()
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
if __name__ == '__main__':
|
| 212 |
+
parser = argparse.ArgumentParser()
|
| 213 |
+
parser.add_argument('--config', type=str, default='config/dream.yaml')
|
| 214 |
+
args = parser.parse_args()
|
| 215 |
+
main(args)
|
Discrete-Diffusion-Forcing/D2F-train/train.sh
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# export CUDA_LAUNCH_BLOCKING=1
|
| 2 |
+
CUDA_VISIBLE_DEVICES=4 accelerate launch --config_file config/acc_config --num_processes 1 --main_process_port 29577 train.py --config config/llada.yaml
|
| 3 |
+
|
| 4 |
+
CUDA_VISIBLE_DEVICES=4 accelerate launch --config_file config/acc_config --num_processes 1 --main_process_port 29577 train.py --config config/dream_eagle.yaml
|
Discrete-Diffusion-Forcing/D2F-train/utils/__init__.py
ADDED
|
File without changes
|
Discrete-Diffusion-Forcing/D2F-train/utils/data.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datasets import load_dataset
|
| 2 |
+
from torch.utils.data import DataLoader,Dataset
|
| 3 |
+
from peft import PeftModel, PeftConfig, get_peft_model
|
| 4 |
+
# from modelscope.msdatasets import MsDataset
|
| 5 |
+
import torch
|
| 6 |
+
import json
|
| 7 |
+
import re
|
| 8 |
+
def extract_answer(text):
|
| 9 |
+
pattern = r"<\|begin_of_solution\|>(.*?)<\|end_of_solution\|>"
|
| 10 |
+
match = re.search(pattern, text, re.DOTALL)
|
| 11 |
+
|
| 12 |
+
if match:
|
| 13 |
+
solution_content = match.group(1).strip()
|
| 14 |
+
# print("Extracted content:\n")
|
| 15 |
+
# print(solution_content)
|
| 16 |
+
return solution_content
|
| 17 |
+
else:
|
| 18 |
+
# print("No matching content found.")
|
| 19 |
+
return None
|
| 20 |
+
def collate_fn(batch, tokenizer, max_length):
|
| 21 |
+
"""
|
| 22 |
+
batch: list of raw text samples (str)
|
| 23 |
+
tokenizer: huggingface tokenizer
|
| 24 |
+
max_length: maximum length to pad to (int)
|
| 25 |
+
"""
|
| 26 |
+
encoded_batch = []
|
| 27 |
+
for text in batch:
|
| 28 |
+
# Encode text, return dictionary, note no automatic padding
|
| 29 |
+
enc = tokenizer(text["text"], add_special_tokens=False, return_tensors="pt")
|
| 30 |
+
input_ids = enc["input_ids"].squeeze(0) # (seq_len,)
|
| 31 |
+
|
| 32 |
+
# Add eos_token_id
|
| 33 |
+
eos_id = tokenizer.eos_token_id
|
| 34 |
+
if eos_id is None:
|
| 35 |
+
raise ValueError("tokenizer does not have eos_token_id")
|
| 36 |
+
|
| 37 |
+
input_ids = torch.cat([input_ids, torch.tensor([eos_id], device=input_ids.device)])
|
| 38 |
+
|
| 39 |
+
# Padding to max_length
|
| 40 |
+
pad_id = tokenizer.pad_token_id
|
| 41 |
+
if pad_id is None:
|
| 42 |
+
raise ValueError("tokenizer does not have pad_token_id")
|
| 43 |
+
|
| 44 |
+
seq_len = input_ids.size(0)
|
| 45 |
+
if seq_len > max_length:
|
| 46 |
+
# Truncate if too long
|
| 47 |
+
input_ids = input_ids[:max_length]
|
| 48 |
+
else:
|
| 49 |
+
# Pad right side if not long enough
|
| 50 |
+
pad_len = max_length - seq_len
|
| 51 |
+
padding = torch.full((pad_len,), pad_id, device=input_ids.device, dtype=input_ids.dtype)
|
| 52 |
+
input_ids = torch.cat([input_ids, padding])
|
| 53 |
+
|
| 54 |
+
encoded_batch.append(input_ids)
|
| 55 |
+
|
| 56 |
+
return torch.stack(encoded_batch)
|
| 57 |
+
|
| 58 |
+
def prepare_dataloader(data, tokenizer, batch_size, max_length):
|
| 59 |
+
dataset = CustomDataset(data)
|
| 60 |
+
dataloader = DataLoader(
|
| 61 |
+
dataset,
|
| 62 |
+
batch_size = batch_size,
|
| 63 |
+
collate_fn = lambda x: collate_fn(x, tokenizer, max_length=max_length),
|
| 64 |
+
num_workers = 0,
|
| 65 |
+
shuffle = True,
|
| 66 |
+
pin_memory = True,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
return dataloader
|
| 70 |
+
|
| 71 |
+
def read_math():
|
| 72 |
+
math_data = []
|
| 73 |
+
dataset = load_dataset("microsoft/orca-math-word-problems-200k", split="train")
|
| 74 |
+
for item in dataset:
|
| 75 |
+
math_data.append({"question": item['question'], "answer": item['answer']})
|
| 76 |
+
return math_data
|
| 77 |
+
|
| 78 |
+
def read_python():
|
| 79 |
+
python_data = []
|
| 80 |
+
dataset = load_dataset("microsoft/orca-math-word-problems-200k", split="train")
|
| 81 |
+
for item in dataset:
|
| 82 |
+
python_data.append({"question": item['question'], "answer": item['answer']})
|
| 83 |
+
return python_data
|
| 84 |
+
|
| 85 |
+
def read_numinamath():
|
| 86 |
+
math_data = read_math()
|
| 87 |
+
python_data = read_python()
|
| 88 |
+
return math_data + python_data
|
| 89 |
+
|
| 90 |
+
def read_bs(config=None):
|
| 91 |
+
data=[]
|
| 92 |
+
# Get path from config, use default path if no config
|
| 93 |
+
if config and hasattr(config, 'paths') and hasattr(config.paths, 'data') and hasattr(config.paths.data, 'bs'):
|
| 94 |
+
dataset_path = config.paths.data.bs
|
| 95 |
+
else:
|
| 96 |
+
dataset_path = "/data1/xck/dllm_block_wx/data/Lansechen/bs17k_collection_filtered_hard_maxlength600"
|
| 97 |
+
|
| 98 |
+
dataset=load_dataset(dataset_path, split="train")
|
| 99 |
+
for item in dataset:
|
| 100 |
+
data.append({"question": item['question'], "answer": item['qwen7b_answer']})
|
| 101 |
+
return data
|
| 102 |
+
|
| 103 |
+
def read_bs_easy(config=None):
|
| 104 |
+
data=[]
|
| 105 |
+
# Get path from config, use default path if no config
|
| 106 |
+
if config and hasattr(config, 'paths') and hasattr(config.paths, 'data') and hasattr(config.paths.data, 'bs_easy'):
|
| 107 |
+
dataset_path = config.paths.data.bs_easy
|
| 108 |
+
else:
|
| 109 |
+
dataset_path = "/data1/xck/dllm_block_wx/data/Lansechen/bs17k_collection_filtered_easy_maxlength600"
|
| 110 |
+
|
| 111 |
+
dataset=load_dataset(dataset_path, split="train")
|
| 112 |
+
for item in dataset:
|
| 113 |
+
data.append({"question": item['question'], "answer": item['qwen7b_answer']})
|
| 114 |
+
return data
|
| 115 |
+
|
| 116 |
+
def read_bs_17k():
|
| 117 |
+
data=[]
|
| 118 |
+
dataset=load_dataset("/data/wx/dataset/bespokelabs/Bespoke-Stratos-17k",split="train")
|
| 119 |
+
for item in dataset:
|
| 120 |
+
item=item["conversations"]
|
| 121 |
+
data.append({"question": item[0]['value'], "answer": extract_answer(item[1]['value'])})
|
| 122 |
+
return data
|
| 123 |
+
class CustomDataset(Dataset):
|
| 124 |
+
def __init__(self, data):
|
| 125 |
+
self.data = data
|
| 126 |
+
|
| 127 |
+
def __len__(self):
|
| 128 |
+
return len(self.data)
|
| 129 |
+
|
| 130 |
+
def __getitem__(self, idx):
|
| 131 |
+
return self.data[idx]
|
| 132 |
+
def read_llada(file_path="/home/wx/dllm_block/data/merged_bs17k_easy_hard_llada_collected.jsonl"):
|
| 133 |
+
data = []
|
| 134 |
+
with open(file_path, 'r', encoding='utf-8') as file:
|
| 135 |
+
for line in file:
|
| 136 |
+
try:
|
| 137 |
+
json_obj = json.loads(line)
|
| 138 |
+
data.append(json_obj)
|
| 139 |
+
except json.JSONDecodeError:
|
| 140 |
+
print(f'JSONDecodeError: {line}')
|
| 141 |
+
return data
|
| 142 |
+
def get_bs17k_dataloader(tokenizer, config, max_length=1024):
|
| 143 |
+
train_dataset = []
|
| 144 |
+
# Pass global config to data reading functions
|
| 145 |
+
global_config = getattr(config, '_parent', config) # Try to get parent config
|
| 146 |
+
data_dict=read_bs(global_config)+read_bs_easy(global_config)
|
| 147 |
+
for data in data_dict:
|
| 148 |
+
question = data['question']
|
| 149 |
+
answer = data['answer']
|
| 150 |
+
|
| 151 |
+
# messages = [
|
| 152 |
+
# {"role": "user", "content": "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?"},
|
| 153 |
+
# ]
|
| 154 |
+
messages = [
|
| 155 |
+
{"role": "user", "content": question}
|
| 156 |
+
]
|
| 157 |
+
question = tokenizer.apply_chat_template(
|
| 158 |
+
messages, return_tensors="pt", return_dict=True, add_generation_prompt=True
|
| 159 |
+
).input_ids[0]
|
| 160 |
+
|
| 161 |
+
# question = tokenizer(question, return_tensors='pt')['input_ids'][0]
|
| 162 |
+
answer = tokenizer(answer, return_tensors='pt')['input_ids'][0]
|
| 163 |
+
answer = torch.cat((answer, torch.tensor([tokenizer.eos_token_id])), dim=-1)
|
| 164 |
+
|
| 165 |
+
question_length = question.shape[-1]
|
| 166 |
+
answer_length = answer.shape[-1]
|
| 167 |
+
combined_length = question_length + answer_length
|
| 168 |
+
if question_length > max_length-100:
|
| 169 |
+
continue
|
| 170 |
+
if combined_length > max_length:
|
| 171 |
+
padded_data = torch.cat((question, answer), dim=-1)
|
| 172 |
+
padded_data = padded_data[:max_length] # Truncate to max_length
|
| 173 |
+
else:
|
| 174 |
+
padding_length = max_length - combined_length
|
| 175 |
+
padding = torch.full((padding_length,), tokenizer.eos_token_id, dtype=question.dtype)
|
| 176 |
+
padded_data = torch.cat((question, answer, padding), dim=-1)
|
| 177 |
+
|
| 178 |
+
train_dataset.append(
|
| 179 |
+
dict(
|
| 180 |
+
data = padded_data,
|
| 181 |
+
question_length = question_length,
|
| 182 |
+
length = combined_length,
|
| 183 |
+
)
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
dataset = CustomDataset(train_dataset)
|
| 187 |
+
dataloader = DataLoader(
|
| 188 |
+
dataset,
|
| 189 |
+
batch_size = config.batch_size,
|
| 190 |
+
num_workers = 0,
|
| 191 |
+
shuffle = True,
|
| 192 |
+
pin_memory = True,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
return dataloader
|
| 196 |
+
|
| 197 |
+
# def get_gsm8k_dataloader(tokenizer, config, max_length=1024):
|
| 198 |
+
# train_dataset = []
|
| 199 |
+
# data_dict = read_numinamath()
|
| 200 |
+
# for data in data_dict:
|
| 201 |
+
# question = data['question']
|
| 202 |
+
# answer = data['answer']
|
| 203 |
+
|
| 204 |
+
# question = tokenizer(question, return_tensors='pt')['input_ids'][0]
|
| 205 |
+
# answer = tokenizer(answer, return_tensors='pt')['input_ids'][0]
|
| 206 |
+
# answer = torch.cat((answer, torch.tensor([tokenizer.eos_token_id])), dim=-1)
|
| 207 |
+
|
| 208 |
+
# question_length = question.shape[-1]
|
| 209 |
+
# answer_length = answer.shape[-1]
|
| 210 |
+
# combined_length = question_length + answer_length
|
| 211 |
+
|
| 212 |
+
# if combined_length > max_length:
|
| 213 |
+
# continue
|
| 214 |
+
|
| 215 |
+
# padding_length = max_length - combined_length
|
| 216 |
+
# padding = torch.full((padding_length,), tokenizer.eos_token_id, dtype=question.dtype)
|
| 217 |
+
# padded_data = torch.cat((question, answer, padding), dim=-1)
|
| 218 |
+
|
| 219 |
+
# train_dataset.append(
|
| 220 |
+
# dict(
|
| 221 |
+
# data = padded_data,
|
| 222 |
+
# question_length = question_length,
|
| 223 |
+
# length = combined_length,
|
| 224 |
+
# )
|
| 225 |
+
# )
|
| 226 |
+
|
| 227 |
+
# dataset = CustomDataset(train_dataset)
|
| 228 |
+
# dataloader = DataLoader(
|
| 229 |
+
# dataset,
|
| 230 |
+
# batch_size = config.batch_size,
|
| 231 |
+
# collate_fn = lambda x: collate_fn_pad(x, tokenizer, max_length=max_length),
|
| 232 |
+
# num_workers = 0,
|
| 233 |
+
# shuffle = True,
|
| 234 |
+
# pin_memory = True,
|
| 235 |
+
# )
|
| 236 |
+
|
| 237 |
+
# return dataloader
|
| 238 |
+
def get_llada_bs17k_dataloader(tokenizer, config, max_length=1024):
|
| 239 |
+
train_dataset = []
|
| 240 |
+
# Pass global config to data reading functions
|
| 241 |
+
global_config = getattr(config, '_parent', config) # Try to get parent config
|
| 242 |
+
data_dict = read_bs(global_config)
|
| 243 |
+
python_dict=read_bs_easy(global_config)
|
| 244 |
+
data_dict=data_dict+python_dict
|
| 245 |
+
print("Data length:",len(data_dict))
|
| 246 |
+
# data_dict = read_llada()
|
| 247 |
+
for data in data_dict:
|
| 248 |
+
question = data['question']
|
| 249 |
+
answer = data['answer']
|
| 250 |
+
|
| 251 |
+
# messages = [
|
| 252 |
+
# {"role": "user", "content": "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?"},
|
| 253 |
+
# ]
|
| 254 |
+
messages = [
|
| 255 |
+
{"role": "user", "content": question}
|
| 256 |
+
]
|
| 257 |
+
question = tokenizer.apply_chat_template(
|
| 258 |
+
messages, return_tensors="pt", return_dict=True, add_generation_prompt=True
|
| 259 |
+
).input_ids[0]
|
| 260 |
+
|
| 261 |
+
# question = tokenizer(question, return_tensors='pt')['input_ids'][0]
|
| 262 |
+
answer = tokenizer(answer, return_tensors='pt')['input_ids'][0]
|
| 263 |
+
answer = torch.cat((answer, torch.tensor([126348])), dim=-1)
|
| 264 |
+
|
| 265 |
+
question_length = question.shape[-1]
|
| 266 |
+
answer_length = answer.shape[-1]
|
| 267 |
+
combined_length = question_length + answer_length
|
| 268 |
+
|
| 269 |
+
if combined_length > max_length:
|
| 270 |
+
continue
|
| 271 |
+
|
| 272 |
+
padding_length = max_length - combined_length
|
| 273 |
+
padding = torch.full((padding_length,), tokenizer.eos_token_id, dtype=question.dtype)
|
| 274 |
+
padded_data = torch.cat((question, answer, padding), dim=-1)
|
| 275 |
+
|
| 276 |
+
train_dataset.append(
|
| 277 |
+
dict(
|
| 278 |
+
data = padded_data,
|
| 279 |
+
question_length = question_length,
|
| 280 |
+
length = combined_length,
|
| 281 |
+
)
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
dataset = CustomDataset(train_dataset)
|
| 285 |
+
dataloader = DataLoader(
|
| 286 |
+
dataset,
|
| 287 |
+
batch_size = config.batch_size,
|
| 288 |
+
num_workers = 0,
|
| 289 |
+
shuffle = True,
|
| 290 |
+
pin_memory = True,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
return dataloader
|
| 294 |
+
if __name__ == "__main__":
|
| 295 |
+
text="<|begin_of_thought|>\n\nOkay, let me try to figure out this problem. So, we have this operation defined as a⊗b = a²/b. And we need to compute [(1⊗2)⊗3] - [1⊗(2⊗3)]. Then choose the correct answer from the options given. Alright, let's break it down step by step.\n\nFirst, I need to remember that the operation ⊗ is not associative, right? Because the problem is asking for the difference between two different groupings: (1⊗2)⊗3 and 1⊗(2⊗3). So, the order in which we perform the operations matters here. That's probably why there's a subtraction between them.\n\nLet me start by computing each part separately. Let's tackle the first part: (1⊗2)⊗3.\n\nStarting with the innermost operation, which is 1⊗2. According to the definition, a⊗b = a²/b. So here, a is 1 and b is 2. Plugging those in: 1² / 2 = 1/2. So, 1⊗2 equals 1/2.\n\nNow, we take that result and perform the next operation with 3. So, (1⊗2)⊗3 becomes (1/2)⊗3. Again, using the same definition: a is now 1/2 and b is 3. So, ( (1/2)² ) / 3 = (1/4) / 3 = 1/12. So, (1⊗2)⊗3 equals 1/12.\n\nAlright, that's the first part. Now let's compute the second part: 1⊗(2⊗3). Again, starting with the innermost operation, which is 2⊗3. Applying the definition: a is 2 and b is 3. So, 2² / 3 = 4/3. Therefore, 2⊗3 equals 4/3.\n\nNow, we need to compute 1⊗(4/3). Here, a is 1 and b is 4/3. Using the operation definition: 1² / (4/3) = 1 / (4/3) = 3/4. So, 1⊗(2⊗3) equals 3/4.\n\nNow, the problem asks for the difference between the two results: [(1⊗2)⊗3] - [1⊗(2⊗3)] = (1/12) - (3/4). To subtract these fractions, they need a common denominator. The denominators are 12 and 4, so 12 is the common denominator.\n\nConverting 3/4 to twelfths: 3/4 = 9/12. So, 1/12 - 9/12 = (1 - 9)/12 = -8/12. Simplifying that fraction by dividing numerator and denominator by 4: -8/12 = -2/3.\n\nHmm, looking at the answer choices, option A is -2/3. So, is that the answer? Wait, but let me double-check my calculations to make sure I didn't make a mistake somewhere.\n\nFirst, checking (1⊗2): 1² / 2 = 1/2. Correct. Then, (1/2)⊗3: (1/2)² / 3 = (1/4)/3 = 1/12. That seems right.\n\nNow, for 2⊗3: 2² / 3 = 4/3. Correct. Then, 1⊗(4/3): 1² / (4/3) = 1 / (4/3) = 3/4. Yes, that's correct.\n\nSubtracting 3/4 from 1/12: 1/12 - 3/4. Convert 3/4 to 9/12, so 1/12 - 9/12 = -8/12 = -2/3. Yes, that all checks out. So the answer should be -2/3, which is option A.\n\nWait, but let me think again. The operation is defined for all nonzero numbers, so we don't have any issues with division by zero here. 2⊗3 is 4/3, which is fine, and then 1⊗(4/3) is 3/4. Correct.\n\nAlternatively, maybe there's a different way to approach the problem? Let me try expanding both expressions using variables to see if there's a pattern.\n\nLet's denote the first expression: (a⊗b)⊗c. Using the definition:\n\nFirst, compute a⊗b = a²/b.\n\nThen, take that result and ⊗ with c: (a²/b)⊗c = ( (a²/b)² ) / c = a⁴ / (b² c).\n\nNow, the second expression: a⊗(b⊗c). First compute b⊗c = b²/c.\n\nThen, a⊗(b²/c) = a² / (b²/c) = a² * (c / b²) = (a² c) / b².\n\nTherefore, the difference between the two expressions is:\n\n(a⁴ / (b² c)) - (a² c / b²) = (a⁴ - a² c²) / (b² c) = a² (a² - c²) / (b² c).\n\nHmm, factoring that, it's a² (a - c)(a + c) / (b² c).\n\nBut in our specific problem, a = 1, b = 2, c = 3. Plugging those values in:\n\n1² (1 - 3)(1 + 3) / (2² * 3) = 1 * (-2)(4) / (4 * 3) = (-8) / 12 = -2/3. Same result. So that confirms the answer is indeed -2/3.\n\nTherefore, I think my initial calculation was correct, and the answer is option A.\n\n**Final Answer**\n\\boxed{A}\n\n<|end_of_thought|>\n\n<|begin_of_solution|>\n\nTo determine the value of \\([(1 \\otimes 2) \\otimes 3] - [1 \\otimes (2 \\otimes 3)]\\) where the operation \\(\\otimes\\) is defined by \\(a \\otimes b = \\frac{a^2}{b}\\), we proceed as follows:\n\nFirst, compute \\(1 \\otimes 2\\):\n\\[\n1 \\otimes 2 = \\frac{1^2}{2} = \\frac{1}{2}\n\\]\nNext, use this result to compute \\((1 \\otimes 2) \\otimes 3\\):\n\\[\n\\left(\\frac{1}{2}\\right) \\otimes 3 = \\frac{\\left(\\frac{1}{2}\\right)^2}{3} = \\frac{\\frac{1}{4}}{3} = \\frac{1}{12}\n\\]\n\nNow, compute \\(2 \\otimes 3\\):\n\\[\n2 \\otimes 3 = \\frac{2^2}{3} = \\frac{4}{3}\n\\]\nThen, use this result to compute \\(1 \\otimes (2 \\otimes 3)\\):\n\\[\n1 \\otimes \\left(\\frac{4}{3}\\right) = \\frac{1^2}{\\frac{4}{3}} = \\frac{1}{\\frac{4}{3}} = \\frac{3}{4}\n\\]\n\nFinally, find the difference between the two results:\n\\[\n\\frac{1}{12} - \\frac{3}{4} = \\frac{1}{12} - \\frac{9}{12} = \\frac{1 - 9}{12} = \\frac{-8}{12} = -\\frac{2}{3}\n\\]\n\nThus, the answer is \\(\\boxed{A}\\).\n\n<|end_of_solution|>"
|
| 296 |
+
print(extract_answer(text))
|
| 297 |
+
|
| 298 |
+
def get_dataloader_by_config(tokenizer, config, global_config=None, max_length=1024):
|
| 299 |
+
"""Select different data loaders based on config file"""
|
| 300 |
+
if global_config is None:
|
| 301 |
+
global_config = config
|
| 302 |
+
|
| 303 |
+
training_mode = global_config.get('training_mode', 'dream')
|
| 304 |
+
|
| 305 |
+
# Add reference to global config for data loading functions to access
|
| 306 |
+
config._parent = global_config
|
| 307 |
+
|
| 308 |
+
if training_mode == 'llada':
|
| 309 |
+
return get_llada_bs17k_dataloader(tokenizer, config, max_length)
|
| 310 |
+
elif training_mode == 'dream':
|
| 311 |
+
return get_bs17k_dataloader(tokenizer, config, max_length)
|
| 312 |
+
else:
|
| 313 |
+
raise ValueError(f"Unsupported training mode: {training_mode}")
|
Discrete-Diffusion-Forcing/D2F-train/utils/generation.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import torch.distributions as dists
|
| 4 |
+
from peft import PeftModel, PeftConfig
|
| 5 |
+
def build_custom_float_attention_mask(input_ids, prompt_length, block_size, device=None):
|
| 6 |
+
B,seq_len= input_ids.shape
|
| 7 |
+
# 初始化为全 -inf
|
| 8 |
+
attn_mask = torch.full((B,1,seq_len, seq_len), float('-inf'), dtype=torch.float32, device=device)
|
| 9 |
+
# 1. Prompt部分:每个token可以注意整个prompt
|
| 10 |
+
for i in range(B):
|
| 11 |
+
attn_mask[i,:,:,:prompt_length[i]] = 0.0 # 允许所有 token 看 prompt
|
| 12 |
+
|
| 13 |
+
# 2. 块划分:从 prompt_length 开始划分 block
|
| 14 |
+
num_blocks = (seq_len - prompt_length[i] + block_size - 1) // block_size
|
| 15 |
+
|
| 16 |
+
for b in range(num_blocks):
|
| 17 |
+
block_start = prompt_length[i] + b * block_size
|
| 18 |
+
# print(block_start,block_size,seq_len)
|
| 19 |
+
block_end = min(block_start + block_size, seq_len)
|
| 20 |
+
|
| 21 |
+
# 块内全注意
|
| 22 |
+
attn_mask[i,:,block_start:block_end, block_start:block_end] = 0.0
|
| 23 |
+
|
| 24 |
+
# 块之间因果注意(只能看前面块)
|
| 25 |
+
for prev_b in range(b):
|
| 26 |
+
prev_start = prompt_length[i] + prev_b * block_size
|
| 27 |
+
prev_end = min(prev_start + block_size, seq_len)
|
| 28 |
+
|
| 29 |
+
# 当前块可以看前面块
|
| 30 |
+
attn_mask[i,:,block_start:block_end, prev_start:prev_end] = 0.0
|
| 31 |
+
|
| 32 |
+
return attn_mask
|
| 33 |
+
def top_p_logits(logits, top_p=None):
|
| 34 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 35 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 36 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 37 |
+
# Shift the indices to the right to keep the first token above the threshold
|
| 38 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 39 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 40 |
+
|
| 41 |
+
mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
|
| 42 |
+
mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
|
| 43 |
+
logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
|
| 44 |
+
return logits
|
| 45 |
+
|
| 46 |
+
def top_k_logits(logits, top_k=None):
|
| 47 |
+
top_k = min(top_k, logits.size(-1)) # Safety check
|
| 48 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
| 49 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
| 50 |
+
logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
|
| 51 |
+
return logits
|
| 52 |
+
|
| 53 |
+
def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
|
| 54 |
+
if temperature > 0:
|
| 55 |
+
logits = logits / temperature
|
| 56 |
+
if top_p is not None and top_p < 1:
|
| 57 |
+
logits = top_p_logits(logits, top_p)
|
| 58 |
+
if top_k is not None:
|
| 59 |
+
logits = top_k_logits(logits, top_k)
|
| 60 |
+
probs = torch.softmax(logits, dim=-1)
|
| 61 |
+
|
| 62 |
+
if temperature > 0:
|
| 63 |
+
try:
|
| 64 |
+
x0 = dists.Categorical(probs=probs).sample()
|
| 65 |
+
confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
|
| 66 |
+
except:
|
| 67 |
+
confidence, x0 = probs.max(dim=-1)
|
| 68 |
+
else:
|
| 69 |
+
confidence, x0 = probs.max(dim=-1)
|
| 70 |
+
|
| 71 |
+
if margin_confidence:
|
| 72 |
+
sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
|
| 73 |
+
# Extract top1 and top2 probabilities
|
| 74 |
+
top1_probs = sorted_probs[:, 0]
|
| 75 |
+
top2_probs = sorted_probs[:, 1]
|
| 76 |
+
# Calculate confidence as top1 - top2
|
| 77 |
+
confidence = top1_probs - top2_probs
|
| 78 |
+
|
| 79 |
+
if neg_entropy:
|
| 80 |
+
epsilon = 1e-10
|
| 81 |
+
log_probs = torch.log(probs + epsilon)
|
| 82 |
+
confidence = torch.sum(probs * log_probs, dim=-1)
|
| 83 |
+
|
| 84 |
+
return confidence, x0
|
| 85 |
+
# def generate(model,prompt,block_size,max_length,mask_id):
|
| 86 |
+
# def generate(model, prompt, block_size, max_length, mask_id, eos_token_id=None):
|
| 87 |
+
# device = prompt.device
|
| 88 |
+
# output = prompt.clone()
|
| 89 |
+
|
| 90 |
+
# while output.shape[1] < max_length:
|
| 91 |
+
# # 添加一个 block 的 mask
|
| 92 |
+
# mask_block = torch.full((1, block_size), mask_id, dtype=torch.long, device=device)
|
| 93 |
+
# input_ids = torch.cat([output, mask_block], dim=1)
|
| 94 |
+
# attention_mask = build_custom_float_attention_mask(input_ids, torch.tensor([[prompt.shape[1]]]), block_size, device=device)
|
| 95 |
+
# attention_mask = attention_mask.to(torch.bfloat16)
|
| 96 |
+
# for i in range(block_size):
|
| 97 |
+
def generate_block(denoiser, block_size, mask_id,tokenizer,device):
|
| 98 |
+
denoiser.eval()
|
| 99 |
+
question = 'please give me a code about transformer model'
|
| 100 |
+
# prompt = tokenizer(question)['input_ids']
|
| 101 |
+
# prompt = torch.tensor(prompt).to(accelerator.device).unsqueeze(0)
|
| 102 |
+
messages = [
|
| 103 |
+
{"role": "user", "content": question}
|
| 104 |
+
]
|
| 105 |
+
prompt = tokenizer.apply_chat_template(
|
| 106 |
+
messages, return_tensors="pt", return_dict=True, add_generation_prompt=True
|
| 107 |
+
).input_ids
|
| 108 |
+
prompt = prompt.to(device)
|
| 109 |
+
|
| 110 |
+
mask_id = 151666
|
| 111 |
+
gen_len = (384 - prompt.shape[1])//block_size
|
| 112 |
+
print(gen_len)
|
| 113 |
+
temperature = 0.2
|
| 114 |
+
top_p = 0.95
|
| 115 |
+
with torch.inference_mode():
|
| 116 |
+
for i in range(gen_len):
|
| 117 |
+
if i==0:
|
| 118 |
+
x_t = torch.cat([prompt, torch.tensor([[mask_id]*block_size]).to(device)], dim=1)
|
| 119 |
+
else:
|
| 120 |
+
x_t = torch.cat([x_t, torch.tensor([[mask_id]*block_size]).to(device)], dim=1)
|
| 121 |
+
attention_mask = build_custom_float_attention_mask(x_t, torch.tensor([[prompt.shape[1]]]), block_size, device=device)
|
| 122 |
+
attention_mask = attention_mask.to(torch.bfloat16)
|
| 123 |
+
for n in range(block_size):
|
| 124 |
+
mask_index = (x_t == mask_id)
|
| 125 |
+
if mask_index.sum() == 0:
|
| 126 |
+
break
|
| 127 |
+
logits =denoiser(x_t, attention_mask=attention_mask).logits
|
| 128 |
+
logits = shift_logits(logits)
|
| 129 |
+
mask_logits = logits[mask_index]
|
| 130 |
+
confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=None, neg_entropy=True)
|
| 131 |
+
number_transfer_tokens = 1
|
| 132 |
+
_, transfer_index = torch.topk(confidence, number_transfer_tokens)
|
| 133 |
+
x0_ = torch.zeros_like(x0, device=device, dtype=torch.long) + mask_id
|
| 134 |
+
x0_[transfer_index] = x0[transfer_index].clone()
|
| 135 |
+
x_t[mask_index] = x0_
|
| 136 |
+
answer = tokenizer.batch_decode(x_t[:, prompt.shape[1]:], skip_special_tokens=False)[0]
|
| 137 |
+
print(answer)
|
| 138 |
+
answer = tokenizer.batch_decode(x_t[:, prompt.shape[1]:], skip_special_tokens=False)[0]
|
| 139 |
+
print(answer)
|
| 140 |
+
|
| 141 |
+
if __name__ == "__main__":
|
| 142 |
+
config = PeftConfig.from_pretrained("ybelkada/opt-350m-lora")
|
| 143 |
+
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
|
| 144 |
+
lora_model = PeftModel.from_pretrained(model, "ybelkada/opt-350m-lora")
|
Discrete-Diffusion-Forcing/D2F-train/utils/loss.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from utils.util import forward_process_length, shift_logits,forward_process
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
def compute_loss_by_config(
|
| 6 |
+
input_ids,
|
| 7 |
+
denoiser,
|
| 8 |
+
question_length,
|
| 9 |
+
mask_id,
|
| 10 |
+
block_size,
|
| 11 |
+
enable_shift,
|
| 12 |
+
share_steps,
|
| 13 |
+
self_align,
|
| 14 |
+
feature_align,
|
| 15 |
+
self_step,
|
| 16 |
+
eos_id,
|
| 17 |
+
config
|
| 18 |
+
):
|
| 19 |
+
"""Select different loss functions based on config file"""
|
| 20 |
+
training_mode = config.get('training_mode', 'dream')
|
| 21 |
+
|
| 22 |
+
if training_mode == 'llada':
|
| 23 |
+
return compute_llada_loss(
|
| 24 |
+
input_ids, denoiser, question_length, mask_id, block_size,
|
| 25 |
+
enable_shift, share_steps, self_align, feature_align, self_step, eos_id
|
| 26 |
+
)
|
| 27 |
+
elif training_mode == 'dream':
|
| 28 |
+
return compute_loss(
|
| 29 |
+
input_ids, denoiser, question_length, mask_id, block_size,
|
| 30 |
+
enable_shift, share_steps, self_align, feature_align, self_step, eos_id
|
| 31 |
+
)
|
| 32 |
+
else:
|
| 33 |
+
raise ValueError(f"Unsupported training mode: {training_mode}")
|
| 34 |
+
|
| 35 |
+
def compute_loss(
|
| 36 |
+
input_ids,
|
| 37 |
+
denoiser,
|
| 38 |
+
question_length,
|
| 39 |
+
mask_id,
|
| 40 |
+
block_size,
|
| 41 |
+
enable_shift,
|
| 42 |
+
share_steps,
|
| 43 |
+
self_align,
|
| 44 |
+
feature_align,
|
| 45 |
+
self_step,
|
| 46 |
+
eos_id,
|
| 47 |
+
):
|
| 48 |
+
B, L = input_ids.shape
|
| 49 |
+
noisy_batch, masked_indices, p_mask = forward_process_length(input_ids, mask_id=mask_id,prompt_lengths=question_length, block_size=block_size,eos_id=eos_id)
|
| 50 |
+
token_positions = torch.arange(L, device=noisy_batch.device).expand(B, L)
|
| 51 |
+
prompt_mask = (token_positions < question_length.unsqueeze(1))
|
| 52 |
+
noisy_batch[prompt_mask] = input_ids[prompt_mask]
|
| 53 |
+
# prompt_mask = prompt_mask.to(torch.int64)
|
| 54 |
+
noisy_batch = noisy_batch.to(denoiser.device)
|
| 55 |
+
attention_mask=build_custom_float_attention_mask(noisy_batch, question_length, block_size, device=noisy_batch.device)
|
| 56 |
+
attention_mask=attention_mask.to(torch.float16)
|
| 57 |
+
logits=denoiser(noisy_batch,attention_mask=attention_mask).logits
|
| 58 |
+
logits=shift_logits(logits)
|
| 59 |
+
if self_align:
|
| 60 |
+
with torch.no_grad():
|
| 61 |
+
with denoiser.disable_adapter():
|
| 62 |
+
# ref_model = denoiser
|
| 63 |
+
# ref_model.eval()
|
| 64 |
+
# print(type(ref_model))
|
| 65 |
+
# denoiser.eval()
|
| 66 |
+
ref_logits=denoiser(noisy_batch,attention_mask=torch.zeros([1,1,noisy_batch.shape[1],noisy_batch.shape[1]],dtype=torch.float16,device=denoiser.device)).logits
|
| 67 |
+
ref_logits=shift_logits(ref_logits)
|
| 68 |
+
ref_logits = torch.nn.functional.softmax(ref_logits, dim=-1)
|
| 69 |
+
# denoiser.train()
|
| 70 |
+
token_loss_2 = F.cross_entropy(logits[masked_indices], ref_logits[masked_indices], reduction='none') / p_mask[masked_indices]
|
| 71 |
+
# print("token_loss_2",token_loss_2.shape)
|
| 72 |
+
else:
|
| 73 |
+
token_loss_2= F.cross_entropy(logits[masked_indices], input_ids[masked_indices], reduction='none') / p_mask[masked_indices]
|
| 74 |
+
losses = {
|
| 75 |
+
# 'loss_1': token_loss_2.mean() * 0,
|
| 76 |
+
'loss': token_loss_2.mean(),
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
return losses
|
| 80 |
+
def compute_normal_loss(
|
| 81 |
+
input_ids,
|
| 82 |
+
denoiser,
|
| 83 |
+
question_length,
|
| 84 |
+
mask_id,
|
| 85 |
+
block_size,
|
| 86 |
+
enable_shift,
|
| 87 |
+
share_steps,
|
| 88 |
+
self_align,
|
| 89 |
+
feature_align,
|
| 90 |
+
self_step,
|
| 91 |
+
eos_id,
|
| 92 |
+
):
|
| 93 |
+
B, L = input_ids.shape
|
| 94 |
+
noisy_batch, masked_indices, p_mask = forward_process_length(input_ids, mask_id=mask_id,prompt_lengths=question_length, block_size=block_size,eos_id=eos_id)
|
| 95 |
+
token_positions = torch.arange(L, device=noisy_batch.device).expand(B, L)
|
| 96 |
+
prompt_mask = (token_positions < question_length.unsqueeze(1))
|
| 97 |
+
noisy_batch[prompt_mask] = input_ids[prompt_mask]
|
| 98 |
+
# prompt_mask = prompt_mask.to(torch.int64)
|
| 99 |
+
noisy_batch = noisy_batch.to(denoiser.device)
|
| 100 |
+
logits=denoiser(noisy_batch).logits
|
| 101 |
+
logits=shift_logits(logits)
|
| 102 |
+
token_loss_2= F.cross_entropy(logits[masked_indices], input_ids[masked_indices], reduction='none') / p_mask[masked_indices]
|
| 103 |
+
losses = {
|
| 104 |
+
# 'loss_1': token_loss_2.mean() * 0,
|
| 105 |
+
'loss': token_loss_2.mean(),
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
return losses
|
| 109 |
+
import torch
|
| 110 |
+
def compute_llada_loss(
|
| 111 |
+
input_ids,
|
| 112 |
+
denoiser,
|
| 113 |
+
question_length,
|
| 114 |
+
mask_id,
|
| 115 |
+
block_size,
|
| 116 |
+
enable_shift,
|
| 117 |
+
share_steps,
|
| 118 |
+
self_align,
|
| 119 |
+
feature_align,
|
| 120 |
+
self_step,
|
| 121 |
+
eos_id,
|
| 122 |
+
):
|
| 123 |
+
mask_id=126336
|
| 124 |
+
B, L = input_ids.shape
|
| 125 |
+
noisy_batch, masked_indices, p_mask = forward_process_length(input_ids, mask_id=mask_id,prompt_lengths=question_length, block_size=block_size,eos_id=eos_id)
|
| 126 |
+
token_positions = torch.arange(L, device=noisy_batch.device).expand(B, L)
|
| 127 |
+
prompt_mask = (token_positions < question_length.unsqueeze(1))
|
| 128 |
+
noisy_batch[prompt_mask] = input_ids[prompt_mask]
|
| 129 |
+
# prompt_mask = prompt_mask.to(torch.int64)
|
| 130 |
+
noisy_batch = noisy_batch.to(denoiser.device)
|
| 131 |
+
# print(noisy_batch)
|
| 132 |
+
attention_mask=build_custom_float_attention_mask(noisy_batch, question_length, block_size, device=noisy_batch.device)
|
| 133 |
+
attention_mask=attention_mask.to(torch.float16)
|
| 134 |
+
# print(type(denoiser),noisy_batch.shape,attention_mask.shape)
|
| 135 |
+
logits=denoiser(noisy_batch,attention_bias=attention_mask).logits
|
| 136 |
+
# logits=shift_logits(logits)
|
| 137 |
+
if self_align:
|
| 138 |
+
with torch.no_grad():
|
| 139 |
+
with denoiser.disable_adapter():
|
| 140 |
+
# ref_model = denoiser
|
| 141 |
+
# ref_model.eval()
|
| 142 |
+
# print(type(ref_model))
|
| 143 |
+
ref_logits=denoiser(noisy_batch,attention_bias=torch.zeros([1,1,noisy_batch.shape[1],noisy_batch.shape[1]],dtype=torch.float16,device=denoiser.device)).logits
|
| 144 |
+
# ref_logits=shift_logits(ref_logits)
|
| 145 |
+
ref_logits = torch.nn.functional.softmax(ref_logits, dim=-1)
|
| 146 |
+
token_loss_2 = F.cross_entropy(logits[masked_indices], ref_logits[masked_indices], reduction='none') / p_mask[masked_indices]
|
| 147 |
+
# print("token_loss_2",token_loss_2.shape)
|
| 148 |
+
else:
|
| 149 |
+
token_loss_2= F.cross_entropy(logits[masked_indices], input_ids[masked_indices], reduction='none') / p_mask[masked_indices]
|
| 150 |
+
losses = {
|
| 151 |
+
# 'loss_1': token_loss_2.mean() * 0,
|
| 152 |
+
'loss': token_loss_2.mean(),
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
return losses
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def build_custom_float_attention_mask(input_ids, prompt_length, block_size, device=None):
|
| 159 |
+
B,seq_len= input_ids.shape
|
| 160 |
+
# 初始化为全 -inf
|
| 161 |
+
attn_mask = torch.full((B,1,seq_len, seq_len), float('-inf'), dtype=torch.float32, device=device)
|
| 162 |
+
# 1. Prompt部分:每个token可以注意整个prompt
|
| 163 |
+
for i in range(B):
|
| 164 |
+
attn_mask[i,:,:,:prompt_length[i]] = 0.0 # 允许所有 token 看 prompt
|
| 165 |
+
|
| 166 |
+
# 2. 块划分:从 prompt_length 开始划分 block
|
| 167 |
+
num_blocks = (seq_len - prompt_length[i] + block_size - 1) // block_size
|
| 168 |
+
|
| 169 |
+
for b in range(num_blocks):
|
| 170 |
+
block_start = prompt_length[i] + b * block_size
|
| 171 |
+
# print(block_start,block_size,seq_len)
|
| 172 |
+
block_end = min(block_start + block_size, seq_len)
|
| 173 |
+
|
| 174 |
+
# 块内全注意
|
| 175 |
+
attn_mask[i,:,block_start:block_end, block_start:block_end] = 0.0
|
| 176 |
+
|
| 177 |
+
# 块之间因果注意(只能看前面块)
|
| 178 |
+
for prev_b in range(b):
|
| 179 |
+
prev_start = prompt_length[i] + prev_b * block_size
|
| 180 |
+
prev_end = min(prev_start + block_size, seq_len)
|
| 181 |
+
|
| 182 |
+
# 当前块可以看前面块
|
| 183 |
+
attn_mask[i,:,block_start:block_end, prev_start:prev_end] = 0.0
|
| 184 |
+
|
| 185 |
+
return attn_mask # [seq_len, seq_len], float, 0.0 for allowed, -inf for disallowed
|
| 186 |
+
if __name__ == "__main__":
|
| 187 |
+
seq_len = 10
|
| 188 |
+
input_ids = torch.randint(0, 100, (2, seq_len)) # 示例输入
|
| 189 |
+
block_size = 4
|
| 190 |
+
prompt_length = torch.tensor([2, 4]) # 示例prompt长度
|
| 191 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 192 |
+
attn_mask = build_custom_float_attention_mask(input_ids, prompt_length, block_size, device)
|
| 193 |
+
print(attn_mask)
|
Discrete-Diffusion-Forcing/D2F-train/utils/model.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import transformers
|
| 2 |
+
from transformers import AutoModel, AutoTokenizer
|
| 3 |
+
from peft import LoraConfig,get_peft_model
|
| 4 |
+
from model.modeling_llada import LLaDAModelLM
|
| 5 |
+
from model.configuration_llada import LLaDAConfig
|
| 6 |
+
|
| 7 |
+
def get_model_by_config(config):
|
| 8 |
+
"""Select different models based on config file"""
|
| 9 |
+
training_mode = config.get('training_mode', 'dream')
|
| 10 |
+
|
| 11 |
+
if training_mode == 'llada':
|
| 12 |
+
return get_llada(config)
|
| 13 |
+
elif training_mode == 'dream':
|
| 14 |
+
return get_model(config)
|
| 15 |
+
else:
|
| 16 |
+
raise ValueError(f"Unsupported training mode: {training_mode}")
|
| 17 |
+
|
| 18 |
+
def get_model(config):
|
| 19 |
+
# Use path from config, use default path if no config
|
| 20 |
+
model_path = config.paths.model if hasattr(config, 'paths') and hasattr(config.paths, 'model') else "/home/wx/data/model/Dream-org/Dream-v0-Base-7B"
|
| 21 |
+
|
| 22 |
+
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
|
| 23 |
+
# print(model.named_modules())
|
| 24 |
+
# print(model,"model
|
| 25 |
+
for param in model.parameters():
|
| 26 |
+
param.requires_grad = False
|
| 27 |
+
tokenizer=AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
| 28 |
+
peft_config = LoraConfig(r=32, lora_alpha=32, lora_dropout=0.1,target_modules=["q_proj", "v_proj","k_proj", "o_proj"],)
|
| 29 |
+
model = get_peft_model(model, peft_config)
|
| 30 |
+
model.print_trainable_parameters()
|
| 31 |
+
return model, tokenizer
|
| 32 |
+
|
| 33 |
+
def get_llada(config):
|
| 34 |
+
# Use path from config, use default path if no config
|
| 35 |
+
model_path = config.paths.model if hasattr(config, 'paths') and hasattr(config.paths, 'model') else "/data1/xck/models/llada-8b-instruct"
|
| 36 |
+
|
| 37 |
+
config_obj=LLaDAConfig.from_pretrained(model_path)
|
| 38 |
+
model = LLaDAModelLM.from_pretrained(model_path,config=config_obj)
|
| 39 |
+
# print(model.named_modules())
|
| 40 |
+
# print(model,"model
|
| 41 |
+
# print(model)
|
| 42 |
+
# exit()
|
| 43 |
+
for param in model.parameters():
|
| 44 |
+
param.requires_grad = False
|
| 45 |
+
tokenizer=AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
| 46 |
+
peft_config = LoraConfig(r=32, lora_alpha=32, lora_dropout=0.1,target_modules=["q_proj", "v_proj","k_proj", "attn_out"],)
|
| 47 |
+
model = get_peft_model(model, peft_config)
|
| 48 |
+
model.print_trainable_parameters()
|
| 49 |
+
return model, tokenizer
|
| 50 |
+
# def create_attention_mask(input_ids, mask_id):
|
| 51 |
+
# """
|
| 52 |
+
# Create an attention mask based on the input_ids and mask_id.
|
| 53 |
+
|
| 54 |
+
# Args:
|
| 55 |
+
# input_ids (torch.Tensor): The input tensor of shape (batch_size, sequence_length).
|
| 56 |
+
# mask_id (int): The ID of the mask token.
|
| 57 |
+
|
| 58 |
+
# Returns:
|
| 59 |
+
# torch.Tensor: The attention mask of shape (batch_size, sequence_length, sequence_length).
|
Discrete-Diffusion-Forcing/D2F-train/utils/util.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.distributions import Uniform
|
| 3 |
+
|
| 4 |
+
def forward_process_block_fixed_p(x, mask_id, p_mask):
|
| 5 |
+
B, L = x.shape
|
| 6 |
+
if isinstance(p_mask, float):
|
| 7 |
+
p_mask = torch.full((B, 1), p_mask, device=x.device)
|
| 8 |
+
elif p_mask.ndim == 1:
|
| 9 |
+
p_mask = p_mask[:, None]
|
| 10 |
+
rand = torch.rand((B, L), device=x.device)
|
| 11 |
+
mask = rand < p_mask
|
| 12 |
+
x_masked = torch.where(mask, mask_id, x)
|
| 13 |
+
return x_masked, mask
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
+
def generate_monotonic_pmasks(batch_size, max_blocks, device):
|
| 18 |
+
"""
|
| 19 |
+
生成 shape (B, max_blocks) 的单调非降随机序列,每行第一个元素在[0,1]随机,后续不小于前一个
|
| 20 |
+
"""
|
| 21 |
+
# 第一个block p_mask随机
|
| 22 |
+
p0 = torch.rand(batch_size, 1, device=device)/2+0.2
|
| 23 |
+
# print(p0)
|
| 24 |
+
# 后续blocks生成增量 [0, 1],加起来保证不超过1(之后用 clamp)
|
| 25 |
+
increments = torch.rand(batch_size, max_blocks - 1, device=device) * (0.7 - p0)/ (max_blocks - 1)
|
| 26 |
+
# print(increments)
|
| 27 |
+
# 逐元素累加,保证非降
|
| 28 |
+
cum_increments = torch.cumsum(increments, dim=1)
|
| 29 |
+
# print(cum_increments)
|
| 30 |
+
# 总 p_mask = p0 + 累积增量,保证不超过1
|
| 31 |
+
p_masks = torch.cat([p0, p0 + cum_increments], dim=1)
|
| 32 |
+
p_masks = torch.clamp(p_masks, max=1.0)
|
| 33 |
+
# print(p_masks)
|
| 34 |
+
return p_masks # (B, max_blocks)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def forward_process_length(input_ids, mask_id, block_size, prompt_lengths,eos_id=None):
|
| 38 |
+
"""
|
| 39 |
+
Args:
|
| 40 |
+
input_ids: (B, L)
|
| 41 |
+
prompt_lengths: (B,)
|
| 42 |
+
Returns:
|
| 43 |
+
noisy_batch, masked_indices, p_mask_tensor
|
| 44 |
+
"""
|
| 45 |
+
B, L = input_ids.shape
|
| 46 |
+
device = input_ids.device
|
| 47 |
+
noisy_batch = input_ids.clone()
|
| 48 |
+
eos_indices= (input_ids==eos_id)
|
| 49 |
+
masked_indices = torch.zeros_like(input_ids,dtype=torch.bool)
|
| 50 |
+
p_mask_tensor = torch.zeros((B, L), device=device)
|
| 51 |
+
|
| 52 |
+
# 计算每个样本block数
|
| 53 |
+
non_prompt_lens = L - prompt_lengths
|
| 54 |
+
full_blocks = non_prompt_lens // block_size
|
| 55 |
+
remainders = non_prompt_lens % block_size
|
| 56 |
+
total_blocks = full_blocks + (remainders > 0).long()
|
| 57 |
+
|
| 58 |
+
max_blocks = total_blocks.max().item()
|
| 59 |
+
|
| 60 |
+
# 生成每个样本block的mask比率,单调非降且第一个随机
|
| 61 |
+
p_masks = generate_monotonic_pmasks(B, max_blocks, device) # shape (B, max_blocks)
|
| 62 |
+
|
| 63 |
+
for i in range(B):
|
| 64 |
+
prompt_len = prompt_lengths[i].item()
|
| 65 |
+
num_blocks = total_blocks[i].item()
|
| 66 |
+
start_block = torch.tensor([0]) # 随机选择一个block开始
|
| 67 |
+
for block_idx in range(num_blocks):
|
| 68 |
+
if block_idx < start_block:
|
| 69 |
+
continue
|
| 70 |
+
start = prompt_len + block_idx * block_size
|
| 71 |
+
end = min(start + block_size, L)
|
| 72 |
+
|
| 73 |
+
p_block = p_masks[i, block_idx-start_block].item()
|
| 74 |
+
|
| 75 |
+
block = noisy_batch[i, start:end].unsqueeze(0)
|
| 76 |
+
masked_block, mask = forward_process_block_fixed_p(block, mask_id, p_block)
|
| 77 |
+
|
| 78 |
+
noisy_batch[i, start:end] = masked_block.squeeze(0)
|
| 79 |
+
masked_indices[i, start:end] = mask.squeeze(0)
|
| 80 |
+
# if torch.all(input_ids[i, start:end] == eos_id):
|
| 81 |
+
# masked_indices[i,start:end]== False
|
| 82 |
+
# print("1")
|
| 83 |
+
|
| 84 |
+
p_mask_tensor[i, start:end] = p_block
|
| 85 |
+
|
| 86 |
+
return noisy_batch, masked_indices, p_mask_tensor
|
| 87 |
+
|
| 88 |
+
# def forward_process_length(input_ids, mask_id, block_size, prompt_lengths, p_min=0.2, p_max=0.9):
|
| 89 |
+
# """
|
| 90 |
+
# 返回每个 token 的实际 mask 概率 tensor(非prompt区域),其余为0。
|
| 91 |
+
# """
|
| 92 |
+
# B, L = input_ids.shape
|
| 93 |
+
# device = input_ids.device
|
| 94 |
+
# noisy_batch = input_ids.clone()
|
| 95 |
+
# masked_indices = torch.zeros_like(input_ids, dtype=torch.bool)
|
| 96 |
+
# p_mask_tensor = torch.zeros((B, L), device=device) # 最终返回值
|
| 97 |
+
|
| 98 |
+
# for i in range(B):
|
| 99 |
+
# prompt_len = prompt_lengths[i].item()
|
| 100 |
+
# non_prompt_len = L - prompt_len
|
| 101 |
+
# full_blocks = non_prompt_len // block_size
|
| 102 |
+
# remainder = non_prompt_len % block_size
|
| 103 |
+
# total_blocks = full_blocks + (1 if remainder > 0 else 0)
|
| 104 |
+
|
| 105 |
+
# for block_idx in range(total_blocks):
|
| 106 |
+
# start = prompt_len + block_idx * block_size
|
| 107 |
+
# end = min(start + block_size, L)
|
| 108 |
+
|
| 109 |
+
# # block的 mask 概率(线性递增)
|
| 110 |
+
# if total_blocks > 1:
|
| 111 |
+
# p_block = p_min + (p_max - p_min) * (block_idx / (total_blocks - 1))
|
| 112 |
+
# else:
|
| 113 |
+
# p_block = p_max
|
| 114 |
+
|
| 115 |
+
# block = noisy_batch[i, start:end].unsqueeze(0)
|
| 116 |
+
# masked_block, mask = forward_process_block_fixed_p(block, mask_id, p_block)
|
| 117 |
+
# noisy_batch[i, start:end] = masked_block.squeeze(0)
|
| 118 |
+
# masked_indices[i, start:end] = mask.squeeze(0)
|
| 119 |
+
|
| 120 |
+
# # 记录 p_mask 到 tensor 中
|
| 121 |
+
# p_mask_tensor[i, start:end] = p_block
|
| 122 |
+
|
| 123 |
+
# return noisy_batch, masked_indices, p_mask_tensor
|
| 124 |
+
def forward_process(input_ids,mask_id ,t_max=1.0, eps=1e-4):
|
| 125 |
+
B, L = input_ids.shape
|
| 126 |
+
# t = torch.rand(B, device=input_ids.device)
|
| 127 |
+
dist = Uniform(0., t_max)
|
| 128 |
+
t = dist.sample((B,)).to(input_ids.device)
|
| 129 |
+
p_mask = (1 - eps) * t + eps
|
| 130 |
+
p_mask = p_mask[:, None].repeat(1, L)
|
| 131 |
+
masked_indices = torch.rand((B, L), device=input_ids.device) < p_mask
|
| 132 |
+
noisy_batch = torch.where(masked_indices, mask_id, input_ids)
|
| 133 |
+
|
| 134 |
+
return noisy_batch, masked_indices, p_mask
|
| 135 |
+
def flatten_dict(d, parent_key='', sep='_'):
|
| 136 |
+
items = []
|
| 137 |
+
for k, v in d.items():
|
| 138 |
+
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
| 139 |
+
if isinstance(v, dict):
|
| 140 |
+
items.extend(flatten_dict(v, new_key, sep=sep).items())
|
| 141 |
+
else:
|
| 142 |
+
items.append((new_key, v))
|
| 143 |
+
return dict(items)
|
| 144 |
+
|
| 145 |
+
def shift_logits(logits):
|
| 146 |
+
shifted_logits = torch.zeros_like(logits)
|
| 147 |
+
shifted_logits[:, 1:, :] = logits[:, :-1, :]
|
| 148 |
+
shifted_logits[:, 0, :] = 1.0
|
| 149 |
+
|
| 150 |
+
return shifted_logits
|
| 151 |
+
if __name__ == '__main__':
|
| 152 |
+
input_ids= torch.tensor([[1,5,4,3,25,6,7,9,5,8,7,6],[1,3,8,9,7,34,6,9,5,8,7,6]])
|
| 153 |
+
mask_id=0
|
| 154 |
+
block_size=3
|
| 155 |
+
prompt_length=torch.tensor([2,1])
|
| 156 |
+
noisy_batch, masked_indices,p_mask = forward_process_length(input_ids, mask_id, block_size, prompt_length)
|
| 157 |
+
print("noisy_batch:", noisy_batch)
|
| 158 |
+
print("masked_indices:", masked_indices)
|
| 159 |
+
print("p_mask:", p_mask)
|
Discrete-Diffusion-Forcing/Discrete Diffusion Forcing.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:882a505b24c5a09dacae481b637388774c24f0f53e27545da67291b07e5b157d
|
| 3 |
+
size 788772
|
Discrete-Diffusion-Forcing/LICENCE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 DENG Lab @ SJTU
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
Discrete-Diffusion-Forcing/README.md
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<p align="center">
|
| 2 |
+
<img src="docs/assets/img/d2f/logo_lr.png" width="300">
|
| 3 |
+
</p>
|
| 4 |
+
|
| 5 |
+
## Discrete Diffusion Forcing (D2F): dLLMs Can Do Faster-Than-AR Inference
|
| 6 |
+
|
| 7 |
+
<p align="center">
|
| 8 |
+
<a href="https://arxiv.org/abs/2508.09192"><b>📄 Paper</b></a> •
|
| 9 |
+
<a href="https://zhijie-group.github.io/Discrete-Diffusion-Forcing/"><b>📝 Blog Post</b></a> •
|
| 10 |
+
<a href="https://huggingface.co/spaces/zhijie3/D2F-LLaDA-Instruct-8B"><b>🚀 Online Demo</b></a> •
|
| 11 |
+
<a href="https://huggingface.co/SJTU-Deng-Lab/D2F_Dream_Base_7B_Lora"><b>🤗 D2F-Dream LoRA</b></a> •
|
| 12 |
+
<a href="https://huggingface.co/SJTU-Deng-Lab/D2F_LLaDA_Instruct_8B_Lora"><b>🤗 D2F-LLaDA LoRA</b></a>
|
| 13 |
+
</p>
|
| 14 |
+
|
| 15 |
+
<p align="center">
|
| 16 |
+
<a href="https://discord.gg/aDWgxT6S2q"><b>💬 Discord</b></a> •
|
| 17 |
+
<a href="docs/assets/img/d2f/wechat.png"><b>💬 Wechat</b></a>
|
| 18 |
+
</p>
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
https://github.com/user-attachments/assets/d9de6450-68d6-4caf-85c2-c7f384395c42
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
<p align="center">
|
| 26 |
+
<br>
|
| 27 |
+
<small><b>Real-time generation demo:</b> our D2F model (left) uses parallel block decoding, while the AR baseline (right) generates tokens sequentially. This visualizes the source of D2F's significant throughput advantage.</small>
|
| 28 |
+
</p>
|
| 29 |
+
|
| 30 |
+
<hr>
|
| 31 |
+
|
| 32 |
+
<p align="center">
|
| 33 |
+
<img src="docs/assets/img/d2f/fig1_main_result.png" width="800">
|
| 34 |
+
<br>
|
| 35 |
+
<small><b>Inference throughput comparison:</b> D2F dLLMs surpass similarly-sized AR models in inference speed for the first time, achieving up to a <b>2.5x speedup</b> over LLaMA3 and a <b>>50x speedup</b> over vanilla dLLM baselines (Speed tests conducted on NVIDIA A100-PCIe-40GB GPUs).</small>
|
| 36 |
+
</p>
|
| 37 |
+
|
| 38 |
+
**Discrete Diffusion Forcing (D2F)** is a novel training and inference paradigm that, for the first time, enables open-source Diffusion Language Models (dLLMs) to surpass their autoregressive (AR) counterparts in inference speed. By introducing a highly efficient AR-diffusion hybrid model, D2F achieves:
|
| 39 |
+
- Up to a **2.5x speedup** over leading AR models like LLaMA3-8B.
|
| 40 |
+
- A staggering **50x+ acceleration** over vanilla dLLM baselines.
|
| 41 |
+
- Comparable generation quality on standard reasoning and coding benchmarks.
|
| 42 |
+
- **Integration with vLLM** to unlock the next tier of extreme inference acceleration.
|
| 43 |
+
|
| 44 |
+
This repository provides the code to reproduce our evaluation results and run generation demos.
|
| 45 |
+
|
| 46 |
+
## 🔥 News!
|
| 47 |
+
* Aug 20, 2025: We've released the training pipeline of D2F!
|
| 48 |
+
* Aug 8, 2025: We've released the inference code of D2F!
|
| 49 |
+
## Contents
|
| 50 |
+
- [🤔 How It Works](#-how-it-works)
|
| 51 |
+
- [📊 Performance Highlights](#-performance-highlights)
|
| 52 |
+
- [⚡️ Extreme Acceleration with vLLM Integration](#️-extreme-acceleration-with-vllm-integration)
|
| 53 |
+
- [🚀 Usage Guide](#-usage-guide)
|
| 54 |
+
- [🙏 Acknowledgements](#-acknowledgements)
|
| 55 |
+
- [©️ Citation](#️-citation)
|
| 56 |
+
|
| 57 |
+
## 🤔 How It Works
|
| 58 |
+
|
| 59 |
+
D2F overcomes the historical speed bottlenecks of dLLMs (KV Cache incompatibility and strict sequential dependencies) by restructuring the generation process.
|
| 60 |
+
|
| 61 |
+
**1. Hybrid Architecture:** D2F employs a **block-wise causal attention** mechanism. Attention *within* a block is bidirectional, preserving rich local context, while attention *between* blocks is causal. This simple but powerful change makes the model fully compatible with the standard KV Cache, drastically reducing redundant computations.
|
| 62 |
+
|
| 63 |
+
**2. Efficient Training via Asymmetric Distillation:** Instead of training from scratch, we distill a powerful, pre-trained bidirectional dLLM (teacher) into our cache-friendly D2F model (student). The student learns to match the teacher's output with only a limited, causal view of the context.
|
| 64 |
+
|
| 65 |
+
<p align="center">
|
| 66 |
+
<img src="docs/assets/img/d2f/fig3_overview.png" width="800">
|
| 67 |
+
<br>
|
| 68 |
+
<small><b>Overview of Discrete Diffusion Forcing (D2F):</b> A D2F model (student) with a KV-cache-friendly block-wise causal attention mask is trained to mimic a powerful, pre-trained bidirectional dLLM (teacher), efficiently inheriting its capabilities.</small>
|
| 69 |
+
</p>
|
| 70 |
+
|
| 71 |
+
**3. High-Throughput Pipelined Decoding:** D2F is trained to predict future blocks based on *partially incomplete* prefixes. This enables a **pipelined parallel decoding** algorithm during inference, where multiple blocks are refined simultaneously in an asynchronous workflow, maximizing GPU utilization and throughput.
|
| 72 |
+
|
| 73 |
+
<p align="center">
|
| 74 |
+
<img src="docs/assets/img/d2f/fig4_pipeline.png" width="800">
|
| 75 |
+
<br>
|
| 76 |
+
<small><b>Visualization of our pipelined parallel decoding:</b> New blocks are dynamically added and decoded in parallel with their predecessors, moving from a conservative "semi-activated" state to an aggressive "fully-activated" state. This creates a continuous, high-throughput generation flow.</small>
|
| 77 |
+
</p>
|
| 78 |
+
|
| 79 |
+
https://github.com/user-attachments/assets/41a0176b-e4ae-4f8b-95a6-daed7af2a027
|
| 80 |
+
|
| 81 |
+
<p align="center">
|
| 82 |
+
<br>
|
| 83 |
+
<small><b>A slow-motion demonstration of the parallel decoding process within a single block of D2F. Watch as multiple tokens within the block are refined simultaneously, showcasing the efficiency of our approach.</small>
|
| 84 |
+
</p>
|
| 85 |
+
|
| 86 |
+
## 📊 Performance Highlights
|
| 87 |
+
|
| 88 |
+
D2F delivers transformative speedups while maintaining or improving scores. Below is a comprehensive summary of performance on **LLaDA-Instruct-8B** and **Dream-Base-7B**, comparing our method against the original baseline and the previous SOTA acceleration method, Fast-dLLM.
|
| 89 |
+
|
| 90 |
+
<center>
|
| 91 |
+
|
| 92 |
+
**Performance on LLaDA-Instruct-8B**
|
| 93 |
+
<table style="width:100%; border-collapse: collapse; text-align: center;">
|
| 94 |
+
<thead style="background-color:#f2f2f2;">
|
| 95 |
+
<tr>
|
| 96 |
+
<th style="padding: 8px; border: 1px solid #ddd;">Benchmark</th>
|
| 97 |
+
<th style="padding: 8px; border: 1px solid #ddd;">Metric</th>
|
| 98 |
+
<th style="padding: 8px; border: 1px solid #ddd;">LLaDA-Instruct (Baseline)</th>
|
| 99 |
+
<th style="padding: 8px; border: 1px solid #ddd;">Fast-dLLM (SOTA)</th>
|
| 100 |
+
<th style="padding: 8px; border: 1px solid #ddd;">D2F-LLaDA (Ours)</th>
|
| 101 |
+
</tr>
|
| 102 |
+
</thead>
|
| 103 |
+
<tbody>
|
| 104 |
+
<tr>
|
| 105 |
+
<td rowspan="2" style="padding: 8px; border: 1px solid #ddd; vertical-align: middle;"><strong>GSM8K-4-shot</strong></td>
|
| 106 |
+
<td style="padding: 8px; border: 1px solid #ddd;">TPS ↑</td>
|
| 107 |
+
<td style="padding: 8px; border: 1px solid #ddd;">7.2</td>
|
| 108 |
+
<td style="padding: 8px; border: 1px solid #ddd;">35.2</td>
|
| 109 |
+
<td style="padding: 8px; border: 1px solid #ddd;"><strong>52.5 <font color="green">(7.3x)</font></strong></td>
|
| 110 |
+
</tr>
|
| 111 |
+
<tr>
|
| 112 |
+
<td style="padding: 8px; border: 1px solid #ddd;">Score ↑</td>
|
| 113 |
+
<td style="padding: 8px; border: 1px solid #ddd;">77.4</td>
|
| 114 |
+
<td style="padding: 8px; border: 1px solid #ddd;"><b>78.9</b></td>
|
| 115 |
+
<td style="padding: 8px; border: 1px solid #ddd;">77.3</td>
|
| 116 |
+
</tr>
|
| 117 |
+
<tr>
|
| 118 |
+
<td rowspan="2" style="padding: 8px; border: 1px solid #ddd; vertical-align: middle; background-color: #fafafa;"><strong>MBPP-3-shot</strong></td>
|
| 119 |
+
<td style="padding: 8px; border: 1px solid #ddd; background-color: #fafafa;">TPS ↑</td>
|
| 120 |
+
<td style="padding: 8px; border: 1px solid #ddd; background-color: #fafafa;">0.9</td>
|
| 121 |
+
<td style="padding: 8px; border: 1px solid #ddd; background-color: #fafafa;">15.3</td>
|
| 122 |
+
<td style="padding: 8px; border: 1px solid #ddd; background-color: #fafafa;"><strong>47.6 <font color="green">(52.9x)</font></strong></td>
|
| 123 |
+
</tr>
|
| 124 |
+
<tr>
|
| 125 |
+
<td style="padding: 8px; border: 1px solid #ddd; background-color: #fafafa;">Score ↑</td>
|
| 126 |
+
<td style="padding: 8px; border: 1px solid #ddd; background-color: #fafafa;"><b>39.0</b></td>
|
| 127 |
+
<td style="padding: 8px; border: 1px solid #ddd; background-color: #fafafa;">36.4</td>
|
| 128 |
+
<td style="padding: 8px; border: 1px solid #ddd; background-color: #fafafa;">38.0</td>
|
| 129 |
+
</tr>
|
| 130 |
+
<tr>
|
| 131 |
+
<td rowspan="2" style="padding: 8px; border: 1px solid #ddd; vertical-align: middle;"><strong>HumanEval-0-shot</strong></td>
|
| 132 |
+
<td style="padding: 8px; border: 1px solid #ddd;">TPS ↑</td>
|
| 133 |
+
<td style="padding: 8px; border: 1px solid #ddd;">2.8</td>
|
| 134 |
+
<td style="padding: 8px; border: 1px solid #ddd;">19.2</td>
|
| 135 |
+
<td style="padding: 8px; border: 1px solid #ddd;"><strong>81.6 <font color="green">(29.1x)</font></strong></td>
|
| 136 |
+
</tr>
|
| 137 |
+
<tr>
|
| 138 |
+
<td style="padding: 8px; border: 1px solid #ddd;">Score ↑</td>
|
| 139 |
+
<td style="padding: 8px; border: 1px solid #ddd;">36.0</td>
|
| 140 |
+
<td style="padding: 8px; border: 1px solid #ddd;">35.4</td>
|
| 141 |
+
<td style="padding: 8px; border: 1px solid #ddd;"><b>40.2</b></td>
|
| 142 |
+
</tr>
|
| 143 |
+
<tr>
|
| 144 |
+
<td rowspan="2" style="padding: 8px; border: 1px solid #ddd; vertical-align: middle; background-color: #fafafa;"><strong>Math-4-shot</strong></td>
|
| 145 |
+
<td style="padding: 8px; border: 1px solid #ddd; background-color: #fafafa;">TPS ↑</td>
|
| 146 |
+
<td style="padding: 8px; border: 1px solid #ddd; background-color: #fafafa;">21.1</td>
|
| 147 |
+
<td style="padding: 8px; border: 1px solid #ddd; background-color: #fafafa;">42.5</td>
|
| 148 |
+
<td style="padding: 8px; border: 1px solid #ddd; background-color: #fafafa;"><strong>90.2 <font color="green">(4.3x)</font></strong></td>
|
| 149 |
+
</tr>
|
| 150 |
+
<tr>
|
| 151 |
+
<td style="padding: 8px; border: 1px solid #ddd; background-color: #fafafa;">Score ↑</td>
|
| 152 |
+
<td style="padding: 8px; border: 1px solid #ddd; background-color: #fafafa;">23.7</td>
|
| 153 |
+
<td style="padding: 8px; border: 1px solid #ddd; background-color: #fafafa;">22.4</td>
|
| 154 |
+
<td style="padding: 8px; border: 1px solid #ddd; background-color: #fafafa;"><b>29.1</b></td>
|
| 155 |
+
</tr>
|
| 156 |
+
</tbody>
|
| 157 |
+
</table>
|
| 158 |
+
|
| 159 |
+
**Performance on Dream-Base-7B**
|
| 160 |
+
<table style="width:100%; border-collapse: collapse; text-align: center;">
|
| 161 |
+
<thead style="background-color:#f2f2f2;">
|
| 162 |
+
<tr>
|
| 163 |
+
<th style="padding: 8px; border: 1px solid #ddd;">Benchmark</th>
|
| 164 |
+
<th style="padding: 8px; border: 1px solid #ddd;">Metric</th>
|
| 165 |
+
<th style="padding: 8px; border: 1px solid #ddd;">Dream-Base (Baseline)</th>
|
| 166 |
+
<th style="padding: 8px; border: 1px solid #ddd;">Fast-dLLM (SOTA)</th>
|
| 167 |
+
<th style="padding: 8px; border: 1px solid #ddd;">D2F-Dream (Ours)</th>
|
| 168 |
+
</tr>
|
| 169 |
+
</thead>
|
| 170 |
+
<tbody>
|
| 171 |
+
<tr>
|
| 172 |
+
<td rowspan="2" style="padding: 8px; border: 1px solid #ddd; vertical-align: middle;"><strong>GSM8K-CoT-8-shot</strong></td>
|
| 173 |
+
<td style="padding: 8px; border: 1px solid #ddd;">TPS ↑</td>
|
| 174 |
+
<td style="padding: 8px; border: 1px solid #ddd;">9.5</td>
|
| 175 |
+
<td style="padding: 8px; border: 1px solid #ddd;">49.8</td>
|
| 176 |
+
<td style="padding: 8px; border: 1px solid #ddd;"><strong>91.2 <font color="green">(9.6x)</font></strong></td>
|
| 177 |
+
</tr>
|
| 178 |
+
<tr>
|
| 179 |
+
<td style="padding: 8px; border: 1px solid #ddd;">Score ↑</td>
|
| 180 |
+
<td style="padding: 8px; border: 1px solid #ddd;">75.0</td>
|
| 181 |
+
<td style="padding: 8px; border: 1px solid #ddd;">75.0</td>
|
| 182 |
+
<td style="padding: 8px; border: 1px solid #ddd;"><b>77.6</b></td>
|
| 183 |
+
</tr>
|
| 184 |
+
<tr>
|
| 185 |
+
<td rowspan="2" style="padding: 8px; border: 1px solid #ddd; vertical-align: middle; background-color: #fafafa;"><strong>MBPP-3-shot</strong></td>
|
| 186 |
+
<td style="padding: 8px; border: 1px solid #ddd; background-color: #fafafa;">TPS ↑</td>
|
| 187 |
+
<td style="padding: 8px; border: 1px solid #ddd; background-color: #fafafa;">10.4</td>
|
| 188 |
+
<td style="padding: 8px; border: 1px solid #ddd; background-color: #fafafa;">73.2</td>
|
| 189 |
+
<td style="padding: 8px; border: 1px solid #ddd; background-color: #fafafa;"><strong>105 <font color="green">(10.1x)</font></strong></td>
|
| 190 |
+
</tr>
|
| 191 |
+
<tr>
|
| 192 |
+
<td style="padding: 8px; border: 1px solid #ddd; background-color: #fafafa;">Score ↑</td>
|
| 193 |
+
<td style="padding: 8px; border: 1px solid #ddd; background-color: #fafafa;">56.2</td>
|
| 194 |
+
<td style="padding: 8px; border: 1px solid #ddd; background-color: #fafafa;">51.0</td>
|
| 195 |
+
<td style="padding: 8px; border: 1px solid #ddd; background-color: #fafafa;"><b>56.4</b></td>
|
| 196 |
+
</tr>
|
| 197 |
+
<tr>
|
| 198 |
+
<td rowspan="2" style="padding: 8px; border: 1px solid #ddd; vertical-align: middle;"><strong>HumanEval-0-shot</strong></td>
|
| 199 |
+
<td style="padding: 8px; border: 1px solid #ddd;">TPS ↑</td>
|
| 200 |
+
<td style="padding: 8px; border: 1px solid #ddd;">20.2</td>
|
| 201 |
+
<td style="padding: 8px; border: 1px solid #ddd;">60.0</td>
|
| 202 |
+
<td style="padding: 8px; border: 1px solid #ddd;"><strong>73.2 <font color="green">(3.6x)</font></strong></td>
|
| 203 |
+
</tr>
|
| 204 |
+
<tr>
|
| 205 |
+
<td style="padding: 8px; border: 1px solid #ddd;">Score ↑</td>
|
| 206 |
+
<td style="padding: 8px; border: 1px solid #ddd;">54.3</td>
|
| 207 |
+
<td style="padding: 8px; border: 1px solid #ddd;">53.0</td>
|
| 208 |
+
<td style="padding: 8px; border: 1px solid #ddd;"><b>55.5</b></td>
|
| 209 |
+
</tr>
|
| 210 |
+
<tr>
|
| 211 |
+
<td rowspan="2" style="padding: 8px; border: 1px solid #ddd; vertical-align: middle; background-color: #fafafa;"><strong>Math-4-shot</strong></td>
|
| 212 |
+
<td style="padding: 8px; border: 1px solid #ddd; background-color: #fafafa;">TPS ↑</td>
|
| 213 |
+
<td style="padding: 8px; border: 1px solid #ddd; background-color: #fafafa;">9.9</td>
|
| 214 |
+
<td style="padding: 8px; border: 1px solid #ddd; background-color: #fafafa;">67.0</td>
|
| 215 |
+
<td style="padding: 8px; border: 1px solid #ddd; background-color: #fafafa;"><strong>98.8 <font color="green">(10.0x)</font></strong></td>
|
| 216 |
+
</tr>
|
| 217 |
+
<tr>
|
| 218 |
+
<td style="padding: 8px; border: 1px solid #ddd; background-color: #fafafa;">Score ↑</td>
|
| 219 |
+
<td style="padding: 8px; border: 1px solid #ddd; background-color: #fafafa;">35.8</td>
|
| 220 |
+
<td style="padding: 8px; border: 1px solid #ddd; background-color: #fafafa;"><b>37.6</b></td>
|
| 221 |
+
<td style="padding: 8px; border: 1px solid #ddd; background-color: #fafafa;">35.4</td>
|
| 222 |
+
</tr>
|
| 223 |
+
</tbody>
|
| 224 |
+
</table>
|
| 225 |
+
</center>
|
| 226 |
+
|
| 227 |
+
## ⚡️ Extreme Acceleration with vLLM Integration
|
| 228 |
+
|
| 229 |
+
To push the boundaries of inference speed, we've integrated D2F with a **preliminary vLLM-based engine**. This unlocks a multiplicative speedup on top of our already-accelerated model, showcasing the immense potential for production environments.
|
| 230 |
+
|
| 231 |
+
<center>
|
| 232 |
+
|
| 233 |
+
<strong>HumanEval-0-shot with vLLM</strong>
|
| 234 |
+
<table style="width:100%; border-collapse: collapse; text-align: center;">
|
| 235 |
+
<thead style="background-color:#f2f2f2;">
|
| 236 |
+
<tr>
|
| 237 |
+
<th style="padding: 8px; border: 1px solid #ddd;">Model</th>
|
| 238 |
+
<th style="padding: 8px; border: 1px solid #ddd;">TPS ↑</th>
|
| 239 |
+
<th style="padding: 8px; border: 1px solid #ddd;">Score ↑</th>
|
| 240 |
+
</tr>
|
| 241 |
+
</thead>
|
| 242 |
+
<tbody>
|
| 243 |
+
<tr>
|
| 244 |
+
<td style="padding: 8px; border: 1px solid #ddd;">Dream-Base (Baseline)</td>
|
| 245 |
+
<td style="padding: 8px; border: 1px solid #ddd;">20.2 <font color="green">(1.0x)</font></td>
|
| 246 |
+
<td style="padding: 8px; border: 1px solid #ddd;">54.3</td>
|
| 247 |
+
</tr>
|
| 248 |
+
<tr>
|
| 249 |
+
<td style="padding: 8px; border: 1px solid #ddd;">D2F-Dream (Ours)</td>
|
| 250 |
+
<td style="padding: 8px; border: 1px solid #ddd;">73.2 <font color="green">(3.6x)</font></td>
|
| 251 |
+
<td style="padding: 8px; border: 1px solid #ddd;">54.3</td>
|
| 252 |
+
</tr>
|
| 253 |
+
<tr style="background-color:#E8F5E9;">
|
| 254 |
+
<td style="padding: 8px; border: 1px solid #ddd;"><strong>D2F-Dream + vLLM (Ours)</strong></td>
|
| 255 |
+
<td style="padding: 8px; border: 1px solid #ddd;"><strong>131.7 <font color="green">(6.5x)</font></strong></td>
|
| 256 |
+
<td style="padding: 8px; border: 1px solid #ddd;">40.2</td>
|
| 257 |
+
</tr>
|
| 258 |
+
</tbody>
|
| 259 |
+
</table>
|
| 260 |
+
<br>
|
| 261 |
+
<small>Our D2F-Dream model with a preliminary vLLM engine achieves a <b>6.5x speedup</b> over the original Dream-Base, though we observe a score drop that we are actively working to resolve through optimized kernels.</small>
|
| 262 |
+
|
| 263 |
+
</center>
|
| 264 |
+
|
| 265 |
+
> **Implementation Notes:**
|
| 266 |
+
> The current vLLM integration is an initial proof-of-concept. It already provides a significant performance boost by leveraging Flex Attention, but there is substantial room for further optimization. Our future work will focus on implementing specialized CUDA kernels and other advanced vLLM features to maximize speed while restoring the score.
|
| 267 |
+
|
| 268 |
+
## 🚀 Usage Guide
|
| 269 |
+
|
| 270 |
+
### 1. Installation
|
| 271 |
+
|
| 272 |
+
First, clone the repository and set up the environment.
|
| 273 |
+
|
| 274 |
+
```shell
|
| 275 |
+
# Clone the repository
|
| 276 |
+
git clone https://github.com/zhijie-group/Discrete-Diffusion-Forcing.git
|
| 277 |
+
cd Discrete-Diffusion-Forcing
|
| 278 |
+
```
|
| 279 |
+
|
| 280 |
+
#### Environment Configuration
|
| 281 |
+
|
| 282 |
+
##### UV (Recommended)
|
| 283 |
+
|
| 284 |
+
```shell
|
| 285 |
+
uv sync
|
| 286 |
+
```
|
| 287 |
+
|
| 288 |
+
##### Conda
|
| 289 |
+
|
| 290 |
+
```shell
|
| 291 |
+
# Create and activate a conda environment
|
| 292 |
+
conda create -n d2f python=3.10
|
| 293 |
+
conda activate d2f
|
| 294 |
+
|
| 295 |
+
# Install dependencies
|
| 296 |
+
pip install -r requirements.txt
|
| 297 |
+
```
|
| 298 |
+
|
| 299 |
+
#### vLLM Installation
|
| 300 |
+
|
| 301 |
+
vLLM is comming soon, right now we only implemented the basic functions of vLLM.
|
| 302 |
+
|
| 303 |
+
### 2. Evaluation
|
| 304 |
+
All evaluation scripts are located in the `D2F-eval` directory.
|
| 305 |
+
|
| 306 |
+
```shell
|
| 307 |
+
cd D2F-eval
|
| 308 |
+
```
|
| 309 |
+
|
| 310 |
+
To evaluate the **D2F-Dream** model on all benchmarks, run:
|
| 311 |
+
|
| 312 |
+
```shell
|
| 313 |
+
shell eval_dream.sh
|
| 314 |
+
```
|
| 315 |
+
|
| 316 |
+
To evaluate the **D2F-LLaDA** model on all benchmarks, run:
|
| 317 |
+
|
| 318 |
+
```shell
|
| 319 |
+
shell eval_llada.sh
|
| 320 |
+
```
|
| 321 |
+
The results will be saved in the `output_path` specified within the shell scripts.
|
| 322 |
+
|
| 323 |
+
> ### ❗️ Important Notice for HumanEval
|
| 324 |
+
> The `HumanEval` benchmark requires a post-processing step to sanitize the generated code and calculate the final `pass@1` score. After the evaluation script finishes, run the following command:
|
| 325 |
+
> ```shell
|
| 326 |
+
> python postprocess_code.py {path/to/your/samples_humaneval_xxx.jsonl}
|
| 327 |
+
> ```
|
| 328 |
+
> Replace the path with the actual path to your generated samples file, which can be found in the specified `output_path`.
|
| 329 |
+
|
| 330 |
+
### 3. Training
|
| 331 |
+
All training scripts and configurations are located in the `D2F-train` directory.
|
| 332 |
+
```shell
|
| 333 |
+
# Navigate to the training directory
|
| 334 |
+
cd D2F-train
|
| 335 |
+
```
|
| 336 |
+
Before starting the training, you need to configure the paths for your dataset, models, and output directories. Modify the relevant paths in the configuration files located inside the `config` folder.
|
| 337 |
+
|
| 338 |
+
Once the configuration is set, you can start the training process by running:
|
| 339 |
+
```shell
|
| 340 |
+
bash train.sh
|
| 341 |
+
```
|
| 342 |
+
|
| 343 |
+
### 4. Generation Demo
|
| 344 |
+
|
| 345 |
+
We provide simple scripts to demonstrate the generation process and compare D2F with a standard AR baseline.
|
| 346 |
+
```shell
|
| 347 |
+
# To run a demo with the D2F pipelined block generation method:
|
| 348 |
+
python generate_llada_demo_block.py
|
| 349 |
+
|
| 350 |
+
# To compare, run a demo with the baseline AR generation method:
|
| 351 |
+
python generate_llada_demo_ar.py
|
| 352 |
+
```
|
| 353 |
+
You can inspect these files to see how to use the D2F model for inference in your own projects.
|
| 354 |
+
|
| 355 |
+
## 📚 Future Works
|
| 356 |
+
|
| 357 |
+
- [x] Implement dLLM-suported vLLM (preliminary).
|
| 358 |
+
- [ ] Implement fused dLLM-specific decoding kernels for vLLM to maximize performance and restore scores.
|
| 359 |
+
- [ ] Implement distributed inference with multi-GPUs in vLLM.
|
| 360 |
+
- [ ] Implement CUDA graph capturing for dynamic sequences in vLLM.
|
| 361 |
+
|
| 362 |
+
## 🙏 Acknowledgements
|
| 363 |
+
Our work builds upon the foundations laid by the original **LLaDA** and **Dream** models. We thank their authors for making their work public. We are also grateful for the powerful open-source tools from Hugging Face and the vLLM team that made this research possible.
|
| 364 |
+
|
| 365 |
+
## ©️ Citation
|
| 366 |
+
If you find our work useful for your research, please consider citing our paper:
|
| 367 |
+
```bibtex
|
| 368 |
+
@article{wang2025diffusion,
|
| 369 |
+
title={Diffusion llms can do faster-than-ar inference via discrete diffusion forcing},
|
| 370 |
+
author={Wang, Xu and Xu, Chenkai and Jin, Yijie and Jin, Jiachun and Zhang, Hao and Deng, Zhijie},
|
| 371 |
+
journal={arXiv preprint arXiv:2508.09192},
|
| 372 |
+
year={2025}
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
```
|
Discrete-Diffusion-Forcing/docs/_config.yml
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
theme: jekyll-theme-minimal
|
Discrete-Diffusion-Forcing/docs/assets/img/d2f/fig1_main_result.png
ADDED
|
Git LFS Details
|
Discrete-Diffusion-Forcing/docs/assets/img/d2f/fig2_tradeoff.png
ADDED
|
Git LFS Details
|
Discrete-Diffusion-Forcing/docs/assets/img/d2f/fig3_overview.png
ADDED
|
Git LFS Details
|
Discrete-Diffusion-Forcing/docs/assets/img/d2f/fig4_pipeline.png
ADDED
|
Git LFS Details
|
Discrete-Diffusion-Forcing/docs/assets/img/d2f/logo_lr.jpg
ADDED
|
Discrete-Diffusion-Forcing/docs/assets/img/d2f/logo_lr.png
ADDED
|
Git LFS Details
|
Discrete-Diffusion-Forcing/docs/assets/img/d2f/logo_ud.jpg
ADDED
|
Discrete-Diffusion-Forcing/docs/assets/img/d2f/logo_ud.png
ADDED
|
Git LFS Details
|
Discrete-Diffusion-Forcing/docs/assets/img/d2f/table1_llada_results.png
ADDED
|
Git LFS Details
|
Discrete-Diffusion-Forcing/docs/assets/img/d2f/table2_dream_results.png
ADDED
|
Git LFS Details
|
Discrete-Diffusion-Forcing/docs/assets/img/d2f/wechat.png
ADDED
|
Git LFS Details
|