Commit
·
51be264
0
Parent(s):
first push
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- .gitignore +186 -0
- Dockerfile +55 -0
- LICENSE +201 -0
- README.md +34 -0
- c2cite.py +300 -0
- c2cite/__init__.py +52 -0
- c2cite/adapters/__init__.py +104 -0
- c2cite/adapters/loramoe/__init__.py +7 -0
- c2cite/adapters/loramoe/config.py +42 -0
- c2cite/adapters/loramoe/model.py +62 -0
- c2cite/adapters/mixlora/__init__.py +19 -0
- c2cite/adapters/mixlora/config.py +144 -0
- c2cite/adapters/mixlora/model.py +610 -0
- c2cite/adapters/mola/__init__.py +8 -0
- c2cite/adapters/mola/config.py +57 -0
- c2cite/adapters/mola/model.py +159 -0
- c2cite/common/__init__.py +92 -0
- c2cite/common/abstracts.py +194 -0
- c2cite/common/attention.py +293 -0
- c2cite/common/cache.py +554 -0
- c2cite/common/checkpoint.py +33 -0
- c2cite/common/config.py +234 -0
- c2cite/common/feed_forward.py +70 -0
- c2cite/common/lora_linear.py +511 -0
- c2cite/common/moe_utils.py +57 -0
- c2cite/common/rope.py +88 -0
- c2cite/dispatcher.py +378 -0
- c2cite/evaluator.py +518 -0
- c2cite/executors/__init__.py +54 -0
- c2cite/executors/common.py +77 -0
- c2cite/executors/cpu.py +51 -0
- c2cite/executors/cuda.py +53 -0
- c2cite/executors/mps.py +71 -0
- c2cite/generator.py +669 -0
- c2cite/model.py +1039 -0
- c2cite/models/__init__.py +40 -0
- c2cite/models/modeling_chatglm.py +855 -0
- c2cite/models/modeling_gemma.py +131 -0
- c2cite/models/modeling_gemma2.py +528 -0
- c2cite/models/modeling_llama.py +579 -0
- c2cite/models/modeling_mistral.py +255 -0
- c2cite/models/modeling_phi.py +576 -0
- c2cite/models/modeling_phi3.py +581 -0
- c2cite/prompter.py +63 -0
- c2cite/solutions.py +9 -0
- c2cite/tasks/__init__.py +29 -0
- c2cite/tasks/attribute_tasks.py +567 -0
- c2cite/tasks/common.py +1045 -0
- c2cite/tasks/glue_tasks.py +90 -0
.gitattributes
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
paper_wsdm_c2cite.pdf filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.pdf filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
.python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# poetry
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 102 |
+
#poetry.lock
|
| 103 |
+
|
| 104 |
+
# pdm
|
| 105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 106 |
+
#pdm.lock
|
| 107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 108 |
+
# in version control.
|
| 109 |
+
# https://pdm.fming.dev/#use-with-ide
|
| 110 |
+
.pdm.toml
|
| 111 |
+
|
| 112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 113 |
+
__pypackages__/
|
| 114 |
+
|
| 115 |
+
# Celery stuff
|
| 116 |
+
celerybeat-schedule
|
| 117 |
+
celerybeat.pid
|
| 118 |
+
|
| 119 |
+
# SageMath parsed files
|
| 120 |
+
*.sage.py
|
| 121 |
+
|
| 122 |
+
# Environments
|
| 123 |
+
.env
|
| 124 |
+
.venv
|
| 125 |
+
env/
|
| 126 |
+
venv/
|
| 127 |
+
ENV/
|
| 128 |
+
env.bak/
|
| 129 |
+
venv.bak/
|
| 130 |
+
|
| 131 |
+
# Spyder project settings
|
| 132 |
+
.spyderproject
|
| 133 |
+
.spyproject
|
| 134 |
+
|
| 135 |
+
# Rope project settings
|
| 136 |
+
.ropeproject
|
| 137 |
+
|
| 138 |
+
# mkdocs documentation
|
| 139 |
+
/site
|
| 140 |
+
|
| 141 |
+
# mypy
|
| 142 |
+
.mypy_cache/
|
| 143 |
+
.dmypy.json
|
| 144 |
+
dmypy.json
|
| 145 |
+
|
| 146 |
+
# Pyre type checker
|
| 147 |
+
.pyre/
|
| 148 |
+
|
| 149 |
+
# pytype static type analyzer
|
| 150 |
+
.pytype/
|
| 151 |
+
|
| 152 |
+
# Cython debug symbols
|
| 153 |
+
cython_debug/
|
| 154 |
+
|
| 155 |
+
# PyCharm
|
| 156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 160 |
+
.idea/
|
| 161 |
+
|
| 162 |
+
# IDEs
|
| 163 |
+
.vscode/
|
| 164 |
+
|
| 165 |
+
# MoE-PEFT
|
| 166 |
+
__pycache__/
|
| 167 |
+
*.egg-info/
|
| 168 |
+
*.egg
|
| 169 |
+
moe_peft.json
|
| 170 |
+
moe_peft_train_*.json
|
| 171 |
+
|
| 172 |
+
# macOS junk files
|
| 173 |
+
.DS_Store
|
| 174 |
+
|
| 175 |
+
# PEFT adapters
|
| 176 |
+
adapter_model.bin
|
| 177 |
+
adapter_config.json
|
| 178 |
+
|
| 179 |
+
result/
|
| 180 |
+
checkpoints/
|
| 181 |
+
cases/
|
| 182 |
+
dataset/
|
| 183 |
+
tblogs/
|
| 184 |
+
*.png
|
| 185 |
+
*.svg
|
| 186 |
+
logs
|
Dockerfile
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM nvidia/cuda:12.5.1-devel-ubuntu22.04
|
| 2 |
+
|
| 3 |
+
ARG PYTHON_VERSION=3.11
|
| 4 |
+
ARG http_proxy
|
| 5 |
+
ARG https_proxy
|
| 6 |
+
|
| 7 |
+
RUN apt-get update
|
| 8 |
+
|
| 9 |
+
RUN apt-get install -y \
|
| 10 |
+
locales \
|
| 11 |
+
build-essential \
|
| 12 |
+
git \
|
| 13 |
+
git-lfs \
|
| 14 |
+
vim \
|
| 15 |
+
cmake \
|
| 16 |
+
pkg-config \
|
| 17 |
+
zlib1g-dev libncurses5-dev \
|
| 18 |
+
libgdbm-dev libnss3-dev libssl-dev libreadline-dev libffi-dev wget \
|
| 19 |
+
liblzma-dev libsqlite3-dev libbz2-dev
|
| 20 |
+
|
| 21 |
+
RUN apt-get clean
|
| 22 |
+
|
| 23 |
+
ENV LANG=en_US.UTF-8
|
| 24 |
+
ENV LANGUAGE=en_US:en
|
| 25 |
+
ENV LC_ALL=en_US.UTF-8
|
| 26 |
+
|
| 27 |
+
RUN sed -i '/en_US.UTF-8/s/^# //g' /etc/locale.gen && locale-gen
|
| 28 |
+
|
| 29 |
+
ENV PYENV_ROOT=/root/.pyenv
|
| 30 |
+
ENV PATH="$PYENV_ROOT/bin/:$PATH"
|
| 31 |
+
|
| 32 |
+
RUN /usr/bin/echo -e '#!/bin/bash\neval "$(pyenv init -)"\neval "$(pyenv virtualenv-init -)"\ncd /moe_peft\nbash' | tee /opt/init.sh \
|
| 33 |
+
&& chmod +x /opt/init.sh \
|
| 34 |
+
&& /usr/bin/echo -e 'export PYENV_ROOT=/root/.pyenv' >> ~/.bashrc \
|
| 35 |
+
&& /usr/bin/echo -e 'export PATH=/root/.pyenv/bin:$PATH' >> ~/.bashrc \
|
| 36 |
+
&& /usr/bin/echo -e 'eval "$(pyenv init -)"' >> ~/.bashrc \
|
| 37 |
+
&& /usr/bin/echo -e 'eval "$(pyenv virtualenv-init -)"' >> ~/.bashrc \
|
| 38 |
+
&& git clone https://github.com/pyenv/pyenv.git /root/.pyenv \
|
| 39 |
+
&& git clone https://github.com/pyenv/pyenv-virtualenv.git /root/.pyenv/plugins/pyenv-virtualenv \
|
| 40 |
+
&& cd /root/.pyenv && src/configure && make -C src \
|
| 41 |
+
&& eval "$(pyenv init -)" \
|
| 42 |
+
&& eval "$(pyenv virtualenv-init -)"
|
| 43 |
+
|
| 44 |
+
RUN . ~/.bashrc \
|
| 45 |
+
&& pyenv install $PYTHON_VERSION \
|
| 46 |
+
&& pyenv global $PYTHON_VERSION \
|
| 47 |
+
&& git clone https://github.com/TUDB-Labs/MoE-PEFT /moe_peft \
|
| 48 |
+
&& cd /moe_peft \
|
| 49 |
+
&& pyenv virtualenv $PYTHON_VERSION moe_peft \
|
| 50 |
+
&& pyenv local moe_peft \
|
| 51 |
+
&& pip install -r ./requirements.txt --upgrade --no-compile --no-cache-dir
|
| 52 |
+
|
| 53 |
+
WORKDIR /moe_peft
|
| 54 |
+
|
| 55 |
+
CMD ["/bin/bash", "/opt/init.sh"]
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
This repository contains the code for the paper “C$^2$-Cite: Contextual-Aware Citation Generation for \\ Attributed Large Language Models”. The project is based on the open-source repository"[TUDB-Labs/MoE-PEFT](https://github.com/TUDB-Labs/MoE-PEFT)". C$^2$-Cite is a model that can answer the questions with citation markers.
|
| 2 |
+
## File description
|
| 3 |
+
- **config**: Including the configurations of training or evaluating
|
| 4 |
+
- **c2cite/backends**: Some backend tools for GMoE.
|
| 5 |
+
- **c2cite/common**: The implementation of Transformer architecture.
|
| 6 |
+
- **c2cite/models**: The implementation of some series of Transformer-based models.
|
| 7 |
+
- **c2cite/tasks**: The implementation of datasets.
|
| 8 |
+
- **c2cite.py** The start file of this project.
|
| 9 |
+
## Environment Requirements
|
| 10 |
+
- python3=3.11
|
| 11 |
+
- pytorch >= 2.1.2
|
| 12 |
+
- Other dependencies, See ```requirements.txt```
|
| 13 |
+
## Quick Start
|
| 14 |
+
### STEP 1: Download Base models
|
| 15 |
+
- [Llama-3-8B-inst]
|
| 16 |
+
### STEP 2: Downlaod training datasets
|
| 17 |
+
To get Training dataset proposed in paper "Towards Faithful and Robust LLM Specialists for Evidence-Based Question-Answering", you can download [SynSciQA](https://github.com/EdisonNi-hku/Robust_Evidence_Based_QA) here. And please put SynSciQA.json, SynSciQA+.json, SynSciQA++.json in ./dataset/SynSciQA
|
| 18 |
+
### STEP 3: Download evaluation datasets
|
| 19 |
+
We evaluate our model and baselines using [ALCE](https://github.com/princeton-nlp/ALCE). To get Evaluate datasets, please run
|
| 20 |
+
```bash
|
| 21 |
+
bash download_test_data.sh
|
| 22 |
+
```
|
| 23 |
+
### STEP 4: Start training
|
| 24 |
+
Replace the **[base model]** and the **[train/evaluate config]** below with the directory of base model and the configuration in Folder "config".
|
| 25 |
+
``````python
|
| 26 |
+
python c2cite.py --dir ./checkpoint --log_file ./logs --verbose --seed 42 --attn_impl eager --base_model [base model] --config [train/evaluate config] --device cuda:0
|
| 27 |
+
``````
|
| 28 |
+
### STEP 5: Conduct evaluation
|
| 29 |
+
After training process, we can conduct the evaluation step with the command below:
|
| 30 |
+
``````python
|
| 31 |
+
python c2cite.py --dir ./checkpoint --log_file ./logs --verbose --seed 42 --attn_impl eager --base_model [base model] --config [train/evaluate config] --device cuda:0 --evaluate
|
| 32 |
+
``````
|
| 33 |
+
***Note***: **Do not** change the information in the **train config** after training step, or it won't find the right adapter.
|
| 34 |
+
|
c2cite.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
from typing import Dict, List, Tuple, Union
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from transformers.utils import is_flash_attn_2_available
|
| 10 |
+
|
| 11 |
+
import moe_peft
|
| 12 |
+
import moe_peft.adapters
|
| 13 |
+
|
| 14 |
+
# Command Line Arguments
|
| 15 |
+
parser = argparse.ArgumentParser(description="MoE-PEFT main program")
|
| 16 |
+
parser.add_argument(
|
| 17 |
+
"--base_model", type=str, required=True, help="Path to or name of base model"
|
| 18 |
+
)
|
| 19 |
+
parser.add_argument(
|
| 20 |
+
"--inference", action="store_true", help="The inference mode (just for test)"
|
| 21 |
+
)
|
| 22 |
+
parser.add_argument(
|
| 23 |
+
"--evaluate", action="store_true", help="The evaluate mode (just for test)"
|
| 24 |
+
)
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"--disable_prompter", action="store_true", help="Disable prompter when inference"
|
| 27 |
+
)
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--load_adapter",
|
| 30 |
+
action="store_true",
|
| 31 |
+
help="Load adapter from file instead of init randomly",
|
| 32 |
+
)
|
| 33 |
+
parser.add_argument(
|
| 34 |
+
"--disable_adapter", action="store_true", help="Disable the adapter modules"
|
| 35 |
+
)
|
| 36 |
+
parser.add_argument(
|
| 37 |
+
"--attn_impl", type=str, help="Specify the implementation of attention"
|
| 38 |
+
)
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
"--sliding_window",
|
| 41 |
+
action="store_true",
|
| 42 |
+
help="Use sliding window attention (requires flash attention)",
|
| 43 |
+
)
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--disable_cache",
|
| 46 |
+
action="store_true",
|
| 47 |
+
help="Disable cache when inference",
|
| 48 |
+
)
|
| 49 |
+
parser.add_argument(
|
| 50 |
+
"--cache_implementation",
|
| 51 |
+
type=str,
|
| 52 |
+
help="Specify the implementation of cache",
|
| 53 |
+
)
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--fp16", action="store_true", help="Load base model in float16 precision"
|
| 56 |
+
)
|
| 57 |
+
parser.add_argument(
|
| 58 |
+
"--bf16", action="store_true", help="Load base model in bfloat16 precision"
|
| 59 |
+
)
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
"--tf32", action="store_true", help="Use tfloat32 instead of float32 if available"
|
| 62 |
+
)
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
"--load_8bit", action="store_true", help="Load base model with 8bit quantization"
|
| 65 |
+
)
|
| 66 |
+
parser.add_argument(
|
| 67 |
+
"--load_4bit", action="store_true", help="Load base model with 4bit quantization"
|
| 68 |
+
)
|
| 69 |
+
parser.add_argument("--device", type=str, help="Specify which GPU to be used")
|
| 70 |
+
parser.add_argument(
|
| 71 |
+
"--config", type=str, required=True, help="Path to finetune configuration"
|
| 72 |
+
)
|
| 73 |
+
parser.add_argument(
|
| 74 |
+
"--seed", type=int, default=42, help="Random seed in integer, default is 42"
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--dir", type=str, default=".", help="Path to read or save checkpoints"
|
| 78 |
+
)
|
| 79 |
+
parser.add_argument("--disable_log", action="store_true", help="Disable logging")
|
| 80 |
+
parser.add_argument("--log_file", type=str, help="Save log to specific file")
|
| 81 |
+
parser.add_argument(
|
| 82 |
+
"--verbose", action="store_true", help="Show extra informations such as parameters"
|
| 83 |
+
)
|
| 84 |
+
parser.add_argument(
|
| 85 |
+
"--overwrite",
|
| 86 |
+
action="store_true",
|
| 87 |
+
help="Overwrite adapter model when older one existed",
|
| 88 |
+
)
|
| 89 |
+
parser.add_argument("--debug", action="store_true", help="Enabling debugging mode")
|
| 90 |
+
parser.add_argument(
|
| 91 |
+
"--deterministic",
|
| 92 |
+
action="store_true",
|
| 93 |
+
help="Use deterministic algorithms to improve the reproducibility",
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
args = parser.parse_args()
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def query_yes_no(question, default="no"):
|
| 100 |
+
valid = {"yes": True, "y": True, "ye": True, "no": False, "n": False}
|
| 101 |
+
if default is None:
|
| 102 |
+
prompt = " [y/n] "
|
| 103 |
+
elif default == "yes":
|
| 104 |
+
prompt = " [Y/n] "
|
| 105 |
+
elif default == "no":
|
| 106 |
+
prompt = " [y/N] "
|
| 107 |
+
else:
|
| 108 |
+
raise ValueError("invalid default answer: '%s'" % default)
|
| 109 |
+
|
| 110 |
+
while True:
|
| 111 |
+
sys.stdout.write(question + prompt)
|
| 112 |
+
choice = input().lower()
|
| 113 |
+
if default is not None and choice == "":
|
| 114 |
+
return valid[default]
|
| 115 |
+
elif choice in valid:
|
| 116 |
+
return valid[choice]
|
| 117 |
+
else:
|
| 118 |
+
sys.stdout.write("Please respond with 'yes' or 'no' " "(or 'y' or 'n').\n")
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def load_base_model() -> Tuple[moe_peft.Tokenizer, moe_peft.LLMModel]:
|
| 122 |
+
logging.info("Initializing pre-trained model.")
|
| 123 |
+
model = moe_peft.LLMModel.from_pretrained(
|
| 124 |
+
name_or_path=args.base_model,
|
| 125 |
+
device=args.device,
|
| 126 |
+
attn_impl=args.attn_impl,
|
| 127 |
+
use_sliding_window=args.sliding_window,
|
| 128 |
+
bits=(8 if args.load_8bit else (4 if args.load_4bit else None)),
|
| 129 |
+
load_dtype=(
|
| 130 |
+
torch.bfloat16
|
| 131 |
+
if args.bf16
|
| 132 |
+
else (torch.float16 if args.fp16 else torch.float32)
|
| 133 |
+
),
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
tokenizer = moe_peft.Tokenizer(args.base_model)
|
| 137 |
+
|
| 138 |
+
return tokenizer, model
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def init_adapter_config(
|
| 142 |
+
config: Dict[str, any],
|
| 143 |
+
llm_model: moe_peft.LLMModel,
|
| 144 |
+
) -> List[Union[moe_peft.GenerateConfig, moe_peft.TrainConfig]]:
|
| 145 |
+
config_list = []
|
| 146 |
+
|
| 147 |
+
if config["cutoff_len"] == -1:
|
| 148 |
+
config["cutoff_len"] = llm_model.config_.max_seq_len_
|
| 149 |
+
logging.info(f"Setting cutoff_len to {llm_model.config_.max_seq_len_} automatically.")
|
| 150 |
+
|
| 151 |
+
for lora_config in config["lora"]:
|
| 152 |
+
adapter_name = lora_config["name"]
|
| 153 |
+
adapter_path = f"{args.dir}{os.sep}{adapter_name}"
|
| 154 |
+
if not args.load_adapter and os.path.exists(adapter_path):
|
| 155 |
+
if args.overwrite:
|
| 156 |
+
logging.warning(
|
| 157 |
+
f"Overwriting existed adapter model file: {adapter_path}"
|
| 158 |
+
)
|
| 159 |
+
elif not query_yes_no(
|
| 160 |
+
f"Existed adapter model file detected: {adapter_path}\n" + "Overwrite?"
|
| 161 |
+
):
|
| 162 |
+
logging.info("User canceled training due to file conflict.")
|
| 163 |
+
exit(0)
|
| 164 |
+
|
| 165 |
+
if args.load_adapter:
|
| 166 |
+
llm_model.load_adapter(adapter_path, adapter_name)
|
| 167 |
+
else:
|
| 168 |
+
llm_model.init_adapter(moe_peft.adapters.lora_config_factory(lora_config))
|
| 169 |
+
|
| 170 |
+
if args.inference:
|
| 171 |
+
config_class = moe_peft.GenerateConfig(adapter_name=adapter_name)
|
| 172 |
+
if not args.disable_prompter:
|
| 173 |
+
config_class.prompt_template = lora_config.get("prompt", None)
|
| 174 |
+
config_list.append(config_class)
|
| 175 |
+
elif args.evaluate:
|
| 176 |
+
config_list.extend(moe_peft.EvaluateConfig.from_config(lora_config))
|
| 177 |
+
else:
|
| 178 |
+
config_list.append(moe_peft.TrainConfig.from_config(lora_config))
|
| 179 |
+
|
| 180 |
+
if args.verbose:
|
| 181 |
+
logging.info(config_list[-1].__dict__)
|
| 182 |
+
|
| 183 |
+
return config_list
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def inference_callback(cur_pos, outputs):
|
| 187 |
+
print(f"POSITION: {cur_pos}")
|
| 188 |
+
for adapter_name, output in outputs.items():
|
| 189 |
+
print(f"{adapter_name} OUTPUT: {output[0]}")
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def inference(
|
| 193 |
+
model: moe_peft.LLMModel,
|
| 194 |
+
tokenizer: moe_peft.Tokenizer,
|
| 195 |
+
configs: List[moe_peft.GenerateConfig],
|
| 196 |
+
concurrent_jobs: int,
|
| 197 |
+
):
|
| 198 |
+
while True:
|
| 199 |
+
input_raw = input("INPUT WITHOUT PROMPT: ")
|
| 200 |
+
if input_raw == "QUIT":
|
| 201 |
+
return
|
| 202 |
+
for config in configs:
|
| 203 |
+
config.prompts = [input_raw]
|
| 204 |
+
callback = None if args.disable_log else inference_callback
|
| 205 |
+
outputs = moe_peft.generate(
|
| 206 |
+
model,
|
| 207 |
+
tokenizer,
|
| 208 |
+
configs,
|
| 209 |
+
max_gen_len=128,
|
| 210 |
+
use_cache=not args.disable_cache,
|
| 211 |
+
concurrent_jobs=concurrent_jobs,
|
| 212 |
+
cache_implementation=args.cache_implementation,
|
| 213 |
+
stream_callback=callback,
|
| 214 |
+
)
|
| 215 |
+
print(f"\n{'='*10}\n")
|
| 216 |
+
print(f"PROMPT: {input_raw}")
|
| 217 |
+
for adapter_name, output in outputs.items():
|
| 218 |
+
print(f"{adapter_name} OUTPUT:")
|
| 219 |
+
print(output[0])
|
| 220 |
+
print(f"\n{'='*10}\n")
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
# Main Function
|
| 224 |
+
if __name__ == "__main__":
|
| 225 |
+
if args.debug:
|
| 226 |
+
torch.autograd.set_detect_anomaly(True)
|
| 227 |
+
|
| 228 |
+
if args.inference or args.evaluate:
|
| 229 |
+
args.load_adapter = True
|
| 230 |
+
inference_mode = True
|
| 231 |
+
else:
|
| 232 |
+
inference_mode = False
|
| 233 |
+
#args.load_adapter = False##############################
|
| 234 |
+
moe_peft.setup_logging("INFO", args.log_file)
|
| 235 |
+
|
| 236 |
+
moe_peft_executor = moe_peft.executor
|
| 237 |
+
|
| 238 |
+
if not moe_peft_executor.check_available():
|
| 239 |
+
exit(-1)
|
| 240 |
+
|
| 241 |
+
if args.attn_impl is None:
|
| 242 |
+
if (
|
| 243 |
+
inference_mode
|
| 244 |
+
and moe_peft_executor.device_name() == "cuda"
|
| 245 |
+
and is_flash_attn_2_available()
|
| 246 |
+
):
|
| 247 |
+
args.attn_impl = "flash_attn"
|
| 248 |
+
else:
|
| 249 |
+
args.attn_impl = "eager"
|
| 250 |
+
|
| 251 |
+
if args.device is None:
|
| 252 |
+
args.device = moe_peft.executor.default_device_name()
|
| 253 |
+
|
| 254 |
+
moe_peft_executor.use_deterministic_algorithms(args.deterministic)
|
| 255 |
+
moe_peft_executor.allow_tf32(args.tf32)
|
| 256 |
+
moe_peft_executor.manual_seed(args.seed)
|
| 257 |
+
|
| 258 |
+
with open(args.config, "r", encoding="utf8") as fp:
|
| 259 |
+
config = json.load(fp)
|
| 260 |
+
|
| 261 |
+
tokenizer, model = load_base_model()
|
| 262 |
+
adapters = init_adapter_config(config, model)
|
| 263 |
+
|
| 264 |
+
moe_peft_executor.empty_cache()
|
| 265 |
+
|
| 266 |
+
if os.getenv("MOE_PEFT_EVALUATE_MODE") is None:
|
| 267 |
+
logging.info("Using efficient operators.")
|
| 268 |
+
else:
|
| 269 |
+
logging.info("Using deterministic operators.")
|
| 270 |
+
|
| 271 |
+
if args.inference:
|
| 272 |
+
inference(
|
| 273 |
+
model=model,
|
| 274 |
+
tokenizer=tokenizer,
|
| 275 |
+
configs=adapters,
|
| 276 |
+
concurrent_jobs=config.get("inference_lora_simultaneously_num", 2),
|
| 277 |
+
)
|
| 278 |
+
elif args.evaluate:
|
| 279 |
+
moe_peft.evaluate(
|
| 280 |
+
model=model,
|
| 281 |
+
tokenizer=tokenizer,
|
| 282 |
+
configs=adapters,
|
| 283 |
+
max_concurrent_jobs=config.get("eval_lora_simultaneously_num", None),
|
| 284 |
+
retrying_steps=config.get("eval_rollback_retrying_steps", 20),
|
| 285 |
+
max_seq_len=config["cutoff_len"],
|
| 286 |
+
save_file=config.get("evaluate_result", None),
|
| 287 |
+
require_attention = -1,
|
| 288 |
+
require_hide = -1,
|
| 289 |
+
)
|
| 290 |
+
else:
|
| 291 |
+
moe_peft.train(
|
| 292 |
+
model=model,
|
| 293 |
+
tokenizer=tokenizer,
|
| 294 |
+
configs=adapters,
|
| 295 |
+
max_concurrent_jobs=config.get("train_lora_simultaneously_num", None),
|
| 296 |
+
strategy=config["train_strategy"],
|
| 297 |
+
cutoff_len=config["cutoff_len"],
|
| 298 |
+
save_step=config["save_step"],
|
| 299 |
+
save_dir=args.dir,
|
| 300 |
+
)
|
c2cite/__init__.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .common import (
|
| 2 |
+
AdapterConfig,
|
| 3 |
+
LLMBatchConfig,
|
| 4 |
+
LLMCache,
|
| 5 |
+
LLMForCausalLM,
|
| 6 |
+
LLMModelConfig,
|
| 7 |
+
LLMModelInput,
|
| 8 |
+
LLMModelOutput,
|
| 9 |
+
LoraConfig,
|
| 10 |
+
cache_factory,
|
| 11 |
+
)
|
| 12 |
+
from .dispatcher import Dispatcher, TrainTask
|
| 13 |
+
from .evaluator import EvaluateConfig, evaluate
|
| 14 |
+
from .executors import executor
|
| 15 |
+
from .generator import GenerateConfig, generate
|
| 16 |
+
from .model import LLMModel
|
| 17 |
+
from .prompter import Prompter
|
| 18 |
+
from .tokenizer import Tokenizer
|
| 19 |
+
from .trainer import TrainConfig, train
|
| 20 |
+
from .utils import is_package_available, setup_logging
|
| 21 |
+
|
| 22 |
+
assert is_package_available("torch", "2.3.0"), "MoE-PEFT requires torch>=2.3.0"
|
| 23 |
+
assert is_package_available(
|
| 24 |
+
"transformers", "4.43.0"
|
| 25 |
+
), "MoE-PEFT requires transformers>=4.43.0"
|
| 26 |
+
|
| 27 |
+
setup_logging()
|
| 28 |
+
|
| 29 |
+
__all__ = [
|
| 30 |
+
"LLMCache",
|
| 31 |
+
"cache_factory",
|
| 32 |
+
"LLMModelConfig",
|
| 33 |
+
"LLMModelOutput",
|
| 34 |
+
"LLMForCausalLM",
|
| 35 |
+
"LLMBatchConfig",
|
| 36 |
+
"LLMModelInput",
|
| 37 |
+
"AdapterConfig",
|
| 38 |
+
"LoraConfig",
|
| 39 |
+
"TrainTask",
|
| 40 |
+
"Dispatcher",
|
| 41 |
+
"EvaluateConfig",
|
| 42 |
+
"evaluate",
|
| 43 |
+
"GenerateConfig",
|
| 44 |
+
"generate",
|
| 45 |
+
"TrainConfig",
|
| 46 |
+
"train",
|
| 47 |
+
"LLMModel",
|
| 48 |
+
"Prompter",
|
| 49 |
+
"Tokenizer",
|
| 50 |
+
"setup_logging",
|
| 51 |
+
"executor",
|
| 52 |
+
]
|
c2cite/adapters/__init__.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Optional, TypeAlias
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from moe_peft.common import AdapterConfig, LoraConfig
|
| 6 |
+
|
| 7 |
+
from .loramoe import LoraMoe, LoraMoeConfig
|
| 8 |
+
from .mixlora import (
|
| 9 |
+
DynamicRouterLoss,
|
| 10 |
+
DynamicSparseMoe,
|
| 11 |
+
MixLoraConfig,
|
| 12 |
+
MixtralRouterLoss,
|
| 13 |
+
MixtralSparseMoe,
|
| 14 |
+
SwitchRouterLoss,
|
| 15 |
+
SwitchSparseMoe,
|
| 16 |
+
)
|
| 17 |
+
from .mola import MolaConfig, MolaRouterLoss, MolaSparseMoe
|
| 18 |
+
|
| 19 |
+
peft_type_dict = {
|
| 20 |
+
"LORA": LoraConfig,
|
| 21 |
+
"MIXLORA": MixLoraConfig,
|
| 22 |
+
"LORAMOE": LoraMoeConfig,
|
| 23 |
+
"MOLA": MolaConfig,
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
routing_strategy_dict = {
|
| 27 |
+
"mixlora": MixLoraConfig,
|
| 28 |
+
"mixlora-dynamic": MixLoraConfig,
|
| 29 |
+
"mixlora-switch": MixLoraConfig,
|
| 30 |
+
"loramoe": LoraMoeConfig,
|
| 31 |
+
"mola": MolaConfig,
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
router_loss_dict = {
|
| 35 |
+
"mixlora": MixtralRouterLoss,
|
| 36 |
+
"mixlora-dynamic": DynamicRouterLoss,
|
| 37 |
+
"mixlora-switch": SwitchRouterLoss,
|
| 38 |
+
"mola": MolaRouterLoss,
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
moe_layer_dict = {
|
| 42 |
+
"mixlora": MixtralSparseMoe,
|
| 43 |
+
"mixlora-dynamic": DynamicSparseMoe,
|
| 44 |
+
"mixlora-switch": SwitchSparseMoe,
|
| 45 |
+
"loramoe": LoraMoe,
|
| 46 |
+
"mola": MolaSparseMoe,
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def lora_config_factory(config: Dict[str, any]) -> LoraConfig:
|
| 51 |
+
if peft_type_dict.get(config.get("peft_type", ""), None) is not None:
|
| 52 |
+
config_class: TypeAlias[AdapterConfig] = peft_type_dict[config["peft_type"]]
|
| 53 |
+
elif (
|
| 54 |
+
routing_strategy_dict.get(config.get("routing_strategy", ""), None) is not None
|
| 55 |
+
):
|
| 56 |
+
config_class: TypeAlias[AdapterConfig] = routing_strategy_dict[
|
| 57 |
+
config["routing_strategy"]
|
| 58 |
+
]
|
| 59 |
+
else:
|
| 60 |
+
config_class = LoraConfig
|
| 61 |
+
|
| 62 |
+
return config_class.from_config(config).check()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def router_loss_factory(config: MixLoraConfig) -> torch.nn.Module:
|
| 66 |
+
if config.routing_strategy_ not in router_loss_dict:
|
| 67 |
+
return None
|
| 68 |
+
if config.router_loss_:
|
| 69 |
+
return router_loss_dict[config.routing_strategy_](config)
|
| 70 |
+
else:
|
| 71 |
+
return None
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def moe_layer_factory(
|
| 75 |
+
in_features: int,
|
| 76 |
+
device: torch.device,
|
| 77 |
+
config: MolaConfig,
|
| 78 |
+
gate: Optional[torch.Tensor] = None,
|
| 79 |
+
) -> torch.nn.Module:
|
| 80 |
+
if config.routing_strategy_ not in moe_layer_dict:
|
| 81 |
+
raise ValueError(f"Unknown routing strategy {config.routing_strategy_}")
|
| 82 |
+
return moe_layer_dict[config.routing_strategy_](in_features, device, config, gate)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
__all__ = [
|
| 86 |
+
"MixLoraConfig",
|
| 87 |
+
"MixtralRouterLoss",
|
| 88 |
+
"MixtralSparseMoe",
|
| 89 |
+
"DynamicRouterLoss",
|
| 90 |
+
"DynamicSparseMoe",
|
| 91 |
+
"SwitchRouterLoss",
|
| 92 |
+
"SwitchSparseMoe",
|
| 93 |
+
"LoraMoeConfig",
|
| 94 |
+
"LoraMoe",
|
| 95 |
+
"MolaConfig",
|
| 96 |
+
"MolaSparseMoe",
|
| 97 |
+
"peft_type_dict",
|
| 98 |
+
"routing_strategy_dict",
|
| 99 |
+
"router_loss_dict",
|
| 100 |
+
"moe_layer_dict",
|
| 101 |
+
"lora_config_factory",
|
| 102 |
+
"router_loss_factory",
|
| 103 |
+
"moe_layer_factory",
|
| 104 |
+
]
|
c2cite/adapters/loramoe/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .config import LoraMoeConfig
|
| 2 |
+
from .model import LoraMoe
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"LoraMoeConfig",
|
| 6 |
+
"LoraMoe",
|
| 7 |
+
]
|
c2cite/adapters/loramoe/config.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Dict
|
| 4 |
+
|
| 5 |
+
from moe_peft.common import LoraConfig
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class LoraMoeConfig(LoraConfig):
|
| 10 |
+
num_experts_: int = None
|
| 11 |
+
router_init_range_: float = None
|
| 12 |
+
routing_strategy_: str = "loramoe"
|
| 13 |
+
|
| 14 |
+
def check(self) -> "LoraMoeConfig":
|
| 15 |
+
super().check()
|
| 16 |
+
assert isinstance(self.num_experts_, int) and self.num_experts_ > 0
|
| 17 |
+
assert (
|
| 18 |
+
isinstance(self.router_init_range_, float) and self.router_init_range_ >= 0
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
return self
|
| 22 |
+
|
| 23 |
+
@staticmethod
|
| 24 |
+
def from_config(config: Dict[str, any]) -> "LoraMoeConfig":
|
| 25 |
+
return LoraMoeConfig(
|
| 26 |
+
num_experts_=config["num_experts"],
|
| 27 |
+
router_init_range_=config.get("router_init_range", 5.0),
|
| 28 |
+
**LoraConfig.from_config(config).__dict__,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
def export(self) -> Dict[str, any]:
|
| 32 |
+
config = super().export()
|
| 33 |
+
config["peft_type"] = "LORAMOE"
|
| 34 |
+
config["routing_strategy"] = self.routing_strategy_
|
| 35 |
+
config["num_experts"] = self.num_experts_
|
| 36 |
+
|
| 37 |
+
return config
|
| 38 |
+
|
| 39 |
+
def expert_config(self, expert_idx: int) -> LoraConfig:
|
| 40 |
+
config = copy.deepcopy(super())
|
| 41 |
+
config.adapter_name = f"moe.{self.adapter_name}.experts.{expert_idx}"
|
| 42 |
+
return config
|
c2cite/adapters/loramoe/model.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from moe_peft.common import Linear, LLMMoeBlock
|
| 8 |
+
|
| 9 |
+
from .config import LoraMoeConfig
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class LoraMoe(LLMMoeBlock):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
in_features: int,
|
| 16 |
+
device: torch.device,
|
| 17 |
+
config: LoraMoeConfig,
|
| 18 |
+
gate: Optional[torch.Tensor] = None,
|
| 19 |
+
) -> None:
|
| 20 |
+
super().__init__()
|
| 21 |
+
|
| 22 |
+
self.adapter_name_: str = config.adapter_name
|
| 23 |
+
self.dtype_: torch.dtype = torch.float32
|
| 24 |
+
self.gate_ = torch.nn.Linear(
|
| 25 |
+
in_features,
|
| 26 |
+
config.num_experts_,
|
| 27 |
+
bias=False,
|
| 28 |
+
device=device,
|
| 29 |
+
dtype=torch.float32,
|
| 30 |
+
)
|
| 31 |
+
self.experts_ = config.num_experts_
|
| 32 |
+
self.router_logits_: torch.Tensor = None
|
| 33 |
+
|
| 34 |
+
if gate is None:
|
| 35 |
+
torch.nn.init.kaiming_uniform_(
|
| 36 |
+
self.gate_.weight, a=math.sqrt(config.router_init_range_)
|
| 37 |
+
)
|
| 38 |
+
else:
|
| 39 |
+
with torch.no_grad():
|
| 40 |
+
self.gate_.weight.copy_(gate)
|
| 41 |
+
|
| 42 |
+
def forward(
|
| 43 |
+
self,
|
| 44 |
+
residual: torch.Tensor,
|
| 45 |
+
hidden_states: torch.Tensor,
|
| 46 |
+
lora_linear: Optional[Linear] = None,
|
| 47 |
+
) -> Tuple:
|
| 48 |
+
assert lora_linear is not None
|
| 49 |
+
router_logits = self.gate_(hidden_states.to(self.dtype_))
|
| 50 |
+
self.router_logits_ = router_logits.reshape(-1, self.experts_)
|
| 51 |
+
routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)
|
| 52 |
+
|
| 53 |
+
for expert_idx in range(self.experts_):
|
| 54 |
+
expert_lora = lora_linear.loras_[
|
| 55 |
+
f"moe.{self.adapter_name_}.experts.{expert_idx}"
|
| 56 |
+
]
|
| 57 |
+
residual = residual + (
|
| 58 |
+
torch.unsqueeze(routing_weights[:, :, expert_idx], -1)
|
| 59 |
+
* expert_lora.lora_forward(hidden_states)
|
| 60 |
+
).to(hidden_states.dtype)
|
| 61 |
+
|
| 62 |
+
return residual
|
c2cite/adapters/mixlora/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .config import MixLoraConfig
|
| 2 |
+
from .model import (
|
| 3 |
+
DynamicRouterLoss,
|
| 4 |
+
DynamicSparseMoe,
|
| 5 |
+
MixtralRouterLoss,
|
| 6 |
+
MixtralSparseMoe,
|
| 7 |
+
SwitchRouterLoss,
|
| 8 |
+
SwitchSparseMoe,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"MixLoraConfig",
|
| 13 |
+
"MixtralRouterLoss",
|
| 14 |
+
"MixtralSparseMoe",
|
| 15 |
+
"DynamicRouterLoss",
|
| 16 |
+
"DynamicSparseMoe",
|
| 17 |
+
"SwitchRouterLoss",
|
| 18 |
+
"SwitchSparseMoe",
|
| 19 |
+
]
|
c2cite/adapters/mixlora/config.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Dict, Optional, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from transformers.activations import ACT2FN
|
| 7 |
+
|
| 8 |
+
from moe_peft.common import LoraConfig
|
| 9 |
+
|
| 10 |
+
available_routing_strategies = ["mixlora", "mixlora-dynamic", "mixlora-switch"]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class MixLoraConfig(LoraConfig):
|
| 15 |
+
# expert lora
|
| 16 |
+
expert_config_: LoraConfig = None
|
| 17 |
+
# router config
|
| 18 |
+
router_aux_loss_coef_: float = None
|
| 19 |
+
router_init_range_: float = None
|
| 20 |
+
routing_strategy_: str = None
|
| 21 |
+
jitter_noise_: float = None
|
| 22 |
+
router_loss_: bool = True
|
| 23 |
+
num_experts_: int = None
|
| 24 |
+
act_fn_: Optional[Union[str, torch.nn.Module]] = None
|
| 25 |
+
# mixtral config
|
| 26 |
+
top_k_: int = None
|
| 27 |
+
# dynamic config
|
| 28 |
+
top_p_: float = None
|
| 29 |
+
temperature_: float = None
|
| 30 |
+
# switch transformers config
|
| 31 |
+
router_z_loss_coef_: float = None
|
| 32 |
+
expert_capacity_: int = None
|
| 33 |
+
ffn_dropout_: float = None
|
| 34 |
+
sparse_step_: int = None
|
| 35 |
+
|
| 36 |
+
def check(self) -> "MixLoraConfig":
|
| 37 |
+
super().check()
|
| 38 |
+
if self.expert_config_ is not None:
|
| 39 |
+
self.expert_config_.check()
|
| 40 |
+
assert (
|
| 41 |
+
isinstance(self.router_aux_loss_coef_, float)
|
| 42 |
+
and self.router_aux_loss_coef_ >= 0
|
| 43 |
+
)
|
| 44 |
+
assert (
|
| 45 |
+
isinstance(self.router_init_range_, float) and self.router_init_range_ >= 0
|
| 46 |
+
)
|
| 47 |
+
assert (
|
| 48 |
+
isinstance(self.routing_strategy_, str)
|
| 49 |
+
and self.routing_strategy_ in available_routing_strategies
|
| 50 |
+
)
|
| 51 |
+
assert isinstance(self.jitter_noise_, float) and self.jitter_noise_ >= 0
|
| 52 |
+
assert isinstance(self.router_loss_, bool)
|
| 53 |
+
assert isinstance(self.num_experts_, int) and self.num_experts_ > 0
|
| 54 |
+
assert self.act_fn_ is None or (
|
| 55 |
+
isinstance(self.act_fn_, str) and self.act_fn_ in ACT2FN
|
| 56 |
+
)
|
| 57 |
+
if self.routing_strategy_ == "mixlora":
|
| 58 |
+
assert isinstance(self.top_k_, int) and self.top_k_ > 0
|
| 59 |
+
elif self.routing_strategy_ == "mixlora-dynamic":
|
| 60 |
+
assert (
|
| 61 |
+
isinstance(self.top_p_, float) and self.top_p_ > 0 and self.top_p_ <= 1
|
| 62 |
+
)
|
| 63 |
+
assert isinstance(self.temperature_, float) and self.temperature_ >= 0
|
| 64 |
+
elif self.routing_strategy_ == "mixlora-switch":
|
| 65 |
+
assert (
|
| 66 |
+
isinstance(self.router_z_loss_coef_, float)
|
| 67 |
+
and self.router_z_loss_coef_ >= 0
|
| 68 |
+
)
|
| 69 |
+
if self.sparse_step_ is not None:
|
| 70 |
+
assert isinstance(self.sparse_step_, int) and self.sparse_step_ > 0
|
| 71 |
+
assert isinstance(self.expert_capacity_, int) and self.expert_capacity_ > 0
|
| 72 |
+
assert isinstance(self.ffn_dropout_, float) and self.ffn_dropout_ >= 0
|
| 73 |
+
|
| 74 |
+
return self
|
| 75 |
+
|
| 76 |
+
@staticmethod
|
| 77 |
+
def from_config(config: Dict[str, any]) -> "MixLoraConfig":
|
| 78 |
+
lora_config = MixLoraConfig(**LoraConfig.from_config(config).__dict__)
|
| 79 |
+
if "expert_lora" in config:
|
| 80 |
+
expert_config = copy.deepcopy(config)
|
| 81 |
+
expert_config.update(config["expert_lora"])
|
| 82 |
+
lora_config.expert_config_ = LoraConfig().from_config(expert_config)
|
| 83 |
+
lora_config.router_aux_loss_coef_ = config.get(
|
| 84 |
+
"router_aux_loss_coef", 0.001
|
| 85 |
+
) # for training
|
| 86 |
+
lora_config.routing_strategy_ = config["routing_strategy"]
|
| 87 |
+
lora_config.router_loss_ = config.get("router_loss", True)
|
| 88 |
+
lora_config.num_experts_ = config["num_experts"]
|
| 89 |
+
# silu for mixtral or gelu_new for switch transformers
|
| 90 |
+
# left blank to automatically use the original act_fn of FFN
|
| 91 |
+
lora_config.act_fn_ = config.get("act_fn", None)
|
| 92 |
+
if lora_config.routing_strategy_ == "mixlora":
|
| 93 |
+
lora_config.router_init_range_ = config.get("router_init_range", 0.02)
|
| 94 |
+
lora_config.jitter_noise_ = config.get("jitter_noise", 0.0)
|
| 95 |
+
lora_config.top_k_ = config.get("top_k", 2)
|
| 96 |
+
elif lora_config.routing_strategy_ == "mixlora-dynamic":
|
| 97 |
+
lora_config.router_init_range_ = config.get("router_init_range", 0.02)
|
| 98 |
+
lora_config.jitter_noise_ = config.get("jitter_noise", 0.0)
|
| 99 |
+
lora_config.top_p_ = config.get("top_p", 0.8)
|
| 100 |
+
lora_config.temperature_ = config.get("temperature", 0.0)
|
| 101 |
+
elif lora_config.routing_strategy_ == "mixlora-switch":
|
| 102 |
+
lora_config.router_init_range_ = config.get("router_init_range", 1.0)
|
| 103 |
+
lora_config.jitter_noise_ = config.get("jitter_noise", 0.01)
|
| 104 |
+
lora_config.router_z_loss_coef_ = config.get(
|
| 105 |
+
"router_z_loss_coef", 0.001
|
| 106 |
+
) # for training
|
| 107 |
+
# expert_capacity = (max_sequence_length / num_experts) * capacity_factor
|
| 108 |
+
# common values of capacity_factor: 1.0, 1.25, 2.0
|
| 109 |
+
lora_config.expert_capacity_ = config.get("expert_capacity", 32)
|
| 110 |
+
lora_config.ffn_dropout_ = config.get("ffn_dropout", 0.0)
|
| 111 |
+
lora_config.sparse_step_ = config.get("sparse_step", None)
|
| 112 |
+
|
| 113 |
+
return lora_config
|
| 114 |
+
|
| 115 |
+
def export(self) -> Dict[str, any]:
|
| 116 |
+
config = super().export()
|
| 117 |
+
config["peft_type"] = "MIXLORA"
|
| 118 |
+
if self.expert_config_ is not None:
|
| 119 |
+
expert_config = self.expert_config_.export()
|
| 120 |
+
expert_config.pop("peft_type")
|
| 121 |
+
expert_config.pop("target_modules")
|
| 122 |
+
config["expert_lora"] = expert_config
|
| 123 |
+
config["routing_strategy"] = self.routing_strategy_
|
| 124 |
+
config["num_experts"] = self.num_experts_
|
| 125 |
+
if self.act_fn_ is not None and isinstance(self.act_fn_, str):
|
| 126 |
+
config["act_fn"] = self.act_fn_
|
| 127 |
+
if self.routing_strategy_ == "mixlora":
|
| 128 |
+
config["top_k"] = self.top_k_
|
| 129 |
+
elif self.routing_strategy_ == "mixlora-dynamic":
|
| 130 |
+
config["top_p"] = self.top_p_
|
| 131 |
+
config["temperature"] = self.temperature_
|
| 132 |
+
elif self.routing_strategy_ == "mixlora-switch":
|
| 133 |
+
config["expert_capacity"] = self.expert_capacity_
|
| 134 |
+
config["sparse_step"] = self.sparse_step_
|
| 135 |
+
|
| 136 |
+
return config
|
| 137 |
+
|
| 138 |
+
def expert_config(self, expert_idx: int) -> LoraConfig:
|
| 139 |
+
if self.expert_config_ is None:
|
| 140 |
+
config = copy.deepcopy(super())
|
| 141 |
+
else:
|
| 142 |
+
config = copy.deepcopy(self.expert_config_)
|
| 143 |
+
config.adapter_name = f"moe.{self.adapter_name}.experts.{expert_idx}"
|
| 144 |
+
return config
|
c2cite/adapters/mixlora/model.py
ADDED
|
@@ -0,0 +1,610 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from transformers.activations import ACT2FN
|
| 6 |
+
|
| 7 |
+
from moe_peft.common import LLMFeedForward, LLMModelInput, LLMMoeBlock, slice_tensor
|
| 8 |
+
|
| 9 |
+
from .config import MixLoraConfig
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _mixlora_compatible_forward(
|
| 13 |
+
ffn_layer: LLMFeedForward,
|
| 14 |
+
moe_name: str,
|
| 15 |
+
act_fn: torch.nn.Module,
|
| 16 |
+
expert_mask: torch.Tensor,
|
| 17 |
+
hidden_states: torch.Tensor,
|
| 18 |
+
input_dtype: torch.device,
|
| 19 |
+
):
|
| 20 |
+
final_expert_states = []
|
| 21 |
+
for expert_idx in range(expert_mask.shape[0]):
|
| 22 |
+
_, top_x = torch.where(expert_mask[expert_idx])
|
| 23 |
+
lora_name = f"moe.{moe_name}.experts.{expert_idx}"
|
| 24 |
+
lora_data = slice_tensor(hidden_states, top_x, input_dtype)
|
| 25 |
+
final_expert_states.append(
|
| 26 |
+
ffn_layer._lora_forward(lora_name, act_fn, lora_data)
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
return final_expert_states
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _mixtral_load_balancing_loss_func(
|
| 33 |
+
gate_logits: torch.Tensor,
|
| 34 |
+
num_experts: int,
|
| 35 |
+
top_k: int,
|
| 36 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 37 |
+
) -> float:
|
| 38 |
+
routing_weights = torch.nn.functional.softmax(gate_logits, dim=-1)
|
| 39 |
+
|
| 40 |
+
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
| 41 |
+
|
| 42 |
+
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
|
| 43 |
+
|
| 44 |
+
if attention_mask is None:
|
| 45 |
+
# Compute the percentage of tokens routed to each experts
|
| 46 |
+
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
|
| 47 |
+
|
| 48 |
+
# Compute the average probability of routing to these experts
|
| 49 |
+
router_prob_per_expert = torch.mean(routing_weights, dim=0)
|
| 50 |
+
else:
|
| 51 |
+
batch_size, sequence_length = attention_mask.shape
|
| 52 |
+
num_hidden_layers = routing_weights.shape[0] // (batch_size * sequence_length)
|
| 53 |
+
|
| 54 |
+
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
|
| 55 |
+
expert_attention_mask = (
|
| 56 |
+
attention_mask[None, :, :, None, None]
|
| 57 |
+
.expand(
|
| 58 |
+
(num_hidden_layers, batch_size, sequence_length, top_k, num_experts)
|
| 59 |
+
)
|
| 60 |
+
.reshape(-1, top_k, num_experts)
|
| 61 |
+
.to(routing_weights.device)
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# Compute the percentage of tokens routed to each experts
|
| 65 |
+
tokens_per_expert = torch.sum(
|
| 66 |
+
expert_mask.float() * expert_attention_mask, dim=0
|
| 67 |
+
) / torch.sum(expert_attention_mask, dim=0)
|
| 68 |
+
|
| 69 |
+
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
|
| 70 |
+
router_per_expert_attention_mask = (
|
| 71 |
+
attention_mask[None, :, :, None]
|
| 72 |
+
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
|
| 73 |
+
.reshape(-1, num_experts)
|
| 74 |
+
.to(routing_weights.device)
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Compute the average probability of routing to these experts
|
| 78 |
+
router_prob_per_expert = torch.sum(
|
| 79 |
+
routing_weights * router_per_expert_attention_mask, dim=0
|
| 80 |
+
) / torch.sum(router_per_expert_attention_mask, dim=0)
|
| 81 |
+
|
| 82 |
+
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
|
| 83 |
+
return overall_loss * num_experts
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class MixtralRouterLoss(torch.nn.Module):
|
| 87 |
+
def __init__(self, config: MixLoraConfig) -> None:
|
| 88 |
+
super().__init__()
|
| 89 |
+
self.aux_loss_coef = config.router_aux_loss_coef_
|
| 90 |
+
self.experts = config.num_experts_
|
| 91 |
+
self.topk = config.top_k_
|
| 92 |
+
|
| 93 |
+
def forward(self, gate_logits, attention_mask) -> torch.Tensor:
|
| 94 |
+
return self.aux_loss_coef * _mixtral_load_balancing_loss_func(
|
| 95 |
+
gate_logits, self.experts, self.topk, attention_mask
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class MixtralSparseMoe(LLMMoeBlock):
|
| 100 |
+
def __init__(
|
| 101 |
+
self,
|
| 102 |
+
in_features: int,
|
| 103 |
+
device: torch.device,
|
| 104 |
+
config: MixLoraConfig,
|
| 105 |
+
gate: Optional[torch.Tensor] = None,
|
| 106 |
+
) -> None:
|
| 107 |
+
super().__init__()
|
| 108 |
+
|
| 109 |
+
self.adapter_name_: str = config.adapter_name
|
| 110 |
+
self.dtype_: torch.dtype = torch.float32
|
| 111 |
+
self.gate_ = torch.nn.Linear(
|
| 112 |
+
in_features,
|
| 113 |
+
config.num_experts_,
|
| 114 |
+
bias=False,
|
| 115 |
+
device=device,
|
| 116 |
+
dtype=self.dtype_,
|
| 117 |
+
)
|
| 118 |
+
self.act_ = (
|
| 119 |
+
ACT2FN[config.act_fn_]
|
| 120 |
+
if isinstance(config.act_fn_, str)
|
| 121 |
+
else config.act_fn_
|
| 122 |
+
)
|
| 123 |
+
self.experts_: int = config.num_experts_
|
| 124 |
+
self.topk_: int = config.top_k_
|
| 125 |
+
self.jitter_noise_: float = config.jitter_noise_
|
| 126 |
+
self.router_profile_: bool = False
|
| 127 |
+
self.profiler_: List[int] = None
|
| 128 |
+
|
| 129 |
+
if gate is None:
|
| 130 |
+
torch.nn.init.normal_(
|
| 131 |
+
self.gate_.weight,
|
| 132 |
+
mean=0.0,
|
| 133 |
+
std=config.router_init_range_,
|
| 134 |
+
)
|
| 135 |
+
else:
|
| 136 |
+
with torch.no_grad():
|
| 137 |
+
self.gate_.weight.copy_(gate)
|
| 138 |
+
|
| 139 |
+
def state_dict(self) -> Dict[str, torch.nn.Module]:
|
| 140 |
+
return {"gate": self.gate_.weight}
|
| 141 |
+
|
| 142 |
+
def _profiling(
|
| 143 |
+
self, batch_size: int, sequence_length: int, selected_experts: torch.Tensor
|
| 144 |
+
) -> None:
|
| 145 |
+
if not self.router_profile_:
|
| 146 |
+
return
|
| 147 |
+
|
| 148 |
+
router_statistic_ = list(0 for _ in range(self.experts_))
|
| 149 |
+
for selected in selected_experts.tolist():
|
| 150 |
+
for idx in selected:
|
| 151 |
+
router_statistic_[idx] += 1
|
| 152 |
+
|
| 153 |
+
if self.profiler_ is None:
|
| 154 |
+
self.profiler_ = list(0 for _ in range(self.experts_))
|
| 155 |
+
for idx in range(self.experts_):
|
| 156 |
+
self.profiler_[idx] = (
|
| 157 |
+
router_statistic_[idx] / batch_size
|
| 158 |
+
) / sequence_length
|
| 159 |
+
else:
|
| 160 |
+
for idx in range(self.experts_):
|
| 161 |
+
pressure = (router_statistic_[idx] / batch_size) / sequence_length
|
| 162 |
+
self.profiler_[idx] = (self.profiler_[idx] + pressure) / 2
|
| 163 |
+
|
| 164 |
+
def forward(
|
| 165 |
+
self,
|
| 166 |
+
hidden_states: torch.Tensor,
|
| 167 |
+
ffn_layer: LLMFeedForward,
|
| 168 |
+
input_args: LLMModelInput,
|
| 169 |
+
) -> Tuple:
|
| 170 |
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
| 171 |
+
|
| 172 |
+
if not input_args.inference_mode_ and self.jitter_noise_ > 0:
|
| 173 |
+
# Multiply the token inputs by the uniform distribution - adding some noise
|
| 174 |
+
hidden_states *= torch.empty_like(hidden_states).uniform_(
|
| 175 |
+
1.0 - self.jitter_noise_, 1.0 + self.jitter_noise_
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
input_dtype = hidden_states.dtype
|
| 179 |
+
hidden_states = hidden_states.view(-1, hidden_dim).to(self.dtype_)
|
| 180 |
+
# router_logits: (batch * sequence_length, n_experts)
|
| 181 |
+
router_logits = self.gate_(hidden_states)
|
| 182 |
+
|
| 183 |
+
routing_weights = F.softmax(router_logits, dim=1, dtype=self.dtype_)
|
| 184 |
+
routing_weights, selected_experts = torch.topk(
|
| 185 |
+
routing_weights, self.topk_, dim=-1
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
self._profiling(batch_size, sequence_length, selected_experts)
|
| 189 |
+
|
| 190 |
+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
| 191 |
+
|
| 192 |
+
final_hidden_states = torch.zeros(
|
| 193 |
+
(batch_size * sequence_length, hidden_dim),
|
| 194 |
+
dtype=self.dtype_,
|
| 195 |
+
device=hidden_states.device,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# One hot encode the selected experts to create an expert mask
|
| 199 |
+
# this will be used to easily index which expert is going to be sollicitated
|
| 200 |
+
expert_mask = torch.nn.functional.one_hot(
|
| 201 |
+
selected_experts, num_classes=self.experts_
|
| 202 |
+
).permute(2, 1, 0)
|
| 203 |
+
|
| 204 |
+
# Perform the computation on each expert
|
| 205 |
+
if input_args.efficient_operator_ and hasattr(ffn_layer, "_mixlora_forward"):
|
| 206 |
+
expert_states = ffn_layer._mixlora_forward(
|
| 207 |
+
self.adapter_name_, self.act_, expert_mask, hidden_states, input_dtype
|
| 208 |
+
)
|
| 209 |
+
else:
|
| 210 |
+
expert_states = _mixlora_compatible_forward(
|
| 211 |
+
ffn_layer,
|
| 212 |
+
self.adapter_name_,
|
| 213 |
+
self.act_,
|
| 214 |
+
expert_mask,
|
| 215 |
+
hidden_states,
|
| 216 |
+
input_dtype,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
# Unpack
|
| 220 |
+
for expert_idx in range(self.experts_):
|
| 221 |
+
idx, top_x = torch.where(expert_mask[expert_idx])
|
| 222 |
+
|
| 223 |
+
# Index the correct hidden states and compute the expert hidden state for
|
| 224 |
+
# the current expert. We need to make sure to multiply the output hidden
|
| 225 |
+
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
| 226 |
+
current_hidden_states = (
|
| 227 |
+
expert_states[expert_idx] * routing_weights[top_x, idx, None]
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# However `index_add_` only support torch tensors for indexing so we'll use
|
| 231 |
+
# the `top_x` tensor here.
|
| 232 |
+
final_hidden_states.index_add_(
|
| 233 |
+
0, top_x, current_hidden_states.to(self.dtype_)
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
final_hidden_states = final_hidden_states.reshape(
|
| 237 |
+
batch_size, sequence_length, hidden_dim
|
| 238 |
+
).to(input_dtype)
|
| 239 |
+
|
| 240 |
+
return final_hidden_states, router_logits
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def _dynamic_top_p(router_logits: torch.Tensor, top_p: float, temperature: float = 0.0):
|
| 244 |
+
if temperature > 0.0:
|
| 245 |
+
router_logits = router_logits / temperature
|
| 246 |
+
sorted_logits, sorted_indices = torch.sort(router_logits, dim=-1, descending=True)
|
| 247 |
+
cumulative_probs = sorted_logits.cumsum(dim=-1)
|
| 248 |
+
expert_mask = cumulative_probs > top_p
|
| 249 |
+
threshold_indices = expert_mask.long().argmax(dim=-1)
|
| 250 |
+
threshold_mask = torch.nn.functional.one_hot(
|
| 251 |
+
threshold_indices, num_classes=sorted_indices.size(-1)
|
| 252 |
+
).bool()
|
| 253 |
+
expert_mask = expert_mask & ~threshold_mask
|
| 254 |
+
sorted_logits = sorted_logits.masked_fill(expert_mask, 0.0)
|
| 255 |
+
sorted_indices = sorted_indices.masked_fill(expert_mask, -1)
|
| 256 |
+
return sorted_logits, sorted_indices
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def _dynamic_load_balancing_loss_func(
|
| 260 |
+
routing_weights: torch.Tensor,
|
| 261 |
+
num_experts: int,
|
| 262 |
+
top_p: float,
|
| 263 |
+
temperature: float,
|
| 264 |
+
) -> float:
|
| 265 |
+
_, selected_experts = _dynamic_top_p(routing_weights, top_p, temperature)
|
| 266 |
+
|
| 267 |
+
expert_mask = torch.empty(
|
| 268 |
+
(num_experts, num_experts, routing_weights.size(0)),
|
| 269 |
+
dtype=routing_weights.dtype,
|
| 270 |
+
device=routing_weights.device,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
for expert_idx in range(num_experts):
|
| 274 |
+
expert_mask[expert_idx] = (selected_experts == expert_idx).transpose(0, 1)
|
| 275 |
+
|
| 276 |
+
expert_mask = expert_mask.permute(2, 1, 0)
|
| 277 |
+
|
| 278 |
+
# Compute the percentage of tokens routed to each experts
|
| 279 |
+
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
|
| 280 |
+
|
| 281 |
+
# Compute the average probability of routing to these experts
|
| 282 |
+
router_prob_per_expert = torch.mean(routing_weights, dim=0)
|
| 283 |
+
|
| 284 |
+
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
|
| 285 |
+
return overall_loss * num_experts
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
class DynamicRouterLoss(torch.nn.Module):
|
| 289 |
+
def __init__(self, config: MixLoraConfig) -> None:
|
| 290 |
+
super().__init__()
|
| 291 |
+
self.aux_loss_coef = config.router_aux_loss_coef_
|
| 292 |
+
self.experts = config.num_experts_
|
| 293 |
+
self.top_p = config.top_p_
|
| 294 |
+
self.temperature = config.temperature_
|
| 295 |
+
|
| 296 |
+
def forward(self, gate_logits, attention_mask) -> torch.Tensor:
|
| 297 |
+
routing_weights = torch.nn.functional.softmax(gate_logits, dim=-1)
|
| 298 |
+
return self.aux_loss_coef * _dynamic_load_balancing_loss_func(
|
| 299 |
+
routing_weights,
|
| 300 |
+
self.experts,
|
| 301 |
+
self.top_p,
|
| 302 |
+
self.temperature,
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
class DynamicSparseMoe(LLMMoeBlock):
|
| 307 |
+
def __init__(
|
| 308 |
+
self,
|
| 309 |
+
in_features: int,
|
| 310 |
+
device: torch.device,
|
| 311 |
+
config: MixLoraConfig,
|
| 312 |
+
gate: Optional[torch.Tensor] = None,
|
| 313 |
+
) -> None:
|
| 314 |
+
super().__init__()
|
| 315 |
+
|
| 316 |
+
self.adapter_name_: str = config.adapter_name
|
| 317 |
+
self.dtype_: torch.dtype = torch.float32
|
| 318 |
+
self.gate_ = torch.nn.Linear(
|
| 319 |
+
in_features,
|
| 320 |
+
config.num_experts_,
|
| 321 |
+
bias=False,
|
| 322 |
+
device=device,
|
| 323 |
+
dtype=self.dtype_,
|
| 324 |
+
)
|
| 325 |
+
self.act_ = (
|
| 326 |
+
ACT2FN[config.act_fn_]
|
| 327 |
+
if isinstance(config.act_fn_, str)
|
| 328 |
+
else config.act_fn_
|
| 329 |
+
)
|
| 330 |
+
self.experts_: int = config.num_experts_
|
| 331 |
+
self.top_p_: float = config.top_p_
|
| 332 |
+
self.temperature_: float = config.temperature_
|
| 333 |
+
self.jitter_noise_: float = config.jitter_noise_
|
| 334 |
+
self.router_profile_: bool = False
|
| 335 |
+
self.profiler_: List[int] = None
|
| 336 |
+
|
| 337 |
+
if gate is None:
|
| 338 |
+
torch.nn.init.normal_(
|
| 339 |
+
self.gate_.weight,
|
| 340 |
+
mean=0.0,
|
| 341 |
+
std=config.router_init_range_,
|
| 342 |
+
)
|
| 343 |
+
else:
|
| 344 |
+
with torch.no_grad():
|
| 345 |
+
self.gate_.weight.copy_(gate)
|
| 346 |
+
|
| 347 |
+
def state_dict(self) -> Dict[str, torch.nn.Module]:
|
| 348 |
+
return {"gate": self.gate_.weight}
|
| 349 |
+
|
| 350 |
+
def _profiling(
|
| 351 |
+
self, batch_size: int, sequence_length: int, selected_experts: torch.Tensor
|
| 352 |
+
) -> None:
|
| 353 |
+
if not self.router_profile_:
|
| 354 |
+
return
|
| 355 |
+
|
| 356 |
+
router_statistic_ = list(0 for _ in range(self.experts_))
|
| 357 |
+
for selected in selected_experts.tolist():
|
| 358 |
+
for idx in selected:
|
| 359 |
+
router_statistic_[idx] += 1
|
| 360 |
+
|
| 361 |
+
if self.profiler_ is None:
|
| 362 |
+
self.profiler_ = list(0 for _ in range(self.experts_))
|
| 363 |
+
for idx in range(self.experts_):
|
| 364 |
+
self.profiler_[idx] = (
|
| 365 |
+
router_statistic_[idx] / batch_size
|
| 366 |
+
) / sequence_length
|
| 367 |
+
else:
|
| 368 |
+
for idx in range(self.experts_):
|
| 369 |
+
pressure = (router_statistic_[idx] / batch_size) / sequence_length
|
| 370 |
+
self.profiler_[idx] = (self.profiler_[idx] + pressure) / 2
|
| 371 |
+
|
| 372 |
+
def forward(
|
| 373 |
+
self,
|
| 374 |
+
hidden_states: torch.Tensor,
|
| 375 |
+
ffn_layer: LLMFeedForward,
|
| 376 |
+
input_args: LLMModelInput,
|
| 377 |
+
) -> Tuple:
|
| 378 |
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
| 379 |
+
|
| 380 |
+
if not input_args.inference_mode_ and self.jitter_noise_ > 0:
|
| 381 |
+
# Multiply the token inputs by the uniform distribution - adding some noise
|
| 382 |
+
hidden_states *= torch.empty_like(hidden_states).uniform_(
|
| 383 |
+
1.0 - self.jitter_noise_, 1.0 + self.jitter_noise_
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
input_dtype = hidden_states.dtype
|
| 387 |
+
hidden_states = hidden_states.view(-1, hidden_dim).to(self.dtype_)
|
| 388 |
+
# router_logits: (batch * sequence_length, n_experts)
|
| 389 |
+
router_logits = self.gate_(hidden_states)
|
| 390 |
+
|
| 391 |
+
routing_weights = F.softmax(router_logits, dim=1, dtype=self.dtype_)
|
| 392 |
+
routing_weights, selected_experts = _dynamic_top_p(
|
| 393 |
+
routing_weights, self.top_p_, self.temperature_
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
self._profiling(batch_size, sequence_length, selected_experts)
|
| 397 |
+
|
| 398 |
+
final_hidden_states = torch.zeros(
|
| 399 |
+
(batch_size * sequence_length, hidden_dim),
|
| 400 |
+
dtype=self.dtype_,
|
| 401 |
+
device=hidden_states.device,
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
expert_mask = torch.empty(
|
| 405 |
+
(self.experts_, self.experts_, batch_size * sequence_length),
|
| 406 |
+
dtype=self.dtype_,
|
| 407 |
+
device=hidden_states.device,
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
for expert_idx in range(self.experts_):
|
| 411 |
+
expert_mask[expert_idx] = (selected_experts == expert_idx).transpose(0, 1)
|
| 412 |
+
|
| 413 |
+
# Perform the computation on each expert
|
| 414 |
+
if input_args.efficient_operator_ and hasattr(ffn_layer, "_mixlora_forward"):
|
| 415 |
+
expert_states = ffn_layer._mixlora_forward(
|
| 416 |
+
self.adapter_name_, self.act_, expert_mask, hidden_states, input_dtype
|
| 417 |
+
)
|
| 418 |
+
else:
|
| 419 |
+
expert_states = _mixlora_compatible_forward(
|
| 420 |
+
ffn_layer,
|
| 421 |
+
self.adapter_name_,
|
| 422 |
+
self.act_,
|
| 423 |
+
expert_mask,
|
| 424 |
+
hidden_states,
|
| 425 |
+
input_dtype,
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
# Unpack
|
| 429 |
+
for expert_idx in range(self.experts_):
|
| 430 |
+
idx, top_x = torch.where(expert_mask[expert_idx])
|
| 431 |
+
|
| 432 |
+
# Index the correct hidden states and compute the expert hidden state for
|
| 433 |
+
# the current expert. We need to make sure to multiply the output hidden
|
| 434 |
+
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
| 435 |
+
current_hidden_states = (
|
| 436 |
+
expert_states[expert_idx] * routing_weights[top_x, idx, None]
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
# However `index_add_` only support torch tensors for indexing so we'll use
|
| 440 |
+
# the `top_x` tensor here.
|
| 441 |
+
final_hidden_states.index_add_(
|
| 442 |
+
0, top_x, current_hidden_states.to(self.dtype_)
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
final_hidden_states = final_hidden_states.reshape(
|
| 446 |
+
batch_size, sequence_length, hidden_dim
|
| 447 |
+
).to(input_dtype)
|
| 448 |
+
|
| 449 |
+
return final_hidden_states, router_logits
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def _switch_router_z_loss_func(router_logits: torch.Tensor) -> float:
|
| 453 |
+
log_z = torch.logsumexp(router_logits, dim=-1)
|
| 454 |
+
z_loss = log_z**2
|
| 455 |
+
return torch.sum(z_loss) / (router_logits.size(0))
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def _switch_load_balancing_loss_func(router_probs: torch.Tensor) -> float:
|
| 459 |
+
num_experts = router_probs.size(-1)
|
| 460 |
+
|
| 461 |
+
expert_mask = torch.argmax(router_probs, dim=-1)
|
| 462 |
+
expert_mask = torch.nn.functional.one_hot(expert_mask, num_classes=num_experts)
|
| 463 |
+
|
| 464 |
+
tokens_per_group_and_expert = torch.mean(expert_mask.float(), dim=0)
|
| 465 |
+
|
| 466 |
+
router_prob_per_group_and_expert = torch.mean(router_probs, dim=0)
|
| 467 |
+
return torch.mean(
|
| 468 |
+
tokens_per_group_and_expert * router_prob_per_group_and_expert
|
| 469 |
+
) * (num_experts**2)
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
class SwitchRouterLoss(torch.nn.Module):
|
| 473 |
+
def __init__(self, config: MixLoraConfig) -> None:
|
| 474 |
+
super().__init__()
|
| 475 |
+
self.experts = config.num_experts_
|
| 476 |
+
self.expert_capacity_ = config.expert_capacity_
|
| 477 |
+
self.z_loss_coef = config.router_z_loss_coef_
|
| 478 |
+
self.aux_loss_coef = config.router_aux_loss_coef_
|
| 479 |
+
|
| 480 |
+
def forward(self, router_logits, attention_mask) -> torch.Tensor:
|
| 481 |
+
z_loss = _switch_router_z_loss_func(router_logits)
|
| 482 |
+
router_probs = F.softmax(router_logits, dim=-1)
|
| 483 |
+
# recompute expert indexes due to MoE-PEFT constraints
|
| 484 |
+
aux_loss = _switch_load_balancing_loss_func(router_probs)
|
| 485 |
+
return self.z_loss_coef * z_loss + self.aux_loss_coef * aux_loss
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
class SwitchSparseMoe(LLMMoeBlock):
|
| 489 |
+
def __init__(
|
| 490 |
+
self,
|
| 491 |
+
in_features: int,
|
| 492 |
+
device: torch.device,
|
| 493 |
+
config: MixLoraConfig,
|
| 494 |
+
gate: Optional[torch.Tensor] = None,
|
| 495 |
+
) -> None:
|
| 496 |
+
super().__init__()
|
| 497 |
+
|
| 498 |
+
self.adapter_name_: str = config.adapter_name
|
| 499 |
+
self.dtype_: torch.dtype = torch.float32
|
| 500 |
+
self.gate_ = torch.nn.Linear(
|
| 501 |
+
in_features,
|
| 502 |
+
config.num_experts_,
|
| 503 |
+
bias=False,
|
| 504 |
+
device=device,
|
| 505 |
+
dtype=self.dtype_,
|
| 506 |
+
)
|
| 507 |
+
self.act_ = (
|
| 508 |
+
ACT2FN[config.act_fn_]
|
| 509 |
+
if isinstance(config.act_fn_, str)
|
| 510 |
+
else config.act_fn_
|
| 511 |
+
)
|
| 512 |
+
self.experts_: int = config.num_experts_
|
| 513 |
+
self.dropout_ = (
|
| 514 |
+
torch.nn.Dropout(config.ffn_dropout_)
|
| 515 |
+
if config.ffn_dropout_ > 0
|
| 516 |
+
else torch.nn.Identity()
|
| 517 |
+
)
|
| 518 |
+
self.expert_capacity_: int = config.expert_capacity_
|
| 519 |
+
self.jitter_noise_: float = config.jitter_noise_
|
| 520 |
+
self.router_profile_: bool = False
|
| 521 |
+
self.profiler_: List[int] = None
|
| 522 |
+
|
| 523 |
+
if gate is None:
|
| 524 |
+
torch.nn.init.normal_(
|
| 525 |
+
self.gate_.weight,
|
| 526 |
+
mean=0.0,
|
| 527 |
+
std=config.router_init_range_,
|
| 528 |
+
)
|
| 529 |
+
else:
|
| 530 |
+
with torch.no_grad():
|
| 531 |
+
self.gate_.weight.copy_(gate)
|
| 532 |
+
|
| 533 |
+
def _profiling(
|
| 534 |
+
self, batch_size: int, sequence_length: int, router_mask: torch.Tensor
|
| 535 |
+
) -> None:
|
| 536 |
+
if not self.router_profile_:
|
| 537 |
+
return
|
| 538 |
+
|
| 539 |
+
selected_experts = torch.argmax(router_mask, dim=-1)
|
| 540 |
+
|
| 541 |
+
router_statistic_ = list(0 for _ in range(self.experts_))
|
| 542 |
+
for selected in selected_experts.tolist():
|
| 543 |
+
for idx in selected:
|
| 544 |
+
router_statistic_[idx] += 1
|
| 545 |
+
|
| 546 |
+
if self.profiler_ is None:
|
| 547 |
+
self.profiler_ = list(0 for _ in range(self.experts_))
|
| 548 |
+
for idx in range(self.experts_):
|
| 549 |
+
self.profiler_[idx] = (
|
| 550 |
+
router_statistic_[idx] / batch_size
|
| 551 |
+
) / sequence_length
|
| 552 |
+
else:
|
| 553 |
+
for idx in range(self.experts_):
|
| 554 |
+
pressure = (router_statistic_[idx] / batch_size) / sequence_length
|
| 555 |
+
self.profiler_[idx] = (self.profiler_[idx] + pressure) / 2
|
| 556 |
+
|
| 557 |
+
def route(self, hidden_states: torch.Tensor, input_args: LLMModelInput) -> Tuple:
|
| 558 |
+
if not input_args.inference_mode_ and self.jitter_noise_ > 0:
|
| 559 |
+
# Multiply the token inputs by the uniform distribution - adding some noise
|
| 560 |
+
hidden_states = hidden_states * torch.empty_like(hidden_states).uniform_(
|
| 561 |
+
1.0 - self.jitter_noise_, 1.0 + self.jitter_noise_
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
# Apply Softmax
|
| 565 |
+
router_logits = self.gate_(hidden_states)
|
| 566 |
+
router_probs = F.softmax(router_logits, dim=-1, dtype=self.dtype_)
|
| 567 |
+
|
| 568 |
+
expert_index = torch.argmax(router_probs, dim=-1)
|
| 569 |
+
expert_index = torch.nn.functional.one_hot(
|
| 570 |
+
expert_index, num_classes=self.experts_
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
# Mask tokens outside expert capacity. Sum over each sequence
|
| 574 |
+
token_priority = torch.cumsum(expert_index, dim=-2)
|
| 575 |
+
# mask if the token routed to to the expert will overflow
|
| 576 |
+
expert_capacity_mask = token_priority <= self.expert_capacity_
|
| 577 |
+
expert_index = expert_index * expert_capacity_mask
|
| 578 |
+
|
| 579 |
+
router_probs = torch.max(router_probs, dim=-1).values.unsqueeze(-1)
|
| 580 |
+
return expert_index, router_probs, router_logits
|
| 581 |
+
|
| 582 |
+
def forward(
|
| 583 |
+
self,
|
| 584 |
+
hidden_states: torch.Tensor,
|
| 585 |
+
ffn_layer: LLMFeedForward,
|
| 586 |
+
input_args: LLMModelInput,
|
| 587 |
+
) -> Tuple:
|
| 588 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
| 589 |
+
|
| 590 |
+
input_dtype = hidden_states.dtype
|
| 591 |
+
hidden_states = hidden_states.to(self.dtype_)
|
| 592 |
+
|
| 593 |
+
router_mask, router_probs, router_logits = self.route(hidden_states, input_args)
|
| 594 |
+
|
| 595 |
+
self._profiling(batch_size, sequence_length, router_mask)
|
| 596 |
+
|
| 597 |
+
next_states = hidden_states.clone()
|
| 598 |
+
for expert_idx in range(self.experts_):
|
| 599 |
+
token_indices = router_mask[:, :, expert_idx].bool()
|
| 600 |
+
lora_name = f"moe.{self.adapter_name_}.experts.{expert_idx}"
|
| 601 |
+
next_states[token_indices] = ffn_layer._lora_forward(
|
| 602 |
+
lora_name, self.act_, hidden_states[token_indices].to(input_dtype)
|
| 603 |
+
).to(next_states.dtype)
|
| 604 |
+
|
| 605 |
+
if input_args.inference_mode_:
|
| 606 |
+
hidden_states = hidden_states.to(input_dtype)
|
| 607 |
+
else:
|
| 608 |
+
hidden_states = self.dropout_(router_probs * next_states).to(input_dtype)
|
| 609 |
+
|
| 610 |
+
return hidden_states, router_logits.reshape(-1, self.experts_)
|
c2cite/adapters/mola/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .config import MolaConfig
|
| 2 |
+
from .model import MolaRouterLoss, MolaSparseMoe
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"MolaConfig",
|
| 6 |
+
"MolaSparseMoe",
|
| 7 |
+
"MolaRouterLoss",
|
| 8 |
+
]
|
c2cite/adapters/mola/config.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Dict
|
| 4 |
+
|
| 5 |
+
from moe_peft.common import LoraConfig
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class MolaConfig(LoraConfig):
|
| 10 |
+
top_k_: int = None
|
| 11 |
+
num_experts_: int = None
|
| 12 |
+
routing_strategy_: str = "mola"
|
| 13 |
+
router_init_range_: float = None
|
| 14 |
+
# this router loss is copied from MixLoRA
|
| 15 |
+
# and only for test MoE-PEFT propose
|
| 16 |
+
router_aux_loss_coef_: float = None
|
| 17 |
+
router_loss_: bool = True
|
| 18 |
+
|
| 19 |
+
def check(self) -> "MolaConfig":
|
| 20 |
+
super().check()
|
| 21 |
+
assert isinstance(self.top_k_, int) and self.top_k_ > 0
|
| 22 |
+
assert isinstance(self.num_experts_, int) and self.num_experts_ > 0
|
| 23 |
+
assert (
|
| 24 |
+
isinstance(self.router_init_range_, float) and self.router_init_range_ >= 0
|
| 25 |
+
)
|
| 26 |
+
assert (
|
| 27 |
+
isinstance(self.router_aux_loss_coef_, float)
|
| 28 |
+
and self.router_aux_loss_coef_ >= 0
|
| 29 |
+
)
|
| 30 |
+
assert isinstance(self.router_loss_, bool)
|
| 31 |
+
|
| 32 |
+
return self
|
| 33 |
+
|
| 34 |
+
@staticmethod
|
| 35 |
+
def from_config(config: Dict[str, any]) -> "MolaConfig":
|
| 36 |
+
return MolaConfig(
|
| 37 |
+
top_k_=config.get("top_k", 2),
|
| 38 |
+
num_experts_=config["num_experts"],
|
| 39 |
+
router_init_range_=config.get("router_init_range", 5.0),
|
| 40 |
+
router_aux_loss_coef_=config.get("router_aux_loss_coef", 0.001),
|
| 41 |
+
router_loss_=config.get("router_loss", False),
|
| 42 |
+
**LoraConfig.from_config(config).__dict__,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
def export(self) -> Dict[str, any]:
|
| 46 |
+
config = super().export()
|
| 47 |
+
config["peft_type"] = "MOLA"
|
| 48 |
+
config["routing_strategy"] = self.routing_strategy_
|
| 49 |
+
config["num_experts"] = self.num_experts_
|
| 50 |
+
config["top_k"] = self.top_k_
|
| 51 |
+
|
| 52 |
+
return config
|
| 53 |
+
|
| 54 |
+
def expert_config(self, expert_idx: int) -> LoraConfig:
|
| 55 |
+
config = copy.deepcopy(super())
|
| 56 |
+
config.adapter_name = f"moe.{self.adapter_name}.experts.{expert_idx}"
|
| 57 |
+
return config
|
c2cite/adapters/mola/model.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from moe_peft.common import Linear, LLMMoeBlock
|
| 8 |
+
|
| 9 |
+
from .config import MolaConfig
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# copied from mixlora.model._mixtral_load_balancing_loss_func
|
| 13 |
+
def _mixtral_load_balancing_loss_func(
|
| 14 |
+
gate_logits: torch.Tensor,
|
| 15 |
+
num_experts: int,
|
| 16 |
+
top_k: int,
|
| 17 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 18 |
+
) -> float:
|
| 19 |
+
routing_weights = torch.nn.functional.softmax(gate_logits, dim=-1)
|
| 20 |
+
|
| 21 |
+
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
| 22 |
+
|
| 23 |
+
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
|
| 24 |
+
|
| 25 |
+
if attention_mask is None:
|
| 26 |
+
# Compute the percentage of tokens routed to each experts
|
| 27 |
+
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
|
| 28 |
+
|
| 29 |
+
# Compute the average probability of routing to these experts
|
| 30 |
+
router_prob_per_expert = torch.mean(routing_weights, dim=0)
|
| 31 |
+
else:
|
| 32 |
+
batch_size, sequence_length = attention_mask.shape
|
| 33 |
+
num_hidden_layers = routing_weights.shape[0] // (batch_size * sequence_length)
|
| 34 |
+
|
| 35 |
+
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
|
| 36 |
+
expert_attention_mask = (
|
| 37 |
+
attention_mask[None, :, :, None, None]
|
| 38 |
+
.expand(
|
| 39 |
+
(num_hidden_layers, batch_size, sequence_length, top_k, num_experts)
|
| 40 |
+
)
|
| 41 |
+
.reshape(-1, top_k, num_experts)
|
| 42 |
+
.to(routing_weights.device)
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# Compute the percentage of tokens routed to each experts
|
| 46 |
+
tokens_per_expert = torch.sum(
|
| 47 |
+
expert_mask.float() * expert_attention_mask, dim=0
|
| 48 |
+
) / torch.sum(expert_attention_mask, dim=0)
|
| 49 |
+
|
| 50 |
+
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
|
| 51 |
+
router_per_expert_attention_mask = (
|
| 52 |
+
attention_mask[None, :, :, None]
|
| 53 |
+
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
|
| 54 |
+
.reshape(-1, num_experts)
|
| 55 |
+
.to(routing_weights.device)
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# Compute the average probability of routing to these experts
|
| 59 |
+
router_prob_per_expert = torch.sum(
|
| 60 |
+
routing_weights * router_per_expert_attention_mask, dim=0
|
| 61 |
+
) / torch.sum(router_per_expert_attention_mask, dim=0)
|
| 62 |
+
|
| 63 |
+
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
|
| 64 |
+
return overall_loss * num_experts
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class MolaRouterLoss(torch.nn.Module):
|
| 68 |
+
def __init__(self, config: MolaConfig) -> None:
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.aux_loss_coef = config.router_aux_loss_coef_
|
| 71 |
+
self.experts = config.num_experts_
|
| 72 |
+
self.topk = config.top_k_
|
| 73 |
+
|
| 74 |
+
def forward(self, gate_logits, attention_mask) -> torch.Tensor:
|
| 75 |
+
return self.aux_loss_coef * _mixtral_load_balancing_loss_func(
|
| 76 |
+
gate_logits, self.experts, self.topk, attention_mask
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class MolaSparseMoe(LLMMoeBlock):
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
in_features: int,
|
| 84 |
+
device: torch.device,
|
| 85 |
+
config: MolaConfig,
|
| 86 |
+
gate: Optional[torch.Tensor] = None,
|
| 87 |
+
) -> None:
|
| 88 |
+
super().__init__()
|
| 89 |
+
|
| 90 |
+
self.adapter_name_: str = config.adapter_name
|
| 91 |
+
self.dtype_: torch.dtype = torch.float32
|
| 92 |
+
self.gate_ = torch.nn.Linear(
|
| 93 |
+
in_features,
|
| 94 |
+
config.num_experts_,
|
| 95 |
+
bias=False,
|
| 96 |
+
device=device,
|
| 97 |
+
dtype=torch.float32,
|
| 98 |
+
)
|
| 99 |
+
self.experts_ = config.num_experts_
|
| 100 |
+
self.topk_ = config.top_k_
|
| 101 |
+
self.router_logits_: torch.Tensor = None
|
| 102 |
+
|
| 103 |
+
if gate is None:
|
| 104 |
+
torch.nn.init.kaiming_uniform_(
|
| 105 |
+
self.gate_.weight, a=math.sqrt(config.router_init_range_)
|
| 106 |
+
)
|
| 107 |
+
else:
|
| 108 |
+
with torch.no_grad():
|
| 109 |
+
self.gate_.weight.copy_(gate)
|
| 110 |
+
|
| 111 |
+
def forward(
|
| 112 |
+
self,
|
| 113 |
+
residual: torch.Tensor,
|
| 114 |
+
hidden_states: torch.Tensor,
|
| 115 |
+
lora_linear: Optional[Linear] = None,
|
| 116 |
+
):
|
| 117 |
+
assert lora_linear is not None
|
| 118 |
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
| 119 |
+
input_dtype = hidden_states.dtype
|
| 120 |
+
hidden_states = hidden_states.view(-1, hidden_dim).to(self.dtype_)
|
| 121 |
+
router_logits = self.gate_(hidden_states)
|
| 122 |
+
self.router_logits_ = router_logits.reshape(-1, self.experts_)
|
| 123 |
+
routing_weights_before = F.softmax(router_logits, dim=1, dtype=self.dtype_)
|
| 124 |
+
|
| 125 |
+
routing_weights, selected_experts = torch.topk(
|
| 126 |
+
routing_weights_before, self.topk_, dim=-1
|
| 127 |
+
)
|
| 128 |
+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
| 129 |
+
|
| 130 |
+
expert_mask = torch.nn.functional.one_hot(
|
| 131 |
+
selected_experts, num_classes=self.experts_
|
| 132 |
+
).permute(2, 1, 0)
|
| 133 |
+
|
| 134 |
+
final_hidden_states = torch.zeros(
|
| 135 |
+
(batch_size * sequence_length, lora_linear.out_features_),
|
| 136 |
+
dtype=self.dtype_,
|
| 137 |
+
device=hidden_states.device,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
for expert_idx in range(self.experts_):
|
| 141 |
+
expert_lora = lora_linear.loras_[
|
| 142 |
+
f"moe.{self.adapter_name_}.experts.{expert_idx}"
|
| 143 |
+
]
|
| 144 |
+
idx, top_x = torch.where(expert_mask[expert_idx])
|
| 145 |
+
|
| 146 |
+
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
|
| 147 |
+
current_hidden_states = (
|
| 148 |
+
expert_lora.lora_forward(current_state)
|
| 149 |
+
* routing_weights[top_x, idx, None]
|
| 150 |
+
)
|
| 151 |
+
final_hidden_states.index_add_(
|
| 152 |
+
0, top_x, current_hidden_states.to(self.dtype_)
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
final_hidden_states = final_hidden_states.reshape(
|
| 156 |
+
batch_size, sequence_length, lora_linear.out_features_
|
| 157 |
+
).to(input_dtype)
|
| 158 |
+
|
| 159 |
+
return residual + final_hidden_states
|
c2cite/common/__init__.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Basic Abstract Class
|
| 2 |
+
from .abstracts import (
|
| 3 |
+
LLMAttention,
|
| 4 |
+
LLMCache,
|
| 5 |
+
LLMDecoder,
|
| 6 |
+
LLMFeedForward,
|
| 7 |
+
LLMForCausalLM,
|
| 8 |
+
LLMMoeBlock,
|
| 9 |
+
LLMOutput,
|
| 10 |
+
)
|
| 11 |
+
from .attention import (
|
| 12 |
+
eager_attention_forward,
|
| 13 |
+
flash_attention_forward,
|
| 14 |
+
prepare_4d_causal_attention_mask,
|
| 15 |
+
)
|
| 16 |
+
from .cache import (
|
| 17 |
+
DynamicCache,
|
| 18 |
+
HybridCache,
|
| 19 |
+
SlidingWindowCache,
|
| 20 |
+
StaticCache,
|
| 21 |
+
cache_factory,
|
| 22 |
+
)
|
| 23 |
+
from .checkpoint import (
|
| 24 |
+
CHECKPOINT_CLASSES,
|
| 25 |
+
CheckpointNoneFunction,
|
| 26 |
+
CheckpointOffloadFunction,
|
| 27 |
+
CheckpointRecomputeFunction,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# Model Configuration
|
| 31 |
+
from .config import (
|
| 32 |
+
AdapterConfig,
|
| 33 |
+
InputData,
|
| 34 |
+
Labels,
|
| 35 |
+
LLMBatchConfig,
|
| 36 |
+
LLMModelConfig,
|
| 37 |
+
LLMModelInput,
|
| 38 |
+
LLMModelOutput,
|
| 39 |
+
LoraConfig,
|
| 40 |
+
Masks,
|
| 41 |
+
Prompt,
|
| 42 |
+
Tokens,
|
| 43 |
+
)
|
| 44 |
+
from .feed_forward import FeedForward
|
| 45 |
+
|
| 46 |
+
# LoRA
|
| 47 |
+
from .lora_linear import Linear, Lora, get_range_tensor
|
| 48 |
+
|
| 49 |
+
# MoEs
|
| 50 |
+
from .moe_utils import collect_plugin_router_logtis, slice_tensor, unpack_router_logits
|
| 51 |
+
from .rope import ROPE_INIT_FUNCTIONS
|
| 52 |
+
|
| 53 |
+
__all__ = [
|
| 54 |
+
"prepare_4d_causal_attention_mask",
|
| 55 |
+
"eager_attention_forward",
|
| 56 |
+
"flash_attention_forward",
|
| 57 |
+
"LLMCache",
|
| 58 |
+
"DynamicCache",
|
| 59 |
+
"HybridCache",
|
| 60 |
+
"SlidingWindowCache",
|
| 61 |
+
"StaticCache",
|
| 62 |
+
"cache_factory",
|
| 63 |
+
"CheckpointNoneFunction",
|
| 64 |
+
"CheckpointOffloadFunction",
|
| 65 |
+
"CheckpointRecomputeFunction",
|
| 66 |
+
"CHECKPOINT_CLASSES",
|
| 67 |
+
"FeedForward",
|
| 68 |
+
"slice_tensor",
|
| 69 |
+
"unpack_router_logits",
|
| 70 |
+
"collect_plugin_router_logtis",
|
| 71 |
+
"get_range_tensor",
|
| 72 |
+
"Lora",
|
| 73 |
+
"Linear",
|
| 74 |
+
"LLMAttention",
|
| 75 |
+
"LLMFeedForward",
|
| 76 |
+
"LLMMoeBlock",
|
| 77 |
+
"LLMDecoder",
|
| 78 |
+
"LLMOutput",
|
| 79 |
+
"LLMForCausalLM",
|
| 80 |
+
"Tokens",
|
| 81 |
+
"Labels",
|
| 82 |
+
"Masks",
|
| 83 |
+
"Prompt",
|
| 84 |
+
"InputData",
|
| 85 |
+
"LLMModelConfig",
|
| 86 |
+
"LLMModelOutput",
|
| 87 |
+
"LLMBatchConfig",
|
| 88 |
+
"LLMModelInput",
|
| 89 |
+
"AdapterConfig",
|
| 90 |
+
"LoraConfig",
|
| 91 |
+
"ROPE_INIT_FUNCTIONS",
|
| 92 |
+
]
|
c2cite/common/abstracts.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABCMeta
|
| 2 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from .config import LLMModelConfig, LLMModelInput
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class LLMCache(torch.nn.Module):
|
| 10 |
+
def __init__(self):
|
| 11 |
+
super().__init__()
|
| 12 |
+
|
| 13 |
+
def update(
|
| 14 |
+
self,
|
| 15 |
+
key_states: torch.Tensor,
|
| 16 |
+
value_states: torch.Tensor,
|
| 17 |
+
layer_idx: int,
|
| 18 |
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
| 19 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 20 |
+
raise NotImplementedError("Make sure to implement `update` in a subclass.")
|
| 21 |
+
|
| 22 |
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
| 23 |
+
# TODO: deprecate this function in favor of `cache_position`
|
| 24 |
+
raise NotImplementedError(
|
| 25 |
+
"Make sure to implement `get_seq_length` in a subclass."
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
def get_max_length(self) -> Optional[int]:
|
| 29 |
+
raise NotImplementedError(
|
| 30 |
+
"Make sure to implement `get_max_length` in a subclass."
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
def get_usable_length(
|
| 34 |
+
self, new_seq_length: int, layer_idx: Optional[int] = 0
|
| 35 |
+
) -> int:
|
| 36 |
+
max_length = self.get_max_length()
|
| 37 |
+
previous_seq_length = self.get_seq_length(layer_idx)
|
| 38 |
+
if max_length is not None and previous_seq_length + new_seq_length > max_length:
|
| 39 |
+
return max_length - new_seq_length
|
| 40 |
+
return previous_seq_length
|
| 41 |
+
|
| 42 |
+
def reorder_cache(self, beam_idx: torch.LongTensor):
|
| 43 |
+
for layer_idx in range(len(self.key_cache)):
|
| 44 |
+
device = self.key_cache[layer_idx].device
|
| 45 |
+
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
|
| 46 |
+
0, beam_idx.to(device)
|
| 47 |
+
)
|
| 48 |
+
device = self.value_cache[layer_idx].device
|
| 49 |
+
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(
|
| 50 |
+
0, beam_idx.to(device)
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class LLMAttention(metaclass=ABCMeta):
|
| 55 |
+
@classmethod
|
| 56 |
+
def state_dict(self) -> Dict[str, torch.nn.Module]:
|
| 57 |
+
return {}
|
| 58 |
+
|
| 59 |
+
@classmethod
|
| 60 |
+
def forward(
|
| 61 |
+
self,
|
| 62 |
+
hidden_states: torch.Tensor,
|
| 63 |
+
input_args: LLMModelInput,
|
| 64 |
+
rotary_emb: Tuple[torch.Tensor, torch.Tensor],
|
| 65 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 66 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 67 |
+
past_key_value: Optional[LLMCache] = None,
|
| 68 |
+
):
|
| 69 |
+
pass
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class LLMFeedForward(metaclass=ABCMeta):
|
| 73 |
+
@classmethod
|
| 74 |
+
def state_dict(self) -> Dict[str, torch.nn.Module]:
|
| 75 |
+
return {}
|
| 76 |
+
|
| 77 |
+
@classmethod
|
| 78 |
+
def _batch_forward(
|
| 79 |
+
self, hidden_states: torch.Tensor, input_args: LLMModelInput
|
| 80 |
+
) -> torch.Tensor:
|
| 81 |
+
pass
|
| 82 |
+
|
| 83 |
+
@classmethod
|
| 84 |
+
def _lora_forward(
|
| 85 |
+
self, lora_name: str, act_fn: torch.nn.Module, data: torch.Tensor
|
| 86 |
+
) -> torch.Tensor:
|
| 87 |
+
pass
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class LLMMoeBlock(metaclass=ABCMeta):
|
| 91 |
+
def __init__(self) -> None:
|
| 92 |
+
super().__init__()
|
| 93 |
+
|
| 94 |
+
self.adapter_name_: str = None
|
| 95 |
+
self.dtype_: torch.dtype = None
|
| 96 |
+
self.gate_: torch.nn.Linear = None
|
| 97 |
+
self.experts_: int = None
|
| 98 |
+
self.router_profile_: bool = False
|
| 99 |
+
self.profiler_: List[int] = None
|
| 100 |
+
|
| 101 |
+
@classmethod
|
| 102 |
+
def forward(
|
| 103 |
+
self,
|
| 104 |
+
residual: torch.Tensor,
|
| 105 |
+
hidden_states: torch.Tensor,
|
| 106 |
+
**kwargs,
|
| 107 |
+
) -> Tuple:
|
| 108 |
+
pass
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class LLMDecoder(metaclass=ABCMeta):
|
| 112 |
+
def __init__(self) -> None:
|
| 113 |
+
super().__init__()
|
| 114 |
+
self.self_attn_: LLMAttention = None
|
| 115 |
+
self.mlp_: LLMFeedForward = None
|
| 116 |
+
|
| 117 |
+
@classmethod
|
| 118 |
+
def state_dict(
|
| 119 |
+
self,
|
| 120 |
+
) -> Tuple[Dict[str, torch.nn.Module], Dict[str, torch.nn.Module]]:
|
| 121 |
+
return {}
|
| 122 |
+
|
| 123 |
+
@classmethod
|
| 124 |
+
def forward(
|
| 125 |
+
self,
|
| 126 |
+
hidden_states: torch.Tensor,
|
| 127 |
+
input_args: LLMModelInput,
|
| 128 |
+
rotary_emb: Tuple[torch.Tensor, torch.Tensor],
|
| 129 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 130 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 131 |
+
past_key_value: Optional[LLMCache] = None,
|
| 132 |
+
):
|
| 133 |
+
pass
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class LLMOutput(metaclass=ABCMeta):
|
| 137 |
+
@classmethod
|
| 138 |
+
def state_dict(self) -> Dict[str, torch.nn.Module]:
|
| 139 |
+
return {}
|
| 140 |
+
|
| 141 |
+
@classmethod
|
| 142 |
+
def forward(self, data: torch.Tensor) -> torch.Tensor:
|
| 143 |
+
pass
|
| 144 |
+
|
| 145 |
+
@classmethod
|
| 146 |
+
def loss(
|
| 147 |
+
self,
|
| 148 |
+
input_ids: torch.Tensor,
|
| 149 |
+
output_logits: torch.Tensor,
|
| 150 |
+
labels: List[List[int]],
|
| 151 |
+
) -> torch.Tensor:
|
| 152 |
+
pass
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class LLMForCausalLM(metaclass=ABCMeta):
|
| 156 |
+
@classmethod
|
| 157 |
+
def embed_tokens(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 158 |
+
pass
|
| 159 |
+
|
| 160 |
+
@classmethod
|
| 161 |
+
def rotary_embed(
|
| 162 |
+
self, input_tensor: torch.Tensor, position_ids: torch.Tensor
|
| 163 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 164 |
+
pass
|
| 165 |
+
|
| 166 |
+
@classmethod
|
| 167 |
+
def decoder_stack(self) -> List[LLMDecoder]:
|
| 168 |
+
pass
|
| 169 |
+
|
| 170 |
+
@classmethod
|
| 171 |
+
def norm(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 172 |
+
pass
|
| 173 |
+
|
| 174 |
+
@classmethod
|
| 175 |
+
def causal_mask(
|
| 176 |
+
self,
|
| 177 |
+
attention_mask: torch.Tensor,
|
| 178 |
+
input_tensor: torch.Tensor,
|
| 179 |
+
cache_position: torch.Tensor,
|
| 180 |
+
past_key_values: Optional[LLMCache],
|
| 181 |
+
) -> torch.Tensor:
|
| 182 |
+
pass
|
| 183 |
+
|
| 184 |
+
@classmethod
|
| 185 |
+
def cache_implementation(self) -> str:
|
| 186 |
+
return "dynamic"
|
| 187 |
+
|
| 188 |
+
@classmethod
|
| 189 |
+
def model_config(self) -> LLMModelConfig:
|
| 190 |
+
pass
|
| 191 |
+
|
| 192 |
+
@staticmethod
|
| 193 |
+
def from_pretrained(llm_model, **kwargs):
|
| 194 |
+
pass
|
c2cite/common/attention.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import math
|
| 3 |
+
from typing import Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from transformers.utils import is_flash_attn_2_available
|
| 8 |
+
|
| 9 |
+
from .cache import LLMCache, StaticCache
|
| 10 |
+
|
| 11 |
+
if is_flash_attn_2_available():
|
| 12 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 13 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
| 14 |
+
|
| 15 |
+
_flash_supports_window_size = "window_size" in list(
|
| 16 |
+
inspect.signature(flash_attn_func).parameters
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def prepare_4d_causal_attention_mask(
|
| 21 |
+
attention_mask: torch.Tensor,
|
| 22 |
+
input_tensor: torch.Tensor,
|
| 23 |
+
cache_position: torch.Tensor,
|
| 24 |
+
past_key_values: LLMCache,
|
| 25 |
+
) -> torch.Tensor:
|
| 26 |
+
past_seen_tokens = (
|
| 27 |
+
past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
if past_seen_tokens is None:
|
| 31 |
+
past_seen_tokens = 0
|
| 32 |
+
|
| 33 |
+
using_static_cache = isinstance(past_key_values, StaticCache)
|
| 34 |
+
|
| 35 |
+
dtype, device = input_tensor.dtype, input_tensor.device
|
| 36 |
+
min_dtype = torch.finfo(dtype).min
|
| 37 |
+
sequence_length = input_tensor.shape[1]
|
| 38 |
+
if using_static_cache:
|
| 39 |
+
target_length = past_key_values.get_max_length()
|
| 40 |
+
else:
|
| 41 |
+
target_length = (
|
| 42 |
+
attention_mask.shape[-1]
|
| 43 |
+
if isinstance(attention_mask, torch.Tensor)
|
| 44 |
+
else past_seen_tokens + sequence_length + 1
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
causal_mask = torch.full(
|
| 48 |
+
(sequence_length, target_length),
|
| 49 |
+
fill_value=min_dtype,
|
| 50 |
+
dtype=dtype,
|
| 51 |
+
device=device,
|
| 52 |
+
)
|
| 53 |
+
if sequence_length != 1:
|
| 54 |
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
| 55 |
+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(
|
| 56 |
+
-1, 1
|
| 57 |
+
)
|
| 58 |
+
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
|
| 59 |
+
if attention_mask is not None:
|
| 60 |
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
| 61 |
+
mask_length = attention_mask.shape[-1]
|
| 62 |
+
padding_mask = (
|
| 63 |
+
causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
| 64 |
+
)
|
| 65 |
+
padding_mask = padding_mask == 0
|
| 66 |
+
causal_mask[:, :, :, :mask_length] = causal_mask[
|
| 67 |
+
:, :, :, :mask_length
|
| 68 |
+
].masked_fill(padding_mask, min_dtype)
|
| 69 |
+
|
| 70 |
+
return causal_mask
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def eager_attention_forward(
|
| 74 |
+
query_states: torch.Tensor,
|
| 75 |
+
key_states: torch.Tensor,
|
| 76 |
+
value_states: torch.Tensor,
|
| 77 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 78 |
+
) -> torch.Tensor:
|
| 79 |
+
attention_score = torch.matmul(
|
| 80 |
+
query_states, key_states.transpose(2, 3)
|
| 81 |
+
) / math.sqrt(query_states.size(-1))
|
| 82 |
+
if attention_mask is not None:
|
| 83 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 84 |
+
attention_score = attention_score + causal_mask
|
| 85 |
+
attention_score = F.softmax(attention_score, dim=-1, dtype=torch.float32).to(
|
| 86 |
+
value_states.dtype
|
| 87 |
+
)
|
| 88 |
+
attention_matrix = attention_score
|
| 89 |
+
attention_score = torch.matmul(attention_score, value_states)
|
| 90 |
+
attention_score = attention_score.transpose(1, 2).contiguous()
|
| 91 |
+
return attention_score, attention_matrix
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _get_unpad_data(
|
| 95 |
+
attention_mask: torch.Tensor,
|
| 96 |
+
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
| 97 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
| 98 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
| 99 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
| 100 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
| 101 |
+
return (
|
| 102 |
+
indices,
|
| 103 |
+
cu_seqlens,
|
| 104 |
+
max_seqlen_in_batch,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def _upad_input(
|
| 109 |
+
query_layer: torch.Tensor,
|
| 110 |
+
key_layer: torch.Tensor,
|
| 111 |
+
value_layer: torch.Tensor,
|
| 112 |
+
attention_mask: torch.Tensor,
|
| 113 |
+
query_length: int,
|
| 114 |
+
):
|
| 115 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
| 116 |
+
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
| 117 |
+
|
| 118 |
+
key_layer = index_first_axis(
|
| 119 |
+
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
| 120 |
+
indices_k,
|
| 121 |
+
)
|
| 122 |
+
value_layer = index_first_axis(
|
| 123 |
+
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
| 124 |
+
indices_k,
|
| 125 |
+
)
|
| 126 |
+
if query_length == kv_seq_len:
|
| 127 |
+
query_layer = index_first_axis(
|
| 128 |
+
query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k
|
| 129 |
+
)
|
| 130 |
+
cu_seqlens_q = cu_seqlens_k
|
| 131 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
| 132 |
+
indices_q = indices_k
|
| 133 |
+
elif query_length == 1:
|
| 134 |
+
max_seqlen_in_batch_q = 1
|
| 135 |
+
cu_seqlens_q = torch.arange(
|
| 136 |
+
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
| 137 |
+
) # There is a memcpy here, that is very bad.
|
| 138 |
+
indices_q = cu_seqlens_q[:-1]
|
| 139 |
+
query_layer = query_layer.squeeze(1)
|
| 140 |
+
else:
|
| 141 |
+
# The -q_len: slice assumes left padding.
|
| 142 |
+
attention_mask = attention_mask[:, -query_length:]
|
| 143 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
|
| 144 |
+
query_layer, attention_mask
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
return (
|
| 148 |
+
query_layer,
|
| 149 |
+
key_layer,
|
| 150 |
+
value_layer,
|
| 151 |
+
indices_q,
|
| 152 |
+
(cu_seqlens_q, cu_seqlens_k),
|
| 153 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def prepare_fa2_from_position_ids(query, key, value, position_ids):
|
| 158 |
+
query = query.view(-1, query.size(-2), query.size(-1))
|
| 159 |
+
key = key.view(-1, key.size(-2), key.size(-1))
|
| 160 |
+
value = value.view(-1, value.size(-2), value.size(-1))
|
| 161 |
+
position_ids = position_ids.flatten()
|
| 162 |
+
indices_q = torch.arange(
|
| 163 |
+
position_ids.size(0), device=position_ids.device, dtype=torch.int32
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
cu_seq_lens = torch.cat(
|
| 167 |
+
(
|
| 168 |
+
indices_q[position_ids == 0],
|
| 169 |
+
torch.tensor(
|
| 170 |
+
position_ids.size(), device=position_ids.device, dtype=torch.int32
|
| 171 |
+
),
|
| 172 |
+
)
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
max_length = position_ids.max() + 1
|
| 176 |
+
|
| 177 |
+
return (
|
| 178 |
+
query,
|
| 179 |
+
key,
|
| 180 |
+
value,
|
| 181 |
+
indices_q,
|
| 182 |
+
(cu_seq_lens, cu_seq_lens),
|
| 183 |
+
(max_length, max_length),
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def flash_attention_forward(
|
| 188 |
+
query_states: torch.Tensor,
|
| 189 |
+
key_states: torch.Tensor,
|
| 190 |
+
value_states: torch.Tensor,
|
| 191 |
+
attention_mask: torch.Tensor,
|
| 192 |
+
query_length: int,
|
| 193 |
+
is_causal: bool,
|
| 194 |
+
dropout: float = 0.0,
|
| 195 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 196 |
+
softmax_scale: Optional[float] = None,
|
| 197 |
+
sliding_window: Optional[int] = None,
|
| 198 |
+
use_top_left_mask: bool = False,
|
| 199 |
+
softcap: Optional[float] = None,
|
| 200 |
+
deterministic: Optional[bool] = None,
|
| 201 |
+
):
|
| 202 |
+
if not use_top_left_mask:
|
| 203 |
+
causal = is_causal
|
| 204 |
+
else:
|
| 205 |
+
causal = is_causal and query_length != 1
|
| 206 |
+
|
| 207 |
+
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
|
| 208 |
+
use_sliding_windows = (
|
| 209 |
+
_flash_supports_window_size
|
| 210 |
+
and sliding_window is not None
|
| 211 |
+
and key_states.shape[1] > sliding_window
|
| 212 |
+
)
|
| 213 |
+
flash_kwargs = (
|
| 214 |
+
{"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
if deterministic is not None:
|
| 218 |
+
flash_kwargs["deterministic"] = deterministic
|
| 219 |
+
|
| 220 |
+
if softcap is not None:
|
| 221 |
+
flash_kwargs["softcap"] = softcap
|
| 222 |
+
|
| 223 |
+
# Contains at least one padding token in the sequence
|
| 224 |
+
if attention_mask is not None:
|
| 225 |
+
batch_size = query_states.shape[0]
|
| 226 |
+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = (
|
| 227 |
+
_upad_input(
|
| 228 |
+
query_states, key_states, value_states, attention_mask, query_length
|
| 229 |
+
)
|
| 230 |
+
)
|
| 231 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
| 232 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
| 233 |
+
|
| 234 |
+
attn_output_unpad = flash_attn_varlen_func(
|
| 235 |
+
query_states,
|
| 236 |
+
key_states,
|
| 237 |
+
value_states,
|
| 238 |
+
cu_seqlens_q=cu_seqlens_q,
|
| 239 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 240 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
| 241 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
| 242 |
+
dropout_p=dropout,
|
| 243 |
+
softmax_scale=softmax_scale,
|
| 244 |
+
causal=causal,
|
| 245 |
+
**flash_kwargs,
|
| 246 |
+
)
|
| 247 |
+
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
| 248 |
+
|
| 249 |
+
elif (
|
| 250 |
+
position_ids is not None
|
| 251 |
+
and not (torch.diff(position_ids, dim=-1) >= 0).all()
|
| 252 |
+
and query_length != 1
|
| 253 |
+
):
|
| 254 |
+
batch_size = query_states.size(0)
|
| 255 |
+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = (
|
| 256 |
+
prepare_fa2_from_position_ids(
|
| 257 |
+
query_states, key_states, value_states, position_ids
|
| 258 |
+
)
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
| 262 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
| 263 |
+
|
| 264 |
+
attn_output = flash_attn_varlen_func(
|
| 265 |
+
query_states,
|
| 266 |
+
key_states,
|
| 267 |
+
value_states,
|
| 268 |
+
cu_seqlens_q=cu_seqlens_q,
|
| 269 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 270 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
| 271 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
| 272 |
+
dropout_p=dropout,
|
| 273 |
+
softmax_scale=softmax_scale,
|
| 274 |
+
causal=causal,
|
| 275 |
+
**flash_kwargs,
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
attn_output = attn_output.view(
|
| 279 |
+
batch_size, -1, attn_output.size(-2), attn_output.size(-1)
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
else:
|
| 283 |
+
attn_output = flash_attn_func(
|
| 284 |
+
query_states,
|
| 285 |
+
key_states,
|
| 286 |
+
value_states,
|
| 287 |
+
dropout,
|
| 288 |
+
softmax_scale=softmax_scale,
|
| 289 |
+
causal=causal,
|
| 290 |
+
**flash_kwargs,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
return attn_output
|
c2cite/common/cache.py
ADDED
|
@@ -0,0 +1,554 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from transformers.utils import is_torchdynamo_compiling
|
| 6 |
+
|
| 7 |
+
from .abstracts import LLMCache
|
| 8 |
+
from .config import LLMModelConfig
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class DynamicCache(LLMCache):
|
| 12 |
+
def __init__(self, **kwargs) -> None:
|
| 13 |
+
super().__init__()
|
| 14 |
+
self._seen_tokens = (
|
| 15 |
+
0 # Used in `generate` to keep tally of how many tokens the cache has seen
|
| 16 |
+
)
|
| 17 |
+
self.key_cache: List[torch.Tensor] = []
|
| 18 |
+
self.value_cache: List[torch.Tensor] = []
|
| 19 |
+
|
| 20 |
+
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
|
| 21 |
+
"""
|
| 22 |
+
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
|
| 23 |
+
sequence length.
|
| 24 |
+
"""
|
| 25 |
+
if layer_idx < len(self):
|
| 26 |
+
return (self.key_cache[layer_idx], self.value_cache[layer_idx])
|
| 27 |
+
else:
|
| 28 |
+
raise KeyError(
|
| 29 |
+
f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}"
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
def __iter__(self):
|
| 33 |
+
"""
|
| 34 |
+
Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
|
| 35 |
+
keys and values
|
| 36 |
+
"""
|
| 37 |
+
for layer_idx in range(len(self)):
|
| 38 |
+
yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
|
| 39 |
+
|
| 40 |
+
def __len__(self):
|
| 41 |
+
"""
|
| 42 |
+
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
|
| 43 |
+
to the number of layers in the model.
|
| 44 |
+
"""
|
| 45 |
+
return len(self.key_cache)
|
| 46 |
+
|
| 47 |
+
def update(
|
| 48 |
+
self,
|
| 49 |
+
key_states: torch.Tensor,
|
| 50 |
+
value_states: torch.Tensor,
|
| 51 |
+
layer_idx: int,
|
| 52 |
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
| 53 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 54 |
+
# Update the number of seen tokens
|
| 55 |
+
if layer_idx == 0:
|
| 56 |
+
self._seen_tokens += key_states.shape[-2]
|
| 57 |
+
|
| 58 |
+
# Update the cache
|
| 59 |
+
if len(self.key_cache) <= layer_idx:
|
| 60 |
+
# There may be skipped layers, fill them with empty lists
|
| 61 |
+
for _ in range(len(self.key_cache), layer_idx):
|
| 62 |
+
self.key_cache.append([])
|
| 63 |
+
self.value_cache.append([])
|
| 64 |
+
self.key_cache.append(key_states)
|
| 65 |
+
self.value_cache.append(value_states)
|
| 66 |
+
elif (
|
| 67 |
+
len(self.key_cache[layer_idx]) == 0
|
| 68 |
+
): # fills previously skipped layers; checking for tensor causes errors
|
| 69 |
+
self.key_cache[layer_idx] = key_states
|
| 70 |
+
self.value_cache[layer_idx] = value_states
|
| 71 |
+
else:
|
| 72 |
+
self.key_cache[layer_idx] = torch.cat(
|
| 73 |
+
[self.key_cache[layer_idx], key_states], dim=-2
|
| 74 |
+
)
|
| 75 |
+
self.value_cache[layer_idx] = torch.cat(
|
| 76 |
+
[self.value_cache[layer_idx], value_states], dim=-2
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
| 80 |
+
|
| 81 |
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
| 82 |
+
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
| 83 |
+
# TODO: deprecate this function in favor of `cache_position`
|
| 84 |
+
is_empty_layer = (
|
| 85 |
+
len(self.key_cache) == 0 # no cache in any layer
|
| 86 |
+
or len(self.key_cache)
|
| 87 |
+
<= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
|
| 88 |
+
or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
|
| 89 |
+
)
|
| 90 |
+
layer_seq_length = (
|
| 91 |
+
self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
|
| 92 |
+
)
|
| 93 |
+
return layer_seq_length
|
| 94 |
+
|
| 95 |
+
def get_max_length(self) -> Optional[int]:
|
| 96 |
+
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
|
| 97 |
+
return None
|
| 98 |
+
|
| 99 |
+
def crop(self, max_length: int):
|
| 100 |
+
"""Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
|
| 101 |
+
negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search.
|
| 102 |
+
"""
|
| 103 |
+
# In case it is negative
|
| 104 |
+
if max_length < 0:
|
| 105 |
+
max_length = self.get_seq_length() - abs(max_length)
|
| 106 |
+
|
| 107 |
+
if self.get_seq_length() <= max_length:
|
| 108 |
+
return
|
| 109 |
+
|
| 110 |
+
self._seen_tokens = max_length
|
| 111 |
+
for idx in range(len(self.key_cache)):
|
| 112 |
+
if self.key_cache[idx] != []:
|
| 113 |
+
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
|
| 114 |
+
self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
|
| 115 |
+
|
| 116 |
+
def batch_split(
|
| 117 |
+
self, full_batch_size: int, split_size: int
|
| 118 |
+
) -> List["DynamicCache"]:
|
| 119 |
+
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
|
| 120 |
+
`_split_model_inputs()` in `generation.utils`"""
|
| 121 |
+
out = []
|
| 122 |
+
for i in range(0, full_batch_size, split_size):
|
| 123 |
+
current_split = DynamicCache()
|
| 124 |
+
current_split._seen_tokens = self._seen_tokens
|
| 125 |
+
current_split.key_cache = [
|
| 126 |
+
tensor[i : i + split_size] for tensor in self.key_cache
|
| 127 |
+
]
|
| 128 |
+
current_split.value_cache = [
|
| 129 |
+
tensor[i : i + split_size] for tensor in self.value_cache
|
| 130 |
+
]
|
| 131 |
+
out.append(current_split)
|
| 132 |
+
return out
|
| 133 |
+
|
| 134 |
+
@classmethod
|
| 135 |
+
def from_batch_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache":
|
| 136 |
+
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
|
| 137 |
+
`generation.utils`"""
|
| 138 |
+
cache = cls()
|
| 139 |
+
for idx in range(len(splits[0])):
|
| 140 |
+
key_cache = [
|
| 141 |
+
current.key_cache[idx]
|
| 142 |
+
for current in splits
|
| 143 |
+
if current.key_cache[idx] != []
|
| 144 |
+
]
|
| 145 |
+
value_cache = [
|
| 146 |
+
current.key_cache[idx]
|
| 147 |
+
for current in splits
|
| 148 |
+
if current.key_cache[idx] != []
|
| 149 |
+
]
|
| 150 |
+
if key_cache != []:
|
| 151 |
+
layer_keys = torch.cat(key_cache, dim=0)
|
| 152 |
+
layer_values = torch.cat(value_cache, dim=0)
|
| 153 |
+
cache.update(layer_keys, layer_values, idx)
|
| 154 |
+
return cache
|
| 155 |
+
|
| 156 |
+
def batch_repeat_interleave(self, repeats: int):
|
| 157 |
+
"""Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
|
| 158 |
+
for layer_idx in range(len(self)):
|
| 159 |
+
self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(
|
| 160 |
+
repeats, dim=0
|
| 161 |
+
)
|
| 162 |
+
self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(
|
| 163 |
+
repeats, dim=0
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
def batch_select_indices(self, indices: torch.Tensor):
|
| 167 |
+
"""Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
|
| 168 |
+
for layer_idx in range(len(self)):
|
| 169 |
+
self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
|
| 170 |
+
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class StaticCache(LLMCache):
|
| 174 |
+
def __init__(
|
| 175 |
+
self,
|
| 176 |
+
config: LLMModelConfig,
|
| 177 |
+
batch_size: int,
|
| 178 |
+
max_cache_len: int,
|
| 179 |
+
device: torch.device,
|
| 180 |
+
dtype: torch.dtype = torch.float32,
|
| 181 |
+
) -> None:
|
| 182 |
+
super().__init__()
|
| 183 |
+
self.batch_size = batch_size
|
| 184 |
+
self.max_cache_len = (
|
| 185 |
+
config.max_seq_len_ if max_cache_len is None else max_cache_len
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
self.head_dim = config.head_dim_
|
| 189 |
+
|
| 190 |
+
self.dtype = dtype
|
| 191 |
+
self.num_key_value_heads = config.n_kv_heads_
|
| 192 |
+
|
| 193 |
+
self.key_cache: List[torch.Tensor] = []
|
| 194 |
+
self.value_cache: List[torch.Tensor] = []
|
| 195 |
+
# Note: There will be significant perf decrease if switching to use 5D tensors instead.
|
| 196 |
+
cache_shape = (
|
| 197 |
+
self.batch_size,
|
| 198 |
+
self.num_key_value_heads,
|
| 199 |
+
self.max_cache_len,
|
| 200 |
+
self.head_dim,
|
| 201 |
+
)
|
| 202 |
+
for idx in range(config.n_layers_):
|
| 203 |
+
new_layer_key_cache = torch.zeros(
|
| 204 |
+
cache_shape, dtype=self.dtype, device=device
|
| 205 |
+
)
|
| 206 |
+
new_layer_value_cache = torch.zeros(
|
| 207 |
+
cache_shape, dtype=self.dtype, device=device
|
| 208 |
+
)
|
| 209 |
+
# Notes:
|
| 210 |
+
# 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
| 211 |
+
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
|
| 212 |
+
# it is not needed anyway)
|
| 213 |
+
# 2. `torch.export()` requires mutations to be registered as buffers.
|
| 214 |
+
if not is_torchdynamo_compiling():
|
| 215 |
+
self.register_buffer(
|
| 216 |
+
f"key_cache_{idx}",
|
| 217 |
+
torch.zeros(cache_shape, dtype=dtype, device=device),
|
| 218 |
+
)
|
| 219 |
+
self.register_buffer(
|
| 220 |
+
f"value_cache_{idx}",
|
| 221 |
+
torch.zeros(cache_shape, dtype=dtype, device=device),
|
| 222 |
+
)
|
| 223 |
+
new_layer_key_cache = getattr(self, f"key_cache_{idx}")
|
| 224 |
+
new_layer_value_cache = getattr(self, f"value_cache_{idx}")
|
| 225 |
+
torch._dynamo.mark_static_address(new_layer_key_cache)
|
| 226 |
+
torch._dynamo.mark_static_address(new_layer_value_cache)
|
| 227 |
+
self.key_cache.append(new_layer_key_cache)
|
| 228 |
+
self.value_cache.append(new_layer_value_cache)
|
| 229 |
+
|
| 230 |
+
def update(
|
| 231 |
+
self,
|
| 232 |
+
key_states: torch.Tensor,
|
| 233 |
+
value_states: torch.Tensor,
|
| 234 |
+
layer_idx: int,
|
| 235 |
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
| 236 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 237 |
+
cache_position = cache_kwargs.get("cache_position")
|
| 238 |
+
|
| 239 |
+
k_out = self.key_cache[layer_idx]
|
| 240 |
+
v_out = self.value_cache[layer_idx]
|
| 241 |
+
|
| 242 |
+
if cache_position is None:
|
| 243 |
+
k_out.copy_(key_states)
|
| 244 |
+
v_out.copy_(value_states)
|
| 245 |
+
else:
|
| 246 |
+
# Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to
|
| 247 |
+
# `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place
|
| 248 |
+
# operation, that avoids copies and uses less memory.
|
| 249 |
+
try:
|
| 250 |
+
k_out.index_copy_(2, cache_position, key_states)
|
| 251 |
+
v_out.index_copy_(2, cache_position, value_states)
|
| 252 |
+
except NotImplementedError:
|
| 253 |
+
# The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
|
| 254 |
+
k_out[:, :, cache_position] = key_states
|
| 255 |
+
v_out[:, :, cache_position] = value_states
|
| 256 |
+
|
| 257 |
+
return k_out, v_out
|
| 258 |
+
|
| 259 |
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
| 260 |
+
"""Returns the sequence length of the cached states that were seen by the model."""
|
| 261 |
+
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
|
| 262 |
+
# limit the check to the first batch member and head dimension.
|
| 263 |
+
# TODO: deprecate this function in favor of `cache_position`
|
| 264 |
+
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
|
| 265 |
+
|
| 266 |
+
def get_max_length(self) -> Optional[int]:
|
| 267 |
+
"""Returns the maximum sequence length of the cached states."""
|
| 268 |
+
return self.max_cache_len
|
| 269 |
+
|
| 270 |
+
def reset(self):
|
| 271 |
+
"""Resets the cache values while preserving the objects"""
|
| 272 |
+
for layer_idx in range(len(self.key_cache)):
|
| 273 |
+
# In-place ops prevent breaking the static address
|
| 274 |
+
self.key_cache[layer_idx].zero_()
|
| 275 |
+
self.value_cache[layer_idx].zero_()
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
class SlidingWindowCache(StaticCache):
|
| 279 |
+
def __init__(
|
| 280 |
+
self,
|
| 281 |
+
config: LLMModelConfig,
|
| 282 |
+
batch_size: int,
|
| 283 |
+
max_cache_len: int,
|
| 284 |
+
device: torch.device,
|
| 285 |
+
dtype: torch.dtype = torch.float32,
|
| 286 |
+
) -> None:
|
| 287 |
+
super().__init__()
|
| 288 |
+
if not hasattr(config, "sliding_window_") or config.sliding_window_ is None:
|
| 289 |
+
raise ValueError(
|
| 290 |
+
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
|
| 291 |
+
"sliding window attention, please check if there is a `sliding_window` field in the model "
|
| 292 |
+
"config and it's not set to None."
|
| 293 |
+
)
|
| 294 |
+
max_cache_len = min(config.sliding_window_, max_cache_len)
|
| 295 |
+
super().__init__(
|
| 296 |
+
config=config,
|
| 297 |
+
batch_size=batch_size,
|
| 298 |
+
max_cache_len=max_cache_len,
|
| 299 |
+
device=device,
|
| 300 |
+
dtype=dtype,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
def update(
|
| 304 |
+
self,
|
| 305 |
+
key_states: torch.Tensor,
|
| 306 |
+
value_states: torch.Tensor,
|
| 307 |
+
layer_idx: int,
|
| 308 |
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
| 309 |
+
) -> Tuple[torch.Tensor]:
|
| 310 |
+
cache_position = cache_kwargs.get("cache_position")
|
| 311 |
+
k_out = self.key_cache[layer_idx]
|
| 312 |
+
v_out = self.value_cache[layer_idx]
|
| 313 |
+
|
| 314 |
+
# assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len)
|
| 315 |
+
if cache_position.shape[0] > self.max_cache_len:
|
| 316 |
+
k_out = key_states[:, :, -self.max_cache_len :, :]
|
| 317 |
+
v_out = value_states[:, :, -self.max_cache_len :, :]
|
| 318 |
+
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
|
| 319 |
+
self.key_cache[layer_idx] += k_out
|
| 320 |
+
self.value_cache[layer_idx] += v_out
|
| 321 |
+
# we should return the whole states instead of k_out, v_out to take the whole prompt
|
| 322 |
+
# into consideration when building kv cache instead of just throwing away tokens outside of the window
|
| 323 |
+
return key_states, value_states
|
| 324 |
+
|
| 325 |
+
slicing = torch.ones(
|
| 326 |
+
self.max_cache_len, dtype=torch.long, device=value_states.device
|
| 327 |
+
).cumsum(0)
|
| 328 |
+
cache_position = cache_position.clamp(0, self.max_cache_len - 1)
|
| 329 |
+
to_shift = cache_position >= self.max_cache_len - 1
|
| 330 |
+
indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len
|
| 331 |
+
|
| 332 |
+
k_out = k_out[:, :, indices]
|
| 333 |
+
v_out = v_out[:, :, indices]
|
| 334 |
+
|
| 335 |
+
try:
|
| 336 |
+
k_out.index_copy_(2, cache_position, key_states)
|
| 337 |
+
v_out.index_copy_(2, cache_position, value_states)
|
| 338 |
+
except NotImplementedError:
|
| 339 |
+
# The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
|
| 340 |
+
k_out[:, :, cache_position] = key_states
|
| 341 |
+
v_out[:, :, cache_position] = value_states
|
| 342 |
+
|
| 343 |
+
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
|
| 344 |
+
self.key_cache[layer_idx].zero_()
|
| 345 |
+
self.value_cache[layer_idx].zero_()
|
| 346 |
+
|
| 347 |
+
self.key_cache[layer_idx] += k_out
|
| 348 |
+
self.value_cache[layer_idx] += v_out
|
| 349 |
+
|
| 350 |
+
return k_out, v_out
|
| 351 |
+
|
| 352 |
+
def get_max_length(self) -> Optional[int]:
|
| 353 |
+
# in theory there is no limit because the sliding window size is fixed no matter how long the sentence is
|
| 354 |
+
return None
|
| 355 |
+
|
| 356 |
+
def reset(self):
|
| 357 |
+
for layer_idx in range(len(self.key_cache)):
|
| 358 |
+
# In-place ops prevent breaking the static address
|
| 359 |
+
self.key_cache[layer_idx].zero_()
|
| 360 |
+
self.value_cache[layer_idx].zero_()
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
class HybridCache(LLMCache):
|
| 364 |
+
def __init__(
|
| 365 |
+
self,
|
| 366 |
+
config: LLMModelConfig,
|
| 367 |
+
batch_size: int,
|
| 368 |
+
max_cache_len: int,
|
| 369 |
+
device: torch.device,
|
| 370 |
+
dtype: torch.dtype = torch.float32,
|
| 371 |
+
) -> None:
|
| 372 |
+
super().__init__()
|
| 373 |
+
if not hasattr(config, "sliding_window_") or config.sliding_window_ is None:
|
| 374 |
+
raise ValueError(
|
| 375 |
+
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
|
| 376 |
+
"sliding window attention, please check if there is a `sliding_window` field in the model "
|
| 377 |
+
"config and it's not set to None."
|
| 378 |
+
)
|
| 379 |
+
self.max_cache_len = max_cache_len
|
| 380 |
+
self.batch_size = batch_size
|
| 381 |
+
self.head_dim = config.head_dim_
|
| 382 |
+
|
| 383 |
+
self.dtype = dtype
|
| 384 |
+
self.num_key_value_heads = config.n_kv_heads_
|
| 385 |
+
self.is_sliding = torch.tensor(
|
| 386 |
+
[not bool(i % 2) for i in range(config.n_layers_)],
|
| 387 |
+
dtype=torch.bool,
|
| 388 |
+
device=device,
|
| 389 |
+
)
|
| 390 |
+
self.key_cache: List[torch.Tensor] = []
|
| 391 |
+
self.value_cache: List[torch.Tensor] = []
|
| 392 |
+
global_cache_shape = (
|
| 393 |
+
self.batch_size,
|
| 394 |
+
self.num_key_value_heads,
|
| 395 |
+
max_cache_len,
|
| 396 |
+
self.head_dim,
|
| 397 |
+
)
|
| 398 |
+
sliding_cache_shape = (
|
| 399 |
+
self.batch_size,
|
| 400 |
+
self.num_key_value_heads,
|
| 401 |
+
min(config.sliding_window_, max_cache_len),
|
| 402 |
+
self.head_dim,
|
| 403 |
+
)
|
| 404 |
+
for i in range(config.n_layers_):
|
| 405 |
+
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
| 406 |
+
# breaks when updating the cache.
|
| 407 |
+
cache_shape = (
|
| 408 |
+
global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
|
| 409 |
+
)
|
| 410 |
+
new_layer_key_cache = torch.zeros(
|
| 411 |
+
cache_shape, dtype=self.dtype, device=device
|
| 412 |
+
)
|
| 413 |
+
new_layer_value_cache = torch.zeros(
|
| 414 |
+
cache_shape, dtype=self.dtype, device=device
|
| 415 |
+
)
|
| 416 |
+
torch._dynamo.mark_static_address(new_layer_key_cache)
|
| 417 |
+
torch._dynamo.mark_static_address(new_layer_value_cache)
|
| 418 |
+
self.key_cache.append(new_layer_key_cache)
|
| 419 |
+
self.value_cache.append(new_layer_value_cache)
|
| 420 |
+
|
| 421 |
+
def _sliding_update(
|
| 422 |
+
self,
|
| 423 |
+
cache_position,
|
| 424 |
+
layer_idx,
|
| 425 |
+
key_states,
|
| 426 |
+
value_states,
|
| 427 |
+
k_out,
|
| 428 |
+
v_out,
|
| 429 |
+
max_cache_len,
|
| 430 |
+
):
|
| 431 |
+
if cache_position.shape[0] > max_cache_len:
|
| 432 |
+
k_out = key_states[:, :, -max_cache_len:, :]
|
| 433 |
+
v_out = value_states[:, :, -max_cache_len:, :]
|
| 434 |
+
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
|
| 435 |
+
self.key_cache[layer_idx] += k_out
|
| 436 |
+
self.value_cache[layer_idx] += v_out
|
| 437 |
+
# we should return the whole states instead of k_out, v_out to take the whole prompt
|
| 438 |
+
# into consideration when building kv cache instead of just throwing away tokens outside of the window
|
| 439 |
+
return key_states, value_states
|
| 440 |
+
|
| 441 |
+
slicing = torch.ones(
|
| 442 |
+
max_cache_len, dtype=torch.long, device=value_states.device
|
| 443 |
+
).cumsum(0)
|
| 444 |
+
cache_position = cache_position.clamp(0, max_cache_len - 1)
|
| 445 |
+
to_shift = cache_position >= max_cache_len - 1
|
| 446 |
+
indices = (slicing + to_shift[-1].int() - 1) % max_cache_len
|
| 447 |
+
k_out = k_out[:, :, indices]
|
| 448 |
+
v_out = v_out[:, :, indices]
|
| 449 |
+
|
| 450 |
+
k_out[:, :, cache_position] = key_states
|
| 451 |
+
v_out[:, :, cache_position] = value_states
|
| 452 |
+
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
|
| 453 |
+
self.key_cache[layer_idx].zero_()
|
| 454 |
+
self.value_cache[layer_idx].zero_()
|
| 455 |
+
|
| 456 |
+
self.key_cache[layer_idx] += k_out
|
| 457 |
+
self.value_cache[layer_idx] += v_out
|
| 458 |
+
return k_out, v_out
|
| 459 |
+
|
| 460 |
+
def _static_update(
|
| 461 |
+
self,
|
| 462 |
+
cache_position,
|
| 463 |
+
layer_idx,
|
| 464 |
+
key_states,
|
| 465 |
+
value_states,
|
| 466 |
+
k_out,
|
| 467 |
+
v_out,
|
| 468 |
+
max_cache_len,
|
| 469 |
+
):
|
| 470 |
+
k_out[:, :, cache_position] = key_states
|
| 471 |
+
v_out[:, :, cache_position] = value_states
|
| 472 |
+
|
| 473 |
+
self.key_cache[layer_idx] = k_out
|
| 474 |
+
self.value_cache[layer_idx] = v_out
|
| 475 |
+
return k_out, v_out
|
| 476 |
+
|
| 477 |
+
def update(
|
| 478 |
+
self,
|
| 479 |
+
key_states: torch.Tensor,
|
| 480 |
+
value_states: torch.Tensor,
|
| 481 |
+
layer_idx: int,
|
| 482 |
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
| 483 |
+
) -> Tuple[torch.Tensor]:
|
| 484 |
+
cache_position = cache_kwargs.get("cache_position")
|
| 485 |
+
sliding_window = cache_kwargs.get("sliding_window")
|
| 486 |
+
k_out = self.key_cache[layer_idx]
|
| 487 |
+
v_out = self.value_cache[layer_idx]
|
| 488 |
+
if sliding_window:
|
| 489 |
+
update_fn = self._sliding_update
|
| 490 |
+
else:
|
| 491 |
+
update_fn = self._static_update
|
| 492 |
+
|
| 493 |
+
return update_fn(
|
| 494 |
+
cache_position,
|
| 495 |
+
layer_idx,
|
| 496 |
+
key_states,
|
| 497 |
+
value_states,
|
| 498 |
+
k_out,
|
| 499 |
+
v_out,
|
| 500 |
+
k_out.shape[2],
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
def get_max_length(self) -> Optional[int]:
|
| 504 |
+
# in theory there is no limit because the sliding window size is fixed
|
| 505 |
+
# no matter how long the sentence is
|
| 506 |
+
return self.max_cache_len
|
| 507 |
+
|
| 508 |
+
def get_seq_length(self, layer_idx: Optional[int] = 0):
|
| 509 |
+
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
|
| 510 |
+
# limit the check to the first batch member and head dimension.
|
| 511 |
+
# TODO: deprecate this function in favor of `cache_position`
|
| 512 |
+
if layer_idx != 0:
|
| 513 |
+
raise ValueError(
|
| 514 |
+
"`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. "
|
| 515 |
+
"Using the `layer_idx` argument is not supported."
|
| 516 |
+
)
|
| 517 |
+
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
|
| 518 |
+
|
| 519 |
+
def reset(self):
|
| 520 |
+
"""Resets the cache values while preserving the objects"""
|
| 521 |
+
for layer_idx in range(len(self.key_cache)):
|
| 522 |
+
# In-place ops prevent breaking the static address
|
| 523 |
+
self.key_cache[layer_idx].zero_()
|
| 524 |
+
self.value_cache[layer_idx].zero_()
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
cache_dict = {
|
| 528 |
+
"dynamic": DynamicCache,
|
| 529 |
+
"static": StaticCache,
|
| 530 |
+
"sliding_window": SlidingWindowCache,
|
| 531 |
+
"hybrid": HybridCache,
|
| 532 |
+
}
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
def cache_factory(
|
| 536 |
+
cache_implementation: str,
|
| 537 |
+
config: LLMModelConfig,
|
| 538 |
+
batch_size: int,
|
| 539 |
+
max_cache_len: int,
|
| 540 |
+
):
|
| 541 |
+
assert (
|
| 542 |
+
cache_implementation in cache_dict
|
| 543 |
+
), f"Unknown cache type. {cache_implementation}"
|
| 544 |
+
logging.info(f"Use {cache_implementation} as cache implementation.")
|
| 545 |
+
if cache_implementation == "sliding_window":
|
| 546 |
+
assert hasattr(config, "sliding_window_")
|
| 547 |
+
max_cache_len = min(config.sliding_window_, max_cache_len)
|
| 548 |
+
return cache_dict[cache_implementation](
|
| 549 |
+
config=config,
|
| 550 |
+
batch_size=batch_size,
|
| 551 |
+
max_cache_len=max_cache_len,
|
| 552 |
+
device=config.device_,
|
| 553 |
+
dtype=config.dtype_,
|
| 554 |
+
)
|
c2cite/common/checkpoint.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def pack_hook(to_offload: torch.Tensor) -> Tuple[torch.device, torch.Tensor]:
|
| 7 |
+
return to_offload.device, to_offload.to("cpu")
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def unpack_hook(to_offload_info: Tuple[torch.device, torch.Tensor]) -> torch.Tensor:
|
| 11 |
+
device, to_offload = to_offload_info
|
| 12 |
+
return to_offload.to(device)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def CheckpointNoneFunction(run_function: Callable, *args):
|
| 16 |
+
return run_function(*args)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def CheckpointOffloadFunction(run_function: Callable, *args):
|
| 20 |
+
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
|
| 21 |
+
outputs = run_function(*args)
|
| 22 |
+
return outputs
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def CheckpointRecomputeFunction(run_function: Callable, *args):
|
| 26 |
+
return torch.utils.checkpoint.checkpoint(run_function, *args, use_reentrant=True)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
CHECKPOINT_CLASSES = {
|
| 30 |
+
"none": CheckpointNoneFunction,
|
| 31 |
+
"offload": CheckpointOffloadFunction,
|
| 32 |
+
"recompute": CheckpointRecomputeFunction,
|
| 33 |
+
}
|
c2cite/common/config.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import os
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import Callable, Dict, List, Optional, TypeAlias, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
Tokens: TypeAlias = List[int]
|
| 9 |
+
Labels: TypeAlias = List[int]
|
| 10 |
+
Masks: TypeAlias = List[bool]
|
| 11 |
+
Ground: TypeAlias = List[str]
|
| 12 |
+
Citations: TypeAlias = List[str]
|
| 13 |
+
Query: TypeAlias = List[str]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class Prompt:
|
| 18 |
+
instruction: str = None
|
| 19 |
+
input: str = None
|
| 20 |
+
label: str = None
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class InputData:
|
| 25 |
+
inputs: List[Union[Prompt, List[str], str]] = None
|
| 26 |
+
prefix_length_: int = None
|
| 27 |
+
tokens: Optional[Tokens] = None
|
| 28 |
+
labels: Optional[Labels] = None
|
| 29 |
+
grounds: Optional[Ground] = None
|
| 30 |
+
citations: Optional[Citations] = None
|
| 31 |
+
citation_tokens: Optional[List] = None
|
| 32 |
+
citation_embeds: Optional[List] = None
|
| 33 |
+
query: Optional[Query] = None
|
| 34 |
+
token_len: Optional[int] = None
|
| 35 |
+
prompt: Optional[str] = None
|
| 36 |
+
prompt_len: Optional[int] = None
|
| 37 |
+
test_citations: Optional[Citations] = None
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@dataclass
|
| 41 |
+
class LLMModelConfig:
|
| 42 |
+
name_or_path_: str = None
|
| 43 |
+
device_: str = None
|
| 44 |
+
dim_: int = None
|
| 45 |
+
head_dim_: int = None
|
| 46 |
+
intermediate_: int = None
|
| 47 |
+
n_heads_: int = None
|
| 48 |
+
n_kv_heads_: int = None
|
| 49 |
+
n_layers_: int = None
|
| 50 |
+
hidden_act_: str = None
|
| 51 |
+
hidden_dropout_: float = None
|
| 52 |
+
vocab_size_: int = None
|
| 53 |
+
pad_token_id_: int = None
|
| 54 |
+
rope_theta_: float = None
|
| 55 |
+
partial_rotary_factor_: float = None
|
| 56 |
+
max_seq_len_: int = None
|
| 57 |
+
# eager or flash_attn
|
| 58 |
+
attn_implementation_: str = "eager"
|
| 59 |
+
# data type
|
| 60 |
+
dtype_: torch.dtype = None
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@dataclass
|
| 64 |
+
class LLMModelOutput:
|
| 65 |
+
adapter_name: str = None
|
| 66 |
+
logits: torch.Tensor = None
|
| 67 |
+
router_logits: torch.Tensor = None
|
| 68 |
+
loss: torch.Tensor = None
|
| 69 |
+
cite_flag: bool = False
|
| 70 |
+
aux_loss: torch.Tensor = None
|
| 71 |
+
# for internal use
|
| 72 |
+
batch_start_idx_: int = -1
|
| 73 |
+
batch_end_idx_: int = -1
|
| 74 |
+
loss_fn_: Callable = None
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@dataclass
|
| 78 |
+
class LLMBatchConfig:
|
| 79 |
+
adapter_name_: str = ""
|
| 80 |
+
batch_start_idx_: int = -1
|
| 81 |
+
batch_end_idx_: int = -1
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _efficient_operator_factory():
|
| 85 |
+
efficient_operator = os.getenv("MOE_PEFT_EVALUATE_MODE") is None
|
| 86 |
+
return efficient_operator
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@dataclass
|
| 90 |
+
class LLMModelInput:
|
| 91 |
+
batch_configs_: List[LLMBatchConfig] = None
|
| 92 |
+
batch_tokens_: List[Tokens] = None
|
| 93 |
+
batch_labels_: List[Labels] = None
|
| 94 |
+
batch_grounds_: List[Ground] = None
|
| 95 |
+
batch_cites: List[List] = None
|
| 96 |
+
batch_cites_value: List[List] = None
|
| 97 |
+
batch_masks_: List[Masks] = None
|
| 98 |
+
batch_docs: List[str] = None
|
| 99 |
+
batch_prompt_len: List[int] = None
|
| 100 |
+
|
| 101 |
+
output_router_logits_: bool = True
|
| 102 |
+
|
| 103 |
+
gradient_checkpoint_: str = "none"
|
| 104 |
+
efficient_operator_: bool = field(default_factory=_efficient_operator_factory)
|
| 105 |
+
inference_mode_: bool = False
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
@dataclass
|
| 109 |
+
class AdapterConfig:
|
| 110 |
+
adapter_name: str = ""
|
| 111 |
+
task_name: str = "casual"
|
| 112 |
+
|
| 113 |
+
@staticmethod
|
| 114 |
+
def from_config(config: Dict[str, any]) -> "AdapterConfig":
|
| 115 |
+
return AdapterConfig(
|
| 116 |
+
adapter_name=config.get("name", None),
|
| 117 |
+
task_name=config.get("task_name", None),
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
lora_target_modules = {
|
| 122 |
+
# LLaMA names
|
| 123 |
+
"q_proj": False,
|
| 124 |
+
"k_proj": False,
|
| 125 |
+
"v_proj": False,
|
| 126 |
+
"o_proj": False,
|
| 127 |
+
"gate_proj": False,
|
| 128 |
+
"down_proj": False,
|
| 129 |
+
"up_proj": False,
|
| 130 |
+
# Phi names
|
| 131 |
+
"q_proj": False,
|
| 132 |
+
"k_proj": False,
|
| 133 |
+
"v_proj": False,
|
| 134 |
+
"dense": False,
|
| 135 |
+
"fc1": False,
|
| 136 |
+
"fc2": False,
|
| 137 |
+
# Phi3 names
|
| 138 |
+
"qkv_proj": False,
|
| 139 |
+
"o_proj": False,
|
| 140 |
+
"gate_up_proj": False,
|
| 141 |
+
"down_proj": False,
|
| 142 |
+
# GLM names
|
| 143 |
+
"qkv_proj": False,
|
| 144 |
+
"dense": False,
|
| 145 |
+
"dense_h_to_4h": False,
|
| 146 |
+
"dense_4h_to_h": False,
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
@dataclass
|
| 151 |
+
class LoraConfig(AdapterConfig):
|
| 152 |
+
# Weight-Decomposed Low-Rank Adaptation
|
| 153 |
+
use_dora_: bool = False
|
| 154 |
+
# Rank-Stabilized LoRA
|
| 155 |
+
# sets the adapter scaling factor to `alpha/math.sqrt(r)`
|
| 156 |
+
use_rslora_: bool = False
|
| 157 |
+
# can be original or gaussian
|
| 158 |
+
lora_init_: str = "original"
|
| 159 |
+
lora_r_: int = None
|
| 160 |
+
lora_alpha_: int = None
|
| 161 |
+
lora_dropout_: float = None
|
| 162 |
+
target_modules_: Dict[str, bool] = None
|
| 163 |
+
atten_coin: float = None
|
| 164 |
+
router_coin: float = None
|
| 165 |
+
cite_coin:float = None
|
| 166 |
+
learning_rate: float = None
|
| 167 |
+
|
| 168 |
+
def check(self) -> "LoraConfig":
|
| 169 |
+
assert isinstance(self.use_dora_, bool)
|
| 170 |
+
assert isinstance(self.use_rslora_, bool)
|
| 171 |
+
assert isinstance(self.lora_init_, str) and self.lora_init_ in [
|
| 172 |
+
"original",
|
| 173 |
+
"gaussian",
|
| 174 |
+
]
|
| 175 |
+
assert isinstance(self.lora_r_, int) and self.lora_r_ > 0
|
| 176 |
+
assert isinstance(self.lora_alpha_, int) and self.lora_alpha_ > 0
|
| 177 |
+
assert isinstance(self.lora_dropout_, float) and self.lora_dropout_ >= 0
|
| 178 |
+
assert isinstance(self.target_modules_, Dict)
|
| 179 |
+
for key, value in self.target_modules_.items():
|
| 180 |
+
assert isinstance(key, str) and len(key) > 0
|
| 181 |
+
assert isinstance(value, bool)
|
| 182 |
+
|
| 183 |
+
return self
|
| 184 |
+
|
| 185 |
+
@staticmethod
|
| 186 |
+
def from_config(config: Dict[str, any]) -> "LoraConfig":
|
| 187 |
+
lora_config = LoraConfig(**AdapterConfig.from_config(config).__dict__)
|
| 188 |
+
lora_config.use_dora_ = config.get("use_dora", False)
|
| 189 |
+
lora_config.use_rslora_ = config.get("use_rslora", False)
|
| 190 |
+
lora_config.lora_init_ = config.get("lora_init", "original")
|
| 191 |
+
lora_config.lora_r_ = config["r"]
|
| 192 |
+
lora_config.lora_alpha_ = config["lora_alpha"]
|
| 193 |
+
lora_config.lora_dropout_ = config["lora_dropout"]
|
| 194 |
+
lora_config.target_modules_ = copy.deepcopy(lora_target_modules)
|
| 195 |
+
lora_config.atten_coin = config["atten_mat_coin"]
|
| 196 |
+
lora_config.router_coin = config["router_coin"]
|
| 197 |
+
lora_config.cite_coin = config["cite_coin"]
|
| 198 |
+
lora_config.learning_rate = config["lr"]
|
| 199 |
+
if isinstance(config["target_modules"], List):
|
| 200 |
+
for target in config["target_modules"]:
|
| 201 |
+
if target in lora_target_modules:
|
| 202 |
+
lora_config.target_modules_[target] = True
|
| 203 |
+
elif isinstance(config["target_modules"], Dict):
|
| 204 |
+
for target, value in config["target_modules"].items():
|
| 205 |
+
if target in lora_target_modules:
|
| 206 |
+
lora_config.target_modules_[target] = value
|
| 207 |
+
else:
|
| 208 |
+
raise ValueError("broken config item: target_modules")
|
| 209 |
+
|
| 210 |
+
return lora_config
|
| 211 |
+
|
| 212 |
+
def export(self) -> Dict[str, any]:
|
| 213 |
+
config = {}
|
| 214 |
+
if self.use_dora_:
|
| 215 |
+
config["use_dora"] = True
|
| 216 |
+
if self.use_rslora_:
|
| 217 |
+
config["use_rslora"] = True
|
| 218 |
+
config["bias"] = "none"
|
| 219 |
+
config["peft_type"] = "LORA"
|
| 220 |
+
config["r"] = self.lora_r_
|
| 221 |
+
config["lora_alpha"] = self.lora_alpha_
|
| 222 |
+
config["lora_dropout"] = self.lora_dropout_
|
| 223 |
+
tgt_list = []
|
| 224 |
+
for target, value in self.target_modules_.items():
|
| 225 |
+
if value:
|
| 226 |
+
tgt_list.append(target)
|
| 227 |
+
config["target_modules"] = tgt_list
|
| 228 |
+
|
| 229 |
+
config["atten_mat_coin"] = self.atten_coin
|
| 230 |
+
config["router_coin"] = self.router_coin
|
| 231 |
+
config["cite_coin"] = self.cite_coin
|
| 232 |
+
config["lr"] = self.learning_rate
|
| 233 |
+
|
| 234 |
+
return config
|
c2cite/common/feed_forward.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from moe_peft.executors import executor
|
| 6 |
+
|
| 7 |
+
from .abstracts import LLMFeedForward, LLMMoeBlock
|
| 8 |
+
from .config import LLMModelInput
|
| 9 |
+
from .lora_linear import Linear, get_range_tensor
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class FeedForward(torch.nn.Module):
|
| 13 |
+
def __init__(self, mlp: LLMFeedForward) -> None:
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.mlp_: LLMFeedForward = mlp
|
| 16 |
+
# mix of experts
|
| 17 |
+
self.moes_: Dict[str, LLMMoeBlock] = {}
|
| 18 |
+
|
| 19 |
+
def state_dict(self) -> Dict[str, Linear]:
|
| 20 |
+
return self.mlp_.state_dict()
|
| 21 |
+
|
| 22 |
+
def forward(
|
| 23 |
+
self, data: torch.Tensor, input_args: LLMModelInput
|
| 24 |
+
) -> Tuple[torch.Tensor, List]:
|
| 25 |
+
if len(self.moes_) == 0:
|
| 26 |
+
return self.mlp_._batch_forward(data, input_args)
|
| 27 |
+
else:
|
| 28 |
+
return self._moe_forward(data, input_args)
|
| 29 |
+
|
| 30 |
+
def _moe_forward(self, data: torch.Tensor, input_args: LLMModelInput):
|
| 31 |
+
final_hidden_states = executor.init_tensor(data)
|
| 32 |
+
|
| 33 |
+
if input_args.output_router_logits_:
|
| 34 |
+
router_logits = [None for _ in range(len(input_args.batch_configs_))]
|
| 35 |
+
else:
|
| 36 |
+
router_logits = []
|
| 37 |
+
|
| 38 |
+
lora_range = get_range_tensor(data.device, data.shape[0])
|
| 39 |
+
for idx, lora_config in enumerate(input_args.batch_configs_):
|
| 40 |
+
moe_name = lora_config.adapter_name_
|
| 41 |
+
start_idx = lora_config.batch_start_idx_
|
| 42 |
+
end_idx = lora_config.batch_end_idx_
|
| 43 |
+
|
| 44 |
+
if moe_name in self.moes_:
|
| 45 |
+
current_hidden_states, current_router_outputs = self.moes_[
|
| 46 |
+
moe_name
|
| 47 |
+
].forward(
|
| 48 |
+
hidden_states=data[start_idx:end_idx],
|
| 49 |
+
ffn_layer=self.mlp_,
|
| 50 |
+
input_args=input_args,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
if (
|
| 54 |
+
input_args.output_router_logits_
|
| 55 |
+
and current_router_outputs is not None
|
| 56 |
+
):
|
| 57 |
+
router_logits[idx] = current_router_outputs
|
| 58 |
+
else:
|
| 59 |
+
current_hidden_states = self.mlp_._lora_forward(
|
| 60 |
+
moe_name, self.mlp_.act_, data[start_idx:end_idx]
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
executor.index_copy(
|
| 64 |
+
final_hidden_states,
|
| 65 |
+
0,
|
| 66 |
+
lora_range[start_idx:end_idx],
|
| 67 |
+
current_hidden_states,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
return final_hidden_states, router_logits
|
c2cite/common/lora_linear.py
ADDED
|
@@ -0,0 +1,511 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from transformers.utils import is_bitsandbytes_available
|
| 7 |
+
|
| 8 |
+
from moe_peft.executors import executor
|
| 9 |
+
|
| 10 |
+
from .abstracts import LLMMoeBlock
|
| 11 |
+
from .config import LLMModelInput, LoraConfig
|
| 12 |
+
|
| 13 |
+
if is_bitsandbytes_available():
|
| 14 |
+
import bitsandbytes as bnb
|
| 15 |
+
from bitsandbytes.nn import Linear4bit, Linear8bitLt
|
| 16 |
+
else:
|
| 17 |
+
from moe_peft.utils import Linear8bitLt, Linear4bit
|
| 18 |
+
|
| 19 |
+
from typing import Any, Dict, List, Tuple
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None):
|
| 23 |
+
# BNB requires CUDA weights
|
| 24 |
+
device = weight.device
|
| 25 |
+
is_cpu = device.type == torch.device("cpu").type
|
| 26 |
+
if is_cpu:
|
| 27 |
+
weight = weight.to(torch.device("cuda"))
|
| 28 |
+
|
| 29 |
+
cls_name = weight.__class__.__name__
|
| 30 |
+
if cls_name == "Params4bit":
|
| 31 |
+
dequantized = bnb.functional.dequantize_4bit(weight.data, weight.quant_state)
|
| 32 |
+
if is_cpu:
|
| 33 |
+
dequantized = dequantized.to(device)
|
| 34 |
+
return dequantized
|
| 35 |
+
|
| 36 |
+
if state.SCB is None:
|
| 37 |
+
state.SCB = weight.SCB
|
| 38 |
+
|
| 39 |
+
im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device)
|
| 40 |
+
im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im)
|
| 41 |
+
im, Sim = bnb.functional.transform(im, "col32")
|
| 42 |
+
if state.CxB is None:
|
| 43 |
+
state.CxB, state.SB = bnb.functional.transform(
|
| 44 |
+
weight.data, to_order=state.formatB
|
| 45 |
+
)
|
| 46 |
+
out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB)
|
| 47 |
+
dequantized = bnb.functional.mm_dequant(
|
| 48 |
+
out32, Sout32, SCim, state.SCB, bias=None
|
| 49 |
+
).t()
|
| 50 |
+
if is_cpu:
|
| 51 |
+
dequantized = dequantized.to(device)
|
| 52 |
+
return dequantized
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def dequantize_module_weight(module: torch.nn.Module) -> torch.nn.Parameter:
|
| 56 |
+
if hasattr(module, "W_q"): # For handling HQQ quantized weight
|
| 57 |
+
weight = module.dequantize()
|
| 58 |
+
return weight
|
| 59 |
+
|
| 60 |
+
weight = module.weight
|
| 61 |
+
if not isinstance(weight, torch.nn.Parameter):
|
| 62 |
+
raise TypeError(
|
| 63 |
+
f"Input weight should be of type nn.Parameter, got {type(weight)} instead"
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
cls_name = weight.__class__.__name__
|
| 67 |
+
if cls_name not in ("Params4bit", "Int8Params"):
|
| 68 |
+
return weight
|
| 69 |
+
|
| 70 |
+
quant_state = getattr(module, "state", None)
|
| 71 |
+
device = weight.device
|
| 72 |
+
is_cpu = device.type == torch.device("cpu").type
|
| 73 |
+
weight = dequantize_bnb_weight(weight, state=quant_state) # no-op if not bnb
|
| 74 |
+
if is_cpu:
|
| 75 |
+
# dequantize_bnb_weight for 8bit moves the device in-place, thus we need to move it back to CPU if necessary
|
| 76 |
+
module.weight = module.weight.to(device)
|
| 77 |
+
return weight
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
g_cached_range_tensor: Dict[torch.device, torch.Tensor] = {}
|
| 81 |
+
# also max batch size
|
| 82 |
+
g_max_range = 128
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def get_range_tensor(device: torch.device, batch_size: int = 1024):
|
| 86 |
+
global g_cached_range_tensor
|
| 87 |
+
global g_max_range
|
| 88 |
+
if device not in g_cached_range_tensor or batch_size > g_max_range:
|
| 89 |
+
g_max_range = g_max_range if g_max_range > batch_size else batch_size
|
| 90 |
+
g_cached_range_tensor[device] = torch.arange(
|
| 91 |
+
0, g_max_range, step=1, device=device
|
| 92 |
+
)
|
| 93 |
+
return g_cached_range_tensor[device]
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class LoraFunction(torch.autograd.Function):
|
| 97 |
+
@staticmethod
|
| 98 |
+
def forward(
|
| 99 |
+
ctx,
|
| 100 |
+
result: torch.Tensor,
|
| 101 |
+
data: torch.Tensor,
|
| 102 |
+
input_args: LLMModelInput,
|
| 103 |
+
dropouts: List[float],
|
| 104 |
+
scalings: List[float],
|
| 105 |
+
*args,
|
| 106 |
+
):
|
| 107 |
+
# the lora module is f32 precision
|
| 108 |
+
data = data.to(torch.float32)
|
| 109 |
+
|
| 110 |
+
save_inputs: Tuple[torch.Tensor | None, ...] = (data,)
|
| 111 |
+
|
| 112 |
+
lora_range = get_range_tensor(data.device, data.shape[0])
|
| 113 |
+
for lora_a, lora_b, lora_config, dropout, scaling in zip(
|
| 114 |
+
args[::2],
|
| 115 |
+
args[1::2],
|
| 116 |
+
input_args.batch_configs_,
|
| 117 |
+
dropouts,
|
| 118 |
+
scalings,
|
| 119 |
+
):
|
| 120 |
+
assert not ((lora_a is None) ^ (lora_b is None))
|
| 121 |
+
if lora_a is None and lora_b is None:
|
| 122 |
+
save_inputs += (None, None, None)
|
| 123 |
+
continue
|
| 124 |
+
|
| 125 |
+
assert not ((lora_a.requires_grad) ^ (lora_b.requires_grad))
|
| 126 |
+
if not lora_a.requires_grad and not lora_b.requires_grad:
|
| 127 |
+
save_inputs += (None, None, None)
|
| 128 |
+
continue
|
| 129 |
+
|
| 130 |
+
start_idx = lora_config.batch_start_idx_
|
| 131 |
+
end_idx = lora_config.batch_end_idx_
|
| 132 |
+
|
| 133 |
+
# must ensure the dropout is not zero
|
| 134 |
+
# is dropout == 0, dropdata is a data's referece, so the data will be changed
|
| 135 |
+
assert dropout > 0.0
|
| 136 |
+
|
| 137 |
+
drop_data = F.dropout(data[start_idx:end_idx], p=dropout)
|
| 138 |
+
drop_data.mul_(scaling)
|
| 139 |
+
drop_data = drop_data @ lora_a.transpose(0, 1)
|
| 140 |
+
lora_data = drop_data @ lora_b.transpose(0, 1)
|
| 141 |
+
|
| 142 |
+
lora_data = lora_data.to(result.dtype)
|
| 143 |
+
|
| 144 |
+
result.index_add_(0, lora_range[start_idx:end_idx], lora_data)
|
| 145 |
+
|
| 146 |
+
save_inputs += (lora_a, lora_b, drop_data)
|
| 147 |
+
|
| 148 |
+
ctx.input_args = input_args
|
| 149 |
+
ctx.dropouts = dropouts
|
| 150 |
+
ctx.scalings = scalings
|
| 151 |
+
ctx.save_for_backward(*save_inputs)
|
| 152 |
+
|
| 153 |
+
return result
|
| 154 |
+
|
| 155 |
+
@staticmethod
|
| 156 |
+
def backward(ctx: Any, *grad_outputs: Any) -> Any:
|
| 157 |
+
grad_output: torch.Tensor = grad_outputs[0]
|
| 158 |
+
grad_result = None
|
| 159 |
+
grad_data: torch.Tensor | None = None
|
| 160 |
+
grad_input_args = None
|
| 161 |
+
grad_dropouts = None
|
| 162 |
+
grad_scalings = None
|
| 163 |
+
grad_loras: Tuple[torch.Tensor | None, ...] = ()
|
| 164 |
+
|
| 165 |
+
data, *loras = ctx.saved_tensors
|
| 166 |
+
|
| 167 |
+
if ctx.needs_input_grad[0]:
|
| 168 |
+
grad_result = grad_output
|
| 169 |
+
if ctx.needs_input_grad[1]:
|
| 170 |
+
grad_data = executor.init_tensor(data)
|
| 171 |
+
|
| 172 |
+
# the lora module is fp32 precision
|
| 173 |
+
grad_output = grad_output.to(torch.float32)
|
| 174 |
+
lora_range = get_range_tensor(
|
| 175 |
+
grad_output.device, batch_size=grad_output.shape[0]
|
| 176 |
+
)
|
| 177 |
+
for lora_a, lora_b, drop_data, dropout, scaling, lora_config in zip(
|
| 178 |
+
loras[::3],
|
| 179 |
+
loras[1::3],
|
| 180 |
+
loras[2::3],
|
| 181 |
+
ctx.dropouts,
|
| 182 |
+
ctx.scalings,
|
| 183 |
+
ctx.input_args.batch_configs_,
|
| 184 |
+
):
|
| 185 |
+
start_idx = lora_config.batch_start_idx_
|
| 186 |
+
end_idx = lora_config.batch_end_idx_
|
| 187 |
+
assert not ((lora_a is None) ^ (lora_b is None))
|
| 188 |
+
if lora_a is None and lora_b is None:
|
| 189 |
+
grad_loras += (None, None)
|
| 190 |
+
if grad_data is not None:
|
| 191 |
+
executor.index_fill(grad_data, 0, lora_range[start_idx:end_idx], 0)
|
| 192 |
+
continue
|
| 193 |
+
|
| 194 |
+
# lora_data shape is batch_size * seq_len * in_dim
|
| 195 |
+
lora_data = data[start_idx:end_idx]
|
| 196 |
+
# grad_y shape is batch_size * seq_len * out_dim
|
| 197 |
+
grad_y = grad_output[start_idx:end_idx]
|
| 198 |
+
|
| 199 |
+
# drop_data shape is batch_size * seq_len * r
|
| 200 |
+
|
| 201 |
+
# bstage shape is batch_size * seq_len * r
|
| 202 |
+
bstage = grad_y @ lora_b
|
| 203 |
+
bstage *= scaling / (1 - dropout)
|
| 204 |
+
|
| 205 |
+
grad_a = torch.sum(bstage.transpose(1, 2) @ lora_data, dim=0)
|
| 206 |
+
grad_b = torch.sum(grad_y.transpose(1, 2) @ drop_data, dim=0)
|
| 207 |
+
grad_loras += (grad_a, grad_b)
|
| 208 |
+
|
| 209 |
+
# grad_data shape is batch_size * seq_len * in_dim
|
| 210 |
+
if grad_data is not None:
|
| 211 |
+
grad_x = bstage @ lora_a
|
| 212 |
+
executor.index_copy(grad_data, 0, lora_range[start_idx:end_idx], grad_x)
|
| 213 |
+
|
| 214 |
+
return (
|
| 215 |
+
grad_result,
|
| 216 |
+
grad_data,
|
| 217 |
+
grad_input_args,
|
| 218 |
+
grad_dropouts,
|
| 219 |
+
grad_scalings,
|
| 220 |
+
*grad_loras,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class Lora(nn.Module):
|
| 225 |
+
def __init__(
|
| 226 |
+
self,
|
| 227 |
+
base_layer: nn.Module,
|
| 228 |
+
shape: Tuple[int, int],
|
| 229 |
+
config: LoraConfig,
|
| 230 |
+
device: str,
|
| 231 |
+
):
|
| 232 |
+
|
| 233 |
+
super().__init__()
|
| 234 |
+
|
| 235 |
+
self.base_layer_ = base_layer
|
| 236 |
+
self.device_ = torch.device(device)
|
| 237 |
+
|
| 238 |
+
self.initializer_ = config.lora_init_
|
| 239 |
+
self.r_ = config.lora_r_
|
| 240 |
+
self.alpha_ = config.lora_alpha_
|
| 241 |
+
|
| 242 |
+
if config.use_rslora_:
|
| 243 |
+
self.scaling_ = self.alpha_ / math.sqrt(self.r_)
|
| 244 |
+
else:
|
| 245 |
+
self.scaling_ = self.alpha_ / self.r_
|
| 246 |
+
|
| 247 |
+
self.in_features_, self.out_features_ = shape
|
| 248 |
+
|
| 249 |
+
assert config.lora_dropout_ > 0.0
|
| 250 |
+
self.dropout_ = nn.Dropout(p=config.lora_dropout_)
|
| 251 |
+
|
| 252 |
+
self.lora_a_ = nn.Linear(
|
| 253 |
+
self.in_features_,
|
| 254 |
+
self.r_,
|
| 255 |
+
bias=False,
|
| 256 |
+
dtype=torch.float32,
|
| 257 |
+
device=self.device_,
|
| 258 |
+
)
|
| 259 |
+
self.lora_b_ = nn.Linear(
|
| 260 |
+
self.r_,
|
| 261 |
+
self.out_features_,
|
| 262 |
+
bias=False,
|
| 263 |
+
dtype=torch.float32,
|
| 264 |
+
device=self.device_,
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
self.use_dora_: bool = config.use_dora_
|
| 268 |
+
self.magnitude_vector_: nn.Parameter = None
|
| 269 |
+
|
| 270 |
+
def _get_weight_norm(self, dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
| 271 |
+
# calculate L2 norm of weight matrix, column-wise
|
| 272 |
+
weight = dequantize_module_weight(self.base_layer_).to(dtype)
|
| 273 |
+
lora_weight = self.lora_b_.weight @ self.lora_a_.weight
|
| 274 |
+
weight = weight + self.scaling_ * lora_weight
|
| 275 |
+
weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype)
|
| 276 |
+
return weight_norm
|
| 277 |
+
|
| 278 |
+
def reset_parameters(self, lora_tensor=(None, None)) -> None:
|
| 279 |
+
# if the lora_tensor is not (None, None), use it to init the lora weight
|
| 280 |
+
assert isinstance(lora_tensor, Tuple)
|
| 281 |
+
assert len(lora_tensor) == 2
|
| 282 |
+
assert ((lora_tensor[0] is None) and (lora_tensor[1] is None)) or (
|
| 283 |
+
isinstance(lora_tensor[0], torch.Tensor)
|
| 284 |
+
and isinstance(lora_tensor[1], torch.Tensor)
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
if lora_tensor == (None, None):
|
| 288 |
+
if self.initializer_ == "original":
|
| 289 |
+
nn.init.kaiming_uniform_(self.lora_a_.weight, a=math.sqrt(5))
|
| 290 |
+
elif self.initializer_ == "gaussian":
|
| 291 |
+
nn.init.normal_(self.lora_a_.weight, std=1 / self.r_)
|
| 292 |
+
else:
|
| 293 |
+
raise ValueError(f"Unknown initialization {self.initializer_}")
|
| 294 |
+
nn.init.zeros_(self.lora_b_.weight)
|
| 295 |
+
else:
|
| 296 |
+
with torch.no_grad():
|
| 297 |
+
self.lora_a_.weight.copy_(lora_tensor[0])
|
| 298 |
+
self.lora_b_.weight.copy_(lora_tensor[1])
|
| 299 |
+
|
| 300 |
+
if self.use_dora_:
|
| 301 |
+
self.magnitude_vector_ = nn.Parameter(
|
| 302 |
+
self._get_weight_norm(), requires_grad=True
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
def apply_dora(
|
| 306 |
+
self,
|
| 307 |
+
residual: torch.Tensor,
|
| 308 |
+
result_lora: torch.Tensor,
|
| 309 |
+
):
|
| 310 |
+
weight_norm = self._get_weight_norm().detach()
|
| 311 |
+
mag_norm_scale = (self.magnitude_vector_ / weight_norm).view(1, -1)
|
| 312 |
+
return mag_norm_scale * residual + mag_norm_scale * result_lora
|
| 313 |
+
|
| 314 |
+
def lora_forward(self, hidden_states: torch.Tensor):
|
| 315 |
+
return (
|
| 316 |
+
self.lora_b_(self.lora_a_(self.dropout_(hidden_states.to(torch.float32))))
|
| 317 |
+
* self.scaling_
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
def forward(
|
| 321 |
+
self,
|
| 322 |
+
residual: torch.Tensor,
|
| 323 |
+
hidden_states: torch.Tensor,
|
| 324 |
+
) -> torch.Tensor:
|
| 325 |
+
result_lora = self.lora_forward(hidden_states)
|
| 326 |
+
if self.use_dora_:
|
| 327 |
+
return self.apply_dora(residual, result_lora).to(residual.dtype)
|
| 328 |
+
else:
|
| 329 |
+
return residual + result_lora.to(residual.dtype)
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
class Linear(nn.Module):
|
| 333 |
+
def __init__(self, base_layer: nn.Module, device: str):
|
| 334 |
+
super().__init__()
|
| 335 |
+
|
| 336 |
+
if not isinstance(base_layer, nn.Linear):
|
| 337 |
+
assert isinstance(base_layer, Linear8bitLt) or isinstance(
|
| 338 |
+
base_layer, Linear4bit
|
| 339 |
+
), f"error type - {type(base_layer)}."
|
| 340 |
+
else:
|
| 341 |
+
base_layer.requires_grad_(False)
|
| 342 |
+
|
| 343 |
+
self.device_ = torch.device(device)
|
| 344 |
+
self.base_layer_ = base_layer.to(self.device_)
|
| 345 |
+
self.loras_: Dict[str, Lora] = {}
|
| 346 |
+
self.moes_: Dict[str, LLMMoeBlock] = {}
|
| 347 |
+
|
| 348 |
+
if isinstance(self.base_layer_, Linear4bit):
|
| 349 |
+
self.out_features_, self.in_features_ = (
|
| 350 |
+
self.base_layer_.out_features,
|
| 351 |
+
self.base_layer_.in_features,
|
| 352 |
+
)
|
| 353 |
+
else:
|
| 354 |
+
self.out_features_, self.in_features_ = self.base_layer_.weight.shape
|
| 355 |
+
|
| 356 |
+
def init_lora_weight(
|
| 357 |
+
self, lora_config: LoraConfig, lora_tensor=(None, None), adapter_name=None
|
| 358 |
+
):
|
| 359 |
+
if adapter_name is None:
|
| 360 |
+
adapter_name = lora_config.adapter_name
|
| 361 |
+
|
| 362 |
+
if adapter_name not in self.loras_:
|
| 363 |
+
self.loras_[adapter_name] = Lora(
|
| 364 |
+
self.base_layer_,
|
| 365 |
+
(self.in_features_, self.out_features_),
|
| 366 |
+
lora_config,
|
| 367 |
+
self.device_,
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
self.loras_[adapter_name].reset_parameters(lora_tensor)
|
| 371 |
+
|
| 372 |
+
def _appy_dora(
|
| 373 |
+
self,
|
| 374 |
+
residual: torch.Tensor,
|
| 375 |
+
lora_delta: torch.Tensor,
|
| 376 |
+
input_args: LLMModelInput,
|
| 377 |
+
):
|
| 378 |
+
next_states = executor.init_tensor(residual)
|
| 379 |
+
lora_range = get_range_tensor(
|
| 380 |
+
next_states.device, batch_size=next_states.shape[0]
|
| 381 |
+
)
|
| 382 |
+
for lora_config in input_args.batch_configs_:
|
| 383 |
+
adapter_name = lora_config.adapter_name_
|
| 384 |
+
start_idx = lora_config.batch_start_idx_
|
| 385 |
+
end_idx = lora_config.batch_end_idx_
|
| 386 |
+
|
| 387 |
+
if adapter_name == "" or adapter_name not in self.loras_:
|
| 388 |
+
continue
|
| 389 |
+
|
| 390 |
+
if self.loras_[adapter_name].use_dora_:
|
| 391 |
+
lora_data = self.loras_[adapter_name].apply_dora(
|
| 392 |
+
residual[start_idx:end_idx],
|
| 393 |
+
lora_delta[start_idx:end_idx],
|
| 394 |
+
)
|
| 395 |
+
else:
|
| 396 |
+
lora_data = residual[start_idx:end_idx] + lora_delta[start_idx:end_idx]
|
| 397 |
+
|
| 398 |
+
executor.index_copy(
|
| 399 |
+
next_states, 0, lora_range[start_idx:end_idx], lora_data
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
return next_states
|
| 403 |
+
|
| 404 |
+
def _efficient_impl(
|
| 405 |
+
self, hidden_states: torch.Tensor, input_args: LLMModelInput
|
| 406 |
+
) -> torch.Tensor:
|
| 407 |
+
# hidden_states shape is: batch_size * max_seq_len * dim
|
| 408 |
+
# result = hidden_states @ self.weight_.transpose(0, 1)
|
| 409 |
+
residual = self.base_layer_.forward(hidden_states)
|
| 410 |
+
|
| 411 |
+
if len(self.loras_) == 0:
|
| 412 |
+
return residual
|
| 413 |
+
|
| 414 |
+
# split the data and result
|
| 415 |
+
dropouts: List[float] = []
|
| 416 |
+
scalings: List[float] = []
|
| 417 |
+
loras: Tuple[torch.Tensor] = ()
|
| 418 |
+
for lora_config in input_args.batch_configs_:
|
| 419 |
+
adapter_name = lora_config.adapter_name_
|
| 420 |
+
|
| 421 |
+
if adapter_name not in self.loras_:
|
| 422 |
+
loras += (None, None)
|
| 423 |
+
dropouts.append(None)
|
| 424 |
+
scalings.append(None)
|
| 425 |
+
continue
|
| 426 |
+
|
| 427 |
+
loras += (
|
| 428 |
+
self.loras_[adapter_name].lora_a_.weight,
|
| 429 |
+
self.loras_[adapter_name].lora_b_.weight,
|
| 430 |
+
)
|
| 431 |
+
dropouts.append(self.loras_[adapter_name].dropout_.p)
|
| 432 |
+
scalings.append(self.loras_[adapter_name].scaling_)
|
| 433 |
+
|
| 434 |
+
have_dora = any(lora.use_dora_ for lora in self.loras_.values())
|
| 435 |
+
|
| 436 |
+
if have_dora:
|
| 437 |
+
lora_delta = torch.zeros_like(residual, dtype=torch.float32)
|
| 438 |
+
lora_delta = LoraFunction.apply(
|
| 439 |
+
lora_delta,
|
| 440 |
+
hidden_states.to(torch.float32),
|
| 441 |
+
input_args,
|
| 442 |
+
dropouts,
|
| 443 |
+
scalings,
|
| 444 |
+
*loras,
|
| 445 |
+
)
|
| 446 |
+
next_states = self._appy_dora(
|
| 447 |
+
residual.to(torch.float32), lora_delta, input_args
|
| 448 |
+
)
|
| 449 |
+
else:
|
| 450 |
+
next_states = LoraFunction.apply(
|
| 451 |
+
residual.to(torch.float32),
|
| 452 |
+
hidden_states.to(torch.float32),
|
| 453 |
+
input_args,
|
| 454 |
+
dropouts,
|
| 455 |
+
scalings,
|
| 456 |
+
*loras,
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
return next_states.to(hidden_states.dtype)
|
| 460 |
+
|
| 461 |
+
def _compatible_impl(
|
| 462 |
+
self, hidden_states: torch.Tensor, input_args: LLMModelInput
|
| 463 |
+
) -> torch.Tensor:
|
| 464 |
+
# hidden_states shape is: batch_size * max_seq_len * dim
|
| 465 |
+
# result = hidden_states @ self.weight_.transpose(0, 1)
|
| 466 |
+
residual = self.base_layer_.forward(hidden_states)
|
| 467 |
+
|
| 468 |
+
if len(self.loras_) == 0:
|
| 469 |
+
return residual
|
| 470 |
+
|
| 471 |
+
next_states = executor.init_tensor(residual)
|
| 472 |
+
lora_range = get_range_tensor(hidden_states.device, hidden_states.shape[0])
|
| 473 |
+
|
| 474 |
+
for lora_config in input_args.batch_configs_:
|
| 475 |
+
adapter_name = lora_config.adapter_name_
|
| 476 |
+
start_idx = lora_config.batch_start_idx_
|
| 477 |
+
end_idx = lora_config.batch_end_idx_
|
| 478 |
+
|
| 479 |
+
if adapter_name in self.loras_:
|
| 480 |
+
fwd_fn = self.loras_[adapter_name].forward
|
| 481 |
+
kwargs = {}
|
| 482 |
+
elif adapter_name in self.moes_:
|
| 483 |
+
fwd_fn = self.moes_[adapter_name].forward
|
| 484 |
+
kwargs = {"lora_linear": self}
|
| 485 |
+
else:
|
| 486 |
+
executor.index_copy(
|
| 487 |
+
next_states,
|
| 488 |
+
0,
|
| 489 |
+
lora_range[start_idx:end_idx],
|
| 490 |
+
residual[start_idx:end_idx],
|
| 491 |
+
)
|
| 492 |
+
continue
|
| 493 |
+
|
| 494 |
+
lora_data = fwd_fn(
|
| 495 |
+
residual=residual[start_idx:end_idx],
|
| 496 |
+
hidden_states=hidden_states[start_idx:end_idx],
|
| 497 |
+
**kwargs,
|
| 498 |
+
)
|
| 499 |
+
executor.index_copy(
|
| 500 |
+
next_states, 0, lora_range[start_idx:end_idx], lora_data
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
return next_states
|
| 504 |
+
|
| 505 |
+
def forward(
|
| 506 |
+
self, hidden_states: torch.Tensor, input_args: LLMModelInput
|
| 507 |
+
) -> torch.Tensor:
|
| 508 |
+
if input_args.efficient_operator_ and len(self.moes_) == 0:
|
| 509 |
+
return self._efficient_impl(hidden_states, input_args)
|
| 510 |
+
else:
|
| 511 |
+
return self._compatible_impl(hidden_states, input_args)
|
c2cite/common/moe_utils.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from .abstracts import LLMDecoder, LLMModelInput
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def slice_tensor(
|
| 10 |
+
data: torch.Tensor,
|
| 11 |
+
slice: torch.Tensor,
|
| 12 |
+
dtype: torch.dtype,
|
| 13 |
+
last_value: Optional[torch.Tensor] = None,
|
| 14 |
+
):
|
| 15 |
+
if last_value is None:
|
| 16 |
+
# for macOS debugging, please uncomment this line
|
| 17 |
+
# assert data.dtype in (torch.float, torch.int, torch.bool)
|
| 18 |
+
return data[None, slice].reshape(-1, data.shape[-1]).to(dtype)
|
| 19 |
+
else:
|
| 20 |
+
return last_value
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def unpack_router_logits(gate_logits: List[torch.Tensor]) -> torch.Tensor:
|
| 24 |
+
compute_device = gate_logits[0].device
|
| 25 |
+
concatenated_gate_logits = torch.cat(
|
| 26 |
+
[layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0
|
| 27 |
+
)
|
| 28 |
+
return concatenated_gate_logits
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def collect_plugin_router_logtis(
|
| 32 |
+
router_logits, input_args: LLMModelInput, decoder_layer: LLMDecoder
|
| 33 |
+
):
|
| 34 |
+
if router_logits is None or len(router_logits) == 0:
|
| 35 |
+
router_logits = [None for _ in range(len(input_args.batch_configs_))]
|
| 36 |
+
|
| 37 |
+
attn_proj, mlp_proj = decoder_layer.state_dict()
|
| 38 |
+
all_proj = copy.copy(attn_proj)
|
| 39 |
+
all_proj.update(mlp_proj)
|
| 40 |
+
for idx, config in enumerate(input_args.batch_configs_):
|
| 41 |
+
if router_logits[idx] is not None:
|
| 42 |
+
continue
|
| 43 |
+
adapter_name = config.adapter_name_
|
| 44 |
+
for proj in all_proj.values():
|
| 45 |
+
if adapter_name in proj.moes_ and hasattr(
|
| 46 |
+
proj.moes_[adapter_name], "router_logits_"
|
| 47 |
+
):
|
| 48 |
+
if router_logits[idx] is None:
|
| 49 |
+
router_logits[idx] = []
|
| 50 |
+
router_logits[idx].append(proj.moes_[adapter_name].router_logits_)
|
| 51 |
+
proj.moes_[adapter_name].router_logits_ = None
|
| 52 |
+
|
| 53 |
+
for idx, logits in enumerate(router_logits):
|
| 54 |
+
if isinstance(logits, list):
|
| 55 |
+
router_logits[idx] = torch.cat(logits, 0)
|
| 56 |
+
|
| 57 |
+
return router_logits
|
c2cite/common/rope.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from .config import LLMModelConfig
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _compute_default_rope_parameters(
|
| 10 |
+
config: Optional[LLMModelConfig] = None,
|
| 11 |
+
device: Optional[torch.device] = None,
|
| 12 |
+
seq_len: Optional[int] = None,
|
| 13 |
+
**rope_kwargs,
|
| 14 |
+
) -> Tuple[torch.Tensor, float]:
|
| 15 |
+
if len(rope_kwargs) > 0:
|
| 16 |
+
base = rope_kwargs["base"]
|
| 17 |
+
dim = rope_kwargs["dim"]
|
| 18 |
+
elif config is not None:
|
| 19 |
+
base = config.rope_theta_
|
| 20 |
+
partial_rotary_factor = (
|
| 21 |
+
config.partial_rotary_factor_
|
| 22 |
+
if config.partial_rotary_factor_ is not None
|
| 23 |
+
else 1.0
|
| 24 |
+
)
|
| 25 |
+
head_dim = (
|
| 26 |
+
config.dim_ // config.n_heads_
|
| 27 |
+
if config.head_dim_ is None
|
| 28 |
+
else config.head_dim_
|
| 29 |
+
)
|
| 30 |
+
dim = int(head_dim * partial_rotary_factor)
|
| 31 |
+
|
| 32 |
+
attention_factor = 1.0 # Unused in this type of RoPE
|
| 33 |
+
|
| 34 |
+
# Compute the inverse frequencies
|
| 35 |
+
inv_freq = 1.0 / (
|
| 36 |
+
base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)
|
| 37 |
+
)
|
| 38 |
+
return inv_freq, attention_factor
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _compute_llama3_parameters(
|
| 42 |
+
config: LLMModelConfig,
|
| 43 |
+
device: torch.device,
|
| 44 |
+
seq_len: Optional[int] = None,
|
| 45 |
+
**rope_kwargs,
|
| 46 |
+
) -> Tuple[torch.Tensor, float]:
|
| 47 |
+
# Gets the default RoPE parameters
|
| 48 |
+
inv_freq, attention_factor = _compute_default_rope_parameters(
|
| 49 |
+
config, device, seq_len, **rope_kwargs
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
factor = config.rope_scaling_["factor"] # `8` in the original implementation
|
| 53 |
+
low_freq_factor = config.rope_scaling_[
|
| 54 |
+
"low_freq_factor"
|
| 55 |
+
] # `1` in the original implementation
|
| 56 |
+
high_freq_factor = config.rope_scaling_[
|
| 57 |
+
"high_freq_factor"
|
| 58 |
+
] # `4` in the original implementation
|
| 59 |
+
old_context_len = config.rope_scaling_[
|
| 60 |
+
"original_max_position_embeddings"
|
| 61 |
+
] # `8192` in the original implementation
|
| 62 |
+
|
| 63 |
+
low_freq_wavelen = old_context_len / low_freq_factor
|
| 64 |
+
high_freq_wavelen = old_context_len / high_freq_factor
|
| 65 |
+
|
| 66 |
+
wavelen = 2 * math.pi / inv_freq
|
| 67 |
+
# wavelen < high_freq_wavelen: do nothing
|
| 68 |
+
# wavelen > low_freq_wavelen: divide by factor
|
| 69 |
+
inv_freq_llama = torch.where(
|
| 70 |
+
wavelen > low_freq_wavelen, inv_freq / factor, inv_freq
|
| 71 |
+
)
|
| 72 |
+
# otherwise: interpolate between the two, using a smooth factor
|
| 73 |
+
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (
|
| 74 |
+
high_freq_factor - low_freq_factor
|
| 75 |
+
)
|
| 76 |
+
smoothed_inv_freq = (
|
| 77 |
+
1 - smooth_factor
|
| 78 |
+
) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
|
| 79 |
+
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
|
| 80 |
+
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
|
| 81 |
+
|
| 82 |
+
return inv_freq_llama, attention_factor
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
ROPE_INIT_FUNCTIONS = {
|
| 86 |
+
"default": _compute_default_rope_parameters,
|
| 87 |
+
"llama3": _compute_llama3_parameters,
|
| 88 |
+
}
|
c2cite/dispatcher.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import random
|
| 3 |
+
import sys
|
| 4 |
+
from abc import abstractmethod
|
| 5 |
+
from typing import Callable, Dict, List
|
| 6 |
+
|
| 7 |
+
import datasets
|
| 8 |
+
import copy
|
| 9 |
+
|
| 10 |
+
from .common import InputData, LLMBatchConfig, LLMModelInput, Masks, Tokens
|
| 11 |
+
from .tokenizer import Tokenizer
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Event:
|
| 15 |
+
__callback_list: List[Callable] = None
|
| 16 |
+
|
| 17 |
+
def __init__(self):
|
| 18 |
+
self.__callback_list = []
|
| 19 |
+
|
| 20 |
+
def register(self, func: Callable) -> "Event":
|
| 21 |
+
self.__callback_list = [func] + self.__callback_list
|
| 22 |
+
return self
|
| 23 |
+
|
| 24 |
+
def activate(self, **kwargs) -> bool:
|
| 25 |
+
for func in self.__callback_list:
|
| 26 |
+
if func(**kwargs):
|
| 27 |
+
return True
|
| 28 |
+
|
| 29 |
+
return False
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def load_dataset(data_path: str):
|
| 33 |
+
if data_path.endswith(".json") or data_path.endswith(".jsonl"):
|
| 34 |
+
return datasets.load_dataset("json", data_files=data_path)
|
| 35 |
+
else:
|
| 36 |
+
if ":" in data_path:
|
| 37 |
+
result = data_path.split(":")
|
| 38 |
+
return datasets.load_dataset(result[0], result[1])
|
| 39 |
+
else:
|
| 40 |
+
return datasets.load_dataset(data_path)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class TrainTask:
|
| 44 |
+
tokenizer_: Tokenizer = None
|
| 45 |
+
|
| 46 |
+
adapter_name_: str = ""
|
| 47 |
+
data_path_: str = ""
|
| 48 |
+
dataload_function_: Callable = None
|
| 49 |
+
train_token_data_: List[InputData] = None
|
| 50 |
+
|
| 51 |
+
# train parameter
|
| 52 |
+
total_epoch_num_: int = -1
|
| 53 |
+
max_train_batch_size_: int = -1
|
| 54 |
+
max_train_micro_batch_size_: int = -1
|
| 55 |
+
max_test_batch_size_: int = -1
|
| 56 |
+
|
| 57 |
+
train_cutoff_len_: int = -1
|
| 58 |
+
group_by_length_: bool = False
|
| 59 |
+
|
| 60 |
+
# count the stat of train and test data
|
| 61 |
+
epoch_cnt_: int = 1
|
| 62 |
+
next_train_data_start_idx_: int = 0
|
| 63 |
+
next_test_data_start_idx_: int = 0
|
| 64 |
+
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
tokenzer: Tokenizer,
|
| 68 |
+
adapter_name: str,
|
| 69 |
+
dataload_function: Callable,
|
| 70 |
+
total_epoch_num: int,
|
| 71 |
+
max_train_batch_size: int,
|
| 72 |
+
max_train_micro_batch_size: int,
|
| 73 |
+
train_cutoff_len: int = 256,
|
| 74 |
+
group_by_length: bool = True,
|
| 75 |
+
):
|
| 76 |
+
self.tokenizer_ = tokenzer
|
| 77 |
+
self.adapter_name_ = adapter_name
|
| 78 |
+
self.dataload_function_ = dataload_function
|
| 79 |
+
self.total_epoch_num_ = total_epoch_num
|
| 80 |
+
self.max_train_batch_size_ = max_train_batch_size
|
| 81 |
+
self.max_train_micro_batch_size_ = max_train_micro_batch_size
|
| 82 |
+
self.train_cutoff_len_ = train_cutoff_len
|
| 83 |
+
self.group_by_length_ = group_by_length
|
| 84 |
+
|
| 85 |
+
def load_data(self):
|
| 86 |
+
self.train_token_data_ = self.dataload_function_(self.tokenizer_)
|
| 87 |
+
max_train_tokens_len = 0
|
| 88 |
+
for data in self.train_token_data_:
|
| 89 |
+
max_train_tokens_len = max(max_train_tokens_len, len(data.tokens))
|
| 90 |
+
if len(data.tokens) > self.train_cutoff_len_:
|
| 91 |
+
data.tokens = data.tokens[: self.train_cutoff_len_]
|
| 92 |
+
|
| 93 |
+
logging.info(
|
| 94 |
+
f"Max train tokens length: {max_train_tokens_len}/{self.train_cutoff_len_}"
|
| 95 |
+
)
|
| 96 |
+
if self.group_by_length_:
|
| 97 |
+
self.train_token_data_.sort(key=lambda x: len(x.tokens), reverse=True)
|
| 98 |
+
else:
|
| 99 |
+
random.shuffle(self.train_token_data_)
|
| 100 |
+
|
| 101 |
+
def is_train_done(self):
|
| 102 |
+
if self.epoch_cnt_ <= self.total_epoch_num_:
|
| 103 |
+
return False
|
| 104 |
+
return True
|
| 105 |
+
|
| 106 |
+
def is_test_done(self):
|
| 107 |
+
if self.next_test_data_start_idx_ < len(self.test_token_data_):
|
| 108 |
+
return False
|
| 109 |
+
return True
|
| 110 |
+
|
| 111 |
+
def reset_test_status(self):
|
| 112 |
+
self.next_test_data_start_idx_ = 0
|
| 113 |
+
|
| 114 |
+
# reentry function
|
| 115 |
+
def get_train_deta_max_seq_len(self) -> int:
|
| 116 |
+
start_idx = self.next_train_data_start_idx_
|
| 117 |
+
assert start_idx < len(self.train_token_data_)
|
| 118 |
+
# in this strategy must sort
|
| 119 |
+
return len(self.train_token_data_[start_idx].tokens)
|
| 120 |
+
|
| 121 |
+
# non reentry function
|
| 122 |
+
def get_train_data(self) -> List[InputData]:
|
| 123 |
+
start_idx = self.next_train_data_start_idx_
|
| 124 |
+
end_idx = start_idx + self.max_train_micro_batch_size_
|
| 125 |
+
|
| 126 |
+
ret_data = self.train_token_data_[start_idx:end_idx]
|
| 127 |
+
|
| 128 |
+
logging.info(f"{self.adapter_name_} train data:")
|
| 129 |
+
logging.info(
|
| 130 |
+
f" epoch: {self.epoch_cnt_}/{self.total_epoch_num_} \
|
| 131 |
+
step in epoch: {start_idx}/{len(self.train_token_data_)}"
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
self.next_train_data_start_idx_ += self.max_train_micro_batch_size_
|
| 135 |
+
if self.next_train_data_start_idx_ >= len(self.train_token_data_):
|
| 136 |
+
self.next_train_data_start_idx_ = 0
|
| 137 |
+
self.epoch_cnt_ += 1
|
| 138 |
+
|
| 139 |
+
return ret_data
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class DispatcherConfig:
|
| 143 |
+
@abstractmethod
|
| 144 |
+
def dispatcher_context(self) -> Dict[str, any]:
|
| 145 |
+
return {}
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class Dispatcher:
|
| 149 |
+
config_ = None
|
| 150 |
+
tokenizer_: Tokenizer = None
|
| 151 |
+
|
| 152 |
+
# all train task
|
| 153 |
+
ready_train_task_: List[TrainTask] = None
|
| 154 |
+
running_train_task_: List[TrainTask] = None
|
| 155 |
+
done_train_task_: List[TrainTask] = None
|
| 156 |
+
|
| 157 |
+
# train task in event
|
| 158 |
+
train_task_in_event_: Event = None
|
| 159 |
+
train_task_out_event_: Event = None
|
| 160 |
+
|
| 161 |
+
# the number of max candidate training lora model
|
| 162 |
+
# can chose train data from this dataset
|
| 163 |
+
train_lora_candidate_num_: int = 0
|
| 164 |
+
# the number of simultaneously train lora model
|
| 165 |
+
train_lora_simultaneously_num_: int = 0
|
| 166 |
+
|
| 167 |
+
strategy_: str = ""
|
| 168 |
+
|
| 169 |
+
def __init__(
|
| 170 |
+
self,
|
| 171 |
+
tokenizer: Tokenizer,
|
| 172 |
+
configs: List[DispatcherConfig],
|
| 173 |
+
max_concurrent_jobs: int = None,
|
| 174 |
+
strategy: str = "optim",
|
| 175 |
+
cutoff_len: int = 256,
|
| 176 |
+
) -> None:
|
| 177 |
+
if max_concurrent_jobs is None:
|
| 178 |
+
max_concurrent_jobs = len(configs)
|
| 179 |
+
|
| 180 |
+
self.tokenizer_ = tokenizer
|
| 181 |
+
|
| 182 |
+
self.ready_train_task_ = []
|
| 183 |
+
self.running_train_task_ = []
|
| 184 |
+
self.done_train_task_ = []
|
| 185 |
+
|
| 186 |
+
self.train_task_in_event_ = Event()
|
| 187 |
+
self.train_task_out_event_ = Event()
|
| 188 |
+
|
| 189 |
+
self.train_lora_candidate_num_ = sys.maxsize
|
| 190 |
+
self.train_lora_simultaneously_num_ = max_concurrent_jobs
|
| 191 |
+
self.strategy_ = strategy
|
| 192 |
+
|
| 193 |
+
# create ready task
|
| 194 |
+
for config_class in configs:
|
| 195 |
+
kwargs = config_class.dispatcher_context()
|
| 196 |
+
self.ready_train_task_.append(
|
| 197 |
+
TrainTask(
|
| 198 |
+
tokenzer=self.tokenizer_, train_cutoff_len=cutoff_len, **kwargs
|
| 199 |
+
)
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
def optim_dispatch_strategy(self) -> Dict[str, List[InputData]]:
|
| 203 |
+
task_len = {}
|
| 204 |
+
for idx, task in enumerate(self.running_train_task_):
|
| 205 |
+
task_len[idx] = task.get_train_deta_max_seq_len()
|
| 206 |
+
# sort to get the seq most similar data
|
| 207 |
+
task_len = sorted(task_len.items(), key=lambda x: x[1], reverse=True)
|
| 208 |
+
# find the mini diff
|
| 209 |
+
min_need_pad_len = sys.maxsize
|
| 210 |
+
win_start_idx = 0
|
| 211 |
+
for sidx in range(0, len(task_len) - self.train_lora_simultaneously_num_ + 1):
|
| 212 |
+
win = task_len[sidx : sidx + self.train_lora_simultaneously_num_]
|
| 213 |
+
need_pad_len = 0
|
| 214 |
+
for i in range(1, len(win)):
|
| 215 |
+
# aligin to the max seq len
|
| 216 |
+
need_pad_len += abs(win[i][1] - win[0][1])
|
| 217 |
+
if need_pad_len < min_need_pad_len:
|
| 218 |
+
min_need_pad_len = need_pad_len
|
| 219 |
+
win_start_idx = sidx
|
| 220 |
+
# the result is win_start_idx
|
| 221 |
+
result_win = task_len[
|
| 222 |
+
win_start_idx : win_start_idx + self.train_lora_simultaneously_num_
|
| 223 |
+
]
|
| 224 |
+
ret_train_data = {}
|
| 225 |
+
for result_task_len in result_win:
|
| 226 |
+
task_idx = result_task_len[0]
|
| 227 |
+
ret_train_data[self.running_train_task_[task_idx].adapter_name_] = (
|
| 228 |
+
self.running_train_task_[task_idx].get_train_data()
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
return ret_train_data
|
| 232 |
+
|
| 233 |
+
def none_dispatch_strategy(self) -> Dict[str, List[InputData]]:
|
| 234 |
+
ret_train_data = {}
|
| 235 |
+
cnt = 0
|
| 236 |
+
for task in self.running_train_task_:
|
| 237 |
+
assert not task.is_train_done()
|
| 238 |
+
if cnt >= self.train_lora_simultaneously_num_:
|
| 239 |
+
break
|
| 240 |
+
ret_train_data[task.adapter_name_] = task.get_train_data()
|
| 241 |
+
cnt += 1
|
| 242 |
+
return ret_train_data
|
| 243 |
+
|
| 244 |
+
def check_task_done(self) -> bool:
|
| 245 |
+
if len(self.ready_train_task_) == 0 and len(self.running_train_task_) == 0:
|
| 246 |
+
return True
|
| 247 |
+
return False
|
| 248 |
+
|
| 249 |
+
def check_test_done(self) -> bool:
|
| 250 |
+
for task in self.running_train_task_:
|
| 251 |
+
if task.is_train_done():
|
| 252 |
+
return False
|
| 253 |
+
return True
|
| 254 |
+
|
| 255 |
+
def reset_test_task(self):
|
| 256 |
+
for task in self.running_train_task_:
|
| 257 |
+
task.reset_test_status()
|
| 258 |
+
|
| 259 |
+
# ready task -> running task
|
| 260 |
+
def __dispatch_task_in(self):
|
| 261 |
+
assert len(self.running_train_task_) <= self.train_lora_candidate_num_
|
| 262 |
+
if len(self.running_train_task_) == self.train_lora_candidate_num_:
|
| 263 |
+
return
|
| 264 |
+
# chose task into running
|
| 265 |
+
while (
|
| 266 |
+
len(self.running_train_task_) < self.train_lora_candidate_num_
|
| 267 |
+
and len(self.ready_train_task_) > 0
|
| 268 |
+
):
|
| 269 |
+
# TODO to dispatch task
|
| 270 |
+
task = self.ready_train_task_.pop(0)
|
| 271 |
+
# to lazy load data
|
| 272 |
+
task.load_data()
|
| 273 |
+
self.train_task_in_event_.activate(task=task)
|
| 274 |
+
self.running_train_task_.append(task)
|
| 275 |
+
|
| 276 |
+
# running task -> done task
|
| 277 |
+
def __dispatch_task_out(self):
|
| 278 |
+
for task in self.running_train_task_:
|
| 279 |
+
if task.is_train_done():
|
| 280 |
+
self.train_task_out_event_.activate(task=task)
|
| 281 |
+
self.done_train_task_.append(task)
|
| 282 |
+
|
| 283 |
+
self.running_train_task_ = [
|
| 284 |
+
task for task in self.running_train_task_ if not task.is_train_done()
|
| 285 |
+
]
|
| 286 |
+
|
| 287 |
+
def get_test_data(self) -> LLMModelInput:
|
| 288 |
+
pass
|
| 289 |
+
|
| 290 |
+
def get_train_data(self) -> LLMModelInput:
|
| 291 |
+
self.__dispatch_task_in()
|
| 292 |
+
|
| 293 |
+
# get task train data
|
| 294 |
+
all_train_data: Dict[str, List[InputData]] = {}
|
| 295 |
+
if self.strategy_ == "none":
|
| 296 |
+
all_train_data = self.none_dispatch_strategy()
|
| 297 |
+
elif self.strategy_ == "optim":
|
| 298 |
+
all_train_data = self.optim_dispatch_strategy()
|
| 299 |
+
else:
|
| 300 |
+
raise "unkown strategy"
|
| 301 |
+
|
| 302 |
+
batch_seq_len: int = -1
|
| 303 |
+
# to align batch token data
|
| 304 |
+
for adapter in all_train_data:
|
| 305 |
+
for data in all_train_data[adapter]:
|
| 306 |
+
batch_seq_len = max(batch_seq_len, len(data.tokens))
|
| 307 |
+
# all prompts and tokens / config
|
| 308 |
+
batch_tokens: List[Tokens] = []
|
| 309 |
+
attention_masks: List[Masks] = []
|
| 310 |
+
batch_labels: List[List] = []
|
| 311 |
+
lora_batch_data_config: List[LLMBatchConfig] = []
|
| 312 |
+
|
| 313 |
+
cites = []
|
| 314 |
+
cites_value = []
|
| 315 |
+
docs = []
|
| 316 |
+
prompt_len = []
|
| 317 |
+
# batch the all adapter data
|
| 318 |
+
adapter_start_idx: int = 0
|
| 319 |
+
for adapter in all_train_data:
|
| 320 |
+
adapter_end_idx: int = adapter_start_idx + len(all_train_data[adapter])
|
| 321 |
+
for data in all_train_data[adapter]:
|
| 322 |
+
tokens: Tokens = data.tokens.copy()
|
| 323 |
+
#print(data.inputs)
|
| 324 |
+
#print("")
|
| 325 |
+
def condition(i):
|
| 326 |
+
return (128010 <= i <= 128255) or i in {128004, 128002, 128003, 128005, 128008}
|
| 327 |
+
prompt_len.append(data.prompt_len)
|
| 328 |
+
cite = [index for index, value in enumerate(tokens) if condition(value)]
|
| 329 |
+
cite_value = [value for value in tokens if condition(value)]
|
| 330 |
+
assert len(cite) <40, print(f"too long!!! need:{len(cites)}")
|
| 331 |
+
if len(cite) > 0:
|
| 332 |
+
if cite[len(cite) - 1] != data.token_len:
|
| 333 |
+
cite.append(data.token_len)
|
| 334 |
+
pad_side = self.tokenizer_.padding_side_
|
| 335 |
+
assert pad_side == "right" or pad_side == "left"
|
| 336 |
+
# pad the tokens to align
|
| 337 |
+
while len(tokens) < batch_seq_len:
|
| 338 |
+
if pad_side == "right":
|
| 339 |
+
tokens.append(self.tokenizer_.pad_id_)
|
| 340 |
+
else:
|
| 341 |
+
tokens.insert(0, self.tokenizer_.pad_id_)
|
| 342 |
+
batch_tokens.append(tokens)
|
| 343 |
+
cites.append(cite.copy())
|
| 344 |
+
cites_value.append(cite_value.copy())
|
| 345 |
+
if data.citation_embeds == None:
|
| 346 |
+
docs.append(data.citation_tokens)
|
| 347 |
+
else:
|
| 348 |
+
docs.append(data.citation_embeds)
|
| 349 |
+
attention_masks.append(self.tokenizer_.mask_from(tokens))
|
| 350 |
+
labels = data.labels
|
| 351 |
+
if labels is None:
|
| 352 |
+
labels = tokens.copy()
|
| 353 |
+
else:
|
| 354 |
+
labels = labels.copy()
|
| 355 |
+
batch_labels.append(labels)
|
| 356 |
+
|
| 357 |
+
lora_batch_data_config.append(
|
| 358 |
+
LLMBatchConfig(
|
| 359 |
+
adapter_name_=adapter,
|
| 360 |
+
batch_start_idx_=adapter_start_idx,
|
| 361 |
+
batch_end_idx_=adapter_end_idx,
|
| 362 |
+
)
|
| 363 |
+
)
|
| 364 |
+
adapter_start_idx = adapter_end_idx
|
| 365 |
+
|
| 366 |
+
self.__dispatch_task_out()
|
| 367 |
+
|
| 368 |
+
return LLMModelInput(
|
| 369 |
+
batch_cites = cites,
|
| 370 |
+
batch_cites_value=cites_value,
|
| 371 |
+
batch_docs = docs,
|
| 372 |
+
batch_prompt_len = prompt_len,
|
| 373 |
+
batch_configs_=lora_batch_data_config,
|
| 374 |
+
batch_tokens_=batch_tokens,
|
| 375 |
+
batch_labels_=batch_labels,
|
| 376 |
+
batch_masks_=attention_masks,
|
| 377 |
+
gradient_checkpoint_="recompute",
|
| 378 |
+
)
|
c2cite/evaluator.py
ADDED
|
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
import time
|
| 4 |
+
import sys
|
| 5 |
+
import os
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Dict, List, Tuple, Union, Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from .adapters import MixLoraConfig
|
| 12 |
+
from .common import InputData, LLMBatchConfig, LLMModelInput, Prompt, Tokens
|
| 13 |
+
from .model import LLMModel
|
| 14 |
+
from .tasks import BasicMetric, BasicTask, CommonSenseTask, task_dict
|
| 15 |
+
from .tokenizer import Tokenizer
|
| 16 |
+
from moe_peft.prompter import Prompter
|
| 17 |
+
from moe_peft.generator import _batch_generate
|
| 18 |
+
from moe_peft.solutions import get_output
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class GenerateData:
|
| 22 |
+
adapter_name_: str = None
|
| 23 |
+
prompt_index_: int = None
|
| 24 |
+
prefix_length_: int = None
|
| 25 |
+
raw_tokens_: Tokens = None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class GenerateConfig:
|
| 30 |
+
adapter_name: str = None
|
| 31 |
+
prompts: List[Union[str, Tuple[str, str]]] = None
|
| 32 |
+
prompt_template: str = None
|
| 33 |
+
# Generate Arguments
|
| 34 |
+
batch_size: int = 8
|
| 35 |
+
stop_token: str = None
|
| 36 |
+
temperature: float = 1
|
| 37 |
+
top_p: float = 0.9
|
| 38 |
+
top_k: float = 50
|
| 39 |
+
do_sample: bool = True
|
| 40 |
+
repetition_penalty: float = 1.1
|
| 41 |
+
renormalize_logits: bool = True
|
| 42 |
+
# Do not set these manually
|
| 43 |
+
prompter_: Prompter = None
|
| 44 |
+
stop_token_: torch.Tensor = None
|
| 45 |
+
data_: List[GenerateData] = None
|
| 46 |
+
|
| 47 |
+
# Set prompt_template_ to enable the prompter
|
| 48 |
+
def generate_prompt(self, instruction: str, input: str = None) -> str:
|
| 49 |
+
if self.prompter_ is None:
|
| 50 |
+
self.prompter_ = Prompter(self.prompt_template)
|
| 51 |
+
|
| 52 |
+
return self.prompter_.generate_prompt(instruction=instruction, input=input)
|
| 53 |
+
|
| 54 |
+
def get_prompts(self) -> List[str]:
|
| 55 |
+
prompts = []
|
| 56 |
+
for prompt in self.prompts:
|
| 57 |
+
args = prompt if isinstance(prompt, Tuple) else (prompt, None)
|
| 58 |
+
prompts.append(self.generate_prompt(*args))
|
| 59 |
+
|
| 60 |
+
return prompts
|
| 61 |
+
|
| 62 |
+
def get_response(self, output: str) -> str:
|
| 63 |
+
if self.prompter_ is None:
|
| 64 |
+
return output.strip()
|
| 65 |
+
else:
|
| 66 |
+
return self.prompter_.get_response(output)
|
| 67 |
+
|
| 68 |
+
def reset_parameters(self):
|
| 69 |
+
self.prompter_ = Prompter(self.prompt_template)
|
| 70 |
+
self.stop_token_ = None
|
| 71 |
+
self.data_ = []
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
@dataclass
|
| 75 |
+
class EvaluateConfig:
|
| 76 |
+
adapter_name: str = None
|
| 77 |
+
task_name: str = None
|
| 78 |
+
data_path: str = None
|
| 79 |
+
batch_size: int = 16
|
| 80 |
+
router_profile: bool = False
|
| 81 |
+
# Do not set these manually
|
| 82 |
+
task_: BasicTask = None
|
| 83 |
+
data_: List[InputData] = None
|
| 84 |
+
metric_: BasicMetric = None
|
| 85 |
+
rollback_start_idx_: int = 0
|
| 86 |
+
batch_start_idx_: int = 0
|
| 87 |
+
batch_end_idx_: int = 0
|
| 88 |
+
|
| 89 |
+
def _dataload_fn(self, tokenizer: Tokenizer, **tokenizer_kwargs):
|
| 90 |
+
data = self.task_.loading_data(False, self.data_path)
|
| 91 |
+
for idx, data_point in enumerate(data):
|
| 92 |
+
assert not isinstance(data_point.inputs, Prompt)
|
| 93 |
+
|
| 94 |
+
data_point.tokens = tokenizer.encode(data_point.inputs, **tokenizer_kwargs)
|
| 95 |
+
data_point.prefix_length_ = len(data_point.tokens)
|
| 96 |
+
if data_point.citations is not None:
|
| 97 |
+
if data_point.citation_embeds is None:
|
| 98 |
+
data_point.citation_tokens = [tokenizer.encode(c, **tokenizer_kwargs)
|
| 99 |
+
for c in data_point.citations]
|
| 100 |
+
else:
|
| 101 |
+
data_point.citation_tokens = data_point.citation_embeds
|
| 102 |
+
if idx % 10000 == 0:
|
| 103 |
+
logging.info(f"Encode text data: {idx}/{len(data)}")
|
| 104 |
+
|
| 105 |
+
return data
|
| 106 |
+
|
| 107 |
+
@staticmethod
|
| 108 |
+
def from_config(config: Dict[str, any]) -> List["EvaluateConfig"]:
|
| 109 |
+
adapter_name = config["name"]
|
| 110 |
+
data_path = config.get("data", None)
|
| 111 |
+
task_list = config.get("task_name", "casual").split(";")
|
| 112 |
+
path_list = (
|
| 113 |
+
[None] * len(task_list) if data_path is None else data_path.split(";")
|
| 114 |
+
)
|
| 115 |
+
config_list = []
|
| 116 |
+
for task_name_, data_path_ in zip(task_list, path_list):
|
| 117 |
+
if task_name_ not in task_dict:
|
| 118 |
+
continue
|
| 119 |
+
config_list.append(
|
| 120 |
+
EvaluateConfig(
|
| 121 |
+
adapter_name=adapter_name,
|
| 122 |
+
task_name=task_name_,
|
| 123 |
+
data_path=data_path_,
|
| 124 |
+
batch_size=config["evaluate_batch_size"],
|
| 125 |
+
)
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
return config_list
|
| 129 |
+
|
| 130 |
+
def prepare(self, tokenizer: Tokenizer, device: str):
|
| 131 |
+
self.reset_parameters()
|
| 132 |
+
assert (
|
| 133 |
+
self.task_name != "casual"
|
| 134 |
+
), "Auto evaluation is not currently available for casual supervised fine-tuning tasks."
|
| 135 |
+
self.task_ = task_dict[self.task_name]
|
| 136 |
+
self.data_ = self._dataload_fn(tokenizer)
|
| 137 |
+
self.metric_ = self.task_.loading_metric()
|
| 138 |
+
if isinstance(self.task_, CommonSenseTask):
|
| 139 |
+
labels = self.task_.label_list()
|
| 140 |
+
label_indices = [0] * len(labels)
|
| 141 |
+
for idx, label in enumerate(labels):
|
| 142 |
+
ids = tokenizer.encode(" " + label)
|
| 143 |
+
label_indices[idx] = ids[-1]
|
| 144 |
+
self.label_indices_ = torch.tensor(
|
| 145 |
+
label_indices, dtype=torch.int64, device=device
|
| 146 |
+
)
|
| 147 |
+
else:
|
| 148 |
+
self.label_indices_ = None
|
| 149 |
+
|
| 150 |
+
def reset_parameters(self):
|
| 151 |
+
self.task_ = None
|
| 152 |
+
self.data_ = None
|
| 153 |
+
self.metric_ = None
|
| 154 |
+
self.rollback_start_idx_ = 0
|
| 155 |
+
self.batch_start_idx_ = 0
|
| 156 |
+
self.batch_end_idx_ = 0
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def _prepare_tasks(model, tokenizer, configs):
|
| 160 |
+
for config in configs:
|
| 161 |
+
config.prepare(tokenizer, model.device_)
|
| 162 |
+
if not isinstance(model.adapter_configs_[config.adapter_name], MixLoraConfig):
|
| 163 |
+
continue
|
| 164 |
+
for layer in model.model_.layers_:
|
| 165 |
+
if config.adapter_name in layer.mlp_.moes_:
|
| 166 |
+
layer.mlp_.moes_[config.adapter_name].router_profile_ = (
|
| 167 |
+
config.router_profile
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def _dispatch_task_in(tokenizer, configs, concurrent_jobs, max_seq_len):
|
| 172 |
+
batch_data_config = []
|
| 173 |
+
sequence_lengths = []
|
| 174 |
+
current_configs = []
|
| 175 |
+
batch_tokens = []
|
| 176 |
+
batch_labels = []
|
| 177 |
+
more_grounds = []
|
| 178 |
+
atten_masks = []
|
| 179 |
+
max_tokens_len = 0
|
| 180 |
+
for config in configs:
|
| 181 |
+
if len(current_configs) >= concurrent_jobs:
|
| 182 |
+
break
|
| 183 |
+
if config.batch_start_idx_ >= len(config.data_):
|
| 184 |
+
continue
|
| 185 |
+
config.batch_end_idx_ = min(
|
| 186 |
+
config.batch_start_idx_ + config.batch_size, len(config.data_)
|
| 187 |
+
)
|
| 188 |
+
batch_start_idx = len(batch_tokens)
|
| 189 |
+
for idx in range(config.batch_start_idx_, config.batch_end_idx_):
|
| 190 |
+
if idx >= len(config.data_):
|
| 191 |
+
break
|
| 192 |
+
tokens = config.data_[idx].tokens
|
| 193 |
+
labels = config.data_[idx].labels
|
| 194 |
+
grounds = config.data_[idx].grounds
|
| 195 |
+
if len(tokens) > max_seq_len:
|
| 196 |
+
tokens = tokens[:max_seq_len]
|
| 197 |
+
max_tokens_len = max(len(tokens), max_tokens_len)
|
| 198 |
+
batch_tokens.append(tokens)
|
| 199 |
+
if labels:
|
| 200 |
+
batch_labels.append([labels].copy())
|
| 201 |
+
if grounds:
|
| 202 |
+
more_grounds.append(grounds.copy())
|
| 203 |
+
|
| 204 |
+
config.batch_start_idx_ = config.batch_end_idx_
|
| 205 |
+
current_configs.append(config)
|
| 206 |
+
batch_data_config.append(
|
| 207 |
+
LLMBatchConfig(
|
| 208 |
+
adapter_name_=config.adapter_name,
|
| 209 |
+
batch_start_idx_=batch_start_idx,
|
| 210 |
+
batch_end_idx_=len(batch_tokens),
|
| 211 |
+
)
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
max_seq_len = min(max_seq_len, max_tokens_len)
|
| 215 |
+
|
| 216 |
+
for tokens in batch_tokens:
|
| 217 |
+
sequence_lengths.append(len(tokens) - 1)
|
| 218 |
+
while len(tokens) < max_seq_len:
|
| 219 |
+
tokens.append(tokenizer.pad_id_)
|
| 220 |
+
atten_masks.append(tokenizer.mask_from(tokens))
|
| 221 |
+
|
| 222 |
+
return (
|
| 223 |
+
current_configs,
|
| 224 |
+
sequence_lengths,
|
| 225 |
+
batch_labels,
|
| 226 |
+
more_grounds,
|
| 227 |
+
LLMModelInput(
|
| 228 |
+
batch_configs_=batch_data_config,
|
| 229 |
+
batch_tokens_=batch_tokens,
|
| 230 |
+
batch_masks_=atten_masks,
|
| 231 |
+
inference_mode_=True,
|
| 232 |
+
),
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def _compute_metrcis(model, current_configs, sequence_lengths, batch_labels, outputs):
|
| 237 |
+
for idx, output in enumerate(outputs):
|
| 238 |
+
config: EvaluateConfig = current_configs[idx]
|
| 239 |
+
task: BasicTask = config.task_
|
| 240 |
+
metric: BasicMetric = config.metric_
|
| 241 |
+
start_idx = output.batch_start_idx_
|
| 242 |
+
end_idx = output.batch_end_idx_
|
| 243 |
+
logits = output.logits
|
| 244 |
+
|
| 245 |
+
if config.router_profile:
|
| 246 |
+
adapter_config = model.adapter_configs_[config.adapter_name]
|
| 247 |
+
if isinstance(adapter_config, MixLoraConfig):
|
| 248 |
+
router_statistic_ = list(0 for _ in range(adapter_config.num_experts_))
|
| 249 |
+
for layer in model.model_.layers_:
|
| 250 |
+
if config.adapter_name not in layer.mlp_.moes_:
|
| 251 |
+
continue
|
| 252 |
+
for idx, val in enumerate(
|
| 253 |
+
layer.mlp_.moes_[config.adapter_name].profiler_
|
| 254 |
+
):
|
| 255 |
+
router_statistic_[idx] += val
|
| 256 |
+
for idx, val in enumerate(router_statistic_):
|
| 257 |
+
logging.info(
|
| 258 |
+
f"{config.adapter_name}: expert {idx}, load = {val/32}"
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
batch_size = logits.shape[0]
|
| 262 |
+
pooled_logits = logits[
|
| 263 |
+
torch.arange(batch_size, device=logits.device),
|
| 264 |
+
sequence_lengths[start_idx:end_idx],
|
| 265 |
+
]
|
| 266 |
+
labels = torch.tensor(
|
| 267 |
+
batch_labels[start_idx:end_idx],
|
| 268 |
+
dtype=task.label_dtype_,
|
| 269 |
+
device=logits.device,
|
| 270 |
+
)
|
| 271 |
+
if task.task_type_ == "common_sense":
|
| 272 |
+
pooled_logits = pooled_logits[:, config.label_indices_]
|
| 273 |
+
pooled_logits = pooled_logits.softmax(-1).argmax(-1)
|
| 274 |
+
elif task.task_type_ == "single_label_classification":
|
| 275 |
+
pooled_logits = pooled_logits.softmax(-1).argmax(-1)
|
| 276 |
+
pooled_logits = pooled_logits.to(task.label_dtype_)
|
| 277 |
+
elif task.task_type_ != "multi_label_classification":
|
| 278 |
+
raise ValueError(f"unknown task type {task.task_type_}")
|
| 279 |
+
|
| 280 |
+
metric.add_batch(
|
| 281 |
+
predictions=pooled_logits.detach().cpu(), references=labels.detach().cpu()
|
| 282 |
+
)
|
| 283 |
+
logging.info(f"{config.adapter_name} evaluate data:")
|
| 284 |
+
logging.info(f" step: {config.batch_start_idx_}/{len(config.data_)}")
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def _compute_result(model, configs, save_file):
|
| 288 |
+
results = []
|
| 289 |
+
for config in configs:
|
| 290 |
+
result = {
|
| 291 |
+
"adapter_name": config.adapter_name,
|
| 292 |
+
"task_name": config.task_name,
|
| 293 |
+
"date_time": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
|
| 294 |
+
"metrics": {},
|
| 295 |
+
}
|
| 296 |
+
compute_results = config.metric_.compute()
|
| 297 |
+
result["metrics"] = compute_results
|
| 298 |
+
if config.router_profile:
|
| 299 |
+
adapter_config = model.adapter_configs_[config.adapter_name]
|
| 300 |
+
if isinstance(adapter_config, MixLoraConfig):
|
| 301 |
+
router_statistic_ = list(0 for _ in range(adapter_config.num_experts_))
|
| 302 |
+
for layer in model.model_.layers_:
|
| 303 |
+
if config.adapter_name not in layer.mlp_.moes_:
|
| 304 |
+
continue
|
| 305 |
+
for idx, val in enumerate(
|
| 306 |
+
layer.mlp_.moes_[config.adapter_name].profiler_
|
| 307 |
+
):
|
| 308 |
+
router_statistic_[idx] += val
|
| 309 |
+
layer.mlp_.moes_[config.adapter_name].profiler_ = None
|
| 310 |
+
result["router_profile"] = list(val / 32 for val in router_statistic_)
|
| 311 |
+
|
| 312 |
+
results.append(result)
|
| 313 |
+
|
| 314 |
+
if save_file is not None:
|
| 315 |
+
if not os.path.exists(save_file):
|
| 316 |
+
os.makedirs(save_file)
|
| 317 |
+
file_path = save_file + os.sep + f"{config.adapter_name}.json"
|
| 318 |
+
with open(file_path, "w") as f:
|
| 319 |
+
json.dump(results, f, indent=4)
|
| 320 |
+
logging.info(f"saving evaluation result to {file_path}")
|
| 321 |
+
else:
|
| 322 |
+
print(json.dumps(results, indent=4))
|
| 323 |
+
|
| 324 |
+
return results
|
| 325 |
+
|
| 326 |
+
def _dispatch_task_in2(
|
| 327 |
+
tokenizer,
|
| 328 |
+
configs: List[GenerateConfig],# config.data_, config.batch_size, config, config.adapter_name
|
| 329 |
+
concurrent_jobs: int,
|
| 330 |
+
strategy: str = "fair",
|
| 331 |
+
):
|
| 332 |
+
assert strategy in ["fair", "fifo"], f"Unknown dispatch strategy {strategy}"
|
| 333 |
+
current_jobs = []
|
| 334 |
+
batch_config = []
|
| 335 |
+
input_tokens = []
|
| 336 |
+
max_tokens_len = 0
|
| 337 |
+
min_tokens_len = sys.maxsize
|
| 338 |
+
for config in configs:
|
| 339 |
+
if len(batch_config) >= concurrent_jobs:
|
| 340 |
+
break
|
| 341 |
+
|
| 342 |
+
if len(config.data_) == 0:
|
| 343 |
+
continue
|
| 344 |
+
print(f"count down:{len(config.data_)}")
|
| 345 |
+
if strategy == "fair":
|
| 346 |
+
per_task_jobs = max(concurrent_jobs // len(configs), 1)
|
| 347 |
+
else:
|
| 348 |
+
per_task_jobs = concurrent_jobs
|
| 349 |
+
|
| 350 |
+
per_task_jobs = min(per_task_jobs, config.batch_size)
|
| 351 |
+
|
| 352 |
+
batch_start_idx = len(input_tokens)
|
| 353 |
+
while per_task_jobs > 0 and len(config.data_) > 0:
|
| 354 |
+
per_task_jobs = per_task_jobs - 1
|
| 355 |
+
data = config.data_.pop(0)
|
| 356 |
+
current_jobs.append(data)
|
| 357 |
+
tokens = data.tokens
|
| 358 |
+
max_tokens_len = max(len(tokens), max_tokens_len)
|
| 359 |
+
min_tokens_len = min(len(tokens), min_tokens_len)
|
| 360 |
+
input_tokens.append(tokens)
|
| 361 |
+
|
| 362 |
+
batch_config.append(
|
| 363 |
+
LLMBatchConfig(
|
| 364 |
+
adapter_name_=config.adapter_name,
|
| 365 |
+
batch_start_idx_=batch_start_idx,
|
| 366 |
+
batch_end_idx_=len(input_tokens),
|
| 367 |
+
)
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
return (
|
| 371 |
+
current_jobs,
|
| 372 |
+
batch_config,
|
| 373 |
+
input_tokens,
|
| 374 |
+
max_tokens_len,
|
| 375 |
+
min_tokens_len,
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def _generate_then_compute_metrics(
|
| 380 |
+
model, tokenizer, concurrent_jobs, \
|
| 381 |
+
max_gen_len, current_configs: List[EvaluateConfig],\
|
| 382 |
+
require_attention: Optional[int] = -1, require_hide: Optional[int] = -1
|
| 383 |
+
):
|
| 384 |
+
# grounds 是qa_pair
|
| 385 |
+
metric = current_configs[0].metric_.metric_
|
| 386 |
+
|
| 387 |
+
###outputs, hidden_output, hidden_atten = model.forward(input_args)
|
| 388 |
+
|
| 389 |
+
#!!! 在这把current_configs转化为GenerateConfig。现在是EvaluateConfig
|
| 390 |
+
#cnt = 50
|
| 391 |
+
#cases = []
|
| 392 |
+
while True:# configs里的data在变,是调度的唯一指标
|
| 393 |
+
dispatch_args = _dispatch_task_in2(tokenizer, current_configs, concurrent_jobs)
|
| 394 |
+
# 包含:current_jobs, batch_config(LLMBatchConfig(taskname,start,end)),
|
| 395 |
+
# batch_tokens, max_lenth, min_length
|
| 396 |
+
|
| 397 |
+
if len(dispatch_args[0]) == 0:
|
| 398 |
+
break
|
| 399 |
+
use_cache = True
|
| 400 |
+
cache_implementation = model.model_.cache_implementation()
|
| 401 |
+
if cache_implementation is None:
|
| 402 |
+
logging.warn(
|
| 403 |
+
"Cache disabled by model, use cache_implementation to force enable."
|
| 404 |
+
)
|
| 405 |
+
use_cache = False
|
| 406 |
+
outputs, running_jobs = _batch_generate(
|
| 407 |
+
model,
|
| 408 |
+
tokenizer,
|
| 409 |
+
max_gen_len,
|
| 410 |
+
use_cache,
|
| 411 |
+
require_attention,
|
| 412 |
+
require_hide,
|
| 413 |
+
cache_implementation,
|
| 414 |
+
None,
|
| 415 |
+
*dispatch_args,
|
| 416 |
+
)
|
| 417 |
+
for data in running_jobs:
|
| 418 |
+
current_configs[0].data_.append(data)
|
| 419 |
+
|
| 420 |
+
print(f"\noutput:{outputs[0]}\n")
|
| 421 |
+
metric.add_batch(
|
| 422 |
+
{
|
| 423 |
+
'output': outputs[0],
|
| 424 |
+
'qa_pairs': dispatch_args[0][0].grounds,
|
| 425 |
+
'answer': dispatch_args[0][0].labels,
|
| 426 |
+
'docs': dispatch_args[0][0].citations,
|
| 427 |
+
'query': dispatch_args[0][0].query,
|
| 428 |
+
}
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
@torch.inference_mode()
|
| 433 |
+
def evaluate(
|
| 434 |
+
model: LLMModel,
|
| 435 |
+
tokenizer: Tokenizer,
|
| 436 |
+
configs: List[EvaluateConfig],
|
| 437 |
+
max_concurrent_jobs: int = None,
|
| 438 |
+
retrying_steps: int = 20,
|
| 439 |
+
max_seq_len: int = 512,
|
| 440 |
+
save_file: str = None,
|
| 441 |
+
require_attention: Optional[int] = -1,
|
| 442 |
+
require_hide: Optional[int] = -1,
|
| 443 |
+
) -> Dict:
|
| 444 |
+
|
| 445 |
+
if max_concurrent_jobs is None:
|
| 446 |
+
max_concurrent_jobs = len(configs)
|
| 447 |
+
logging.info(
|
| 448 |
+
f"Setting max_concurrent_jobs to {max_concurrent_jobs} automatically"
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
assert max_concurrent_jobs > 0
|
| 452 |
+
assert retrying_steps > 0
|
| 453 |
+
|
| 454 |
+
_prepare_tasks(model, tokenizer, configs)
|
| 455 |
+
|
| 456 |
+
concurrent_jobs = max_concurrent_jobs
|
| 457 |
+
retrying_count = 0
|
| 458 |
+
while True:
|
| 459 |
+
if concurrent_jobs < max_concurrent_jobs and retrying_count > 0:
|
| 460 |
+
retrying_count -= 1
|
| 461 |
+
if retrying_count == 0:
|
| 462 |
+
concurrent_jobs += 1
|
| 463 |
+
logging.info(f"recovering concurrent jobs to {concurrent_jobs}")
|
| 464 |
+
|
| 465 |
+
current_configs, sequence_lengths, batch_labels, grounds, input_args = _dispatch_task_in(
|
| 466 |
+
tokenizer, configs, concurrent_jobs, max_seq_len
|
| 467 |
+
)
|
| 468 |
+
# current_configs(这个batch的configs)
|
| 469 |
+
# sequence_lengths(这个batch里的tokens的length)
|
| 470 |
+
# batch_labels、grounds
|
| 471 |
+
# input_args: LLMBatchConfig (batch_config(adapter_name, start,end),/
|
| 472 |
+
# tokens, attention_mask)
|
| 473 |
+
if len(current_configs) == 0:
|
| 474 |
+
break
|
| 475 |
+
|
| 476 |
+
try:
|
| 477 |
+
if current_configs[0].task_.task_type_ == 'attribute':
|
| 478 |
+
_generate_then_compute_metrics(
|
| 479 |
+
model,
|
| 480 |
+
tokenizer,
|
| 481 |
+
concurrent_jobs,
|
| 482 |
+
max_seq_len,
|
| 483 |
+
current_configs,
|
| 484 |
+
require_attention,
|
| 485 |
+
require_hide
|
| 486 |
+
)
|
| 487 |
+
else:
|
| 488 |
+
_compute_metrcis(
|
| 489 |
+
model,
|
| 490 |
+
current_configs,
|
| 491 |
+
sequence_lengths,
|
| 492 |
+
batch_labels,
|
| 493 |
+
model.forward(input_args),
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
except RuntimeError as e:
|
| 497 |
+
if "out of memory" in str(e).lower():
|
| 498 |
+
concurrent_jobs -= 1
|
| 499 |
+
if concurrent_jobs == 0:
|
| 500 |
+
raise e
|
| 501 |
+
logging.warn(
|
| 502 |
+
f"deprecating concurrent jobs to {concurrent_jobs} due to OOM."
|
| 503 |
+
)
|
| 504 |
+
# rollback
|
| 505 |
+
retrying_count = retrying_steps
|
| 506 |
+
for config in current_configs:
|
| 507 |
+
config.batch_start_idx_ = config.rollback_start_idx_
|
| 508 |
+
logging.info(
|
| 509 |
+
f"{config.adapter_name}: rollback to {config.batch_start_idx_}/{len(config.data_)}"
|
| 510 |
+
)
|
| 511 |
+
continue
|
| 512 |
+
else:
|
| 513 |
+
raise e
|
| 514 |
+
|
| 515 |
+
for config in current_configs:
|
| 516 |
+
config.rollback_start_idx_ = config.batch_start_idx_
|
| 517 |
+
|
| 518 |
+
return _compute_result(model, configs, save_file)
|
c2cite/executors/__init__.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from .common import BasicExecutor
|
| 7 |
+
from .cpu import CPUExecutor
|
| 8 |
+
from .cuda import CUDAExecutor
|
| 9 |
+
from .mps import MPSExecutor
|
| 10 |
+
|
| 11 |
+
executor_dict = {
|
| 12 |
+
"CUDA": CUDAExecutor,
|
| 13 |
+
"MPS": MPSExecutor,
|
| 14 |
+
"CPU": CPUExecutor,
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _init_executor():
|
| 19 |
+
env = os.getenv("MOE_PEFT_EXECUTOR_TYPE")
|
| 20 |
+
if env is not None:
|
| 21 |
+
env = env.upper()
|
| 22 |
+
if env not in executor_dict:
|
| 23 |
+
raise ValueError(f"Assigning unknown executor type {env}")
|
| 24 |
+
return executor_dict[env]()
|
| 25 |
+
elif torch.cuda.is_available():
|
| 26 |
+
return CUDAExecutor()
|
| 27 |
+
elif torch.backends.mps.is_available():
|
| 28 |
+
return MPSExecutor()
|
| 29 |
+
else:
|
| 30 |
+
return CPUExecutor()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
executor: BasicExecutor = _init_executor()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class no_cache(object):
|
| 37 |
+
def __enter__(self):
|
| 38 |
+
executor.empty_cache()
|
| 39 |
+
gc.collect()
|
| 40 |
+
return self
|
| 41 |
+
|
| 42 |
+
def __exit__(self, type, value, traceback):
|
| 43 |
+
executor.empty_cache()
|
| 44 |
+
gc.collect()
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
__all__ = [
|
| 48 |
+
"BasicExecutor",
|
| 49 |
+
"CUDAExecutor",
|
| 50 |
+
"MPSExecutor",
|
| 51 |
+
"CPUExecutor",
|
| 52 |
+
"executor",
|
| 53 |
+
"no_cache",
|
| 54 |
+
]
|
c2cite/executors/common.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from transformers.utils import is_torch_bf16_available_on_device
|
| 6 |
+
|
| 7 |
+
from moe_peft.utils import NoneContexts
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class BasicExecutor:
|
| 11 |
+
def name(self) -> str:
|
| 12 |
+
raise NotImplementedError()
|
| 13 |
+
|
| 14 |
+
def device_name(self) -> str:
|
| 15 |
+
raise NotImplementedError()
|
| 16 |
+
|
| 17 |
+
def default_device_name(self) -> str:
|
| 18 |
+
return self.device_name()
|
| 19 |
+
|
| 20 |
+
def is_available(self) -> bool:
|
| 21 |
+
raise NotImplementedError()
|
| 22 |
+
|
| 23 |
+
def is_initialized(self) -> bool:
|
| 24 |
+
raise NotImplementedError()
|
| 25 |
+
|
| 26 |
+
def is_bf16_supported(self) -> bool:
|
| 27 |
+
return is_torch_bf16_available_on_device(self.device_name())
|
| 28 |
+
|
| 29 |
+
def manual_seed(self, seed: int):
|
| 30 |
+
random.seed(seed)
|
| 31 |
+
torch.manual_seed(seed)
|
| 32 |
+
|
| 33 |
+
def empty_cache(self):
|
| 34 |
+
raise NotImplementedError()
|
| 35 |
+
|
| 36 |
+
def use_deterministic_algorithms(self, mode: bool):
|
| 37 |
+
torch.use_deterministic_algorithms(mode)
|
| 38 |
+
|
| 39 |
+
def allow_tf32(self, mode: bool):
|
| 40 |
+
raise NotImplementedError()
|
| 41 |
+
|
| 42 |
+
def set_rng_state(self, device, state):
|
| 43 |
+
raise NotImplementedError()
|
| 44 |
+
|
| 45 |
+
def get_rng_state(self, device):
|
| 46 |
+
raise NotImplementedError()
|
| 47 |
+
|
| 48 |
+
def fork_rng(self, rng_devices: list):
|
| 49 |
+
return torch.random.fork_rng(
|
| 50 |
+
devices=rng_devices, device_type=self.device_name()
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
def autocast(self, **kwargs):
|
| 54 |
+
return NoneContexts()
|
| 55 |
+
|
| 56 |
+
def init_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
|
| 57 |
+
return torch.empty_like(tensor)
|
| 58 |
+
|
| 59 |
+
def index_fill(
|
| 60 |
+
self, input: torch.Tensor, dim: int, index: torch.Tensor, value: torch.Tensor
|
| 61 |
+
):
|
| 62 |
+
input.index_fill_(dim, index, value)
|
| 63 |
+
|
| 64 |
+
def index_copy(
|
| 65 |
+
self, input: torch.Tensor, dim: int, index: torch.Tensor, source: torch.Tensor
|
| 66 |
+
):
|
| 67 |
+
input.index_copy_(dim, index, source)
|
| 68 |
+
|
| 69 |
+
def check_available(self):
|
| 70 |
+
if not self.is_available():
|
| 71 |
+
logging.error(f"{self.name()} not available.")
|
| 72 |
+
return False
|
| 73 |
+
if not self.is_initialized():
|
| 74 |
+
logging.error(f"{self.name()} not initialized.")
|
| 75 |
+
return False
|
| 76 |
+
logging.info(f"{self.name()} initialized successfully.")
|
| 77 |
+
return True
|
c2cite/executors/cpu.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from .common import BasicExecutor
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class CPUExecutor(BasicExecutor):
|
| 10 |
+
def __init__(self) -> None:
|
| 11 |
+
super().__init__()
|
| 12 |
+
|
| 13 |
+
def name(self) -> str:
|
| 14 |
+
return "CPU"
|
| 15 |
+
|
| 16 |
+
def device_name(self) -> str:
|
| 17 |
+
return "cpu"
|
| 18 |
+
|
| 19 |
+
def is_available(self) -> bool:
|
| 20 |
+
return True
|
| 21 |
+
|
| 22 |
+
def is_initialized(self) -> bool:
|
| 23 |
+
return False
|
| 24 |
+
|
| 25 |
+
def empty_cache(self):
|
| 26 |
+
pass
|
| 27 |
+
|
| 28 |
+
def allow_tf32(self, mode: bool):
|
| 29 |
+
assert not mode, "Enabling tf32 for CPU."
|
| 30 |
+
|
| 31 |
+
def set_rng_state(self, device: int, state: torch.Tensor):
|
| 32 |
+
assert device == 0
|
| 33 |
+
torch.set_rng_state(state)
|
| 34 |
+
|
| 35 |
+
def get_rng_state(self, device: int):
|
| 36 |
+
assert device == 0
|
| 37 |
+
return torch.get_rng_state()
|
| 38 |
+
|
| 39 |
+
@contextlib.contextmanager
|
| 40 |
+
def fork_rng(self, rng_devices: list):
|
| 41 |
+
# TODO: change to official implementation
|
| 42 |
+
assert len(rng_devices) == 0
|
| 43 |
+
cpu_rng_state = torch.get_rng_state()
|
| 44 |
+
try:
|
| 45 |
+
yield
|
| 46 |
+
finally:
|
| 47 |
+
torch.set_rng_state(cpu_rng_state)
|
| 48 |
+
|
| 49 |
+
def check_available(self):
|
| 50 |
+
logging.info(f"{self.name()} initialized successfully.")
|
| 51 |
+
return True
|
c2cite/executors/cuda.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from .common import BasicExecutor
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class CUDAExecutor(BasicExecutor):
|
| 7 |
+
def __init__(self) -> None:
|
| 8 |
+
super().__init__()
|
| 9 |
+
torch.cuda.init()
|
| 10 |
+
|
| 11 |
+
def name(self) -> str:
|
| 12 |
+
return "NVIDIA CUDA"
|
| 13 |
+
|
| 14 |
+
def device_name(self) -> str:
|
| 15 |
+
return "cuda"
|
| 16 |
+
|
| 17 |
+
def default_device_name(self) -> str:
|
| 18 |
+
return "cuda:0"
|
| 19 |
+
|
| 20 |
+
def is_available(self) -> bool:
|
| 21 |
+
return torch.cuda.is_available()
|
| 22 |
+
|
| 23 |
+
def is_initialized(self) -> bool:
|
| 24 |
+
return torch.cuda.is_initialized()
|
| 25 |
+
|
| 26 |
+
def is_bf16_supported(self) -> bool:
|
| 27 |
+
return torch.cuda.is_bf16_supported()
|
| 28 |
+
|
| 29 |
+
def manual_seed(self, seed: int):
|
| 30 |
+
super().manual_seed(seed)
|
| 31 |
+
torch.cuda.manual_seed_all(seed)
|
| 32 |
+
|
| 33 |
+
def empty_cache(self):
|
| 34 |
+
torch.cuda.empty_cache()
|
| 35 |
+
|
| 36 |
+
def use_deterministic_algorithms(self, mode: bool):
|
| 37 |
+
torch.backends.cudnn.benchmark = not mode
|
| 38 |
+
torch.backends.cudnn.deterministic = mode
|
| 39 |
+
|
| 40 |
+
def allow_tf32(self, mode: bool):
|
| 41 |
+
torch.backends.cudnn.allow_tf32 = mode
|
| 42 |
+
torch.backends.cuda.matmul.allow_tf32 = mode
|
| 43 |
+
|
| 44 |
+
def set_rng_state(self, device, state):
|
| 45 |
+
with torch.cuda.device(device):
|
| 46 |
+
return torch.cuda.set_rng_state(state)
|
| 47 |
+
|
| 48 |
+
def get_rng_state(self, device):
|
| 49 |
+
with torch.cuda.device(device):
|
| 50 |
+
return torch.cuda.get_rng_state()
|
| 51 |
+
|
| 52 |
+
def autocast(self, **kwargs):
|
| 53 |
+
return torch.cuda.amp.autocast(**kwargs)
|
c2cite/executors/mps.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from .common import BasicExecutor
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class MPSExecutor(BasicExecutor):
|
| 9 |
+
def __init__(self) -> None:
|
| 10 |
+
super().__init__()
|
| 11 |
+
|
| 12 |
+
def name(self) -> str:
|
| 13 |
+
return "APPLE MPS"
|
| 14 |
+
|
| 15 |
+
def device_name(self) -> str:
|
| 16 |
+
return "mps"
|
| 17 |
+
|
| 18 |
+
def is_available(self) -> bool:
|
| 19 |
+
return torch.backends.mps.is_available() and torch.backends.mps.is_built()
|
| 20 |
+
|
| 21 |
+
def is_initialized(self) -> bool:
|
| 22 |
+
# TODO: change to official implementation
|
| 23 |
+
return not torch.mps._is_in_bad_fork()
|
| 24 |
+
|
| 25 |
+
def manual_seed(self, seed: int):
|
| 26 |
+
super().manual_seed(seed)
|
| 27 |
+
torch.mps.manual_seed(seed)
|
| 28 |
+
|
| 29 |
+
def empty_cache(self):
|
| 30 |
+
torch.mps.empty_cache()
|
| 31 |
+
|
| 32 |
+
def allow_tf32(self, mode: bool):
|
| 33 |
+
assert not mode, "Enabling tf32 for MPS devices."
|
| 34 |
+
|
| 35 |
+
def set_rng_state(self, device: int, state: torch.Tensor):
|
| 36 |
+
assert device == 0
|
| 37 |
+
return torch.mps.set_rng_state(state)
|
| 38 |
+
|
| 39 |
+
def get_rng_state(self, device: int):
|
| 40 |
+
assert device == 0
|
| 41 |
+
return torch.mps.get_rng_state()
|
| 42 |
+
|
| 43 |
+
@contextlib.contextmanager
|
| 44 |
+
def fork_rng(self, rng_devices: list):
|
| 45 |
+
# TODO: change to official implementation
|
| 46 |
+
assert len(rng_devices) == 1 and rng_devices[0] == 0
|
| 47 |
+
cpu_rng_state = torch.get_rng_state()
|
| 48 |
+
device_rng_states = torch.mps.get_rng_state()
|
| 49 |
+
try:
|
| 50 |
+
yield
|
| 51 |
+
finally:
|
| 52 |
+
torch.set_rng_state(cpu_rng_state)
|
| 53 |
+
torch.mps.set_rng_state(device_rng_states)
|
| 54 |
+
|
| 55 |
+
def autocast(self, **kwargs):
|
| 56 |
+
# TODO: change to official implementation
|
| 57 |
+
# running with compatible mode
|
| 58 |
+
return torch.cuda.amp.autocast(**kwargs)
|
| 59 |
+
|
| 60 |
+
def init_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
|
| 61 |
+
return torch.zeros_like(tensor)
|
| 62 |
+
|
| 63 |
+
def index_fill(
|
| 64 |
+
self, input: torch.Tensor, dim: int, index: torch.Tensor, value: torch.Tensor
|
| 65 |
+
):
|
| 66 |
+
pass
|
| 67 |
+
|
| 68 |
+
def index_copy(
|
| 69 |
+
self, input: torch.Tensor, dim: int, index: torch.Tensor, source: torch.Tensor
|
| 70 |
+
):
|
| 71 |
+
input.index_add_(dim, index, source)
|
c2cite/generator.py
ADDED
|
@@ -0,0 +1,669 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import sys
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import re
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
|
| 10 |
+
from moe_peft.common import LLMBatchConfig, LLMModelInput, Tokens, cache_factory
|
| 11 |
+
from moe_peft.executors import executor
|
| 12 |
+
from moe_peft.model import LLMModel
|
| 13 |
+
from moe_peft.prompter import Prompter
|
| 14 |
+
from moe_peft.tokenizer import Tokenizer
|
| 15 |
+
from moe_peft.solutions import get_output
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class GenerateData:
|
| 20 |
+
adapter_name_: str = None
|
| 21 |
+
prompt_index_: int = None
|
| 22 |
+
prefix_length_: int = None
|
| 23 |
+
raw_tokens_: Tokens = None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class GenerateConfig:
|
| 28 |
+
adapter_name: str = None
|
| 29 |
+
prompts: List[Union[str, Tuple[str, str]]] = None
|
| 30 |
+
prompt_template: str = None
|
| 31 |
+
# Generate Arguments
|
| 32 |
+
batch_size: int = 8
|
| 33 |
+
stop_token: str = None
|
| 34 |
+
temperature: float = 1
|
| 35 |
+
top_p: float = 0.9
|
| 36 |
+
top_k: float = 50
|
| 37 |
+
do_sample: bool = True
|
| 38 |
+
repetition_penalty: float = 1.1
|
| 39 |
+
renormalize_logits: bool = True
|
| 40 |
+
# Do not set these manually
|
| 41 |
+
prompter_: Prompter = None
|
| 42 |
+
stop_token_: torch.Tensor = None
|
| 43 |
+
data_: List[GenerateData] = None
|
| 44 |
+
|
| 45 |
+
# Set prompt_template_ to enable the prompter
|
| 46 |
+
def generate_prompt(self, instruction: str, input: str = None) -> str:
|
| 47 |
+
if self.prompter_ is None:
|
| 48 |
+
self.prompter_ = Prompter(self.prompt_template)
|
| 49 |
+
|
| 50 |
+
return self.prompter_.generate_prompt(instruction=instruction, input=input)
|
| 51 |
+
|
| 52 |
+
def get_prompts(self) -> List[str]:
|
| 53 |
+
prompts = []
|
| 54 |
+
for prompt in self.prompts:
|
| 55 |
+
args = prompt if isinstance(prompt, Tuple) else (prompt, None)
|
| 56 |
+
prompts.append(self.generate_prompt(*args))
|
| 57 |
+
|
| 58 |
+
return prompts
|
| 59 |
+
|
| 60 |
+
def get_response(self, output: str) -> str:
|
| 61 |
+
if self.prompter_ is None:
|
| 62 |
+
return output.strip()
|
| 63 |
+
else:
|
| 64 |
+
return self.prompter_.get_response(output)
|
| 65 |
+
|
| 66 |
+
def reset_parameters(self):
|
| 67 |
+
self.prompter_ = Prompter(self.prompt_template)
|
| 68 |
+
self.stop_token_ = None
|
| 69 |
+
self.data_ = []
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _logits_sample_top_p(probs, p, filter_value=float("-inf"), min_tokens_to_keep=1):
|
| 73 |
+
sorted_logits, sorted_indices = torch.sort(probs, descending=False)
|
| 74 |
+
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
| 75 |
+
sorted_indices_to_remove = cumulative_probs <= (1 - p)
|
| 76 |
+
sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0
|
| 77 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
| 78 |
+
1, sorted_indices, sorted_indices_to_remove
|
| 79 |
+
)
|
| 80 |
+
return probs.masked_fill(indices_to_remove, filter_value)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _logits_sample_top_k(probs, k, filter_value=float("-inf")):
|
| 84 |
+
top_k = min(k, probs.size(-1)) # Safety check
|
| 85 |
+
indices_to_remove = probs < torch.topk(probs, top_k)[0][..., -1, None]
|
| 86 |
+
return probs.masked_fill(indices_to_remove, filter_value)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _logits_repetition_penalty(prev_tokens, probs, penalty):
|
| 90 |
+
score = torch.gather(probs, 1, prev_tokens)
|
| 91 |
+
score = torch.where(score < 0, score * penalty, score / penalty)
|
| 92 |
+
probs.scatter_(1, prev_tokens, score)
|
| 93 |
+
return probs
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def id2token(x):
|
| 97 |
+
if x == 0:
|
| 98 |
+
return 128002
|
| 99 |
+
elif x == 1:
|
| 100 |
+
return 128003
|
| 101 |
+
elif x == 2:
|
| 102 |
+
return 128004
|
| 103 |
+
elif x == 3:
|
| 104 |
+
return 128005
|
| 105 |
+
elif x == 4:
|
| 106 |
+
return 128008
|
| 107 |
+
elif x >= 5:
|
| 108 |
+
return 128005 + x
|
| 109 |
+
else:
|
| 110 |
+
assert False, "wrong router"
|
| 111 |
+
|
| 112 |
+
def logits_process(
|
| 113 |
+
probs: torch.Tensor,
|
| 114 |
+
prev_tokens: torch.Tensor,
|
| 115 |
+
cite_flag = False,
|
| 116 |
+
temperature=0.9,
|
| 117 |
+
top_p=0,
|
| 118 |
+
top_k=0,
|
| 119 |
+
do_sample=True,
|
| 120 |
+
repetition_penalty=1.01,
|
| 121 |
+
renormalize_logits=True,
|
| 122 |
+
):
|
| 123 |
+
if cite_flag == False:
|
| 124 |
+
process_conditions = any([repetition_penalty > 0])
|
| 125 |
+
sample_conditions = any([temperature > 0, top_p > 0 and top_p <= 1.0, top_k > 0])
|
| 126 |
+
|
| 127 |
+
if not do_sample and sample_conditions:
|
| 128 |
+
do_sample = True
|
| 129 |
+
logging.warn("do_sample force to enabled.")
|
| 130 |
+
|
| 131 |
+
if repetition_penalty > 0:
|
| 132 |
+
probs = _logits_repetition_penalty(prev_tokens, probs, repetition_penalty)
|
| 133 |
+
|
| 134 |
+
if process_conditions and renormalize_logits:
|
| 135 |
+
probs = probs.log_softmax(-1)
|
| 136 |
+
|
| 137 |
+
if temperature > 0:
|
| 138 |
+
probs = probs / temperature
|
| 139 |
+
|
| 140 |
+
if top_k > 0:
|
| 141 |
+
probs = _logits_sample_top_k(probs, top_k)
|
| 142 |
+
|
| 143 |
+
if top_p > 0 and top_p <= 1.0:
|
| 144 |
+
probs = _logits_sample_top_p(probs, top_p)
|
| 145 |
+
|
| 146 |
+
if sample_conditions and renormalize_logits:
|
| 147 |
+
probs = probs.log_softmax(-1)
|
| 148 |
+
else:
|
| 149 |
+
do_sample = False
|
| 150 |
+
|
| 151 |
+
if do_sample:
|
| 152 |
+
probs = torch.softmax(probs, dim=-1)
|
| 153 |
+
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
|
| 154 |
+
else:
|
| 155 |
+
next_token = torch.argmax(probs, dim=-1)
|
| 156 |
+
|
| 157 |
+
if cite_flag:
|
| 158 |
+
for i in range(probs.shape[0]):
|
| 159 |
+
next_token[i] = id2token(next_token[i] + 1)
|
| 160 |
+
return next_token.reshape(-1)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def _extract_effective_tokens(
|
| 164 |
+
tokenizer: Tokenizer,
|
| 165 |
+
prefix_length: int,
|
| 166 |
+
tokens: Tokens,
|
| 167 |
+
remove_prefix=True,
|
| 168 |
+
remove_pad=True,
|
| 169 |
+
remove_eos=True,
|
| 170 |
+
):
|
| 171 |
+
if remove_prefix:
|
| 172 |
+
tokens = tokens[prefix_length:]
|
| 173 |
+
|
| 174 |
+
if remove_pad and tokenizer.pad_id_ in tokens:
|
| 175 |
+
pad_idx = tokens.index(tokenizer.pad_id_)
|
| 176 |
+
tokens = tokens[:pad_idx]
|
| 177 |
+
|
| 178 |
+
if remove_eos and tokenizer.eos_id_ in tokens:
|
| 179 |
+
stop_idx = tokens.index(tokenizer.eos_id_)
|
| 180 |
+
tokens = tokens[:stop_idx]
|
| 181 |
+
|
| 182 |
+
return tokens
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def _gen_outputs(
|
| 186 |
+
tokenizer: Tokenizer,
|
| 187 |
+
config_dict: Dict[str, GenerateConfig],
|
| 188 |
+
current_jobs: List[GenerateData],
|
| 189 |
+
tokens: torch.Tensor,
|
| 190 |
+
):
|
| 191 |
+
tokens = tokens.tolist()
|
| 192 |
+
packed_outputs: Dict[str, List[str]] = {}
|
| 193 |
+
for idx, data in enumerate(current_jobs):
|
| 194 |
+
output = config_dict[data.adapter_name_].get_response(
|
| 195 |
+
tokenizer.decode(
|
| 196 |
+
_extract_effective_tokens(
|
| 197 |
+
tokenizer,
|
| 198 |
+
data.prefix_length_,
|
| 199 |
+
tokens[idx],
|
| 200 |
+
remove_prefix=True,
|
| 201 |
+
remove_pad=True,
|
| 202 |
+
remove_eos=True,
|
| 203 |
+
)
|
| 204 |
+
)
|
| 205 |
+
)
|
| 206 |
+
if data.adapter_name_ in packed_outputs:
|
| 207 |
+
packed_outputs[data.adapter_name_].append(output)
|
| 208 |
+
else:
|
| 209 |
+
packed_outputs[data.adapter_name_] = [output]
|
| 210 |
+
|
| 211 |
+
return packed_outputs
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def _dispatch_task_in(
|
| 215 |
+
configs: List[GenerateConfig],# config.data_, config.batch_size, config, config.adapter_name
|
| 216 |
+
concurrent_jobs: int,
|
| 217 |
+
strategy: str = "fair",
|
| 218 |
+
):
|
| 219 |
+
assert strategy in ["fair", "fifo"], f"Unknown dispatch strategy {strategy}"
|
| 220 |
+
current_jobs = []
|
| 221 |
+
batch_config = []
|
| 222 |
+
input_tokens = []
|
| 223 |
+
max_tokens_len = 0
|
| 224 |
+
min_tokens_len = sys.maxsize
|
| 225 |
+
for config in configs:
|
| 226 |
+
if len(batch_config) >= concurrent_jobs:
|
| 227 |
+
break
|
| 228 |
+
|
| 229 |
+
if len(config.data_) == 0:
|
| 230 |
+
continue
|
| 231 |
+
|
| 232 |
+
if strategy == "fair":
|
| 233 |
+
per_task_jobs = max(concurrent_jobs // len(configs), 1)
|
| 234 |
+
else:
|
| 235 |
+
per_task_jobs = concurrent_jobs
|
| 236 |
+
|
| 237 |
+
per_task_jobs = min(per_task_jobs, config.batch_size)
|
| 238 |
+
|
| 239 |
+
batch_start_idx = len(input_tokens)
|
| 240 |
+
while per_task_jobs > 0 and len(config.data_) > 0:
|
| 241 |
+
per_task_jobs = per_task_jobs - 1
|
| 242 |
+
data = config.data_.pop(0)
|
| 243 |
+
current_jobs.append(data)
|
| 244 |
+
tokens = data.raw_tokens_
|
| 245 |
+
max_tokens_len = max(len(tokens), max_tokens_len)
|
| 246 |
+
min_tokens_len = min(len(tokens), min_tokens_len)
|
| 247 |
+
input_tokens.append(tokens)
|
| 248 |
+
|
| 249 |
+
batch_config.append(
|
| 250 |
+
LLMBatchConfig(
|
| 251 |
+
adapter_name_=config.adapter_name,
|
| 252 |
+
batch_start_idx_=batch_start_idx,
|
| 253 |
+
batch_end_idx_=len(input_tokens),
|
| 254 |
+
)
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
return (
|
| 258 |
+
current_jobs,
|
| 259 |
+
batch_config,
|
| 260 |
+
input_tokens,
|
| 261 |
+
max_tokens_len,
|
| 262 |
+
min_tokens_len,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def _dispatch_task_out(
|
| 267 |
+
tokenizer: Tokenizer,
|
| 268 |
+
# config_dict: Dict[str, GenerateConfig],
|
| 269 |
+
current_jobs: List[GenerateData],
|
| 270 |
+
tokens: torch.Tensor,
|
| 271 |
+
stop_reached: torch.Tensor,
|
| 272 |
+
attentions,
|
| 273 |
+
hides,
|
| 274 |
+
require_attention,
|
| 275 |
+
require_hide
|
| 276 |
+
):
|
| 277 |
+
"""hide = []
|
| 278 |
+
if require_hide != -1:
|
| 279 |
+
ans_len = len(hides)
|
| 280 |
+
for i in range(len(hides[0])):
|
| 281 |
+
hide.append(torch.cat([t[i] for t in hides], dim = 1))
|
| 282 |
+
if require_attention != -1:
|
| 283 |
+
ans_len = len(attentions)
|
| 284 |
+
for i in range(len(hides[0])):
|
| 285 |
+
hide.append(torch.cat([t[i] for t in attentions], dim = 1))"""
|
| 286 |
+
tokens = tokens.tolist()
|
| 287 |
+
stop_reached = stop_reached.view(-1).tolist()
|
| 288 |
+
packed_outputs: List[str] = []
|
| 289 |
+
packed_add = []
|
| 290 |
+
running_jobs: List[GenerateData] = []
|
| 291 |
+
for idx, data in enumerate(current_jobs): # 这里的data是evaluate data, 但是应该是generate data
|
| 292 |
+
if stop_reached[idx]:
|
| 293 |
+
output_tokens = _extract_effective_tokens(
|
| 294 |
+
tokenizer,
|
| 295 |
+
data.prefix_length_,
|
| 296 |
+
tokens[idx],
|
| 297 |
+
remove_prefix=True,
|
| 298 |
+
remove_pad=True,
|
| 299 |
+
remove_eos=True,
|
| 300 |
+
)
|
| 301 |
+
#if len(hide):
|
| 302 |
+
# get_output(hide, output_tokens, ans_len)
|
| 303 |
+
output_s = tokenizer.decode(output_tokens).strip()
|
| 304 |
+
output = re.sub(r'<\|reserved_special_token_(\d+)\|>', r'[\1]', output_s)
|
| 305 |
+
packed_outputs.append(output)
|
| 306 |
+
else:
|
| 307 |
+
data.tokens = _extract_effective_tokens(
|
| 308 |
+
tokenizer,
|
| 309 |
+
data.prefix_length_,
|
| 310 |
+
tokens[idx],
|
| 311 |
+
remove_prefix=False,
|
| 312 |
+
remove_pad=True,
|
| 313 |
+
remove_eos=False,
|
| 314 |
+
)
|
| 315 |
+
running_jobs.append(data)
|
| 316 |
+
|
| 317 |
+
return packed_outputs, running_jobs
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def _batch_generate(
|
| 321 |
+
model: LLMModel,
|
| 322 |
+
tokenizer: Tokenizer,
|
| 323 |
+
max_gen_len: Optional[int],
|
| 324 |
+
use_cache: bool,
|
| 325 |
+
require_attention: Optional[int],
|
| 326 |
+
require_hide: Optional[int],
|
| 327 |
+
cache_implementation: Optional[str],
|
| 328 |
+
stream_callback: Optional[Callable],
|
| 329 |
+
#config_dict: Dict[str, GenerateConfig],
|
| 330 |
+
current_jobs: List[GenerateData],
|
| 331 |
+
batch_config: List[LLMBatchConfig],
|
| 332 |
+
input_tokens: List[Tokens],
|
| 333 |
+
max_tokens_len: int,
|
| 334 |
+
min_tokens_len: int,
|
| 335 |
+
):
|
| 336 |
+
executor.empty_cache()
|
| 337 |
+
device = torch.device(model.device_)
|
| 338 |
+
batch_size = len(input_tokens)
|
| 339 |
+
if max_gen_len is None:
|
| 340 |
+
max_gen_len = model.config_.max_seq_len_ - max_tokens_len
|
| 341 |
+
total_len = min(model.config_.max_seq_len_, max_gen_len + max_tokens_len)
|
| 342 |
+
past_key_values = (
|
| 343 |
+
cache_factory(
|
| 344 |
+
cache_implementation=cache_implementation,
|
| 345 |
+
config=model.model_.model_config(),
|
| 346 |
+
batch_size=batch_size,
|
| 347 |
+
max_cache_len=total_len,
|
| 348 |
+
)
|
| 349 |
+
if cache_implementation is not None
|
| 350 |
+
else None
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
tokens = torch.full(
|
| 354 |
+
(batch_size, total_len), tokenizer.pad_id_, dtype=torch.int64, device=device
|
| 355 |
+
)
|
| 356 |
+
# print(f"yyyyyy:\n{tokenizer.decode(input_tokens[0])}")
|
| 357 |
+
for k, t in enumerate(input_tokens):
|
| 358 |
+
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.int64, device=device)
|
| 359 |
+
def condition(i):
|
| 360 |
+
return (128010 <= i <= 128255) or i in {128005, 128004, 128003, 128002, 128008}
|
| 361 |
+
prompt_len = len(input_tokens[0])
|
| 362 |
+
cite = [index for index, value in enumerate(input_tokens[0]) if condition(value)]
|
| 363 |
+
cite_v = [value for value in input_tokens[0] if condition(value)]
|
| 364 |
+
|
| 365 |
+
prev_pos = 0
|
| 366 |
+
stop_reached = torch.tensor([False] * batch_size, device=device)
|
| 367 |
+
input_text_mask = tokens != tokenizer.pad_id_
|
| 368 |
+
|
| 369 |
+
hidden_states = []
|
| 370 |
+
hidden_attentions = []
|
| 371 |
+
#arti_mask = torch.ones(batch_size, total_len, device=device, dtype=torch.int64)
|
| 372 |
+
cite_start = -1
|
| 373 |
+
#flag = -1
|
| 374 |
+
plac = []
|
| 375 |
+
for cur_pos in range(min_tokens_len, total_len):
|
| 376 |
+
input_data = LLMModelInput(
|
| 377 |
+
batch_configs_=batch_config,
|
| 378 |
+
batch_tokens_=tokens[:, prev_pos:cur_pos].tolist(),
|
| 379 |
+
#batch_masks_ = arti_mask,############
|
| 380 |
+
batch_cites = [cite],
|
| 381 |
+
batch_cites_value = [cite_v],
|
| 382 |
+
batch_docs = [current_jobs[0].citation_tokens],
|
| 383 |
+
batch_prompt_len = [prompt_len],
|
| 384 |
+
inference_mode_=True,
|
| 385 |
+
)
|
| 386 |
+
# print(f"fuck:\n{tokenizer.decode(tokens[0, prev_pos:cur_pos])}")
|
| 387 |
+
outputs = model.forward(input_data, past_key_values)
|
| 388 |
+
#hidden_states.append(hidden_state)
|
| 389 |
+
#hidden_attentions.append(hidden_attention)
|
| 390 |
+
|
| 391 |
+
#if flag != -1:
|
| 392 |
+
#输出attention
|
| 393 |
+
|
| 394 |
+
for output in outputs:
|
| 395 |
+
#config = config_dict[output.adapter_name]
|
| 396 |
+
start_idx = output.batch_start_idx_
|
| 397 |
+
end_idx = output.batch_end_idx_
|
| 398 |
+
|
| 399 |
+
next_token = logits_process(
|
| 400 |
+
output.logits[:, -1],#####看看它的维度,这里是乘完doc的,应该是logits
|
| 401 |
+
tokens[start_idx:end_idx, :cur_pos],
|
| 402 |
+
cite_flag = output.cite_flag,
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
next_token = torch.where(
|
| 406 |
+
input_text_mask[start_idx:end_idx, cur_pos],
|
| 407 |
+
tokens[start_idx:end_idx, cur_pos],
|
| 408 |
+
next_token,
|
| 409 |
+
).to(torch.int64)
|
| 410 |
+
#print(tokenizer.decode(next_token))
|
| 411 |
+
if output.cite_flag == True:# 记得查看input_text_mask的形状
|
| 412 |
+
for i in range(start_idx, end_idx):
|
| 413 |
+
if input_text_mask[i, cur_pos]:#纯废话,这时候考虑上多batch了
|
| 414 |
+
continue
|
| 415 |
+
cite.append(cur_pos)
|
| 416 |
+
cite_v.append(next_token)
|
| 417 |
+
|
| 418 |
+
tokens[start_idx:end_idx, cur_pos] = next_token
|
| 419 |
+
stop_criteria = (~input_text_mask[start_idx:end_idx, cur_pos]) & (
|
| 420 |
+
next_token == torch.tensor(
|
| 421 |
+
[tokenizer.eos_id_], dtype=torch.int64, device=device
|
| 422 |
+
)
|
| 423 |
+
)
|
| 424 |
+
stop_reached[start_idx:end_idx] |= stop_criteria
|
| 425 |
+
if cite_start != -1:
|
| 426 |
+
if tokenizer.decode(next_token)[-1] in ['.','!','?']:
|
| 427 |
+
#arti_mask[start_idx:end_idx, cite_start:cur_pos] = 0
|
| 428 |
+
#tokens[start_idx:end_idx, cur_pos] = tokenizer.encode(tokenizer.decode(next_token)[-1])[-1]
|
| 429 |
+
cite_start = -1
|
| 430 |
+
if tokenizer.decode(next_token)[-1] in ['0','1','2','3','4','5','6','7','8','9']:
|
| 431 |
+
plac.append(cur_pos)
|
| 432 |
+
# tokens[start_idx:end_idx, cur_pos] = (tokens[start_idx:end_idx, cur_pos] + 2)
|
| 433 |
+
|
| 434 |
+
if tokenizer.decode(next_token)[-1] == '[' or tokenizer.decode(next_token) == '[':
|
| 435 |
+
if cite_start == -1:
|
| 436 |
+
cite_start = cur_pos
|
| 437 |
+
#flag = cur_pos
|
| 438 |
+
|
| 439 |
+
stop_reached |= total_len - cur_pos == 1
|
| 440 |
+
|
| 441 |
+
if any(stop_reached):
|
| 442 |
+
break
|
| 443 |
+
|
| 444 |
+
if use_cache:
|
| 445 |
+
prev_pos = cur_pos
|
| 446 |
+
|
| 447 |
+
"""input_data = LLMModelInput(
|
| 448 |
+
batch_configs_=batch_config,
|
| 449 |
+
batch_tokens_=tokens[:,:hidden_attention.shape[0]].tolist(),
|
| 450 |
+
inference_mode_=True,
|
| 451 |
+
)"""
|
| 452 |
+
# print(f"fuck:\n{tokenizer.decode(tokens[0, prev_pos:cur_pos])}")
|
| 453 |
+
#outputs, _, attn = model.forward(input_data, None, require_attention, require_hide)
|
| 454 |
+
"""for i in plac:
|
| 455 |
+
|
| 456 |
+
plt.figure(figsize=(hidden_attention.shape[0], 5), dpi = 50)
|
| 457 |
+
print("painting")
|
| 458 |
+
plt.bar(range(hidden_attention.shape[0]), attn[:,i].cpu().numpy())
|
| 459 |
+
plt.xticks(range(hidden_attention.shape[0]), [tokenizer.decode(j) for j in tokens[0][:hidden_attention.shape[0]]], fontsize = 8)
|
| 460 |
+
plt.savefig("high_res_heatmap.svg", dpi=50)
|
| 461 |
+
print("ok~")
|
| 462 |
+
input()
|
| 463 |
+
"""
|
| 464 |
+
"""attn[torch.arange(hidden_attention.shape[0]), torch.arange(hidden_attention.shape[0])] = 0.0
|
| 465 |
+
attn = torch.nn.functional.normalize(attn, p=2, dim=1)
|
| 466 |
+
attn = attn[min_tokens_len:hidden_attention.shape[0],min_tokens_len:hidden_attention.shape[0]]
|
| 467 |
+
|
| 468 |
+
plt.figure(figsize=(hidden_attention.shape[0] - min_tokens_len, hidden_attention.shape[0] - min_tokens_len)) # 调整图像大小
|
| 469 |
+
plt.imshow(attn.cpu().numpy(), cmap='viridis', vmin = 0, vmax = 0.1)
|
| 470 |
+
plt.colorbar(label='Value')
|
| 471 |
+
plt.xticks(range(hidden_attention.shape[0] - min_tokens_len), [tokenizer.decode(i) for i in tokens[0][min_tokens_len:hidden_attention.shape[0]]], fontsize = 10)
|
| 472 |
+
plt.yticks(range(hidden_attention.shape[0] - min_tokens_len), [tokenizer.decode(i) for i in tokens[0][min_tokens_len:hidden_attention.shape[0]]], fontsize = 10)
|
| 473 |
+
plt.savefig("high_res_heatmap.png", dpi=200) # 保存为高分辨率图像
|
| 474 |
+
plt.show()
|
| 475 |
+
print("ok~")
|
| 476 |
+
input()"""
|
| 477 |
+
"""token2 = tokens * arti_mask
|
| 478 |
+
lst = token2[0].tolist()
|
| 479 |
+
lst = [ele for ele in lst if ele != 0]
|
| 480 |
+
tokens = torch.tensor(lst, dtype=torch.int64, device=device).unsqueeze(0)"""
|
| 481 |
+
|
| 482 |
+
return _dispatch_task_out(
|
| 483 |
+
tokenizer, current_jobs, tokens, stop_reached, hidden_states, hidden_attentions, require_attention, require_hide
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
def _batch_generate_original(
|
| 488 |
+
model: LLMModel,
|
| 489 |
+
tokenizer: Tokenizer,
|
| 490 |
+
max_gen_len: Optional[int],
|
| 491 |
+
use_cache: bool,
|
| 492 |
+
cache_implementation: Optional[str],
|
| 493 |
+
stream_callback: Optional[Callable],
|
| 494 |
+
config_dict: Dict[str, GenerateConfig],
|
| 495 |
+
current_jobs: List[GenerateData],
|
| 496 |
+
batch_config: List[LLMBatchConfig],
|
| 497 |
+
input_tokens: List[Tokens],
|
| 498 |
+
max_tokens_len: int,
|
| 499 |
+
min_tokens_len: int,
|
| 500 |
+
):
|
| 501 |
+
executor.empty_cache()
|
| 502 |
+
device = torch.device(model.device_)
|
| 503 |
+
batch_size = len(input_tokens)
|
| 504 |
+
if max_gen_len is None:
|
| 505 |
+
max_gen_len = model.config_.max_seq_len_ - max_tokens_len
|
| 506 |
+
total_len = min(model.config_.max_seq_len_, max_gen_len + max_tokens_len)
|
| 507 |
+
|
| 508 |
+
past_key_values = (
|
| 509 |
+
cache_factory(
|
| 510 |
+
cache_implementation=cache_implementation,
|
| 511 |
+
config=model.model_.model_config(),
|
| 512 |
+
batch_size=batch_size,
|
| 513 |
+
max_cache_len=total_len,
|
| 514 |
+
)
|
| 515 |
+
if cache_implementation is not None
|
| 516 |
+
else None
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
tokens = torch.full(
|
| 520 |
+
(batch_size, total_len), tokenizer.pad_id_, dtype=torch.int64, device=device
|
| 521 |
+
)
|
| 522 |
+
for k, t in enumerate(input_tokens):
|
| 523 |
+
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.int64, device=device)
|
| 524 |
+
|
| 525 |
+
prev_pos = 0
|
| 526 |
+
stop_reached = torch.tensor([False] * batch_size, device=device)
|
| 527 |
+
input_text_mask = tokens != tokenizer.pad_id_
|
| 528 |
+
for cur_pos in range(min_tokens_len, total_len):
|
| 529 |
+
input_data = LLMModelInput(
|
| 530 |
+
batch_configs_=batch_config,
|
| 531 |
+
batch_tokens_=tokens[:, prev_pos:cur_pos].tolist(),
|
| 532 |
+
inference_mode_=True,
|
| 533 |
+
)
|
| 534 |
+
outputs = model.forward(input_data, past_key_values)
|
| 535 |
+
for output in outputs:
|
| 536 |
+
config = config_dict[output.adapter_name]
|
| 537 |
+
start_idx = output.batch_start_idx_
|
| 538 |
+
end_idx = output.batch_end_idx_
|
| 539 |
+
|
| 540 |
+
next_token = logits_process(
|
| 541 |
+
output.logits[:, -1],
|
| 542 |
+
tokens[start_idx:end_idx, :cur_pos],
|
| 543 |
+
config.temperature,
|
| 544 |
+
config.top_p,
|
| 545 |
+
config.top_k,
|
| 546 |
+
config.do_sample,
|
| 547 |
+
config.repetition_penalty,
|
| 548 |
+
config.renormalize_logits,
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
next_token = torch.where(
|
| 552 |
+
input_text_mask[start_idx:end_idx, cur_pos],
|
| 553 |
+
tokens[start_idx:end_idx, cur_pos],
|
| 554 |
+
next_token,
|
| 555 |
+
).to(torch.int64)
|
| 556 |
+
tokens[start_idx:end_idx, cur_pos] = next_token
|
| 557 |
+
stop_criteria = (~input_text_mask[start_idx:end_idx, cur_pos]) & (
|
| 558 |
+
next_token == config.stop_token_
|
| 559 |
+
)
|
| 560 |
+
stop_reached[start_idx:end_idx] |= stop_criteria
|
| 561 |
+
|
| 562 |
+
stop_reached |= total_len - cur_pos == 1
|
| 563 |
+
|
| 564 |
+
if any(stop_reached):
|
| 565 |
+
break
|
| 566 |
+
|
| 567 |
+
if stream_callback is not None:
|
| 568 |
+
stream_callback(
|
| 569 |
+
cur_pos,
|
| 570 |
+
_gen_outputs(
|
| 571 |
+
tokenizer,
|
| 572 |
+
config_dict,
|
| 573 |
+
current_jobs,
|
| 574 |
+
tokens,
|
| 575 |
+
),
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
if use_cache:
|
| 579 |
+
prev_pos = cur_pos
|
| 580 |
+
|
| 581 |
+
return _dispatch_task_out(
|
| 582 |
+
tokenizer, config_dict, current_jobs, tokens, stop_reached
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
@torch.inference_mode()
|
| 587 |
+
def generate(
|
| 588 |
+
model: LLMModel,
|
| 589 |
+
tokenizer: Tokenizer,
|
| 590 |
+
configs: List[GenerateConfig],
|
| 591 |
+
max_gen_len: Optional[int] = None,
|
| 592 |
+
use_cache: bool = True,
|
| 593 |
+
dispatch_strategy: str = "fair",
|
| 594 |
+
concurrent_jobs: Optional[int] = None,
|
| 595 |
+
cache_implementation: Optional[str] = None,
|
| 596 |
+
stream_callback: Optional[Callable] = None,
|
| 597 |
+
):
|
| 598 |
+
if concurrent_jobs is None:
|
| 599 |
+
concurrent_jobs = len(configs)
|
| 600 |
+
logging.info(f"Setting concurrent jobs to {concurrent_jobs} automatically")
|
| 601 |
+
|
| 602 |
+
assert concurrent_jobs > 0
|
| 603 |
+
|
| 604 |
+
# prepare for generation
|
| 605 |
+
device = torch.device(model.device_)
|
| 606 |
+
config_dict = {}
|
| 607 |
+
for config in configs:
|
| 608 |
+
config.reset_parameters()
|
| 609 |
+
config_dict[config.adapter_name] = config
|
| 610 |
+
if config.stop_token is not None:
|
| 611 |
+
stop_token = tokenizer.encode(" " + config.stop_token, False)[-1]
|
| 612 |
+
else:
|
| 613 |
+
stop_token = tokenizer.eos_id_
|
| 614 |
+
config.stop_token_ = torch.tensor(
|
| 615 |
+
[stop_token], dtype=torch.int64, device=device
|
| 616 |
+
)
|
| 617 |
+
for idx, prompt in enumerate(config.prompts):
|
| 618 |
+
args = prompt if isinstance(prompt, Tuple) else (prompt, None)
|
| 619 |
+
tokens = tokenizer.encode(config.generate_prompt(*args))
|
| 620 |
+
assert (
|
| 621 |
+
len(tokens) < model.config_.max_seq_len_
|
| 622 |
+
), "Inputs exceeded max sequence length of model."
|
| 623 |
+
config.data_.append(
|
| 624 |
+
GenerateData(
|
| 625 |
+
adapter_name_=config.adapter_name,
|
| 626 |
+
prompt_index_=idx,
|
| 627 |
+
prefix_length_=len(tokens),
|
| 628 |
+
raw_tokens_=tokens,
|
| 629 |
+
)
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
if use_cache and cache_implementation is None:
|
| 633 |
+
cache_implementation = model.model_.cache_implementation()
|
| 634 |
+
if cache_implementation is None:
|
| 635 |
+
logging.warn(
|
| 636 |
+
"Cache disabled by model, use cache_implementation to force enable."
|
| 637 |
+
)
|
| 638 |
+
use_cache = False
|
| 639 |
+
|
| 640 |
+
packed_outputs: Dict[str, List] = {}
|
| 641 |
+
|
| 642 |
+
while True:# configs里的data在变,是调度的唯一指标
|
| 643 |
+
dispatch_args = _dispatch_task_in(configs, concurrent_jobs, dispatch_strategy)
|
| 644 |
+
# 包含:current_jobs, batch_config(LLMBatchConfig(taskname,start,end)),
|
| 645 |
+
# batch_tokens, max_lenth, min_length
|
| 646 |
+
if len(dispatch_args[0]) == 0:
|
| 647 |
+
break
|
| 648 |
+
|
| 649 |
+
outputs, running_jobs = _batch_generate(
|
| 650 |
+
model,
|
| 651 |
+
tokenizer,
|
| 652 |
+
max_gen_len,
|
| 653 |
+
use_cache,
|
| 654 |
+
cache_implementation,
|
| 655 |
+
stream_callback,
|
| 656 |
+
config_dict,
|
| 657 |
+
*dispatch_args,
|
| 658 |
+
)
|
| 659 |
+
|
| 660 |
+
for name, output in outputs.items():
|
| 661 |
+
if name in packed_outputs:
|
| 662 |
+
packed_outputs[name].extend(output)
|
| 663 |
+
else:
|
| 664 |
+
packed_outputs[name] = output
|
| 665 |
+
|
| 666 |
+
for data in running_jobs:
|
| 667 |
+
config_dict[data.adapter_name_].data_.append(data)
|
| 668 |
+
|
| 669 |
+
return packed_outputs
|
c2cite/model.py
ADDED
|
@@ -0,0 +1,1039 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
from typing import Dict, List, Optional, Tuple
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from huggingface_hub import snapshot_download
|
| 12 |
+
from transformers import AutoModelForCausalLM
|
| 13 |
+
|
| 14 |
+
from moe_peft.adapters import (
|
| 15 |
+
LoraMoeConfig,
|
| 16 |
+
MixLoraConfig,
|
| 17 |
+
MolaConfig,
|
| 18 |
+
lora_config_factory,
|
| 19 |
+
moe_layer_factory,
|
| 20 |
+
router_loss_factory,
|
| 21 |
+
)
|
| 22 |
+
from moe_peft.common import (
|
| 23 |
+
CHECKPOINT_CLASSES,
|
| 24 |
+
AdapterConfig,
|
| 25 |
+
Linear,
|
| 26 |
+
LLMCache,
|
| 27 |
+
LLMDecoder,
|
| 28 |
+
LLMForCausalLM,
|
| 29 |
+
LLMModelConfig,
|
| 30 |
+
LLMModelInput,
|
| 31 |
+
LLMModelOutput,
|
| 32 |
+
LLMMoeBlock,
|
| 33 |
+
LLMOutput,
|
| 34 |
+
LoraConfig,
|
| 35 |
+
unpack_router_logits,
|
| 36 |
+
)
|
| 37 |
+
from moe_peft.executors import executor
|
| 38 |
+
from moe_peft.models import from_pretrained
|
| 39 |
+
from moe_peft.tasks import SequenceClassificationTask, task_dict
|
| 40 |
+
from moe_peft.utils import is_package_available
|
| 41 |
+
|
| 42 |
+
if is_package_available("bitsandbytes"):
|
| 43 |
+
from transformers import BitsAndBytesConfig
|
| 44 |
+
else:
|
| 45 |
+
from moe_peft.utils import BitsAndBytesConfig
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class CasualOutputLayer(LLMOutput):
|
| 49 |
+
def __init__(self, vocab_size: int, weight: torch.nn.Linear):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.vocab_size_: int = vocab_size
|
| 52 |
+
self.lm_head_: torch.nn.Module = weight
|
| 53 |
+
|
| 54 |
+
def forward(self, data: torch.Tensor) -> torch.Tensor:
|
| 55 |
+
return self.lm_head_(data).float()
|
| 56 |
+
|
| 57 |
+
def loss(
|
| 58 |
+
self, input_ids: torch.Tensor, output_logits: torch.Tensor, labels,
|
| 59 |
+
cites: Optional[List] = None, cites_v: Optional[List] = None, prompt_lens: Optional[List] = None
|
| 60 |
+
) -> torch.Tensor:
|
| 61 |
+
if isinstance(labels, torch.Tensor):
|
| 62 |
+
labels = (
|
| 63 |
+
labels.clone()
|
| 64 |
+
.detach()
|
| 65 |
+
.to(dtype=torch.long, device=output_logits.device)
|
| 66 |
+
)
|
| 67 |
+
else:
|
| 68 |
+
labels = torch.tensor(labels, dtype=torch.long, device=output_logits.device)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
loss_fn = torch.nn.CrossEntropyLoss()
|
| 72 |
+
if cites:
|
| 73 |
+
for i in range(len(labels)):
|
| 74 |
+
for j in range(len(cites_v[i])):
|
| 75 |
+
labels[i][cites[i][j]] = -100
|
| 76 |
+
loss_fn = torch.nn.CrossEntropyLoss(ignore_index = -100)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
"""return loss_fn(
|
| 80 |
+
output_logits[..., :-1, :].contiguous().view(-1, self.vocab_size_),
|
| 81 |
+
labels[..., 1:].contiguous().view(-1),
|
| 82 |
+
)"""
|
| 83 |
+
ans = 0
|
| 84 |
+
for i in range(len(prompt_lens)):
|
| 85 |
+
ans += loss_fn(
|
| 86 |
+
output_logits[i, prompt_lens[i] - 1:-1, :].contiguous().view(-1, self.vocab_size_),
|
| 87 |
+
labels[i, prompt_lens[i]:].contiguous().view(-1),
|
| 88 |
+
)
|
| 89 |
+
return ans / len(prompt_lens)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class ClassificationOutputLayer(LLMOutput):
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
task_type: str,
|
| 96 |
+
num_labels: int,
|
| 97 |
+
label_dtype: torch.dtype,
|
| 98 |
+
hidden_size: int,
|
| 99 |
+
pad_token_id: int,
|
| 100 |
+
device: str,
|
| 101 |
+
weight: Optional[torch.Tensor],
|
| 102 |
+
):
|
| 103 |
+
super().__init__()
|
| 104 |
+
self.label_dtype_ = label_dtype
|
| 105 |
+
self.num_labels_ = num_labels
|
| 106 |
+
self.task_type_ = task_type
|
| 107 |
+
self.pad_id_ = pad_token_id
|
| 108 |
+
self.score_ = torch.nn.Linear(
|
| 109 |
+
hidden_size,
|
| 110 |
+
self.num_labels_,
|
| 111 |
+
bias=False,
|
| 112 |
+
dtype=torch.float32,
|
| 113 |
+
device=device,
|
| 114 |
+
)
|
| 115 |
+
if weight is None:
|
| 116 |
+
torch.nn.init.kaiming_normal_(self.score_.weight, a=math.sqrt(5))
|
| 117 |
+
else:
|
| 118 |
+
with torch.no_grad():
|
| 119 |
+
self.score_.weight.copy_(weight["classifier"])
|
| 120 |
+
|
| 121 |
+
def state_dict(self):
|
| 122 |
+
return {"classifier": self.score_.weight}
|
| 123 |
+
|
| 124 |
+
def forward(self, data: torch.Tensor) -> torch.Tensor:
|
| 125 |
+
return self.score_(data.to(torch.float32))
|
| 126 |
+
|
| 127 |
+
def loss(
|
| 128 |
+
self, input_ids: torch.Tensor, output_logits: torch.Tensor, labels
|
| 129 |
+
) -> torch.Tensor:
|
| 130 |
+
if isinstance(labels, torch.Tensor):
|
| 131 |
+
labels = (
|
| 132 |
+
labels.clone()
|
| 133 |
+
.detach()
|
| 134 |
+
.to(dtype=self.label_dtype_, device=output_logits.device)
|
| 135 |
+
)
|
| 136 |
+
else:
|
| 137 |
+
labels = torch.tensor(
|
| 138 |
+
labels, dtype=self.label_dtype_, device=output_logits.device
|
| 139 |
+
)
|
| 140 |
+
batch_size = input_ids.shape[0]
|
| 141 |
+
sequence_lengths = (torch.eq(input_ids, self.pad_id_).int().argmax(-1) - 1).to(
|
| 142 |
+
output_logits.device
|
| 143 |
+
)
|
| 144 |
+
pooled_logits = output_logits[
|
| 145 |
+
torch.arange(batch_size, device=output_logits.device), sequence_lengths
|
| 146 |
+
]
|
| 147 |
+
if self.task_type_ == "single_label_classification":
|
| 148 |
+
loss_fn = torch.nn.CrossEntropyLoss()
|
| 149 |
+
return loss_fn(pooled_logits.view(-1, self.num_labels_), labels.view(-1))
|
| 150 |
+
elif self.task_type_ == "multi_label_classification":
|
| 151 |
+
loss_fn = torch.nn.BCEWithLogitsLoss()
|
| 152 |
+
return loss_fn(pooled_logits, labels)
|
| 153 |
+
else:
|
| 154 |
+
raise ValueError(f"unknown task type {self.task_type_}")
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class OutputLayer(torch.nn.Module):
|
| 158 |
+
def __init__(self):
|
| 159 |
+
super().__init__()
|
| 160 |
+
self.layers_: Dict[str, torch.nn.Module] = {}
|
| 161 |
+
|
| 162 |
+
def forward(
|
| 163 |
+
self, data: torch.Tensor, input_args: LLMModelInput
|
| 164 |
+
) -> List[LLMModelOutput]:
|
| 165 |
+
outputs = []
|
| 166 |
+
for lora_config in input_args.batch_configs_:
|
| 167 |
+
adapter_name = lora_config.adapter_name_
|
| 168 |
+
start_idx = lora_config.batch_start_idx_
|
| 169 |
+
end_idx = lora_config.batch_end_idx_
|
| 170 |
+
|
| 171 |
+
assert adapter_name != "" and adapter_name in self.layers_
|
| 172 |
+
layer = self.layers_[adapter_name]
|
| 173 |
+
outputs.append(
|
| 174 |
+
LLMModelOutput(
|
| 175 |
+
adapter_name=adapter_name,
|
| 176 |
+
logits=layer.forward(data[start_idx:end_idx]),
|
| 177 |
+
loss_fn_=layer.loss,
|
| 178 |
+
)
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
return outputs
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def init_lora_layer_weight(
|
| 185 |
+
transformer_layer: LLMDecoder,
|
| 186 |
+
llm_config: LLMModelConfig,
|
| 187 |
+
lora_config: LoraConfig,
|
| 188 |
+
lora_weights: Optional[Dict[str, torch.Tensor]],
|
| 189 |
+
):
|
| 190 |
+
target_modules = lora_config.target_modules_
|
| 191 |
+
attn_state_dict, mlp_state_dict = transformer_layer.state_dict()
|
| 192 |
+
attn_state_dict: Dict[str, torch.Tensor]
|
| 193 |
+
mlp_state_dict: Dict[str, torch.Tensor]
|
| 194 |
+
all_state_dict: Dict[str, torch.Tensor] = copy.copy(attn_state_dict)
|
| 195 |
+
all_state_dict.update(mlp_state_dict)
|
| 196 |
+
moe_init_strategy = "none"
|
| 197 |
+
if isinstance(lora_config, MixLoraConfig):
|
| 198 |
+
model_prefix_name = "mixlora"
|
| 199 |
+
moe_layer_name_list = list(mlp_state_dict.keys())
|
| 200 |
+
moe_init_strategy = "fused_mlp"
|
| 201 |
+
elif isinstance(lora_config, LoraMoeConfig):
|
| 202 |
+
model_prefix_name = "loramoe"
|
| 203 |
+
moe_layer_name_list = list(mlp_state_dict.keys())
|
| 204 |
+
moe_init_strategy = "plugin"
|
| 205 |
+
elif isinstance(lora_config, MolaConfig):
|
| 206 |
+
model_prefix_name = "mola"
|
| 207 |
+
moe_layer_name_list = list(all_state_dict.keys())
|
| 208 |
+
moe_init_strategy = "plugin"
|
| 209 |
+
else:
|
| 210 |
+
model_prefix_name = "base_model.model.model"
|
| 211 |
+
moe_layer_name_list = []
|
| 212 |
+
|
| 213 |
+
assert len(moe_layer_name_list) == 0 or moe_init_strategy in ["plugin", "fused_mlp"]
|
| 214 |
+
|
| 215 |
+
if moe_init_strategy == "fused_mlp":
|
| 216 |
+
transformer_layer.mlp_.moes_[lora_config.adapter_name] = moe_layer_factory(
|
| 217 |
+
llm_config.dim_,
|
| 218 |
+
llm_config.device_,
|
| 219 |
+
lora_config,
|
| 220 |
+
(
|
| 221 |
+
None
|
| 222 |
+
if lora_weights is None
|
| 223 |
+
else lora_weights[
|
| 224 |
+
f"{model_prefix_name}.layers.{transformer_layer.layer_id_}.mlp.moe_gate.weight"
|
| 225 |
+
]
|
| 226 |
+
),
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
for proj_name, lora_linear in all_state_dict.items():
|
| 230 |
+
lora_linear: Linear
|
| 231 |
+
if proj_name not in target_modules or not target_modules[proj_name]:
|
| 232 |
+
continue
|
| 233 |
+
module_name = (
|
| 234 |
+
"self_attn"
|
| 235 |
+
if proj_name in attn_state_dict
|
| 236 |
+
else ("mlp" if proj_name in mlp_state_dict else None)
|
| 237 |
+
)
|
| 238 |
+
module_name = f"{model_prefix_name}.layers.{transformer_layer.layer_id_}.{module_name}.{proj_name}"
|
| 239 |
+
if proj_name in moe_layer_name_list:
|
| 240 |
+
if moe_init_strategy == "plugin":
|
| 241 |
+
# init for gating mechanisms
|
| 242 |
+
lora_linear.moes_[lora_config.adapter_name] = moe_layer_factory(
|
| 243 |
+
lora_linear.in_features_,
|
| 244 |
+
llm_config.device_,
|
| 245 |
+
lora_config,
|
| 246 |
+
(
|
| 247 |
+
lora_weights.get(f"{module_name}.moe_gate.weight", None)
|
| 248 |
+
if lora_weights is not None
|
| 249 |
+
else None
|
| 250 |
+
),
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
for expert_idx in range(lora_config.num_experts_):
|
| 254 |
+
if lora_weights is None:
|
| 255 |
+
lora_a = None
|
| 256 |
+
lora_b = None
|
| 257 |
+
else:
|
| 258 |
+
lora_a = lora_weights.get(
|
| 259 |
+
f"{module_name}.experts.{expert_idx}.lora_A.weight", None
|
| 260 |
+
)
|
| 261 |
+
lora_b = lora_weights.get(
|
| 262 |
+
f"{module_name}.experts.{expert_idx}.lora_B.weight", None
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
lora_linear.init_lora_weight(
|
| 266 |
+
lora_config.expert_config(expert_idx), (lora_a, lora_b)
|
| 267 |
+
)
|
| 268 |
+
else:
|
| 269 |
+
if lora_weights is None:
|
| 270 |
+
lora_a = None
|
| 271 |
+
lora_b = None
|
| 272 |
+
else:
|
| 273 |
+
lora_a = lora_weights.get(f"{module_name}.lora_A.weight", None)
|
| 274 |
+
lora_b = lora_weights.get(f"{module_name}.lora_B.weight", None)
|
| 275 |
+
|
| 276 |
+
lora_linear.init_lora_weight(lora_config, (lora_a, lora_b))
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def get_lora_layer_weight(
|
| 280 |
+
transformer_layer: LLMDecoder,
|
| 281 |
+
lora_config: LoraConfig,
|
| 282 |
+
lora_weights: Dict[str, torch.Tensor],
|
| 283 |
+
):
|
| 284 |
+
target_modules = lora_config.target_modules_
|
| 285 |
+
attn_state_dict, mlp_state_dict = transformer_layer.state_dict()
|
| 286 |
+
attn_state_dict: Dict[str, torch.Tensor]
|
| 287 |
+
mlp_state_dict: Dict[str, torch.Tensor]
|
| 288 |
+
all_state_dict: Dict[str, torch.Tensor] = copy.copy(attn_state_dict)
|
| 289 |
+
all_state_dict.update(mlp_state_dict)
|
| 290 |
+
if isinstance(lora_config, MixLoraConfig):
|
| 291 |
+
model_prefix_name = "mixlora"
|
| 292 |
+
gate_layer_name = (
|
| 293 |
+
f"mixlora.layers.{transformer_layer.layer_id_}.mlp.moe_gate.weight"
|
| 294 |
+
)
|
| 295 |
+
moe_layer_name_list = list(mlp_state_dict.keys())
|
| 296 |
+
elif isinstance(lora_config, LoraMoeConfig):
|
| 297 |
+
model_prefix_name = "loramoe"
|
| 298 |
+
moe_layer_name_list = list(mlp_state_dict.keys())
|
| 299 |
+
elif isinstance(lora_config, MolaConfig):
|
| 300 |
+
model_prefix_name = "mola"
|
| 301 |
+
moe_layer_name_list = list(all_state_dict.keys())
|
| 302 |
+
else:
|
| 303 |
+
model_prefix_name = "base_model.model.model"
|
| 304 |
+
moe_layer_name_list = []
|
| 305 |
+
|
| 306 |
+
# for fused MoEs such as MixLoRA
|
| 307 |
+
mlp_moe_layer: LLMMoeBlock = transformer_layer.mlp_.moes_.get(
|
| 308 |
+
lora_config.adapter_name, None
|
| 309 |
+
)
|
| 310 |
+
if mlp_moe_layer is not None:
|
| 311 |
+
lora_weights[gate_layer_name] = mlp_moe_layer.gate_.weight
|
| 312 |
+
|
| 313 |
+
for proj_name, lora_linear in all_state_dict.items():
|
| 314 |
+
lora_linear: Linear
|
| 315 |
+
if proj_name not in target_modules or not target_modules[proj_name]:
|
| 316 |
+
continue
|
| 317 |
+
module_name = (
|
| 318 |
+
"self_attn"
|
| 319 |
+
if proj_name in attn_state_dict
|
| 320 |
+
else ("mlp" if proj_name in mlp_state_dict else None)
|
| 321 |
+
)
|
| 322 |
+
module_name = f"{model_prefix_name}.layers.{transformer_layer.layer_id_}.{module_name}.{proj_name}"
|
| 323 |
+
if proj_name in moe_layer_name_list:
|
| 324 |
+
moe_layer = (
|
| 325 |
+
lora_linear.moes_[lora_config.adapter_name]
|
| 326 |
+
if lora_config.adapter_name in lora_linear.moes_
|
| 327 |
+
else mlp_moe_layer
|
| 328 |
+
)
|
| 329 |
+
# for plugged MoEs such as LoRAMoE, MoLA, etc.
|
| 330 |
+
if lora_config.adapter_name in lora_linear.moes_:
|
| 331 |
+
lora_weights[f"{module_name}.moe_gate.weight"] = lora_linear.moes_[
|
| 332 |
+
lora_config.adapter_name
|
| 333 |
+
].gate_.weight
|
| 334 |
+
|
| 335 |
+
for expert_idx in range(moe_layer.experts_):
|
| 336 |
+
moe_lora_name = f"moe.{lora_config.adapter_name}.experts.{expert_idx}"
|
| 337 |
+
lora_obj = lora_linear.loras_.get(moe_lora_name, None)
|
| 338 |
+
if lora_obj is not None:
|
| 339 |
+
lora_weights[
|
| 340 |
+
f"{module_name}.experts.{expert_idx}.lora_A.weight"
|
| 341 |
+
] = lora_obj.lora_a_.weight
|
| 342 |
+
lora_weights[
|
| 343 |
+
f"{module_name}.experts.{expert_idx}.lora_B.weight"
|
| 344 |
+
] = lora_obj.lora_b_.weight
|
| 345 |
+
|
| 346 |
+
else:
|
| 347 |
+
lora_obj = lora_linear.loras_.get(lora_config.adapter_name, None)
|
| 348 |
+
if lora_obj is not None:
|
| 349 |
+
lora_weights[f"{module_name}.lora_A.weight"] = lora_obj.lora_a_.weight
|
| 350 |
+
lora_weights[f"{module_name}.lora_B.weight"] = lora_obj.lora_b_.weight
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def get_atten_tar(x, y, device, dtype):
|
| 354 |
+
si = torch.arange(0, y, device=device, dtype = dtype)
|
| 355 |
+
xi = torch.arange(1, x, device=device, dtype = dtype)#1~19
|
| 356 |
+
lamb = torch.tensor(-2, device=device, dtype= dtype)
|
| 357 |
+
alpha = (1 - torch.exp(-(si / 200))).detach()
|
| 358 |
+
base = torch.empty(x-1, device=device, dtype= dtype)#(19)
|
| 359 |
+
#base[0] = torch.log(torch.tensor(x, device=device, dtype = dtype)-1)
|
| 360 |
+
base[0] = torch.exp(lamb)
|
| 361 |
+
for i in range(1, x-1):
|
| 362 |
+
#base[i] = base[i - 1] + torch.log(torch.tensor(x-i-1, device=device, dtype = dtype))
|
| 363 |
+
base[i] = base[i - 1] + torch.exp(lamb * (i + 1))
|
| 364 |
+
award = (0.1 * (0.5 - 1 / (xi + 1)) + 0.2).detach()
|
| 365 |
+
#beta = (torch.log(x - xi) * award).expand(xi.shape[0], x-1).T
|
| 366 |
+
beta = (torch.exp(lamb * xi) * award).expand(xi.shape[0], x-1).T
|
| 367 |
+
beta = (beta / base).detach()
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
return alpha, beta # alpha是从0开始的,beta[0]是1。至少321长度时,beta至少得0.8
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
class LLMModel(torch.nn.Module):
|
| 375 |
+
def __init__(self, model: LLMForCausalLM):
|
| 376 |
+
super().__init__()
|
| 377 |
+
args: LLMModelConfig = model.config_
|
| 378 |
+
if args.vocab_size_ >= torch.finfo(args.dtype_).max:
|
| 379 |
+
logging.warn(
|
| 380 |
+
f"vocab_size >= max({args.dtype_}), consider load model with higher precision."
|
| 381 |
+
)
|
| 382 |
+
self.model_ = model
|
| 383 |
+
self.config_ = args
|
| 384 |
+
# configs
|
| 385 |
+
self.name_or_path_ = args.name_or_path_
|
| 386 |
+
self.vocab_size_ = args.vocab_size_
|
| 387 |
+
self.device_ = args.device_
|
| 388 |
+
self.dtype_ = args.dtype_
|
| 389 |
+
|
| 390 |
+
self.attention_weight = torch.nn.Parameter(torch.empty(
|
| 391 |
+
model.layers_[0].self_attn_.n_heads_,1,dtype=args.dtype_,device=args.device_,))
|
| 392 |
+
|
| 393 |
+
self.routerup = torch.nn.Parameter(torch.empty(
|
| 394 |
+
model.config_.dim_, 2,dtype=args.dtype_,device=args.device_,))
|
| 395 |
+
"""self. routerdown = torch.nn.Parameter(torch.empty(
|
| 396 |
+
model.config_.dim_ * 2, 2,dtype=args.dtype_,device=args.device_,))"""
|
| 397 |
+
self.cite_output = torch.nn.Parameter(torch.empty(
|
| 398 |
+
model.config_.dim_,model.config_.dim_,dtype=args.dtype_,device=args.device_,))
|
| 399 |
+
self.doc_proj = torch.nn.Parameter(torch.empty(
|
| 400 |
+
model.config_.dim_, model.config_.dim_,dtype=args.dtype_,device=args.device_,))
|
| 401 |
+
|
| 402 |
+
self.alpha, self.beta= get_atten_tar(40, 3000, args.device_, args.dtype_)
|
| 403 |
+
self.silu = torch.nn.SiLU()
|
| 404 |
+
|
| 405 |
+
self.output_ = OutputLayer()
|
| 406 |
+
# adapter configs
|
| 407 |
+
self.adapter_configs_: Dict[str, LoraConfig] = {}
|
| 408 |
+
|
| 409 |
+
def token2id(self, t):
|
| 410 |
+
if isinstance(t, torch.Tensor):
|
| 411 |
+
x = t.item()
|
| 412 |
+
else:
|
| 413 |
+
x = t
|
| 414 |
+
if x == 128002:
|
| 415 |
+
return 0
|
| 416 |
+
elif x == 128003:
|
| 417 |
+
return 1
|
| 418 |
+
elif x == 128004:
|
| 419 |
+
return 2
|
| 420 |
+
elif x == 128005:
|
| 421 |
+
return 3
|
| 422 |
+
elif x == 128008:
|
| 423 |
+
return 4
|
| 424 |
+
elif x >= 128010 and x <= 128255:
|
| 425 |
+
return x - 128005
|
| 426 |
+
else:
|
| 427 |
+
return -1
|
| 428 |
+
|
| 429 |
+
def attention_target(self, i, j, T):
|
| 430 |
+
return self.alpha[j] * self.beta[T, i] * self.award[i]
|
| 431 |
+
|
| 432 |
+
def _prepare_inputs(
|
| 433 |
+
self, input_args: LLMModelInput, past_key_values: Optional[LLMCache] = None
|
| 434 |
+
):
|
| 435 |
+
assert input_args.batch_tokens_ is not None, "Model have no input."
|
| 436 |
+
assert (
|
| 437 |
+
input_args.gradient_checkpoint_ == "none" or past_key_values is None
|
| 438 |
+
), "Cache is incompatible with gradient checkpointing."
|
| 439 |
+
assert (
|
| 440 |
+
not input_args.inference_mode_ or input_args.gradient_checkpoint_ == "none"
|
| 441 |
+
), "Can not use gradient checkpoint when inference."
|
| 442 |
+
|
| 443 |
+
# prepare inputs
|
| 444 |
+
if isinstance(input_args.batch_tokens_, torch.Tensor):
|
| 445 |
+
input_ids = input_args.batch_tokens_.to(
|
| 446 |
+
dtype=torch.int64, device=self.device_, requires_grad=False
|
| 447 |
+
)
|
| 448 |
+
else:
|
| 449 |
+
input_ids = torch.tensor(
|
| 450 |
+
input_args.batch_tokens_, dtype=torch.int64, device=self.device_, requires_grad=False
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
inputs_embeds = self.model_.embed_tokens(input_ids)
|
| 454 |
+
|
| 455 |
+
"""if input_ids.shape[-1] > 1:
|
| 456 |
+
self.doc_embeds = []
|
| 457 |
+
cites = input_args.batch_cites
|
| 458 |
+
docs = input_args.batch_docs
|
| 459 |
+
for doc in docs:
|
| 460 |
+
doc = doc.clone().to(self.device_)
|
| 461 |
+
doc = doc @ self.doc_proj
|
| 462 |
+
self.doc_embeds.append(doc)
|
| 463 |
+
for i, cite in enumerate(cites):
|
| 464 |
+
for c in range(len(input_args.batch_cites_value[i])):
|
| 465 |
+
inputs_embeds[i, cite[c]] = self.doc_embeds[i][self.token2id(input_args.batch_cites_value[i][c]) - 1].to(self.device_)
|
| 466 |
+
else:
|
| 467 |
+
fk = self.token2id(input_ids[0,0])
|
| 468 |
+
if fk != -1:
|
| 469 |
+
inputs_embeds[0][0] = self.doc_embeds[0][fk - 1].to(self.device_)"""
|
| 470 |
+
|
| 471 |
+
docs = input_args.batch_docs
|
| 472 |
+
if input_ids.shape[-1] > 1:
|
| 473 |
+
self.doc_embeds = []
|
| 474 |
+
cites = input_args.batch_cites
|
| 475 |
+
if not isinstance(docs[0][0], torch.Tensor):
|
| 476 |
+
for i in range(len(docs)):
|
| 477 |
+
d = []
|
| 478 |
+
for j in range(len(docs[i])):
|
| 479 |
+
temp = self.model_.embed_tokens(torch.tensor(
|
| 480 |
+
docs[i][j][1:], dtype=torch.int64, device=self.device_, requires_grad=False))
|
| 481 |
+
temp = torch.mean(temp, dim = 0)
|
| 482 |
+
d.append(temp)
|
| 483 |
+
d = torch.stack(d)
|
| 484 |
+
self.doc_embeds.append(d)
|
| 485 |
+
for i, cite in enumerate(cites):
|
| 486 |
+
for c in range(len(input_args.batch_cites_value[i])):
|
| 487 |
+
doc_ind = self.token2id(input_args.batch_cites_value[i][c]) - 1
|
| 488 |
+
assert doc_ind >= 0, print("fake cite token")
|
| 489 |
+
inputs_embeds[i, cite[c]] = self.doc_embeds[i][doc_ind].to(self.device_)
|
| 490 |
+
else:
|
| 491 |
+
fk = self.token2id(input_ids[0,0]) - 1
|
| 492 |
+
if fk >= 0:
|
| 493 |
+
inputs_embeds[0][0] = self.doc_embeds[0][fk].to(self.device_)
|
| 494 |
+
|
| 495 |
+
if input_args.gradient_checkpoint_ != "none":
|
| 496 |
+
inputs_embeds.requires_grad_(True)
|
| 497 |
+
|
| 498 |
+
# prepare cache
|
| 499 |
+
past_seen_tokens = (
|
| 500 |
+
past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
if past_seen_tokens is None:
|
| 504 |
+
past_seen_tokens = 0
|
| 505 |
+
|
| 506 |
+
cache_position = torch.arange(
|
| 507 |
+
past_seen_tokens,
|
| 508 |
+
past_seen_tokens + inputs_embeds.shape[1],
|
| 509 |
+
device=inputs_embeds.device,
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
# prepare mask
|
| 513 |
+
if input_args.batch_masks_ is not None:
|
| 514 |
+
# 2d mask is passed through the layers
|
| 515 |
+
if isinstance(input_args.batch_masks_, torch.Tensor):
|
| 516 |
+
attention_mask = input_args.batch_masks_.to(
|
| 517 |
+
dtype=torch.int64, device=self.device_
|
| 518 |
+
)
|
| 519 |
+
else:
|
| 520 |
+
attention_mask = torch.tensor(
|
| 521 |
+
input_args.batch_masks_, dtype=torch.int64, device=self.device_
|
| 522 |
+
)
|
| 523 |
+
else:
|
| 524 |
+
attention_mask = None
|
| 525 |
+
|
| 526 |
+
if self.config_.attn_implementation_ != "flash_attn":
|
| 527 |
+
causal_mask = self.model_.causal_mask(
|
| 528 |
+
attention_mask, inputs_embeds, cache_position, past_key_values
|
| 529 |
+
)
|
| 530 |
+
else:
|
| 531 |
+
causal_mask = attention_mask
|
| 532 |
+
|
| 533 |
+
return input_ids, inputs_embeds, attention_mask, causal_mask, cache_position
|
| 534 |
+
|
| 535 |
+
def _call_decoder_stack_original(
|
| 536 |
+
self,
|
| 537 |
+
hidden_states: torch.Tensor,
|
| 538 |
+
input_args: LLMModelInput,
|
| 539 |
+
rotary_emb: Tuple[torch.Tensor, torch.Tensor],
|
| 540 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 541 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 542 |
+
past_key_value: Optional[LLMCache] = None,
|
| 543 |
+
):
|
| 544 |
+
# decoder layers
|
| 545 |
+
num_adapters = len(input_args.batch_configs_)
|
| 546 |
+
all_router_logits = [[] for _ in range(num_adapters)]
|
| 547 |
+
gradient_checkpoint = CHECKPOINT_CLASSES[input_args.gradient_checkpoint_]
|
| 548 |
+
|
| 549 |
+
for decoder_layer in self.model_.decoder_stack():
|
| 550 |
+
hidden_states, *router_logits = gradient_checkpoint(
|
| 551 |
+
decoder_layer.forward,
|
| 552 |
+
hidden_states,
|
| 553 |
+
input_args,
|
| 554 |
+
rotary_emb,
|
| 555 |
+
attention_mask,
|
| 556 |
+
cache_position,
|
| 557 |
+
past_key_value,
|
| 558 |
+
)
|
| 559 |
+
if len(router_logits) == 0:
|
| 560 |
+
continue
|
| 561 |
+
# collecting router logits
|
| 562 |
+
assert len(router_logits) == num_adapters
|
| 563 |
+
for idx in range(num_adapters):
|
| 564 |
+
if router_logits[idx] is not None:
|
| 565 |
+
all_router_logits[idx].append(router_logits[idx])
|
| 566 |
+
|
| 567 |
+
hidden_states = self.model_.norm(hidden_states)
|
| 568 |
+
|
| 569 |
+
return hidden_states, all_router_logits
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
def _call_decoder_stack(
|
| 573 |
+
self,
|
| 574 |
+
hidden_states: torch.Tensor,
|
| 575 |
+
input_args: LLMModelInput,
|
| 576 |
+
rotary_emb: Tuple[torch.Tensor, torch.Tensor],
|
| 577 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 578 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 579 |
+
past_key_value: Optional[LLMCache] = None,
|
| 580 |
+
#require_attention: Optional[int] = -1,
|
| 581 |
+
#require_hide: Optional[int] = -1,
|
| 582 |
+
):
|
| 583 |
+
# decoder layers
|
| 584 |
+
gradient_checkpoint = CHECKPOINT_CLASSES[input_args.gradient_checkpoint_]
|
| 585 |
+
|
| 586 |
+
#hidden_output = []
|
| 587 |
+
#hidden_atten = []
|
| 588 |
+
attention_matrixs = []
|
| 589 |
+
for idx, decoder_layer in enumerate(self.model_.decoder_stack()):
|
| 590 |
+
hidden_states, attention_matrix = gradient_checkpoint(
|
| 591 |
+
decoder_layer.forward,
|
| 592 |
+
hidden_states,
|
| 593 |
+
input_args,
|
| 594 |
+
rotary_emb,
|
| 595 |
+
attention_mask,
|
| 596 |
+
cache_position,
|
| 597 |
+
past_key_value,
|
| 598 |
+
)
|
| 599 |
+
if idx in [31,30,29]:
|
| 600 |
+
attention_matrixs.append(attention_matrix)
|
| 601 |
+
"""if require_hide == len(self.model_.layers_) or require_hide == idx:
|
| 602 |
+
hidden_output.append(hidden_states)
|
| 603 |
+
if require_attention == len(self.model_.layers_) or require_attention == idx:
|
| 604 |
+
hidden_atten.append(hidden_attention)"""
|
| 605 |
+
|
| 606 |
+
hidden_states = self.model_.norm(hidden_states)
|
| 607 |
+
|
| 608 |
+
return hidden_states, attention_matrixs#hidden_atten, hidden_output
|
| 609 |
+
|
| 610 |
+
# compute the model: output probs
|
| 611 |
+
def forward(
|
| 612 |
+
self, input_args: LLMModelInput, past_key_values: Optional[LLMCache] = None
|
| 613 |
+
) -> List[LLMModelOutput]:
|
| 614 |
+
input_ids, inputs_embeds, attention_mask, causal_mask, cache_position = (
|
| 615 |
+
self._prepare_inputs(input_args, past_key_values)
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
labels = input_args.batch_labels_
|
| 619 |
+
|
| 620 |
+
input_args.batch_labels_ = None
|
| 621 |
+
input_args.batch_tokens_ = None
|
| 622 |
+
input_args.batch_masks_ = None
|
| 623 |
+
|
| 624 |
+
# embed positions
|
| 625 |
+
hidden_states = inputs_embeds
|
| 626 |
+
|
| 627 |
+
rotary_emb = self.model_.rotary_embed(
|
| 628 |
+
hidden_states, cache_position.unsqueeze(0)
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
hidden_states, attention_matrixs = self._call_decoder_stack(
|
| 632 |
+
hidden_states,
|
| 633 |
+
input_args,
|
| 634 |
+
rotary_emb,
|
| 635 |
+
causal_mask,
|
| 636 |
+
cache_position,
|
| 637 |
+
past_key_values,
|
| 638 |
+
#require_attention,
|
| 639 |
+
#require_hide,
|
| 640 |
+
)
|
| 641 |
+
attention_matrixs[-1] = attention_matrixs[-1].permute(0,2,3,1)
|
| 642 |
+
attention_matrixs[-1] = torch.sum(attention_matrixs[-1], dim = -1).squeeze().to('cpu').detach()
|
| 643 |
+
#print(attention_matrixs[-1].shape)
|
| 644 |
+
#print(torch.mean(attention_matrixs[-1][input_args.batch_cites[0][0] + 1:input_args.batch_cites[0][2],input_args.batch_cites[0][0]]))
|
| 645 |
+
import numpy as np
|
| 646 |
+
import matplotlib.pyplot as plt
|
| 647 |
+
import seaborn as sns
|
| 648 |
+
plt.figure(figsize=(8, 6))
|
| 649 |
+
print(f"len:{input_args.batch_prompt_len[0]}")
|
| 650 |
+
print(attention_matrixs[-1].shape)
|
| 651 |
+
sns.heatmap(attention_matrixs[-1][input_args.batch_prompt_len[0]:,input_args.batch_prompt_len[0]:], annot=False, cmap="YlGnBu", vmin = 0, vmax = 0.2, xticklabels=False, yticklabels=False)
|
| 652 |
+
plt.savefig("/yy21/heatmap", bbox_inches='tight', dpi=300)
|
| 653 |
+
input()
|
| 654 |
+
#route_logits = hidden_states @ (self.routerup @ self.routerdown)
|
| 655 |
+
route_logits = hidden_states @ self.routerup
|
| 656 |
+
hidden_cites = hidden_states @ self.cite_output
|
| 657 |
+
norm_cite_logits = F.normalize(hidden_cites, p = 2, dim = 2)
|
| 658 |
+
cite_logits = []
|
| 659 |
+
for batch in range(hidden_states.shape[0]):
|
| 660 |
+
#norm_doc = F.normalize(self.doc_embeds[batch], p = 2, dim = 1)
|
| 661 |
+
norm_doc = F.normalize(self.doc_embeds[batch].detach(), p = 2, dim = 1)
|
| 662 |
+
cite_logits.append(norm_cite_logits[batch] @ norm_doc.T)
|
| 663 |
+
#cite_logits.append(norm_cite_logits[batch])
|
| 664 |
+
|
| 665 |
+
|
| 666 |
+
# calculate loss
|
| 667 |
+
output = self.output_(hidden_states, input_args)
|
| 668 |
+
#att_s = hidden_atten[0].sum(dim = 1).squeeze() / 32 ###这里把List变为一个值
|
| 669 |
+
assert isinstance(output, List)
|
| 670 |
+
for indx, lora_config in enumerate(input_args.batch_configs_):
|
| 671 |
+
output_data = output[indx]
|
| 672 |
+
assert isinstance(output_data, LLMModelOutput)
|
| 673 |
+
start_idx = lora_config.batch_start_idx_
|
| 674 |
+
end_idx = lora_config.batch_end_idx_
|
| 675 |
+
output_data.batch_start_idx_ = start_idx
|
| 676 |
+
output_data.batch_end_idx_ = end_idx
|
| 677 |
+
#print(f"router:{route_logits[0,-1]}")
|
| 678 |
+
#print(f"cite:{cite_logits}")
|
| 679 |
+
if (labels is None) and (route_logits[0, -1, 1] > route_logits[0, -1, 0]):
|
| 680 |
+
output_data.logits = cite_logits[0].unsqueeze(0)
|
| 681 |
+
#output_data.logits = hidden_states[0].unsqueeze(0)
|
| 682 |
+
output_data.cite_flag = True
|
| 683 |
+
else:
|
| 684 |
+
output_data.cite_flag = False
|
| 685 |
+
if labels is None:
|
| 686 |
+
continue
|
| 687 |
+
# compute loss when labels provided
|
| 688 |
+
output_data.loss = output_data.loss_fn_(
|
| 689 |
+
input_ids[start_idx:end_idx],
|
| 690 |
+
output_data.logits,
|
| 691 |
+
labels[start_idx:end_idx],
|
| 692 |
+
input_args.batch_cites,
|
| 693 |
+
input_args.batch_cites_value,
|
| 694 |
+
input_args.batch_prompt_len
|
| 695 |
+
)
|
| 696 |
+
output_data.loss_fn_ = None
|
| 697 |
+
# route_logits和下面的合并
|
| 698 |
+
for idx in range(len(input_args.batch_cites)):
|
| 699 |
+
new_cites = []
|
| 700 |
+
new_cites_v = []
|
| 701 |
+
for i in range(len(input_args.batch_cites[idx])):
|
| 702 |
+
if input_args.batch_cites[idx][i] >= input_args.batch_prompt_len[idx]:
|
| 703 |
+
new_cites.append(input_args.batch_cites[idx][i])
|
| 704 |
+
if i < len(input_args.batch_cites_value[idx]):
|
| 705 |
+
new_cites_v.append(input_args.batch_cites_value[idx][i])
|
| 706 |
+
input_args.batch_cites[idx] = new_cites
|
| 707 |
+
input_args.batch_cites_value[idx] = new_cites_v
|
| 708 |
+
if output_data.aux_loss is None:
|
| 709 |
+
output_data.aux_loss = self.attn_mat_coin * 0.01 * self.attention_loss_fn(attention_matrixs, causal_mask, input_args.batch_cites, input_args.batch_prompt_len)
|
| 710 |
+
else:
|
| 711 |
+
output_data.aux_loss += self.attn_mat_coin * 0.01 * self.attention_loss_fn(attention_matrixs, causal_mask, input_args.batch_cites, input_args.batch_prompt_len)
|
| 712 |
+
print(f"1:{output_data.aux_loss}")
|
| 713 |
+
for idx in range(len(input_args.batch_cites)):
|
| 714 |
+
if len(input_args.batch_cites[idx]) > len(input_args.batch_cites_value[idx]):
|
| 715 |
+
input_args.batch_cites[idx] = input_args.batch_cites[idx][:-1]
|
| 716 |
+
output_data.aux_loss += self.router_coin * 10 * self.compute_route_loss(route_logits, input_args.batch_cites)#router的label中,cite位置的是1,其他是0
|
| 717 |
+
print(f"2:{output_data.aux_loss}")
|
| 718 |
+
#output_data.aux_loss += self.cite_coin * self.compute_cite_loss2(hidden_states, input_args.batch_cites,input_args.batch_cites_value,batch_doc_embed)#router的label中,cite位置的是1,其他是0
|
| 719 |
+
output_data.aux_loss += self.cite_coin * 100 * self.compute_cite_loss(cite_logits, input_args.batch_cites,input_args.batch_cites_value)#router的label中,cite位置的是1,其他是0
|
| 720 |
+
print(f"3:{output_data.aux_loss}")
|
| 721 |
+
return output
|
| 722 |
+
|
| 723 |
+
def from_pretrained(
|
| 724 |
+
name_or_path: str,
|
| 725 |
+
device: str,
|
| 726 |
+
bits: int = None,
|
| 727 |
+
attn_impl: str = "eager",
|
| 728 |
+
use_sliding_window: bool = False,
|
| 729 |
+
load_dtype: torch.dtype = torch.bfloat16,
|
| 730 |
+
compute_dtype: torch.dtype = torch.bfloat16,
|
| 731 |
+
double_quant: bool = True,
|
| 732 |
+
quant_type: str = "nf4",
|
| 733 |
+
) -> "LLMModel":
|
| 734 |
+
# load_dtype will change the precision of LLaMA pre-trained model
|
| 735 |
+
# when loading with quantization (bits = 8 or bits = 4), load_dtype will only influence the actual computing precision
|
| 736 |
+
if load_dtype not in [torch.bfloat16, torch.float16, torch.float32]:
|
| 737 |
+
raise ValueError(f"unsupported load dtype {load_dtype}")
|
| 738 |
+
|
| 739 |
+
if compute_dtype not in [torch.bfloat16, torch.float16, torch.float32]:
|
| 740 |
+
raise ValueError(f"unsupported compute dtype {compute_dtype}")
|
| 741 |
+
|
| 742 |
+
if load_dtype in [torch.bfloat16, torch.float16]:
|
| 743 |
+
logging.info("Loading model with half precision.")
|
| 744 |
+
|
| 745 |
+
# BFloat16 is only supported after Ampere GPUs
|
| 746 |
+
if not executor.is_bf16_supported():
|
| 747 |
+
if load_dtype == torch.bfloat16:
|
| 748 |
+
logging.warning("bf16 is not available. deprecated to fp16.")
|
| 749 |
+
load_dtype = torch.float16
|
| 750 |
+
|
| 751 |
+
if bits in [4, 8] and compute_dtype == torch.bfloat16:
|
| 752 |
+
logging.warning("bf16 is not available. deprecated to fp16.")
|
| 753 |
+
compute_dtype = torch.float16
|
| 754 |
+
|
| 755 |
+
if bits in [4, 8]:
|
| 756 |
+
logging.info(f"Loading model with quantization, bits = {bits}.")
|
| 757 |
+
llm_model = AutoModelForCausalLM.from_pretrained(
|
| 758 |
+
name_or_path,
|
| 759 |
+
device_map=device,
|
| 760 |
+
trust_remote_code=True,
|
| 761 |
+
quantization_config=BitsAndBytesConfig(
|
| 762 |
+
load_in_4bit=bits == 4,
|
| 763 |
+
load_in_8bit=bits == 8,
|
| 764 |
+
llm_int8_threshold=6.0,
|
| 765 |
+
llm_int8_has_fp16_weight=False,
|
| 766 |
+
bnb_4bit_compute_dtype=compute_dtype,
|
| 767 |
+
bnb_4bit_use_double_quant=double_quant,
|
| 768 |
+
bnb_4bit_quant_type=quant_type,
|
| 769 |
+
),
|
| 770 |
+
torch_dtype=load_dtype,
|
| 771 |
+
)
|
| 772 |
+
else:
|
| 773 |
+
llm_model = AutoModelForCausalLM.from_pretrained(
|
| 774 |
+
name_or_path,
|
| 775 |
+
device_map=device,
|
| 776 |
+
trust_remote_code=True,
|
| 777 |
+
torch_dtype=load_dtype,
|
| 778 |
+
)
|
| 779 |
+
|
| 780 |
+
llm_model.requires_grad_(False)
|
| 781 |
+
|
| 782 |
+
model = from_pretrained(
|
| 783 |
+
llm_model,
|
| 784 |
+
attn_impl=attn_impl,
|
| 785 |
+
use_sliding_window=use_sliding_window,
|
| 786 |
+
device=device,
|
| 787 |
+
)
|
| 788 |
+
|
| 789 |
+
logging.info(f"Use {attn_impl} as attention implementation.")
|
| 790 |
+
|
| 791 |
+
return LLMModel(model)
|
| 792 |
+
|
| 793 |
+
def init_adapter(
|
| 794 |
+
self, config: AdapterConfig, weight: Optional[Dict[str, torch.Tensor]] = None
|
| 795 |
+
):
|
| 796 |
+
self.attn_mat_coin = config.atten_coin
|
| 797 |
+
self.router_coin = config.router_coin
|
| 798 |
+
self.cite_coin = config.cite_coin
|
| 799 |
+
# Patch for MixLoRA
|
| 800 |
+
if isinstance(config, MixLoraConfig) and config.act_fn_ is None:
|
| 801 |
+
config.act_fn_ = self.config_.hidden_act_
|
| 802 |
+
|
| 803 |
+
self.adapter_configs_[config.adapter_name] = config
|
| 804 |
+
# init output layer
|
| 805 |
+
if config.task_name in task_dict and isinstance(
|
| 806 |
+
task_dict[config.task_name], SequenceClassificationTask
|
| 807 |
+
):
|
| 808 |
+
output_layer = ClassificationOutputLayer(
|
| 809 |
+
**task_dict[config.task_name].init_kwargs(),
|
| 810 |
+
hidden_size=self.config_.dim_,
|
| 811 |
+
pad_token_id=self.config_.pad_token_id_,
|
| 812 |
+
device=self.device_,
|
| 813 |
+
weight=weight,
|
| 814 |
+
)
|
| 815 |
+
else:
|
| 816 |
+
output_layer = CasualOutputLayer(
|
| 817 |
+
vocab_size=self.config_.vocab_size_, weight=self.model_.lm_head_
|
| 818 |
+
)
|
| 819 |
+
|
| 820 |
+
if weight is None:
|
| 821 |
+
torch.nn.init.kaiming_normal_(self.attention_weight, mode='fan_in', nonlinearity='relu')
|
| 822 |
+
torch.nn.init.kaiming_normal_(self.routerup, mode='fan_in', nonlinearity='relu')
|
| 823 |
+
#torch.nn.init.kaiming_normal_(self.routerdown, mode='fan_in', nonlinearity='relu')
|
| 824 |
+
torch.nn.init.kaiming_normal_(self.cite_output, mode='fan_in', nonlinearity='relu')
|
| 825 |
+
torch.nn.init.orthogonal_(self.doc_proj)
|
| 826 |
+
else:
|
| 827 |
+
with torch.no_grad():
|
| 828 |
+
self.attention_weight.copy_(weight.get(f"{config.adapter_name}.attention_mat_weight", None))
|
| 829 |
+
self.routerup.copy_(weight.get(f"{config.adapter_name}.router_weight_up", None))
|
| 830 |
+
#self.routerdown.copy_(weight.get(f"{config.adapter_name}.router_weight_down", None))
|
| 831 |
+
self.cite_output.copy_(weight.get(f"{config.adapter_name}.cite_weight", None))
|
| 832 |
+
self.doc_proj.copy_(weight.get(f"{config.adapter_name}.doc_weight", None))
|
| 833 |
+
self.output_.layers_[config.adapter_name] = output_layer
|
| 834 |
+
if type(config) is not AdapterConfig:
|
| 835 |
+
# init transformer layers
|
| 836 |
+
for transformer_layer in self.model_.layers_:
|
| 837 |
+
init_lora_layer_weight(transformer_layer, self.config_, config, weight)
|
| 838 |
+
else:
|
| 839 |
+
assert weight is None, "can not load basic adapter with weight"
|
| 840 |
+
|
| 841 |
+
return config.adapter_name
|
| 842 |
+
|
| 843 |
+
def get_adapter_weight_dict(self, adapter_name: str) -> Dict[str, torch.Tensor]:
|
| 844 |
+
# return the lora weight and target_module's name
|
| 845 |
+
lora_weight_dict = self.output_.layers_[adapter_name].state_dict()
|
| 846 |
+
atten_name = f"{adapter_name}.attention_mat_weight"
|
| 847 |
+
lora_weight_dict[atten_name] = self.attention_weight
|
| 848 |
+
route_name = f"{adapter_name}.router_weight_up"
|
| 849 |
+
lora_weight_dict[route_name] = self.routerup
|
| 850 |
+
"""route_name = f"{adapter_name}.router_weight_down"
|
| 851 |
+
lora_weight_dict[route_name] = self.routerdown"""
|
| 852 |
+
cite_name = f"{adapter_name}.cite_weight"
|
| 853 |
+
lora_weight_dict[cite_name] = self.cite_output
|
| 854 |
+
doc_name = f"{adapter_name}.doc_weight"
|
| 855 |
+
lora_weight_dict[doc_name] = self.doc_proj
|
| 856 |
+
lora_config = self.adapter_configs_[adapter_name]
|
| 857 |
+
for transformer_layer in self.model_.layers_:
|
| 858 |
+
get_lora_layer_weight(transformer_layer, lora_config, lora_weight_dict)
|
| 859 |
+
|
| 860 |
+
return lora_weight_dict
|
| 861 |
+
|
| 862 |
+
def unload_adapter(
|
| 863 |
+
self, adapter_name: str
|
| 864 |
+
) -> Tuple[LoraConfig, Dict[str, torch.Tensor]]:
|
| 865 |
+
assert adapter_name in self.adapter_configs_, "adapter not exist"
|
| 866 |
+
lora_weight = self.get_adapter_weight_dict(adapter_name)
|
| 867 |
+
lora_config = self.adapter_configs_.pop(adapter_name)
|
| 868 |
+
self.output_.layers_.pop(adapter_name)
|
| 869 |
+
for transformer_layer in self.model_.layers_:
|
| 870 |
+
attn_state_dict, mlp_state_dict = transformer_layer.state_dict()
|
| 871 |
+
attn_state_dict: Dict[str, torch.Tensor]
|
| 872 |
+
mlp_state_dict: Dict[str, torch.Tensor]
|
| 873 |
+
lora_layer_list = list(attn_state_dict.values())
|
| 874 |
+
lora_layer_list.extend(mlp_state_dict.values())
|
| 875 |
+
|
| 876 |
+
for lora_layer in lora_layer_list:
|
| 877 |
+
if adapter_name in lora_layer.loras_:
|
| 878 |
+
lora_layer.loras_.pop(adapter_name, None)
|
| 879 |
+
elif adapter_name in transformer_layer.mlp_.moes_:
|
| 880 |
+
for expert_idx in range(
|
| 881 |
+
transformer_layer.mlp_.moes_[adapter_name].experts_
|
| 882 |
+
):
|
| 883 |
+
moe_lora_name = f"moe.{adapter_name}.experts.{expert_idx}"
|
| 884 |
+
lora_layer.loras_.pop(moe_lora_name, None)
|
| 885 |
+
|
| 886 |
+
transformer_layer.mlp_.moes_.pop(adapter_name)
|
| 887 |
+
elif adapter_name in lora_layer.moes_:
|
| 888 |
+
for expert_idx in range(lora_layer.moes_[adapter_name].experts_):
|
| 889 |
+
moe_lora_name = f"moe.{adapter_name}.experts.{expert_idx}"
|
| 890 |
+
lora_layer.loras_.pop(moe_lora_name, None)
|
| 891 |
+
|
| 892 |
+
lora_layer.moes_.pop(lora_config.adapter_name, None)
|
| 893 |
+
|
| 894 |
+
return lora_config, lora_weight
|
| 895 |
+
|
| 896 |
+
def load_adapter(self, name_or_path: str, adapter_name: Optional[str] = None):
|
| 897 |
+
if adapter_name is None:
|
| 898 |
+
adapter_name = name_or_path
|
| 899 |
+
|
| 900 |
+
if not os.path.exists(name_or_path):
|
| 901 |
+
name_or_path = snapshot_download(repo_id=name_or_path, repo_type="model")
|
| 902 |
+
with open(
|
| 903 |
+
name_or_path + os.sep + "adapter_config.json", "r", encoding="utf8"
|
| 904 |
+
) as fp:
|
| 905 |
+
lora_config = lora_config_factory(json.load(fp))
|
| 906 |
+
lora_config.adapter_name = adapter_name
|
| 907 |
+
lora_weight = torch.load(
|
| 908 |
+
name_or_path + os.sep + "adapter_model.bin",
|
| 909 |
+
map_location=self.device_,
|
| 910 |
+
weights_only=False,
|
| 911 |
+
)
|
| 912 |
+
|
| 913 |
+
self.init_adapter(lora_config, lora_weight)
|
| 914 |
+
return adapter_name
|
| 915 |
+
|
| 916 |
+
def compute_route_loss(self, logits, cites):
|
| 917 |
+
nrom_logits = logits / torch.norm(logits, dim = -1, keepdim=True)
|
| 918 |
+
b, l, v = logits.shape
|
| 919 |
+
"""for c in cites:
|
| 920 |
+
if c[-1] == l:
|
| 921 |
+
del c[-1]"""
|
| 922 |
+
label = []
|
| 923 |
+
for k in range(b):
|
| 924 |
+
label.append([1 if i in cites[k] else 0 for i in range(l)])
|
| 925 |
+
|
| 926 |
+
if isinstance(label, torch.Tensor):
|
| 927 |
+
label = (
|
| 928 |
+
label.clone()
|
| 929 |
+
.detach()
|
| 930 |
+
.to(dtype=torch.long, device=logits.device)
|
| 931 |
+
)
|
| 932 |
+
else:
|
| 933 |
+
label = torch.tensor(label, dtype=torch.long, device=logits.device)
|
| 934 |
+
|
| 935 |
+
loss_fn = torch.nn.CrossEntropyLoss()
|
| 936 |
+
return loss_fn(
|
| 937 |
+
nrom_logits[..., :-1, :].contiguous().view(-1, v),
|
| 938 |
+
label[..., 1:].contiguous().view(-1),
|
| 939 |
+
)
|
| 940 |
+
|
| 941 |
+
def compute_cite_loss2(self, logits, cites, cites_v, docs_pos):
|
| 942 |
+
b = len(logits)
|
| 943 |
+
docs_pos = [torch.tensor(i) for i in docs_pos]
|
| 944 |
+
doc_embeds = []
|
| 945 |
+
norm_logits = [F.normalize(logits[batch], p = 2, dim = 1) for batch in range(logits.shape[0])]
|
| 946 |
+
for i in range(b):
|
| 947 |
+
doc_embeds.append(norm_logits[i][docs_pos[i]].transpose(0,1))
|
| 948 |
+
b_logits = []
|
| 949 |
+
|
| 950 |
+
for i in range(len(cites)):
|
| 951 |
+
b_logits.append(norm_logits[i] @ doc_embeds[i])
|
| 952 |
+
for k in range(len(cites_v)):
|
| 953 |
+
cites_v[k] = [self.token2id(i) for i in cites_v[k]]
|
| 954 |
+
|
| 955 |
+
labels = []
|
| 956 |
+
for k in range(b):
|
| 957 |
+
labels.append([-100 for _ in range(logits[k].shape[0])])
|
| 958 |
+
for i, v in zip(cites[k], cites_v[k]):
|
| 959 |
+
labels[k][i] = v - 1
|
| 960 |
+
|
| 961 |
+
if isinstance(labels[0], torch.Tensor):
|
| 962 |
+
for k in range(b):
|
| 963 |
+
labels[k] = (
|
| 964 |
+
labels[k].clone()
|
| 965 |
+
.detach()
|
| 966 |
+
.to(dtype=torch.long, device=logits[0].device)
|
| 967 |
+
)
|
| 968 |
+
else:
|
| 969 |
+
for k in range(b):
|
| 970 |
+
labels[k] = torch.tensor(labels[k], dtype=torch.long, device=logits[0].device)
|
| 971 |
+
|
| 972 |
+
loss_fn = torch.nn.CrossEntropyLoss(ignore_index = -100)
|
| 973 |
+
|
| 974 |
+
loss = 0
|
| 975 |
+
for k in range(b):
|
| 976 |
+
if len(cites[k]) != 0:
|
| 977 |
+
loss += loss_fn(
|
| 978 |
+
b_logits[k][..., :-1, :].contiguous().view(-1, b_logits[k].shape[-1]),
|
| 979 |
+
labels[k][..., 1:].contiguous().view(-1),
|
| 980 |
+
)
|
| 981 |
+
return loss / b
|
| 982 |
+
|
| 983 |
+
def compute_cite_loss(self, logits, cites, cites_v):
|
| 984 |
+
b = len(logits)
|
| 985 |
+
|
| 986 |
+
for k in range(len(cites_v)):
|
| 987 |
+
"""if len(cites[k]) > len(cites_v[k]):
|
| 988 |
+
del cites[k][-1]"""
|
| 989 |
+
cites_v[k] = [self.token2id(i) for i in cites_v[k]]
|
| 990 |
+
|
| 991 |
+
labels = []
|
| 992 |
+
for k in range(b):
|
| 993 |
+
labels.append([-100 for _ in range(logits[k].shape[0])])
|
| 994 |
+
for i, v in zip(cites[k], cites_v[k]):
|
| 995 |
+
labels[k][i] = v - 1
|
| 996 |
+
|
| 997 |
+
if isinstance(labels[0], torch.Tensor):
|
| 998 |
+
for k in range(b):
|
| 999 |
+
labels[k] = (
|
| 1000 |
+
labels[k].clone()
|
| 1001 |
+
.detach()
|
| 1002 |
+
.to(dtype=torch.long, device=logits[0].device)
|
| 1003 |
+
)
|
| 1004 |
+
else:
|
| 1005 |
+
for k in range(b):
|
| 1006 |
+
labels[k] = torch.tensor(labels[k], dtype=torch.long, device=logits[0].device)
|
| 1007 |
+
|
| 1008 |
+
loss_fn = torch.nn.CrossEntropyLoss(ignore_index = -100)
|
| 1009 |
+
|
| 1010 |
+
loss = 0
|
| 1011 |
+
for k in range(b):
|
| 1012 |
+
if len(cites[k]) != 0:
|
| 1013 |
+
loss += loss_fn(
|
| 1014 |
+
logits[k][..., :-1, :].contiguous().view(-1, logits[k].shape[-1]),
|
| 1015 |
+
labels[k][..., 1:].contiguous().view(-1),
|
| 1016 |
+
)
|
| 1017 |
+
return loss / b
|
| 1018 |
+
|
| 1019 |
+
|
| 1020 |
+
def attention_loss_fn(self, mat, mask, cites, prompt_len):# cites: T个元素,每个元素代表c_i所在列
|
| 1021 |
+
mat = torch.stack(mat, dim = 0)
|
| 1022 |
+
mat = mat.permute(1,0,3,4,2)
|
| 1023 |
+
#final_mat = torch.matmul(mat, self.attention_weight).squeeze(-1)
|
| 1024 |
+
final_mat = torch.mean(mat, dim = -1)
|
| 1025 |
+
final_mat += mask
|
| 1026 |
+
final_mat = F.softmax(final_mat, dim=-1)
|
| 1027 |
+
loss = torch.tensor(0.0, dtype = final_mat.dtype, device = final_mat.device)
|
| 1028 |
+
num_layer = final_mat.shape[1]
|
| 1029 |
+
for batch in range(final_mat.shape[0]):
|
| 1030 |
+
if len(cites[batch]) == 0:
|
| 1031 |
+
continue
|
| 1032 |
+
for k in range(len(cites[batch]) - 1):
|
| 1033 |
+
for i in range(k + 1):
|
| 1034 |
+
if cites[batch][k] == cites[batch][k + 1] - 1:
|
| 1035 |
+
continue
|
| 1036 |
+
loss_now = (self.alpha[cites[batch][k]:cites[batch][k + 1] - 1] * self.beta[k - i, k]).expand(1, num_layer,-1) - final_mat[batch,:,cites[batch][k]:cites[batch][k + 1] - 1,cites[batch][i]]
|
| 1037 |
+
loss += F.relu(loss_now).sum() / (cites[batch][k + 1] - cites[batch][k])
|
| 1038 |
+
|
| 1039 |
+
return loss
|
c2cite/models/__init__.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .modeling_chatglm import GLMForCausalLM
|
| 2 |
+
from .modeling_gemma import GemmaForCausalLM
|
| 3 |
+
from .modeling_gemma2 import Gemma2ForCausalLM
|
| 4 |
+
from .modeling_llama import LlamaForCausalLM
|
| 5 |
+
from .modeling_mistral import MistralForCausalLM
|
| 6 |
+
from .modeling_mistral import MistralForCausalLM as Qwen2ForCausalLM
|
| 7 |
+
from .modeling_phi import PhiForCausalLM
|
| 8 |
+
from .modeling_phi3 import Phi3ForCausalLM
|
| 9 |
+
|
| 10 |
+
model_dict = {
|
| 11 |
+
"llama": LlamaForCausalLM,
|
| 12 |
+
"gemma": GemmaForCausalLM,
|
| 13 |
+
"gemma2": Gemma2ForCausalLM,
|
| 14 |
+
"mistral": MistralForCausalLM,
|
| 15 |
+
"qwen2": Qwen2ForCausalLM,
|
| 16 |
+
"phi": PhiForCausalLM,
|
| 17 |
+
"phi3": Phi3ForCausalLM,
|
| 18 |
+
"chatglm": GLMForCausalLM,
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def from_pretrained(llm_model, **kwargs):
|
| 23 |
+
if llm_model.config.model_type in model_dict:
|
| 24 |
+
return model_dict[llm_model.config.model_type].from_pretrained(
|
| 25 |
+
llm_model, **kwargs
|
| 26 |
+
)
|
| 27 |
+
else:
|
| 28 |
+
raise RuntimeError(f"Model {llm_model.config.model_type} not supported.")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
__all__ = [
|
| 32 |
+
"LlamaForCausalLM",
|
| 33 |
+
"GemmaForCausalLM",
|
| 34 |
+
"MistralForCausalLM",
|
| 35 |
+
"Qwen2ForCausalLM",
|
| 36 |
+
"PhiForCausalLM",
|
| 37 |
+
"Phi3ForCausalLM",
|
| 38 |
+
"from_pretrained",
|
| 39 |
+
"GLMForCausalLM",
|
| 40 |
+
]
|
c2cite/models/modeling_chatglm.py
ADDED
|
@@ -0,0 +1,855 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Dict, List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch.nn import LayerNorm
|
| 9 |
+
from transformers.utils import is_flash_attn_2_available
|
| 10 |
+
|
| 11 |
+
from moe_peft.common import (
|
| 12 |
+
FeedForward,
|
| 13 |
+
Linear,
|
| 14 |
+
LLMAttention,
|
| 15 |
+
LLMCache,
|
| 16 |
+
LLMDecoder,
|
| 17 |
+
LLMFeedForward,
|
| 18 |
+
LLMForCausalLM,
|
| 19 |
+
LLMModelConfig,
|
| 20 |
+
LLMModelInput,
|
| 21 |
+
collect_plugin_router_logtis,
|
| 22 |
+
flash_attention_forward,
|
| 23 |
+
slice_tensor,
|
| 24 |
+
)
|
| 25 |
+
from moe_peft.executors import executor
|
| 26 |
+
from moe_peft.utils import copy_parameters
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class GLMConfig(LLMModelConfig):
|
| 31 |
+
post_layer_norm: bool = True
|
| 32 |
+
rmsnorm: bool = True
|
| 33 |
+
layernorm_epsilon: float = 1e-5
|
| 34 |
+
apply_residual_connection_post_layernorm: bool = False
|
| 35 |
+
fp32_residual_connection: bool = False
|
| 36 |
+
kv_channels: int = 128
|
| 37 |
+
multi_query_attention: bool = False
|
| 38 |
+
multi_query_group_num: int = 2
|
| 39 |
+
apply_query_key_layer_scaling: bool = True
|
| 40 |
+
attention_softmax_in_fp32: bool = True
|
| 41 |
+
original_rope: bool = True
|
| 42 |
+
add_bias_linear: bool = False
|
| 43 |
+
padded_vocab_size: int = -1
|
| 44 |
+
rope_ratio: float = 1
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def split_tensor_along_last_dim(
|
| 48 |
+
tensor: torch.Tensor,
|
| 49 |
+
num_partitions: int,
|
| 50 |
+
contiguous_split_chunks: bool = False,
|
| 51 |
+
) -> List[torch.Tensor]:
|
| 52 |
+
# Get the size and dimension.
|
| 53 |
+
last_dim = tensor.dim() - 1
|
| 54 |
+
last_dim_size = tensor.size()[last_dim] // num_partitions
|
| 55 |
+
# Split.
|
| 56 |
+
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
|
| 57 |
+
# Note: torch.split does not create contiguous tensors by default.
|
| 58 |
+
if contiguous_split_chunks:
|
| 59 |
+
return tuple(chunk.contiguous() for chunk in tensor_list)
|
| 60 |
+
|
| 61 |
+
return tensor_list
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class RotaryEmbedding(nn.Module):
|
| 65 |
+
def __init__(self, dim, rope_ratio=1, original_impl=False, device=None, dtype=None):
|
| 66 |
+
super().__init__()
|
| 67 |
+
inv_freq = 1.0 / (
|
| 68 |
+
10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)
|
| 69 |
+
)
|
| 70 |
+
self.register_buffer("inv_freq", inv_freq)
|
| 71 |
+
self.dim = dim
|
| 72 |
+
self.original_impl = original_impl
|
| 73 |
+
self.rope_ratio = rope_ratio
|
| 74 |
+
|
| 75 |
+
def forward_impl(
|
| 76 |
+
self,
|
| 77 |
+
seq_len: int,
|
| 78 |
+
n_elem: int,
|
| 79 |
+
dtype: torch.dtype,
|
| 80 |
+
device: torch.device,
|
| 81 |
+
base: int = 10000,
|
| 82 |
+
):
|
| 83 |
+
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
|
| 84 |
+
base = base * self.rope_ratio
|
| 85 |
+
theta = 1.0 / (
|
| 86 |
+
base
|
| 87 |
+
** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem)
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# Create position indexes `[0, 1, ..., seq_len - 1]`
|
| 91 |
+
seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)
|
| 92 |
+
|
| 93 |
+
# Calculate the product of position index and $\theta_i$
|
| 94 |
+
idx_theta = torch.outer(seq_idx, theta).float()
|
| 95 |
+
|
| 96 |
+
cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
|
| 97 |
+
|
| 98 |
+
# this is to mimic the behaviour of complex32, else we will get different results
|
| 99 |
+
if dtype in (torch.float16, torch.bfloat16, torch.int8):
|
| 100 |
+
cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
|
| 101 |
+
return cache
|
| 102 |
+
|
| 103 |
+
def forward(self, max_seq_len, offset=0):
|
| 104 |
+
return self.forward_impl(
|
| 105 |
+
max_seq_len,
|
| 106 |
+
self.dim,
|
| 107 |
+
dtype=self.inv_freq.dtype,
|
| 108 |
+
device=self.inv_freq.device,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@torch.jit.script
|
| 113 |
+
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
|
| 114 |
+
# x: [b, np, sq, hn]
|
| 115 |
+
b, np, sq, _ = x.shape
|
| 116 |
+
rot_dim = rope_cache.shape[-2] * 2
|
| 117 |
+
x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
|
| 118 |
+
# truncate to support variable sizes
|
| 119 |
+
rope_cache = rope_cache[:, :sq]
|
| 120 |
+
xshaped = x.reshape(b, np, sq, rot_dim // 2, 2)
|
| 121 |
+
rope_cache = rope_cache.view(-1, 1, sq, xshaped.size(3), 2)
|
| 122 |
+
x_out2 = torch.stack(
|
| 123 |
+
[
|
| 124 |
+
xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
|
| 125 |
+
xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
|
| 126 |
+
],
|
| 127 |
+
-1,
|
| 128 |
+
)
|
| 129 |
+
x_out2 = x_out2.flatten(3)
|
| 130 |
+
return torch.cat((x_out2, x_pass), dim=-1)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class RMSNorm(torch.nn.Module):
|
| 134 |
+
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
|
| 135 |
+
super().__init__()
|
| 136 |
+
self.weight = torch.nn.Parameter(
|
| 137 |
+
torch.empty(normalized_shape, device=device, dtype=dtype)
|
| 138 |
+
)
|
| 139 |
+
self.eps = eps
|
| 140 |
+
|
| 141 |
+
def forward(self, hidden_states: torch.Tensor):
|
| 142 |
+
input_dtype = hidden_states.dtype
|
| 143 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
| 144 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
| 145 |
+
|
| 146 |
+
return (self.weight * hidden_states).to(input_dtype)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class CoreAttention(torch.nn.Module):
|
| 150 |
+
def __init__(self, config: GLMConfig, layer_number):
|
| 151 |
+
super(CoreAttention, self).__init__()
|
| 152 |
+
self.config = config
|
| 153 |
+
self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
|
| 154 |
+
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
|
| 155 |
+
if self.apply_query_key_layer_scaling:
|
| 156 |
+
self.attention_softmax_in_fp32 = True
|
| 157 |
+
self.layer_number = max(1, layer_number)
|
| 158 |
+
self.is_causal = True
|
| 159 |
+
|
| 160 |
+
projection_size = config.kv_channels * config.n_heads_
|
| 161 |
+
|
| 162 |
+
# Per attention head and per partition values.
|
| 163 |
+
self.hidden_size_per_partition = projection_size
|
| 164 |
+
self.hidden_size_per_attention_head = projection_size // config.n_heads_
|
| 165 |
+
self.num_attention_heads_per_partition = config.n_heads_
|
| 166 |
+
|
| 167 |
+
coeff = None
|
| 168 |
+
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
|
| 169 |
+
if self.apply_query_key_layer_scaling:
|
| 170 |
+
coeff = self.layer_number
|
| 171 |
+
self.norm_factor *= coeff
|
| 172 |
+
self.coeff = coeff
|
| 173 |
+
|
| 174 |
+
def forward(self, query_layer, key_layer, value_layer, attention_mask):
|
| 175 |
+
# [b, np, sq, sk]
|
| 176 |
+
output_size = (
|
| 177 |
+
query_layer.size(0),
|
| 178 |
+
query_layer.size(1),
|
| 179 |
+
query_layer.size(2),
|
| 180 |
+
key_layer.size(2),
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# [b, np, sq, hn] -> [b * np, sq, hn]
|
| 184 |
+
query_layer = query_layer.view(
|
| 185 |
+
output_size[0] * output_size[1], output_size[2], -1
|
| 186 |
+
)
|
| 187 |
+
# [b, np, sk, hn] -> [b * np, sk, hn]
|
| 188 |
+
key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1)
|
| 189 |
+
|
| 190 |
+
# preallocting input tensor: [b * np, sq, sk]
|
| 191 |
+
matmul_input_buffer = torch.empty(
|
| 192 |
+
output_size[0] * output_size[1],
|
| 193 |
+
output_size[2],
|
| 194 |
+
output_size[3],
|
| 195 |
+
dtype=query_layer.dtype,
|
| 196 |
+
device=query_layer.device,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
# Raw attention scores. [b * np, sq, sk]
|
| 200 |
+
matmul_result = torch.baddbmm(
|
| 201 |
+
matmul_input_buffer,
|
| 202 |
+
query_layer, # [b * np, sq, hn]
|
| 203 |
+
key_layer.transpose(1, 2), # [b * np, hn, sk]
|
| 204 |
+
beta=0.0,
|
| 205 |
+
alpha=(1.0 / self.norm_factor),
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
# change view to [b, np, sq, sk]
|
| 209 |
+
attention_scores = matmul_result.view(*output_size)
|
| 210 |
+
|
| 211 |
+
# attention scores and attention mask [b, np, sq, sk]
|
| 212 |
+
if self.attention_softmax_in_fp32:
|
| 213 |
+
attention_scores = attention_scores.float()
|
| 214 |
+
if self.coeff is not None:
|
| 215 |
+
attention_scores = attention_scores * self.coeff
|
| 216 |
+
if (
|
| 217 |
+
attention_mask is None
|
| 218 |
+
and attention_scores.shape[2] == attention_scores.shape[3]
|
| 219 |
+
):
|
| 220 |
+
attention_mask = torch.ones(
|
| 221 |
+
output_size[0],
|
| 222 |
+
1,
|
| 223 |
+
output_size[2],
|
| 224 |
+
output_size[3],
|
| 225 |
+
device=attention_scores.device,
|
| 226 |
+
dtype=torch.bool,
|
| 227 |
+
)
|
| 228 |
+
attention_mask.tril_()
|
| 229 |
+
attention_mask = ~attention_mask
|
| 230 |
+
if attention_mask is not None:
|
| 231 |
+
attention_scores = attention_scores.masked_fill(
|
| 232 |
+
attention_mask, float("-inf")
|
| 233 |
+
)
|
| 234 |
+
attention_probs = F.softmax(attention_scores, dim=-1)
|
| 235 |
+
attention_probs = attention_probs.type_as(value_layer)
|
| 236 |
+
|
| 237 |
+
# query layer shape: [b * np, sq, hn]
|
| 238 |
+
# value layer shape: [b, np, sk, hn]
|
| 239 |
+
# attention shape: [b, np, sq, sk]
|
| 240 |
+
# context layer shape: [b, np, sq, hn]
|
| 241 |
+
output_size = (
|
| 242 |
+
value_layer.size(0),
|
| 243 |
+
value_layer.size(1),
|
| 244 |
+
query_layer.size(1),
|
| 245 |
+
value_layer.size(3),
|
| 246 |
+
)
|
| 247 |
+
# change view [b * np, sk, hn]
|
| 248 |
+
value_layer = value_layer.view(
|
| 249 |
+
output_size[0] * output_size[1], value_layer.size(2), -1
|
| 250 |
+
)
|
| 251 |
+
# change view [b * np, sq, sk]
|
| 252 |
+
attention_probs = attention_probs.view(
|
| 253 |
+
output_size[0] * output_size[1], output_size[2], -1
|
| 254 |
+
)
|
| 255 |
+
# matmul: [b * np, sq, hn]
|
| 256 |
+
context_layer = torch.bmm(attention_probs, value_layer)
|
| 257 |
+
# change view [b, np, sq, hn]
|
| 258 |
+
context_layer = context_layer.view(*output_size)
|
| 259 |
+
# [b, np, sq, hn] --> [b, sq, np, hn]
|
| 260 |
+
context_layer = context_layer.transpose(1, 2).contiguous()
|
| 261 |
+
# [b, sq, np, hn] --> [b, sq, hp]
|
| 262 |
+
new_context_layer_shape = context_layer.size()[:-2] + (
|
| 263 |
+
self.hidden_size_per_partition,
|
| 264 |
+
)
|
| 265 |
+
context_layer = context_layer.reshape(*new_context_layer_shape)
|
| 266 |
+
|
| 267 |
+
return context_layer
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
class FlashAttention2(CoreAttention):
|
| 271 |
+
def __init__(self, *args, **kwargs):
|
| 272 |
+
assert is_flash_attn_2_available(), "Flash Attention is not available"
|
| 273 |
+
super().__init__(*args, **kwargs)
|
| 274 |
+
|
| 275 |
+
def forward(self, query_states, key_states, value_states, attention_mask):
|
| 276 |
+
query_states = query_states.transpose(1, 2)
|
| 277 |
+
key_states = key_states.transpose(1, 2)
|
| 278 |
+
value_states = value_states.transpose(1, 2)
|
| 279 |
+
|
| 280 |
+
batch_size, query_length = query_states.shape[:2]
|
| 281 |
+
|
| 282 |
+
attn_output = flash_attention_forward(
|
| 283 |
+
query_states,
|
| 284 |
+
key_states,
|
| 285 |
+
value_states,
|
| 286 |
+
attention_mask,
|
| 287 |
+
query_length,
|
| 288 |
+
is_causal=self.is_causal,
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
attn_output = attn_output.reshape(
|
| 292 |
+
batch_size, query_length, self.hidden_size_per_partition
|
| 293 |
+
).contiguous()
|
| 294 |
+
|
| 295 |
+
return attn_output
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
CORE_ATTENTION_CLASSES = {
|
| 299 |
+
"eager": CoreAttention,
|
| 300 |
+
"flash_attn": FlashAttention2,
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
class GLMSelfAttention(LLMAttention):
|
| 305 |
+
def __init__(
|
| 306 |
+
self,
|
| 307 |
+
qkv_layer: torch.nn.Module,
|
| 308 |
+
dense_layer: torch.nn.Module,
|
| 309 |
+
config: GLMConfig,
|
| 310 |
+
layer_idx,
|
| 311 |
+
):
|
| 312 |
+
super(GLMSelfAttention, self).__init__()
|
| 313 |
+
self.layer_idx = layer_idx
|
| 314 |
+
|
| 315 |
+
self.projection_size = config.kv_channels * config.n_heads_
|
| 316 |
+
|
| 317 |
+
# Per attention head and per-partition values.
|
| 318 |
+
self.hidden_size_per_attention_head = self.projection_size // config.n_heads_
|
| 319 |
+
self.num_attention_heads_per_partition = config.n_heads_
|
| 320 |
+
self.multi_query_attention = config.multi_query_attention
|
| 321 |
+
self.qkv_hidden_size = 3 * self.projection_size
|
| 322 |
+
|
| 323 |
+
if self.multi_query_attention:
|
| 324 |
+
self.num_multi_query_groups_per_partition = config.multi_query_group_num
|
| 325 |
+
self.qkv_hidden_size = (
|
| 326 |
+
self.projection_size
|
| 327 |
+
+ 2
|
| 328 |
+
* self.hidden_size_per_attention_head
|
| 329 |
+
* self.num_multi_query_groups_per_partition
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
# QKV layer.
|
| 333 |
+
self.query_key_value = Linear(base_layer=qkv_layer, device=config.device_)
|
| 334 |
+
# Core attention layer.
|
| 335 |
+
self.core_attention = CORE_ATTENTION_CLASSES[config.attn_implementation_](
|
| 336 |
+
config, self.layer_idx
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
# Dense layer.
|
| 340 |
+
self.dense = Linear(base_layer=dense_layer, device=config.device_)
|
| 341 |
+
|
| 342 |
+
def state_dict(self) -> Dict[str, Linear]:
|
| 343 |
+
return {"qkv_proj": self.query_key_value, "dense": self.dense}
|
| 344 |
+
|
| 345 |
+
def forward(
|
| 346 |
+
self,
|
| 347 |
+
hidden_states: torch.Tensor,
|
| 348 |
+
input_args: LLMModelInput,
|
| 349 |
+
rotary_pos_emb: Tuple[torch.Tensor, torch.Tensor],
|
| 350 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 351 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 352 |
+
past_key_value: Optional[LLMCache] = None,
|
| 353 |
+
):
|
| 354 |
+
mixed_x_layer = self.query_key_value(hidden_states, input_args)
|
| 355 |
+
|
| 356 |
+
if self.multi_query_attention:
|
| 357 |
+
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
|
| 358 |
+
[
|
| 359 |
+
self.num_attention_heads_per_partition
|
| 360 |
+
* self.hidden_size_per_attention_head,
|
| 361 |
+
self.num_multi_query_groups_per_partition
|
| 362 |
+
* self.hidden_size_per_attention_head,
|
| 363 |
+
self.num_multi_query_groups_per_partition
|
| 364 |
+
* self.hidden_size_per_attention_head,
|
| 365 |
+
],
|
| 366 |
+
dim=-1,
|
| 367 |
+
)
|
| 368 |
+
query_layer = query_layer.view(
|
| 369 |
+
query_layer.size()[:-1]
|
| 370 |
+
+ (
|
| 371 |
+
self.num_attention_heads_per_partition,
|
| 372 |
+
self.hidden_size_per_attention_head,
|
| 373 |
+
)
|
| 374 |
+
)
|
| 375 |
+
key_layer = key_layer.view(
|
| 376 |
+
key_layer.size()[:-1]
|
| 377 |
+
+ (
|
| 378 |
+
self.num_multi_query_groups_per_partition,
|
| 379 |
+
self.hidden_size_per_attention_head,
|
| 380 |
+
)
|
| 381 |
+
)
|
| 382 |
+
value_layer = value_layer.view(
|
| 383 |
+
value_layer.size()[:-1]
|
| 384 |
+
+ (
|
| 385 |
+
self.num_multi_query_groups_per_partition,
|
| 386 |
+
self.hidden_size_per_attention_head,
|
| 387 |
+
)
|
| 388 |
+
)
|
| 389 |
+
else:
|
| 390 |
+
new_tensor_shape = mixed_x_layer.size()[:-1] + (
|
| 391 |
+
self.num_attention_heads_per_partition,
|
| 392 |
+
3 * self.hidden_size_per_attention_head,
|
| 393 |
+
)
|
| 394 |
+
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
| 395 |
+
|
| 396 |
+
# [b, sq, np, 3 * hn] --> 3 [b, sq, np, hn]
|
| 397 |
+
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(
|
| 398 |
+
mixed_x_layer, 3
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
# [b, sq, np, hn] -> [b, np, sq, hn]
|
| 402 |
+
query_layer, key_layer, value_layer = [
|
| 403 |
+
k.transpose(1, 2) for k in [query_layer, key_layer, value_layer]
|
| 404 |
+
]
|
| 405 |
+
|
| 406 |
+
# apply relative positional encoding (rotary embedding)
|
| 407 |
+
query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
|
| 408 |
+
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
|
| 409 |
+
|
| 410 |
+
if past_key_value is not None:
|
| 411 |
+
key_layer, value_layer = past_key_value.update(
|
| 412 |
+
key_layer,
|
| 413 |
+
value_layer,
|
| 414 |
+
self.layer_idx,
|
| 415 |
+
{"cache_position": cache_position},
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
if self.multi_query_attention:
|
| 419 |
+
key_layer = key_layer.unsqueeze(2)
|
| 420 |
+
key_layer = key_layer.expand(
|
| 421 |
+
-1,
|
| 422 |
+
-1,
|
| 423 |
+
self.num_attention_heads_per_partition
|
| 424 |
+
// self.num_multi_query_groups_per_partition,
|
| 425 |
+
-1,
|
| 426 |
+
-1,
|
| 427 |
+
)
|
| 428 |
+
key_layer = key_layer.contiguous().view(
|
| 429 |
+
key_layer.size()[:1]
|
| 430 |
+
+ (self.num_attention_heads_per_partition,)
|
| 431 |
+
+ key_layer.size()[3:]
|
| 432 |
+
)
|
| 433 |
+
value_layer = value_layer.unsqueeze(2)
|
| 434 |
+
value_layer = value_layer.expand(
|
| 435 |
+
-1,
|
| 436 |
+
-1,
|
| 437 |
+
self.num_attention_heads_per_partition
|
| 438 |
+
// self.num_multi_query_groups_per_partition,
|
| 439 |
+
-1,
|
| 440 |
+
-1,
|
| 441 |
+
)
|
| 442 |
+
value_layer = value_layer.contiguous().view(
|
| 443 |
+
value_layer.size()[:1]
|
| 444 |
+
+ (self.num_attention_heads_per_partition,)
|
| 445 |
+
+ value_layer.size()[3:]
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
context_layer = self.core_attention(
|
| 449 |
+
query_layer,
|
| 450 |
+
key_layer,
|
| 451 |
+
value_layer,
|
| 452 |
+
attention_mask,
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
output = self.dense(context_layer, input_args)
|
| 456 |
+
|
| 457 |
+
return output
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def swiglu(x):
|
| 461 |
+
x = torch.chunk(x, 2, dim=-1)
|
| 462 |
+
return F.silu(x[0]) * x[1]
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
class GLMMLP(LLMFeedForward):
|
| 466 |
+
def __init__(
|
| 467 |
+
self,
|
| 468 |
+
dense_h_to_4h: torch.nn.Module,
|
| 469 |
+
dense_4h_to_h: torch.nn.Module,
|
| 470 |
+
config: GLMConfig,
|
| 471 |
+
) -> None:
|
| 472 |
+
super().__init__()
|
| 473 |
+
self.dense_h_to_4h: Linear = Linear(dense_h_to_4h, config.device_)
|
| 474 |
+
self.dense_4h_to_h: Linear = Linear(dense_4h_to_h, config.device_)
|
| 475 |
+
|
| 476 |
+
self.activation_func = swiglu
|
| 477 |
+
|
| 478 |
+
def state_dict(self) -> Dict[str, torch.nn.Module]:
|
| 479 |
+
return {
|
| 480 |
+
"dense_h_to_4h": self.dense_h_to_4h,
|
| 481 |
+
"dense_4h_to_h": self.dense_4h_to_h,
|
| 482 |
+
}
|
| 483 |
+
|
| 484 |
+
def _batch_forward(
|
| 485 |
+
self, data: torch.Tensor, input_args: LLMModelInput
|
| 486 |
+
) -> torch.Tensor:
|
| 487 |
+
# [b, sq, h] -> [b, sq, 4hp]
|
| 488 |
+
intermediate_parallel = self.dense_h_to_4h(data, input_args)
|
| 489 |
+
intermediate_parallel = self.activation_func(intermediate_parallel)
|
| 490 |
+
# [b, sq, 4hp] -> [b, sq, h]
|
| 491 |
+
output = self.dense_4h_to_h(intermediate_parallel, input_args)
|
| 492 |
+
return output
|
| 493 |
+
|
| 494 |
+
def _lora_forward(
|
| 495 |
+
self, lora_name: str, act_fn: torch.nn.Module, hidden_states: torch.Tensor
|
| 496 |
+
) -> torch.Tensor:
|
| 497 |
+
if lora_name in self.dense_h_to_4h.loras_:
|
| 498 |
+
hidden_states = self.dense_h_to_4h.loras_[lora_name].forward(
|
| 499 |
+
self.dense_h_to_4h.base_layer_.forward(hidden_states), hidden_states
|
| 500 |
+
)
|
| 501 |
+
else:
|
| 502 |
+
hidden_states = self.dense_h_to_4h.base_layer_.forward(hidden_states)
|
| 503 |
+
|
| 504 |
+
hidden_states = self.activation_func(hidden_states)
|
| 505 |
+
|
| 506 |
+
if lora_name in self.dense_4h_to_h.loras_:
|
| 507 |
+
hidden_states = self.dense_4h_to_h.loras_[lora_name].forward(
|
| 508 |
+
self.dense_4h_to_h.base_layer_.forward(hidden_states), hidden_states
|
| 509 |
+
)
|
| 510 |
+
else:
|
| 511 |
+
hidden_states = self.dense_4h_to_h.base_layer_.forward(hidden_states)
|
| 512 |
+
|
| 513 |
+
return hidden_states
|
| 514 |
+
|
| 515 |
+
def _mixlora_forward(
|
| 516 |
+
self, moe_name, act_fn, expert_mask, hidden_states, input_dtype
|
| 517 |
+
):
|
| 518 |
+
common_dense_h_to_4h = self.dense_h_to_4h.base_layer_.forward(
|
| 519 |
+
hidden_states.to(input_dtype)
|
| 520 |
+
).to(hidden_states.dtype)
|
| 521 |
+
final_expert_states = []
|
| 522 |
+
for expert_idx in range(expert_mask.shape[0]):
|
| 523 |
+
_, top_x = torch.where(expert_mask[expert_idx])
|
| 524 |
+
|
| 525 |
+
lora_name = f"moe.{moe_name}.experts.{expert_idx}"
|
| 526 |
+
if lora_name in self.dense_h_to_4h.loras_:
|
| 527 |
+
lora_data = slice_tensor(hidden_states, top_x, input_dtype)
|
| 528 |
+
act_result = self.activation_func(
|
| 529 |
+
self.dense_h_to_4h.loras_[lora_name].forward(
|
| 530 |
+
slice_tensor(common_dense_h_to_4h, top_x, input_dtype),
|
| 531 |
+
lora_data,
|
| 532 |
+
)
|
| 533 |
+
)
|
| 534 |
+
else:
|
| 535 |
+
act_result = self.activation_func(
|
| 536 |
+
slice_tensor(common_dense_h_to_4h, top_x, input_dtype)
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
if lora_name in self.dense_4h_to_h.loras_:
|
| 540 |
+
final_expert_states.append(
|
| 541 |
+
self.dense_4h_to_h.loras_[lora_name].forward(
|
| 542 |
+
self.dense_4h_to_h.base_layer_.forward(act_result), act_result
|
| 543 |
+
)
|
| 544 |
+
)
|
| 545 |
+
else:
|
| 546 |
+
final_expert_states.append(
|
| 547 |
+
self.dense_4h_to_h.base_layer_.forward(act_result)
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
return final_expert_states
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
class GLMDecoderLayer(LLMDecoder):
|
| 554 |
+
def __init__(
|
| 555 |
+
self, self_attn: GLMSelfAttention, mlp: FeedForward, config: GLMConfig
|
| 556 |
+
) -> None:
|
| 557 |
+
super().__init__()
|
| 558 |
+
self.layer_id_ = self_attn.layer_idx
|
| 559 |
+
self.apply_residual_connection_post_layernorm = (
|
| 560 |
+
config.apply_residual_connection_post_layernorm
|
| 561 |
+
)
|
| 562 |
+
self.fp32_residual_connection = config.fp32_residual_connection
|
| 563 |
+
|
| 564 |
+
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
| 565 |
+
# Input layer norm.
|
| 566 |
+
self.input_layernorm = LayerNormFunc(
|
| 567 |
+
config.dim_,
|
| 568 |
+
eps=config.layernorm_epsilon,
|
| 569 |
+
device=config.device_,
|
| 570 |
+
dtype=config.dtype_,
|
| 571 |
+
)
|
| 572 |
+
# Self-attention layer.
|
| 573 |
+
self.self_attn_: GLMSelfAttention = self_attn
|
| 574 |
+
self.hidden_dropout = config.hidden_dropout_
|
| 575 |
+
|
| 576 |
+
# Post attention layer norm.
|
| 577 |
+
self.post_layernorm = LayerNormFunc(
|
| 578 |
+
config.dim_,
|
| 579 |
+
eps=config.layernorm_epsilon,
|
| 580 |
+
device=config.device_,
|
| 581 |
+
dtype=config.dtype_,
|
| 582 |
+
)
|
| 583 |
+
# mlp
|
| 584 |
+
self.mlp_: FeedForward = mlp
|
| 585 |
+
|
| 586 |
+
def state_dict(self) -> Tuple[Dict[str, nn.Module], Dict[str, nn.Module]]:
|
| 587 |
+
return self.self_attn_.state_dict(), self.mlp_.state_dict()
|
| 588 |
+
|
| 589 |
+
def forward(
|
| 590 |
+
self,
|
| 591 |
+
hidden_states: torch.Tensor,
|
| 592 |
+
input_args: LLMModelInput,
|
| 593 |
+
rotary_pos_emb: Tuple[torch.Tensor, torch.Tensor],
|
| 594 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 595 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 596 |
+
past_key_value: Optional[LLMCache] = None,
|
| 597 |
+
):
|
| 598 |
+
layernorm_output = self.input_layernorm(hidden_states)
|
| 599 |
+
|
| 600 |
+
attention_output = self.self_attn_.forward(
|
| 601 |
+
layernorm_output,
|
| 602 |
+
input_args,
|
| 603 |
+
rotary_pos_emb,
|
| 604 |
+
attention_mask,
|
| 605 |
+
cache_position,
|
| 606 |
+
past_key_value,
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
# Residual connection.
|
| 610 |
+
if self.apply_residual_connection_post_layernorm:
|
| 611 |
+
residual = layernorm_output
|
| 612 |
+
else:
|
| 613 |
+
residual = hidden_states
|
| 614 |
+
|
| 615 |
+
layernorm_input = F.dropout(
|
| 616 |
+
attention_output,
|
| 617 |
+
p=self.hidden_dropout,
|
| 618 |
+
training=not input_args.inference_mode_,
|
| 619 |
+
)
|
| 620 |
+
layernorm_input = residual + layernorm_input
|
| 621 |
+
|
| 622 |
+
# Layer norm post the self attention.
|
| 623 |
+
layernorm_output = self.post_layernorm(layernorm_input)
|
| 624 |
+
|
| 625 |
+
# MLP.
|
| 626 |
+
mlp_output, router_logits = self.mlp_(layernorm_output, input_args)
|
| 627 |
+
|
| 628 |
+
# Second residual connection.
|
| 629 |
+
if self.apply_residual_connection_post_layernorm:
|
| 630 |
+
residual = layernorm_output
|
| 631 |
+
else:
|
| 632 |
+
residual = layernorm_input
|
| 633 |
+
|
| 634 |
+
output = F.dropout(
|
| 635 |
+
mlp_output, p=self.hidden_dropout, training=not input_args.inference_mode_
|
| 636 |
+
)
|
| 637 |
+
output = residual + output
|
| 638 |
+
|
| 639 |
+
if input_args.output_router_logits_:
|
| 640 |
+
router_logits = collect_plugin_router_logtis(
|
| 641 |
+
router_logits, input_args, self
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
return output, *router_logits
|
| 645 |
+
|
| 646 |
+
|
| 647 |
+
class GLMEmbedding(torch.nn.Module):
|
| 648 |
+
def __init__(self, config: GLMConfig):
|
| 649 |
+
super(GLMEmbedding, self).__init__()
|
| 650 |
+
|
| 651 |
+
self.hidden_size = config.dim_
|
| 652 |
+
# Word embeddings (parallel).
|
| 653 |
+
self.word_embeddings = nn.Embedding(
|
| 654 |
+
config.padded_vocab_size,
|
| 655 |
+
self.hidden_size,
|
| 656 |
+
dtype=config.dtype_,
|
| 657 |
+
device=config.device_,
|
| 658 |
+
)
|
| 659 |
+
self.fp32_residual_connection = config.fp32_residual_connection
|
| 660 |
+
|
| 661 |
+
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 662 |
+
# Embeddings.
|
| 663 |
+
words_embeddings = self.word_embeddings(input_ids)
|
| 664 |
+
embeddings = words_embeddings
|
| 665 |
+
# If the input flag for fp32 residual connection is set, convert for float.
|
| 666 |
+
if self.fp32_residual_connection:
|
| 667 |
+
embeddings = embeddings.float()
|
| 668 |
+
return embeddings
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
class GLMForCausalLM(LLMForCausalLM):
|
| 672 |
+
def __init__(self, config: GLMConfig) -> None:
|
| 673 |
+
self.config_ = config
|
| 674 |
+
self.padding_idx_ = config.pad_token_id_
|
| 675 |
+
self.vocab_size_ = config.vocab_size_
|
| 676 |
+
|
| 677 |
+
# Embedding layer.
|
| 678 |
+
self.embed_tokens_ = GLMEmbedding(config)
|
| 679 |
+
# Rotary Position Embedding.
|
| 680 |
+
self.rotary_emb_layer: RotaryEmbedding = None
|
| 681 |
+
# Encoder(Decoder) layers.
|
| 682 |
+
self.layers_: List[GLMDecoderLayer] = []
|
| 683 |
+
# Final layer norm.
|
| 684 |
+
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
| 685 |
+
if self.config_.post_layer_norm:
|
| 686 |
+
self.final_layernorm_ = LayerNormFunc(
|
| 687 |
+
config.dim_,
|
| 688 |
+
eps=config.layernorm_epsilon,
|
| 689 |
+
device=config.device_,
|
| 690 |
+
dtype=config.dtype_,
|
| 691 |
+
)
|
| 692 |
+
else:
|
| 693 |
+
self.final_layernorm_ = nn.Identity()
|
| 694 |
+
# Output layer.
|
| 695 |
+
self.lm_head_ = torch.nn.Linear(
|
| 696 |
+
config.dim_,
|
| 697 |
+
config.vocab_size_,
|
| 698 |
+
bias=config.add_bias_linear,
|
| 699 |
+
dtype=config.dtype_,
|
| 700 |
+
device=config.device_,
|
| 701 |
+
)
|
| 702 |
+
|
| 703 |
+
def embed_tokens(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 704 |
+
return self.embed_tokens_(input_ids)
|
| 705 |
+
|
| 706 |
+
def rotary_embed(
|
| 707 |
+
self, input_tensor: torch.Tensor, position_ids: torch.Tensor
|
| 708 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 709 |
+
return self.rotary_emb_layer(max_seq_len=self.config_.max_seq_len_)[
|
| 710 |
+
None, position_ids[-1]
|
| 711 |
+
]
|
| 712 |
+
|
| 713 |
+
def decoder_stack(self) -> List[LLMDecoder]:
|
| 714 |
+
return self.layers_
|
| 715 |
+
|
| 716 |
+
def norm(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 717 |
+
return self.final_layernorm_(hidden_states)
|
| 718 |
+
|
| 719 |
+
def get_masks(
|
| 720 |
+
self,
|
| 721 |
+
input_ids: torch.Tensor,
|
| 722 |
+
past_key_values: LLMCache,
|
| 723 |
+
padding_mask: torch.Tensor,
|
| 724 |
+
):
|
| 725 |
+
batch_size, seq_length, _ = input_ids.shape
|
| 726 |
+
full_attention_mask = torch.ones(
|
| 727 |
+
batch_size, seq_length, seq_length, device=input_ids.device
|
| 728 |
+
)
|
| 729 |
+
full_attention_mask.tril_()
|
| 730 |
+
past_length = 0
|
| 731 |
+
if past_key_values:
|
| 732 |
+
past_length = past_key_values.get_seq_length()
|
| 733 |
+
if past_length:
|
| 734 |
+
full_attention_mask = torch.cat(
|
| 735 |
+
(
|
| 736 |
+
torch.ones(
|
| 737 |
+
batch_size, seq_length, past_length, device=input_ids.device
|
| 738 |
+
),
|
| 739 |
+
full_attention_mask,
|
| 740 |
+
),
|
| 741 |
+
dim=-1,
|
| 742 |
+
)
|
| 743 |
+
if padding_mask is not None:
|
| 744 |
+
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
|
| 745 |
+
if not past_length and padding_mask is not None:
|
| 746 |
+
full_attention_mask -= padding_mask.unsqueeze(-1) - 1
|
| 747 |
+
full_attention_mask = (full_attention_mask < 0.5).bool()
|
| 748 |
+
full_attention_mask.unsqueeze_(1)
|
| 749 |
+
return full_attention_mask
|
| 750 |
+
|
| 751 |
+
def causal_mask(
|
| 752 |
+
self,
|
| 753 |
+
attention_mask: torch.Tensor,
|
| 754 |
+
input_tensor: torch.Tensor,
|
| 755 |
+
cache_position: torch.Tensor,
|
| 756 |
+
past_key_values: Optional[LLMCache],
|
| 757 |
+
) -> torch.Tensor:
|
| 758 |
+
return self.get_masks(input_tensor, past_key_values, attention_mask)
|
| 759 |
+
|
| 760 |
+
def model_config(self) -> GLMConfig:
|
| 761 |
+
return self.config_
|
| 762 |
+
|
| 763 |
+
@staticmethod
|
| 764 |
+
def from_pretrained(
|
| 765 |
+
llm_model,
|
| 766 |
+
attn_impl: str = "eager",
|
| 767 |
+
use_sliding_window: bool = False,
|
| 768 |
+
device: str = executor.default_device_name(),
|
| 769 |
+
):
|
| 770 |
+
assert not use_sliding_window, "ChatGLM model does not support SWA."
|
| 771 |
+
# Get the config from LLM model and input args.
|
| 772 |
+
llm_config = llm_model.config
|
| 773 |
+
config = GLMConfig(
|
| 774 |
+
# LLM model args.
|
| 775 |
+
name_or_path_=llm_config._name_or_path,
|
| 776 |
+
device_=device,
|
| 777 |
+
dim_=llm_config.hidden_size,
|
| 778 |
+
head_dim_=llm_config.hidden_size // llm_config.num_attention_heads,
|
| 779 |
+
n_heads_=llm_config.num_attention_heads,
|
| 780 |
+
n_kv_heads_=llm_config.multi_query_group_num,
|
| 781 |
+
n_layers_=llm_config.num_layers,
|
| 782 |
+
hidden_act_=swiglu,
|
| 783 |
+
hidden_dropout_=llm_config.hidden_dropout,
|
| 784 |
+
vocab_size_=llm_config.vocab_size,
|
| 785 |
+
pad_token_id_=llm_config.pad_token_id,
|
| 786 |
+
max_seq_len_=llm_config.seq_length,
|
| 787 |
+
attn_implementation_=attn_impl,
|
| 788 |
+
dtype_=llm_model.dtype,
|
| 789 |
+
# ChatGLM args.
|
| 790 |
+
post_layer_norm=llm_config.post_layer_norm,
|
| 791 |
+
rmsnorm=llm_config.rmsnorm,
|
| 792 |
+
layernorm_epsilon=llm_config.layernorm_epsilon,
|
| 793 |
+
apply_residual_connection_post_layernorm=llm_config.apply_residual_connection_post_layernorm,
|
| 794 |
+
fp32_residual_connection=llm_config.fp32_residual_connection,
|
| 795 |
+
apply_query_key_layer_scaling=llm_config.apply_query_key_layer_scaling,
|
| 796 |
+
kv_channels=llm_config.kv_channels,
|
| 797 |
+
multi_query_attention=llm_config.multi_query_attention,
|
| 798 |
+
multi_query_group_num=llm_config.multi_query_group_num,
|
| 799 |
+
attention_softmax_in_fp32=llm_config.attention_softmax_in_fp32,
|
| 800 |
+
original_rope=llm_config.original_rope,
|
| 801 |
+
add_bias_linear=llm_config.add_bias_linear,
|
| 802 |
+
padded_vocab_size=llm_config.padded_vocab_size,
|
| 803 |
+
rope_ratio=(
|
| 804 |
+
llm_config.rope_ratio if hasattr(llm_config, "rope_ratio") else 1
|
| 805 |
+
),
|
| 806 |
+
)
|
| 807 |
+
|
| 808 |
+
model = GLMForCausalLM(config)
|
| 809 |
+
llm_model.requires_grad_(False)
|
| 810 |
+
|
| 811 |
+
copy_parameters(
|
| 812 |
+
llm_model.transformer.embedding,
|
| 813 |
+
model.embed_tokens_,
|
| 814 |
+
)
|
| 815 |
+
|
| 816 |
+
rotary_dim = (
|
| 817 |
+
config.dim_ // config.n_heads_
|
| 818 |
+
if config.kv_channels is None
|
| 819 |
+
else config.kv_channels
|
| 820 |
+
)
|
| 821 |
+
model.rotary_emb_layer = RotaryEmbedding(
|
| 822 |
+
dim=rotary_dim // 2,
|
| 823 |
+
rope_ratio=config.rope_ratio,
|
| 824 |
+
original_impl=config.original_rope,
|
| 825 |
+
device=device,
|
| 826 |
+
dtype=config.dtype_,
|
| 827 |
+
)
|
| 828 |
+
|
| 829 |
+
for idx, layer in enumerate(llm_model.transformer.encoder.layers):
|
| 830 |
+
# Get self-attention layer.
|
| 831 |
+
self_attention = GLMSelfAttention(
|
| 832 |
+
qkv_layer=layer.self_attention.query_key_value,
|
| 833 |
+
dense_layer=layer.self_attention.dense,
|
| 834 |
+
config=config,
|
| 835 |
+
layer_idx=idx,
|
| 836 |
+
)
|
| 837 |
+
# Get MLP layer.
|
| 838 |
+
mlp = FeedForward(
|
| 839 |
+
GLMMLP(layer.mlp.dense_h_to_4h, layer.mlp.dense_4h_to_h, config=config)
|
| 840 |
+
)
|
| 841 |
+
# Create a transformer block.
|
| 842 |
+
encoder = GLMDecoderLayer(self_attention, mlp, config)
|
| 843 |
+
copy_parameters(layer.input_layernorm, encoder.input_layernorm)
|
| 844 |
+
copy_parameters(layer.post_attention_layernorm, encoder.post_layernorm)
|
| 845 |
+
model.layers_.append(encoder)
|
| 846 |
+
|
| 847 |
+
if config.post_layer_norm:
|
| 848 |
+
copy_parameters(
|
| 849 |
+
llm_model.transformer.encoder.final_layernorm,
|
| 850 |
+
model.final_layernorm_,
|
| 851 |
+
)
|
| 852 |
+
|
| 853 |
+
copy_parameters(llm_model.transformer.output_layer, model.lm_head_)
|
| 854 |
+
|
| 855 |
+
return model
|
c2cite/models/modeling_gemma.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from transformers.models.gemma import modeling_gemma
|
| 5 |
+
|
| 6 |
+
from moe_peft.common import FeedForward
|
| 7 |
+
from moe_peft.executors import executor
|
| 8 |
+
from moe_peft.models.modeling_llama import (
|
| 9 |
+
LLAMA_ATTENTION_CLASSES as GEMMA_ATTENTION_CLASSES,
|
| 10 |
+
)
|
| 11 |
+
from moe_peft.models.modeling_llama import (
|
| 12 |
+
LlamaConfig,
|
| 13 |
+
LlamaDecoderLayer,
|
| 14 |
+
LlamaForCausalLM,
|
| 15 |
+
LlamaMLP,
|
| 16 |
+
)
|
| 17 |
+
from moe_peft.utils import copy_parameters
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class GemmaRMSNorm(nn.Module):
|
| 21 |
+
def __init__(self, weight: torch.Tensor, eps: float = 1e-6):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.norm_eps_ = eps
|
| 24 |
+
self.weight_ = weight
|
| 25 |
+
|
| 26 |
+
def _norm(self, x):
|
| 27 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.norm_eps_)
|
| 28 |
+
|
| 29 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 30 |
+
output = self._norm(x.to(torch.float32))
|
| 31 |
+
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
|
| 32 |
+
# See https://github.com/huggingface/transformers/pull/29402
|
| 33 |
+
output = output * (1.0 + self.weight_.to(torch.float32))
|
| 34 |
+
return output.to(x.dtype)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class GemmaEmbedding(nn.Module):
|
| 38 |
+
def __init__(self, embedding: torch.Tensor, pad_token: int, normalizer: float):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.token_embedding_: torch.Tensor = embedding
|
| 41 |
+
self.padding_idx_: int = pad_token
|
| 42 |
+
self.normalizer_: float = normalizer
|
| 43 |
+
|
| 44 |
+
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
| 45 |
+
data = F.embedding(tokens, self.token_embedding_, padding_idx=self.padding_idx_)
|
| 46 |
+
# normalized
|
| 47 |
+
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
|
| 48 |
+
# See https://github.com/huggingface/transformers/pull/29402
|
| 49 |
+
normalizer = torch.tensor(self.normalizer_, dtype=data.dtype)
|
| 50 |
+
return data * normalizer
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _patch_hidden_act(config: modeling_gemma.GemmaConfig) -> str:
|
| 54 |
+
if hasattr(config, "hidden_activation") and config.hidden_activation is not None:
|
| 55 |
+
return config.hidden_activation
|
| 56 |
+
else:
|
| 57 |
+
return config.hidden_act
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class GemmaForCausalLM(LlamaForCausalLM):
|
| 61 |
+
def __init__(self, config: LlamaConfig) -> None:
|
| 62 |
+
super().__init__(config)
|
| 63 |
+
|
| 64 |
+
@staticmethod
|
| 65 |
+
def from_pretrained(
|
| 66 |
+
llm_model: modeling_gemma.GemmaForCausalLM,
|
| 67 |
+
attn_impl: str = "eager",
|
| 68 |
+
use_sliding_window: bool = False,
|
| 69 |
+
device: str = executor.default_device_name(),
|
| 70 |
+
):
|
| 71 |
+
assert not use_sliding_window, "Gemma model does not support SWA."
|
| 72 |
+
llm_config: modeling_gemma.GemmaConfig = llm_model.config
|
| 73 |
+
llm_args = LlamaConfig(
|
| 74 |
+
name_or_path_=llm_config.name_or_path,
|
| 75 |
+
vocab_size_=llm_config.vocab_size,
|
| 76 |
+
dim_=llm_config.hidden_size,
|
| 77 |
+
head_dim_=llm_config.head_dim,
|
| 78 |
+
intermediate_=llm_config.intermediate_size,
|
| 79 |
+
n_layers_=llm_config.num_hidden_layers,
|
| 80 |
+
n_heads_=llm_config.num_attention_heads,
|
| 81 |
+
n_kv_heads_=llm_config.num_key_value_heads,
|
| 82 |
+
hidden_act_=_patch_hidden_act(llm_config),
|
| 83 |
+
rms_norm_eps_=llm_config.rms_norm_eps,
|
| 84 |
+
max_seq_len_=llm_config.max_position_embeddings,
|
| 85 |
+
rope_theta_=llm_config.rope_theta,
|
| 86 |
+
pad_token_id_=llm_config.pad_token_id,
|
| 87 |
+
attn_implementation_=attn_impl,
|
| 88 |
+
device_=torch.device(device),
|
| 89 |
+
dtype_=llm_model.dtype,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
if llm_args.pad_token_id_ is None:
|
| 93 |
+
llm_args.pad_token_id_ = -1
|
| 94 |
+
|
| 95 |
+
model = GemmaForCausalLM(llm_args)
|
| 96 |
+
llm_model.requires_grad_(False)
|
| 97 |
+
model.embed_tokens_ = GemmaEmbedding(
|
| 98 |
+
llm_model.model.embed_tokens.weight,
|
| 99 |
+
llm_args.pad_token_id_,
|
| 100 |
+
llm_args.dim_**0.5,
|
| 101 |
+
)
|
| 102 |
+
model.norm_ = GemmaRMSNorm(llm_model.model.norm.weight, llm_args.rms_norm_eps_)
|
| 103 |
+
copy_parameters(llm_model.lm_head, model.lm_head_)
|
| 104 |
+
|
| 105 |
+
for idx, layer in enumerate(llm_model.model.layers):
|
| 106 |
+
decoder = LlamaDecoderLayer(idx)
|
| 107 |
+
decoder.self_attn_ = GEMMA_ATTENTION_CLASSES[llm_args.attn_implementation_](
|
| 108 |
+
layer.self_attn.q_proj,
|
| 109 |
+
layer.self_attn.k_proj,
|
| 110 |
+
layer.self_attn.v_proj,
|
| 111 |
+
layer.self_attn.o_proj,
|
| 112 |
+
idx,
|
| 113 |
+
llm_args,
|
| 114 |
+
)
|
| 115 |
+
decoder.mlp_ = FeedForward(
|
| 116 |
+
LlamaMLP(
|
| 117 |
+
layer.mlp.gate_proj,
|
| 118 |
+
layer.mlp.down_proj,
|
| 119 |
+
layer.mlp.up_proj,
|
| 120 |
+
llm_args,
|
| 121 |
+
)
|
| 122 |
+
)
|
| 123 |
+
decoder.input_layernorm_ = GemmaRMSNorm(
|
| 124 |
+
layer.input_layernorm.weight, llm_args.rms_norm_eps_
|
| 125 |
+
)
|
| 126 |
+
decoder.post_attention_layernorm_ = GemmaRMSNorm(
|
| 127 |
+
layer.post_attention_layernorm.weight, llm_args.rms_norm_eps_
|
| 128 |
+
)
|
| 129 |
+
model.layers_.append(decoder)
|
| 130 |
+
|
| 131 |
+
return model
|
c2cite/models/modeling_gemma2.py
ADDED
|
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Dict, List, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from transformers.models.gemma2 import modeling_gemma2
|
| 7 |
+
from transformers.models.gemma2.modeling_gemma2 import apply_rotary_pos_emb, repeat_kv
|
| 8 |
+
from transformers.utils import is_flash_attn_2_available
|
| 9 |
+
|
| 10 |
+
from moe_peft.common import (
|
| 11 |
+
FeedForward,
|
| 12 |
+
Linear,
|
| 13 |
+
LLMAttention,
|
| 14 |
+
LLMCache,
|
| 15 |
+
LLMDecoder,
|
| 16 |
+
LLMForCausalLM,
|
| 17 |
+
LLMModelConfig,
|
| 18 |
+
LLMModelInput,
|
| 19 |
+
collect_plugin_router_logtis,
|
| 20 |
+
flash_attention_forward,
|
| 21 |
+
prepare_4d_causal_attention_mask,
|
| 22 |
+
)
|
| 23 |
+
from moe_peft.executors import executor
|
| 24 |
+
from moe_peft.models.modeling_gemma import GemmaEmbedding, GemmaRMSNorm
|
| 25 |
+
from moe_peft.models.modeling_llama import LlamaMLP
|
| 26 |
+
from moe_peft.utils import copy_parameters, is_package_available
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class Gemma2Config(LLMModelConfig):
|
| 31 |
+
rms_norm_eps_: float = 1e-6
|
| 32 |
+
attn_logit_softcapping_: float = 50.0
|
| 33 |
+
final_logit_softcapping_: float = 30.0
|
| 34 |
+
query_pre_attn_scalar_: int = 224
|
| 35 |
+
use_sliding_window_: bool = False
|
| 36 |
+
sliding_window_: int = 4096
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class Gemma2RotaryEmbedding(nn.Module):
|
| 40 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
| 41 |
+
super().__init__()
|
| 42 |
+
|
| 43 |
+
self.dim = dim
|
| 44 |
+
self.max_position_embeddings = max_position_embeddings
|
| 45 |
+
self.base = base
|
| 46 |
+
inv_freq = 1.0 / (
|
| 47 |
+
self.base
|
| 48 |
+
** (
|
| 49 |
+
torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device)
|
| 50 |
+
/ self.dim
|
| 51 |
+
)
|
| 52 |
+
)
|
| 53 |
+
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
|
| 54 |
+
|
| 55 |
+
@torch.no_grad()
|
| 56 |
+
def forward(self, x, position_ids):
|
| 57 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
| 58 |
+
self.inv_freq.to(x.device)
|
| 59 |
+
inv_freq_expanded = (
|
| 60 |
+
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
| 61 |
+
)
|
| 62 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 63 |
+
# Force float32 since bfloat16 loses precision on long contexts
|
| 64 |
+
# See https://github.com/huggingface/transformers/pull/29285
|
| 65 |
+
device_type = x.device.type
|
| 66 |
+
device_type = (
|
| 67 |
+
device_type
|
| 68 |
+
if isinstance(device_type, str) and device_type != "mps"
|
| 69 |
+
else "cpu"
|
| 70 |
+
)
|
| 71 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
| 72 |
+
freqs = (
|
| 73 |
+
inv_freq_expanded.float() @ position_ids_expanded.float()
|
| 74 |
+
).transpose(1, 2)
|
| 75 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 76 |
+
cos = emb.cos()
|
| 77 |
+
sin = emb.sin()
|
| 78 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# Multi-headed attention from 'Attention Is All You Need' paper.
|
| 82 |
+
class Gemma2Attention(LLMAttention):
|
| 83 |
+
def __init__(
|
| 84 |
+
self,
|
| 85 |
+
q_proj: nn.Module,
|
| 86 |
+
k_proj: nn.Module,
|
| 87 |
+
v_proj: nn.Module,
|
| 88 |
+
o_proj: nn.Module,
|
| 89 |
+
layer_idx: int,
|
| 90 |
+
config: Gemma2Config,
|
| 91 |
+
):
|
| 92 |
+
super().__init__()
|
| 93 |
+
# attention
|
| 94 |
+
self.q_proj_: Linear = Linear(q_proj, config.device_)
|
| 95 |
+
self.k_proj_: Linear = Linear(k_proj, config.device_)
|
| 96 |
+
self.v_proj_: Linear = Linear(v_proj, config.device_)
|
| 97 |
+
self.o_proj_: Linear = Linear(o_proj, config.device_)
|
| 98 |
+
# config
|
| 99 |
+
self.layer_idx_ = layer_idx
|
| 100 |
+
self.config_ = config
|
| 101 |
+
self.dim_ = config.dim_
|
| 102 |
+
self.n_heads_ = config.n_heads_
|
| 103 |
+
self.n_kv_heads_ = config.n_kv_heads_
|
| 104 |
+
self.n_rep_ = self.n_heads_ // self.n_kv_heads_
|
| 105 |
+
self.head_dim_ = config.head_dim_
|
| 106 |
+
self.dtype_ = config.dtype_
|
| 107 |
+
self.is_causal_ = True
|
| 108 |
+
|
| 109 |
+
self.scaling_ = config.query_pre_attn_scalar_**-0.5
|
| 110 |
+
self.sliding_window_ = (
|
| 111 |
+
config.sliding_window_
|
| 112 |
+
if config.use_sliding_window_ and not bool(layer_idx % 2)
|
| 113 |
+
else None
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
def state_dict(self) -> Dict[str, Linear]:
|
| 117 |
+
return {
|
| 118 |
+
"q_proj": self.q_proj_,
|
| 119 |
+
"k_proj": self.k_proj_,
|
| 120 |
+
"v_proj": self.v_proj_,
|
| 121 |
+
"o_proj": self.o_proj_,
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
def forward(
|
| 125 |
+
self,
|
| 126 |
+
hidden_states: torch.Tensor,
|
| 127 |
+
input_args: LLMModelInput,
|
| 128 |
+
rotary_emb: Tuple[torch.Tensor, torch.Tensor],
|
| 129 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 130 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 131 |
+
past_key_value: Optional[LLMCache] = None,
|
| 132 |
+
):
|
| 133 |
+
bsz, q_len, _ = hidden_states.size()
|
| 134 |
+
|
| 135 |
+
query_states = self.q_proj_(hidden_states, input_args)
|
| 136 |
+
key_states = self.k_proj_(hidden_states, input_args)
|
| 137 |
+
value_states = self.v_proj_(hidden_states, input_args)
|
| 138 |
+
|
| 139 |
+
query_states = query_states.view(
|
| 140 |
+
bsz, q_len, self.n_heads_, self.head_dim_
|
| 141 |
+
).transpose(1, 2)
|
| 142 |
+
key_states = key_states.view(
|
| 143 |
+
bsz, q_len, self.n_kv_heads_, self.head_dim_
|
| 144 |
+
).transpose(1, 2)
|
| 145 |
+
value_states = value_states.view(
|
| 146 |
+
bsz, q_len, self.n_kv_heads_, self.head_dim_
|
| 147 |
+
).transpose(1, 2)
|
| 148 |
+
|
| 149 |
+
cos, sin = rotary_emb
|
| 150 |
+
query_states, key_states = apply_rotary_pos_emb(
|
| 151 |
+
query_states, key_states, cos, sin
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
if past_key_value is not None:
|
| 155 |
+
cache_kwargs = {
|
| 156 |
+
"sin": sin,
|
| 157 |
+
"cos": cos,
|
| 158 |
+
"sliding_window": self.sliding_window_,
|
| 159 |
+
"cache_position": cache_position,
|
| 160 |
+
}
|
| 161 |
+
key_states, value_states = past_key_value.update(
|
| 162 |
+
key_states, value_states, self.layer_idx_, cache_kwargs
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
key_states = repeat_kv(key_states, self.n_rep_)
|
| 166 |
+
value_states = repeat_kv(value_states, self.n_rep_)
|
| 167 |
+
|
| 168 |
+
attn_weights = (
|
| 169 |
+
torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling_
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
if self.config_.attn_logit_softcapping_ is not None:
|
| 173 |
+
attn_weights = attn_weights / self.config_.attn_logit_softcapping_
|
| 174 |
+
attn_weights = torch.tanh(attn_weights)
|
| 175 |
+
attn_weights = attn_weights * self.config_.attn_logit_softcapping_
|
| 176 |
+
|
| 177 |
+
if attention_mask is not None: # no matter the length, we just slice it
|
| 178 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 179 |
+
attn_weights = attn_weights + causal_mask
|
| 180 |
+
|
| 181 |
+
# upcast attention to fp32
|
| 182 |
+
attn_weights = nn.functional.softmax(
|
| 183 |
+
attn_weights, dim=-1, dtype=torch.float32
|
| 184 |
+
).to(query_states.dtype)
|
| 185 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 186 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 187 |
+
|
| 188 |
+
attn_output = attn_output.view(bsz, q_len, -1)
|
| 189 |
+
return self.o_proj_(attn_output, input_args)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class Gemma2FlashAttention2(Gemma2Attention):
|
| 193 |
+
def __init__(
|
| 194 |
+
self,
|
| 195 |
+
q_proj: nn.Module,
|
| 196 |
+
k_proj: nn.Module,
|
| 197 |
+
v_proj: nn.Module,
|
| 198 |
+
o_proj: nn.Module,
|
| 199 |
+
layer_idx: int,
|
| 200 |
+
config: Gemma2Config,
|
| 201 |
+
):
|
| 202 |
+
assert is_flash_attn_2_available(), "Flash Attention is not available"
|
| 203 |
+
super().__init__(q_proj, k_proj, v_proj, o_proj, layer_idx, config)
|
| 204 |
+
|
| 205 |
+
def forward(
|
| 206 |
+
self,
|
| 207 |
+
hidden_states: torch.Tensor,
|
| 208 |
+
input_args: LLMModelInput,
|
| 209 |
+
rotary_emb: Tuple[torch.Tensor, torch.Tensor],
|
| 210 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 211 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 212 |
+
past_key_value: Optional[LLMCache] = None,
|
| 213 |
+
):
|
| 214 |
+
bsz, q_len, _ = hidden_states.size()
|
| 215 |
+
|
| 216 |
+
query_states = self.q_proj_(hidden_states, input_args)
|
| 217 |
+
key_states = self.k_proj_(hidden_states, input_args)
|
| 218 |
+
value_states = self.v_proj_(hidden_states, input_args)
|
| 219 |
+
|
| 220 |
+
query_states = query_states.view(
|
| 221 |
+
bsz, q_len, self.n_heads_, self.head_dim_
|
| 222 |
+
).transpose(1, 2)
|
| 223 |
+
key_states = key_states.view(
|
| 224 |
+
bsz, q_len, self.n_kv_heads_, self.head_dim_
|
| 225 |
+
).transpose(1, 2)
|
| 226 |
+
value_states = value_states.view(
|
| 227 |
+
bsz, q_len, self.n_kv_heads_, self.head_dim_
|
| 228 |
+
).transpose(1, 2)
|
| 229 |
+
|
| 230 |
+
cos, sin = rotary_emb
|
| 231 |
+
query_states, key_states = apply_rotary_pos_emb(
|
| 232 |
+
query_states, key_states, cos, sin
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
if past_key_value is not None:
|
| 236 |
+
cache_kwargs = {
|
| 237 |
+
"sin": sin,
|
| 238 |
+
"cos": cos,
|
| 239 |
+
"sliding_window": self.sliding_window_,
|
| 240 |
+
"cache_position": cache_position,
|
| 241 |
+
}
|
| 242 |
+
key_states, value_states = past_key_value.update(
|
| 243 |
+
key_states, value_states, self.layer_idx_, cache_kwargs
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
if attention_mask is not None:
|
| 247 |
+
seq_len = attention_mask.shape[1]
|
| 248 |
+
key_states = key_states[:, :, :seq_len]
|
| 249 |
+
value_states = value_states[:, :, :seq_len]
|
| 250 |
+
|
| 251 |
+
query_states = query_states.transpose(1, 2)
|
| 252 |
+
key_states = key_states.transpose(1, 2)
|
| 253 |
+
value_states = value_states.transpose(1, 2)
|
| 254 |
+
|
| 255 |
+
input_dtype = query_states.dtype
|
| 256 |
+
if input_dtype == torch.float32:
|
| 257 |
+
if executor.is_bf16_supported():
|
| 258 |
+
target_dtype = torch.bfloat16
|
| 259 |
+
else:
|
| 260 |
+
target_dtype = torch.float16
|
| 261 |
+
query_states = query_states.to(target_dtype)
|
| 262 |
+
key_states = key_states.to(target_dtype)
|
| 263 |
+
value_states = value_states.to(target_dtype)
|
| 264 |
+
|
| 265 |
+
attn_output = flash_attention_forward(
|
| 266 |
+
query_states,
|
| 267 |
+
key_states,
|
| 268 |
+
value_states,
|
| 269 |
+
attention_mask,
|
| 270 |
+
q_len,
|
| 271 |
+
is_causal=self.is_causal_,
|
| 272 |
+
softmax_scale=self.scaling_,
|
| 273 |
+
sliding_window=(
|
| 274 |
+
self.sliding_window_ if self.config_.use_sliding_window_ else None
|
| 275 |
+
),
|
| 276 |
+
softcap=(
|
| 277 |
+
self.config_.attn_logit_softcapping_
|
| 278 |
+
if is_package_available("flash_attn", "2.6.0")
|
| 279 |
+
else None
|
| 280 |
+
),
|
| 281 |
+
).to(input_dtype)
|
| 282 |
+
|
| 283 |
+
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
| 284 |
+
attn_output = self.o_proj_(attn_output, input_args)
|
| 285 |
+
|
| 286 |
+
return attn_output
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
GEMMA2_ATTENTION_CLASSES = {
|
| 290 |
+
"eager": Gemma2Attention,
|
| 291 |
+
"flash_attn": Gemma2FlashAttention2,
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class Gemma2DecoderLayer(LLMDecoder):
|
| 296 |
+
def __init__(self, layer_idx: int, config: Gemma2Config) -> None:
|
| 297 |
+
super().__init__()
|
| 298 |
+
self.layer_id_: int = layer_idx
|
| 299 |
+
self.self_attn_: Gemma2Attention = None
|
| 300 |
+
self.mlp_: FeedForward = None
|
| 301 |
+
self.input_layernorm_: GemmaRMSNorm = None
|
| 302 |
+
self.post_attention_layernorm_: GemmaRMSNorm = None
|
| 303 |
+
|
| 304 |
+
self.config_ = config
|
| 305 |
+
self.is_sliding_ = not bool(layer_idx % 2)
|
| 306 |
+
self.pre_feedforward_layernorm_: GemmaRMSNorm = None
|
| 307 |
+
self.post_feedforward_layernorm_: GemmaRMSNorm = None
|
| 308 |
+
self.sliding_window_ = config.sliding_window_
|
| 309 |
+
|
| 310 |
+
def state_dict(self) -> Tuple[Dict[str, nn.Module], Dict[str, nn.Module]]:
|
| 311 |
+
return self.self_attn_.state_dict(), self.mlp_.state_dict()
|
| 312 |
+
|
| 313 |
+
def forward(
|
| 314 |
+
self,
|
| 315 |
+
hidden_states: torch.Tensor,
|
| 316 |
+
input_args: LLMModelInput,
|
| 317 |
+
rotary_emb: Tuple[torch.Tensor, torch.Tensor],
|
| 318 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 319 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 320 |
+
past_key_value: Optional[LLMCache] = None,
|
| 321 |
+
):
|
| 322 |
+
if (
|
| 323 |
+
self.config_.use_sliding_window_
|
| 324 |
+
and self.is_sliding_
|
| 325 |
+
and attention_mask is not None
|
| 326 |
+
):
|
| 327 |
+
if self.config_.attn_implementation_ == "flash_attn":
|
| 328 |
+
if past_key_value is not None: # when decoding
|
| 329 |
+
attention_mask = attention_mask[:, -self.sliding_window :]
|
| 330 |
+
else:
|
| 331 |
+
min_dtype = torch.finfo(hidden_states.dtype).min
|
| 332 |
+
sliding_window_mask = torch.tril(
|
| 333 |
+
torch.ones_like(attention_mask, dtype=torch.bool),
|
| 334 |
+
diagonal=-self.sliding_window_,
|
| 335 |
+
)
|
| 336 |
+
attention_mask = torch.where(
|
| 337 |
+
sliding_window_mask, min_dtype, attention_mask
|
| 338 |
+
)
|
| 339 |
+
if attention_mask.shape[-1] <= 1: # when decoding
|
| 340 |
+
attention_mask = attention_mask[:, :, :, -self.sliding_window_ :]
|
| 341 |
+
|
| 342 |
+
residual = hidden_states
|
| 343 |
+
|
| 344 |
+
hidden_states = self.input_layernorm_(hidden_states)
|
| 345 |
+
|
| 346 |
+
hidden_states = self.self_attn_.forward(
|
| 347 |
+
hidden_states,
|
| 348 |
+
input_args,
|
| 349 |
+
rotary_emb,
|
| 350 |
+
attention_mask,
|
| 351 |
+
cache_position,
|
| 352 |
+
past_key_value,
|
| 353 |
+
)
|
| 354 |
+
hidden_states = self.post_attention_layernorm_(hidden_states)
|
| 355 |
+
hidden_states = residual + hidden_states
|
| 356 |
+
|
| 357 |
+
residual = hidden_states
|
| 358 |
+
hidden_states = self.pre_feedforward_layernorm_(hidden_states)
|
| 359 |
+
hidden_states, router_logits = self.mlp_.forward(hidden_states, input_args)
|
| 360 |
+
hidden_states = self.post_feedforward_layernorm_(hidden_states)
|
| 361 |
+
hidden_states = residual + hidden_states
|
| 362 |
+
|
| 363 |
+
if input_args.output_router_logits_:
|
| 364 |
+
router_logits = collect_plugin_router_logtis(
|
| 365 |
+
router_logits, input_args, self
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
return hidden_states, *router_logits
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
class Gemma2OutputLayer(nn.Module):
|
| 372 |
+
def __init__(self, config: Gemma2Config):
|
| 373 |
+
super().__init__()
|
| 374 |
+
self.lm_head_ = nn.Linear(
|
| 375 |
+
config.dim_,
|
| 376 |
+
config.vocab_size_,
|
| 377 |
+
bias=False,
|
| 378 |
+
dtype=config.dtype_,
|
| 379 |
+
device=config.device_,
|
| 380 |
+
)
|
| 381 |
+
self.final_logit_softcapping_ = config.final_logit_softcapping_
|
| 382 |
+
|
| 383 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 384 |
+
logits = self.lm_head_(hidden_states)
|
| 385 |
+
if self.final_logit_softcapping_ is not None:
|
| 386 |
+
logits = logits / self.final_logit_softcapping_
|
| 387 |
+
logits = torch.tanh(logits)
|
| 388 |
+
logits = logits * self.final_logit_softcapping_
|
| 389 |
+
return logits
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
class Gemma2ForCausalLM(LLMForCausalLM):
|
| 393 |
+
def __init__(self, config: Gemma2Config) -> None:
|
| 394 |
+
super().__init__()
|
| 395 |
+
self.config_ = config
|
| 396 |
+
self.padding_idx_ = config.pad_token_id_
|
| 397 |
+
self.vocab_size_ = config.vocab_size_
|
| 398 |
+
self.embed_tokens_: GemmaEmbedding = None
|
| 399 |
+
self.norm_: GemmaRMSNorm = None
|
| 400 |
+
self.rotary_emb_ = Gemma2RotaryEmbedding(
|
| 401 |
+
config.head_dim_,
|
| 402 |
+
max_position_embeddings=config.max_seq_len_,
|
| 403 |
+
base=config.rope_theta_,
|
| 404 |
+
device=config.device_,
|
| 405 |
+
)
|
| 406 |
+
self.lm_head_ = Gemma2OutputLayer(config)
|
| 407 |
+
self.layers_: List[Gemma2DecoderLayer] = []
|
| 408 |
+
|
| 409 |
+
def embed_tokens(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 410 |
+
return self.embed_tokens_(input_ids)
|
| 411 |
+
|
| 412 |
+
def rotary_embed(
|
| 413 |
+
self, input_tensor: torch.Tensor, position_ids: torch.Tensor
|
| 414 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 415 |
+
return self.rotary_emb_(input_tensor, position_ids)
|
| 416 |
+
|
| 417 |
+
def decoder_stack(self) -> List[LLMDecoder]:
|
| 418 |
+
return self.layers_
|
| 419 |
+
|
| 420 |
+
def norm(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 421 |
+
return self.norm_(hidden_states)
|
| 422 |
+
|
| 423 |
+
def causal_mask(
|
| 424 |
+
self,
|
| 425 |
+
attention_mask: torch.Tensor,
|
| 426 |
+
input_tensor: torch.Tensor,
|
| 427 |
+
cache_position: torch.Tensor,
|
| 428 |
+
past_key_values: Optional[LLMCache],
|
| 429 |
+
) -> torch.Tensor:
|
| 430 |
+
|
| 431 |
+
return prepare_4d_causal_attention_mask(
|
| 432 |
+
attention_mask,
|
| 433 |
+
input_tensor,
|
| 434 |
+
cache_position,
|
| 435 |
+
past_key_values,
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
def cache_implementation(self) -> str:
|
| 439 |
+
if self.config_.use_sliding_window_ and self.config_.sliding_window_:
|
| 440 |
+
return "hybrid"
|
| 441 |
+
else:
|
| 442 |
+
return "dynamic"
|
| 443 |
+
|
| 444 |
+
def model_config(self) -> Gemma2Config:
|
| 445 |
+
return self.config_
|
| 446 |
+
|
| 447 |
+
@staticmethod
|
| 448 |
+
def from_pretrained(
|
| 449 |
+
llm_model: modeling_gemma2.Gemma2PreTrainedModel,
|
| 450 |
+
attn_impl: str = "eager",
|
| 451 |
+
use_sliding_window: bool = False,
|
| 452 |
+
device: str = executor.default_device_name(),
|
| 453 |
+
):
|
| 454 |
+
llm_config: modeling_gemma2.Gemma2Config = llm_model.config
|
| 455 |
+
model_config = Gemma2Config(
|
| 456 |
+
name_or_path_=llm_config.name_or_path,
|
| 457 |
+
vocab_size_=llm_config.vocab_size,
|
| 458 |
+
dim_=llm_config.hidden_size,
|
| 459 |
+
head_dim_=llm_config.head_dim,
|
| 460 |
+
intermediate_=llm_config.intermediate_size,
|
| 461 |
+
n_layers_=llm_config.num_hidden_layers,
|
| 462 |
+
n_heads_=llm_config.num_attention_heads,
|
| 463 |
+
n_kv_heads_=llm_config.num_key_value_heads,
|
| 464 |
+
hidden_act_=llm_config.hidden_activation,
|
| 465 |
+
rms_norm_eps_=llm_config.rms_norm_eps,
|
| 466 |
+
max_seq_len_=llm_config.max_position_embeddings,
|
| 467 |
+
rope_theta_=llm_config.rope_theta,
|
| 468 |
+
attn_logit_softcapping_=llm_config.attn_logit_softcapping,
|
| 469 |
+
final_logit_softcapping_=llm_config.final_logit_softcapping,
|
| 470 |
+
query_pre_attn_scalar_=llm_config.query_pre_attn_scalar,
|
| 471 |
+
pad_token_id_=llm_config.pad_token_id,
|
| 472 |
+
attn_implementation_=attn_impl,
|
| 473 |
+
use_sliding_window_=use_sliding_window,
|
| 474 |
+
sliding_window_=llm_config.sliding_window,
|
| 475 |
+
device_=torch.device(device),
|
| 476 |
+
dtype_=llm_model.dtype,
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
if model_config.pad_token_id_ is None:
|
| 480 |
+
model_config.pad_token_id_ = -1
|
| 481 |
+
|
| 482 |
+
model = Gemma2ForCausalLM(model_config)
|
| 483 |
+
llm_model.requires_grad_(False)
|
| 484 |
+
model.embed_tokens_ = GemmaEmbedding(
|
| 485 |
+
llm_model.model.embed_tokens.weight,
|
| 486 |
+
model_config.pad_token_id_,
|
| 487 |
+
model_config.dim_**0.5,
|
| 488 |
+
)
|
| 489 |
+
model.norm_ = GemmaRMSNorm(
|
| 490 |
+
llm_model.model.norm.weight, model_config.rms_norm_eps_
|
| 491 |
+
)
|
| 492 |
+
copy_parameters(llm_model.lm_head, model.lm_head_.lm_head_)
|
| 493 |
+
|
| 494 |
+
for layer_idx, layer in enumerate(llm_model.model.layers):
|
| 495 |
+
decoder = Gemma2DecoderLayer(layer_idx, model_config)
|
| 496 |
+
decoder.self_attn_ = GEMMA2_ATTENTION_CLASSES[
|
| 497 |
+
model_config.attn_implementation_
|
| 498 |
+
](
|
| 499 |
+
layer.self_attn.q_proj,
|
| 500 |
+
layer.self_attn.k_proj,
|
| 501 |
+
layer.self_attn.v_proj,
|
| 502 |
+
layer.self_attn.o_proj,
|
| 503 |
+
layer_idx,
|
| 504 |
+
model_config,
|
| 505 |
+
)
|
| 506 |
+
decoder.mlp_ = FeedForward(
|
| 507 |
+
LlamaMLP(
|
| 508 |
+
layer.mlp.gate_proj,
|
| 509 |
+
layer.mlp.down_proj,
|
| 510 |
+
layer.mlp.up_proj,
|
| 511 |
+
model_config,
|
| 512 |
+
)
|
| 513 |
+
)
|
| 514 |
+
decoder.input_layernorm_ = GemmaRMSNorm(
|
| 515 |
+
layer.input_layernorm.weight, model_config.rms_norm_eps_
|
| 516 |
+
)
|
| 517 |
+
decoder.post_attention_layernorm_ = GemmaRMSNorm(
|
| 518 |
+
layer.post_attention_layernorm.weight, model_config.rms_norm_eps_
|
| 519 |
+
)
|
| 520 |
+
decoder.pre_feedforward_layernorm_ = GemmaRMSNorm(
|
| 521 |
+
layer.pre_feedforward_layernorm.weight, model_config.rms_norm_eps_
|
| 522 |
+
)
|
| 523 |
+
decoder.post_feedforward_layernorm_ = GemmaRMSNorm(
|
| 524 |
+
layer.post_feedforward_layernorm.weight, model_config.rms_norm_eps_
|
| 525 |
+
)
|
| 526 |
+
model.layers_.append(decoder)
|
| 527 |
+
|
| 528 |
+
return model
|
c2cite/models/modeling_llama.py
ADDED
|
@@ -0,0 +1,579 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from transformers.activations import ACT2FN
|
| 8 |
+
from transformers.models.llama import modeling_llama
|
| 9 |
+
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
| 10 |
+
from transformers.utils import is_flash_attn_2_available
|
| 11 |
+
|
| 12 |
+
from moe_peft.common import (
|
| 13 |
+
ROPE_INIT_FUNCTIONS,
|
| 14 |
+
FeedForward,
|
| 15 |
+
Linear,
|
| 16 |
+
LLMAttention,
|
| 17 |
+
LLMCache,
|
| 18 |
+
LLMDecoder,
|
| 19 |
+
LLMFeedForward,
|
| 20 |
+
LLMForCausalLM,
|
| 21 |
+
LLMModelConfig,
|
| 22 |
+
LLMModelInput,
|
| 23 |
+
collect_plugin_router_logtis,
|
| 24 |
+
eager_attention_forward,
|
| 25 |
+
flash_attention_forward,
|
| 26 |
+
prepare_4d_causal_attention_mask,
|
| 27 |
+
slice_tensor,
|
| 28 |
+
)
|
| 29 |
+
from moe_peft.executors import executor
|
| 30 |
+
from moe_peft.utils import copy_parameters
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass
|
| 34 |
+
class LlamaConfig(LLMModelConfig):
|
| 35 |
+
rms_norm_eps_: float = 1e-6
|
| 36 |
+
rope_scaling_: Optional[Dict[str, Any]] = None
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class LlamaRotaryEmbedding(nn.Module):
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
config: Optional[LlamaConfig],
|
| 43 |
+
scaling_factor=1.0,
|
| 44 |
+
rope_type="default",
|
| 45 |
+
):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.rope_kwargs = {
|
| 48 |
+
"rope_type": rope_type,
|
| 49 |
+
"factor": scaling_factor,
|
| 50 |
+
"dim": config.head_dim_,
|
| 51 |
+
"base": config.rope_theta_,
|
| 52 |
+
"max_position_embeddings": config.max_seq_len_,
|
| 53 |
+
}
|
| 54 |
+
if config is None:
|
| 55 |
+
self.rope_type = rope_type
|
| 56 |
+
self.max_seq_len_cached = config.max_seq_len_
|
| 57 |
+
self.original_max_seq_len = config.max_seq_len_
|
| 58 |
+
else:
|
| 59 |
+
# BC: "rope_type" was originally "type"
|
| 60 |
+
if config.rope_scaling_ is not None:
|
| 61 |
+
self.rope_type = config.rope_scaling_.get(
|
| 62 |
+
"rope_type", config.rope_scaling_.get("type")
|
| 63 |
+
)
|
| 64 |
+
else:
|
| 65 |
+
self.rope_type = "default"
|
| 66 |
+
self.max_seq_len_cached = config.max_seq_len_
|
| 67 |
+
self.original_max_seq_len = config.max_seq_len_
|
| 68 |
+
|
| 69 |
+
self.config = config
|
| 70 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 71 |
+
|
| 72 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(
|
| 73 |
+
self.config, config.device_, **self.rope_kwargs
|
| 74 |
+
)
|
| 75 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 76 |
+
self.original_inv_freq = self.inv_freq
|
| 77 |
+
|
| 78 |
+
def _dynamic_frequency_update(self, position_ids, device):
|
| 79 |
+
seq_len = torch.max(position_ids) + 1
|
| 80 |
+
if seq_len > self.max_seq_len_cached: # growth
|
| 81 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(
|
| 82 |
+
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
| 83 |
+
)
|
| 84 |
+
self.register_buffer(
|
| 85 |
+
"inv_freq", inv_freq, persistent=False
|
| 86 |
+
) # TODO joao: may break with compilation
|
| 87 |
+
self.max_seq_len_cached = seq_len
|
| 88 |
+
|
| 89 |
+
if (
|
| 90 |
+
seq_len < self.original_max_seq_len
|
| 91 |
+
and self.max_seq_len_cached > self.original_max_seq_len
|
| 92 |
+
): # reset
|
| 93 |
+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
| 94 |
+
self.max_seq_len_cached = self.original_max_seq_len
|
| 95 |
+
|
| 96 |
+
@torch.no_grad()
|
| 97 |
+
def forward(self, x, position_ids):
|
| 98 |
+
if "dynamic" in self.rope_type:
|
| 99 |
+
self._dynamic_frequency_update(position_ids, device=x.device)
|
| 100 |
+
|
| 101 |
+
# Core RoPE block
|
| 102 |
+
inv_freq_expanded = (
|
| 103 |
+
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
| 104 |
+
)
|
| 105 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 106 |
+
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
| 107 |
+
device_type = x.device.type
|
| 108 |
+
device_type = (
|
| 109 |
+
device_type
|
| 110 |
+
if isinstance(device_type, str) and device_type != "mps"
|
| 111 |
+
else "cpu"
|
| 112 |
+
)
|
| 113 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
| 114 |
+
freqs = (
|
| 115 |
+
inv_freq_expanded.float() @ position_ids_expanded.float()
|
| 116 |
+
).transpose(1, 2)
|
| 117 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 118 |
+
cos = emb.cos()
|
| 119 |
+
sin = emb.sin()
|
| 120 |
+
|
| 121 |
+
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
| 122 |
+
cos = cos * self.attention_scaling
|
| 123 |
+
sin = sin * self.attention_scaling
|
| 124 |
+
|
| 125 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# Multi-headed attention from 'Attention Is All You Need' paper.
|
| 129 |
+
class LlamaAttention(LLMAttention):
|
| 130 |
+
def __init__(
|
| 131 |
+
self,
|
| 132 |
+
wq: nn.Module,
|
| 133 |
+
wk: nn.Module,
|
| 134 |
+
wv: nn.Module,
|
| 135 |
+
wo: nn.Module,
|
| 136 |
+
idx: int,
|
| 137 |
+
args: LlamaConfig,
|
| 138 |
+
):
|
| 139 |
+
super().__init__()
|
| 140 |
+
# attention
|
| 141 |
+
self.wq_: Linear = Linear(wq, args.device_) # dim * dim
|
| 142 |
+
self.wk_: Linear = Linear(wk, args.device_) # dim * dim
|
| 143 |
+
self.wv_: Linear = Linear(wv, args.device_) # dim * dim
|
| 144 |
+
self.wo_: Linear = Linear(wo, args.device_) # dim * dim
|
| 145 |
+
# config
|
| 146 |
+
self.layer_idx_ = idx
|
| 147 |
+
self.dim_ = args.dim_
|
| 148 |
+
self.n_heads_ = args.n_heads_
|
| 149 |
+
self.n_kv_heads_ = args.n_kv_heads_
|
| 150 |
+
self.n_rep_ = self.n_heads_ // self.n_kv_heads_
|
| 151 |
+
self.head_dim_ = args.head_dim_
|
| 152 |
+
self.dtype_ = args.dtype_
|
| 153 |
+
self.is_causal_ = True
|
| 154 |
+
|
| 155 |
+
def state_dict(self) -> Dict[str, Linear]:
|
| 156 |
+
return {
|
| 157 |
+
"q_proj": self.wq_,
|
| 158 |
+
"k_proj": self.wk_,
|
| 159 |
+
"v_proj": self.wv_,
|
| 160 |
+
"o_proj": self.wo_,
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
def forward(
|
| 164 |
+
self,
|
| 165 |
+
hidden_states: torch.Tensor,
|
| 166 |
+
input_args: LLMModelInput,
|
| 167 |
+
rotary_emb: Tuple[torch.Tensor, torch.Tensor],
|
| 168 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 169 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 170 |
+
past_key_value: Optional[LLMCache] = None,
|
| 171 |
+
):
|
| 172 |
+
batch_size, max_seq_len, _ = hidden_states.shape
|
| 173 |
+
|
| 174 |
+
xq = self.wq_.forward(hidden_states, input_args)
|
| 175 |
+
xk = self.wk_.forward(hidden_states, input_args)
|
| 176 |
+
xv = self.wv_.forward(hidden_states, input_args)
|
| 177 |
+
|
| 178 |
+
# conver shape to multi head
|
| 179 |
+
xq = xq.view(batch_size, max_seq_len, self.n_heads_, self.head_dim_).transpose(
|
| 180 |
+
1, 2
|
| 181 |
+
)
|
| 182 |
+
xk = xk.view(
|
| 183 |
+
batch_size, max_seq_len, self.n_kv_heads_, self.head_dim_
|
| 184 |
+
).transpose(1, 2)
|
| 185 |
+
xv = xv.view(
|
| 186 |
+
batch_size, max_seq_len, self.n_kv_heads_, self.head_dim_
|
| 187 |
+
).transpose(1, 2)
|
| 188 |
+
|
| 189 |
+
# apply rotary embedding
|
| 190 |
+
cos, sin = rotary_emb
|
| 191 |
+
xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin)
|
| 192 |
+
|
| 193 |
+
if past_key_value is not None:
|
| 194 |
+
cache_kwargs = {
|
| 195 |
+
"sin": sin,
|
| 196 |
+
"cos": cos,
|
| 197 |
+
"cache_position": cache_position,
|
| 198 |
+
}
|
| 199 |
+
xk, xv = past_key_value.update(xk, xv, self.layer_idx_, cache_kwargs)
|
| 200 |
+
|
| 201 |
+
# for llama2 need to repeat the heads
|
| 202 |
+
# before dim: batch_size, n_kv_head, seq_len, head_dim
|
| 203 |
+
# after dim: batch_size, n_head, seq_len, head_dim
|
| 204 |
+
xk = repeat_kv(xk, self.n_rep_)
|
| 205 |
+
xv = repeat_kv(xv, self.n_rep_)
|
| 206 |
+
|
| 207 |
+
attention_score, attention_matrix = eager_attention_forward(xq, xk, xv, attention_mask)
|
| 208 |
+
attention_score = attention_score.reshape(batch_size, max_seq_len, -1)
|
| 209 |
+
|
| 210 |
+
# get output attention score
|
| 211 |
+
return self.wo_.forward(attention_score, input_args), attention_matrix
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class LlamaFlashAttention(LlamaAttention):
|
| 215 |
+
def __init__(
|
| 216 |
+
self,
|
| 217 |
+
wq: nn.Module,
|
| 218 |
+
wk: nn.Module,
|
| 219 |
+
wv: nn.Module,
|
| 220 |
+
wo: nn.Module,
|
| 221 |
+
idx: int,
|
| 222 |
+
args: LlamaConfig,
|
| 223 |
+
):
|
| 224 |
+
assert is_flash_attn_2_available(), "Flash Attention is not available"
|
| 225 |
+
super().__init__(wq, wk, wv, wo, idx, args)
|
| 226 |
+
|
| 227 |
+
def forward(
|
| 228 |
+
self,
|
| 229 |
+
hidden_states: torch.Tensor,
|
| 230 |
+
input_args: LLMModelInput,
|
| 231 |
+
rotary_emb: Tuple[torch.Tensor, torch.Tensor],
|
| 232 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 233 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 234 |
+
past_key_value: Optional[LLMCache] = None,
|
| 235 |
+
):
|
| 236 |
+
batch_size, max_seq_len, _ = hidden_states.shape
|
| 237 |
+
|
| 238 |
+
xq = self.wq_.forward(hidden_states, input_args)
|
| 239 |
+
xk = self.wk_.forward(hidden_states, input_args)
|
| 240 |
+
xv = self.wv_.forward(hidden_states, input_args)
|
| 241 |
+
|
| 242 |
+
# conver shape to multi head
|
| 243 |
+
xq = xq.view(batch_size, max_seq_len, self.n_heads_, self.head_dim_).transpose(
|
| 244 |
+
1, 2
|
| 245 |
+
)
|
| 246 |
+
xk = xk.view(
|
| 247 |
+
batch_size, max_seq_len, self.n_kv_heads_, self.head_dim_
|
| 248 |
+
).transpose(1, 2)
|
| 249 |
+
xv = xv.view(
|
| 250 |
+
batch_size, max_seq_len, self.n_kv_heads_, self.head_dim_
|
| 251 |
+
).transpose(1, 2)
|
| 252 |
+
|
| 253 |
+
# apply rotary embedding
|
| 254 |
+
cos, sin = rotary_emb
|
| 255 |
+
xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin)
|
| 256 |
+
|
| 257 |
+
if past_key_value is not None:
|
| 258 |
+
cache_kwargs = {
|
| 259 |
+
"sin": sin,
|
| 260 |
+
"cos": cos,
|
| 261 |
+
"cache_position": cache_position,
|
| 262 |
+
}
|
| 263 |
+
xk, xv = past_key_value.update(xk, xv, self.layer_idx_, cache_kwargs)
|
| 264 |
+
|
| 265 |
+
xq = xq.transpose(1, 2)
|
| 266 |
+
xk = xk.transpose(1, 2)
|
| 267 |
+
xv = xv.transpose(1, 2)
|
| 268 |
+
|
| 269 |
+
input_dtype = xq.dtype
|
| 270 |
+
if input_dtype == torch.float32:
|
| 271 |
+
if executor.is_bf16_supported():
|
| 272 |
+
target_dtype = torch.bfloat16
|
| 273 |
+
else:
|
| 274 |
+
target_dtype = torch.float16
|
| 275 |
+
xq = xq.to(target_dtype)
|
| 276 |
+
xk = xk.to(target_dtype)
|
| 277 |
+
xv = xv.to(target_dtype)
|
| 278 |
+
|
| 279 |
+
attn_output = flash_attention_forward(
|
| 280 |
+
xq,
|
| 281 |
+
xk,
|
| 282 |
+
xv,
|
| 283 |
+
attention_mask,
|
| 284 |
+
max_seq_len,
|
| 285 |
+
is_causal=self.is_causal_,
|
| 286 |
+
).to(input_dtype)
|
| 287 |
+
|
| 288 |
+
attn_output = attn_output.reshape(batch_size, max_seq_len, -1).contiguous()
|
| 289 |
+
attn_output = self.wo_.forward(attn_output, input_args)
|
| 290 |
+
|
| 291 |
+
return attn_output
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
LLAMA_ATTENTION_CLASSES = {
|
| 295 |
+
"eager": LlamaAttention,
|
| 296 |
+
"flash_attn": LlamaFlashAttention,
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
class LlamaMLP(LLMFeedForward):
|
| 301 |
+
def __init__(
|
| 302 |
+
self, w1: nn.Module, w2: nn.Module, w3: nn.Module, args: LlamaConfig
|
| 303 |
+
) -> None:
|
| 304 |
+
super().__init__()
|
| 305 |
+
# feed forward
|
| 306 |
+
self.w1_: Linear = Linear(w1, args.device_)
|
| 307 |
+
self.w2_: Linear = Linear(w2, args.device_)
|
| 308 |
+
self.w3_: Linear = Linear(w3, args.device_)
|
| 309 |
+
self.act_ = ACT2FN[args.hidden_act_]
|
| 310 |
+
|
| 311 |
+
def state_dict(self) -> Dict[str, nn.Module]:
|
| 312 |
+
return {
|
| 313 |
+
"gate_proj": self.w1_,
|
| 314 |
+
"down_proj": self.w2_,
|
| 315 |
+
"up_proj": self.w3_,
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
def _batch_forward(
|
| 319 |
+
self, data: torch.Tensor, input_args: LLMModelInput
|
| 320 |
+
) -> torch.Tensor:
|
| 321 |
+
w1 = self.w1_.forward(data, input_args)
|
| 322 |
+
w3 = self.w3_.forward(data, input_args)
|
| 323 |
+
return self.w2_.forward(self.act_(w1) * w3, input_args)
|
| 324 |
+
|
| 325 |
+
def _lora_forward(
|
| 326 |
+
self, lora_name: str, act_fn: nn.Module, data: torch.Tensor
|
| 327 |
+
) -> torch.Tensor:
|
| 328 |
+
# Applying LoRA weights to FFN weights
|
| 329 |
+
if lora_name in self.w1_.loras_:
|
| 330 |
+
w1 = self.w1_.loras_[lora_name].forward(
|
| 331 |
+
self.w1_.base_layer_.forward(data), data
|
| 332 |
+
)
|
| 333 |
+
else:
|
| 334 |
+
w1 = self.w1_.base_layer_.forward(data)
|
| 335 |
+
|
| 336 |
+
if lora_name in self.w3_.loras_:
|
| 337 |
+
w3 = self.w3_.loras_[lora_name].forward(
|
| 338 |
+
self.w3_.base_layer_.forward(data), data
|
| 339 |
+
)
|
| 340 |
+
else:
|
| 341 |
+
w3 = self.w3_.base_layer_.forward(data)
|
| 342 |
+
|
| 343 |
+
act_result = act_fn(w1) * w3
|
| 344 |
+
if lora_name in self.w2_.loras_:
|
| 345 |
+
return self.w2_.loras_[lora_name].forward(
|
| 346 |
+
self.w2_.base_layer_.forward(act_result), act_result
|
| 347 |
+
)
|
| 348 |
+
else:
|
| 349 |
+
return self.w2_.base_layer_.forward(act_result)
|
| 350 |
+
|
| 351 |
+
def _mixlora_forward(
|
| 352 |
+
self, moe_name, act_fn, expert_mask, hidden_states, input_dtype
|
| 353 |
+
):
|
| 354 |
+
common_w1 = self.w1_.base_layer_.forward(hidden_states.to(input_dtype)).to(
|
| 355 |
+
hidden_states.dtype
|
| 356 |
+
)
|
| 357 |
+
common_w3 = self.w3_.base_layer_.forward(hidden_states.to(input_dtype)).to(
|
| 358 |
+
hidden_states.dtype
|
| 359 |
+
)
|
| 360 |
+
final_expert_states = []
|
| 361 |
+
for expert_idx in range(expert_mask.shape[0]):
|
| 362 |
+
_, top_x = torch.where(expert_mask[expert_idx])
|
| 363 |
+
|
| 364 |
+
lora_name = f"moe.{moe_name}.experts.{expert_idx}"
|
| 365 |
+
if lora_name in self.w1_.loras_:
|
| 366 |
+
lora_data = slice_tensor(hidden_states, top_x, input_dtype)
|
| 367 |
+
w1 = self.w1_.loras_[lora_name].forward(
|
| 368 |
+
slice_tensor(common_w1, top_x, input_dtype), lora_data
|
| 369 |
+
)
|
| 370 |
+
else:
|
| 371 |
+
lora_data = None
|
| 372 |
+
w1 = slice_tensor(common_w1, top_x, input_dtype)
|
| 373 |
+
|
| 374 |
+
if lora_name in self.w3_.loras_:
|
| 375 |
+
w3 = self.w3_.loras_[lora_name].forward(
|
| 376 |
+
slice_tensor(common_w3, top_x, input_dtype),
|
| 377 |
+
slice_tensor(hidden_states, top_x, input_dtype, lora_data),
|
| 378 |
+
)
|
| 379 |
+
else:
|
| 380 |
+
w3 = slice_tensor(common_w3, top_x, input_dtype)
|
| 381 |
+
|
| 382 |
+
act_result = act_fn(w1) * w3
|
| 383 |
+
if lora_name in self.w2_.loras_:
|
| 384 |
+
final_expert_states.append(
|
| 385 |
+
self.w2_.loras_[lora_name].forward(
|
| 386 |
+
self.w2_.base_layer_.forward(act_result), act_result
|
| 387 |
+
)
|
| 388 |
+
)
|
| 389 |
+
else:
|
| 390 |
+
final_expert_states.append(self.w2_.base_layer_.forward(act_result))
|
| 391 |
+
|
| 392 |
+
return final_expert_states
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
class LlamaRMSNorm(nn.Module):
|
| 396 |
+
def __init__(self, weight: torch.Tensor, eps: float = 1e-6):
|
| 397 |
+
super().__init__()
|
| 398 |
+
self.norm_eps_ = eps
|
| 399 |
+
self.weight_ = weight
|
| 400 |
+
|
| 401 |
+
def forward(self, data: torch.Tensor) -> torch.Tensor:
|
| 402 |
+
input_dtype = data.dtype
|
| 403 |
+
v = data.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
| 404 |
+
data = data * torch.rsqrt(v + self.norm_eps_)
|
| 405 |
+
|
| 406 |
+
return (self.weight_ * data).to(input_dtype)
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
class LlamaDecoderLayer(LLMDecoder):
|
| 410 |
+
def __init__(self, layer_id: int) -> None:
|
| 411 |
+
super().__init__()
|
| 412 |
+
self.layer_id_: int = layer_id
|
| 413 |
+
self.self_attn_: LlamaAttention = None
|
| 414 |
+
self.mlp_: FeedForward = None
|
| 415 |
+
self.input_layernorm_: LlamaRMSNorm = None
|
| 416 |
+
self.post_attention_layernorm_: LlamaRMSNorm = None
|
| 417 |
+
|
| 418 |
+
def state_dict(self) -> Tuple[Dict[str, nn.Module], Dict[str, nn.Module]]:
|
| 419 |
+
return self.self_attn_.state_dict(), self.mlp_.state_dict()
|
| 420 |
+
|
| 421 |
+
def forward(
|
| 422 |
+
self,
|
| 423 |
+
hidden_states: torch.Tensor,
|
| 424 |
+
input_args: LLMModelInput,
|
| 425 |
+
rotary_emb: Tuple[torch.Tensor, torch.Tensor],
|
| 426 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 427 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 428 |
+
past_key_value: Optional[LLMCache] = None,
|
| 429 |
+
):
|
| 430 |
+
|
| 431 |
+
residual = hidden_states
|
| 432 |
+
hidden_states = self.input_layernorm_(hidden_states)
|
| 433 |
+
# Self Attention
|
| 434 |
+
hidden_states, attention_matrix = self.self_attn_.forward(
|
| 435 |
+
hidden_states,
|
| 436 |
+
input_args,
|
| 437 |
+
rotary_emb,
|
| 438 |
+
attention_mask,
|
| 439 |
+
cache_position,
|
| 440 |
+
past_key_value,
|
| 441 |
+
)
|
| 442 |
+
hidden_states = residual + hidden_states
|
| 443 |
+
# Fully Connected
|
| 444 |
+
residual = hidden_states
|
| 445 |
+
hidden_states = self.post_attention_layernorm_(hidden_states)
|
| 446 |
+
hidden_states = self.mlp_.forward(hidden_states, input_args)
|
| 447 |
+
hidden_states = residual + hidden_states
|
| 448 |
+
|
| 449 |
+
return hidden_states, attention_matrix
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
class LlamaEmbedding(nn.Module):
|
| 453 |
+
def __init__(self, embedding: torch.Tensor, pad_token: int):
|
| 454 |
+
super().__init__()
|
| 455 |
+
self.token_embedding_: torch.Tensor = embedding
|
| 456 |
+
self.padding_idx_: int = pad_token
|
| 457 |
+
|
| 458 |
+
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
| 459 |
+
data = F.embedding(tokens, self.token_embedding_, padding_idx=self.padding_idx_)
|
| 460 |
+
return data
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
class LlamaForCausalLM(LLMForCausalLM):
|
| 464 |
+
def __init__(self, config: LlamaConfig) -> None:
|
| 465 |
+
super().__init__()
|
| 466 |
+
self.config_ = config
|
| 467 |
+
self.padding_idx_ = config.pad_token_id_
|
| 468 |
+
self.vocab_size_ = config.vocab_size_
|
| 469 |
+
self.embed_tokens_: LlamaEmbedding = None
|
| 470 |
+
self.norm_: LlamaRMSNorm = None
|
| 471 |
+
self.rotary_emb_ = LlamaRotaryEmbedding(config)
|
| 472 |
+
self.lm_head_ = nn.Linear(
|
| 473 |
+
config.dim_,
|
| 474 |
+
config.vocab_size_,
|
| 475 |
+
bias=False,
|
| 476 |
+
dtype=config.dtype_,
|
| 477 |
+
device=config.device_,
|
| 478 |
+
)
|
| 479 |
+
self.layers_: List[LlamaDecoderLayer] = []
|
| 480 |
+
|
| 481 |
+
def embed_tokens(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 482 |
+
return self.embed_tokens_(input_ids)
|
| 483 |
+
|
| 484 |
+
def rotary_embed(
|
| 485 |
+
self, input_tensor: torch.Tensor, position_ids: torch.Tensor
|
| 486 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 487 |
+
return self.rotary_emb_(input_tensor, position_ids)
|
| 488 |
+
|
| 489 |
+
def decoder_stack(self) -> List[LLMDecoder]:
|
| 490 |
+
return self.layers_
|
| 491 |
+
|
| 492 |
+
def norm(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 493 |
+
return self.norm_(hidden_states)
|
| 494 |
+
|
| 495 |
+
def causal_mask(
|
| 496 |
+
self,
|
| 497 |
+
attention_mask: torch.Tensor,
|
| 498 |
+
input_tensor: torch.Tensor,
|
| 499 |
+
cache_position: torch.Tensor,
|
| 500 |
+
past_key_values: Optional[LLMCache],
|
| 501 |
+
) -> torch.Tensor:
|
| 502 |
+
|
| 503 |
+
return prepare_4d_causal_attention_mask(
|
| 504 |
+
attention_mask,
|
| 505 |
+
input_tensor,
|
| 506 |
+
cache_position,
|
| 507 |
+
past_key_values,
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
def model_config(self) -> LlamaConfig:
|
| 511 |
+
return self.config_
|
| 512 |
+
|
| 513 |
+
@staticmethod
|
| 514 |
+
def from_pretrained(
|
| 515 |
+
llm_model: modeling_llama.LlamaForCausalLM,
|
| 516 |
+
attn_impl: str = "eager",
|
| 517 |
+
use_sliding_window: bool = False,
|
| 518 |
+
device: str = executor.default_device_name(),
|
| 519 |
+
):
|
| 520 |
+
assert not use_sliding_window, "Llama model does not support SWA."
|
| 521 |
+
llm_config: modeling_llama.LlamaConfig = llm_model.config
|
| 522 |
+
llm_args = LlamaConfig(
|
| 523 |
+
name_or_path_=llm_config.name_or_path,
|
| 524 |
+
vocab_size_=llm_config.vocab_size,
|
| 525 |
+
dim_=llm_config.hidden_size,
|
| 526 |
+
head_dim_=llm_config.hidden_size // llm_config.num_attention_heads,
|
| 527 |
+
intermediate_=llm_config.intermediate_size,
|
| 528 |
+
n_layers_=llm_config.num_hidden_layers,
|
| 529 |
+
n_heads_=llm_config.num_attention_heads,
|
| 530 |
+
n_kv_heads_=llm_config.num_key_value_heads,
|
| 531 |
+
hidden_act_=llm_config.hidden_act,
|
| 532 |
+
rms_norm_eps_=llm_config.rms_norm_eps,
|
| 533 |
+
max_seq_len_=llm_config.max_position_embeddings,
|
| 534 |
+
rope_theta_=llm_config.rope_theta,
|
| 535 |
+
rope_scaling_=llm_config.rope_scaling,
|
| 536 |
+
pad_token_id_=llm_config.pad_token_id,
|
| 537 |
+
attn_implementation_=attn_impl,
|
| 538 |
+
device_=torch.device(device),
|
| 539 |
+
dtype_=llm_model.dtype,
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
if llm_args.pad_token_id_ is None:
|
| 543 |
+
llm_args.pad_token_id_ = -1
|
| 544 |
+
|
| 545 |
+
model = LlamaForCausalLM(llm_args)
|
| 546 |
+
llm_model.requires_grad_(False)
|
| 547 |
+
model.embed_tokens_ = LlamaEmbedding(
|
| 548 |
+
llm_model.model.embed_tokens.weight, llm_args.pad_token_id_
|
| 549 |
+
)
|
| 550 |
+
model.norm_ = LlamaRMSNorm(llm_model.model.norm.weight, llm_args.rms_norm_eps_)
|
| 551 |
+
copy_parameters(llm_model.lm_head, model.lm_head_)
|
| 552 |
+
|
| 553 |
+
for idx, layer in enumerate(llm_model.model.layers):
|
| 554 |
+
decoder = LlamaDecoderLayer(idx)
|
| 555 |
+
decoder.self_attn_ = LLAMA_ATTENTION_CLASSES[llm_args.attn_implementation_](
|
| 556 |
+
layer.self_attn.q_proj,
|
| 557 |
+
layer.self_attn.k_proj,
|
| 558 |
+
layer.self_attn.v_proj,
|
| 559 |
+
layer.self_attn.o_proj,
|
| 560 |
+
idx,
|
| 561 |
+
llm_args,
|
| 562 |
+
)
|
| 563 |
+
decoder.mlp_ = FeedForward(
|
| 564 |
+
LlamaMLP(
|
| 565 |
+
layer.mlp.gate_proj,
|
| 566 |
+
layer.mlp.down_proj,
|
| 567 |
+
layer.mlp.up_proj,
|
| 568 |
+
llm_args,
|
| 569 |
+
)
|
| 570 |
+
)
|
| 571 |
+
decoder.input_layernorm_ = LlamaRMSNorm(
|
| 572 |
+
layer.input_layernorm.weight, llm_args.rms_norm_eps_
|
| 573 |
+
)
|
| 574 |
+
decoder.post_attention_layernorm_ = LlamaRMSNorm(
|
| 575 |
+
layer.post_attention_layernorm.weight, llm_args.rms_norm_eps_
|
| 576 |
+
)
|
| 577 |
+
model.layers_.append(decoder)
|
| 578 |
+
|
| 579 |
+
return model
|
c2cite/models/modeling_mistral.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from transformers.models.mistral import modeling_mistral
|
| 7 |
+
from transformers.models.qwen2 import modeling_qwen2
|
| 8 |
+
from transformers.utils import is_flash_attn_2_available
|
| 9 |
+
|
| 10 |
+
from moe_peft.common import (
|
| 11 |
+
FeedForward,
|
| 12 |
+
LLMCache,
|
| 13 |
+
LLMModelInput,
|
| 14 |
+
flash_attention_forward,
|
| 15 |
+
)
|
| 16 |
+
from moe_peft.executors import executor
|
| 17 |
+
from moe_peft.models.modeling_llama import (
|
| 18 |
+
LlamaAttention,
|
| 19 |
+
LlamaConfig,
|
| 20 |
+
LlamaDecoderLayer,
|
| 21 |
+
LlamaEmbedding,
|
| 22 |
+
LlamaForCausalLM,
|
| 23 |
+
LlamaMLP,
|
| 24 |
+
LlamaRMSNorm,
|
| 25 |
+
apply_rotary_pos_emb,
|
| 26 |
+
repeat_kv,
|
| 27 |
+
)
|
| 28 |
+
from moe_peft.utils import copy_parameters
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class MistralConfig(LlamaConfig):
|
| 33 |
+
use_sliding_window_: bool = False
|
| 34 |
+
max_window_layers_: int = None
|
| 35 |
+
sliding_window_: int = None
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class MistralFlashAttention(LlamaAttention):
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
wq: nn.Module,
|
| 42 |
+
wk: nn.Module,
|
| 43 |
+
wv: nn.Module,
|
| 44 |
+
wo: nn.Module,
|
| 45 |
+
idx: int,
|
| 46 |
+
args: MistralConfig,
|
| 47 |
+
):
|
| 48 |
+
assert is_flash_attn_2_available(), "Flash Attention is not available"
|
| 49 |
+
super().__init__(wq, wk, wv, wo, idx, args)
|
| 50 |
+
# Qwen2
|
| 51 |
+
self.use_sliding_window_ = args.use_sliding_window_
|
| 52 |
+
self.max_window_layers_ = args.max_window_layers_
|
| 53 |
+
# Mistral and Qwen2
|
| 54 |
+
self.sliding_window_ = args.sliding_window_
|
| 55 |
+
|
| 56 |
+
def forward(
|
| 57 |
+
self,
|
| 58 |
+
hidden_states: torch.Tensor,
|
| 59 |
+
input_args: LLMModelInput,
|
| 60 |
+
rotary_emb: Tuple[torch.Tensor, torch.Tensor],
|
| 61 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 62 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 63 |
+
past_key_value: Optional[LLMCache] = None,
|
| 64 |
+
):
|
| 65 |
+
batch_size, max_seq_len, _ = hidden_states.shape
|
| 66 |
+
|
| 67 |
+
xq = self.wq_.forward(hidden_states, input_args)
|
| 68 |
+
xk = self.wk_.forward(hidden_states, input_args)
|
| 69 |
+
xv = self.wv_.forward(hidden_states, input_args)
|
| 70 |
+
|
| 71 |
+
# conver shape to multi head
|
| 72 |
+
xq = xq.view(batch_size, max_seq_len, self.n_heads_, self.head_dim_).transpose(
|
| 73 |
+
1, 2
|
| 74 |
+
)
|
| 75 |
+
xk = xk.view(
|
| 76 |
+
batch_size, max_seq_len, self.n_kv_heads_, self.head_dim_
|
| 77 |
+
).transpose(1, 2)
|
| 78 |
+
xv = xv.view(
|
| 79 |
+
batch_size, max_seq_len, self.n_kv_heads_, self.head_dim_
|
| 80 |
+
).transpose(1, 2)
|
| 81 |
+
|
| 82 |
+
kv_seq_len = xk.shape[-2]
|
| 83 |
+
if past_key_value is not None:
|
| 84 |
+
kv_seq_len += cache_position[0]
|
| 85 |
+
|
| 86 |
+
# apply rotary embedding
|
| 87 |
+
cos, sin = rotary_emb
|
| 88 |
+
xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin)
|
| 89 |
+
|
| 90 |
+
if past_key_value is not None:
|
| 91 |
+
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
| 92 |
+
cache_has_contents = past_key_value.get_seq_length(self.layer_idx_) > 0
|
| 93 |
+
if (
|
| 94 |
+
self.sliding_window_ is not None
|
| 95 |
+
and kv_seq_len > self.sliding_window_
|
| 96 |
+
and cache_has_contents
|
| 97 |
+
):
|
| 98 |
+
slicing_tokens = 1 - self.sliding_window_
|
| 99 |
+
|
| 100 |
+
past_key = past_key_value[self.layer_idx_][0]
|
| 101 |
+
past_value = past_key_value[self.layer_idx_][1]
|
| 102 |
+
|
| 103 |
+
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
| 104 |
+
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
| 105 |
+
|
| 106 |
+
if past_key.shape[-2] != self.sliding_window_ - 1:
|
| 107 |
+
raise ValueError(
|
| 108 |
+
f"past key must have a shape of (`batch_size, num_heads, self.sliding_window - 1, head_dim`), got"
|
| 109 |
+
f" {past_key.shape}"
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
if attention_mask is not None:
|
| 113 |
+
attention_mask = attention_mask[:, slicing_tokens:]
|
| 114 |
+
attention_mask = torch.cat(
|
| 115 |
+
[attention_mask, torch.ones_like(attention_mask[:, -1:])],
|
| 116 |
+
dim=-1,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
cache_kwargs = {
|
| 120 |
+
"sin": sin,
|
| 121 |
+
"cos": cos,
|
| 122 |
+
"cache_position": cache_position,
|
| 123 |
+
} # Specific to RoPE models
|
| 124 |
+
xk, xv = past_key_value.update(xk, xv, self.layer_idx_, cache_kwargs)
|
| 125 |
+
|
| 126 |
+
xk = repeat_kv(xk, self.n_rep_)
|
| 127 |
+
xv = repeat_kv(xv, self.n_rep_)
|
| 128 |
+
|
| 129 |
+
input_dtype = xq.dtype
|
| 130 |
+
if input_dtype == torch.float32:
|
| 131 |
+
if executor.is_bf16_supported():
|
| 132 |
+
target_dtype = torch.bfloat16
|
| 133 |
+
else:
|
| 134 |
+
target_dtype = torch.float16
|
| 135 |
+
xq = xq.to(target_dtype)
|
| 136 |
+
xk = xk.to(target_dtype)
|
| 137 |
+
xv = xv.to(target_dtype)
|
| 138 |
+
|
| 139 |
+
xq = xq.transpose(1, 2)
|
| 140 |
+
xk = xk.transpose(1, 2)
|
| 141 |
+
xv = xv.transpose(1, 2)
|
| 142 |
+
|
| 143 |
+
if (
|
| 144 |
+
(self.use_sliding_window_ is None or self.use_sliding_window_)
|
| 145 |
+
and self.sliding_window_ is not None
|
| 146 |
+
and (
|
| 147 |
+
self.max_window_layers_ is None
|
| 148 |
+
or self.layer_idx_ >= self.max_window_layers_
|
| 149 |
+
)
|
| 150 |
+
):
|
| 151 |
+
sliding_window = self.sliding_window_
|
| 152 |
+
else:
|
| 153 |
+
sliding_window = None
|
| 154 |
+
|
| 155 |
+
attn_output = flash_attention_forward(
|
| 156 |
+
xq,
|
| 157 |
+
xk,
|
| 158 |
+
xv,
|
| 159 |
+
attention_mask,
|
| 160 |
+
max_seq_len,
|
| 161 |
+
is_causal=self.is_causal_,
|
| 162 |
+
sliding_window=sliding_window,
|
| 163 |
+
).to(input_dtype)
|
| 164 |
+
|
| 165 |
+
attn_output = attn_output.reshape(
|
| 166 |
+
batch_size, max_seq_len, self.dim_
|
| 167 |
+
).contiguous()
|
| 168 |
+
attn_output = self.wo_.forward(attn_output, input_args)
|
| 169 |
+
|
| 170 |
+
return attn_output
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
MISTRAL_ATTENTION_CLASSES = {
|
| 174 |
+
"eager": LlamaAttention,
|
| 175 |
+
"flash_attn": MistralFlashAttention,
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class MistralForCausalLM(LlamaForCausalLM):
|
| 180 |
+
def __init__(self, config: MistralConfig) -> None:
|
| 181 |
+
super().__init__(config)
|
| 182 |
+
|
| 183 |
+
@staticmethod
|
| 184 |
+
def from_pretrained(
|
| 185 |
+
llm_model: modeling_mistral.MistralForCausalLM,
|
| 186 |
+
attn_impl: str = "eager",
|
| 187 |
+
use_sliding_window: bool = False,
|
| 188 |
+
device: str = executor.default_device_name(),
|
| 189 |
+
):
|
| 190 |
+
llm_config: modeling_mistral.MistralConfig = llm_model.config
|
| 191 |
+
llm_args = MistralConfig(
|
| 192 |
+
name_or_path_=llm_config.name_or_path,
|
| 193 |
+
vocab_size_=llm_config.vocab_size,
|
| 194 |
+
dim_=llm_config.hidden_size,
|
| 195 |
+
head_dim_=llm_config.hidden_size // llm_config.num_attention_heads,
|
| 196 |
+
intermediate_=llm_config.intermediate_size,
|
| 197 |
+
n_layers_=llm_config.num_hidden_layers,
|
| 198 |
+
n_heads_=llm_config.num_attention_heads,
|
| 199 |
+
n_kv_heads_=llm_config.num_key_value_heads,
|
| 200 |
+
hidden_act_=llm_config.hidden_act,
|
| 201 |
+
rms_norm_eps_=llm_config.rms_norm_eps,
|
| 202 |
+
max_seq_len_=llm_config.max_position_embeddings,
|
| 203 |
+
rope_theta_=llm_config.rope_theta,
|
| 204 |
+
pad_token_id_=llm_config.pad_token_id,
|
| 205 |
+
attn_implementation_=attn_impl,
|
| 206 |
+
use_sliding_window_=use_sliding_window,
|
| 207 |
+
sliding_window_=llm_config.sliding_window,
|
| 208 |
+
device_=torch.device(device),
|
| 209 |
+
dtype_=llm_model.dtype,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# compatible with qwen2
|
| 213 |
+
if isinstance(llm_config, modeling_qwen2.Qwen2Config):
|
| 214 |
+
llm_args.max_window_layers_ = llm_config.max_window_layers
|
| 215 |
+
|
| 216 |
+
if llm_args.pad_token_id_ is None:
|
| 217 |
+
llm_args.pad_token_id_ = -1
|
| 218 |
+
|
| 219 |
+
model = MistralForCausalLM(llm_args)
|
| 220 |
+
llm_model.requires_grad_(False)
|
| 221 |
+
model.embed_tokens_ = LlamaEmbedding(
|
| 222 |
+
llm_model.model.embed_tokens.weight, llm_args.pad_token_id_
|
| 223 |
+
)
|
| 224 |
+
model.norm_ = LlamaRMSNorm(llm_model.model.norm.weight, llm_args.rms_norm_eps_)
|
| 225 |
+
copy_parameters(llm_model.lm_head, model.lm_head_)
|
| 226 |
+
|
| 227 |
+
for idx, layer in enumerate(llm_model.model.layers):
|
| 228 |
+
decoder = LlamaDecoderLayer(idx)
|
| 229 |
+
decoder.self_attn_ = MISTRAL_ATTENTION_CLASSES[
|
| 230 |
+
llm_args.attn_implementation_
|
| 231 |
+
](
|
| 232 |
+
layer.self_attn.q_proj,
|
| 233 |
+
layer.self_attn.k_proj,
|
| 234 |
+
layer.self_attn.v_proj,
|
| 235 |
+
layer.self_attn.o_proj,
|
| 236 |
+
idx,
|
| 237 |
+
llm_args,
|
| 238 |
+
)
|
| 239 |
+
decoder.mlp_ = FeedForward(
|
| 240 |
+
LlamaMLP(
|
| 241 |
+
layer.mlp.gate_proj,
|
| 242 |
+
layer.mlp.down_proj,
|
| 243 |
+
layer.mlp.up_proj,
|
| 244 |
+
llm_args,
|
| 245 |
+
)
|
| 246 |
+
)
|
| 247 |
+
decoder.input_layernorm_ = LlamaRMSNorm(
|
| 248 |
+
layer.input_layernorm.weight, llm_args.rms_norm_eps_
|
| 249 |
+
)
|
| 250 |
+
decoder.post_attention_layernorm_ = LlamaRMSNorm(
|
| 251 |
+
layer.post_attention_layernorm.weight, llm_args.rms_norm_eps_
|
| 252 |
+
)
|
| 253 |
+
model.layers_.append(decoder)
|
| 254 |
+
|
| 255 |
+
return model
|
c2cite/models/modeling_phi.py
ADDED
|
@@ -0,0 +1,576 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Dict, List, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from transformers.activations import ACT2FN
|
| 8 |
+
from transformers.models.phi import modeling_phi
|
| 9 |
+
from transformers.models.phi.modeling_phi import (
|
| 10 |
+
PhiRotaryEmbedding,
|
| 11 |
+
apply_rotary_pos_emb,
|
| 12 |
+
repeat_kv,
|
| 13 |
+
)
|
| 14 |
+
from transformers.utils import is_flash_attn_2_available
|
| 15 |
+
|
| 16 |
+
from moe_peft.common import (
|
| 17 |
+
FeedForward,
|
| 18 |
+
Linear,
|
| 19 |
+
LLMAttention,
|
| 20 |
+
LLMCache,
|
| 21 |
+
LLMDecoder,
|
| 22 |
+
LLMFeedForward,
|
| 23 |
+
LLMForCausalLM,
|
| 24 |
+
LLMModelConfig,
|
| 25 |
+
LLMModelInput,
|
| 26 |
+
collect_plugin_router_logtis,
|
| 27 |
+
eager_attention_forward,
|
| 28 |
+
flash_attention_forward,
|
| 29 |
+
prepare_4d_causal_attention_mask,
|
| 30 |
+
slice_tensor,
|
| 31 |
+
)
|
| 32 |
+
from moe_peft.executors import executor
|
| 33 |
+
from moe_peft.utils import copy_parameters
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class PhiConfig(LLMModelConfig):
|
| 38 |
+
layer_norm_eps_: float = 1e-05
|
| 39 |
+
resid_pdrop_: float = 0.0
|
| 40 |
+
embd_pdrop_: float = 0.0
|
| 41 |
+
rotary_emb_dim_: int = 0
|
| 42 |
+
qk_layernorm_: bool = False
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def apply_partial_rotary_emb(
|
| 46 |
+
xq: torch.Tensor,
|
| 47 |
+
xk: torch.Tensor,
|
| 48 |
+
rotary_emb_dim: int,
|
| 49 |
+
cos: torch.Tensor,
|
| 50 |
+
sin: torch.Tensor,
|
| 51 |
+
position_ids: torch.Tensor,
|
| 52 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 53 |
+
q_rot, q_pass = (
|
| 54 |
+
xq[..., :rotary_emb_dim],
|
| 55 |
+
xq[..., rotary_emb_dim:],
|
| 56 |
+
)
|
| 57 |
+
k_rot, k_pass = (
|
| 58 |
+
xk[..., :rotary_emb_dim],
|
| 59 |
+
xk[..., rotary_emb_dim:],
|
| 60 |
+
)
|
| 61 |
+
# [batch_size, seq_length, num_heads, head_dim // partial_rotary_factor]
|
| 62 |
+
q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin, position_ids)
|
| 63 |
+
|
| 64 |
+
# [batch_size, seq_length, num_heads, head_dim]
|
| 65 |
+
xq = torch.cat((q_rot, q_pass), dim=-1)
|
| 66 |
+
xk = torch.cat((k_rot, k_pass), dim=-1)
|
| 67 |
+
|
| 68 |
+
return xq, xk
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# Multi-headed attention from 'Attention Is All You Need' paper.
|
| 72 |
+
class PhiAttention(LLMAttention):
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
q_proj: nn.Module,
|
| 76 |
+
k_proj: nn.Module,
|
| 77 |
+
v_proj: nn.Module,
|
| 78 |
+
dense: nn.Module,
|
| 79 |
+
idx: int,
|
| 80 |
+
config: PhiConfig,
|
| 81 |
+
):
|
| 82 |
+
super().__init__()
|
| 83 |
+
# attention
|
| 84 |
+
self.wq_: Linear = Linear(q_proj, config.device_)
|
| 85 |
+
self.wk_: Linear = Linear(k_proj, config.device_)
|
| 86 |
+
self.wv_: Linear = Linear(v_proj, config.device_)
|
| 87 |
+
self.dense_: Linear = Linear(dense, config.device_)
|
| 88 |
+
# config
|
| 89 |
+
self.layer_idx_ = idx
|
| 90 |
+
self.dim_ = config.dim_
|
| 91 |
+
self.n_heads_ = config.n_heads_
|
| 92 |
+
self.n_kv_heads_ = config.n_kv_heads_
|
| 93 |
+
self.n_rep_ = self.n_heads_ // self.n_kv_heads_
|
| 94 |
+
self.rotary_emb_dim_ = config.rotary_emb_dim_
|
| 95 |
+
self.head_dim_ = config.head_dim_
|
| 96 |
+
self.dtype_ = config.dtype_
|
| 97 |
+
self.is_causal_ = True
|
| 98 |
+
# qk norm
|
| 99 |
+
self.qk_layernorm_: bool = config.qk_layernorm_
|
| 100 |
+
if self.qk_layernorm_:
|
| 101 |
+
self.q_layernorm_ = nn.LayerNorm(
|
| 102 |
+
self.hidden_size_ // self.num_heads_,
|
| 103 |
+
eps=config.norm_eps_,
|
| 104 |
+
elementwise_affine=True,
|
| 105 |
+
)
|
| 106 |
+
self.k_layernorm_ = nn.LayerNorm(
|
| 107 |
+
self.hidden_size_ // self.num_heads_,
|
| 108 |
+
eps=config.norm_eps_,
|
| 109 |
+
elementwise_affine=True,
|
| 110 |
+
)
|
| 111 |
+
else:
|
| 112 |
+
self.q_layernorm_ = nn.Identity()
|
| 113 |
+
self.k_layernorm_ = nn.Identity()
|
| 114 |
+
|
| 115 |
+
def state_dict(self) -> Dict[str, Linear]:
|
| 116 |
+
return {
|
| 117 |
+
"q_proj": self.wq_,
|
| 118 |
+
"k_proj": self.wk_,
|
| 119 |
+
"v_proj": self.wv_,
|
| 120 |
+
"dense": self.dense_,
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
def forward(
|
| 124 |
+
self,
|
| 125 |
+
hidden_states: torch.Tensor,
|
| 126 |
+
input_args: LLMModelInput,
|
| 127 |
+
rotary_emb: Tuple[torch.Tensor, torch.Tensor],
|
| 128 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 129 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 130 |
+
past_key_value: Optional[LLMCache] = None,
|
| 131 |
+
):
|
| 132 |
+
batch_size, max_seq_len, _ = hidden_states.shape
|
| 133 |
+
|
| 134 |
+
xq = self.wq_.forward(hidden_states, input_args)
|
| 135 |
+
xk = self.wk_.forward(hidden_states, input_args)
|
| 136 |
+
xv = self.wv_.forward(hidden_states, input_args)
|
| 137 |
+
|
| 138 |
+
xq = self.q_layernorm_(xq)
|
| 139 |
+
xk = self.k_layernorm_(xk)
|
| 140 |
+
|
| 141 |
+
# conver shape to multi head
|
| 142 |
+
xq = xq.view(batch_size, max_seq_len, self.n_heads_, self.head_dim_).transpose(
|
| 143 |
+
1, 2
|
| 144 |
+
)
|
| 145 |
+
xk = xk.view(
|
| 146 |
+
batch_size, max_seq_len, self.n_kv_heads_, self.head_dim_
|
| 147 |
+
).transpose(1, 2)
|
| 148 |
+
xv = xv.view(
|
| 149 |
+
batch_size, max_seq_len, self.n_kv_heads_, self.head_dim_
|
| 150 |
+
).transpose(1, 2)
|
| 151 |
+
|
| 152 |
+
cos, sin = rotary_emb
|
| 153 |
+
|
| 154 |
+
# partial rotary embedding
|
| 155 |
+
xq, xk = apply_partial_rotary_emb(
|
| 156 |
+
xq,
|
| 157 |
+
xk,
|
| 158 |
+
self.rotary_emb_dim_,
|
| 159 |
+
cos,
|
| 160 |
+
sin,
|
| 161 |
+
cache_position.unsqueeze(0),
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
if past_key_value is not None:
|
| 165 |
+
cache_kwargs = {
|
| 166 |
+
"sin": sin,
|
| 167 |
+
"cos": cos,
|
| 168 |
+
"partial_rotation_size": self.rotary_emb_dim_,
|
| 169 |
+
"cache_position": cache_position,
|
| 170 |
+
}
|
| 171 |
+
xk, xv = past_key_value.update(xk, xv, self.layer_idx_, cache_kwargs)
|
| 172 |
+
|
| 173 |
+
# before dim: batch_size, n_kv_head, seq_len, head_dim
|
| 174 |
+
# after dim: batch_size, n_head, seq_len, head_dim
|
| 175 |
+
xk = repeat_kv(xk, self.n_rep_)
|
| 176 |
+
xv = repeat_kv(xv, self.n_rep_)
|
| 177 |
+
|
| 178 |
+
attention_score = eager_attention_forward(
|
| 179 |
+
xq.to(torch.float32), xk.to(torch.float32), xv, attention_mask
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
attention_score = attention_score.reshape(batch_size, max_seq_len, -1)
|
| 183 |
+
attention_score = self.dense_.forward(attention_score, input_args)
|
| 184 |
+
|
| 185 |
+
return attention_score
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class PhiFlashAttention2(PhiAttention):
|
| 189 |
+
def __init__(
|
| 190 |
+
self,
|
| 191 |
+
q_proj: nn.Module,
|
| 192 |
+
k_proj: nn.Module,
|
| 193 |
+
v_proj: nn.Module,
|
| 194 |
+
dense: nn.Module,
|
| 195 |
+
idx: int,
|
| 196 |
+
args: PhiConfig,
|
| 197 |
+
):
|
| 198 |
+
assert is_flash_attn_2_available(), "Flash Attention is not available"
|
| 199 |
+
super().__init__(q_proj, k_proj, v_proj, dense, idx, args)
|
| 200 |
+
|
| 201 |
+
def forward(
|
| 202 |
+
self,
|
| 203 |
+
hidden_states: torch.Tensor,
|
| 204 |
+
input_args: LLMModelInput,
|
| 205 |
+
rotary_emb: Tuple[torch.Tensor, torch.Tensor],
|
| 206 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 207 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 208 |
+
past_key_value: Optional[LLMCache] = None,
|
| 209 |
+
):
|
| 210 |
+
batch_size, max_seq_len, _ = hidden_states.shape
|
| 211 |
+
|
| 212 |
+
xq = self.wq_.forward(hidden_states, input_args)
|
| 213 |
+
xk = self.wk_.forward(hidden_states, input_args)
|
| 214 |
+
xv = self.wv_.forward(hidden_states, input_args)
|
| 215 |
+
|
| 216 |
+
xq = self.q_layernorm_(xq)
|
| 217 |
+
xk = self.k_layernorm_(xk)
|
| 218 |
+
|
| 219 |
+
# conver shape to multi head
|
| 220 |
+
xq = xq.view(batch_size, max_seq_len, self.n_heads_, self.head_dim_).transpose(
|
| 221 |
+
1, 2
|
| 222 |
+
)
|
| 223 |
+
xk = xk.view(
|
| 224 |
+
batch_size, max_seq_len, self.n_kv_heads_, self.head_dim_
|
| 225 |
+
).transpose(1, 2)
|
| 226 |
+
xv = xv.view(
|
| 227 |
+
batch_size, max_seq_len, self.n_kv_heads_, self.head_dim_
|
| 228 |
+
).transpose(1, 2)
|
| 229 |
+
|
| 230 |
+
cos, sin = rotary_emb
|
| 231 |
+
|
| 232 |
+
# partial rotary embedding
|
| 233 |
+
xq, xk = apply_partial_rotary_emb(
|
| 234 |
+
xq,
|
| 235 |
+
xk,
|
| 236 |
+
self.rotary_emb_dim_,
|
| 237 |
+
cos,
|
| 238 |
+
sin,
|
| 239 |
+
cache_position.unsqueeze(0),
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
if past_key_value is not None:
|
| 243 |
+
cache_kwargs = {
|
| 244 |
+
"sin": sin,
|
| 245 |
+
"cos": cos,
|
| 246 |
+
"partial_rotation_size": self.rotary_emb_dim_,
|
| 247 |
+
"cache_position": cache_position,
|
| 248 |
+
}
|
| 249 |
+
xk, xv = past_key_value.update(xk, xv, self.layer_idx_, cache_kwargs)
|
| 250 |
+
|
| 251 |
+
xq = xq.transpose(1, 2)
|
| 252 |
+
xk = xk.transpose(1, 2)
|
| 253 |
+
xv = xv.transpose(1, 2)
|
| 254 |
+
|
| 255 |
+
input_dtype = xq.dtype
|
| 256 |
+
if input_dtype == torch.float32:
|
| 257 |
+
if executor.is_bf16_supported():
|
| 258 |
+
target_dtype = torch.bfloat16
|
| 259 |
+
else:
|
| 260 |
+
target_dtype = torch.float16
|
| 261 |
+
xq = xq.to(target_dtype)
|
| 262 |
+
xk = xk.to(target_dtype)
|
| 263 |
+
xv = xv.to(target_dtype)
|
| 264 |
+
|
| 265 |
+
attn_output = flash_attention_forward(
|
| 266 |
+
xq,
|
| 267 |
+
xk,
|
| 268 |
+
xv,
|
| 269 |
+
attention_mask,
|
| 270 |
+
max_seq_len,
|
| 271 |
+
is_causal=self.is_causal_,
|
| 272 |
+
).to(input_dtype)
|
| 273 |
+
|
| 274 |
+
attn_output = attn_output.reshape(
|
| 275 |
+
batch_size, max_seq_len, self.dim_
|
| 276 |
+
).contiguous()
|
| 277 |
+
attn_output = self.dense_.forward(attn_output, input_args)
|
| 278 |
+
|
| 279 |
+
return attn_output
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
PHI_ATTENTION_CLASSES = {
|
| 283 |
+
"eager": PhiAttention,
|
| 284 |
+
"flash_attn": PhiFlashAttention2,
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
class PhiMLP(LLMFeedForward):
|
| 289 |
+
def __init__(self, fc1: nn.Module, fc2: nn.Module, args: PhiConfig) -> None:
|
| 290 |
+
super().__init__()
|
| 291 |
+
# feed forward
|
| 292 |
+
self.fc1_: Linear = Linear(fc1, args.device_)
|
| 293 |
+
self.fc2_: Linear = Linear(fc2, args.device_)
|
| 294 |
+
self.act_ = ACT2FN[args.hidden_act_]
|
| 295 |
+
|
| 296 |
+
def state_dict(self) -> Dict[str, nn.Module]:
|
| 297 |
+
return {
|
| 298 |
+
"fc1": self.fc1_,
|
| 299 |
+
"fc2": self.fc2_,
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
def _batch_forward(
|
| 303 |
+
self, hidden_states: torch.Tensor, input_args: LLMModelInput
|
| 304 |
+
) -> torch.Tensor:
|
| 305 |
+
hidden_states = self.fc1_.forward(hidden_states, input_args)
|
| 306 |
+
hidden_states = self.act_(hidden_states)
|
| 307 |
+
hidden_states = self.fc2_.forward(hidden_states, input_args)
|
| 308 |
+
return hidden_states
|
| 309 |
+
|
| 310 |
+
def _lora_forward(
|
| 311 |
+
self, lora_name: str, act_fn: nn.Module, hidden_states: torch.Tensor
|
| 312 |
+
) -> torch.Tensor:
|
| 313 |
+
if lora_name in self.fc1_.loras_:
|
| 314 |
+
hidden_states = self.fc1_.loras_[lora_name].forward(
|
| 315 |
+
self.fc1_.base_layer_.forward(hidden_states), hidden_states
|
| 316 |
+
)
|
| 317 |
+
else:
|
| 318 |
+
hidden_states = self.fc1_.base_layer_.forward(hidden_states)
|
| 319 |
+
|
| 320 |
+
hidden_states = act_fn(hidden_states)
|
| 321 |
+
|
| 322 |
+
if lora_name in self.fc2_.loras_:
|
| 323 |
+
hidden_states = self.fc2_.loras_[lora_name].forward(
|
| 324 |
+
self.fc2_.base_layer_.forward(hidden_states), hidden_states
|
| 325 |
+
)
|
| 326 |
+
else:
|
| 327 |
+
hidden_states = self.fc2_.base_layer_.forward(hidden_states)
|
| 328 |
+
|
| 329 |
+
return hidden_states
|
| 330 |
+
|
| 331 |
+
def _mixlora_forward(
|
| 332 |
+
self, moe_name, act_fn, expert_mask, hidden_states, input_dtype
|
| 333 |
+
):
|
| 334 |
+
common_fc1 = self.fc1_.base_layer_.forward(hidden_states.to(input_dtype)).to(
|
| 335 |
+
hidden_states.dtype
|
| 336 |
+
)
|
| 337 |
+
final_expert_states = []
|
| 338 |
+
for expert_idx in range(expert_mask.shape[0]):
|
| 339 |
+
_, top_x = torch.where(expert_mask[expert_idx])
|
| 340 |
+
|
| 341 |
+
lora_name = f"moe.{moe_name}.experts.{expert_idx}"
|
| 342 |
+
if lora_name in self.fc1_.loras_:
|
| 343 |
+
lora_data = slice_tensor(hidden_states, top_x, input_dtype)
|
| 344 |
+
act_result = act_fn(
|
| 345 |
+
self.fc1_.loras_[lora_name].forward(
|
| 346 |
+
slice_tensor(common_fc1, top_x, input_dtype), lora_data
|
| 347 |
+
)
|
| 348 |
+
)
|
| 349 |
+
else:
|
| 350 |
+
act_result = act_fn(slice_tensor(common_fc1, top_x, input_dtype))
|
| 351 |
+
|
| 352 |
+
if lora_name in self.fc2_.loras_:
|
| 353 |
+
final_expert_states.append(
|
| 354 |
+
self.fc2_.loras_[lora_name].forward(
|
| 355 |
+
self.fc2_.base_layer_.forward(act_result), act_result
|
| 356 |
+
)
|
| 357 |
+
)
|
| 358 |
+
else:
|
| 359 |
+
final_expert_states.append(self.fc2_.base_layer_.forward(act_result))
|
| 360 |
+
|
| 361 |
+
return final_expert_states
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
class PhiDecoderLayer(LLMDecoder):
|
| 365 |
+
def __init__(
|
| 366 |
+
self, layer_id: int, self_attn: LLMAttention, mlp: FeedForward, args: PhiConfig
|
| 367 |
+
) -> None:
|
| 368 |
+
super().__init__()
|
| 369 |
+
self.layer_id_: int = layer_id
|
| 370 |
+
self.self_attn_ = self_attn
|
| 371 |
+
self.mlp_ = mlp
|
| 372 |
+
self.input_layernorm_ = nn.LayerNorm(
|
| 373 |
+
args.dim_, eps=args.layer_norm_eps_, dtype=args.dtype_, device=args.device_
|
| 374 |
+
)
|
| 375 |
+
self.resid_pdrop_ = args.resid_pdrop_
|
| 376 |
+
|
| 377 |
+
def state_dict(self) -> Tuple[Dict[str, nn.Module], Dict[str, nn.Module]]:
|
| 378 |
+
return self.self_attn_.state_dict(), self.mlp_.state_dict()
|
| 379 |
+
|
| 380 |
+
def forward(
|
| 381 |
+
self,
|
| 382 |
+
hidden_states: torch.Tensor,
|
| 383 |
+
input_args: LLMModelInput,
|
| 384 |
+
rotary_emb: Tuple[torch.Tensor, torch.Tensor],
|
| 385 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 386 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 387 |
+
past_key_value: Optional[LLMCache] = None,
|
| 388 |
+
):
|
| 389 |
+
residual = hidden_states
|
| 390 |
+
hidden_states = self.input_layernorm_(hidden_states)
|
| 391 |
+
# Self Attention
|
| 392 |
+
attn_outputs = self.self_attn_.forward(
|
| 393 |
+
hidden_states,
|
| 394 |
+
input_args,
|
| 395 |
+
rotary_emb,
|
| 396 |
+
attention_mask,
|
| 397 |
+
cache_position,
|
| 398 |
+
past_key_value,
|
| 399 |
+
)
|
| 400 |
+
attn_outputs = F.dropout(
|
| 401 |
+
attn_outputs, self.resid_pdrop_, not input_args.inference_mode_
|
| 402 |
+
)
|
| 403 |
+
# Fully Connected
|
| 404 |
+
feed_forward_outputs, router_logits = self.mlp_.forward(
|
| 405 |
+
hidden_states, input_args
|
| 406 |
+
)
|
| 407 |
+
feed_forward_outputs = F.dropout(
|
| 408 |
+
feed_forward_outputs, self.resid_pdrop_, not input_args.inference_mode_
|
| 409 |
+
)
|
| 410 |
+
hidden_states = attn_outputs + feed_forward_outputs + residual
|
| 411 |
+
|
| 412 |
+
if input_args.output_router_logits_:
|
| 413 |
+
router_logits = collect_plugin_router_logtis(
|
| 414 |
+
router_logits, input_args, self
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
return hidden_states, *router_logits
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
class PhiEmbedding(nn.Module):
|
| 421 |
+
def __init__(self, config: PhiConfig):
|
| 422 |
+
super().__init__()
|
| 423 |
+
self.embed_tokens = nn.Embedding(
|
| 424 |
+
config.vocab_size_,
|
| 425 |
+
config.dim_,
|
| 426 |
+
config.pad_token_id_,
|
| 427 |
+
dtype=config.dtype_,
|
| 428 |
+
device=config.device_,
|
| 429 |
+
)
|
| 430 |
+
self.embed_dropout = nn.Dropout(config.embd_pdrop_)
|
| 431 |
+
|
| 432 |
+
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 433 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 434 |
+
return self.embed_dropout(inputs_embeds)
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
class PhiLayerNorm(nn.Module):
|
| 438 |
+
def __init__(self, config: PhiConfig) -> None:
|
| 439 |
+
super().__init__()
|
| 440 |
+
self.layernorm_ = nn.LayerNorm(
|
| 441 |
+
config.dim_,
|
| 442 |
+
eps=config.layer_norm_eps_,
|
| 443 |
+
dtype=config.dtype_,
|
| 444 |
+
device=config.device_,
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
def forward(self, data: torch.Tensor) -> torch.Tensor:
|
| 448 |
+
return self.layernorm_(data)
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
class PhiForCausalLM(LLMForCausalLM):
|
| 452 |
+
def __init__(self, config: PhiConfig) -> None:
|
| 453 |
+
super().__init__()
|
| 454 |
+
self.config_ = config
|
| 455 |
+
self.padding_idx_ = config.pad_token_id_
|
| 456 |
+
self.vocab_size_ = config.vocab_size_
|
| 457 |
+
self.embed_tokens_ = PhiEmbedding(config)
|
| 458 |
+
self.final_layernorm_ = PhiLayerNorm(config)
|
| 459 |
+
self.rotary_emb_ = PhiRotaryEmbedding(
|
| 460 |
+
dim=config.rotary_emb_dim_,
|
| 461 |
+
max_position_embeddings=config.max_seq_len_,
|
| 462 |
+
base=config.rope_theta_,
|
| 463 |
+
device=config.device_,
|
| 464 |
+
)
|
| 465 |
+
self.lm_head_ = nn.Linear(
|
| 466 |
+
config.dim_,
|
| 467 |
+
config.vocab_size_,
|
| 468 |
+
bias=True,
|
| 469 |
+
dtype=config.dtype_,
|
| 470 |
+
device=config.device_,
|
| 471 |
+
)
|
| 472 |
+
self.layers_: List[PhiDecoderLayer] = []
|
| 473 |
+
|
| 474 |
+
def embed_tokens(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 475 |
+
return self.embed_tokens_(input_ids)
|
| 476 |
+
|
| 477 |
+
def rotary_embed(
|
| 478 |
+
self, input_tensor: torch.Tensor, position_ids: torch.Tensor
|
| 479 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 480 |
+
return self.rotary_emb_(input_tensor, seq_len=position_ids[-1, -1] + 1)
|
| 481 |
+
|
| 482 |
+
def decoder_stack(self) -> List[LLMDecoder]:
|
| 483 |
+
return self.layers_
|
| 484 |
+
|
| 485 |
+
def norm(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 486 |
+
return self.final_layernorm_(hidden_states)
|
| 487 |
+
|
| 488 |
+
def causal_mask(
|
| 489 |
+
self,
|
| 490 |
+
attention_mask: torch.Tensor,
|
| 491 |
+
input_tensor: torch.Tensor,
|
| 492 |
+
cache_position: torch.Tensor,
|
| 493 |
+
past_key_values: Optional[LLMCache],
|
| 494 |
+
) -> torch.Tensor:
|
| 495 |
+
|
| 496 |
+
return prepare_4d_causal_attention_mask(
|
| 497 |
+
attention_mask,
|
| 498 |
+
input_tensor,
|
| 499 |
+
cache_position,
|
| 500 |
+
past_key_values,
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
def model_config(self) -> PhiConfig:
|
| 504 |
+
return self.config_
|
| 505 |
+
|
| 506 |
+
@staticmethod
|
| 507 |
+
def from_pretrained(
|
| 508 |
+
llm_model: modeling_phi.PhiForCausalLM,
|
| 509 |
+
attn_impl: str = "eager",
|
| 510 |
+
use_sliding_window: bool = False,
|
| 511 |
+
device: str = executor.default_device_name(),
|
| 512 |
+
):
|
| 513 |
+
assert not use_sliding_window, "Phi model does not support SWA."
|
| 514 |
+
llm_config: modeling_phi.PhiConfig = llm_model.config
|
| 515 |
+
llm_args = PhiConfig(
|
| 516 |
+
name_or_path_=llm_config.name_or_path,
|
| 517 |
+
vocab_size_=llm_config.vocab_size,
|
| 518 |
+
dim_=llm_config.hidden_size,
|
| 519 |
+
head_dim_=llm_config.hidden_size // llm_config.num_attention_heads,
|
| 520 |
+
intermediate_=llm_config.intermediate_size,
|
| 521 |
+
n_layers_=llm_config.num_hidden_layers,
|
| 522 |
+
n_heads_=llm_config.num_attention_heads,
|
| 523 |
+
n_kv_heads_=llm_config.num_key_value_heads,
|
| 524 |
+
hidden_act_=llm_config.hidden_act,
|
| 525 |
+
resid_pdrop_=llm_config.resid_pdrop,
|
| 526 |
+
embd_pdrop_=llm_config.embd_pdrop,
|
| 527 |
+
max_seq_len_=llm_config.max_position_embeddings,
|
| 528 |
+
layer_norm_eps_=llm_config.layer_norm_eps,
|
| 529 |
+
rope_theta_=llm_config.rope_theta,
|
| 530 |
+
partial_rotary_factor_=llm_config.partial_rotary_factor,
|
| 531 |
+
qk_layernorm_=llm_config.qk_layernorm,
|
| 532 |
+
pad_token_id_=llm_config.pad_token_id,
|
| 533 |
+
attn_implementation_=attn_impl,
|
| 534 |
+
device_=torch.device(device),
|
| 535 |
+
dtype_=llm_model.dtype,
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
llm_args.rotary_emb_dim_ = int(
|
| 539 |
+
llm_args.partial_rotary_factor_ * llm_args.head_dim_
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
if llm_args.pad_token_id_ is None:
|
| 543 |
+
llm_args.pad_token_id_ = -1
|
| 544 |
+
|
| 545 |
+
model = PhiForCausalLM(llm_args)
|
| 546 |
+
llm_model.requires_grad_(False)
|
| 547 |
+
copy_parameters(llm_model.model.embed_tokens, model.embed_tokens_.embed_tokens)
|
| 548 |
+
copy_parameters(
|
| 549 |
+
llm_model.model.final_layernorm, model.final_layernorm_.layernorm_
|
| 550 |
+
)
|
| 551 |
+
copy_parameters(llm_model.lm_head, model.lm_head_)
|
| 552 |
+
|
| 553 |
+
for idx, layer in enumerate(llm_model.model.layers):
|
| 554 |
+
decoder = PhiDecoderLayer(
|
| 555 |
+
idx,
|
| 556 |
+
PHI_ATTENTION_CLASSES[llm_args.attn_implementation_](
|
| 557 |
+
layer.self_attn.q_proj,
|
| 558 |
+
layer.self_attn.k_proj,
|
| 559 |
+
layer.self_attn.v_proj,
|
| 560 |
+
layer.self_attn.dense,
|
| 561 |
+
idx,
|
| 562 |
+
llm_args,
|
| 563 |
+
),
|
| 564 |
+
FeedForward(
|
| 565 |
+
PhiMLP(
|
| 566 |
+
layer.mlp.fc1,
|
| 567 |
+
layer.mlp.fc2,
|
| 568 |
+
llm_args,
|
| 569 |
+
)
|
| 570 |
+
),
|
| 571 |
+
llm_args,
|
| 572 |
+
)
|
| 573 |
+
copy_parameters(layer.input_layernorm, decoder.input_layernorm_)
|
| 574 |
+
model.layers_.append(decoder)
|
| 575 |
+
|
| 576 |
+
return model
|
c2cite/models/modeling_phi3.py
ADDED
|
@@ -0,0 +1,581 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from transformers.activations import ACT2FN
|
| 8 |
+
from transformers.models.phi3 import modeling_phi3
|
| 9 |
+
from transformers.models.phi3.modeling_phi3 import apply_rotary_pos_emb, repeat_kv
|
| 10 |
+
from transformers.utils import is_flash_attn_2_available
|
| 11 |
+
|
| 12 |
+
from moe_peft.common import (
|
| 13 |
+
FeedForward,
|
| 14 |
+
Linear,
|
| 15 |
+
LLMAttention,
|
| 16 |
+
LLMCache,
|
| 17 |
+
LLMDecoder,
|
| 18 |
+
LLMFeedForward,
|
| 19 |
+
LLMForCausalLM,
|
| 20 |
+
LLMModelConfig,
|
| 21 |
+
LLMModelInput,
|
| 22 |
+
collect_plugin_router_logtis,
|
| 23 |
+
eager_attention_forward,
|
| 24 |
+
flash_attention_forward,
|
| 25 |
+
prepare_4d_causal_attention_mask,
|
| 26 |
+
slice_tensor,
|
| 27 |
+
)
|
| 28 |
+
from moe_peft.executors import executor
|
| 29 |
+
from moe_peft.utils import copy_parameters
|
| 30 |
+
|
| 31 |
+
from .modeling_gemma2 import Gemma2RotaryEmbedding as Phi3RotaryEmbedding
|
| 32 |
+
from .modeling_llama import LlamaEmbedding as Phi3Embedding
|
| 33 |
+
from .modeling_llama import LlamaRMSNorm as Phi3RMSNorm
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class Phi3Config(LLMModelConfig):
|
| 38 |
+
rms_norm_eps_: float = 1e-6
|
| 39 |
+
original_max_position_embeddings_: int = 4096
|
| 40 |
+
rope_scaling_: Optional[Dict[str, Any]] = None
|
| 41 |
+
use_sliding_window_: bool = False
|
| 42 |
+
sliding_window_: int = 4096
|
| 43 |
+
resid_pdrop_: float = 0.0
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class Phi3LongRoPEScaledRotaryEmbedding(Phi3RotaryEmbedding):
|
| 47 |
+
def __init__(self, dim, config: Phi3Config, device=None):
|
| 48 |
+
super().__init__(dim, config.max_seq_len_, config.rope_theta_, device)
|
| 49 |
+
|
| 50 |
+
self.short_factor = config.rope_scaling_["short_factor"]
|
| 51 |
+
self.long_factor = config.rope_scaling_["long_factor"]
|
| 52 |
+
self.original_max_position_embeddings = config.original_max_position_embeddings_
|
| 53 |
+
|
| 54 |
+
@torch.no_grad()
|
| 55 |
+
def forward(self, x, position_ids):
|
| 56 |
+
seq_len = torch.max(position_ids) + 1
|
| 57 |
+
if seq_len > self.original_max_position_embeddings:
|
| 58 |
+
ext_factors = torch.tensor(
|
| 59 |
+
self.long_factor, dtype=torch.float32, device=x.device
|
| 60 |
+
)
|
| 61 |
+
else:
|
| 62 |
+
ext_factors = torch.tensor(
|
| 63 |
+
self.short_factor, dtype=torch.float32, device=x.device
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
inv_freq_shape = (
|
| 67 |
+
torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float()
|
| 68 |
+
/ self.dim
|
| 69 |
+
)
|
| 70 |
+
self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
|
| 71 |
+
|
| 72 |
+
inv_freq_expanded = (
|
| 73 |
+
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
| 74 |
+
)
|
| 75 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 76 |
+
|
| 77 |
+
# Force float32 since bfloat16 loses precision on long contexts
|
| 78 |
+
# See https://github.com/huggingface/transformers/pull/29285
|
| 79 |
+
device_type = x.device.type
|
| 80 |
+
device_type = (
|
| 81 |
+
device_type
|
| 82 |
+
if isinstance(device_type, str) and device_type != "mps"
|
| 83 |
+
else "cpu"
|
| 84 |
+
)
|
| 85 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
| 86 |
+
freqs = (
|
| 87 |
+
inv_freq_expanded.float() @ position_ids_expanded.float()
|
| 88 |
+
).transpose(1, 2)
|
| 89 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 90 |
+
|
| 91 |
+
scale = self.max_position_embeddings / self.original_max_position_embeddings
|
| 92 |
+
if scale <= 1.0:
|
| 93 |
+
scaling_factor = 1.0
|
| 94 |
+
else:
|
| 95 |
+
scaling_factor = math.sqrt(
|
| 96 |
+
1
|
| 97 |
+
+ math.log(scale) / math.log(self.original_max_position_embeddings)
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
cos = emb.cos() * scaling_factor
|
| 101 |
+
sin = emb.sin() * scaling_factor
|
| 102 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class Phi3Attention(LLMAttention):
|
| 106 |
+
def __init__(
|
| 107 |
+
self, qkv_proj: nn.Module, o_proj: nn.Module, layer_idx: int, args: Phi3Config
|
| 108 |
+
) -> None:
|
| 109 |
+
super().__init__()
|
| 110 |
+
# attention
|
| 111 |
+
self.qkv_proj_ = Linear(qkv_proj, args.device_)
|
| 112 |
+
self.o_proj_ = Linear(o_proj, args.device_)
|
| 113 |
+
# config
|
| 114 |
+
self.layer_idx_ = layer_idx
|
| 115 |
+
self.args_ = args
|
| 116 |
+
self.dim_ = args.dim_
|
| 117 |
+
self.n_heads_ = args.n_heads_
|
| 118 |
+
self.n_kv_heads_ = args.n_kv_heads_
|
| 119 |
+
self.n_rep_ = self.n_heads_ // self.n_kv_heads_
|
| 120 |
+
self.rope_theta_ = args.rope_theta_
|
| 121 |
+
self.head_dim_ = self.dim_ // self.n_heads_
|
| 122 |
+
self.dtype_ = args.dtype_
|
| 123 |
+
self.is_causal_ = True
|
| 124 |
+
|
| 125 |
+
def state_dict(self) -> Dict[str, Linear]:
|
| 126 |
+
return {
|
| 127 |
+
"qkv_proj": self.qkv_proj_,
|
| 128 |
+
"o_proj": self.o_proj_,
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
def forward(
|
| 132 |
+
self,
|
| 133 |
+
hidden_states: torch.Tensor,
|
| 134 |
+
input_args: LLMModelInput,
|
| 135 |
+
rotary_emb: Tuple[torch.Tensor, torch.Tensor],
|
| 136 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 137 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 138 |
+
past_key_value: Optional[LLMCache] = None,
|
| 139 |
+
):
|
| 140 |
+
bsz, q_len, _ = hidden_states.size()
|
| 141 |
+
|
| 142 |
+
qkv = self.qkv_proj_.forward(hidden_states, input_args)
|
| 143 |
+
query_pos = self.n_heads_ * self.head_dim_
|
| 144 |
+
query_states = qkv[..., :query_pos]
|
| 145 |
+
key_states = qkv[..., query_pos : query_pos + self.n_kv_heads_ * self.head_dim_]
|
| 146 |
+
value_states = qkv[..., query_pos + self.n_kv_heads_ * self.head_dim_ :]
|
| 147 |
+
|
| 148 |
+
query_states = query_states.view(
|
| 149 |
+
bsz, q_len, self.n_heads_, self.head_dim_
|
| 150 |
+
).transpose(1, 2)
|
| 151 |
+
key_states = key_states.view(
|
| 152 |
+
bsz, q_len, self.n_kv_heads_, self.head_dim_
|
| 153 |
+
).transpose(1, 2)
|
| 154 |
+
value_states = value_states.view(
|
| 155 |
+
bsz, q_len, self.n_kv_heads_, self.head_dim_
|
| 156 |
+
).transpose(1, 2)
|
| 157 |
+
|
| 158 |
+
# apply rotary embedding
|
| 159 |
+
cos, sin = rotary_emb
|
| 160 |
+
assert query_states.dtype == key_states.dtype
|
| 161 |
+
query_states, key_states = apply_rotary_pos_emb(
|
| 162 |
+
query_states, key_states, cos, sin, cache_position.unsqueeze(0)
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
if past_key_value is not None:
|
| 166 |
+
cache_kwargs = {
|
| 167 |
+
"sin": sin,
|
| 168 |
+
"cos": cos,
|
| 169 |
+
"cache_position": cache_position,
|
| 170 |
+
}
|
| 171 |
+
key_states, value_states = past_key_value.update(
|
| 172 |
+
key_states, value_states, self.layer_idx_, cache_kwargs
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
value_states = repeat_kv(value_states, self.n_rep_)
|
| 176 |
+
key_states = repeat_kv(key_states, self.n_rep_)
|
| 177 |
+
|
| 178 |
+
attn_output = eager_attention_forward(
|
| 179 |
+
query_states, key_states, value_states, attention_mask
|
| 180 |
+
)
|
| 181 |
+
attn_output = attn_output.reshape(bsz, q_len, -1)
|
| 182 |
+
|
| 183 |
+
return self.o_proj_(attn_output, input_args)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class Phi3FlashAttention2(Phi3Attention):
|
| 187 |
+
def __init__(
|
| 188 |
+
self, qkv_proj: nn.Module, o_proj: nn.Module, layer_idx: int, args: Phi3Config
|
| 189 |
+
) -> None:
|
| 190 |
+
assert is_flash_attn_2_available(), "Flash Attention is not available"
|
| 191 |
+
super().__init__(qkv_proj, o_proj, layer_idx, args)
|
| 192 |
+
self.sliding_window_ = args.sliding_window_
|
| 193 |
+
|
| 194 |
+
def forward(
|
| 195 |
+
self,
|
| 196 |
+
hidden_states: torch.Tensor,
|
| 197 |
+
input_args: LLMModelInput,
|
| 198 |
+
rotary_emb: Tuple[torch.Tensor, torch.Tensor],
|
| 199 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 200 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 201 |
+
past_key_value: Optional[LLMCache] = None,
|
| 202 |
+
):
|
| 203 |
+
|
| 204 |
+
bsz, q_len, _ = hidden_states.size()
|
| 205 |
+
|
| 206 |
+
# cutting
|
| 207 |
+
qkv = self.qkv_proj_.forward(hidden_states, input_args)
|
| 208 |
+
query_pos = self.n_heads_ * self.head_dim_
|
| 209 |
+
query_states = qkv[..., :query_pos]
|
| 210 |
+
key_states = qkv[..., query_pos : query_pos + self.n_kv_heads_ * self.head_dim_]
|
| 211 |
+
value_states = qkv[..., query_pos + self.n_kv_heads_ * self.head_dim_ :]
|
| 212 |
+
|
| 213 |
+
# viewing
|
| 214 |
+
query_states = query_states.view(
|
| 215 |
+
bsz, q_len, self.n_heads_, self.head_dim_
|
| 216 |
+
).transpose(1, 2)
|
| 217 |
+
key_states = key_states.view(
|
| 218 |
+
bsz, q_len, self.n_kv_heads_, self.head_dim_
|
| 219 |
+
).transpose(1, 2)
|
| 220 |
+
value_states = value_states.view(
|
| 221 |
+
bsz, q_len, self.n_kv_heads_, self.head_dim_
|
| 222 |
+
).transpose(1, 2)
|
| 223 |
+
|
| 224 |
+
kv_seq_len = key_states.shape[-2]
|
| 225 |
+
if past_key_value is not None:
|
| 226 |
+
kv_seq_len += cache_position[0]
|
| 227 |
+
|
| 228 |
+
# apply rotary embedding
|
| 229 |
+
cos, sin = rotary_emb
|
| 230 |
+
query_states, key_states = apply_rotary_pos_emb(
|
| 231 |
+
query_states, key_states, cos, sin
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
# Activate slicing cache
|
| 235 |
+
if past_key_value is not None:
|
| 236 |
+
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
| 237 |
+
cache_has_contents = past_key_value.get_seq_length(self.layer_idx_) > 0
|
| 238 |
+
if (
|
| 239 |
+
self.sliding_window_ is not None
|
| 240 |
+
and kv_seq_len > self.sliding_window_
|
| 241 |
+
and cache_has_contents
|
| 242 |
+
):
|
| 243 |
+
slicing_tokens = 1 - self.sliding_window_
|
| 244 |
+
|
| 245 |
+
past_key = past_key_value[self.layer_idx_][0]
|
| 246 |
+
past_value = past_key_value[self.layer_idx_][1]
|
| 247 |
+
|
| 248 |
+
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
| 249 |
+
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
| 250 |
+
|
| 251 |
+
if past_key.shape[-2] != self.sliding_window_ - 1:
|
| 252 |
+
raise ValueError(
|
| 253 |
+
f"past key must have a shape of (`batch_size, num_heads, self.sliding_window - 1, head_dim`), got"
|
| 254 |
+
f" {past_key.shape}"
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
if attention_mask is not None:
|
| 258 |
+
attention_mask = attention_mask[:, slicing_tokens:]
|
| 259 |
+
attention_mask = torch.cat(
|
| 260 |
+
[attention_mask, torch.ones_like(attention_mask[:, -1:])],
|
| 261 |
+
dim=-1,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
cache_kwargs = {
|
| 265 |
+
"sin": sin,
|
| 266 |
+
"cos": cos,
|
| 267 |
+
"cache_position": cache_position,
|
| 268 |
+
} # Specific to RoPE models
|
| 269 |
+
key_states, value_states = past_key_value.update(
|
| 270 |
+
key_states, value_states, self.layer_idx_, cache_kwargs
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
| 274 |
+
key_states = repeat_kv(key_states, self.n_rep_)
|
| 275 |
+
value_states = repeat_kv(value_states, self.n_rep_)
|
| 276 |
+
|
| 277 |
+
input_dtype = query_states.dtype
|
| 278 |
+
if input_dtype == torch.float32:
|
| 279 |
+
if executor.is_bf16_supported():
|
| 280 |
+
target_dtype = torch.bfloat16
|
| 281 |
+
else:
|
| 282 |
+
target_dtype = torch.float16
|
| 283 |
+
query_states = query_states.to(target_dtype)
|
| 284 |
+
key_states = key_states.to(target_dtype)
|
| 285 |
+
value_states = value_states.to(target_dtype)
|
| 286 |
+
|
| 287 |
+
query_states = query_states.transpose(1, 2)
|
| 288 |
+
key_states = key_states.transpose(1, 2)
|
| 289 |
+
value_states = value_states.transpose(1, 2)
|
| 290 |
+
|
| 291 |
+
attn_output = flash_attention_forward(
|
| 292 |
+
query_states,
|
| 293 |
+
key_states,
|
| 294 |
+
value_states,
|
| 295 |
+
attention_mask,
|
| 296 |
+
q_len,
|
| 297 |
+
is_causal=self.is_causal_,
|
| 298 |
+
sliding_window=self.sliding_window_,
|
| 299 |
+
).to(input_dtype)
|
| 300 |
+
|
| 301 |
+
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
| 302 |
+
attn_output = self.o_proj_(attn_output, input_args)
|
| 303 |
+
|
| 304 |
+
return attn_output
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
PHI3_ATTENTION_CLASSES = {
|
| 308 |
+
"eager": Phi3Attention,
|
| 309 |
+
"flash_attn": Phi3FlashAttention2,
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
class Phi3MLP(LLMFeedForward):
|
| 314 |
+
def __init__(self, gate: nn.Module, down: nn.Module, args: Phi3Config) -> None:
|
| 315 |
+
super().__init__()
|
| 316 |
+
# feed forward
|
| 317 |
+
self.gate_up_proj_ = Linear(gate, args.device_)
|
| 318 |
+
self.down_proj_ = Linear(down, args.device_)
|
| 319 |
+
self.act_ = ACT2FN[args.hidden_act_]
|
| 320 |
+
|
| 321 |
+
def state_dict(self) -> Dict[str, nn.Module]:
|
| 322 |
+
return {
|
| 323 |
+
"gate_up_proj": self.gate_up_proj_,
|
| 324 |
+
"down_proj": self.down_proj_,
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
def _batch_forward(
|
| 328 |
+
self, hidden_states: torch.Tensor, input_args: LLMModelInput
|
| 329 |
+
) -> torch.Tensor:
|
| 330 |
+
up_states = self.gate_up_proj_(hidden_states, input_args)
|
| 331 |
+
|
| 332 |
+
gate, up_states = up_states.chunk(2, dim=-1)
|
| 333 |
+
up_states = up_states * self.act_(gate)
|
| 334 |
+
|
| 335 |
+
return self.down_proj_(up_states, input_args)
|
| 336 |
+
|
| 337 |
+
def _lora_forward(
|
| 338 |
+
self, lora_name: str, act_fn: nn.Module, data: torch.Tensor
|
| 339 |
+
) -> torch.Tensor:
|
| 340 |
+
# Applying LoRA weights to FFN weights
|
| 341 |
+
if lora_name in self.gate_up_proj_.loras_:
|
| 342 |
+
gate_up_states = self.gate_up_proj_.loras_[lora_name].forward(
|
| 343 |
+
self.gate_up_proj_.base_layer_.forward(data), data
|
| 344 |
+
)
|
| 345 |
+
else:
|
| 346 |
+
gate_up_states = self.gate_up_proj_.base_layer_.forward(data)
|
| 347 |
+
|
| 348 |
+
gate_states, up_states = gate_up_states.chunk(2, dim=-1)
|
| 349 |
+
act_result = act_fn(gate_states) * up_states
|
| 350 |
+
|
| 351 |
+
if lora_name in self.down_proj_.loras_:
|
| 352 |
+
return self.down_proj_.loras_[lora_name].forward(
|
| 353 |
+
self.down_proj_.base_layer_.forward(act_result), act_result
|
| 354 |
+
)
|
| 355 |
+
else:
|
| 356 |
+
return self.down_proj_.base_layer_.forward(act_result)
|
| 357 |
+
|
| 358 |
+
def _mixlora_forward(
|
| 359 |
+
self, moe_name, act_fn, expert_mask, hidden_states, input_dtype
|
| 360 |
+
):
|
| 361 |
+
common_gate_up = self.gate_up_proj_.base_layer_.forward(
|
| 362 |
+
hidden_states.to(input_dtype)
|
| 363 |
+
).to(hidden_states.dtype)
|
| 364 |
+
|
| 365 |
+
final_expert_states = []
|
| 366 |
+
for expert_idx in range(expert_mask.shape[0]):
|
| 367 |
+
_, top_x = torch.where(expert_mask[expert_idx])
|
| 368 |
+
|
| 369 |
+
lora_name = f"moe.{moe_name}.experts.{expert_idx}"
|
| 370 |
+
if lora_name in self.gate_up_proj_.loras_:
|
| 371 |
+
gate_up_states = self.gate_up_proj_.loras_[lora_name].forward(
|
| 372 |
+
slice_tensor(common_gate_up, top_x, input_dtype),
|
| 373 |
+
slice_tensor(hidden_states, top_x, input_dtype),
|
| 374 |
+
)
|
| 375 |
+
else:
|
| 376 |
+
gate_up_states = slice_tensor(common_gate_up, top_x, input_dtype)
|
| 377 |
+
|
| 378 |
+
gate_states, up_states = gate_up_states.chunk(2, dim=-1)
|
| 379 |
+
act_result = up_states * act_fn(gate_states)
|
| 380 |
+
|
| 381 |
+
if lora_name in self.down_proj_.loras_:
|
| 382 |
+
final_expert_states.append(
|
| 383 |
+
self.down_proj_.loras_[lora_name].forward(
|
| 384 |
+
self.down_proj_.base_layer_.forward(act_result),
|
| 385 |
+
act_result,
|
| 386 |
+
)
|
| 387 |
+
)
|
| 388 |
+
else:
|
| 389 |
+
final_expert_states.append(
|
| 390 |
+
self.down_proj_.base_layer_.forward(act_result)
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
return final_expert_states
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
class Phi3DecoderLayer(LLMDecoder):
|
| 397 |
+
def __init__(self, layer_id: int, config: Phi3Config) -> None:
|
| 398 |
+
super().__init__()
|
| 399 |
+
self.layer_id_: int = layer_id
|
| 400 |
+
self.self_attn_: Phi3Attention = None
|
| 401 |
+
self.mlp_: FeedForward = None
|
| 402 |
+
self.input_layernorm_: Phi3RMSNorm = None
|
| 403 |
+
|
| 404 |
+
self.resid_attn_dropout = nn.Dropout(config.resid_pdrop_)
|
| 405 |
+
self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop_)
|
| 406 |
+
self.post_attention_layernorm_: Phi3RMSNorm = None
|
| 407 |
+
|
| 408 |
+
def state_dict(self) -> Tuple[Dict[str, nn.Module], Dict[str, nn.Module]]:
|
| 409 |
+
return self.self_attn_.state_dict(), self.mlp_.state_dict()
|
| 410 |
+
|
| 411 |
+
def forward(
|
| 412 |
+
self,
|
| 413 |
+
hidden_states: torch.Tensor,
|
| 414 |
+
input_args: LLMModelInput,
|
| 415 |
+
rotary_emb: Tuple[torch.Tensor, torch.Tensor],
|
| 416 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 417 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 418 |
+
past_key_value: Optional[LLMCache] = None,
|
| 419 |
+
):
|
| 420 |
+
residual = hidden_states
|
| 421 |
+
hidden_states = self.input_layernorm_(hidden_states)
|
| 422 |
+
# Self Attention
|
| 423 |
+
attn_outputs = self.self_attn_.forward(
|
| 424 |
+
hidden_states,
|
| 425 |
+
input_args,
|
| 426 |
+
rotary_emb,
|
| 427 |
+
attention_mask,
|
| 428 |
+
cache_position,
|
| 429 |
+
past_key_value,
|
| 430 |
+
)
|
| 431 |
+
hidden_states = residual + self.resid_attn_dropout(attn_outputs)
|
| 432 |
+
# Fully Connected
|
| 433 |
+
residual = hidden_states
|
| 434 |
+
hidden_states = self.post_attention_layernorm_(hidden_states)
|
| 435 |
+
hidden_states, router_logits = self.mlp_.forward(hidden_states, input_args)
|
| 436 |
+
hidden_states = residual + self.resid_mlp_dropout(hidden_states)
|
| 437 |
+
|
| 438 |
+
if input_args.output_router_logits_:
|
| 439 |
+
router_logits = collect_plugin_router_logtis(
|
| 440 |
+
router_logits, input_args, self
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
return hidden_states, *router_logits
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
class Phi3ForCausalLM(LLMForCausalLM):
|
| 447 |
+
def _init_rope(self):
|
| 448 |
+
if self.config_.rope_scaling_ is None:
|
| 449 |
+
return Phi3RotaryEmbedding(
|
| 450 |
+
self.config_.head_dim_,
|
| 451 |
+
max_position_embeddings=self.config_.max_seq_len_,
|
| 452 |
+
base=self.config_.rope_theta_,
|
| 453 |
+
device=self.config_.device_,
|
| 454 |
+
)
|
| 455 |
+
else:
|
| 456 |
+
scaling_type = self.config_.rope_scaling_["type"]
|
| 457 |
+
assert scaling_type == "longrope", ValueError(
|
| 458 |
+
f"Unknown RoPE scaling type {scaling_type}"
|
| 459 |
+
)
|
| 460 |
+
return Phi3LongRoPEScaledRotaryEmbedding(
|
| 461 |
+
self.config_.head_dim_,
|
| 462 |
+
config=self.config_,
|
| 463 |
+
device=self.config_.device_,
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
def __init__(self, config: Phi3Config) -> None:
|
| 467 |
+
super().__init__()
|
| 468 |
+
self.config_ = config
|
| 469 |
+
self.padding_idx_ = config.pad_token_id_
|
| 470 |
+
self.vocab_size_ = config.vocab_size_
|
| 471 |
+
self.embed_tokens_: Phi3Embedding = None
|
| 472 |
+
self.norm_: Phi3Embedding = None
|
| 473 |
+
self.rotary_emb_ = self._init_rope()
|
| 474 |
+
self.lm_head_ = nn.Linear(
|
| 475 |
+
config.dim_,
|
| 476 |
+
config.vocab_size_,
|
| 477 |
+
bias=False,
|
| 478 |
+
dtype=config.dtype_,
|
| 479 |
+
device=config.device_,
|
| 480 |
+
)
|
| 481 |
+
self.layers_: List[Phi3DecoderLayer] = []
|
| 482 |
+
|
| 483 |
+
def embed_tokens(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 484 |
+
return self.embed_tokens_(input_ids)
|
| 485 |
+
|
| 486 |
+
def rotary_embed(
|
| 487 |
+
self, input_tensor: torch.Tensor, position_ids: torch.Tensor
|
| 488 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 489 |
+
return self.rotary_emb_(input_tensor, position_ids)
|
| 490 |
+
|
| 491 |
+
def decoder_stack(self) -> List[LLMDecoder]:
|
| 492 |
+
return self.layers_
|
| 493 |
+
|
| 494 |
+
def norm(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 495 |
+
return self.norm_(hidden_states)
|
| 496 |
+
|
| 497 |
+
def causal_mask(
|
| 498 |
+
self,
|
| 499 |
+
attention_mask: torch.Tensor,
|
| 500 |
+
input_tensor: torch.Tensor,
|
| 501 |
+
cache_position: torch.Tensor,
|
| 502 |
+
past_key_values: Optional[LLMCache],
|
| 503 |
+
) -> torch.Tensor:
|
| 504 |
+
|
| 505 |
+
return prepare_4d_causal_attention_mask(
|
| 506 |
+
attention_mask,
|
| 507 |
+
input_tensor,
|
| 508 |
+
cache_position,
|
| 509 |
+
past_key_values,
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
def model_config(self) -> Phi3Config:
|
| 513 |
+
return self.config_
|
| 514 |
+
|
| 515 |
+
@staticmethod
|
| 516 |
+
def from_pretrained(
|
| 517 |
+
llm_model: modeling_phi3.Phi3ForCausalLM,
|
| 518 |
+
attn_impl: str = "eager",
|
| 519 |
+
use_sliding_window: bool = False,
|
| 520 |
+
device: str = executor.default_device_name(),
|
| 521 |
+
):
|
| 522 |
+
llm_config: modeling_phi3.Phi3Config = llm_model.config
|
| 523 |
+
llm_args = Phi3Config(
|
| 524 |
+
name_or_path_=llm_config.name_or_path,
|
| 525 |
+
vocab_size_=llm_config.vocab_size,
|
| 526 |
+
dim_=llm_config.hidden_size,
|
| 527 |
+
head_dim_=llm_config.hidden_size // llm_config.num_attention_heads,
|
| 528 |
+
intermediate_=llm_config.intermediate_size,
|
| 529 |
+
n_layers_=llm_config.num_hidden_layers,
|
| 530 |
+
n_heads_=llm_config.num_attention_heads,
|
| 531 |
+
n_kv_heads_=llm_config.num_key_value_heads,
|
| 532 |
+
hidden_act_=llm_config.hidden_act,
|
| 533 |
+
rms_norm_eps_=llm_config.rms_norm_eps,
|
| 534 |
+
resid_pdrop_=llm_config.resid_pdrop,
|
| 535 |
+
max_seq_len_=llm_config.max_position_embeddings,
|
| 536 |
+
rope_theta_=llm_config.rope_theta,
|
| 537 |
+
rope_scaling_=llm_config.rope_scaling,
|
| 538 |
+
original_max_position_embeddings_=llm_config.original_max_position_embeddings,
|
| 539 |
+
pad_token_id_=llm_config.pad_token_id,
|
| 540 |
+
attn_implementation_=attn_impl,
|
| 541 |
+
use_sliding_window_=use_sliding_window,
|
| 542 |
+
sliding_window_=llm_config.sliding_window,
|
| 543 |
+
device_=torch.device(device),
|
| 544 |
+
dtype_=llm_model.dtype,
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
if llm_args.pad_token_id_ is None:
|
| 548 |
+
llm_args.pad_token_id_ = -1
|
| 549 |
+
|
| 550 |
+
model = Phi3ForCausalLM(llm_args)
|
| 551 |
+
llm_model.requires_grad_(False)
|
| 552 |
+
model.embed_tokens_ = Phi3Embedding(
|
| 553 |
+
llm_model.model.embed_tokens.weight, llm_args.pad_token_id_
|
| 554 |
+
)
|
| 555 |
+
model.norm_ = Phi3RMSNorm(llm_model.model.norm.weight, llm_args.rms_norm_eps_)
|
| 556 |
+
copy_parameters(llm_model.lm_head, model.lm_head_)
|
| 557 |
+
|
| 558 |
+
for idx, layer in enumerate(llm_model.model.layers):
|
| 559 |
+
decoder = Phi3DecoderLayer(idx, llm_args)
|
| 560 |
+
decoder.self_attn_ = PHI3_ATTENTION_CLASSES[llm_args.attn_implementation_](
|
| 561 |
+
layer.self_attn.qkv_proj,
|
| 562 |
+
layer.self_attn.o_proj,
|
| 563 |
+
idx,
|
| 564 |
+
llm_args,
|
| 565 |
+
)
|
| 566 |
+
decoder.mlp_ = FeedForward(
|
| 567 |
+
Phi3MLP(
|
| 568 |
+
layer.mlp.gate_up_proj,
|
| 569 |
+
layer.mlp.down_proj,
|
| 570 |
+
llm_args,
|
| 571 |
+
)
|
| 572 |
+
)
|
| 573 |
+
decoder.input_layernorm_ = Phi3RMSNorm(
|
| 574 |
+
layer.input_layernorm.weight, llm_args.rms_norm_eps_
|
| 575 |
+
)
|
| 576 |
+
decoder.post_attention_layernorm_ = Phi3RMSNorm(
|
| 577 |
+
layer.post_attention_layernorm.weight, llm_args.rms_norm_eps_
|
| 578 |
+
)
|
| 579 |
+
model.layers_.append(decoder)
|
| 580 |
+
|
| 581 |
+
return model
|
c2cite/prompter.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
import os.path as osp
|
| 4 |
+
from typing import Dict, Optional, Union
|
| 5 |
+
|
| 6 |
+
prompt_templates = {
|
| 7 |
+
"moe_peft": {
|
| 8 |
+
"description": "Default Prompt Template Provided by MoE-PEFT",
|
| 9 |
+
"prompt_input": "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Output:\n",
|
| 10 |
+
"prompt_no_input": "### Instruction:\n{instruction}\n\n### Output:\n",
|
| 11 |
+
"response_split": "### Output:",
|
| 12 |
+
},
|
| 13 |
+
"alpaca": {
|
| 14 |
+
"description": "Template used by Alpaca-LoRA.",
|
| 15 |
+
"prompt_input": "Below is an instruction that describes a task, "
|
| 16 |
+
+ "paired with an input that provides further context. "
|
| 17 |
+
+ "Write a response that appropriately completes the request.\n\n"
|
| 18 |
+
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
|
| 19 |
+
"prompt_no_input": "Below is an instruction that describes a task. "
|
| 20 |
+
+ "Write a response that appropriately completes the request.\n\n"
|
| 21 |
+
+ "### Instruction:\n{instruction}\n\n### Response:\n",
|
| 22 |
+
"response_split": "### Response:",
|
| 23 |
+
},
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# manage templates and prompt building.
|
| 28 |
+
class Prompter:
|
| 29 |
+
def __init__(self, template: Optional[Union[Dict, str]] = None):
|
| 30 |
+
if template is None:
|
| 31 |
+
self.template = prompt_templates["moe_peft"]
|
| 32 |
+
elif isinstance(template, str):
|
| 33 |
+
if osp.exists(template):
|
| 34 |
+
with open(template) as fp:
|
| 35 |
+
self.template = json.load(fp)
|
| 36 |
+
else:
|
| 37 |
+
self.template = prompt_templates[template]
|
| 38 |
+
else:
|
| 39 |
+
self.template = template
|
| 40 |
+
|
| 41 |
+
logging.info(f"Using prompt template: {self.template['description']}")
|
| 42 |
+
|
| 43 |
+
def generate_prompt(
|
| 44 |
+
self,
|
| 45 |
+
instruction: str,
|
| 46 |
+
input: Union[None, str] = None,
|
| 47 |
+
label: Union[None, str] = None,
|
| 48 |
+
) -> str:
|
| 49 |
+
# returns the full prompt from instruction and optional input
|
| 50 |
+
# if a label (=response, =output) is provided, it's also appended.
|
| 51 |
+
if input:
|
| 52 |
+
res = self.template["prompt_input"].format(
|
| 53 |
+
instruction=instruction, input=input
|
| 54 |
+
)
|
| 55 |
+
else:
|
| 56 |
+
res = self.template["prompt_no_input"].format(instruction=instruction)
|
| 57 |
+
if label:
|
| 58 |
+
res = f"{res}{label}\n"
|
| 59 |
+
logging.debug(res)
|
| 60 |
+
return res
|
| 61 |
+
|
| 62 |
+
def get_response(self, output: str) -> str:
|
| 63 |
+
return output.split(self.template["response_split"])[-1].strip()
|
c2cite/solutions.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
# peering那篇
|
| 5 |
+
def get_output(layers, hidden, output, ans_len):
|
| 6 |
+
if layers == 32:
|
| 7 |
+
pass
|
| 8 |
+
else:
|
| 9 |
+
pass
|
c2cite/tasks/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import glue_tasks, qa_tasks, attribute_tasks
|
| 2 |
+
from .common import (
|
| 3 |
+
AutoMetric,
|
| 4 |
+
BasicMetric,
|
| 5 |
+
BasicTask,
|
| 6 |
+
CasualTask,
|
| 7 |
+
CommonSenseTask,
|
| 8 |
+
MultiTask,
|
| 9 |
+
SequenceClassificationTask,
|
| 10 |
+
task_dict,
|
| 11 |
+
)
|
| 12 |
+
from .qa_tasks import QuestionAnswerTask
|
| 13 |
+
|
| 14 |
+
glue_tasks.update_task_dict(task_dict)
|
| 15 |
+
qa_tasks.update_task_dict(task_dict)
|
| 16 |
+
attribute_tasks.update_task_dict(task_dict)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
"BasicMetric",
|
| 21 |
+
"AutoMetric",
|
| 22 |
+
"BasicTask",
|
| 23 |
+
"CasualTask",
|
| 24 |
+
"SequenceClassificationTask",
|
| 25 |
+
"CommonSenseTask",
|
| 26 |
+
"QuestionAnswerTask",
|
| 27 |
+
"MultiTask",
|
| 28 |
+
"task_dict",
|
| 29 |
+
]
|
c2cite/tasks/attribute_tasks.py
ADDED
|
@@ -0,0 +1,567 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import random
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
|
| 5 |
+
import datasets as hf_datasets
|
| 6 |
+
import torch
|
| 7 |
+
import json
|
| 8 |
+
import re
|
| 9 |
+
import os
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
from transformers import BertTokenizer, BertModel
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
from moe_peft.common import InputData
|
| 16 |
+
|
| 17 |
+
from moe_peft.tasks.common import AttributeTask, BasicMetric, AutoMetric
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class AttributedAnswerTask(AttributeTask):
|
| 21 |
+
def __init__(self) -> None:
|
| 22 |
+
super().__init__()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def loading_metric(self, metrics: List[str]):
|
| 26 |
+
|
| 27 |
+
return AutoMetric("attribute", metrics)
|
| 28 |
+
|
| 29 |
+
class ASQA(AttributedAnswerTask):
|
| 30 |
+
def __init__(self, sub: str = 'vani'):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.inst = 'Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite them properly. Use an unbiased and journalistic tone. Always cite for any factual claim. When citing several search results, use [1][2][3]. Cite at least one document and at most three documents in each sentence. If multiple documents support the sentence, only cite a minimum sufficient subset of the documents.'
|
| 33 |
+
self.inst_special_token = 'Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite them properly. Use an unbiased and journalistic tone. Always cite for any factual claim. Cite at least one document and at most three documents in each sentence. If multiple documents support the sentence, only cite a minimum sufficient subset of the documents.'
|
| 34 |
+
self.inst_new = 'Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite all of them at the end of the sentences. Use an unbiased and journalistic tone. Always cite for any factual claim. Cite at least one document in each sentence.'
|
| 35 |
+
self.sub = sub
|
| 36 |
+
|
| 37 |
+
def loading_data(self, is_train: bool = False, path: str = None, few_shot: bool = True
|
| 38 |
+
) -> List[InputData]:
|
| 39 |
+
few_shot = False #################################
|
| 40 |
+
|
| 41 |
+
num_docs = 5
|
| 42 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 43 |
+
relative_path = "../../dataset/ALCE-data/asqa_eval_gtr_top100.json" # 向上两级再进入dataset目录
|
| 44 |
+
file_path = os.path.join(current_dir, relative_path)
|
| 45 |
+
|
| 46 |
+
with open(path if path is not None else file_path,'r',encoding='utf-8') as file:
|
| 47 |
+
data = json.load(file)
|
| 48 |
+
logging.info("Preparing data for ASQA")
|
| 49 |
+
ret: List[InputData] = []
|
| 50 |
+
#cnt = 5
|
| 51 |
+
"""tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
|
| 52 |
+
model = BertModel.from_pretrained('bert-large-uncased')
|
| 53 |
+
device = 'cuda:6'
|
| 54 |
+
model = model.to(device)
|
| 55 |
+
model.eval()"""
|
| 56 |
+
for data_point in tqdm(data):
|
| 57 |
+
#if cnt == 0:
|
| 58 |
+
# break
|
| 59 |
+
#cnt = cnt - 1
|
| 60 |
+
#prompt = ""
|
| 61 |
+
prompt = "<|start_header_id|>system<|end_header_id|>\n\n" + "You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
|
| 62 |
+
#prompt += self.inst_new
|
| 63 |
+
prompt += self.inst_special_token
|
| 64 |
+
if few_shot:
|
| 65 |
+
prompt += f"Here is an example:\n\nQuestion: Who played galen in planet of the apes?\n\nDocument [1](Title: Planet of the Apes): installment. Jacobs died on June 27, 1973, bringing an end to the APJAC Productions era of the \"Planet of the Apes\" franchise. Former Fox executive Stan Hough took over as producer for the television project, titled \"Planet of the Apes\". CBS picked up the series for its 1974 autumn lineup. Ron Harper and James Naughton played Alan Virdon and Peter Burke, two 20th-century American astronauts who pass through a time warp to a future where apes subjugate humans (unlike the original film, the humans can speak). Roddy McDowall returned to the franchise as Galen, a chimpanzee who joins the astronauts.\nDocument [2](Title: Planet of the Apes (1968 film)): chimpanzees: animal psychologist Zira (Kim Hunter) and surgeon Galen (Wright King). While unable to speak as his throat wound is healing, called \"Bright Eyes\" by Zira and placed with one of the captive primitive humans he later names \"Nova\", Taylor observes the enhanced society of talking apes and in a strict caste system: the gorillas being the military police, hunters and workers; the orangutans overseeing the affairs of government, science, and religion; and intellectual chimpanzees being mostly scientists. While their society is a theocracy similar to the beginnings of the human Industrial Era, the apes consider the primitive humans as\nDocument [3](Title: Planet of the Apes (1968 film)): Planet of the Apes (1968 film) Planet of the Apes is a 1968 American science fiction film directed by Franklin J. Schaffner. It stars Charlton Heston, Roddy McDowall, Kim Hunter, Maurice Evans, James Whitmore, James Daly and Linda Harrison. The screenplay by Michael Wilson and Rod Serling was loosely based on the 1963 French novel \"La Plan\u00e8te des Singes\" by Pierre Boulle. Jerry Goldsmith composed the groundbreaking avant-garde score. It was the first in a series of five films made between 1968 and 1973, all produced by Arthur P. Jacobs and released by 20th Century Fox. The film tells the\nDocument [4](Title: Planet of the Apes): Rupert Wyatt. To portray ape characters realistically, the production avoided practical effects in favor of performance capture acting, partnering with New Zealand visual effects company Weta Digital. Wyatt cast James Franco as Will Rodman, while veteran performance capture actor Andy Serkis signed on to star as Caesar. \"Rise\" debuted on August 5, 2011. Critics reviewed it positively, especially praising the visual effects and Serkis's performance. It was a major box office hit, taking in $482 million globally, more than five times its $93 million budget. Weta's special effects earned the film two Visual Effects Society Awards and an Oscar nomination\nDocument [5](Title: Planet of the Apes): film stars Mark Wahlberg as astronaut Leo Davidson, who accidentally travels through a wormhole to a distant planet where talking apes enslave humans. He leads a human revolt and upends ape civilization by discovering that the apes evolved from the normal earth primates who had accompanied his mission, and arrived years before. Helena Bonham Carter played chimpanzee Ari, while Tim Roth played the human-hating chimpanzee General Thade. The film received mixed reviews; most critics believed it failed to compare to the original. Much of the negative commentary focused on the confusing plot and twist ending, though many reviewers praised the\n\nAnswer:In the 1968 film Planet of the Apes, Galen was played by Wright King [2]. And in the tv series Planet of the Apes, Galen was played by Roddy McDowall [1].\n\n\n"
|
| 66 |
+
#prompt += f"\n\n\nQusetion: {data_point['qa_pairs'][0]['question']}\n\n"
|
| 67 |
+
prompt += f"\n\n\nQusetion: {data_point['question']}\n\n"
|
| 68 |
+
docs = ""
|
| 69 |
+
cites = []
|
| 70 |
+
for i in range(num_docs):
|
| 71 |
+
cites.append({
|
| 72 |
+
'text': data_point['docs'][i]['text'],
|
| 73 |
+
'title': data_point['docs'][i]['title'],
|
| 74 |
+
'summary': data_point['docs'][i]['summary'],
|
| 75 |
+
})
|
| 76 |
+
#random.shuffle(cites)
|
| 77 |
+
for i in range(num_docs):
|
| 78 |
+
docs += f"Document <|reserved_special_token_{i+1}|>: {cites[i]['text'] if self.sub=='vani' else cites[i]['summary']}\n"
|
| 79 |
+
#docs += f"Document <|reserved_special_token_{i+1}|>(Title: {cites[i]['title']}): {cites[i]['text'] if self.sub=='vani' else cites[i]['summary']}\n"
|
| 80 |
+
#docs += f"Document [{i+1}](Title: {cites[i]['title']}): {cites[i]['text'] if self.sub=='vani' else cites[i]['summary']}\n"
|
| 81 |
+
cites = [cites[i]['text'] if self.sub=='vani' else cites[i]['summary'] for i in range(num_docs)]
|
| 82 |
+
prompt += docs
|
| 83 |
+
prompt += f"\nAnswer:"
|
| 84 |
+
# prompt += "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
|
| 85 |
+
#citation_embeds = sents_embed(cites, model, tokenizer, device)
|
| 86 |
+
ret.append(InputData(inputs=prompt, labels=data_point['answer'], \
|
| 87 |
+
grounds=data_point['qa_pairs'], citations = cites,# citation_embeds = citation_embeds,\
|
| 88 |
+
query = data_point['question']))
|
| 89 |
+
|
| 90 |
+
return ret
|
| 91 |
+
|
| 92 |
+
def loading_metric(self):
|
| 93 |
+
config = {}
|
| 94 |
+
config['task'] = 'asqa'
|
| 95 |
+
config['metric'] = metric_list['asqa']
|
| 96 |
+
return AutoMetric("attribute", config)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class ELI5(AttributedAnswerTask):
|
| 100 |
+
def __init__(self, sub: str = 'vani'):
|
| 101 |
+
super().__init__()
|
| 102 |
+
self.inst = 'Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite them properly. Use an unbiased and journalistic tone. Always cite for any factual claim. When citing several search results, use [1][2][3]. Cite at least one document and at most three documents in each sentence. If multiple documents support the sentence, only cite a minimum sufficient subset of the documents.'
|
| 103 |
+
self.inst_special_token = 'Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite them properly. Use an unbiased and journalistic tone. Always cite for any factual claim. Cite at least one document and at most three documents in each sentence. If multiple documents support the sentence, only cite a minimum sufficient subset of the documents.'
|
| 104 |
+
self.sub = sub
|
| 105 |
+
|
| 106 |
+
def loading_data(self, is_train: bool = False, path: str = None, few_shot: bool = True
|
| 107 |
+
) -> List[InputData]:
|
| 108 |
+
few_shot = False ##############
|
| 109 |
+
num_docs = 5
|
| 110 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 111 |
+
relative_path = "../../dataset/ALCE-data/eli5_eval_bm25_top100.json" # 向上两级再进入dataset目录
|
| 112 |
+
file_path = os.path.join(current_dir, relative_path)
|
| 113 |
+
with open(path if path is not None else file_path,'r',encoding='utf-8') as file:
|
| 114 |
+
data = json.load(file)
|
| 115 |
+
logging.info("Preparing data for ELI5")
|
| 116 |
+
ret: List[InputData] = []
|
| 117 |
+
#cnt = 5
|
| 118 |
+
for data_point in tqdm(data):
|
| 119 |
+
#if cnt == 0:
|
| 120 |
+
# break
|
| 121 |
+
#cnt = cnt - 1
|
| 122 |
+
prompt = "<|start_header_id|>system<|end_header_id|>\n\n" + "You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
|
| 123 |
+
#prompt += self.inst
|
| 124 |
+
prompt += self.inst_special_token
|
| 125 |
+
if few_shot:
|
| 126 |
+
prompt += f"Here is an example:\n\nQuestion: Who played galen in planet of the apes?\n\nDocument [1](Title: Planet of the Apes): installment. Jacobs died on June 27, 1973, bringing an end to the APJAC Productions era of the \"Planet of the Apes\" franchise. Former Fox executive Stan Hough took over as producer for the television project, titled \"Planet of the Apes\". CBS picked up the series for its 1974 autumn lineup. Ron Harper and James Naughton played Alan Virdon and Peter Burke, two 20th-century American astronauts who pass through a time warp to a future where apes subjugate humans (unlike the original film, the humans can speak). Roddy McDowall returned to the franchise as Galen, a chimpanzee who joins the astronauts.\nDocument [2](Title: Planet of the Apes (1968 film)): chimpanzees: animal psychologist Zira (Kim Hunter) and surgeon Galen (Wright King). While unable to speak as his throat wound is healing, called \"Bright Eyes\" by Zira and placed with one of the captive primitive humans he later names \"Nova\", Taylor observes the enhanced society of talking apes and in a strict caste system: the gorillas being the military police, hunters and workers; the orangutans overseeing the affairs of government, science, and religion; and intellectual chimpanzees being mostly scientists. While their society is a theocracy similar to the beginnings of the human Industrial Era, the apes consider the primitive humans as\nDocument [3](Title: Planet of the Apes (1968 film)): Planet of the Apes (1968 film) Planet of the Apes is a 1968 American science fiction film directed by Franklin J. Schaffner. It stars Charlton Heston, Roddy McDowall, Kim Hunter, Maurice Evans, James Whitmore, James Daly and Linda Harrison. The screenplay by Michael Wilson and Rod Serling was loosely based on the 1963 French novel \"La Plan\u00e8te des Singes\" by Pierre Boulle. Jerry Goldsmith composed the groundbreaking avant-garde score. It was the first in a series of five films made between 1968 and 1973, all produced by Arthur P. Jacobs and released by 20th Century Fox. The film tells the\nDocument [4](Title: Planet of the Apes): Rupert Wyatt. To portray ape characters realistically, the production avoided practical effects in favor of performance capture acting, partnering with New Zealand visual effects company Weta Digital. Wyatt cast James Franco as Will Rodman, while veteran performance capture actor Andy Serkis signed on to star as Caesar. \"Rise\" debuted on August 5, 2011. Critics reviewed it positively, especially praising the visual effects and Serkis's performance. It was a major box office hit, taking in $482 million globally, more than five times its $93 million budget. Weta's special effects earned the film two Visual Effects Society Awards and an Oscar nomination\nDocument [5](Title: Planet of the Apes): film stars Mark Wahlberg as astronaut Leo Davidson, who accidentally travels through a wormhole to a distant planet where talking apes enslave humans. He leads a human revolt and upends ape civilization by discovering that the apes evolved from the normal earth primates who had accompanied his mission, and arrived years before. Helena Bonham Carter played chimpanzee Ari, while Tim Roth played the human-hating chimpanzee General Thade. The film received mixed reviews; most critics believed it failed to compare to the original. Much of the negative commentary focused on the confusing plot and twist ending, though many reviewers praised the\n\nAnswer:In the 1968 film Planet of the Apes, Galen was played by Wright King [2]. And in the tv series Planet of the Apes, Galen was played by Roddy McDowall [1].\n\n\n"
|
| 127 |
+
prompt += f"\n\n\nQusetion: {data_point['question']}\n\n"
|
| 128 |
+
docs = ""
|
| 129 |
+
cites = []
|
| 130 |
+
for i in range(num_docs):
|
| 131 |
+
cites.append({
|
| 132 |
+
'text': data_point['docs'][i]['text'],
|
| 133 |
+
'title': data_point['docs'][i]['title'],
|
| 134 |
+
'summary': data_point['docs'][i]['summary'],
|
| 135 |
+
})
|
| 136 |
+
#random.shuffle(cites)
|
| 137 |
+
for i in range(num_docs):
|
| 138 |
+
docs += f"Document <|reserved_special_token_{i+1}|>: {cites[i]['text'] if self.sub=='vani' else cites[i]['summary']}\n"
|
| 139 |
+
#docs += f"Document [{i+1}](Title: {cites[i]['title']}): {cites[i]['text'] if self.sub=='vani' else cites[i]['summary']}\n"
|
| 140 |
+
cites = [cites[i]['text'] if self.sub=='vani' else cites[i]['summary'] for i in range(num_docs)]
|
| 141 |
+
prompt += docs
|
| 142 |
+
prompt += f"\nAnswer:"
|
| 143 |
+
# prompt += "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
|
| 144 |
+
ret.append(InputData(inputs=prompt, labels=data_point['answer'], \
|
| 145 |
+
grounds=data_point['claims'], citations = cites, \
|
| 146 |
+
query = data_point['question']))
|
| 147 |
+
|
| 148 |
+
return ret
|
| 149 |
+
|
| 150 |
+
def loading_metric(self):
|
| 151 |
+
config = {}
|
| 152 |
+
config['task'] = 'eli5'
|
| 153 |
+
config['metric'] = metric_list['eli5']
|
| 154 |
+
return AutoMetric("attribute", config)
|
| 155 |
+
|
| 156 |
+
class Qampari(AttributedAnswerTask):
|
| 157 |
+
def __init__(self, sub: str = 'vani'):
|
| 158 |
+
super().__init__()
|
| 159 |
+
self.inst = 'Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite them properly. Use an unbiased and journalistic tone. Always cite for any factual claim. When citing several search results, use [1][2][3]. Cite at least one document and at most three documents in each sentence. If multiple documents support the sentence, only cite a minimum sufficient subset of the documents.'
|
| 160 |
+
self.inst_special_token = 'Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite them properly. Use an unbiased and journalistic tone. Always cite for any factual claim. Cite at least one document and at most three documents in each sentence. If multiple documents support the sentence, only cite a minimum sufficient subset of the documents.'
|
| 161 |
+
self.sub = sub
|
| 162 |
+
|
| 163 |
+
def loading_data(self, is_train: bool = False, path: str = None, few_shot: bool = True
|
| 164 |
+
) -> List[InputData]:
|
| 165 |
+
few_shot = False ##############
|
| 166 |
+
num_docs = 5
|
| 167 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 168 |
+
relative_path = "../../dataset/ALCE-data/qampari_eval_gtr_top100.json" # 向上两级再进入dataset目录
|
| 169 |
+
file_path = os.path.join(current_dir, relative_path)
|
| 170 |
+
with open(path if path is not None else file_path,'r',encoding='utf-8') as file:
|
| 171 |
+
data = json.load(file)
|
| 172 |
+
logging.info("Preparing data for Qampari")
|
| 173 |
+
ret: List[InputData] = []
|
| 174 |
+
#cnt = 5
|
| 175 |
+
for data_point in tqdm(data):
|
| 176 |
+
#if cnt == 0:
|
| 177 |
+
# break
|
| 178 |
+
#cnt = cnt - 1
|
| 179 |
+
prompt = "<|start_header_id|>system<|end_header_id|>\n\n" + "You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
|
| 180 |
+
#prompt += self.inst
|
| 181 |
+
prompt += self.inst_special_token
|
| 182 |
+
if few_shot:
|
| 183 |
+
prompt += f"Here is an example:\n\nQuestion: Who played galen in planet of the apes?\n\nDocument [1](Title: Planet of the Apes): installment. Jacobs died on June 27, 1973, bringing an end to the APJAC Productions era of the \"Planet of the Apes\" franchise. Former Fox executive Stan Hough took over as producer for the television project, titled \"Planet of the Apes\". CBS picked up the series for its 1974 autumn lineup. Ron Harper and James Naughton played Alan Virdon and Peter Burke, two 20th-century American astronauts who pass through a time warp to a future where apes subjugate humans (unlike the original film, the humans can speak). Roddy McDowall returned to the franchise as Galen, a chimpanzee who joins the astronauts.\nDocument [2](Title: Planet of the Apes (1968 film)): chimpanzees: animal psychologist Zira (Kim Hunter) and surgeon Galen (Wright King). While unable to speak as his throat wound is healing, called \"Bright Eyes\" by Zira and placed with one of the captive primitive humans he later names \"Nova\", Taylor observes the enhanced society of talking apes and in a strict caste system: the gorillas being the military police, hunters and workers; the orangutans overseeing the affairs of government, science, and religion; and intellectual chimpanzees being mostly scientists. While their society is a theocracy similar to the beginnings of the human Industrial Era, the apes consider the primitive humans as\nDocument [3](Title: Planet of the Apes (1968 film)): Planet of the Apes (1968 film) Planet of the Apes is a 1968 American science fiction film directed by Franklin J. Schaffner. It stars Charlton Heston, Roddy McDowall, Kim Hunter, Maurice Evans, James Whitmore, James Daly and Linda Harrison. The screenplay by Michael Wilson and Rod Serling was loosely based on the 1963 French novel \"La Plan\u00e8te des Singes\" by Pierre Boulle. Jerry Goldsmith composed the groundbreaking avant-garde score. It was the first in a series of five films made between 1968 and 1973, all produced by Arthur P. Jacobs and released by 20th Century Fox. The film tells the\nDocument [4](Title: Planet of the Apes): Rupert Wyatt. To portray ape characters realistically, the production avoided practical effects in favor of performance capture acting, partnering with New Zealand visual effects company Weta Digital. Wyatt cast James Franco as Will Rodman, while veteran performance capture actor Andy Serkis signed on to star as Caesar. \"Rise\" debuted on August 5, 2011. Critics reviewed it positively, especially praising the visual effects and Serkis's performance. It was a major box office hit, taking in $482 million globally, more than five times its $93 million budget. Weta's special effects earned the film two Visual Effects Society Awards and an Oscar nomination\nDocument [5](Title: Planet of the Apes): film stars Mark Wahlberg as astronaut Leo Davidson, who accidentally travels through a wormhole to a distant planet where talking apes enslave humans. He leads a human revolt and upends ape civilization by discovering that the apes evolved from the normal earth primates who had accompanied his mission, and arrived years before. Helena Bonham Carter played chimpanzee Ari, while Tim Roth played the human-hating chimpanzee General Thade. The film received mixed reviews; most critics believed it failed to compare to the original. Much of the negative commentary focused on the confusing plot and twist ending, though many reviewers praised the\n\nAnswer:In the 1968 film Planet of the Apes, Galen was played by Wright King [2]. And in the tv series Planet of the Apes, Galen was played by Roddy McDowall [1].\n\n\n"
|
| 184 |
+
prompt += f"\n\n\nQusetion: {data_point['question']}\n\n"
|
| 185 |
+
docs = ""
|
| 186 |
+
cites = []
|
| 187 |
+
for i in range(num_docs):
|
| 188 |
+
cites.append({
|
| 189 |
+
'text': data_point['docs'][i]['text'],
|
| 190 |
+
'title': data_point['docs'][i]['title'],
|
| 191 |
+
})
|
| 192 |
+
#random.shuffle(cites)
|
| 193 |
+
for i in range(num_docs):
|
| 194 |
+
docs += f"Document <|reserved_special_token_{i+1}|>: {cites[i]['text']}\n"
|
| 195 |
+
#docs += f"Document [{i+1}](Title: {cites[i]['title']}): {cites[i]['text'] if self.sub=='vani' else cites[i]['summary']}\n"
|
| 196 |
+
cites = [cites[i]['text'] for i in range(num_docs)]
|
| 197 |
+
prompt += docs
|
| 198 |
+
prompt += f"\nAnswer:"
|
| 199 |
+
# prompt += "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
|
| 200 |
+
ret.append(InputData(inputs=prompt, labels=data_point['answers'], \
|
| 201 |
+
citations = cites, \
|
| 202 |
+
query = data_point['question']))
|
| 203 |
+
return ret
|
| 204 |
+
|
| 205 |
+
def loading_metric(self):
|
| 206 |
+
config = {}
|
| 207 |
+
config['task'] = 'qam'
|
| 208 |
+
config['metric'] = metric_list['qam']
|
| 209 |
+
return AutoMetric("attribute", config)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class QouteSum(AttributedAnswerTask):
|
| 213 |
+
def __init__(self, sub: str = 'vani'):
|
| 214 |
+
super().__init__()
|
| 215 |
+
self.sub = sub
|
| 216 |
+
self.inst = 'Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite them properly. Use an unbiased and journalistic tone. Always cite for any factual claim. When citing several search results, use [1][2][3]. Cite at least one document and at most three documents in each sentence. If multiple documents support the sentence, only cite a minimum sufficient subset of the documents.'
|
| 217 |
+
self.inst2 = 'Based on the information contained in the document, answer the question with details to the best of your bilities. Think step by step and explain your answer if that will help better understand the answer.'
|
| 218 |
+
self.inst_special_token = 'Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite them properly. Use an unbiased and journalistic tone. Always cite for any factual claim. Cite at least one document and at most three documents in each sentence. If multiple documents support the sentence, only cite a minimum sufficient subset of the documents.'
|
| 219 |
+
self.inst_new = 'Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite all of them at the end of the sentences. Use an unbiased and journalistic tone. Always cite for any factual claim. Cite at least one document in each sentence.'
|
| 220 |
+
|
| 221 |
+
def loading_data(self, is_train: bool = False, path: str = None,
|
| 222 |
+
few_shot: bool = True ) -> List[InputData]:
|
| 223 |
+
few_shot = False ###########
|
| 224 |
+
if is_train:
|
| 225 |
+
few_shot = False
|
| 226 |
+
ret: List[InputData] = []
|
| 227 |
+
examples_by_qid = {}
|
| 228 |
+
"""tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
|
| 229 |
+
model = BertModel.from_pretrained('bert-large-uncased')
|
| 230 |
+
device = 'cuda:6'
|
| 231 |
+
model = model.to(device)
|
| 232 |
+
model.eval()"""
|
| 233 |
+
with open(f"/yy21/MoE-PEFT/dataset/{'qoutesum_alce' if self.sub == 'alce' else ( 'qoutesum_ans' if self.sub == 'ans' else 'qoutesum')}/{'train' if is_train else 'test'}.jsonl" if path is None else path, 'r') as f:
|
| 234 |
+
#cnt = 50
|
| 235 |
+
for line in f:
|
| 236 |
+
#if cnt == 0:
|
| 237 |
+
# break
|
| 238 |
+
#cnt -= 1
|
| 239 |
+
example = json.loads(line.strip())
|
| 240 |
+
if example['qid'] not in examples_by_qid:
|
| 241 |
+
examples_by_qid[example['qid']] = [example]
|
| 242 |
+
else:
|
| 243 |
+
examples_by_qid[example['qid']].append(example)
|
| 244 |
+
|
| 245 |
+
examples = list(examples_by_qid.values())
|
| 246 |
+
for example in examples:
|
| 247 |
+
prompt = "<|start_header_id|>system<|end_header_id|>\n\n" + "You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
|
| 248 |
+
prompt += self.inst_special_token
|
| 249 |
+
#prompt += self.inst_new
|
| 250 |
+
if few_shot:
|
| 251 |
+
if self.sub == 'alce':
|
| 252 |
+
prompt += f" Here are some examples:\nQuestion: how much power does a wind turbine produce?\nDocument [1](Title:): Compact wind acceleration turbine: It is generally thought that since the amount of power produced by a wind turbine is proportional to the cube of the wind speed, any acceleration benefit is potentially statistically significant in the economics of wind. As noted though this is an inaccurate as it ignores the impact of the exit to area ratio and is therefore an apples to oranges comparison. In the case of a typical CWAT/DAWT the power result in perfect theoretical operation once adjusted for the area of the shroud is actually the square of the velocity at the rotor. As the CWAT/DAWT diverges from theoretical function the power increase drops significantly according\nDocument [2](Title:): Sustainable architecture: roof ledge. Small-scale rooftop wind turbines have been known to be able to generate power from 10% to up to 25% of the electricity required of a regular domestic household dwelling. Turbines for residential scale use are usually between 7 feet (2 m) to 25 feet (8 m) in diameter and produce electricity at a rate of 900 watts to 10,000 watts at their tested wind speed. Building integrated wind turbine performance can be enhanced with the addition of an aerofoil wing on top of a roof mounted turbine. Solar water heaters, also called solar domestic hot water systems, can\nDocument [3](Title:): Turby wind turbine: can because horizontal axis (HAWT) types cannot change their pitch to face the wind directly. The turbine measures 2.0m (6'7\") in diameter by 2.9m (9'6\") high (including generator), and weighs 136 kg (300 lb). It is specified to generate power in winds of between 4 m/s (9 mph, 7.8kts) and 14 m/s (31 mph, 27.2kts), and can survive winds of 55 m/s (123 mph, 107kts). The rated power at 14 m/s is 2.5 kW (3.35 hp). The AC output from the synchronous generator is rectified to DC, then inverted to AC at 230V 50 Hz. Core International developed the turbine\nAnswer: One source states the amount of power produced by a wind turbine is proportional to the cube of the wind speed [1]. Other sources state Turbines for residential scale use produce electricity at a rate of 900 watts to 10,000 watts, and is specified to generate power in winds of between 4 m/s (9 mph, 7.8kts) and 14 m/s (31 mph, 27.2kts) [2][3]."
|
| 253 |
+
elif self.sub == 'vani':
|
| 254 |
+
prompt += f" Here are some examples:\nQuestion: how much power does a wind turbine produce?\n[1] Compact wind acceleration turbine: It is generally thought that since the amount of power produced by a wind turbine is proportional to the cube of the wind speed, any acceleration benefit is potentially statistically significant in the economics of wind. As noted though this is an inaccurate as it ignores the impact of the exit to area ratio and is therefore an apples to oranges comparison. In the case of a typical CWAT/DAWT the power result in perfect theoretical operation once adjusted for the area of the shroud is actually the square of the velocity at the rotor. As the CWAT/DAWT diverges from theoretical function the power increase drops significantly according\n[2] Sustainable architecture: roof ledge. Small-scale rooftop wind turbines have been known to be able to generate power from 10% to up to 25% of the electricity required of a regular domestic household dwelling. Turbines for residential scale use are usually between 7 feet (2 m) to 25 feet (8 m) in diameter and produce electricity at a rate of 900 watts to 10,000 watts at their tested wind speed. Building integrated wind turbine performance can be enhanced with the addition of an aerofoil wing on top of a roof mounted turbine. Solar water heaters, also called solar domestic hot water systems, can\n[3] Turby wind turbine: can because horizontal axis (HAWT) types cannot change their pitch to face the wind directly. The turbine measures 2.0m (6'7\") in diameter by 2.9m (9'6\") high (including generator), and weighs 136 kg (300 lb). It is specified to generate power in winds of between 4 m/s (9 mph, 7.8kts) and 14 m/s (31 mph, 27.2kts), and can survive winds of 55 m/s (123 mph, 107kts). The rated power at 14 m/s is 2.5 kW (3.35 hp). The AC output from the synchronous generator is rectified to DC, then inverted to AC at 230V 50 Hz. Core International developed the turbine\nAnswer: One source states the [ 1 amount of power produced by a wind turbine is proportional to the cube of the wind speed ] . Other sources state [ 2 Turbines for residential scale use ] [ 2 produce electricity at a rate of 900 watts to 10,000 watts ] , and [ 3 is specified to generate power in winds of between 4 m/s (9 mph, 7.8kts) and 14 m/s (31 mph, 27.2kts) ] .\n\nQuestion: a component is what?\n[1] Modular programming: in Dart, Go or Java) is sometimes used instead of module. In other implementations, this is a distinct concept; in Python a package is a collection of modules, while in Java 9 the introduction of the new module concept (a collection of packages with enhanced access control) is planned. Furthermore, the term \"package\" has other uses in software (for example .NET NuGet packages). A component is a similar concept, but typically refers to a higher level; a component is a piece of a whole system, while a module is a piece of an individual program. The scale of the term\n[2] Physical body: the system at a point in time changes from identifying the object to not identifying it. Also an object's identity is created at the first point in time that the simplest model of the system consistent with perception identifies it. An object may be composed of components. A component is an object completely within the boundary of a containing object. In classical mechanics a physical body is collection of matter having properties including mass, velocity, momentum and energy. The matter exists in a volume of three-dimensional space. This space is its extension. Under Newtonian gravity the gravitational field further away\nQuoted summary: [ 1 A component is a similar concept, but typically refers to a higher level; a component is a piece of a whole system, while a module is a piece of an individual program ] in terms of [ 1 Modular programming ] . Whereas in the [ 2 Physical body ] , a [ 2 component is an object completely within the boundary of a containing object ] ."
|
| 255 |
+
elif self.sub == 'ans':
|
| 256 |
+
pass
|
| 257 |
+
prompt += f"\n\nQusetion: {example[0]['question']}\n"
|
| 258 |
+
docs = ""
|
| 259 |
+
sources = []
|
| 260 |
+
citations = []
|
| 261 |
+
#fk = 0
|
| 262 |
+
for i in range(8):
|
| 263 |
+
if f"title{i+1}" not in example[0]:
|
| 264 |
+
break
|
| 265 |
+
#if example[0][f'title{i+1}'] == "":
|
| 266 |
+
# fk = i
|
| 267 |
+
sources.append({'title': example[0][f'title{i+1}'],
|
| 268 |
+
'doc': example[0][f"source{i+1}"]}
|
| 269 |
+
)
|
| 270 |
+
#random.shuffle(sources[:fk])
|
| 271 |
+
for i in range(8):
|
| 272 |
+
if sources[i]['doc'] != "":
|
| 273 |
+
#docs += f"Document [{i+1}](Title: {sources[i]['title']}): {sources[i]['doc']}\n"
|
| 274 |
+
#docs += f"Document <|reserved_special_token_{i+1}|>(Title: {sources[i]['title']}): {sources[i]['doc']}\n"
|
| 275 |
+
docs += f"Document <|reserved_special_token_{i+1}|>: {sources[i]['doc']}\n"
|
| 276 |
+
citations.append(sources[i]['doc'])
|
| 277 |
+
else:
|
| 278 |
+
break
|
| 279 |
+
if len(citations) == 0:
|
| 280 |
+
continue
|
| 281 |
+
#citations = sents_embed(citations, model, tokenizer, device)
|
| 282 |
+
prompt += docs
|
| 283 |
+
prompt += f"\nAnswer:"
|
| 284 |
+
if is_train:
|
| 285 |
+
for e in example:
|
| 286 |
+
#ret.append(InputData(inputs = prompt + e['summary']))
|
| 287 |
+
ret.append(InputData(inputs = prompt + cite2token(e['summary']),
|
| 288 |
+
citations=citations, prompt = prompt))
|
| 289 |
+
else:
|
| 290 |
+
ret.append(InputData(inputs=prompt, labels=[e['summary'] for e in example], \
|
| 291 |
+
grounds=[i for e in example for i in e['covered_short_answers']], \
|
| 292 |
+
citations=citations, query = example[0]['question']))
|
| 293 |
+
return ret
|
| 294 |
+
|
| 295 |
+
def loading_metric(self):
|
| 296 |
+
config = {}
|
| 297 |
+
config['task'] = 'qsum'
|
| 298 |
+
if self.sub == 'alce':
|
| 299 |
+
config['metric'] = metric_list['qsum-a']
|
| 300 |
+
else:
|
| 301 |
+
config['metric'] = metric_list['qsum']
|
| 302 |
+
return AutoMetric("attribute", config)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
class Front(AttributedAnswerTask):
|
| 306 |
+
def __init__(self, sub):
|
| 307 |
+
super().__init__()
|
| 308 |
+
self.inst = 'Extract the relevant content from the provided documents and then use the extracted content to guide answer generation and cite the sources properly.'
|
| 309 |
+
self.sub = sub
|
| 310 |
+
|
| 311 |
+
def loading_data(self, is_train: bool = False, few_shot: bool = True
|
| 312 |
+
) -> List[InputData]:
|
| 313 |
+
few_shot = False ##############
|
| 314 |
+
with open("/yy21/MoE-PEFT/dataset/front/sft.json" if self.sub == 'sft' else "/yy21/MoE-PEFT/dataset/front/dpo.json",'r',encoding='utf-8') as file:
|
| 315 |
+
data = json.load(file)
|
| 316 |
+
logging.info("Preparing data for Front")
|
| 317 |
+
ret: List[InputData] = []
|
| 318 |
+
#cnt = 2
|
| 319 |
+
|
| 320 |
+
for data_point in data:
|
| 321 |
+
if data_point['instruction'] != self.inst:
|
| 322 |
+
continue
|
| 323 |
+
#if cnt == 0:
|
| 324 |
+
# break
|
| 325 |
+
#cnt = cnt - 1
|
| 326 |
+
prompt = "<|start_header_id|>system<|end_header_id|>\n\n" + "You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
|
| 327 |
+
prompt += self.inst
|
| 328 |
+
prompt += data_point['input']
|
| 329 |
+
prompt += "\nAnswer:"
|
| 330 |
+
prompt = cite2token(prompt)
|
| 331 |
+
q_start = len("Question: ")
|
| 332 |
+
q_end = data_point['input'].find("\n\n", q_start)
|
| 333 |
+
q = data_point['input'][q_start:q_end]
|
| 334 |
+
cites = []
|
| 335 |
+
pattern = r"Document \[(\d+)\]: (.*?)(?=Document \[\d+\]:|$)"
|
| 336 |
+
matches = re.findall(pattern, data_point['input'][q_end + 2:], re.DOTALL)
|
| 337 |
+
cites = [content.strip() for _, content in matches]
|
| 338 |
+
#random.shuffle(cites)
|
| 339 |
+
ans_idx = data_point['output'].find("[ANSWER]")
|
| 340 |
+
ans = cite2token(data_point['output'][ans_idx + len("[ANSWER]"):])
|
| 341 |
+
if is_train:
|
| 342 |
+
ret.append(InputData(inputs = prompt + ans, prompt = prompt, citations=cites))
|
| 343 |
+
else:
|
| 344 |
+
ret.append(InputData(inputs=prompt, labels=ans, \
|
| 345 |
+
citations = cites, query = q))
|
| 346 |
+
return ret
|
| 347 |
+
|
| 348 |
+
def loading_metric(self):
|
| 349 |
+
config = {}
|
| 350 |
+
config['task'] = 'front'
|
| 351 |
+
config['metric'] = metric_list['front']
|
| 352 |
+
return AutoMetric("attribute", config)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
class Synsciqa(AttributedAnswerTask):
|
| 356 |
+
def __init__(self, sub):
|
| 357 |
+
super().__init__()
|
| 358 |
+
self.sub = sub
|
| 359 |
+
self.inst = lambda query: f"Can you respond to the question {query} by only relying on the sources. Ignore all sources that do not provide an answer to the question. Do not include any knowledge from outside of these sources. Only write a single paragraph. Each sentence must end with the reference in the form of (author, year, page number). Stricly follow this format. Citing multiple sources in one sentence is not allowed. However, if no source addresses the question, admit truthfully that no answer can be given. Answer the question concisly and avoid being verbose."
|
| 360 |
+
self.inst_a = 'Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite them properly. Use an unbiased and journalistic tone. Always cite for any factual claim. When citing several search results, use [1][2][3]. Cite at least one document and at most three documents in each sentence. If multiple documents support the sentence, only cite a minimum sufficient subset of the documents.'
|
| 361 |
+
self.inst_special_token = 'Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite them properly. Use an unbiased and journalistic tone. Always cite for any factual claim. Cite at least one document and at most three documents in each sentence. If multiple documents support the sentence, only cite a minimum sufficient subset of the documents.'
|
| 362 |
+
self.inst_new = 'Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite all of them at the end of the sentences. Use an unbiased and journalistic tone. Always cite for any factual claim. Cite at least one document in each sentence.'
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def loading_data(self, is_train: bool = False, few_shot: bool = True
|
| 366 |
+
) -> List[InputData]:
|
| 367 |
+
few_shot = False ##############
|
| 368 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 369 |
+
# 向上两级再进入dataset目录
|
| 370 |
+
|
| 371 |
+
if self.sub == 'synsci':
|
| 372 |
+
relative_path = "../../dataset/SynSciQA/SynSciQA.json"
|
| 373 |
+
elif self.sub == 'synsci+':
|
| 374 |
+
relative_path = "../../dataset/SynSciQA/SynSciQA+.json"
|
| 375 |
+
elif self.sub == 'synsci++':
|
| 376 |
+
relative_path = "../../dataset/SynSciQA/SynSciQA++.json"
|
| 377 |
+
file_path = os.path.join(current_dir, relative_path)
|
| 378 |
+
with open(file_path, 'r',encoding='utf-8') as file:
|
| 379 |
+
data = json.load(file)
|
| 380 |
+
|
| 381 |
+
logging.info("Preparing data for SynsciQA")
|
| 382 |
+
ret: List[InputData] = []
|
| 383 |
+
#cnt = 10
|
| 384 |
+
|
| 385 |
+
"""tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
|
| 386 |
+
model = BertModel.from_pretrained('bert-large-uncased')
|
| 387 |
+
device = 'cuda:4'
|
| 388 |
+
model = model.to(device)
|
| 389 |
+
model.eval()"""
|
| 390 |
+
for line in tqdm(data):
|
| 391 |
+
#if cnt == 0:
|
| 392 |
+
# break
|
| 393 |
+
#cnt -= 1
|
| 394 |
+
data_point = line["instruction"]
|
| 395 |
+
answer = line["response"]
|
| 396 |
+
doc_start = data_point.find("[BEGIN OF SOURCES]")
|
| 397 |
+
doc_end = data_point.find("[END OF SOURCES]")
|
| 398 |
+
documents = data_point[doc_start + len("[BEGIN OF SOURCES]"): doc_end].strip().split("\n")
|
| 399 |
+
assert len(documents) > 0, print(f"No docs detected!")
|
| 400 |
+
|
| 401 |
+
data_point = data_point[doc_end + len("[END OF SOURCES]"):]
|
| 402 |
+
pattern = r'"([^"]*)"'
|
| 403 |
+
query = re.findall(pattern, data_point)
|
| 404 |
+
#prompt = ""
|
| 405 |
+
prompt = "<|start_header_id|>system<|end_header_id|>\n\n" + "You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
|
| 406 |
+
#prompt += self.inst_special_token
|
| 407 |
+
#prompt += self.inst_new
|
| 408 |
+
#prompt += self.inst_a
|
| 409 |
+
prompt += f"\n\nQuestion: {query[0]}\n"
|
| 410 |
+
|
| 411 |
+
docs = ""
|
| 412 |
+
citations = []
|
| 413 |
+
index_map = []
|
| 414 |
+
index_map2 = []
|
| 415 |
+
for i, d in enumerate(documents):
|
| 416 |
+
Ids = d[:d.find(":")]
|
| 417 |
+
cont = d[d.find(":") + 2:]
|
| 418 |
+
docs += f"Document <|reserved_special_token_{i+1}|>: {cont}\n"
|
| 419 |
+
#docs += f"Document [{i+1}]: {cont}\n"
|
| 420 |
+
citations.append(cont)
|
| 421 |
+
index_map.append({'index': f"({Ids})", 'ID': f'<|reserved_special_token_{i+1}|>'})
|
| 422 |
+
index_map2.append({'index': f"{Ids}", 'ID': f'<|reserved_special_token_{i+1}|>'})
|
| 423 |
+
#index_map.append({'index': f"({Ids})", 'ID': f'[{i+1}]'})
|
| 424 |
+
index_map = {item['index']: item['ID'] for item in index_map}
|
| 425 |
+
index_map2 = {item['index']: item['ID'] for item in index_map2}
|
| 426 |
+
prompt +=docs
|
| 427 |
+
prompt += "\nAnswer:"
|
| 428 |
+
pattern = re.compile('|'.join(map(re.escape, index_map)))
|
| 429 |
+
answer = pattern.sub(lambda m: index_map[m.group()], answer)
|
| 430 |
+
pattern = re.compile('|'.join(map(re.escape, index_map2)))
|
| 431 |
+
answer = pattern.sub(lambda m: index_map2[m.group()], answer)
|
| 432 |
+
pattern = r'\(\s*(<\|[^|]+\|>)\s*;\s*(<\|[^|]+\|>)\s*\)'
|
| 433 |
+
answer = re.sub(pattern, r'\1\2', answer)
|
| 434 |
+
|
| 435 |
+
pattern = r'<\|reserved_special_token_\d+\|>'
|
| 436 |
+
if bool(re.search(pattern, answer)) == False:
|
| 437 |
+
continue
|
| 438 |
+
pattern = r"\((?:[^)]*,){2}[^)]*p\.[^)]*\)"
|
| 439 |
+
fk = re.findall(pattern, answer)
|
| 440 |
+
if fk:
|
| 441 |
+
continue
|
| 442 |
+
#print(f"inputs:{prompt}\nans:{answer}\ncite{citations}")
|
| 443 |
+
#input()
|
| 444 |
+
#citation_embeds = sents_embed(citations, model, tokenizer, device)
|
| 445 |
+
if is_train:
|
| 446 |
+
ret.append(InputData(
|
| 447 |
+
inputs = prompt + answer, citations = citations, prompt = prompt#, citation_embeds = citation_embeds,
|
| 448 |
+
))
|
| 449 |
+
return ret
|
| 450 |
+
|
| 451 |
+
def loading_metric(self):
|
| 452 |
+
config = {}
|
| 453 |
+
config['task'] = 'front'
|
| 454 |
+
config['metric'] = metric_list['front']
|
| 455 |
+
return AutoMetric("attribute", config)
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
class Reinf(AttributedAnswerTask):
|
| 459 |
+
def __init__(self):
|
| 460 |
+
super().__init__()
|
| 461 |
+
self.inst_a = 'Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite them properly. Use an unbiased and journalistic tone. Always cite for any factual claim. When citing several search results, use [1][2][3]. Cite at least one document and at most three documents in each sentence. If multiple documents support the sentence, only cite a minimum sufficient subset of the documents.'
|
| 462 |
+
self.inst_special_token = 'Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite them properly. Use an unbiased and journalistic tone. Always cite for any factual claim. Cite at least one document and at most three documents in each sentence. If multiple documents support the sentence, only cite a minimum sufficient subset of the documents.'
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
def loading_data(self, is_train: bool = False, few_shot: bool = True
|
| 466 |
+
) -> List[InputData]:
|
| 467 |
+
few_shot = False ##############
|
| 468 |
+
with open("/yy21/MoE-PEFT/dataset/reinforcement/combined_train.json", 'r',encoding='utf-8') as file:
|
| 469 |
+
data = json.load(file)
|
| 470 |
+
logging.info("Preparing data for Reinforcement")
|
| 471 |
+
ret: List[InputData] = []
|
| 472 |
+
#cnt = 305
|
| 473 |
+
|
| 474 |
+
for line in tqdm(data):
|
| 475 |
+
#if cnt == 0:
|
| 476 |
+
# break
|
| 477 |
+
|
| 478 |
+
answer = line["output"][0]
|
| 479 |
+
if bool(re.search(r'\[(\d+)\]', answer)) == False:
|
| 480 |
+
continue
|
| 481 |
+
cs = re.findall(r'\[(\d+)\]', answer)
|
| 482 |
+
if max(map(int, cs)) > len(line["docs"]):
|
| 483 |
+
continue
|
| 484 |
+
|
| 485 |
+
query = line["question"]
|
| 486 |
+
|
| 487 |
+
documents = line["docs"]
|
| 488 |
+
answer = self.get_ans(answer)
|
| 489 |
+
|
| 490 |
+
prompt = "<|start_header_id|>system<|end_header_id|>\n\n" + "You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
|
| 491 |
+
prompt += self.inst_special_token
|
| 492 |
+
#prompt += self.inst_new
|
| 493 |
+
#prompt += self.inst_a
|
| 494 |
+
prompt += f"\n\nQuestion: {query}\n"
|
| 495 |
+
|
| 496 |
+
docs = ""
|
| 497 |
+
citations = []
|
| 498 |
+
for i, d in enumerate(documents):
|
| 499 |
+
docs += f"Document <|reserved_special_token_{i+1}|>: {d['text']}\n"
|
| 500 |
+
citations.append(d["text"])
|
| 501 |
+
prompt +=docs
|
| 502 |
+
prompt += "\nAnswer:"
|
| 503 |
+
#cnt -= 1
|
| 504 |
+
if is_train:
|
| 505 |
+
ret.append(InputData(
|
| 506 |
+
inputs = prompt + answer, citations = citations, prompt = prompt
|
| 507 |
+
))
|
| 508 |
+
return ret
|
| 509 |
+
|
| 510 |
+
def get_ans(self, sent):
|
| 511 |
+
def replace_cite(x):
|
| 512 |
+
i = x.group(1)
|
| 513 |
+
return f"<|reserved_special_token_{i}|>"
|
| 514 |
+
return re.sub(r'\[(\d+)\]', replace_cite, sent)
|
| 515 |
+
|
| 516 |
+
def loading_metric(self):
|
| 517 |
+
config = {}
|
| 518 |
+
config['task'] = 'front'
|
| 519 |
+
config['metric'] = metric_list['front']
|
| 520 |
+
return AutoMetric("attribute", config)
|
| 521 |
+
|
| 522 |
+
def sents_embed(sents, model, tokenizer, device):
|
| 523 |
+
embeds = []
|
| 524 |
+
with torch.no_grad():
|
| 525 |
+
for sent in sents:
|
| 526 |
+
inputs = tokenizer(sent, return_tensors='pt', padding=True, truncation=True)
|
| 527 |
+
inputs = inputs.to(device)
|
| 528 |
+
output = model(**inputs)
|
| 529 |
+
embeds.append(output.pooler_output)
|
| 530 |
+
result = torch.stack(embeds).squeeze(1)
|
| 531 |
+
return result
|
| 532 |
+
|
| 533 |
+
def cite2token(sent):
|
| 534 |
+
pattern = r'\[(\d+)\]'
|
| 535 |
+
ans = re.sub(pattern, r'<|reserved_special_token_\g<1>|>', sent)
|
| 536 |
+
return ans
|
| 537 |
+
|
| 538 |
+
metric_list = {
|
| 539 |
+
'asqa': ['cite_pr', 'length', 'short_ans'],
|
| 540 |
+
'qsum': ['rouge_all', 'semqa_f1', 'semqa_short'],
|
| 541 |
+
'qsum-a': ['rouge_all','semqa_short', 'cite_pr', 'length', 'semqa_f1'],
|
| 542 |
+
'eli5': ['cite_pr', 'eli5_acc', 'length'],
|
| 543 |
+
'qam': ['cite_pr', 'qampari'],
|
| 544 |
+
'front': [],
|
| 545 |
+
}
|
| 546 |
+
|
| 547 |
+
def update_task_dict(task_dict):
|
| 548 |
+
task_dict.update(
|
| 549 |
+
{
|
| 550 |
+
"asqa": ASQA(),
|
| 551 |
+
"qsum": QouteSum('vani'),
|
| 552 |
+
"qsum-a": QouteSum('alce'),
|
| 553 |
+
"qsum-ans": QouteSum('ans'),
|
| 554 |
+
"eli5": ELI5(),
|
| 555 |
+
"front-s": Front('sft'),
|
| 556 |
+
"front-d": Front('dpo'),
|
| 557 |
+
"synsci": Synsciqa('synsci'),
|
| 558 |
+
"synsci+": Synsciqa('synsci+'),
|
| 559 |
+
"synsci++": Synsciqa('synsci++'),
|
| 560 |
+
"rein": Reinf(),
|
| 561 |
+
"qam": Qampari()
|
| 562 |
+
}
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
if __name__ == '__main__':
|
| 566 |
+
asqa = QouteSum()
|
| 567 |
+
asqa.loading_data()
|
c2cite/tasks/common.py
ADDED
|
@@ -0,0 +1,1045 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import json
|
| 5 |
+
import copy
|
| 6 |
+
import string
|
| 7 |
+
from nltk import sent_tokenize
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import numpy as np
|
| 10 |
+
from rouge import Rouge
|
| 11 |
+
import collections
|
| 12 |
+
from rouge_score import rouge_scorer, scoring
|
| 13 |
+
import functools
|
| 14 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
| 15 |
+
|
| 16 |
+
import transformers
|
| 17 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
| 18 |
+
import datasets as hf_datasets
|
| 19 |
+
import evaluate as hf_evaluate
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from moe_peft.common import InputData, Prompt
|
| 23 |
+
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
| 24 |
+
|
| 25 |
+
global autoais_model, autoais_tokenizer
|
| 26 |
+
autoais_model = None
|
| 27 |
+
autoais_tokenizer = None
|
| 28 |
+
qa_pipeline = None
|
| 29 |
+
get_docs_by_index = lambda i,docs: docs[i] if i < len(docs) else None
|
| 30 |
+
ais_LLM = None
|
| 31 |
+
|
| 32 |
+
evaluate_device = 'cuda:6'
|
| 33 |
+
#QA_MODEL = "gaotianyu1350/roberta-large-squad"
|
| 34 |
+
QA_MODEL = "/yy21/qa_model"
|
| 35 |
+
#AUTOAIS_MODEL = "google/t5_xxl_true_nli_mixture"
|
| 36 |
+
AUTOAIS_MODEL = "/yy21/autoais"
|
| 37 |
+
|
| 38 |
+
class BasicMetric:
|
| 39 |
+
def __init__(self) -> None:
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
def add_batch(self, data):
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
def add_batch(self, predictions: torch.Tensor, references: torch.Tensor):
|
| 46 |
+
pass
|
| 47 |
+
|
| 48 |
+
def compute(self) -> Dict[str, Any]:
|
| 49 |
+
pass
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
from statistics import harmonic_mean
|
| 53 |
+
|
| 54 |
+
def normalize_answers(text):
|
| 55 |
+
"""QA style answer normalization. Similar to TriviaQA."""
|
| 56 |
+
|
| 57 |
+
def remove_articles(s):
|
| 58 |
+
return re.sub(r"\b(a|an|the)\b", " ", s)
|
| 59 |
+
|
| 60 |
+
def replace_punctuation(s):
|
| 61 |
+
to_replace = set(string.punctuation)
|
| 62 |
+
return "".join(" " if ch in to_replace else ch for ch in s)
|
| 63 |
+
|
| 64 |
+
def white_space_fix(s):
|
| 65 |
+
return " ".join(s.split())
|
| 66 |
+
|
| 67 |
+
text = text.lower()
|
| 68 |
+
text = replace_punctuation(text)
|
| 69 |
+
text = remove_articles(text)
|
| 70 |
+
text = white_space_fix(text)
|
| 71 |
+
|
| 72 |
+
return text
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def strip_attribution_tokens(text):
|
| 76 |
+
"""Strip the attribution tokens from an answer."""
|
| 77 |
+
return re.sub(r'\[ ([1-9]) ([^\[\]]*) \]',r'\2' , text)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def non_quoted(text):
|
| 81 |
+
"""Returns only the text that is outside of quoted spans."""
|
| 82 |
+
return re.sub(r'\[ ([1-9]) ([^\[\]]*) \]', '' , text)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def only_quoted(text, sources='1-9', sep = ' '):
|
| 86 |
+
"""Returns only the text that is within of quoted spans."""
|
| 87 |
+
return sep.join([x.group(1) for x in re.finditer(r'\[ [{}] ([^\[\]]*) \]'.format(sources), text)])
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def quoted_sources(text):
|
| 91 |
+
"""Returns the list of input sources that were quoted in the answer."""
|
| 92 |
+
return sorted(list(set([int(x.group(1)) for x in re.finditer(r'\[ ([1-9]) [^\[\]]* \]', text)])))
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def score_all(data, scorer, aggr_measure, score_keys, preprocess_func=None, bootstrap=False):
|
| 96 |
+
"""
|
| 97 |
+
Aggregates across all targets per sample.
|
| 98 |
+
|
| 99 |
+
all_targets: list of list of strings
|
| 100 |
+
all_predictions: list of strings
|
| 101 |
+
"""
|
| 102 |
+
all_targets = [d['answer'] for d in data]
|
| 103 |
+
all_predictions = [d['output'] for d in data]
|
| 104 |
+
|
| 105 |
+
np.random.seed(1337)
|
| 106 |
+
|
| 107 |
+
is_rouge_measure = 'rouge' in aggr_measure
|
| 108 |
+
|
| 109 |
+
if preprocess_func is not None:
|
| 110 |
+
scoring_func = lambda target, prediction: scorer.score(target=preprocess_func(target), prediction=preprocess_func(prediction))
|
| 111 |
+
else:
|
| 112 |
+
scoring_func = scorer.score
|
| 113 |
+
|
| 114 |
+
aggregator = scoring.BootstrapAggregator()
|
| 115 |
+
all_scores = [] if is_rouge_measure else dict((k,[]) for k in score_keys)
|
| 116 |
+
for targets, prediction in zip(all_targets, all_predictions):
|
| 117 |
+
# Max across references by aggr_measure
|
| 118 |
+
if is_rouge_measure:
|
| 119 |
+
max_scores = max([scoring_func(target, prediction) for target in targets], key=lambda x: x[aggr_measure].fmeasure)
|
| 120 |
+
|
| 121 |
+
aggregator.add_scores(max_scores)
|
| 122 |
+
all_scores.append(max_scores[aggr_measure].fmeasure*100)
|
| 123 |
+
else:
|
| 124 |
+
if aggr_measure == 'independent':
|
| 125 |
+
max_scores = {}
|
| 126 |
+
for key in score_keys:
|
| 127 |
+
max_scores[key] = max([scoring_func(target, prediction)[key] for target in targets])
|
| 128 |
+
else:
|
| 129 |
+
max_scores = max([scoring_func(target, prediction) for target in targets], key=lambda x: x[aggr_measure])
|
| 130 |
+
|
| 131 |
+
aggregator.add_scores(max_scores)
|
| 132 |
+
for key in score_keys:
|
| 133 |
+
all_scores[key].append(max_scores[key]*100)
|
| 134 |
+
|
| 135 |
+
if not bootstrap:
|
| 136 |
+
return all_scores
|
| 137 |
+
|
| 138 |
+
result = aggregator.aggregate()
|
| 139 |
+
postprocess_result = (lambda x: x.fmeasure*100) if is_rouge_measure else (lambda x: x*100)
|
| 140 |
+
bootstrap_results = {}
|
| 141 |
+
for key in score_keys:
|
| 142 |
+
bootstrap_results[key] = (postprocess_result(result[key].mid), postprocess_result(result[key].low), postprocess_result(result[key].high))
|
| 143 |
+
return bootstrap_results, all_scores
|
| 144 |
+
|
| 145 |
+
## ROUGE ##
|
| 146 |
+
|
| 147 |
+
score_all_rouge = functools.partial(score_all, scorer=rouge_scorer.RougeScorer(rouge_types=("rouge1", "rouge2", "rougeLsum", "rougeL")), aggr_measure='rougeLsum', score_keys=("rouge1", "rouge2", "rougeLsum"), preprocess_func=strip_attribution_tokens, bootstrap=True)
|
| 148 |
+
|
| 149 |
+
## F1 ##
|
| 150 |
+
|
| 151 |
+
class _f1_scorer:
|
| 152 |
+
def score(self, target, prediction):
|
| 153 |
+
"""Computes token F1 score for a single target and prediction."""
|
| 154 |
+
prediction_tokens = prediction.split()
|
| 155 |
+
target_tokens = target.split()
|
| 156 |
+
common = (collections.Counter(prediction_tokens) &
|
| 157 |
+
collections.Counter(target_tokens))
|
| 158 |
+
num_same = sum(common.values())
|
| 159 |
+
if len(target_tokens) == 0 and len(prediction_tokens) == 0:
|
| 160 |
+
return {'F1': 1.0, 'recall': 1.0, 'precision': 1.0}
|
| 161 |
+
elif len(target_tokens) == 0 and len(prediction_tokens) > 0:
|
| 162 |
+
return {'F1': 0.0, 'recall': 1.0, 'precision': 0.0}
|
| 163 |
+
elif len(target_tokens) > 0 and len(prediction_tokens) == 0:
|
| 164 |
+
return {'F1': 0.0, 'recall': 0.0, 'precision': 1.0}
|
| 165 |
+
elif num_same == 0:
|
| 166 |
+
return {'F1': 0.0, 'recall': 0.0, 'precision': 0.0}
|
| 167 |
+
else:
|
| 168 |
+
precision = 1.0 * num_same / len(prediction_tokens)
|
| 169 |
+
recall = 1.0 * num_same / len(target_tokens)
|
| 170 |
+
f1 = (2 * precision * recall) / (precision + recall)
|
| 171 |
+
return {'F1': f1, 'recall': recall, 'precision': precision}
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
score_all_f1 = functools.partial(score_all, scorer=_f1_scorer(), aggr_measure='F1', score_keys=("F1", "recall", "precision"))
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def preprocess_quotes_f1(text, sep=' ', sources='1-7'):
|
| 178 |
+
text = only_quoted(text, sep=sep, sources=sources)
|
| 179 |
+
return normalize_answers(text)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def score_semqa_f1(data, harmonic=False):
|
| 183 |
+
examples = [d['docs'] for d in data]
|
| 184 |
+
per_source_prf1 = {}
|
| 185 |
+
for source in range(1, 8):
|
| 186 |
+
preprocess_quotes_f1_partial_sources = functools.partial(preprocess_quotes_f1, sep=' ', sources=f'{source}')
|
| 187 |
+
scores = score_all_f1(data, aggr_measure='independent', preprocess_func=preprocess_quotes_f1_partial_sources)
|
| 188 |
+
|
| 189 |
+
for aggr_measure in ('F1', 'recall', 'precision'):
|
| 190 |
+
per_source_prf1[f'{aggr_measure}_source_{source}'] = scores[aggr_measure]
|
| 191 |
+
|
| 192 |
+
semqa_f1s = []
|
| 193 |
+
for i in range(len(examples)):
|
| 194 |
+
precisions, recalls, f1s = [], [] , []
|
| 195 |
+
for source in range(1,8):
|
| 196 |
+
if examples[i][source]:
|
| 197 |
+
precisions.append(per_source_prf1[f'precision_source_{source}'][i])
|
| 198 |
+
recalls.append(per_source_prf1[f'recall_source_{source}'][i])
|
| 199 |
+
f1s.append(per_source_prf1[f'F1_source_{source}'][i])
|
| 200 |
+
if harmonic:
|
| 201 |
+
f1 = harmonic_mean(precisions + recalls)
|
| 202 |
+
else:
|
| 203 |
+
f1 = np.mean(f1s)
|
| 204 |
+
semqa_f1s.append(f1)
|
| 205 |
+
|
| 206 |
+
return np.mean(semqa_f1s)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
score_all_recall = functools.partial(score_all, scorer=_f1_scorer(), aggr_measure='recall', score_keys=("recall",))
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def score_semqa_short_recall(data):
|
| 213 |
+
if 'num' in data[0]['qa_pairs'][0].keys():
|
| 214 |
+
return compute_str_em(
|
| 215 |
+
[
|
| 216 |
+
{
|
| 217 |
+
'qa_pairs': [
|
| 218 |
+
{
|
| 219 |
+
'short_answers': i['ans'],
|
| 220 |
+
}for i in d['qa_pairs']],
|
| 221 |
+
'output': d['output']
|
| 222 |
+
}
|
| 223 |
+
for d in data]
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
all_targets = [d['qa_pairs'] for d in data]
|
| 227 |
+
all_predictions = [d['output'] for d in data]
|
| 228 |
+
|
| 229 |
+
fuck = []
|
| 230 |
+
# Ignore examples with no targets.
|
| 231 |
+
non_empty_targets, non_empty_predictions = [], []
|
| 232 |
+
for tar, pred in zip(all_targets, all_predictions):
|
| 233 |
+
if len(tar) == 0 or all([x == '' for x in tar]):
|
| 234 |
+
continue
|
| 235 |
+
fuck.append({
|
| 236 |
+
'answer': tar,
|
| 237 |
+
'output': pred,
|
| 238 |
+
})
|
| 239 |
+
non_empty_targets.append(tar)
|
| 240 |
+
non_empty_predictions.append(pred)
|
| 241 |
+
|
| 242 |
+
per_source_recall = {}
|
| 243 |
+
for source in range(1, 8):
|
| 244 |
+
preprocess_quotes_f1_partial_sources = functools.partial(preprocess_quotes_f1, sep=' ', sources=f'{source}')
|
| 245 |
+
scores = score_all_recall(fuck, preprocess_func=preprocess_quotes_f1_partial_sources)
|
| 246 |
+
per_source_recall[f'recall_source_{source}'] = scores['recall']
|
| 247 |
+
|
| 248 |
+
semqa_recalls = []
|
| 249 |
+
for i in range(len(non_empty_targets)):
|
| 250 |
+
recalls = []
|
| 251 |
+
for source in range(1,8):
|
| 252 |
+
preprocess_quotes_f1_partial_sources = functools.partial(preprocess_quotes_f1, sep=' ', sources=f'{source}')
|
| 253 |
+
if any([preprocess_quotes_f1_partial_sources(tar) for tar in non_empty_targets[i]]):
|
| 254 |
+
recalls.append(per_source_recall[f'recall_source_{source}'][i])
|
| 255 |
+
avg_recalls = np.mean(recalls)
|
| 256 |
+
semqa_recalls.append(avg_recalls)
|
| 257 |
+
|
| 258 |
+
return np.mean(semqa_recalls)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def exact_presence(short_answers, context):
|
| 262 |
+
"""Verify if any of the answers is present in the given context.
|
| 263 |
+
Args:
|
| 264 |
+
short_answers: list of short answers to look for in the context
|
| 265 |
+
context: a paragraph to search for short answers
|
| 266 |
+
Returns:
|
| 267 |
+
true if any of the short answers is present in the context
|
| 268 |
+
"""
|
| 269 |
+
|
| 270 |
+
n_short_answers = [normalize_answer(sa) for sa in short_answers]
|
| 271 |
+
n_context = normalize_answer(context)
|
| 272 |
+
|
| 273 |
+
for ans in n_short_answers:
|
| 274 |
+
if ans in n_context:
|
| 275 |
+
return True
|
| 276 |
+
|
| 277 |
+
return False
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def normalize_answer(s):
|
| 281 |
+
def remove_articles(text):
|
| 282 |
+
return re.sub(r"\b(a|an|the)\b", " ", text)
|
| 283 |
+
|
| 284 |
+
def white_space_fix(text):
|
| 285 |
+
return " ".join(text.split())
|
| 286 |
+
|
| 287 |
+
def remove_punc(text):
|
| 288 |
+
exclude = set(string.punctuation)
|
| 289 |
+
return "".join(ch for ch in text if ch not in exclude)
|
| 290 |
+
|
| 291 |
+
def lower(text):
|
| 292 |
+
return text.lower()
|
| 293 |
+
|
| 294 |
+
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def remove_citations(sent):
|
| 298 |
+
return re.sub(r"\[\d+", "", re.sub(r" \[\d+", "", sent)).replace(" |", "").replace("]", "")
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def load_auto_ais():
|
| 302 |
+
global autoais_model, autoais_tokenizer
|
| 303 |
+
print('Initializing eval model for citation precision and recall...')
|
| 304 |
+
autoais_model = AutoModelForSeq2SeqLM.from_pretrained(AUTOAIS_MODEL, torch_dtype=torch.bfloat16, device_map=evaluate_device, )
|
| 305 |
+
autoais_tokenizer = AutoTokenizer.from_pretrained(AUTOAIS_MODEL, use_fast=False)
|
| 306 |
+
print('Done!')
|
| 307 |
+
|
| 308 |
+
def _run_nli_autoais(passage, claim, test = False):
|
| 309 |
+
"""
|
| 310 |
+
Run inference for assessing AIS between a premise and .hypothesis
|
| 311 |
+
Adapted from https://github.com/google-research-datasets/Attributed-QA/blob/main/evaluation.py
|
| 312 |
+
"""
|
| 313 |
+
if not test:
|
| 314 |
+
global autoais_model, autoais_tokenizer
|
| 315 |
+
if not autoais_model:
|
| 316 |
+
load_auto_ais()
|
| 317 |
+
input_text = "premise: {} hypothesis: {}".format(passage, claim)
|
| 318 |
+
input_ids = autoais_tokenizer(input_text, return_tensors="pt").input_ids.to(autoais_model.device)
|
| 319 |
+
with torch.inference_mode():
|
| 320 |
+
outputs = autoais_model.generate(input_ids, max_new_tokens=10)
|
| 321 |
+
result = autoais_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 322 |
+
inference = 1 if result == "1" else 0
|
| 323 |
+
return inference
|
| 324 |
+
else:
|
| 325 |
+
res = 114514
|
| 326 |
+
|
| 327 |
+
return res
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def compute_autoais(data,
|
| 331 |
+
qampari=False,
|
| 332 |
+
at_most_sents = 50,
|
| 333 |
+
at_most_citations=3,
|
| 334 |
+
entail_function = _run_nli_autoais):
|
| 335 |
+
"""
|
| 336 |
+
Compute AutoAIS score.
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
data: requires field `output` and `docs`
|
| 340 |
+
- docs should be a list of items with fields `title` and `text` (or `phrase` and `sent` for QA-extracted docs)
|
| 341 |
+
citation: check citations and use the corresponding references.
|
| 342 |
+
decontext: decontextualize the output
|
| 343 |
+
"""
|
| 344 |
+
|
| 345 |
+
global autoais_model, autoais_tokenizer
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
ais_scores = []
|
| 349 |
+
ais_scores_prec = []
|
| 350 |
+
|
| 351 |
+
sent_total = 0
|
| 352 |
+
sent_mcite = 0
|
| 353 |
+
sent_mcite_support = 0
|
| 354 |
+
sent_mcite_overcite = 0
|
| 355 |
+
autoais_log = []
|
| 356 |
+
for item in tqdm(data):
|
| 357 |
+
# Get sentences by using NLTK
|
| 358 |
+
if qampari:
|
| 359 |
+
#print('now qampari...')
|
| 360 |
+
sents = [item['query'] + " " + x.strip() for x in
|
| 361 |
+
item['output'].rstrip().rstrip(".").rstrip(",").split(",")]
|
| 362 |
+
else:
|
| 363 |
+
sents = sent_tokenize(item['output'])[:at_most_sents]
|
| 364 |
+
if len(sents) == 0:
|
| 365 |
+
ais_scores.append(0.0)
|
| 366 |
+
ais_scores_prec.append(0.0) # len(sents))
|
| 367 |
+
continue
|
| 368 |
+
|
| 369 |
+
target_sents = [remove_citations(sent).strip() for sent in sents]
|
| 370 |
+
|
| 371 |
+
entail = 0
|
| 372 |
+
entail_prec = 0
|
| 373 |
+
total_citations = 0
|
| 374 |
+
for sent_id, sent in enumerate(sents):
|
| 375 |
+
target_sent = target_sents[sent_id] # Citation removed and (if opted for) decontextualized
|
| 376 |
+
joint_entail = -1 # Undecided
|
| 377 |
+
|
| 378 |
+
# Find references
|
| 379 |
+
#ref = [int(r[1:]) - 1 for r in re.findall(r"\[\d+", sent)] # In text citation id starts from 1
|
| 380 |
+
matches = re.findall(r"\[(\d+(?:,\s*\d+)*)\]", sent)
|
| 381 |
+
ref = [int(num)-1 for match in matches for num in match.replace(' ', '').split(',')]
|
| 382 |
+
if len(ref) == 0:
|
| 383 |
+
# No citations
|
| 384 |
+
joint_entail = 0
|
| 385 |
+
elif any([ref_id >= len(item['docs']) for ref_id in ref]):
|
| 386 |
+
# Citations out of range
|
| 387 |
+
joint_entail = 0
|
| 388 |
+
else:
|
| 389 |
+
if at_most_citations is not None:
|
| 390 |
+
ref = ref[:at_most_citations]
|
| 391 |
+
total_citations += len(ref)
|
| 392 |
+
joint_passage = '\n'.join([(item['docs'][psgs_id]) for psgs_id in ref])
|
| 393 |
+
|
| 394 |
+
# If not directly rejected by citation format error, calculate the recall score
|
| 395 |
+
if joint_entail == -1:
|
| 396 |
+
joint_entail = entail_function(joint_passage, target_sent)
|
| 397 |
+
autoais_log.append({
|
| 398 |
+
#"question": item['question'],
|
| 399 |
+
"output": item['output'],
|
| 400 |
+
"claim": sent,
|
| 401 |
+
"passage": [joint_passage],
|
| 402 |
+
"model_type": "NLI",
|
| 403 |
+
"model_output": joint_entail,
|
| 404 |
+
})
|
| 405 |
+
|
| 406 |
+
entail += joint_entail
|
| 407 |
+
if len(ref) > 1:
|
| 408 |
+
sent_mcite += 1
|
| 409 |
+
|
| 410 |
+
# calculate the precision score if applicable
|
| 411 |
+
if joint_entail and len(ref) > 1:
|
| 412 |
+
sent_mcite_support += 1
|
| 413 |
+
# Precision check: did the model cite any unnecessary documents?
|
| 414 |
+
for psgs_id in ref:
|
| 415 |
+
# condition A
|
| 416 |
+
passage = item['docs'][psgs_id]
|
| 417 |
+
nli_result = entail_function(passage, target_sent)
|
| 418 |
+
|
| 419 |
+
# condition B
|
| 420 |
+
if not nli_result:
|
| 421 |
+
subset_exclude = copy.deepcopy(ref)
|
| 422 |
+
subset_exclude.remove(psgs_id)
|
| 423 |
+
passage = '\n'.join([item['docs'][pid] for pid in subset_exclude])
|
| 424 |
+
nli_result =entail_function(passage, target_sent)
|
| 425 |
+
if nli_result: # psgs_id is not necessary
|
| 426 |
+
flag = 0
|
| 427 |
+
sent_mcite_overcite += 1
|
| 428 |
+
else:
|
| 429 |
+
entail_prec += 1
|
| 430 |
+
else:
|
| 431 |
+
entail_prec += 1
|
| 432 |
+
else:
|
| 433 |
+
entail_prec += joint_entail
|
| 434 |
+
sent_total += len(sents)
|
| 435 |
+
ais_scores.append(entail / len(sents))
|
| 436 |
+
ais_scores_prec.append(entail_prec / total_citations if total_citations > 0 else 0) # len(sents))
|
| 437 |
+
|
| 438 |
+
if sent_mcite > 0 and sent_mcite_support > 0:
|
| 439 |
+
print(
|
| 440 |
+
"Among all sentences, %.2f%% have multiple citations, among which %.2f%% are supported by the joint set, among which %.2f%% overcite." % (
|
| 441 |
+
100 * sent_mcite / sent_total,
|
| 442 |
+
100 * sent_mcite_support / sent_mcite,
|
| 443 |
+
100 * sent_mcite_overcite / sent_mcite_support
|
| 444 |
+
))
|
| 445 |
+
|
| 446 |
+
return {
|
| 447 |
+
"citation_rec": 100 * np.mean(ais_scores),
|
| 448 |
+
"citation_prec": 100 * np.mean(ais_scores_prec),
|
| 449 |
+
}
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def compute_f1(a_gold, a_pred):
|
| 453 |
+
"""Compute F1 score between two strings."""
|
| 454 |
+
|
| 455 |
+
def _get_tokens(s):
|
| 456 |
+
if not s:
|
| 457 |
+
return []
|
| 458 |
+
return normalize_answer(s).split()
|
| 459 |
+
|
| 460 |
+
gold_toks = _get_tokens(a_gold)
|
| 461 |
+
pred_toks = _get_tokens(a_pred)
|
| 462 |
+
|
| 463 |
+
common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
|
| 464 |
+
num_same = sum(common.values())
|
| 465 |
+
|
| 466 |
+
if len(gold_toks) == 0 or len(pred_toks) == 0:
|
| 467 |
+
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
|
| 468 |
+
return int(gold_toks == pred_toks)
|
| 469 |
+
|
| 470 |
+
if num_same == 0:
|
| 471 |
+
return 0
|
| 472 |
+
|
| 473 |
+
precision = 1.0 * num_same / len(pred_toks)
|
| 474 |
+
recall = 1.0 * num_same / len(gold_toks)
|
| 475 |
+
f1 = (2 * precision * recall) / (precision + recall)
|
| 476 |
+
|
| 477 |
+
return f1
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def compute_exact(a_gold, a_pred):
|
| 481 |
+
"""Check whether two strings are equal up to normalization."""
|
| 482 |
+
|
| 483 |
+
return int(normalize_answer(a_gold) == normalize_answer(a_pred))
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def compute_qa(data):
|
| 487 |
+
"""Compute QA-based accuracy.
|
| 488 |
+
Args:
|
| 489 |
+
data: requires filed `qa_pairs/short_answers` and `output`
|
| 490 |
+
Returns:
|
| 491 |
+
QA metrics (QA-EM, QA-F1, QA-Hit)
|
| 492 |
+
"""
|
| 493 |
+
if 'qa_pairs' not in data[0] or data[0]['qa_pairs'] is None:
|
| 494 |
+
logging.warn("Warning: no QA pairs found in data")
|
| 495 |
+
return {
|
| 496 |
+
'QA-EM': 0,
|
| 497 |
+
'QA-F1': 0,
|
| 498 |
+
'QA-Hit': 0,
|
| 499 |
+
}
|
| 500 |
+
|
| 501 |
+
# Load model
|
| 502 |
+
#logger.info("Loading the RoBERTa-large SQuAD model for QA-based accuracy...")
|
| 503 |
+
global qa_pipeline
|
| 504 |
+
if not qa_pipeline:
|
| 505 |
+
qa_pipeline = transformers.pipeline("question-answering", model=QA_MODEL, device = evaluate_device)
|
| 506 |
+
#logger.info("Done")
|
| 507 |
+
|
| 508 |
+
# Get prediction
|
| 509 |
+
#logger.info("Computing the QA-based accuracy...")
|
| 510 |
+
em, f1, bins = [], [], []
|
| 511 |
+
for item in tqdm(data):
|
| 512 |
+
question = [qa_pair['question'] for qa_pair in item['qa_pairs']]
|
| 513 |
+
#question = [item['qa_pairs'][0]['question']]
|
| 514 |
+
context = item['output'] if len(item['output']) > 0 else " "
|
| 515 |
+
results = qa_pipeline(question=question, context=remove_citations(context), handle_impossible_answer=True)
|
| 516 |
+
loc_counter, loc_em, loc_f1 = 0, 0, 0
|
| 517 |
+
|
| 518 |
+
for idx, res in enumerate(results):
|
| 519 |
+
answers = item["qa_pairs"][idx]["short_answers"]
|
| 520 |
+
prediction = res["answer"]
|
| 521 |
+
|
| 522 |
+
loc_em += max([compute_exact(a, prediction) for a in answers])
|
| 523 |
+
loc_f1 += max([compute_f1(a, prediction) for a in answers])
|
| 524 |
+
loc_counter += 1
|
| 525 |
+
|
| 526 |
+
em.append(loc_em / loc_counter)
|
| 527 |
+
f1.append(loc_f1 / loc_counter)
|
| 528 |
+
bins.append(loc_em == loc_counter)
|
| 529 |
+
|
| 530 |
+
return {
|
| 531 |
+
'QA-EM': 100 * np.mean(em),
|
| 532 |
+
'QA-F1': 100 * np.mean(f1),
|
| 533 |
+
'QA-Hit': 100 * np.mean(bins)
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
def compute_claims(data):
|
| 538 |
+
global autoais_model, autoais_tokenizer
|
| 539 |
+
if autoais_model is None:
|
| 540 |
+
#logger.info("Loading AutoAIS model...")
|
| 541 |
+
# autoais_model = AutoModelForSeq2SeqLM.from_pretrained(AUTOAIS_MODEL, torch_dtype=torch.bfloat16, max_memory=get_max_memory(), device_map="auto")
|
| 542 |
+
autoais_model = AutoModelForSeq2SeqLM.from_pretrained(AUTOAIS_MODEL, torch_dtype=torch.bfloat16,
|
| 543 |
+
device_map=evaluate_device)
|
| 544 |
+
# autoais_model = AutoModelForSeq2SeqLM.from_pretrained(AUTOAIS_MODEL, torch_dtype=torch.bfloat16, max_memory=get_max_memory(), device_map="auto",offload_folder= "/data/hongbang/zsf/projects/ALCE/ALCE/model/t5_xxl_true_nli_mixture/offload1")
|
| 545 |
+
autoais_tokenizer = AutoTokenizer.from_pretrained(AUTOAIS_MODEL, use_fast=False)
|
| 546 |
+
#logger.info("Computing claims...")
|
| 547 |
+
scores = []
|
| 548 |
+
for item in tqdm(data):
|
| 549 |
+
normalized_output = remove_citations(item['output'])
|
| 550 |
+
entail = 0
|
| 551 |
+
claims = item["qa_pairs"]
|
| 552 |
+
for claim in claims:
|
| 553 |
+
entail += _run_nli_autoais(normalized_output, claim)
|
| 554 |
+
scores.append(entail / len(claims))
|
| 555 |
+
return 100 * np.mean(scores)
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
def compute_qampari_f1(data, cot=False):
|
| 559 |
+
prec = []
|
| 560 |
+
rec = []
|
| 561 |
+
rec_top5 = []
|
| 562 |
+
f1 = []
|
| 563 |
+
f1_top5 = []
|
| 564 |
+
|
| 565 |
+
num_preds = []
|
| 566 |
+
for item in data:
|
| 567 |
+
if cot:
|
| 568 |
+
if ":" in item['output']:
|
| 569 |
+
o = ':'.join(item['output'].split(":")[1:]) # try to separate the COT part and the answer list part.
|
| 570 |
+
else:
|
| 571 |
+
o = ""
|
| 572 |
+
else:
|
| 573 |
+
o = item['output']
|
| 574 |
+
|
| 575 |
+
preds = [normalize_answer(x.strip()) for x in remove_citations(o).rstrip().rstrip(".").rstrip(",").split(",")]
|
| 576 |
+
preds = [p for p in preds if len(p) > 0] # delete empty answers
|
| 577 |
+
num_preds.append(len(preds))
|
| 578 |
+
answers = [[normalize_answer(x) for x in ans] for ans in item['answer']]
|
| 579 |
+
flat_answers = [item for sublist in answers for item in sublist]
|
| 580 |
+
prec.append(sum([p in flat_answers for p in preds]) / len(preds) if len(preds) > 0 else 0)
|
| 581 |
+
|
| 582 |
+
rec.append(sum([any([x in preds for x in a]) for a in answers]) / len(answers))
|
| 583 |
+
rec_top5.append(min(5, sum([any([x in preds for x in a]) for a in answers])) / min(5, len(answers)))
|
| 584 |
+
if (prec[-1] + rec[-1]) == 0:
|
| 585 |
+
f1.append(0)
|
| 586 |
+
else:
|
| 587 |
+
f1.append(2 * prec[-1] * rec[-1] / (prec[-1] + rec[-1]))
|
| 588 |
+
if (prec[-1] + rec_top5[-1]) == 0:
|
| 589 |
+
f1_top5.append(0)
|
| 590 |
+
else:
|
| 591 |
+
f1_top5.append(2 * prec[-1] * rec_top5[-1] / (prec[-1] + rec_top5[-1]))
|
| 592 |
+
|
| 593 |
+
return {
|
| 594 |
+
"num_preds": np.mean(num_preds),
|
| 595 |
+
"qampari_prec": 100 * np.mean(prec),
|
| 596 |
+
"qampari_rec": 100 * np.mean(rec),
|
| 597 |
+
"qampari_rec_top5": 100 * np.mean(rec_top5),
|
| 598 |
+
"qampari_f1": 100 * np.mean(f1),
|
| 599 |
+
"qampari_f1_top5": 100 * np.mean(f1_top5),
|
| 600 |
+
}
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
def compute_str_em(data):
|
| 604 |
+
"""Compute STR-EM metric (only for ASQA)
|
| 605 |
+
Args:
|
| 606 |
+
data: requires field `qa_pairs/short_answers` and `output`
|
| 607 |
+
Returns:
|
| 608 |
+
STR-EM and STR-EM-HIT ()
|
| 609 |
+
"""
|
| 610 |
+
if 'qa_pairs' not in data[0] or data[0]['qa_pairs'] is None:
|
| 611 |
+
return 0
|
| 612 |
+
|
| 613 |
+
acc = []
|
| 614 |
+
for item in data:
|
| 615 |
+
loc_acc = []
|
| 616 |
+
if len(item['qa_pairs']) == 0:
|
| 617 |
+
continue
|
| 618 |
+
loc_acc.append(exact_presence(item['qa_pairs'][0]['short_answers'], item["output"]))
|
| 619 |
+
"""for qa_pair in item['qa_pairs']:
|
| 620 |
+
loc_acc.append(exact_presence(qa_pair['short_answers'], item["output"]))"""
|
| 621 |
+
acc.append(float(np.mean(loc_acc)))
|
| 622 |
+
return 100 * np.mean(acc) if len(acc) > 0 else 0
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
def compute_mauve(data):
|
| 626 |
+
"""Compute Mauve score."""
|
| 627 |
+
|
| 628 |
+
logging.info("Computing MAUVE...")
|
| 629 |
+
human_data = []
|
| 630 |
+
model_data = []
|
| 631 |
+
for item in data:
|
| 632 |
+
# Remove ending punctuations
|
| 633 |
+
# Remove any new lines
|
| 634 |
+
# Truncate by 100 words
|
| 635 |
+
human_data.append(
|
| 636 |
+
' '.join((item['query'] + " " + item['answer'].strip()).split()[:100]).rstrip(string.punctuation))
|
| 637 |
+
model_data.append(
|
| 638 |
+
' '.join((item['query'] + " " + item['output'].strip()).split()[:100]).rstrip(string.punctuation))
|
| 639 |
+
|
| 640 |
+
import mauve
|
| 641 |
+
out = mauve.compute_mauve(
|
| 642 |
+
p_text=human_data,
|
| 643 |
+
q_text=model_data,
|
| 644 |
+
device_id=0,
|
| 645 |
+
max_text_length=512,
|
| 646 |
+
verbose=True,
|
| 647 |
+
batch_size=8,
|
| 648 |
+
featurize_model_name="gpt2-large"
|
| 649 |
+
)
|
| 650 |
+
return out.mauve * 100
|
| 651 |
+
|
| 652 |
+
|
| 653 |
+
def compute_rouge_l(data):
|
| 654 |
+
total = len(data)
|
| 655 |
+
res = {
|
| 656 |
+
"r": 0.0,
|
| 657 |
+
"p": 0.0,
|
| 658 |
+
"f": 0.0
|
| 659 |
+
}
|
| 660 |
+
for item in data:
|
| 661 |
+
# print(f"output:{item['output']}, \nanswer:{item['answer']}")
|
| 662 |
+
if item['output'] and item['answer']:
|
| 663 |
+
rouge = Rouge()
|
| 664 |
+
scores = rouge.get_scores(item['output'], item['answer'])
|
| 665 |
+
res['r'] += scores[0]['rouge-l']['r']
|
| 666 |
+
res['p'] += scores[0]['rouge-l']['p']
|
| 667 |
+
res['f'] += scores[0]['rouge-l']['f']
|
| 668 |
+
else:
|
| 669 |
+
print('Warning: no hypothesis or references')
|
| 670 |
+
res['r'] /= total
|
| 671 |
+
res['p'] /= total
|
| 672 |
+
res['f'] /= total
|
| 673 |
+
|
| 674 |
+
return res
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
def compute_length(data):
|
| 678 |
+
return sum(len(item['output'].split(' '))for item in data)/(len(data))
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
metric_list = {
|
| 682 |
+
'cite_pr': compute_autoais,
|
| 683 |
+
'asqa_acc': compute_qa,
|
| 684 |
+
'eli5_acc': compute_claims,
|
| 685 |
+
'qampari': compute_qampari_f1,
|
| 686 |
+
'short_ans': compute_str_em,
|
| 687 |
+
# 'fluence': compute_mauve,
|
| 688 |
+
'rouge': compute_rouge_l,
|
| 689 |
+
'length': compute_length,
|
| 690 |
+
'rouge_all': score_all_rouge,
|
| 691 |
+
'semqa_f1': score_semqa_f1, # 相当于precision
|
| 692 |
+
'semqa_short': score_semqa_short_recall, # 相当于recall
|
| 693 |
+
}
|
| 694 |
+
|
| 695 |
+
data_list = {
|
| 696 |
+
'cite_pr': {'output': None, 'docs': None, 'query': None},
|
| 697 |
+
'asqa_acc': {'output': None,'qa_pairs': None, 'query': None},
|
| 698 |
+
'eli5_acc': {'output': None, 'qa_pairs': None},
|
| 699 |
+
'qampari': {'output': None, 'answer': None},
|
| 700 |
+
'short_ans': {'qa_pairs': None, 'output': None},
|
| 701 |
+
# 'fluence': {'query': None, 'answer': None, 'output': None},
|
| 702 |
+
'rouge': {'output': None, 'answer': None},
|
| 703 |
+
'length': {'output': None},
|
| 704 |
+
'rouge_all': {'answer': None, 'output': None},
|
| 705 |
+
'semqa_f1': {'answer': None, 'output': None, 'docs': None},
|
| 706 |
+
'semqa_short':{'output': None, 'qa_pairs': None},
|
| 707 |
+
'semqa': {}
|
| 708 |
+
}
|
| 709 |
+
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
class AttributeMetric:
|
| 713 |
+
def __init__(self, config):
|
| 714 |
+
self.task = 'attribute'
|
| 715 |
+
self.metrics = config['metric']
|
| 716 |
+
self.flag = False
|
| 717 |
+
self.data = {
|
| 718 |
+
'cite_pr': [],
|
| 719 |
+
'asqa_acc': [],
|
| 720 |
+
'eli5_acc': [],
|
| 721 |
+
'qampari': [],
|
| 722 |
+
'short_ans': [],
|
| 723 |
+
'fluence': [],
|
| 724 |
+
'rouge': [],
|
| 725 |
+
'length': [],
|
| 726 |
+
'rouge_all': [],
|
| 727 |
+
'semqa_f1': [],
|
| 728 |
+
'semqa_short': [],
|
| 729 |
+
'semqa': [],
|
| 730 |
+
}
|
| 731 |
+
|
| 732 |
+
def add_batch(self, data): #(output, qa_pairs, answer, docs, query)
|
| 733 |
+
for metric in self.metrics:
|
| 734 |
+
self.data[metric].append({k:v for k, v in data.items() if k in data_list[metric]})
|
| 735 |
+
|
| 736 |
+
def compute(self):
|
| 737 |
+
ans = {}
|
| 738 |
+
for metric in self.metrics:
|
| 739 |
+
assert metric in metric_list, logging.info("Invalid metric")
|
| 740 |
+
if metric == 'cite_pr' and 'qampari' in self.metrics:
|
| 741 |
+
ans[metric] = metric_list[metric](data = self.data[metric], qampari = True)
|
| 742 |
+
else:
|
| 743 |
+
ans[metric] = metric_list[metric](data = self.data[metric])
|
| 744 |
+
#if metric == 'semqa':
|
| 745 |
+
# self.flag = True
|
| 746 |
+
#else:
|
| 747 |
+
# ans[metric] = metric_list[metric](data = self.data[metric], qampari = True if 'qampari' in self.metrics else False)
|
| 748 |
+
#if metric == 'rouge_all':
|
| 749 |
+
# ans[metric] = ans[metric][0]['rougeLsum'][0]
|
| 750 |
+
|
| 751 |
+
#if self.flag:
|
| 752 |
+
# ans['semqa'] = np.sqrt(ans['rouge_all'] * ans['semqa_f1'])
|
| 753 |
+
return ans
|
| 754 |
+
|
| 755 |
+
class AutoMetric(BasicMetric):
|
| 756 |
+
def __init__(self, task_name: str, config: Optional[List]) -> None:
|
| 757 |
+
super().__init__()
|
| 758 |
+
path_prefix = os.getenv("MOE_PEFT_METRIC_PATH")
|
| 759 |
+
if path_prefix is None:
|
| 760 |
+
path_prefix = ""
|
| 761 |
+
elif not path_prefix.endswith(os.sep):
|
| 762 |
+
path_prefix += os.sep
|
| 763 |
+
|
| 764 |
+
if task_name == "attribute":
|
| 765 |
+
self.metric_ = AttributeMetric(config)
|
| 766 |
+
elif ":" in task_name:
|
| 767 |
+
split = task_name.split(":")
|
| 768 |
+
self.metric_ = hf_evaluate.load(path_prefix + split[0], split[1])
|
| 769 |
+
else:
|
| 770 |
+
self.metric_ = hf_evaluate.load(path_prefix + task_name)
|
| 771 |
+
|
| 772 |
+
def add_batch(self, predictions: torch.Tensor, references: torch.Tensor):
|
| 773 |
+
self.metric_.add_batch(predictions=predictions, references=references)
|
| 774 |
+
|
| 775 |
+
def compute(self) -> Dict[str, Any]:
|
| 776 |
+
return self.metric_.compute()
|
| 777 |
+
|
| 778 |
+
|
| 779 |
+
class BasicTask:
|
| 780 |
+
def __init__(self) -> None:
|
| 781 |
+
pass
|
| 782 |
+
|
| 783 |
+
@property
|
| 784 |
+
def peft_task_type(self) -> str:
|
| 785 |
+
pass
|
| 786 |
+
|
| 787 |
+
def loading_data(
|
| 788 |
+
self, is_train: bool = True, path: Optional[str] = None
|
| 789 |
+
) -> List[InputData]:
|
| 790 |
+
pass
|
| 791 |
+
|
| 792 |
+
def loading_metric(self) -> BasicMetric:
|
| 793 |
+
pass
|
| 794 |
+
|
| 795 |
+
def init_kwargs(self) -> Dict:
|
| 796 |
+
return {}
|
| 797 |
+
|
| 798 |
+
|
| 799 |
+
# Casual Fine-tuning Tasks
|
| 800 |
+
# Instant-Created Class
|
| 801 |
+
class CasualTask(BasicTask):
|
| 802 |
+
@property
|
| 803 |
+
def peft_task_type(self) -> str:
|
| 804 |
+
return "CAUSAL_LM"
|
| 805 |
+
|
| 806 |
+
def loading_data(
|
| 807 |
+
self, is_train: bool = True, path: Optional[str] = None
|
| 808 |
+
) -> List[InputData]:
|
| 809 |
+
assert path is not None, "Casual supervised fine-tuning requires data path."
|
| 810 |
+
assert is_train, "Casual supervised fine-tuning task only supports training."
|
| 811 |
+
# Loading dataset
|
| 812 |
+
if path.endswith(".json") or path.endswith(".jsonl"):
|
| 813 |
+
data = hf_datasets.load_dataset("json", data_files=path)
|
| 814 |
+
elif ":" in path:
|
| 815 |
+
split = path.split(":")
|
| 816 |
+
data = hf_datasets.load_dataset(split[0], split[1])
|
| 817 |
+
else:
|
| 818 |
+
data = hf_datasets.load_dataset(path)
|
| 819 |
+
ret: List[InputData] = []
|
| 820 |
+
for data_point in data["train"]:
|
| 821 |
+
ret.append(
|
| 822 |
+
InputData(
|
| 823 |
+
inputs=Prompt(
|
| 824 |
+
instruction=data_point["instruction"],
|
| 825 |
+
input=data_point.get("input", None),
|
| 826 |
+
label=data_point.get("output", None),
|
| 827 |
+
)
|
| 828 |
+
)
|
| 829 |
+
)
|
| 830 |
+
|
| 831 |
+
return ret
|
| 832 |
+
|
| 833 |
+
|
| 834 |
+
# Sequence Classification
|
| 835 |
+
class SequenceClassificationTask(BasicTask):
|
| 836 |
+
def __init__(
|
| 837 |
+
self,
|
| 838 |
+
task_name: str,
|
| 839 |
+
task_type: str,
|
| 840 |
+
label_dtype: torch.dtype,
|
| 841 |
+
num_labels: int,
|
| 842 |
+
dataload_function: Callable,
|
| 843 |
+
# Setting to `None` corresponds to the task name.
|
| 844 |
+
metric_name: Optional[str] = None,
|
| 845 |
+
# The default values are "train" and "validation".
|
| 846 |
+
subset_map: Optional[Tuple[str, str]] = ("train", "validation"),
|
| 847 |
+
) -> None:
|
| 848 |
+
super().__init__()
|
| 849 |
+
self.task_name_ = task_name
|
| 850 |
+
self.task_type_ = task_type
|
| 851 |
+
self.label_dtype_ = label_dtype
|
| 852 |
+
self.num_labels_ = num_labels
|
| 853 |
+
self.dataload_function_ = dataload_function
|
| 854 |
+
if metric_name is None:
|
| 855 |
+
self.metric_name_ = task_name
|
| 856 |
+
else:
|
| 857 |
+
self.metric_name_ = metric_name
|
| 858 |
+
self.subset_map_ = subset_map
|
| 859 |
+
|
| 860 |
+
@property
|
| 861 |
+
def peft_task_type(self) -> str:
|
| 862 |
+
return "SEQ_CLS"
|
| 863 |
+
|
| 864 |
+
def loading_data(
|
| 865 |
+
self, is_train: bool = True, path: Optional[str] = None
|
| 866 |
+
) -> List[InputData]:
|
| 867 |
+
if ":" in self.task_name_:
|
| 868 |
+
split = self.task_name_.split(":")
|
| 869 |
+
data = hf_datasets.load_dataset(
|
| 870 |
+
split[0] if path is None else path, split[1]
|
| 871 |
+
)
|
| 872 |
+
else:
|
| 873 |
+
data = hf_datasets.load_dataset(self.task_name_ if path is None else path)
|
| 874 |
+
data = data[self.subset_map_[0] if is_train else self.subset_map_[1]]
|
| 875 |
+
logging.info(f"Preparing data for {self.task_name_.upper()}")
|
| 876 |
+
ret: List[InputData] = []
|
| 877 |
+
for data_point in data:
|
| 878 |
+
inputs, labels = self.dataload_function_(data_point)
|
| 879 |
+
assert isinstance(labels, List)
|
| 880 |
+
ret.append(InputData(inputs=inputs, labels=labels))
|
| 881 |
+
|
| 882 |
+
return ret
|
| 883 |
+
|
| 884 |
+
def loading_metric(self) -> BasicMetric:
|
| 885 |
+
return AutoMetric(self.metric_name_)
|
| 886 |
+
|
| 887 |
+
def init_kwargs(self) -> Dict:
|
| 888 |
+
return {
|
| 889 |
+
"task_type": self.task_type_,
|
| 890 |
+
"num_labels": self.num_labels_,
|
| 891 |
+
"label_dtype": self.label_dtype_,
|
| 892 |
+
}
|
| 893 |
+
|
| 894 |
+
|
| 895 |
+
# Common Sense
|
| 896 |
+
class CommonSenseTask(BasicTask):
|
| 897 |
+
def __init__(self) -> None:
|
| 898 |
+
super().__init__()
|
| 899 |
+
self.task_type_ = "common_sense"
|
| 900 |
+
self.label_dtype_ = None
|
| 901 |
+
|
| 902 |
+
@property
|
| 903 |
+
def peft_task_type(self) -> str:
|
| 904 |
+
return "QUESTION_ANS"
|
| 905 |
+
|
| 906 |
+
def label_list(self) -> List[str]:
|
| 907 |
+
pass
|
| 908 |
+
|
| 909 |
+
|
| 910 |
+
class AttributeTask(BasicTask):
|
| 911 |
+
def __init__(self) -> None:
|
| 912 |
+
super().__init__()
|
| 913 |
+
self.task_type_ = "attribute"
|
| 914 |
+
self.label_dtype_ = None
|
| 915 |
+
|
| 916 |
+
@property
|
| 917 |
+
def peft_task_type(self) -> str:
|
| 918 |
+
return "ATTRIBUTE"
|
| 919 |
+
|
| 920 |
+
task_dict = {}
|
| 921 |
+
|
| 922 |
+
|
| 923 |
+
# Multi-Task (Only for train)
|
| 924 |
+
class MultiTask(BasicTask):
|
| 925 |
+
def __init__(self, task_names: str) -> None:
|
| 926 |
+
super().__init__()
|
| 927 |
+
self.task_type_ = "multi_task"
|
| 928 |
+
self.label_dtype_ = None
|
| 929 |
+
self.task_list_: List[BasicTask] = []
|
| 930 |
+
task_names = task_names.split(";")
|
| 931 |
+
for name in task_names:
|
| 932 |
+
self.task_list_.append(task_dict[name])
|
| 933 |
+
|
| 934 |
+
def loading_data(
|
| 935 |
+
self, is_train: bool = True, path: Optional[str] = None
|
| 936 |
+
) -> List[InputData]:
|
| 937 |
+
logging.info(f"Preparing data for {len(self.task_list_)} tasks")
|
| 938 |
+
path_list = None if path is None else path.split(";")
|
| 939 |
+
data: List[InputData] = []
|
| 940 |
+
assert is_train
|
| 941 |
+
for idx, task in enumerate(self.task_list_):
|
| 942 |
+
path: str = "" if path_list is None else path_list[idx].strip()
|
| 943 |
+
data.extend(task.loading_data(is_train, None if len(path) == 0 else path))
|
| 944 |
+
return data
|
| 945 |
+
|
| 946 |
+
|
| 947 |
+
def main():
|
| 948 |
+
"""source = '/yy21/MoE-PEFT/dataset/APO/preference_data.jsonl'
|
| 949 |
+
data = []
|
| 950 |
+
with open(source, 'r') as f:
|
| 951 |
+
for line in f:
|
| 952 |
+
y = json.loads(line)
|
| 953 |
+
output = ""
|
| 954 |
+
for s in y['statements']:
|
| 955 |
+
if isinstance(s, List):
|
| 956 |
+
for i in s:
|
| 957 |
+
output += i + " "
|
| 958 |
+
else:
|
| 959 |
+
dot = s['statement'].strip()[-1]
|
| 960 |
+
output += s['statement'].strip()[:-1]
|
| 961 |
+
if 'revised_used_document' in s:
|
| 962 |
+
for i in s['revised_used_document']:
|
| 963 |
+
output += '[' + i + ']'
|
| 964 |
+
else:
|
| 965 |
+
if len(s['used_document']) != 0:
|
| 966 |
+
for i in s['used_document']:
|
| 967 |
+
output += '[' + i + ']'
|
| 968 |
+
output += dot + ' '
|
| 969 |
+
|
| 970 |
+
docs = [d['text'] for d in y['documents']]
|
| 971 |
+
fk = {
|
| 972 |
+
'query': y['query'],
|
| 973 |
+
'output': output,
|
| 974 |
+
'docs': docs,
|
| 975 |
+
}
|
| 976 |
+
ans = compute_autoais(fk)
|
| 977 |
+
print(ans)"""
|
| 978 |
+
def split_docs_and_answer(input_str):
|
| 979 |
+
|
| 980 |
+
if "[ANSWER]" not in input_str:
|
| 981 |
+
return ""
|
| 982 |
+
index = input_str.find("[ANSWER]")
|
| 983 |
+
ans = input_str[index + len("[ANSWER]"):][:-4].strip()
|
| 984 |
+
|
| 985 |
+
return ans
|
| 986 |
+
|
| 987 |
+
test_data = []
|
| 988 |
+
with open('/yy21/test_qamp_v2.jsonl', "r", encoding="utf-8") as fuck:
|
| 989 |
+
with open('/yy21/MoE-PEFT/dataset/front_output/qampari.json', "r", encoding="utf-8") as f:
|
| 990 |
+
data = json.load(f)
|
| 991 |
+
for idx, line in enumerate(fuck):
|
| 992 |
+
opt = json.loads(line)
|
| 993 |
+
|
| 994 |
+
ori_output = re.sub(r'\[ref_(\d+)\]', r'[\1]', opt['response'])
|
| 995 |
+
#qa_pairs = data[idx]['qa_pairs']
|
| 996 |
+
answer = data[idx]['answer']
|
| 997 |
+
query = data[idx]['question']
|
| 998 |
+
|
| 999 |
+
output = split_docs_and_answer(ori_output)
|
| 1000 |
+
ori_docs = []
|
| 1001 |
+
for i in range(5):
|
| 1002 |
+
ori_docs.append(data[idx]['docs'][i]['text'])
|
| 1003 |
+
fk = {
|
| 1004 |
+
#'qa_pairs' : qa_pairs,
|
| 1005 |
+
'answer' : answer,
|
| 1006 |
+
'query' : query,
|
| 1007 |
+
'docs' : ori_docs,
|
| 1008 |
+
'output' : ori_output
|
| 1009 |
+
}
|
| 1010 |
+
test_data.append(fk)
|
| 1011 |
+
ans = compute_autoais(test_data, qampari=True)
|
| 1012 |
+
print(ans)
|
| 1013 |
+
"""with open('/yy21/test_eli5_output0.jsonl', "r", encoding="utf-8") as fuck,\
|
| 1014 |
+
open('/yy21/test_eli5_output.jsonl', "w", encoding="utf-8") as outputf:
|
| 1015 |
+
for idx, line in enumerate(fuck):
|
| 1016 |
+
opt = json.loads(line)
|
| 1017 |
+
opt['accuracy'] = acc[idx]
|
| 1018 |
+
outputf.write(json.dumps(opt, ensure_ascii=False) + '\n')"""
|
| 1019 |
+
""" with open('/yy21/MoE-PEFT/dataset/front_output/eli5.json', "r", encoding="utf-8") as f:
|
| 1020 |
+
data = json.load(f)
|
| 1021 |
+
test_data = []
|
| 1022 |
+
for data_point in data:
|
| 1023 |
+
|
| 1024 |
+
ori_output = data_point['output']
|
| 1025 |
+
qa_pairs = data_point['claims']
|
| 1026 |
+
answer = data_point['answer']
|
| 1027 |
+
query = data_point['question']
|
| 1028 |
+
|
| 1029 |
+
output = split_docs_and_answer(ori_output)
|
| 1030 |
+
ori_docs = []
|
| 1031 |
+
for i in range(5):
|
| 1032 |
+
ori_docs.append(data_point['docs'][i]['text'])
|
| 1033 |
+
fk = {
|
| 1034 |
+
'qa_pairs' : qa_pairs,
|
| 1035 |
+
'answer' : answer,
|
| 1036 |
+
'query' : query,
|
| 1037 |
+
'docs' : ori_docs,
|
| 1038 |
+
'output' : output
|
| 1039 |
+
}
|
| 1040 |
+
test_data.append(fk)
|
| 1041 |
+
ans = compute_claims(test_data)
|
| 1042 |
+
print(ans)"""
|
| 1043 |
+
|
| 1044 |
+
if __name__ == "__main__":
|
| 1045 |
+
main()
|
c2cite/tasks/glue_tasks.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from .common import SequenceClassificationTask
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def update_task_dict(task_dict):
|
| 7 |
+
task_dict.update(
|
| 8 |
+
{
|
| 9 |
+
"glue:cola": SequenceClassificationTask(
|
| 10 |
+
task_name="glue:cola",
|
| 11 |
+
task_type="single_label_classification",
|
| 12 |
+
num_labels=2,
|
| 13 |
+
label_dtype=torch.long,
|
| 14 |
+
dataload_function=lambda data_point: (
|
| 15 |
+
[data_point["sentence"]],
|
| 16 |
+
[int(data_point["label"])],
|
| 17 |
+
),
|
| 18 |
+
),
|
| 19 |
+
"glue:mnli": SequenceClassificationTask(
|
| 20 |
+
task_name="glue:mnli",
|
| 21 |
+
task_type="single_label_classification",
|
| 22 |
+
num_labels=3,
|
| 23 |
+
label_dtype=torch.long,
|
| 24 |
+
dataload_function=lambda data_point: (
|
| 25 |
+
[data_point["premise"], data_point["hypothesis"]],
|
| 26 |
+
[int(data_point["label"])],
|
| 27 |
+
),
|
| 28 |
+
),
|
| 29 |
+
"glue:mrpc": SequenceClassificationTask(
|
| 30 |
+
task_name="glue:mrpc",
|
| 31 |
+
task_type="single_label_classification",
|
| 32 |
+
num_labels=2,
|
| 33 |
+
label_dtype=torch.long,
|
| 34 |
+
dataload_function=lambda data_point: (
|
| 35 |
+
[data_point["sentence1"], data_point["sentence2"]],
|
| 36 |
+
[int(data_point["label"])],
|
| 37 |
+
),
|
| 38 |
+
),
|
| 39 |
+
"glue:qnli": SequenceClassificationTask(
|
| 40 |
+
task_name="glue:qnli",
|
| 41 |
+
task_type="single_label_classification",
|
| 42 |
+
num_labels=2,
|
| 43 |
+
label_dtype=torch.long,
|
| 44 |
+
dataload_function=lambda data_point: (
|
| 45 |
+
[data_point["question"], data_point["sentence"]],
|
| 46 |
+
[int(data_point["label"])],
|
| 47 |
+
),
|
| 48 |
+
),
|
| 49 |
+
"glue:qqp": SequenceClassificationTask(
|
| 50 |
+
task_name="glue:qqp",
|
| 51 |
+
task_type="single_label_classification",
|
| 52 |
+
num_labels=2,
|
| 53 |
+
label_dtype=torch.long,
|
| 54 |
+
dataload_function=lambda data_point: (
|
| 55 |
+
[data_point["question1"], data_point["question2"]],
|
| 56 |
+
[int(data_point["label"])],
|
| 57 |
+
),
|
| 58 |
+
),
|
| 59 |
+
"glue:rte": SequenceClassificationTask(
|
| 60 |
+
task_name="glue:rte",
|
| 61 |
+
task_type="single_label_classification",
|
| 62 |
+
num_labels=2,
|
| 63 |
+
label_dtype=torch.long,
|
| 64 |
+
dataload_function=lambda data_point: (
|
| 65 |
+
[data_point["sentence1"], data_point["sentence2"]],
|
| 66 |
+
[int(data_point["label"])],
|
| 67 |
+
),
|
| 68 |
+
),
|
| 69 |
+
"glue:sst2": SequenceClassificationTask(
|
| 70 |
+
task_name="glue:sst2",
|
| 71 |
+
task_type="single_label_classification",
|
| 72 |
+
num_labels=2,
|
| 73 |
+
label_dtype=torch.long,
|
| 74 |
+
dataload_function=lambda data_point: (
|
| 75 |
+
[data_point["sentence"]],
|
| 76 |
+
[int(data_point["label"])],
|
| 77 |
+
),
|
| 78 |
+
),
|
| 79 |
+
"glue:wnli": SequenceClassificationTask(
|
| 80 |
+
task_name="glue:wnli",
|
| 81 |
+
task_type="single_label_classification",
|
| 82 |
+
num_labels=2,
|
| 83 |
+
label_dtype=torch.long,
|
| 84 |
+
dataload_function=lambda data_point: (
|
| 85 |
+
[data_point["sentence1"] + " </s> " + data_point["sentence2"]],
|
| 86 |
+
[int(data_point["label"])],
|
| 87 |
+
),
|
| 88 |
+
),
|
| 89 |
+
}
|
| 90 |
+
)
|