loadingy commited on
Commit
51be264
·
0 Parent(s):

first push

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. .gitignore +186 -0
  3. Dockerfile +55 -0
  4. LICENSE +201 -0
  5. README.md +34 -0
  6. c2cite.py +300 -0
  7. c2cite/__init__.py +52 -0
  8. c2cite/adapters/__init__.py +104 -0
  9. c2cite/adapters/loramoe/__init__.py +7 -0
  10. c2cite/adapters/loramoe/config.py +42 -0
  11. c2cite/adapters/loramoe/model.py +62 -0
  12. c2cite/adapters/mixlora/__init__.py +19 -0
  13. c2cite/adapters/mixlora/config.py +144 -0
  14. c2cite/adapters/mixlora/model.py +610 -0
  15. c2cite/adapters/mola/__init__.py +8 -0
  16. c2cite/adapters/mola/config.py +57 -0
  17. c2cite/adapters/mola/model.py +159 -0
  18. c2cite/common/__init__.py +92 -0
  19. c2cite/common/abstracts.py +194 -0
  20. c2cite/common/attention.py +293 -0
  21. c2cite/common/cache.py +554 -0
  22. c2cite/common/checkpoint.py +33 -0
  23. c2cite/common/config.py +234 -0
  24. c2cite/common/feed_forward.py +70 -0
  25. c2cite/common/lora_linear.py +511 -0
  26. c2cite/common/moe_utils.py +57 -0
  27. c2cite/common/rope.py +88 -0
  28. c2cite/dispatcher.py +378 -0
  29. c2cite/evaluator.py +518 -0
  30. c2cite/executors/__init__.py +54 -0
  31. c2cite/executors/common.py +77 -0
  32. c2cite/executors/cpu.py +51 -0
  33. c2cite/executors/cuda.py +53 -0
  34. c2cite/executors/mps.py +71 -0
  35. c2cite/generator.py +669 -0
  36. c2cite/model.py +1039 -0
  37. c2cite/models/__init__.py +40 -0
  38. c2cite/models/modeling_chatglm.py +855 -0
  39. c2cite/models/modeling_gemma.py +131 -0
  40. c2cite/models/modeling_gemma2.py +528 -0
  41. c2cite/models/modeling_llama.py +579 -0
  42. c2cite/models/modeling_mistral.py +255 -0
  43. c2cite/models/modeling_phi.py +576 -0
  44. c2cite/models/modeling_phi3.py +581 -0
  45. c2cite/prompter.py +63 -0
  46. c2cite/solutions.py +9 -0
  47. c2cite/tasks/__init__.py +29 -0
  48. c2cite/tasks/attribute_tasks.py +567 -0
  49. c2cite/tasks/common.py +1045 -0
  50. c2cite/tasks/glue_tasks.py +90 -0
.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ paper_wsdm_c2cite.pdf filter=lfs diff=lfs merge=lfs -text
2
+ *.pdf filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ .idea/
161
+
162
+ # IDEs
163
+ .vscode/
164
+
165
+ # MoE-PEFT
166
+ __pycache__/
167
+ *.egg-info/
168
+ *.egg
169
+ moe_peft.json
170
+ moe_peft_train_*.json
171
+
172
+ # macOS junk files
173
+ .DS_Store
174
+
175
+ # PEFT adapters
176
+ adapter_model.bin
177
+ adapter_config.json
178
+
179
+ result/
180
+ checkpoints/
181
+ cases/
182
+ dataset/
183
+ tblogs/
184
+ *.png
185
+ *.svg
186
+ logs
Dockerfile ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.5.1-devel-ubuntu22.04
2
+
3
+ ARG PYTHON_VERSION=3.11
4
+ ARG http_proxy
5
+ ARG https_proxy
6
+
7
+ RUN apt-get update
8
+
9
+ RUN apt-get install -y \
10
+ locales \
11
+ build-essential \
12
+ git \
13
+ git-lfs \
14
+ vim \
15
+ cmake \
16
+ pkg-config \
17
+ zlib1g-dev libncurses5-dev \
18
+ libgdbm-dev libnss3-dev libssl-dev libreadline-dev libffi-dev wget \
19
+ liblzma-dev libsqlite3-dev libbz2-dev
20
+
21
+ RUN apt-get clean
22
+
23
+ ENV LANG=en_US.UTF-8
24
+ ENV LANGUAGE=en_US:en
25
+ ENV LC_ALL=en_US.UTF-8
26
+
27
+ RUN sed -i '/en_US.UTF-8/s/^# //g' /etc/locale.gen && locale-gen
28
+
29
+ ENV PYENV_ROOT=/root/.pyenv
30
+ ENV PATH="$PYENV_ROOT/bin/:$PATH"
31
+
32
+ RUN /usr/bin/echo -e '#!/bin/bash\neval "$(pyenv init -)"\neval "$(pyenv virtualenv-init -)"\ncd /moe_peft\nbash' | tee /opt/init.sh \
33
+ && chmod +x /opt/init.sh \
34
+ && /usr/bin/echo -e 'export PYENV_ROOT=/root/.pyenv' >> ~/.bashrc \
35
+ && /usr/bin/echo -e 'export PATH=/root/.pyenv/bin:$PATH' >> ~/.bashrc \
36
+ && /usr/bin/echo -e 'eval "$(pyenv init -)"' >> ~/.bashrc \
37
+ && /usr/bin/echo -e 'eval "$(pyenv virtualenv-init -)"' >> ~/.bashrc \
38
+ && git clone https://github.com/pyenv/pyenv.git /root/.pyenv \
39
+ && git clone https://github.com/pyenv/pyenv-virtualenv.git /root/.pyenv/plugins/pyenv-virtualenv \
40
+ && cd /root/.pyenv && src/configure && make -C src \
41
+ && eval "$(pyenv init -)" \
42
+ && eval "$(pyenv virtualenv-init -)"
43
+
44
+ RUN . ~/.bashrc \
45
+ && pyenv install $PYTHON_VERSION \
46
+ && pyenv global $PYTHON_VERSION \
47
+ && git clone https://github.com/TUDB-Labs/MoE-PEFT /moe_peft \
48
+ && cd /moe_peft \
49
+ && pyenv virtualenv $PYTHON_VERSION moe_peft \
50
+ && pyenv local moe_peft \
51
+ && pip install -r ./requirements.txt --upgrade --no-compile --no-cache-dir
52
+
53
+ WORKDIR /moe_peft
54
+
55
+ CMD ["/bin/bash", "/opt/init.sh"]
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ This repository contains the code for the paper “C$^2$-Cite: Contextual-Aware Citation Generation for \\ Attributed Large Language Models”. The project is based on the open-source repository"[TUDB-Labs/MoE-PEFT](https://github.com/TUDB-Labs/MoE-PEFT)". C$^2$-Cite is a model that can answer the questions with citation markers.
2
+ ## File description
3
+ - **config**: Including the configurations of training or evaluating
4
+ - **c2cite/backends**: Some backend tools for GMoE.
5
+ - **c2cite/common**: The implementation of Transformer architecture.
6
+ - **c2cite/models**: The implementation of some series of Transformer-based models.
7
+ - **c2cite/tasks**: The implementation of datasets.
8
+ - **c2cite.py** The start file of this project.
9
+ ## Environment Requirements
10
+ - python3=3.11
11
+ - pytorch >= 2.1.2
12
+ - Other dependencies, See ```requirements.txt```
13
+ ## Quick Start
14
+ ### STEP 1: Download Base models
15
+ - [Llama-3-8B-inst]
16
+ ### STEP 2: Downlaod training datasets
17
+ To get Training dataset proposed in paper "Towards Faithful and Robust LLM Specialists for Evidence-Based Question-Answering", you can download [SynSciQA](https://github.com/EdisonNi-hku/Robust_Evidence_Based_QA) here. And please put SynSciQA.json, SynSciQA+.json, SynSciQA++.json in ./dataset/SynSciQA
18
+ ### STEP 3: Download evaluation datasets
19
+ We evaluate our model and baselines using [ALCE](https://github.com/princeton-nlp/ALCE). To get Evaluate datasets, please run
20
+ ```bash
21
+ bash download_test_data.sh
22
+ ```
23
+ ### STEP 4: Start training
24
+ Replace the **[base model]** and the **[train/evaluate config]** below with the directory of base model and the configuration in Folder "config".
25
+ ``````python
26
+ python c2cite.py --dir ./checkpoint --log_file ./logs --verbose --seed 42 --attn_impl eager --base_model [base model] --config [train/evaluate config] --device cuda:0
27
+ ``````
28
+ ### STEP 5: Conduct evaluation
29
+ After training process, we can conduct the evaluation step with the command below:
30
+ ``````python
31
+ python c2cite.py --dir ./checkpoint --log_file ./logs --verbose --seed 42 --attn_impl eager --base_model [base model] --config [train/evaluate config] --device cuda:0 --evaluate
32
+ ``````
33
+ ***Note***: **Do not** change the information in the **train config** after training step, or it won't find the right adapter.
34
+
c2cite.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import logging
4
+ import os
5
+ import sys
6
+ from typing import Dict, List, Tuple, Union
7
+
8
+ import torch
9
+ from transformers.utils import is_flash_attn_2_available
10
+
11
+ import moe_peft
12
+ import moe_peft.adapters
13
+
14
+ # Command Line Arguments
15
+ parser = argparse.ArgumentParser(description="MoE-PEFT main program")
16
+ parser.add_argument(
17
+ "--base_model", type=str, required=True, help="Path to or name of base model"
18
+ )
19
+ parser.add_argument(
20
+ "--inference", action="store_true", help="The inference mode (just for test)"
21
+ )
22
+ parser.add_argument(
23
+ "--evaluate", action="store_true", help="The evaluate mode (just for test)"
24
+ )
25
+ parser.add_argument(
26
+ "--disable_prompter", action="store_true", help="Disable prompter when inference"
27
+ )
28
+ parser.add_argument(
29
+ "--load_adapter",
30
+ action="store_true",
31
+ help="Load adapter from file instead of init randomly",
32
+ )
33
+ parser.add_argument(
34
+ "--disable_adapter", action="store_true", help="Disable the adapter modules"
35
+ )
36
+ parser.add_argument(
37
+ "--attn_impl", type=str, help="Specify the implementation of attention"
38
+ )
39
+ parser.add_argument(
40
+ "--sliding_window",
41
+ action="store_true",
42
+ help="Use sliding window attention (requires flash attention)",
43
+ )
44
+ parser.add_argument(
45
+ "--disable_cache",
46
+ action="store_true",
47
+ help="Disable cache when inference",
48
+ )
49
+ parser.add_argument(
50
+ "--cache_implementation",
51
+ type=str,
52
+ help="Specify the implementation of cache",
53
+ )
54
+ parser.add_argument(
55
+ "--fp16", action="store_true", help="Load base model in float16 precision"
56
+ )
57
+ parser.add_argument(
58
+ "--bf16", action="store_true", help="Load base model in bfloat16 precision"
59
+ )
60
+ parser.add_argument(
61
+ "--tf32", action="store_true", help="Use tfloat32 instead of float32 if available"
62
+ )
63
+ parser.add_argument(
64
+ "--load_8bit", action="store_true", help="Load base model with 8bit quantization"
65
+ )
66
+ parser.add_argument(
67
+ "--load_4bit", action="store_true", help="Load base model with 4bit quantization"
68
+ )
69
+ parser.add_argument("--device", type=str, help="Specify which GPU to be used")
70
+ parser.add_argument(
71
+ "--config", type=str, required=True, help="Path to finetune configuration"
72
+ )
73
+ parser.add_argument(
74
+ "--seed", type=int, default=42, help="Random seed in integer, default is 42"
75
+ )
76
+ parser.add_argument(
77
+ "--dir", type=str, default=".", help="Path to read or save checkpoints"
78
+ )
79
+ parser.add_argument("--disable_log", action="store_true", help="Disable logging")
80
+ parser.add_argument("--log_file", type=str, help="Save log to specific file")
81
+ parser.add_argument(
82
+ "--verbose", action="store_true", help="Show extra informations such as parameters"
83
+ )
84
+ parser.add_argument(
85
+ "--overwrite",
86
+ action="store_true",
87
+ help="Overwrite adapter model when older one existed",
88
+ )
89
+ parser.add_argument("--debug", action="store_true", help="Enabling debugging mode")
90
+ parser.add_argument(
91
+ "--deterministic",
92
+ action="store_true",
93
+ help="Use deterministic algorithms to improve the reproducibility",
94
+ )
95
+
96
+ args = parser.parse_args()
97
+
98
+
99
+ def query_yes_no(question, default="no"):
100
+ valid = {"yes": True, "y": True, "ye": True, "no": False, "n": False}
101
+ if default is None:
102
+ prompt = " [y/n] "
103
+ elif default == "yes":
104
+ prompt = " [Y/n] "
105
+ elif default == "no":
106
+ prompt = " [y/N] "
107
+ else:
108
+ raise ValueError("invalid default answer: '%s'" % default)
109
+
110
+ while True:
111
+ sys.stdout.write(question + prompt)
112
+ choice = input().lower()
113
+ if default is not None and choice == "":
114
+ return valid[default]
115
+ elif choice in valid:
116
+ return valid[choice]
117
+ else:
118
+ sys.stdout.write("Please respond with 'yes' or 'no' " "(or 'y' or 'n').\n")
119
+
120
+
121
+ def load_base_model() -> Tuple[moe_peft.Tokenizer, moe_peft.LLMModel]:
122
+ logging.info("Initializing pre-trained model.")
123
+ model = moe_peft.LLMModel.from_pretrained(
124
+ name_or_path=args.base_model,
125
+ device=args.device,
126
+ attn_impl=args.attn_impl,
127
+ use_sliding_window=args.sliding_window,
128
+ bits=(8 if args.load_8bit else (4 if args.load_4bit else None)),
129
+ load_dtype=(
130
+ torch.bfloat16
131
+ if args.bf16
132
+ else (torch.float16 if args.fp16 else torch.float32)
133
+ ),
134
+ )
135
+
136
+ tokenizer = moe_peft.Tokenizer(args.base_model)
137
+
138
+ return tokenizer, model
139
+
140
+
141
+ def init_adapter_config(
142
+ config: Dict[str, any],
143
+ llm_model: moe_peft.LLMModel,
144
+ ) -> List[Union[moe_peft.GenerateConfig, moe_peft.TrainConfig]]:
145
+ config_list = []
146
+
147
+ if config["cutoff_len"] == -1:
148
+ config["cutoff_len"] = llm_model.config_.max_seq_len_
149
+ logging.info(f"Setting cutoff_len to {llm_model.config_.max_seq_len_} automatically.")
150
+
151
+ for lora_config in config["lora"]:
152
+ adapter_name = lora_config["name"]
153
+ adapter_path = f"{args.dir}{os.sep}{adapter_name}"
154
+ if not args.load_adapter and os.path.exists(adapter_path):
155
+ if args.overwrite:
156
+ logging.warning(
157
+ f"Overwriting existed adapter model file: {adapter_path}"
158
+ )
159
+ elif not query_yes_no(
160
+ f"Existed adapter model file detected: {adapter_path}\n" + "Overwrite?"
161
+ ):
162
+ logging.info("User canceled training due to file conflict.")
163
+ exit(0)
164
+
165
+ if args.load_adapter:
166
+ llm_model.load_adapter(adapter_path, adapter_name)
167
+ else:
168
+ llm_model.init_adapter(moe_peft.adapters.lora_config_factory(lora_config))
169
+
170
+ if args.inference:
171
+ config_class = moe_peft.GenerateConfig(adapter_name=adapter_name)
172
+ if not args.disable_prompter:
173
+ config_class.prompt_template = lora_config.get("prompt", None)
174
+ config_list.append(config_class)
175
+ elif args.evaluate:
176
+ config_list.extend(moe_peft.EvaluateConfig.from_config(lora_config))
177
+ else:
178
+ config_list.append(moe_peft.TrainConfig.from_config(lora_config))
179
+
180
+ if args.verbose:
181
+ logging.info(config_list[-1].__dict__)
182
+
183
+ return config_list
184
+
185
+
186
+ def inference_callback(cur_pos, outputs):
187
+ print(f"POSITION: {cur_pos}")
188
+ for adapter_name, output in outputs.items():
189
+ print(f"{adapter_name} OUTPUT: {output[0]}")
190
+
191
+
192
+ def inference(
193
+ model: moe_peft.LLMModel,
194
+ tokenizer: moe_peft.Tokenizer,
195
+ configs: List[moe_peft.GenerateConfig],
196
+ concurrent_jobs: int,
197
+ ):
198
+ while True:
199
+ input_raw = input("INPUT WITHOUT PROMPT: ")
200
+ if input_raw == "QUIT":
201
+ return
202
+ for config in configs:
203
+ config.prompts = [input_raw]
204
+ callback = None if args.disable_log else inference_callback
205
+ outputs = moe_peft.generate(
206
+ model,
207
+ tokenizer,
208
+ configs,
209
+ max_gen_len=128,
210
+ use_cache=not args.disable_cache,
211
+ concurrent_jobs=concurrent_jobs,
212
+ cache_implementation=args.cache_implementation,
213
+ stream_callback=callback,
214
+ )
215
+ print(f"\n{'='*10}\n")
216
+ print(f"PROMPT: {input_raw}")
217
+ for adapter_name, output in outputs.items():
218
+ print(f"{adapter_name} OUTPUT:")
219
+ print(output[0])
220
+ print(f"\n{'='*10}\n")
221
+
222
+
223
+ # Main Function
224
+ if __name__ == "__main__":
225
+ if args.debug:
226
+ torch.autograd.set_detect_anomaly(True)
227
+
228
+ if args.inference or args.evaluate:
229
+ args.load_adapter = True
230
+ inference_mode = True
231
+ else:
232
+ inference_mode = False
233
+ #args.load_adapter = False##############################
234
+ moe_peft.setup_logging("INFO", args.log_file)
235
+
236
+ moe_peft_executor = moe_peft.executor
237
+
238
+ if not moe_peft_executor.check_available():
239
+ exit(-1)
240
+
241
+ if args.attn_impl is None:
242
+ if (
243
+ inference_mode
244
+ and moe_peft_executor.device_name() == "cuda"
245
+ and is_flash_attn_2_available()
246
+ ):
247
+ args.attn_impl = "flash_attn"
248
+ else:
249
+ args.attn_impl = "eager"
250
+
251
+ if args.device is None:
252
+ args.device = moe_peft.executor.default_device_name()
253
+
254
+ moe_peft_executor.use_deterministic_algorithms(args.deterministic)
255
+ moe_peft_executor.allow_tf32(args.tf32)
256
+ moe_peft_executor.manual_seed(args.seed)
257
+
258
+ with open(args.config, "r", encoding="utf8") as fp:
259
+ config = json.load(fp)
260
+
261
+ tokenizer, model = load_base_model()
262
+ adapters = init_adapter_config(config, model)
263
+
264
+ moe_peft_executor.empty_cache()
265
+
266
+ if os.getenv("MOE_PEFT_EVALUATE_MODE") is None:
267
+ logging.info("Using efficient operators.")
268
+ else:
269
+ logging.info("Using deterministic operators.")
270
+
271
+ if args.inference:
272
+ inference(
273
+ model=model,
274
+ tokenizer=tokenizer,
275
+ configs=adapters,
276
+ concurrent_jobs=config.get("inference_lora_simultaneously_num", 2),
277
+ )
278
+ elif args.evaluate:
279
+ moe_peft.evaluate(
280
+ model=model,
281
+ tokenizer=tokenizer,
282
+ configs=adapters,
283
+ max_concurrent_jobs=config.get("eval_lora_simultaneously_num", None),
284
+ retrying_steps=config.get("eval_rollback_retrying_steps", 20),
285
+ max_seq_len=config["cutoff_len"],
286
+ save_file=config.get("evaluate_result", None),
287
+ require_attention = -1,
288
+ require_hide = -1,
289
+ )
290
+ else:
291
+ moe_peft.train(
292
+ model=model,
293
+ tokenizer=tokenizer,
294
+ configs=adapters,
295
+ max_concurrent_jobs=config.get("train_lora_simultaneously_num", None),
296
+ strategy=config["train_strategy"],
297
+ cutoff_len=config["cutoff_len"],
298
+ save_step=config["save_step"],
299
+ save_dir=args.dir,
300
+ )
c2cite/__init__.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .common import (
2
+ AdapterConfig,
3
+ LLMBatchConfig,
4
+ LLMCache,
5
+ LLMForCausalLM,
6
+ LLMModelConfig,
7
+ LLMModelInput,
8
+ LLMModelOutput,
9
+ LoraConfig,
10
+ cache_factory,
11
+ )
12
+ from .dispatcher import Dispatcher, TrainTask
13
+ from .evaluator import EvaluateConfig, evaluate
14
+ from .executors import executor
15
+ from .generator import GenerateConfig, generate
16
+ from .model import LLMModel
17
+ from .prompter import Prompter
18
+ from .tokenizer import Tokenizer
19
+ from .trainer import TrainConfig, train
20
+ from .utils import is_package_available, setup_logging
21
+
22
+ assert is_package_available("torch", "2.3.0"), "MoE-PEFT requires torch>=2.3.0"
23
+ assert is_package_available(
24
+ "transformers", "4.43.0"
25
+ ), "MoE-PEFT requires transformers>=4.43.0"
26
+
27
+ setup_logging()
28
+
29
+ __all__ = [
30
+ "LLMCache",
31
+ "cache_factory",
32
+ "LLMModelConfig",
33
+ "LLMModelOutput",
34
+ "LLMForCausalLM",
35
+ "LLMBatchConfig",
36
+ "LLMModelInput",
37
+ "AdapterConfig",
38
+ "LoraConfig",
39
+ "TrainTask",
40
+ "Dispatcher",
41
+ "EvaluateConfig",
42
+ "evaluate",
43
+ "GenerateConfig",
44
+ "generate",
45
+ "TrainConfig",
46
+ "train",
47
+ "LLMModel",
48
+ "Prompter",
49
+ "Tokenizer",
50
+ "setup_logging",
51
+ "executor",
52
+ ]
c2cite/adapters/__init__.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, TypeAlias
2
+
3
+ import torch
4
+
5
+ from moe_peft.common import AdapterConfig, LoraConfig
6
+
7
+ from .loramoe import LoraMoe, LoraMoeConfig
8
+ from .mixlora import (
9
+ DynamicRouterLoss,
10
+ DynamicSparseMoe,
11
+ MixLoraConfig,
12
+ MixtralRouterLoss,
13
+ MixtralSparseMoe,
14
+ SwitchRouterLoss,
15
+ SwitchSparseMoe,
16
+ )
17
+ from .mola import MolaConfig, MolaRouterLoss, MolaSparseMoe
18
+
19
+ peft_type_dict = {
20
+ "LORA": LoraConfig,
21
+ "MIXLORA": MixLoraConfig,
22
+ "LORAMOE": LoraMoeConfig,
23
+ "MOLA": MolaConfig,
24
+ }
25
+
26
+ routing_strategy_dict = {
27
+ "mixlora": MixLoraConfig,
28
+ "mixlora-dynamic": MixLoraConfig,
29
+ "mixlora-switch": MixLoraConfig,
30
+ "loramoe": LoraMoeConfig,
31
+ "mola": MolaConfig,
32
+ }
33
+
34
+ router_loss_dict = {
35
+ "mixlora": MixtralRouterLoss,
36
+ "mixlora-dynamic": DynamicRouterLoss,
37
+ "mixlora-switch": SwitchRouterLoss,
38
+ "mola": MolaRouterLoss,
39
+ }
40
+
41
+ moe_layer_dict = {
42
+ "mixlora": MixtralSparseMoe,
43
+ "mixlora-dynamic": DynamicSparseMoe,
44
+ "mixlora-switch": SwitchSparseMoe,
45
+ "loramoe": LoraMoe,
46
+ "mola": MolaSparseMoe,
47
+ }
48
+
49
+
50
+ def lora_config_factory(config: Dict[str, any]) -> LoraConfig:
51
+ if peft_type_dict.get(config.get("peft_type", ""), None) is not None:
52
+ config_class: TypeAlias[AdapterConfig] = peft_type_dict[config["peft_type"]]
53
+ elif (
54
+ routing_strategy_dict.get(config.get("routing_strategy", ""), None) is not None
55
+ ):
56
+ config_class: TypeAlias[AdapterConfig] = routing_strategy_dict[
57
+ config["routing_strategy"]
58
+ ]
59
+ else:
60
+ config_class = LoraConfig
61
+
62
+ return config_class.from_config(config).check()
63
+
64
+
65
+ def router_loss_factory(config: MixLoraConfig) -> torch.nn.Module:
66
+ if config.routing_strategy_ not in router_loss_dict:
67
+ return None
68
+ if config.router_loss_:
69
+ return router_loss_dict[config.routing_strategy_](config)
70
+ else:
71
+ return None
72
+
73
+
74
+ def moe_layer_factory(
75
+ in_features: int,
76
+ device: torch.device,
77
+ config: MolaConfig,
78
+ gate: Optional[torch.Tensor] = None,
79
+ ) -> torch.nn.Module:
80
+ if config.routing_strategy_ not in moe_layer_dict:
81
+ raise ValueError(f"Unknown routing strategy {config.routing_strategy_}")
82
+ return moe_layer_dict[config.routing_strategy_](in_features, device, config, gate)
83
+
84
+
85
+ __all__ = [
86
+ "MixLoraConfig",
87
+ "MixtralRouterLoss",
88
+ "MixtralSparseMoe",
89
+ "DynamicRouterLoss",
90
+ "DynamicSparseMoe",
91
+ "SwitchRouterLoss",
92
+ "SwitchSparseMoe",
93
+ "LoraMoeConfig",
94
+ "LoraMoe",
95
+ "MolaConfig",
96
+ "MolaSparseMoe",
97
+ "peft_type_dict",
98
+ "routing_strategy_dict",
99
+ "router_loss_dict",
100
+ "moe_layer_dict",
101
+ "lora_config_factory",
102
+ "router_loss_factory",
103
+ "moe_layer_factory",
104
+ ]
c2cite/adapters/loramoe/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .config import LoraMoeConfig
2
+ from .model import LoraMoe
3
+
4
+ __all__ = [
5
+ "LoraMoeConfig",
6
+ "LoraMoe",
7
+ ]
c2cite/adapters/loramoe/config.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from dataclasses import dataclass
3
+ from typing import Dict
4
+
5
+ from moe_peft.common import LoraConfig
6
+
7
+
8
+ @dataclass
9
+ class LoraMoeConfig(LoraConfig):
10
+ num_experts_: int = None
11
+ router_init_range_: float = None
12
+ routing_strategy_: str = "loramoe"
13
+
14
+ def check(self) -> "LoraMoeConfig":
15
+ super().check()
16
+ assert isinstance(self.num_experts_, int) and self.num_experts_ > 0
17
+ assert (
18
+ isinstance(self.router_init_range_, float) and self.router_init_range_ >= 0
19
+ )
20
+
21
+ return self
22
+
23
+ @staticmethod
24
+ def from_config(config: Dict[str, any]) -> "LoraMoeConfig":
25
+ return LoraMoeConfig(
26
+ num_experts_=config["num_experts"],
27
+ router_init_range_=config.get("router_init_range", 5.0),
28
+ **LoraConfig.from_config(config).__dict__,
29
+ )
30
+
31
+ def export(self) -> Dict[str, any]:
32
+ config = super().export()
33
+ config["peft_type"] = "LORAMOE"
34
+ config["routing_strategy"] = self.routing_strategy_
35
+ config["num_experts"] = self.num_experts_
36
+
37
+ return config
38
+
39
+ def expert_config(self, expert_idx: int) -> LoraConfig:
40
+ config = copy.deepcopy(super())
41
+ config.adapter_name = f"moe.{self.adapter_name}.experts.{expert_idx}"
42
+ return config
c2cite/adapters/loramoe/model.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from moe_peft.common import Linear, LLMMoeBlock
8
+
9
+ from .config import LoraMoeConfig
10
+
11
+
12
+ class LoraMoe(LLMMoeBlock):
13
+ def __init__(
14
+ self,
15
+ in_features: int,
16
+ device: torch.device,
17
+ config: LoraMoeConfig,
18
+ gate: Optional[torch.Tensor] = None,
19
+ ) -> None:
20
+ super().__init__()
21
+
22
+ self.adapter_name_: str = config.adapter_name
23
+ self.dtype_: torch.dtype = torch.float32
24
+ self.gate_ = torch.nn.Linear(
25
+ in_features,
26
+ config.num_experts_,
27
+ bias=False,
28
+ device=device,
29
+ dtype=torch.float32,
30
+ )
31
+ self.experts_ = config.num_experts_
32
+ self.router_logits_: torch.Tensor = None
33
+
34
+ if gate is None:
35
+ torch.nn.init.kaiming_uniform_(
36
+ self.gate_.weight, a=math.sqrt(config.router_init_range_)
37
+ )
38
+ else:
39
+ with torch.no_grad():
40
+ self.gate_.weight.copy_(gate)
41
+
42
+ def forward(
43
+ self,
44
+ residual: torch.Tensor,
45
+ hidden_states: torch.Tensor,
46
+ lora_linear: Optional[Linear] = None,
47
+ ) -> Tuple:
48
+ assert lora_linear is not None
49
+ router_logits = self.gate_(hidden_states.to(self.dtype_))
50
+ self.router_logits_ = router_logits.reshape(-1, self.experts_)
51
+ routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)
52
+
53
+ for expert_idx in range(self.experts_):
54
+ expert_lora = lora_linear.loras_[
55
+ f"moe.{self.adapter_name_}.experts.{expert_idx}"
56
+ ]
57
+ residual = residual + (
58
+ torch.unsqueeze(routing_weights[:, :, expert_idx], -1)
59
+ * expert_lora.lora_forward(hidden_states)
60
+ ).to(hidden_states.dtype)
61
+
62
+ return residual
c2cite/adapters/mixlora/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .config import MixLoraConfig
2
+ from .model import (
3
+ DynamicRouterLoss,
4
+ DynamicSparseMoe,
5
+ MixtralRouterLoss,
6
+ MixtralSparseMoe,
7
+ SwitchRouterLoss,
8
+ SwitchSparseMoe,
9
+ )
10
+
11
+ __all__ = [
12
+ "MixLoraConfig",
13
+ "MixtralRouterLoss",
14
+ "MixtralSparseMoe",
15
+ "DynamicRouterLoss",
16
+ "DynamicSparseMoe",
17
+ "SwitchRouterLoss",
18
+ "SwitchSparseMoe",
19
+ ]
c2cite/adapters/mixlora/config.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from dataclasses import dataclass
3
+ from typing import Dict, Optional, Union
4
+
5
+ import torch
6
+ from transformers.activations import ACT2FN
7
+
8
+ from moe_peft.common import LoraConfig
9
+
10
+ available_routing_strategies = ["mixlora", "mixlora-dynamic", "mixlora-switch"]
11
+
12
+
13
+ @dataclass
14
+ class MixLoraConfig(LoraConfig):
15
+ # expert lora
16
+ expert_config_: LoraConfig = None
17
+ # router config
18
+ router_aux_loss_coef_: float = None
19
+ router_init_range_: float = None
20
+ routing_strategy_: str = None
21
+ jitter_noise_: float = None
22
+ router_loss_: bool = True
23
+ num_experts_: int = None
24
+ act_fn_: Optional[Union[str, torch.nn.Module]] = None
25
+ # mixtral config
26
+ top_k_: int = None
27
+ # dynamic config
28
+ top_p_: float = None
29
+ temperature_: float = None
30
+ # switch transformers config
31
+ router_z_loss_coef_: float = None
32
+ expert_capacity_: int = None
33
+ ffn_dropout_: float = None
34
+ sparse_step_: int = None
35
+
36
+ def check(self) -> "MixLoraConfig":
37
+ super().check()
38
+ if self.expert_config_ is not None:
39
+ self.expert_config_.check()
40
+ assert (
41
+ isinstance(self.router_aux_loss_coef_, float)
42
+ and self.router_aux_loss_coef_ >= 0
43
+ )
44
+ assert (
45
+ isinstance(self.router_init_range_, float) and self.router_init_range_ >= 0
46
+ )
47
+ assert (
48
+ isinstance(self.routing_strategy_, str)
49
+ and self.routing_strategy_ in available_routing_strategies
50
+ )
51
+ assert isinstance(self.jitter_noise_, float) and self.jitter_noise_ >= 0
52
+ assert isinstance(self.router_loss_, bool)
53
+ assert isinstance(self.num_experts_, int) and self.num_experts_ > 0
54
+ assert self.act_fn_ is None or (
55
+ isinstance(self.act_fn_, str) and self.act_fn_ in ACT2FN
56
+ )
57
+ if self.routing_strategy_ == "mixlora":
58
+ assert isinstance(self.top_k_, int) and self.top_k_ > 0
59
+ elif self.routing_strategy_ == "mixlora-dynamic":
60
+ assert (
61
+ isinstance(self.top_p_, float) and self.top_p_ > 0 and self.top_p_ <= 1
62
+ )
63
+ assert isinstance(self.temperature_, float) and self.temperature_ >= 0
64
+ elif self.routing_strategy_ == "mixlora-switch":
65
+ assert (
66
+ isinstance(self.router_z_loss_coef_, float)
67
+ and self.router_z_loss_coef_ >= 0
68
+ )
69
+ if self.sparse_step_ is not None:
70
+ assert isinstance(self.sparse_step_, int) and self.sparse_step_ > 0
71
+ assert isinstance(self.expert_capacity_, int) and self.expert_capacity_ > 0
72
+ assert isinstance(self.ffn_dropout_, float) and self.ffn_dropout_ >= 0
73
+
74
+ return self
75
+
76
+ @staticmethod
77
+ def from_config(config: Dict[str, any]) -> "MixLoraConfig":
78
+ lora_config = MixLoraConfig(**LoraConfig.from_config(config).__dict__)
79
+ if "expert_lora" in config:
80
+ expert_config = copy.deepcopy(config)
81
+ expert_config.update(config["expert_lora"])
82
+ lora_config.expert_config_ = LoraConfig().from_config(expert_config)
83
+ lora_config.router_aux_loss_coef_ = config.get(
84
+ "router_aux_loss_coef", 0.001
85
+ ) # for training
86
+ lora_config.routing_strategy_ = config["routing_strategy"]
87
+ lora_config.router_loss_ = config.get("router_loss", True)
88
+ lora_config.num_experts_ = config["num_experts"]
89
+ # silu for mixtral or gelu_new for switch transformers
90
+ # left blank to automatically use the original act_fn of FFN
91
+ lora_config.act_fn_ = config.get("act_fn", None)
92
+ if lora_config.routing_strategy_ == "mixlora":
93
+ lora_config.router_init_range_ = config.get("router_init_range", 0.02)
94
+ lora_config.jitter_noise_ = config.get("jitter_noise", 0.0)
95
+ lora_config.top_k_ = config.get("top_k", 2)
96
+ elif lora_config.routing_strategy_ == "mixlora-dynamic":
97
+ lora_config.router_init_range_ = config.get("router_init_range", 0.02)
98
+ lora_config.jitter_noise_ = config.get("jitter_noise", 0.0)
99
+ lora_config.top_p_ = config.get("top_p", 0.8)
100
+ lora_config.temperature_ = config.get("temperature", 0.0)
101
+ elif lora_config.routing_strategy_ == "mixlora-switch":
102
+ lora_config.router_init_range_ = config.get("router_init_range", 1.0)
103
+ lora_config.jitter_noise_ = config.get("jitter_noise", 0.01)
104
+ lora_config.router_z_loss_coef_ = config.get(
105
+ "router_z_loss_coef", 0.001
106
+ ) # for training
107
+ # expert_capacity = (max_sequence_length / num_experts) * capacity_factor
108
+ # common values of capacity_factor: 1.0, 1.25, 2.0
109
+ lora_config.expert_capacity_ = config.get("expert_capacity", 32)
110
+ lora_config.ffn_dropout_ = config.get("ffn_dropout", 0.0)
111
+ lora_config.sparse_step_ = config.get("sparse_step", None)
112
+
113
+ return lora_config
114
+
115
+ def export(self) -> Dict[str, any]:
116
+ config = super().export()
117
+ config["peft_type"] = "MIXLORA"
118
+ if self.expert_config_ is not None:
119
+ expert_config = self.expert_config_.export()
120
+ expert_config.pop("peft_type")
121
+ expert_config.pop("target_modules")
122
+ config["expert_lora"] = expert_config
123
+ config["routing_strategy"] = self.routing_strategy_
124
+ config["num_experts"] = self.num_experts_
125
+ if self.act_fn_ is not None and isinstance(self.act_fn_, str):
126
+ config["act_fn"] = self.act_fn_
127
+ if self.routing_strategy_ == "mixlora":
128
+ config["top_k"] = self.top_k_
129
+ elif self.routing_strategy_ == "mixlora-dynamic":
130
+ config["top_p"] = self.top_p_
131
+ config["temperature"] = self.temperature_
132
+ elif self.routing_strategy_ == "mixlora-switch":
133
+ config["expert_capacity"] = self.expert_capacity_
134
+ config["sparse_step"] = self.sparse_step_
135
+
136
+ return config
137
+
138
+ def expert_config(self, expert_idx: int) -> LoraConfig:
139
+ if self.expert_config_ is None:
140
+ config = copy.deepcopy(super())
141
+ else:
142
+ config = copy.deepcopy(self.expert_config_)
143
+ config.adapter_name = f"moe.{self.adapter_name}.experts.{expert_idx}"
144
+ return config
c2cite/adapters/mixlora/model.py ADDED
@@ -0,0 +1,610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from transformers.activations import ACT2FN
6
+
7
+ from moe_peft.common import LLMFeedForward, LLMModelInput, LLMMoeBlock, slice_tensor
8
+
9
+ from .config import MixLoraConfig
10
+
11
+
12
+ def _mixlora_compatible_forward(
13
+ ffn_layer: LLMFeedForward,
14
+ moe_name: str,
15
+ act_fn: torch.nn.Module,
16
+ expert_mask: torch.Tensor,
17
+ hidden_states: torch.Tensor,
18
+ input_dtype: torch.device,
19
+ ):
20
+ final_expert_states = []
21
+ for expert_idx in range(expert_mask.shape[0]):
22
+ _, top_x = torch.where(expert_mask[expert_idx])
23
+ lora_name = f"moe.{moe_name}.experts.{expert_idx}"
24
+ lora_data = slice_tensor(hidden_states, top_x, input_dtype)
25
+ final_expert_states.append(
26
+ ffn_layer._lora_forward(lora_name, act_fn, lora_data)
27
+ )
28
+
29
+ return final_expert_states
30
+
31
+
32
+ def _mixtral_load_balancing_loss_func(
33
+ gate_logits: torch.Tensor,
34
+ num_experts: int,
35
+ top_k: int,
36
+ attention_mask: Optional[torch.Tensor] = None,
37
+ ) -> float:
38
+ routing_weights = torch.nn.functional.softmax(gate_logits, dim=-1)
39
+
40
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
41
+
42
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
43
+
44
+ if attention_mask is None:
45
+ # Compute the percentage of tokens routed to each experts
46
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
47
+
48
+ # Compute the average probability of routing to these experts
49
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
50
+ else:
51
+ batch_size, sequence_length = attention_mask.shape
52
+ num_hidden_layers = routing_weights.shape[0] // (batch_size * sequence_length)
53
+
54
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
55
+ expert_attention_mask = (
56
+ attention_mask[None, :, :, None, None]
57
+ .expand(
58
+ (num_hidden_layers, batch_size, sequence_length, top_k, num_experts)
59
+ )
60
+ .reshape(-1, top_k, num_experts)
61
+ .to(routing_weights.device)
62
+ )
63
+
64
+ # Compute the percentage of tokens routed to each experts
65
+ tokens_per_expert = torch.sum(
66
+ expert_mask.float() * expert_attention_mask, dim=0
67
+ ) / torch.sum(expert_attention_mask, dim=0)
68
+
69
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
70
+ router_per_expert_attention_mask = (
71
+ attention_mask[None, :, :, None]
72
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
73
+ .reshape(-1, num_experts)
74
+ .to(routing_weights.device)
75
+ )
76
+
77
+ # Compute the average probability of routing to these experts
78
+ router_prob_per_expert = torch.sum(
79
+ routing_weights * router_per_expert_attention_mask, dim=0
80
+ ) / torch.sum(router_per_expert_attention_mask, dim=0)
81
+
82
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
83
+ return overall_loss * num_experts
84
+
85
+
86
+ class MixtralRouterLoss(torch.nn.Module):
87
+ def __init__(self, config: MixLoraConfig) -> None:
88
+ super().__init__()
89
+ self.aux_loss_coef = config.router_aux_loss_coef_
90
+ self.experts = config.num_experts_
91
+ self.topk = config.top_k_
92
+
93
+ def forward(self, gate_logits, attention_mask) -> torch.Tensor:
94
+ return self.aux_loss_coef * _mixtral_load_balancing_loss_func(
95
+ gate_logits, self.experts, self.topk, attention_mask
96
+ )
97
+
98
+
99
+ class MixtralSparseMoe(LLMMoeBlock):
100
+ def __init__(
101
+ self,
102
+ in_features: int,
103
+ device: torch.device,
104
+ config: MixLoraConfig,
105
+ gate: Optional[torch.Tensor] = None,
106
+ ) -> None:
107
+ super().__init__()
108
+
109
+ self.adapter_name_: str = config.adapter_name
110
+ self.dtype_: torch.dtype = torch.float32
111
+ self.gate_ = torch.nn.Linear(
112
+ in_features,
113
+ config.num_experts_,
114
+ bias=False,
115
+ device=device,
116
+ dtype=self.dtype_,
117
+ )
118
+ self.act_ = (
119
+ ACT2FN[config.act_fn_]
120
+ if isinstance(config.act_fn_, str)
121
+ else config.act_fn_
122
+ )
123
+ self.experts_: int = config.num_experts_
124
+ self.topk_: int = config.top_k_
125
+ self.jitter_noise_: float = config.jitter_noise_
126
+ self.router_profile_: bool = False
127
+ self.profiler_: List[int] = None
128
+
129
+ if gate is None:
130
+ torch.nn.init.normal_(
131
+ self.gate_.weight,
132
+ mean=0.0,
133
+ std=config.router_init_range_,
134
+ )
135
+ else:
136
+ with torch.no_grad():
137
+ self.gate_.weight.copy_(gate)
138
+
139
+ def state_dict(self) -> Dict[str, torch.nn.Module]:
140
+ return {"gate": self.gate_.weight}
141
+
142
+ def _profiling(
143
+ self, batch_size: int, sequence_length: int, selected_experts: torch.Tensor
144
+ ) -> None:
145
+ if not self.router_profile_:
146
+ return
147
+
148
+ router_statistic_ = list(0 for _ in range(self.experts_))
149
+ for selected in selected_experts.tolist():
150
+ for idx in selected:
151
+ router_statistic_[idx] += 1
152
+
153
+ if self.profiler_ is None:
154
+ self.profiler_ = list(0 for _ in range(self.experts_))
155
+ for idx in range(self.experts_):
156
+ self.profiler_[idx] = (
157
+ router_statistic_[idx] / batch_size
158
+ ) / sequence_length
159
+ else:
160
+ for idx in range(self.experts_):
161
+ pressure = (router_statistic_[idx] / batch_size) / sequence_length
162
+ self.profiler_[idx] = (self.profiler_[idx] + pressure) / 2
163
+
164
+ def forward(
165
+ self,
166
+ hidden_states: torch.Tensor,
167
+ ffn_layer: LLMFeedForward,
168
+ input_args: LLMModelInput,
169
+ ) -> Tuple:
170
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
171
+
172
+ if not input_args.inference_mode_ and self.jitter_noise_ > 0:
173
+ # Multiply the token inputs by the uniform distribution - adding some noise
174
+ hidden_states *= torch.empty_like(hidden_states).uniform_(
175
+ 1.0 - self.jitter_noise_, 1.0 + self.jitter_noise_
176
+ )
177
+
178
+ input_dtype = hidden_states.dtype
179
+ hidden_states = hidden_states.view(-1, hidden_dim).to(self.dtype_)
180
+ # router_logits: (batch * sequence_length, n_experts)
181
+ router_logits = self.gate_(hidden_states)
182
+
183
+ routing_weights = F.softmax(router_logits, dim=1, dtype=self.dtype_)
184
+ routing_weights, selected_experts = torch.topk(
185
+ routing_weights, self.topk_, dim=-1
186
+ )
187
+
188
+ self._profiling(batch_size, sequence_length, selected_experts)
189
+
190
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
191
+
192
+ final_hidden_states = torch.zeros(
193
+ (batch_size * sequence_length, hidden_dim),
194
+ dtype=self.dtype_,
195
+ device=hidden_states.device,
196
+ )
197
+
198
+ # One hot encode the selected experts to create an expert mask
199
+ # this will be used to easily index which expert is going to be sollicitated
200
+ expert_mask = torch.nn.functional.one_hot(
201
+ selected_experts, num_classes=self.experts_
202
+ ).permute(2, 1, 0)
203
+
204
+ # Perform the computation on each expert
205
+ if input_args.efficient_operator_ and hasattr(ffn_layer, "_mixlora_forward"):
206
+ expert_states = ffn_layer._mixlora_forward(
207
+ self.adapter_name_, self.act_, expert_mask, hidden_states, input_dtype
208
+ )
209
+ else:
210
+ expert_states = _mixlora_compatible_forward(
211
+ ffn_layer,
212
+ self.adapter_name_,
213
+ self.act_,
214
+ expert_mask,
215
+ hidden_states,
216
+ input_dtype,
217
+ )
218
+
219
+ # Unpack
220
+ for expert_idx in range(self.experts_):
221
+ idx, top_x = torch.where(expert_mask[expert_idx])
222
+
223
+ # Index the correct hidden states and compute the expert hidden state for
224
+ # the current expert. We need to make sure to multiply the output hidden
225
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
226
+ current_hidden_states = (
227
+ expert_states[expert_idx] * routing_weights[top_x, idx, None]
228
+ )
229
+
230
+ # However `index_add_` only support torch tensors for indexing so we'll use
231
+ # the `top_x` tensor here.
232
+ final_hidden_states.index_add_(
233
+ 0, top_x, current_hidden_states.to(self.dtype_)
234
+ )
235
+
236
+ final_hidden_states = final_hidden_states.reshape(
237
+ batch_size, sequence_length, hidden_dim
238
+ ).to(input_dtype)
239
+
240
+ return final_hidden_states, router_logits
241
+
242
+
243
+ def _dynamic_top_p(router_logits: torch.Tensor, top_p: float, temperature: float = 0.0):
244
+ if temperature > 0.0:
245
+ router_logits = router_logits / temperature
246
+ sorted_logits, sorted_indices = torch.sort(router_logits, dim=-1, descending=True)
247
+ cumulative_probs = sorted_logits.cumsum(dim=-1)
248
+ expert_mask = cumulative_probs > top_p
249
+ threshold_indices = expert_mask.long().argmax(dim=-1)
250
+ threshold_mask = torch.nn.functional.one_hot(
251
+ threshold_indices, num_classes=sorted_indices.size(-1)
252
+ ).bool()
253
+ expert_mask = expert_mask & ~threshold_mask
254
+ sorted_logits = sorted_logits.masked_fill(expert_mask, 0.0)
255
+ sorted_indices = sorted_indices.masked_fill(expert_mask, -1)
256
+ return sorted_logits, sorted_indices
257
+
258
+
259
+ def _dynamic_load_balancing_loss_func(
260
+ routing_weights: torch.Tensor,
261
+ num_experts: int,
262
+ top_p: float,
263
+ temperature: float,
264
+ ) -> float:
265
+ _, selected_experts = _dynamic_top_p(routing_weights, top_p, temperature)
266
+
267
+ expert_mask = torch.empty(
268
+ (num_experts, num_experts, routing_weights.size(0)),
269
+ dtype=routing_weights.dtype,
270
+ device=routing_weights.device,
271
+ )
272
+
273
+ for expert_idx in range(num_experts):
274
+ expert_mask[expert_idx] = (selected_experts == expert_idx).transpose(0, 1)
275
+
276
+ expert_mask = expert_mask.permute(2, 1, 0)
277
+
278
+ # Compute the percentage of tokens routed to each experts
279
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
280
+
281
+ # Compute the average probability of routing to these experts
282
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
283
+
284
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
285
+ return overall_loss * num_experts
286
+
287
+
288
+ class DynamicRouterLoss(torch.nn.Module):
289
+ def __init__(self, config: MixLoraConfig) -> None:
290
+ super().__init__()
291
+ self.aux_loss_coef = config.router_aux_loss_coef_
292
+ self.experts = config.num_experts_
293
+ self.top_p = config.top_p_
294
+ self.temperature = config.temperature_
295
+
296
+ def forward(self, gate_logits, attention_mask) -> torch.Tensor:
297
+ routing_weights = torch.nn.functional.softmax(gate_logits, dim=-1)
298
+ return self.aux_loss_coef * _dynamic_load_balancing_loss_func(
299
+ routing_weights,
300
+ self.experts,
301
+ self.top_p,
302
+ self.temperature,
303
+ )
304
+
305
+
306
+ class DynamicSparseMoe(LLMMoeBlock):
307
+ def __init__(
308
+ self,
309
+ in_features: int,
310
+ device: torch.device,
311
+ config: MixLoraConfig,
312
+ gate: Optional[torch.Tensor] = None,
313
+ ) -> None:
314
+ super().__init__()
315
+
316
+ self.adapter_name_: str = config.adapter_name
317
+ self.dtype_: torch.dtype = torch.float32
318
+ self.gate_ = torch.nn.Linear(
319
+ in_features,
320
+ config.num_experts_,
321
+ bias=False,
322
+ device=device,
323
+ dtype=self.dtype_,
324
+ )
325
+ self.act_ = (
326
+ ACT2FN[config.act_fn_]
327
+ if isinstance(config.act_fn_, str)
328
+ else config.act_fn_
329
+ )
330
+ self.experts_: int = config.num_experts_
331
+ self.top_p_: float = config.top_p_
332
+ self.temperature_: float = config.temperature_
333
+ self.jitter_noise_: float = config.jitter_noise_
334
+ self.router_profile_: bool = False
335
+ self.profiler_: List[int] = None
336
+
337
+ if gate is None:
338
+ torch.nn.init.normal_(
339
+ self.gate_.weight,
340
+ mean=0.0,
341
+ std=config.router_init_range_,
342
+ )
343
+ else:
344
+ with torch.no_grad():
345
+ self.gate_.weight.copy_(gate)
346
+
347
+ def state_dict(self) -> Dict[str, torch.nn.Module]:
348
+ return {"gate": self.gate_.weight}
349
+
350
+ def _profiling(
351
+ self, batch_size: int, sequence_length: int, selected_experts: torch.Tensor
352
+ ) -> None:
353
+ if not self.router_profile_:
354
+ return
355
+
356
+ router_statistic_ = list(0 for _ in range(self.experts_))
357
+ for selected in selected_experts.tolist():
358
+ for idx in selected:
359
+ router_statistic_[idx] += 1
360
+
361
+ if self.profiler_ is None:
362
+ self.profiler_ = list(0 for _ in range(self.experts_))
363
+ for idx in range(self.experts_):
364
+ self.profiler_[idx] = (
365
+ router_statistic_[idx] / batch_size
366
+ ) / sequence_length
367
+ else:
368
+ for idx in range(self.experts_):
369
+ pressure = (router_statistic_[idx] / batch_size) / sequence_length
370
+ self.profiler_[idx] = (self.profiler_[idx] + pressure) / 2
371
+
372
+ def forward(
373
+ self,
374
+ hidden_states: torch.Tensor,
375
+ ffn_layer: LLMFeedForward,
376
+ input_args: LLMModelInput,
377
+ ) -> Tuple:
378
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
379
+
380
+ if not input_args.inference_mode_ and self.jitter_noise_ > 0:
381
+ # Multiply the token inputs by the uniform distribution - adding some noise
382
+ hidden_states *= torch.empty_like(hidden_states).uniform_(
383
+ 1.0 - self.jitter_noise_, 1.0 + self.jitter_noise_
384
+ )
385
+
386
+ input_dtype = hidden_states.dtype
387
+ hidden_states = hidden_states.view(-1, hidden_dim).to(self.dtype_)
388
+ # router_logits: (batch * sequence_length, n_experts)
389
+ router_logits = self.gate_(hidden_states)
390
+
391
+ routing_weights = F.softmax(router_logits, dim=1, dtype=self.dtype_)
392
+ routing_weights, selected_experts = _dynamic_top_p(
393
+ routing_weights, self.top_p_, self.temperature_
394
+ )
395
+
396
+ self._profiling(batch_size, sequence_length, selected_experts)
397
+
398
+ final_hidden_states = torch.zeros(
399
+ (batch_size * sequence_length, hidden_dim),
400
+ dtype=self.dtype_,
401
+ device=hidden_states.device,
402
+ )
403
+
404
+ expert_mask = torch.empty(
405
+ (self.experts_, self.experts_, batch_size * sequence_length),
406
+ dtype=self.dtype_,
407
+ device=hidden_states.device,
408
+ )
409
+
410
+ for expert_idx in range(self.experts_):
411
+ expert_mask[expert_idx] = (selected_experts == expert_idx).transpose(0, 1)
412
+
413
+ # Perform the computation on each expert
414
+ if input_args.efficient_operator_ and hasattr(ffn_layer, "_mixlora_forward"):
415
+ expert_states = ffn_layer._mixlora_forward(
416
+ self.adapter_name_, self.act_, expert_mask, hidden_states, input_dtype
417
+ )
418
+ else:
419
+ expert_states = _mixlora_compatible_forward(
420
+ ffn_layer,
421
+ self.adapter_name_,
422
+ self.act_,
423
+ expert_mask,
424
+ hidden_states,
425
+ input_dtype,
426
+ )
427
+
428
+ # Unpack
429
+ for expert_idx in range(self.experts_):
430
+ idx, top_x = torch.where(expert_mask[expert_idx])
431
+
432
+ # Index the correct hidden states and compute the expert hidden state for
433
+ # the current expert. We need to make sure to multiply the output hidden
434
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
435
+ current_hidden_states = (
436
+ expert_states[expert_idx] * routing_weights[top_x, idx, None]
437
+ )
438
+
439
+ # However `index_add_` only support torch tensors for indexing so we'll use
440
+ # the `top_x` tensor here.
441
+ final_hidden_states.index_add_(
442
+ 0, top_x, current_hidden_states.to(self.dtype_)
443
+ )
444
+
445
+ final_hidden_states = final_hidden_states.reshape(
446
+ batch_size, sequence_length, hidden_dim
447
+ ).to(input_dtype)
448
+
449
+ return final_hidden_states, router_logits
450
+
451
+
452
+ def _switch_router_z_loss_func(router_logits: torch.Tensor) -> float:
453
+ log_z = torch.logsumexp(router_logits, dim=-1)
454
+ z_loss = log_z**2
455
+ return torch.sum(z_loss) / (router_logits.size(0))
456
+
457
+
458
+ def _switch_load_balancing_loss_func(router_probs: torch.Tensor) -> float:
459
+ num_experts = router_probs.size(-1)
460
+
461
+ expert_mask = torch.argmax(router_probs, dim=-1)
462
+ expert_mask = torch.nn.functional.one_hot(expert_mask, num_classes=num_experts)
463
+
464
+ tokens_per_group_and_expert = torch.mean(expert_mask.float(), dim=0)
465
+
466
+ router_prob_per_group_and_expert = torch.mean(router_probs, dim=0)
467
+ return torch.mean(
468
+ tokens_per_group_and_expert * router_prob_per_group_and_expert
469
+ ) * (num_experts**2)
470
+
471
+
472
+ class SwitchRouterLoss(torch.nn.Module):
473
+ def __init__(self, config: MixLoraConfig) -> None:
474
+ super().__init__()
475
+ self.experts = config.num_experts_
476
+ self.expert_capacity_ = config.expert_capacity_
477
+ self.z_loss_coef = config.router_z_loss_coef_
478
+ self.aux_loss_coef = config.router_aux_loss_coef_
479
+
480
+ def forward(self, router_logits, attention_mask) -> torch.Tensor:
481
+ z_loss = _switch_router_z_loss_func(router_logits)
482
+ router_probs = F.softmax(router_logits, dim=-1)
483
+ # recompute expert indexes due to MoE-PEFT constraints
484
+ aux_loss = _switch_load_balancing_loss_func(router_probs)
485
+ return self.z_loss_coef * z_loss + self.aux_loss_coef * aux_loss
486
+
487
+
488
+ class SwitchSparseMoe(LLMMoeBlock):
489
+ def __init__(
490
+ self,
491
+ in_features: int,
492
+ device: torch.device,
493
+ config: MixLoraConfig,
494
+ gate: Optional[torch.Tensor] = None,
495
+ ) -> None:
496
+ super().__init__()
497
+
498
+ self.adapter_name_: str = config.adapter_name
499
+ self.dtype_: torch.dtype = torch.float32
500
+ self.gate_ = torch.nn.Linear(
501
+ in_features,
502
+ config.num_experts_,
503
+ bias=False,
504
+ device=device,
505
+ dtype=self.dtype_,
506
+ )
507
+ self.act_ = (
508
+ ACT2FN[config.act_fn_]
509
+ if isinstance(config.act_fn_, str)
510
+ else config.act_fn_
511
+ )
512
+ self.experts_: int = config.num_experts_
513
+ self.dropout_ = (
514
+ torch.nn.Dropout(config.ffn_dropout_)
515
+ if config.ffn_dropout_ > 0
516
+ else torch.nn.Identity()
517
+ )
518
+ self.expert_capacity_: int = config.expert_capacity_
519
+ self.jitter_noise_: float = config.jitter_noise_
520
+ self.router_profile_: bool = False
521
+ self.profiler_: List[int] = None
522
+
523
+ if gate is None:
524
+ torch.nn.init.normal_(
525
+ self.gate_.weight,
526
+ mean=0.0,
527
+ std=config.router_init_range_,
528
+ )
529
+ else:
530
+ with torch.no_grad():
531
+ self.gate_.weight.copy_(gate)
532
+
533
+ def _profiling(
534
+ self, batch_size: int, sequence_length: int, router_mask: torch.Tensor
535
+ ) -> None:
536
+ if not self.router_profile_:
537
+ return
538
+
539
+ selected_experts = torch.argmax(router_mask, dim=-1)
540
+
541
+ router_statistic_ = list(0 for _ in range(self.experts_))
542
+ for selected in selected_experts.tolist():
543
+ for idx in selected:
544
+ router_statistic_[idx] += 1
545
+
546
+ if self.profiler_ is None:
547
+ self.profiler_ = list(0 for _ in range(self.experts_))
548
+ for idx in range(self.experts_):
549
+ self.profiler_[idx] = (
550
+ router_statistic_[idx] / batch_size
551
+ ) / sequence_length
552
+ else:
553
+ for idx in range(self.experts_):
554
+ pressure = (router_statistic_[idx] / batch_size) / sequence_length
555
+ self.profiler_[idx] = (self.profiler_[idx] + pressure) / 2
556
+
557
+ def route(self, hidden_states: torch.Tensor, input_args: LLMModelInput) -> Tuple:
558
+ if not input_args.inference_mode_ and self.jitter_noise_ > 0:
559
+ # Multiply the token inputs by the uniform distribution - adding some noise
560
+ hidden_states = hidden_states * torch.empty_like(hidden_states).uniform_(
561
+ 1.0 - self.jitter_noise_, 1.0 + self.jitter_noise_
562
+ )
563
+
564
+ # Apply Softmax
565
+ router_logits = self.gate_(hidden_states)
566
+ router_probs = F.softmax(router_logits, dim=-1, dtype=self.dtype_)
567
+
568
+ expert_index = torch.argmax(router_probs, dim=-1)
569
+ expert_index = torch.nn.functional.one_hot(
570
+ expert_index, num_classes=self.experts_
571
+ )
572
+
573
+ # Mask tokens outside expert capacity. Sum over each sequence
574
+ token_priority = torch.cumsum(expert_index, dim=-2)
575
+ # mask if the token routed to to the expert will overflow
576
+ expert_capacity_mask = token_priority <= self.expert_capacity_
577
+ expert_index = expert_index * expert_capacity_mask
578
+
579
+ router_probs = torch.max(router_probs, dim=-1).values.unsqueeze(-1)
580
+ return expert_index, router_probs, router_logits
581
+
582
+ def forward(
583
+ self,
584
+ hidden_states: torch.Tensor,
585
+ ffn_layer: LLMFeedForward,
586
+ input_args: LLMModelInput,
587
+ ) -> Tuple:
588
+ batch_size, sequence_length, _ = hidden_states.shape
589
+
590
+ input_dtype = hidden_states.dtype
591
+ hidden_states = hidden_states.to(self.dtype_)
592
+
593
+ router_mask, router_probs, router_logits = self.route(hidden_states, input_args)
594
+
595
+ self._profiling(batch_size, sequence_length, router_mask)
596
+
597
+ next_states = hidden_states.clone()
598
+ for expert_idx in range(self.experts_):
599
+ token_indices = router_mask[:, :, expert_idx].bool()
600
+ lora_name = f"moe.{self.adapter_name_}.experts.{expert_idx}"
601
+ next_states[token_indices] = ffn_layer._lora_forward(
602
+ lora_name, self.act_, hidden_states[token_indices].to(input_dtype)
603
+ ).to(next_states.dtype)
604
+
605
+ if input_args.inference_mode_:
606
+ hidden_states = hidden_states.to(input_dtype)
607
+ else:
608
+ hidden_states = self.dropout_(router_probs * next_states).to(input_dtype)
609
+
610
+ return hidden_states, router_logits.reshape(-1, self.experts_)
c2cite/adapters/mola/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .config import MolaConfig
2
+ from .model import MolaRouterLoss, MolaSparseMoe
3
+
4
+ __all__ = [
5
+ "MolaConfig",
6
+ "MolaSparseMoe",
7
+ "MolaRouterLoss",
8
+ ]
c2cite/adapters/mola/config.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from dataclasses import dataclass
3
+ from typing import Dict
4
+
5
+ from moe_peft.common import LoraConfig
6
+
7
+
8
+ @dataclass
9
+ class MolaConfig(LoraConfig):
10
+ top_k_: int = None
11
+ num_experts_: int = None
12
+ routing_strategy_: str = "mola"
13
+ router_init_range_: float = None
14
+ # this router loss is copied from MixLoRA
15
+ # and only for test MoE-PEFT propose
16
+ router_aux_loss_coef_: float = None
17
+ router_loss_: bool = True
18
+
19
+ def check(self) -> "MolaConfig":
20
+ super().check()
21
+ assert isinstance(self.top_k_, int) and self.top_k_ > 0
22
+ assert isinstance(self.num_experts_, int) and self.num_experts_ > 0
23
+ assert (
24
+ isinstance(self.router_init_range_, float) and self.router_init_range_ >= 0
25
+ )
26
+ assert (
27
+ isinstance(self.router_aux_loss_coef_, float)
28
+ and self.router_aux_loss_coef_ >= 0
29
+ )
30
+ assert isinstance(self.router_loss_, bool)
31
+
32
+ return self
33
+
34
+ @staticmethod
35
+ def from_config(config: Dict[str, any]) -> "MolaConfig":
36
+ return MolaConfig(
37
+ top_k_=config.get("top_k", 2),
38
+ num_experts_=config["num_experts"],
39
+ router_init_range_=config.get("router_init_range", 5.0),
40
+ router_aux_loss_coef_=config.get("router_aux_loss_coef", 0.001),
41
+ router_loss_=config.get("router_loss", False),
42
+ **LoraConfig.from_config(config).__dict__,
43
+ )
44
+
45
+ def export(self) -> Dict[str, any]:
46
+ config = super().export()
47
+ config["peft_type"] = "MOLA"
48
+ config["routing_strategy"] = self.routing_strategy_
49
+ config["num_experts"] = self.num_experts_
50
+ config["top_k"] = self.top_k_
51
+
52
+ return config
53
+
54
+ def expert_config(self, expert_idx: int) -> LoraConfig:
55
+ config = copy.deepcopy(super())
56
+ config.adapter_name = f"moe.{self.adapter_name}.experts.{expert_idx}"
57
+ return config
c2cite/adapters/mola/model.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from moe_peft.common import Linear, LLMMoeBlock
8
+
9
+ from .config import MolaConfig
10
+
11
+
12
+ # copied from mixlora.model._mixtral_load_balancing_loss_func
13
+ def _mixtral_load_balancing_loss_func(
14
+ gate_logits: torch.Tensor,
15
+ num_experts: int,
16
+ top_k: int,
17
+ attention_mask: Optional[torch.Tensor] = None,
18
+ ) -> float:
19
+ routing_weights = torch.nn.functional.softmax(gate_logits, dim=-1)
20
+
21
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
22
+
23
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
24
+
25
+ if attention_mask is None:
26
+ # Compute the percentage of tokens routed to each experts
27
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
28
+
29
+ # Compute the average probability of routing to these experts
30
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
31
+ else:
32
+ batch_size, sequence_length = attention_mask.shape
33
+ num_hidden_layers = routing_weights.shape[0] // (batch_size * sequence_length)
34
+
35
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
36
+ expert_attention_mask = (
37
+ attention_mask[None, :, :, None, None]
38
+ .expand(
39
+ (num_hidden_layers, batch_size, sequence_length, top_k, num_experts)
40
+ )
41
+ .reshape(-1, top_k, num_experts)
42
+ .to(routing_weights.device)
43
+ )
44
+
45
+ # Compute the percentage of tokens routed to each experts
46
+ tokens_per_expert = torch.sum(
47
+ expert_mask.float() * expert_attention_mask, dim=0
48
+ ) / torch.sum(expert_attention_mask, dim=0)
49
+
50
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
51
+ router_per_expert_attention_mask = (
52
+ attention_mask[None, :, :, None]
53
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
54
+ .reshape(-1, num_experts)
55
+ .to(routing_weights.device)
56
+ )
57
+
58
+ # Compute the average probability of routing to these experts
59
+ router_prob_per_expert = torch.sum(
60
+ routing_weights * router_per_expert_attention_mask, dim=0
61
+ ) / torch.sum(router_per_expert_attention_mask, dim=0)
62
+
63
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
64
+ return overall_loss * num_experts
65
+
66
+
67
+ class MolaRouterLoss(torch.nn.Module):
68
+ def __init__(self, config: MolaConfig) -> None:
69
+ super().__init__()
70
+ self.aux_loss_coef = config.router_aux_loss_coef_
71
+ self.experts = config.num_experts_
72
+ self.topk = config.top_k_
73
+
74
+ def forward(self, gate_logits, attention_mask) -> torch.Tensor:
75
+ return self.aux_loss_coef * _mixtral_load_balancing_loss_func(
76
+ gate_logits, self.experts, self.topk, attention_mask
77
+ )
78
+
79
+
80
+ class MolaSparseMoe(LLMMoeBlock):
81
+ def __init__(
82
+ self,
83
+ in_features: int,
84
+ device: torch.device,
85
+ config: MolaConfig,
86
+ gate: Optional[torch.Tensor] = None,
87
+ ) -> None:
88
+ super().__init__()
89
+
90
+ self.adapter_name_: str = config.adapter_name
91
+ self.dtype_: torch.dtype = torch.float32
92
+ self.gate_ = torch.nn.Linear(
93
+ in_features,
94
+ config.num_experts_,
95
+ bias=False,
96
+ device=device,
97
+ dtype=torch.float32,
98
+ )
99
+ self.experts_ = config.num_experts_
100
+ self.topk_ = config.top_k_
101
+ self.router_logits_: torch.Tensor = None
102
+
103
+ if gate is None:
104
+ torch.nn.init.kaiming_uniform_(
105
+ self.gate_.weight, a=math.sqrt(config.router_init_range_)
106
+ )
107
+ else:
108
+ with torch.no_grad():
109
+ self.gate_.weight.copy_(gate)
110
+
111
+ def forward(
112
+ self,
113
+ residual: torch.Tensor,
114
+ hidden_states: torch.Tensor,
115
+ lora_linear: Optional[Linear] = None,
116
+ ):
117
+ assert lora_linear is not None
118
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
119
+ input_dtype = hidden_states.dtype
120
+ hidden_states = hidden_states.view(-1, hidden_dim).to(self.dtype_)
121
+ router_logits = self.gate_(hidden_states)
122
+ self.router_logits_ = router_logits.reshape(-1, self.experts_)
123
+ routing_weights_before = F.softmax(router_logits, dim=1, dtype=self.dtype_)
124
+
125
+ routing_weights, selected_experts = torch.topk(
126
+ routing_weights_before, self.topk_, dim=-1
127
+ )
128
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
129
+
130
+ expert_mask = torch.nn.functional.one_hot(
131
+ selected_experts, num_classes=self.experts_
132
+ ).permute(2, 1, 0)
133
+
134
+ final_hidden_states = torch.zeros(
135
+ (batch_size * sequence_length, lora_linear.out_features_),
136
+ dtype=self.dtype_,
137
+ device=hidden_states.device,
138
+ )
139
+
140
+ for expert_idx in range(self.experts_):
141
+ expert_lora = lora_linear.loras_[
142
+ f"moe.{self.adapter_name_}.experts.{expert_idx}"
143
+ ]
144
+ idx, top_x = torch.where(expert_mask[expert_idx])
145
+
146
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
147
+ current_hidden_states = (
148
+ expert_lora.lora_forward(current_state)
149
+ * routing_weights[top_x, idx, None]
150
+ )
151
+ final_hidden_states.index_add_(
152
+ 0, top_x, current_hidden_states.to(self.dtype_)
153
+ )
154
+
155
+ final_hidden_states = final_hidden_states.reshape(
156
+ batch_size, sequence_length, lora_linear.out_features_
157
+ ).to(input_dtype)
158
+
159
+ return residual + final_hidden_states
c2cite/common/__init__.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Basic Abstract Class
2
+ from .abstracts import (
3
+ LLMAttention,
4
+ LLMCache,
5
+ LLMDecoder,
6
+ LLMFeedForward,
7
+ LLMForCausalLM,
8
+ LLMMoeBlock,
9
+ LLMOutput,
10
+ )
11
+ from .attention import (
12
+ eager_attention_forward,
13
+ flash_attention_forward,
14
+ prepare_4d_causal_attention_mask,
15
+ )
16
+ from .cache import (
17
+ DynamicCache,
18
+ HybridCache,
19
+ SlidingWindowCache,
20
+ StaticCache,
21
+ cache_factory,
22
+ )
23
+ from .checkpoint import (
24
+ CHECKPOINT_CLASSES,
25
+ CheckpointNoneFunction,
26
+ CheckpointOffloadFunction,
27
+ CheckpointRecomputeFunction,
28
+ )
29
+
30
+ # Model Configuration
31
+ from .config import (
32
+ AdapterConfig,
33
+ InputData,
34
+ Labels,
35
+ LLMBatchConfig,
36
+ LLMModelConfig,
37
+ LLMModelInput,
38
+ LLMModelOutput,
39
+ LoraConfig,
40
+ Masks,
41
+ Prompt,
42
+ Tokens,
43
+ )
44
+ from .feed_forward import FeedForward
45
+
46
+ # LoRA
47
+ from .lora_linear import Linear, Lora, get_range_tensor
48
+
49
+ # MoEs
50
+ from .moe_utils import collect_plugin_router_logtis, slice_tensor, unpack_router_logits
51
+ from .rope import ROPE_INIT_FUNCTIONS
52
+
53
+ __all__ = [
54
+ "prepare_4d_causal_attention_mask",
55
+ "eager_attention_forward",
56
+ "flash_attention_forward",
57
+ "LLMCache",
58
+ "DynamicCache",
59
+ "HybridCache",
60
+ "SlidingWindowCache",
61
+ "StaticCache",
62
+ "cache_factory",
63
+ "CheckpointNoneFunction",
64
+ "CheckpointOffloadFunction",
65
+ "CheckpointRecomputeFunction",
66
+ "CHECKPOINT_CLASSES",
67
+ "FeedForward",
68
+ "slice_tensor",
69
+ "unpack_router_logits",
70
+ "collect_plugin_router_logtis",
71
+ "get_range_tensor",
72
+ "Lora",
73
+ "Linear",
74
+ "LLMAttention",
75
+ "LLMFeedForward",
76
+ "LLMMoeBlock",
77
+ "LLMDecoder",
78
+ "LLMOutput",
79
+ "LLMForCausalLM",
80
+ "Tokens",
81
+ "Labels",
82
+ "Masks",
83
+ "Prompt",
84
+ "InputData",
85
+ "LLMModelConfig",
86
+ "LLMModelOutput",
87
+ "LLMBatchConfig",
88
+ "LLMModelInput",
89
+ "AdapterConfig",
90
+ "LoraConfig",
91
+ "ROPE_INIT_FUNCTIONS",
92
+ ]
c2cite/common/abstracts.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABCMeta
2
+ from typing import Any, Dict, List, Optional, Tuple
3
+
4
+ import torch
5
+
6
+ from .config import LLMModelConfig, LLMModelInput
7
+
8
+
9
+ class LLMCache(torch.nn.Module):
10
+ def __init__(self):
11
+ super().__init__()
12
+
13
+ def update(
14
+ self,
15
+ key_states: torch.Tensor,
16
+ value_states: torch.Tensor,
17
+ layer_idx: int,
18
+ cache_kwargs: Optional[Dict[str, Any]] = None,
19
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
20
+ raise NotImplementedError("Make sure to implement `update` in a subclass.")
21
+
22
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
23
+ # TODO: deprecate this function in favor of `cache_position`
24
+ raise NotImplementedError(
25
+ "Make sure to implement `get_seq_length` in a subclass."
26
+ )
27
+
28
+ def get_max_length(self) -> Optional[int]:
29
+ raise NotImplementedError(
30
+ "Make sure to implement `get_max_length` in a subclass."
31
+ )
32
+
33
+ def get_usable_length(
34
+ self, new_seq_length: int, layer_idx: Optional[int] = 0
35
+ ) -> int:
36
+ max_length = self.get_max_length()
37
+ previous_seq_length = self.get_seq_length(layer_idx)
38
+ if max_length is not None and previous_seq_length + new_seq_length > max_length:
39
+ return max_length - new_seq_length
40
+ return previous_seq_length
41
+
42
+ def reorder_cache(self, beam_idx: torch.LongTensor):
43
+ for layer_idx in range(len(self.key_cache)):
44
+ device = self.key_cache[layer_idx].device
45
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
46
+ 0, beam_idx.to(device)
47
+ )
48
+ device = self.value_cache[layer_idx].device
49
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(
50
+ 0, beam_idx.to(device)
51
+ )
52
+
53
+
54
+ class LLMAttention(metaclass=ABCMeta):
55
+ @classmethod
56
+ def state_dict(self) -> Dict[str, torch.nn.Module]:
57
+ return {}
58
+
59
+ @classmethod
60
+ def forward(
61
+ self,
62
+ hidden_states: torch.Tensor,
63
+ input_args: LLMModelInput,
64
+ rotary_emb: Tuple[torch.Tensor, torch.Tensor],
65
+ attention_mask: Optional[torch.Tensor] = None,
66
+ cache_position: Optional[torch.Tensor] = None,
67
+ past_key_value: Optional[LLMCache] = None,
68
+ ):
69
+ pass
70
+
71
+
72
+ class LLMFeedForward(metaclass=ABCMeta):
73
+ @classmethod
74
+ def state_dict(self) -> Dict[str, torch.nn.Module]:
75
+ return {}
76
+
77
+ @classmethod
78
+ def _batch_forward(
79
+ self, hidden_states: torch.Tensor, input_args: LLMModelInput
80
+ ) -> torch.Tensor:
81
+ pass
82
+
83
+ @classmethod
84
+ def _lora_forward(
85
+ self, lora_name: str, act_fn: torch.nn.Module, data: torch.Tensor
86
+ ) -> torch.Tensor:
87
+ pass
88
+
89
+
90
+ class LLMMoeBlock(metaclass=ABCMeta):
91
+ def __init__(self) -> None:
92
+ super().__init__()
93
+
94
+ self.adapter_name_: str = None
95
+ self.dtype_: torch.dtype = None
96
+ self.gate_: torch.nn.Linear = None
97
+ self.experts_: int = None
98
+ self.router_profile_: bool = False
99
+ self.profiler_: List[int] = None
100
+
101
+ @classmethod
102
+ def forward(
103
+ self,
104
+ residual: torch.Tensor,
105
+ hidden_states: torch.Tensor,
106
+ **kwargs,
107
+ ) -> Tuple:
108
+ pass
109
+
110
+
111
+ class LLMDecoder(metaclass=ABCMeta):
112
+ def __init__(self) -> None:
113
+ super().__init__()
114
+ self.self_attn_: LLMAttention = None
115
+ self.mlp_: LLMFeedForward = None
116
+
117
+ @classmethod
118
+ def state_dict(
119
+ self,
120
+ ) -> Tuple[Dict[str, torch.nn.Module], Dict[str, torch.nn.Module]]:
121
+ return {}
122
+
123
+ @classmethod
124
+ def forward(
125
+ self,
126
+ hidden_states: torch.Tensor,
127
+ input_args: LLMModelInput,
128
+ rotary_emb: Tuple[torch.Tensor, torch.Tensor],
129
+ attention_mask: Optional[torch.Tensor] = None,
130
+ cache_position: Optional[torch.Tensor] = None,
131
+ past_key_value: Optional[LLMCache] = None,
132
+ ):
133
+ pass
134
+
135
+
136
+ class LLMOutput(metaclass=ABCMeta):
137
+ @classmethod
138
+ def state_dict(self) -> Dict[str, torch.nn.Module]:
139
+ return {}
140
+
141
+ @classmethod
142
+ def forward(self, data: torch.Tensor) -> torch.Tensor:
143
+ pass
144
+
145
+ @classmethod
146
+ def loss(
147
+ self,
148
+ input_ids: torch.Tensor,
149
+ output_logits: torch.Tensor,
150
+ labels: List[List[int]],
151
+ ) -> torch.Tensor:
152
+ pass
153
+
154
+
155
+ class LLMForCausalLM(metaclass=ABCMeta):
156
+ @classmethod
157
+ def embed_tokens(self, input_ids: torch.Tensor) -> torch.Tensor:
158
+ pass
159
+
160
+ @classmethod
161
+ def rotary_embed(
162
+ self, input_tensor: torch.Tensor, position_ids: torch.Tensor
163
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
164
+ pass
165
+
166
+ @classmethod
167
+ def decoder_stack(self) -> List[LLMDecoder]:
168
+ pass
169
+
170
+ @classmethod
171
+ def norm(self, hidden_states: torch.Tensor) -> torch.Tensor:
172
+ pass
173
+
174
+ @classmethod
175
+ def causal_mask(
176
+ self,
177
+ attention_mask: torch.Tensor,
178
+ input_tensor: torch.Tensor,
179
+ cache_position: torch.Tensor,
180
+ past_key_values: Optional[LLMCache],
181
+ ) -> torch.Tensor:
182
+ pass
183
+
184
+ @classmethod
185
+ def cache_implementation(self) -> str:
186
+ return "dynamic"
187
+
188
+ @classmethod
189
+ def model_config(self) -> LLMModelConfig:
190
+ pass
191
+
192
+ @staticmethod
193
+ def from_pretrained(llm_model, **kwargs):
194
+ pass
c2cite/common/attention.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from typing import Optional, Tuple
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from transformers.utils import is_flash_attn_2_available
8
+
9
+ from .cache import LLMCache, StaticCache
10
+
11
+ if is_flash_attn_2_available():
12
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
13
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
14
+
15
+ _flash_supports_window_size = "window_size" in list(
16
+ inspect.signature(flash_attn_func).parameters
17
+ )
18
+
19
+
20
+ def prepare_4d_causal_attention_mask(
21
+ attention_mask: torch.Tensor,
22
+ input_tensor: torch.Tensor,
23
+ cache_position: torch.Tensor,
24
+ past_key_values: LLMCache,
25
+ ) -> torch.Tensor:
26
+ past_seen_tokens = (
27
+ past_key_values.get_seq_length() if past_key_values is not None else 0
28
+ )
29
+
30
+ if past_seen_tokens is None:
31
+ past_seen_tokens = 0
32
+
33
+ using_static_cache = isinstance(past_key_values, StaticCache)
34
+
35
+ dtype, device = input_tensor.dtype, input_tensor.device
36
+ min_dtype = torch.finfo(dtype).min
37
+ sequence_length = input_tensor.shape[1]
38
+ if using_static_cache:
39
+ target_length = past_key_values.get_max_length()
40
+ else:
41
+ target_length = (
42
+ attention_mask.shape[-1]
43
+ if isinstance(attention_mask, torch.Tensor)
44
+ else past_seen_tokens + sequence_length + 1
45
+ )
46
+
47
+ causal_mask = torch.full(
48
+ (sequence_length, target_length),
49
+ fill_value=min_dtype,
50
+ dtype=dtype,
51
+ device=device,
52
+ )
53
+ if sequence_length != 1:
54
+ causal_mask = torch.triu(causal_mask, diagonal=1)
55
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(
56
+ -1, 1
57
+ )
58
+ causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
59
+ if attention_mask is not None:
60
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
61
+ mask_length = attention_mask.shape[-1]
62
+ padding_mask = (
63
+ causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
64
+ )
65
+ padding_mask = padding_mask == 0
66
+ causal_mask[:, :, :, :mask_length] = causal_mask[
67
+ :, :, :, :mask_length
68
+ ].masked_fill(padding_mask, min_dtype)
69
+
70
+ return causal_mask
71
+
72
+
73
+ def eager_attention_forward(
74
+ query_states: torch.Tensor,
75
+ key_states: torch.Tensor,
76
+ value_states: torch.Tensor,
77
+ attention_mask: Optional[torch.Tensor] = None,
78
+ ) -> torch.Tensor:
79
+ attention_score = torch.matmul(
80
+ query_states, key_states.transpose(2, 3)
81
+ ) / math.sqrt(query_states.size(-1))
82
+ if attention_mask is not None:
83
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
84
+ attention_score = attention_score + causal_mask
85
+ attention_score = F.softmax(attention_score, dim=-1, dtype=torch.float32).to(
86
+ value_states.dtype
87
+ )
88
+ attention_matrix = attention_score
89
+ attention_score = torch.matmul(attention_score, value_states)
90
+ attention_score = attention_score.transpose(1, 2).contiguous()
91
+ return attention_score, attention_matrix
92
+
93
+
94
+ def _get_unpad_data(
95
+ attention_mask: torch.Tensor,
96
+ ) -> Tuple[torch.Tensor, torch.Tensor, int]:
97
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
98
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
99
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
100
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
101
+ return (
102
+ indices,
103
+ cu_seqlens,
104
+ max_seqlen_in_batch,
105
+ )
106
+
107
+
108
+ def _upad_input(
109
+ query_layer: torch.Tensor,
110
+ key_layer: torch.Tensor,
111
+ value_layer: torch.Tensor,
112
+ attention_mask: torch.Tensor,
113
+ query_length: int,
114
+ ):
115
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
116
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
117
+
118
+ key_layer = index_first_axis(
119
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
120
+ indices_k,
121
+ )
122
+ value_layer = index_first_axis(
123
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
124
+ indices_k,
125
+ )
126
+ if query_length == kv_seq_len:
127
+ query_layer = index_first_axis(
128
+ query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k
129
+ )
130
+ cu_seqlens_q = cu_seqlens_k
131
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
132
+ indices_q = indices_k
133
+ elif query_length == 1:
134
+ max_seqlen_in_batch_q = 1
135
+ cu_seqlens_q = torch.arange(
136
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
137
+ ) # There is a memcpy here, that is very bad.
138
+ indices_q = cu_seqlens_q[:-1]
139
+ query_layer = query_layer.squeeze(1)
140
+ else:
141
+ # The -q_len: slice assumes left padding.
142
+ attention_mask = attention_mask[:, -query_length:]
143
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
144
+ query_layer, attention_mask
145
+ )
146
+
147
+ return (
148
+ query_layer,
149
+ key_layer,
150
+ value_layer,
151
+ indices_q,
152
+ (cu_seqlens_q, cu_seqlens_k),
153
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
154
+ )
155
+
156
+
157
+ def prepare_fa2_from_position_ids(query, key, value, position_ids):
158
+ query = query.view(-1, query.size(-2), query.size(-1))
159
+ key = key.view(-1, key.size(-2), key.size(-1))
160
+ value = value.view(-1, value.size(-2), value.size(-1))
161
+ position_ids = position_ids.flatten()
162
+ indices_q = torch.arange(
163
+ position_ids.size(0), device=position_ids.device, dtype=torch.int32
164
+ )
165
+
166
+ cu_seq_lens = torch.cat(
167
+ (
168
+ indices_q[position_ids == 0],
169
+ torch.tensor(
170
+ position_ids.size(), device=position_ids.device, dtype=torch.int32
171
+ ),
172
+ )
173
+ )
174
+
175
+ max_length = position_ids.max() + 1
176
+
177
+ return (
178
+ query,
179
+ key,
180
+ value,
181
+ indices_q,
182
+ (cu_seq_lens, cu_seq_lens),
183
+ (max_length, max_length),
184
+ )
185
+
186
+
187
+ def flash_attention_forward(
188
+ query_states: torch.Tensor,
189
+ key_states: torch.Tensor,
190
+ value_states: torch.Tensor,
191
+ attention_mask: torch.Tensor,
192
+ query_length: int,
193
+ is_causal: bool,
194
+ dropout: float = 0.0,
195
+ position_ids: Optional[torch.Tensor] = None,
196
+ softmax_scale: Optional[float] = None,
197
+ sliding_window: Optional[int] = None,
198
+ use_top_left_mask: bool = False,
199
+ softcap: Optional[float] = None,
200
+ deterministic: Optional[bool] = None,
201
+ ):
202
+ if not use_top_left_mask:
203
+ causal = is_causal
204
+ else:
205
+ causal = is_causal and query_length != 1
206
+
207
+ # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
208
+ use_sliding_windows = (
209
+ _flash_supports_window_size
210
+ and sliding_window is not None
211
+ and key_states.shape[1] > sliding_window
212
+ )
213
+ flash_kwargs = (
214
+ {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
215
+ )
216
+
217
+ if deterministic is not None:
218
+ flash_kwargs["deterministic"] = deterministic
219
+
220
+ if softcap is not None:
221
+ flash_kwargs["softcap"] = softcap
222
+
223
+ # Contains at least one padding token in the sequence
224
+ if attention_mask is not None:
225
+ batch_size = query_states.shape[0]
226
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = (
227
+ _upad_input(
228
+ query_states, key_states, value_states, attention_mask, query_length
229
+ )
230
+ )
231
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
232
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
233
+
234
+ attn_output_unpad = flash_attn_varlen_func(
235
+ query_states,
236
+ key_states,
237
+ value_states,
238
+ cu_seqlens_q=cu_seqlens_q,
239
+ cu_seqlens_k=cu_seqlens_k,
240
+ max_seqlen_q=max_seqlen_in_batch_q,
241
+ max_seqlen_k=max_seqlen_in_batch_k,
242
+ dropout_p=dropout,
243
+ softmax_scale=softmax_scale,
244
+ causal=causal,
245
+ **flash_kwargs,
246
+ )
247
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
248
+
249
+ elif (
250
+ position_ids is not None
251
+ and not (torch.diff(position_ids, dim=-1) >= 0).all()
252
+ and query_length != 1
253
+ ):
254
+ batch_size = query_states.size(0)
255
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = (
256
+ prepare_fa2_from_position_ids(
257
+ query_states, key_states, value_states, position_ids
258
+ )
259
+ )
260
+
261
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
262
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
263
+
264
+ attn_output = flash_attn_varlen_func(
265
+ query_states,
266
+ key_states,
267
+ value_states,
268
+ cu_seqlens_q=cu_seqlens_q,
269
+ cu_seqlens_k=cu_seqlens_k,
270
+ max_seqlen_q=max_seqlen_in_batch_q,
271
+ max_seqlen_k=max_seqlen_in_batch_k,
272
+ dropout_p=dropout,
273
+ softmax_scale=softmax_scale,
274
+ causal=causal,
275
+ **flash_kwargs,
276
+ )
277
+
278
+ attn_output = attn_output.view(
279
+ batch_size, -1, attn_output.size(-2), attn_output.size(-1)
280
+ )
281
+
282
+ else:
283
+ attn_output = flash_attn_func(
284
+ query_states,
285
+ key_states,
286
+ value_states,
287
+ dropout,
288
+ softmax_scale=softmax_scale,
289
+ causal=causal,
290
+ **flash_kwargs,
291
+ )
292
+
293
+ return attn_output
c2cite/common/cache.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, Dict, List, Optional, Tuple
3
+
4
+ import torch
5
+ from transformers.utils import is_torchdynamo_compiling
6
+
7
+ from .abstracts import LLMCache
8
+ from .config import LLMModelConfig
9
+
10
+
11
+ class DynamicCache(LLMCache):
12
+ def __init__(self, **kwargs) -> None:
13
+ super().__init__()
14
+ self._seen_tokens = (
15
+ 0 # Used in `generate` to keep tally of how many tokens the cache has seen
16
+ )
17
+ self.key_cache: List[torch.Tensor] = []
18
+ self.value_cache: List[torch.Tensor] = []
19
+
20
+ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
21
+ """
22
+ Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
23
+ sequence length.
24
+ """
25
+ if layer_idx < len(self):
26
+ return (self.key_cache[layer_idx], self.value_cache[layer_idx])
27
+ else:
28
+ raise KeyError(
29
+ f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}"
30
+ )
31
+
32
+ def __iter__(self):
33
+ """
34
+ Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
35
+ keys and values
36
+ """
37
+ for layer_idx in range(len(self)):
38
+ yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
39
+
40
+ def __len__(self):
41
+ """
42
+ Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
43
+ to the number of layers in the model.
44
+ """
45
+ return len(self.key_cache)
46
+
47
+ def update(
48
+ self,
49
+ key_states: torch.Tensor,
50
+ value_states: torch.Tensor,
51
+ layer_idx: int,
52
+ cache_kwargs: Optional[Dict[str, Any]] = None,
53
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
54
+ # Update the number of seen tokens
55
+ if layer_idx == 0:
56
+ self._seen_tokens += key_states.shape[-2]
57
+
58
+ # Update the cache
59
+ if len(self.key_cache) <= layer_idx:
60
+ # There may be skipped layers, fill them with empty lists
61
+ for _ in range(len(self.key_cache), layer_idx):
62
+ self.key_cache.append([])
63
+ self.value_cache.append([])
64
+ self.key_cache.append(key_states)
65
+ self.value_cache.append(value_states)
66
+ elif (
67
+ len(self.key_cache[layer_idx]) == 0
68
+ ): # fills previously skipped layers; checking for tensor causes errors
69
+ self.key_cache[layer_idx] = key_states
70
+ self.value_cache[layer_idx] = value_states
71
+ else:
72
+ self.key_cache[layer_idx] = torch.cat(
73
+ [self.key_cache[layer_idx], key_states], dim=-2
74
+ )
75
+ self.value_cache[layer_idx] = torch.cat(
76
+ [self.value_cache[layer_idx], value_states], dim=-2
77
+ )
78
+
79
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
80
+
81
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
82
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
83
+ # TODO: deprecate this function in favor of `cache_position`
84
+ is_empty_layer = (
85
+ len(self.key_cache) == 0 # no cache in any layer
86
+ or len(self.key_cache)
87
+ <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
88
+ or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
89
+ )
90
+ layer_seq_length = (
91
+ self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
92
+ )
93
+ return layer_seq_length
94
+
95
+ def get_max_length(self) -> Optional[int]:
96
+ """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
97
+ return None
98
+
99
+ def crop(self, max_length: int):
100
+ """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
101
+ negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search.
102
+ """
103
+ # In case it is negative
104
+ if max_length < 0:
105
+ max_length = self.get_seq_length() - abs(max_length)
106
+
107
+ if self.get_seq_length() <= max_length:
108
+ return
109
+
110
+ self._seen_tokens = max_length
111
+ for idx in range(len(self.key_cache)):
112
+ if self.key_cache[idx] != []:
113
+ self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
114
+ self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
115
+
116
+ def batch_split(
117
+ self, full_batch_size: int, split_size: int
118
+ ) -> List["DynamicCache"]:
119
+ """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
120
+ `_split_model_inputs()` in `generation.utils`"""
121
+ out = []
122
+ for i in range(0, full_batch_size, split_size):
123
+ current_split = DynamicCache()
124
+ current_split._seen_tokens = self._seen_tokens
125
+ current_split.key_cache = [
126
+ tensor[i : i + split_size] for tensor in self.key_cache
127
+ ]
128
+ current_split.value_cache = [
129
+ tensor[i : i + split_size] for tensor in self.value_cache
130
+ ]
131
+ out.append(current_split)
132
+ return out
133
+
134
+ @classmethod
135
+ def from_batch_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache":
136
+ """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
137
+ `generation.utils`"""
138
+ cache = cls()
139
+ for idx in range(len(splits[0])):
140
+ key_cache = [
141
+ current.key_cache[idx]
142
+ for current in splits
143
+ if current.key_cache[idx] != []
144
+ ]
145
+ value_cache = [
146
+ current.key_cache[idx]
147
+ for current in splits
148
+ if current.key_cache[idx] != []
149
+ ]
150
+ if key_cache != []:
151
+ layer_keys = torch.cat(key_cache, dim=0)
152
+ layer_values = torch.cat(value_cache, dim=0)
153
+ cache.update(layer_keys, layer_values, idx)
154
+ return cache
155
+
156
+ def batch_repeat_interleave(self, repeats: int):
157
+ """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
158
+ for layer_idx in range(len(self)):
159
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(
160
+ repeats, dim=0
161
+ )
162
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(
163
+ repeats, dim=0
164
+ )
165
+
166
+ def batch_select_indices(self, indices: torch.Tensor):
167
+ """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
168
+ for layer_idx in range(len(self)):
169
+ self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
170
+ self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
171
+
172
+
173
+ class StaticCache(LLMCache):
174
+ def __init__(
175
+ self,
176
+ config: LLMModelConfig,
177
+ batch_size: int,
178
+ max_cache_len: int,
179
+ device: torch.device,
180
+ dtype: torch.dtype = torch.float32,
181
+ ) -> None:
182
+ super().__init__()
183
+ self.batch_size = batch_size
184
+ self.max_cache_len = (
185
+ config.max_seq_len_ if max_cache_len is None else max_cache_len
186
+ )
187
+
188
+ self.head_dim = config.head_dim_
189
+
190
+ self.dtype = dtype
191
+ self.num_key_value_heads = config.n_kv_heads_
192
+
193
+ self.key_cache: List[torch.Tensor] = []
194
+ self.value_cache: List[torch.Tensor] = []
195
+ # Note: There will be significant perf decrease if switching to use 5D tensors instead.
196
+ cache_shape = (
197
+ self.batch_size,
198
+ self.num_key_value_heads,
199
+ self.max_cache_len,
200
+ self.head_dim,
201
+ )
202
+ for idx in range(config.n_layers_):
203
+ new_layer_key_cache = torch.zeros(
204
+ cache_shape, dtype=self.dtype, device=device
205
+ )
206
+ new_layer_value_cache = torch.zeros(
207
+ cache_shape, dtype=self.dtype, device=device
208
+ )
209
+ # Notes:
210
+ # 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
211
+ # breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
212
+ # it is not needed anyway)
213
+ # 2. `torch.export()` requires mutations to be registered as buffers.
214
+ if not is_torchdynamo_compiling():
215
+ self.register_buffer(
216
+ f"key_cache_{idx}",
217
+ torch.zeros(cache_shape, dtype=dtype, device=device),
218
+ )
219
+ self.register_buffer(
220
+ f"value_cache_{idx}",
221
+ torch.zeros(cache_shape, dtype=dtype, device=device),
222
+ )
223
+ new_layer_key_cache = getattr(self, f"key_cache_{idx}")
224
+ new_layer_value_cache = getattr(self, f"value_cache_{idx}")
225
+ torch._dynamo.mark_static_address(new_layer_key_cache)
226
+ torch._dynamo.mark_static_address(new_layer_value_cache)
227
+ self.key_cache.append(new_layer_key_cache)
228
+ self.value_cache.append(new_layer_value_cache)
229
+
230
+ def update(
231
+ self,
232
+ key_states: torch.Tensor,
233
+ value_states: torch.Tensor,
234
+ layer_idx: int,
235
+ cache_kwargs: Optional[Dict[str, Any]] = None,
236
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
237
+ cache_position = cache_kwargs.get("cache_position")
238
+
239
+ k_out = self.key_cache[layer_idx]
240
+ v_out = self.value_cache[layer_idx]
241
+
242
+ if cache_position is None:
243
+ k_out.copy_(key_states)
244
+ v_out.copy_(value_states)
245
+ else:
246
+ # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to
247
+ # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place
248
+ # operation, that avoids copies and uses less memory.
249
+ try:
250
+ k_out.index_copy_(2, cache_position, key_states)
251
+ v_out.index_copy_(2, cache_position, value_states)
252
+ except NotImplementedError:
253
+ # The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
254
+ k_out[:, :, cache_position] = key_states
255
+ v_out[:, :, cache_position] = value_states
256
+
257
+ return k_out, v_out
258
+
259
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
260
+ """Returns the sequence length of the cached states that were seen by the model."""
261
+ # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
262
+ # limit the check to the first batch member and head dimension.
263
+ # TODO: deprecate this function in favor of `cache_position`
264
+ return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
265
+
266
+ def get_max_length(self) -> Optional[int]:
267
+ """Returns the maximum sequence length of the cached states."""
268
+ return self.max_cache_len
269
+
270
+ def reset(self):
271
+ """Resets the cache values while preserving the objects"""
272
+ for layer_idx in range(len(self.key_cache)):
273
+ # In-place ops prevent breaking the static address
274
+ self.key_cache[layer_idx].zero_()
275
+ self.value_cache[layer_idx].zero_()
276
+
277
+
278
+ class SlidingWindowCache(StaticCache):
279
+ def __init__(
280
+ self,
281
+ config: LLMModelConfig,
282
+ batch_size: int,
283
+ max_cache_len: int,
284
+ device: torch.device,
285
+ dtype: torch.dtype = torch.float32,
286
+ ) -> None:
287
+ super().__init__()
288
+ if not hasattr(config, "sliding_window_") or config.sliding_window_ is None:
289
+ raise ValueError(
290
+ "Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
291
+ "sliding window attention, please check if there is a `sliding_window` field in the model "
292
+ "config and it's not set to None."
293
+ )
294
+ max_cache_len = min(config.sliding_window_, max_cache_len)
295
+ super().__init__(
296
+ config=config,
297
+ batch_size=batch_size,
298
+ max_cache_len=max_cache_len,
299
+ device=device,
300
+ dtype=dtype,
301
+ )
302
+
303
+ def update(
304
+ self,
305
+ key_states: torch.Tensor,
306
+ value_states: torch.Tensor,
307
+ layer_idx: int,
308
+ cache_kwargs: Optional[Dict[str, Any]] = None,
309
+ ) -> Tuple[torch.Tensor]:
310
+ cache_position = cache_kwargs.get("cache_position")
311
+ k_out = self.key_cache[layer_idx]
312
+ v_out = self.value_cache[layer_idx]
313
+
314
+ # assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len)
315
+ if cache_position.shape[0] > self.max_cache_len:
316
+ k_out = key_states[:, :, -self.max_cache_len :, :]
317
+ v_out = value_states[:, :, -self.max_cache_len :, :]
318
+ # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
319
+ self.key_cache[layer_idx] += k_out
320
+ self.value_cache[layer_idx] += v_out
321
+ # we should return the whole states instead of k_out, v_out to take the whole prompt
322
+ # into consideration when building kv cache instead of just throwing away tokens outside of the window
323
+ return key_states, value_states
324
+
325
+ slicing = torch.ones(
326
+ self.max_cache_len, dtype=torch.long, device=value_states.device
327
+ ).cumsum(0)
328
+ cache_position = cache_position.clamp(0, self.max_cache_len - 1)
329
+ to_shift = cache_position >= self.max_cache_len - 1
330
+ indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len
331
+
332
+ k_out = k_out[:, :, indices]
333
+ v_out = v_out[:, :, indices]
334
+
335
+ try:
336
+ k_out.index_copy_(2, cache_position, key_states)
337
+ v_out.index_copy_(2, cache_position, value_states)
338
+ except NotImplementedError:
339
+ # The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
340
+ k_out[:, :, cache_position] = key_states
341
+ v_out[:, :, cache_position] = value_states
342
+
343
+ # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
344
+ self.key_cache[layer_idx].zero_()
345
+ self.value_cache[layer_idx].zero_()
346
+
347
+ self.key_cache[layer_idx] += k_out
348
+ self.value_cache[layer_idx] += v_out
349
+
350
+ return k_out, v_out
351
+
352
+ def get_max_length(self) -> Optional[int]:
353
+ # in theory there is no limit because the sliding window size is fixed no matter how long the sentence is
354
+ return None
355
+
356
+ def reset(self):
357
+ for layer_idx in range(len(self.key_cache)):
358
+ # In-place ops prevent breaking the static address
359
+ self.key_cache[layer_idx].zero_()
360
+ self.value_cache[layer_idx].zero_()
361
+
362
+
363
+ class HybridCache(LLMCache):
364
+ def __init__(
365
+ self,
366
+ config: LLMModelConfig,
367
+ batch_size: int,
368
+ max_cache_len: int,
369
+ device: torch.device,
370
+ dtype: torch.dtype = torch.float32,
371
+ ) -> None:
372
+ super().__init__()
373
+ if not hasattr(config, "sliding_window_") or config.sliding_window_ is None:
374
+ raise ValueError(
375
+ "Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
376
+ "sliding window attention, please check if there is a `sliding_window` field in the model "
377
+ "config and it's not set to None."
378
+ )
379
+ self.max_cache_len = max_cache_len
380
+ self.batch_size = batch_size
381
+ self.head_dim = config.head_dim_
382
+
383
+ self.dtype = dtype
384
+ self.num_key_value_heads = config.n_kv_heads_
385
+ self.is_sliding = torch.tensor(
386
+ [not bool(i % 2) for i in range(config.n_layers_)],
387
+ dtype=torch.bool,
388
+ device=device,
389
+ )
390
+ self.key_cache: List[torch.Tensor] = []
391
+ self.value_cache: List[torch.Tensor] = []
392
+ global_cache_shape = (
393
+ self.batch_size,
394
+ self.num_key_value_heads,
395
+ max_cache_len,
396
+ self.head_dim,
397
+ )
398
+ sliding_cache_shape = (
399
+ self.batch_size,
400
+ self.num_key_value_heads,
401
+ min(config.sliding_window_, max_cache_len),
402
+ self.head_dim,
403
+ )
404
+ for i in range(config.n_layers_):
405
+ # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
406
+ # breaks when updating the cache.
407
+ cache_shape = (
408
+ global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
409
+ )
410
+ new_layer_key_cache = torch.zeros(
411
+ cache_shape, dtype=self.dtype, device=device
412
+ )
413
+ new_layer_value_cache = torch.zeros(
414
+ cache_shape, dtype=self.dtype, device=device
415
+ )
416
+ torch._dynamo.mark_static_address(new_layer_key_cache)
417
+ torch._dynamo.mark_static_address(new_layer_value_cache)
418
+ self.key_cache.append(new_layer_key_cache)
419
+ self.value_cache.append(new_layer_value_cache)
420
+
421
+ def _sliding_update(
422
+ self,
423
+ cache_position,
424
+ layer_idx,
425
+ key_states,
426
+ value_states,
427
+ k_out,
428
+ v_out,
429
+ max_cache_len,
430
+ ):
431
+ if cache_position.shape[0] > max_cache_len:
432
+ k_out = key_states[:, :, -max_cache_len:, :]
433
+ v_out = value_states[:, :, -max_cache_len:, :]
434
+ # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
435
+ self.key_cache[layer_idx] += k_out
436
+ self.value_cache[layer_idx] += v_out
437
+ # we should return the whole states instead of k_out, v_out to take the whole prompt
438
+ # into consideration when building kv cache instead of just throwing away tokens outside of the window
439
+ return key_states, value_states
440
+
441
+ slicing = torch.ones(
442
+ max_cache_len, dtype=torch.long, device=value_states.device
443
+ ).cumsum(0)
444
+ cache_position = cache_position.clamp(0, max_cache_len - 1)
445
+ to_shift = cache_position >= max_cache_len - 1
446
+ indices = (slicing + to_shift[-1].int() - 1) % max_cache_len
447
+ k_out = k_out[:, :, indices]
448
+ v_out = v_out[:, :, indices]
449
+
450
+ k_out[:, :, cache_position] = key_states
451
+ v_out[:, :, cache_position] = value_states
452
+ # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
453
+ self.key_cache[layer_idx].zero_()
454
+ self.value_cache[layer_idx].zero_()
455
+
456
+ self.key_cache[layer_idx] += k_out
457
+ self.value_cache[layer_idx] += v_out
458
+ return k_out, v_out
459
+
460
+ def _static_update(
461
+ self,
462
+ cache_position,
463
+ layer_idx,
464
+ key_states,
465
+ value_states,
466
+ k_out,
467
+ v_out,
468
+ max_cache_len,
469
+ ):
470
+ k_out[:, :, cache_position] = key_states
471
+ v_out[:, :, cache_position] = value_states
472
+
473
+ self.key_cache[layer_idx] = k_out
474
+ self.value_cache[layer_idx] = v_out
475
+ return k_out, v_out
476
+
477
+ def update(
478
+ self,
479
+ key_states: torch.Tensor,
480
+ value_states: torch.Tensor,
481
+ layer_idx: int,
482
+ cache_kwargs: Optional[Dict[str, Any]] = None,
483
+ ) -> Tuple[torch.Tensor]:
484
+ cache_position = cache_kwargs.get("cache_position")
485
+ sliding_window = cache_kwargs.get("sliding_window")
486
+ k_out = self.key_cache[layer_idx]
487
+ v_out = self.value_cache[layer_idx]
488
+ if sliding_window:
489
+ update_fn = self._sliding_update
490
+ else:
491
+ update_fn = self._static_update
492
+
493
+ return update_fn(
494
+ cache_position,
495
+ layer_idx,
496
+ key_states,
497
+ value_states,
498
+ k_out,
499
+ v_out,
500
+ k_out.shape[2],
501
+ )
502
+
503
+ def get_max_length(self) -> Optional[int]:
504
+ # in theory there is no limit because the sliding window size is fixed
505
+ # no matter how long the sentence is
506
+ return self.max_cache_len
507
+
508
+ def get_seq_length(self, layer_idx: Optional[int] = 0):
509
+ # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
510
+ # limit the check to the first batch member and head dimension.
511
+ # TODO: deprecate this function in favor of `cache_position`
512
+ if layer_idx != 0:
513
+ raise ValueError(
514
+ "`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. "
515
+ "Using the `layer_idx` argument is not supported."
516
+ )
517
+ return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
518
+
519
+ def reset(self):
520
+ """Resets the cache values while preserving the objects"""
521
+ for layer_idx in range(len(self.key_cache)):
522
+ # In-place ops prevent breaking the static address
523
+ self.key_cache[layer_idx].zero_()
524
+ self.value_cache[layer_idx].zero_()
525
+
526
+
527
+ cache_dict = {
528
+ "dynamic": DynamicCache,
529
+ "static": StaticCache,
530
+ "sliding_window": SlidingWindowCache,
531
+ "hybrid": HybridCache,
532
+ }
533
+
534
+
535
+ def cache_factory(
536
+ cache_implementation: str,
537
+ config: LLMModelConfig,
538
+ batch_size: int,
539
+ max_cache_len: int,
540
+ ):
541
+ assert (
542
+ cache_implementation in cache_dict
543
+ ), f"Unknown cache type. {cache_implementation}"
544
+ logging.info(f"Use {cache_implementation} as cache implementation.")
545
+ if cache_implementation == "sliding_window":
546
+ assert hasattr(config, "sliding_window_")
547
+ max_cache_len = min(config.sliding_window_, max_cache_len)
548
+ return cache_dict[cache_implementation](
549
+ config=config,
550
+ batch_size=batch_size,
551
+ max_cache_len=max_cache_len,
552
+ device=config.device_,
553
+ dtype=config.dtype_,
554
+ )
c2cite/common/checkpoint.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Tuple
2
+
3
+ import torch
4
+
5
+
6
+ def pack_hook(to_offload: torch.Tensor) -> Tuple[torch.device, torch.Tensor]:
7
+ return to_offload.device, to_offload.to("cpu")
8
+
9
+
10
+ def unpack_hook(to_offload_info: Tuple[torch.device, torch.Tensor]) -> torch.Tensor:
11
+ device, to_offload = to_offload_info
12
+ return to_offload.to(device)
13
+
14
+
15
+ def CheckpointNoneFunction(run_function: Callable, *args):
16
+ return run_function(*args)
17
+
18
+
19
+ def CheckpointOffloadFunction(run_function: Callable, *args):
20
+ with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
21
+ outputs = run_function(*args)
22
+ return outputs
23
+
24
+
25
+ def CheckpointRecomputeFunction(run_function: Callable, *args):
26
+ return torch.utils.checkpoint.checkpoint(run_function, *args, use_reentrant=True)
27
+
28
+
29
+ CHECKPOINT_CLASSES = {
30
+ "none": CheckpointNoneFunction,
31
+ "offload": CheckpointOffloadFunction,
32
+ "recompute": CheckpointRecomputeFunction,
33
+ }
c2cite/common/config.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ from dataclasses import dataclass, field
4
+ from typing import Callable, Dict, List, Optional, TypeAlias, Union
5
+
6
+ import torch
7
+
8
+ Tokens: TypeAlias = List[int]
9
+ Labels: TypeAlias = List[int]
10
+ Masks: TypeAlias = List[bool]
11
+ Ground: TypeAlias = List[str]
12
+ Citations: TypeAlias = List[str]
13
+ Query: TypeAlias = List[str]
14
+
15
+
16
+ @dataclass
17
+ class Prompt:
18
+ instruction: str = None
19
+ input: str = None
20
+ label: str = None
21
+
22
+
23
+ @dataclass
24
+ class InputData:
25
+ inputs: List[Union[Prompt, List[str], str]] = None
26
+ prefix_length_: int = None
27
+ tokens: Optional[Tokens] = None
28
+ labels: Optional[Labels] = None
29
+ grounds: Optional[Ground] = None
30
+ citations: Optional[Citations] = None
31
+ citation_tokens: Optional[List] = None
32
+ citation_embeds: Optional[List] = None
33
+ query: Optional[Query] = None
34
+ token_len: Optional[int] = None
35
+ prompt: Optional[str] = None
36
+ prompt_len: Optional[int] = None
37
+ test_citations: Optional[Citations] = None
38
+
39
+
40
+ @dataclass
41
+ class LLMModelConfig:
42
+ name_or_path_: str = None
43
+ device_: str = None
44
+ dim_: int = None
45
+ head_dim_: int = None
46
+ intermediate_: int = None
47
+ n_heads_: int = None
48
+ n_kv_heads_: int = None
49
+ n_layers_: int = None
50
+ hidden_act_: str = None
51
+ hidden_dropout_: float = None
52
+ vocab_size_: int = None
53
+ pad_token_id_: int = None
54
+ rope_theta_: float = None
55
+ partial_rotary_factor_: float = None
56
+ max_seq_len_: int = None
57
+ # eager or flash_attn
58
+ attn_implementation_: str = "eager"
59
+ # data type
60
+ dtype_: torch.dtype = None
61
+
62
+
63
+ @dataclass
64
+ class LLMModelOutput:
65
+ adapter_name: str = None
66
+ logits: torch.Tensor = None
67
+ router_logits: torch.Tensor = None
68
+ loss: torch.Tensor = None
69
+ cite_flag: bool = False
70
+ aux_loss: torch.Tensor = None
71
+ # for internal use
72
+ batch_start_idx_: int = -1
73
+ batch_end_idx_: int = -1
74
+ loss_fn_: Callable = None
75
+
76
+
77
+ @dataclass
78
+ class LLMBatchConfig:
79
+ adapter_name_: str = ""
80
+ batch_start_idx_: int = -1
81
+ batch_end_idx_: int = -1
82
+
83
+
84
+ def _efficient_operator_factory():
85
+ efficient_operator = os.getenv("MOE_PEFT_EVALUATE_MODE") is None
86
+ return efficient_operator
87
+
88
+
89
+ @dataclass
90
+ class LLMModelInput:
91
+ batch_configs_: List[LLMBatchConfig] = None
92
+ batch_tokens_: List[Tokens] = None
93
+ batch_labels_: List[Labels] = None
94
+ batch_grounds_: List[Ground] = None
95
+ batch_cites: List[List] = None
96
+ batch_cites_value: List[List] = None
97
+ batch_masks_: List[Masks] = None
98
+ batch_docs: List[str] = None
99
+ batch_prompt_len: List[int] = None
100
+
101
+ output_router_logits_: bool = True
102
+
103
+ gradient_checkpoint_: str = "none"
104
+ efficient_operator_: bool = field(default_factory=_efficient_operator_factory)
105
+ inference_mode_: bool = False
106
+
107
+
108
+ @dataclass
109
+ class AdapterConfig:
110
+ adapter_name: str = ""
111
+ task_name: str = "casual"
112
+
113
+ @staticmethod
114
+ def from_config(config: Dict[str, any]) -> "AdapterConfig":
115
+ return AdapterConfig(
116
+ adapter_name=config.get("name", None),
117
+ task_name=config.get("task_name", None),
118
+ )
119
+
120
+
121
+ lora_target_modules = {
122
+ # LLaMA names
123
+ "q_proj": False,
124
+ "k_proj": False,
125
+ "v_proj": False,
126
+ "o_proj": False,
127
+ "gate_proj": False,
128
+ "down_proj": False,
129
+ "up_proj": False,
130
+ # Phi names
131
+ "q_proj": False,
132
+ "k_proj": False,
133
+ "v_proj": False,
134
+ "dense": False,
135
+ "fc1": False,
136
+ "fc2": False,
137
+ # Phi3 names
138
+ "qkv_proj": False,
139
+ "o_proj": False,
140
+ "gate_up_proj": False,
141
+ "down_proj": False,
142
+ # GLM names
143
+ "qkv_proj": False,
144
+ "dense": False,
145
+ "dense_h_to_4h": False,
146
+ "dense_4h_to_h": False,
147
+ }
148
+
149
+
150
+ @dataclass
151
+ class LoraConfig(AdapterConfig):
152
+ # Weight-Decomposed Low-Rank Adaptation
153
+ use_dora_: bool = False
154
+ # Rank-Stabilized LoRA
155
+ # sets the adapter scaling factor to `alpha/math.sqrt(r)`
156
+ use_rslora_: bool = False
157
+ # can be original or gaussian
158
+ lora_init_: str = "original"
159
+ lora_r_: int = None
160
+ lora_alpha_: int = None
161
+ lora_dropout_: float = None
162
+ target_modules_: Dict[str, bool] = None
163
+ atten_coin: float = None
164
+ router_coin: float = None
165
+ cite_coin:float = None
166
+ learning_rate: float = None
167
+
168
+ def check(self) -> "LoraConfig":
169
+ assert isinstance(self.use_dora_, bool)
170
+ assert isinstance(self.use_rslora_, bool)
171
+ assert isinstance(self.lora_init_, str) and self.lora_init_ in [
172
+ "original",
173
+ "gaussian",
174
+ ]
175
+ assert isinstance(self.lora_r_, int) and self.lora_r_ > 0
176
+ assert isinstance(self.lora_alpha_, int) and self.lora_alpha_ > 0
177
+ assert isinstance(self.lora_dropout_, float) and self.lora_dropout_ >= 0
178
+ assert isinstance(self.target_modules_, Dict)
179
+ for key, value in self.target_modules_.items():
180
+ assert isinstance(key, str) and len(key) > 0
181
+ assert isinstance(value, bool)
182
+
183
+ return self
184
+
185
+ @staticmethod
186
+ def from_config(config: Dict[str, any]) -> "LoraConfig":
187
+ lora_config = LoraConfig(**AdapterConfig.from_config(config).__dict__)
188
+ lora_config.use_dora_ = config.get("use_dora", False)
189
+ lora_config.use_rslora_ = config.get("use_rslora", False)
190
+ lora_config.lora_init_ = config.get("lora_init", "original")
191
+ lora_config.lora_r_ = config["r"]
192
+ lora_config.lora_alpha_ = config["lora_alpha"]
193
+ lora_config.lora_dropout_ = config["lora_dropout"]
194
+ lora_config.target_modules_ = copy.deepcopy(lora_target_modules)
195
+ lora_config.atten_coin = config["atten_mat_coin"]
196
+ lora_config.router_coin = config["router_coin"]
197
+ lora_config.cite_coin = config["cite_coin"]
198
+ lora_config.learning_rate = config["lr"]
199
+ if isinstance(config["target_modules"], List):
200
+ for target in config["target_modules"]:
201
+ if target in lora_target_modules:
202
+ lora_config.target_modules_[target] = True
203
+ elif isinstance(config["target_modules"], Dict):
204
+ for target, value in config["target_modules"].items():
205
+ if target in lora_target_modules:
206
+ lora_config.target_modules_[target] = value
207
+ else:
208
+ raise ValueError("broken config item: target_modules")
209
+
210
+ return lora_config
211
+
212
+ def export(self) -> Dict[str, any]:
213
+ config = {}
214
+ if self.use_dora_:
215
+ config["use_dora"] = True
216
+ if self.use_rslora_:
217
+ config["use_rslora"] = True
218
+ config["bias"] = "none"
219
+ config["peft_type"] = "LORA"
220
+ config["r"] = self.lora_r_
221
+ config["lora_alpha"] = self.lora_alpha_
222
+ config["lora_dropout"] = self.lora_dropout_
223
+ tgt_list = []
224
+ for target, value in self.target_modules_.items():
225
+ if value:
226
+ tgt_list.append(target)
227
+ config["target_modules"] = tgt_list
228
+
229
+ config["atten_mat_coin"] = self.atten_coin
230
+ config["router_coin"] = self.router_coin
231
+ config["cite_coin"] = self.cite_coin
232
+ config["lr"] = self.learning_rate
233
+
234
+ return config
c2cite/common/feed_forward.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Tuple
2
+
3
+ import torch
4
+
5
+ from moe_peft.executors import executor
6
+
7
+ from .abstracts import LLMFeedForward, LLMMoeBlock
8
+ from .config import LLMModelInput
9
+ from .lora_linear import Linear, get_range_tensor
10
+
11
+
12
+ class FeedForward(torch.nn.Module):
13
+ def __init__(self, mlp: LLMFeedForward) -> None:
14
+ super().__init__()
15
+ self.mlp_: LLMFeedForward = mlp
16
+ # mix of experts
17
+ self.moes_: Dict[str, LLMMoeBlock] = {}
18
+
19
+ def state_dict(self) -> Dict[str, Linear]:
20
+ return self.mlp_.state_dict()
21
+
22
+ def forward(
23
+ self, data: torch.Tensor, input_args: LLMModelInput
24
+ ) -> Tuple[torch.Tensor, List]:
25
+ if len(self.moes_) == 0:
26
+ return self.mlp_._batch_forward(data, input_args)
27
+ else:
28
+ return self._moe_forward(data, input_args)
29
+
30
+ def _moe_forward(self, data: torch.Tensor, input_args: LLMModelInput):
31
+ final_hidden_states = executor.init_tensor(data)
32
+
33
+ if input_args.output_router_logits_:
34
+ router_logits = [None for _ in range(len(input_args.batch_configs_))]
35
+ else:
36
+ router_logits = []
37
+
38
+ lora_range = get_range_tensor(data.device, data.shape[0])
39
+ for idx, lora_config in enumerate(input_args.batch_configs_):
40
+ moe_name = lora_config.adapter_name_
41
+ start_idx = lora_config.batch_start_idx_
42
+ end_idx = lora_config.batch_end_idx_
43
+
44
+ if moe_name in self.moes_:
45
+ current_hidden_states, current_router_outputs = self.moes_[
46
+ moe_name
47
+ ].forward(
48
+ hidden_states=data[start_idx:end_idx],
49
+ ffn_layer=self.mlp_,
50
+ input_args=input_args,
51
+ )
52
+
53
+ if (
54
+ input_args.output_router_logits_
55
+ and current_router_outputs is not None
56
+ ):
57
+ router_logits[idx] = current_router_outputs
58
+ else:
59
+ current_hidden_states = self.mlp_._lora_forward(
60
+ moe_name, self.mlp_.act_, data[start_idx:end_idx]
61
+ )
62
+
63
+ executor.index_copy(
64
+ final_hidden_states,
65
+ 0,
66
+ lora_range[start_idx:end_idx],
67
+ current_hidden_states,
68
+ )
69
+
70
+ return final_hidden_states, router_logits
c2cite/common/lora_linear.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from transformers.utils import is_bitsandbytes_available
7
+
8
+ from moe_peft.executors import executor
9
+
10
+ from .abstracts import LLMMoeBlock
11
+ from .config import LLMModelInput, LoraConfig
12
+
13
+ if is_bitsandbytes_available():
14
+ import bitsandbytes as bnb
15
+ from bitsandbytes.nn import Linear4bit, Linear8bitLt
16
+ else:
17
+ from moe_peft.utils import Linear8bitLt, Linear4bit
18
+
19
+ from typing import Any, Dict, List, Tuple
20
+
21
+
22
+ def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None):
23
+ # BNB requires CUDA weights
24
+ device = weight.device
25
+ is_cpu = device.type == torch.device("cpu").type
26
+ if is_cpu:
27
+ weight = weight.to(torch.device("cuda"))
28
+
29
+ cls_name = weight.__class__.__name__
30
+ if cls_name == "Params4bit":
31
+ dequantized = bnb.functional.dequantize_4bit(weight.data, weight.quant_state)
32
+ if is_cpu:
33
+ dequantized = dequantized.to(device)
34
+ return dequantized
35
+
36
+ if state.SCB is None:
37
+ state.SCB = weight.SCB
38
+
39
+ im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device)
40
+ im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im)
41
+ im, Sim = bnb.functional.transform(im, "col32")
42
+ if state.CxB is None:
43
+ state.CxB, state.SB = bnb.functional.transform(
44
+ weight.data, to_order=state.formatB
45
+ )
46
+ out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB)
47
+ dequantized = bnb.functional.mm_dequant(
48
+ out32, Sout32, SCim, state.SCB, bias=None
49
+ ).t()
50
+ if is_cpu:
51
+ dequantized = dequantized.to(device)
52
+ return dequantized
53
+
54
+
55
+ def dequantize_module_weight(module: torch.nn.Module) -> torch.nn.Parameter:
56
+ if hasattr(module, "W_q"): # For handling HQQ quantized weight
57
+ weight = module.dequantize()
58
+ return weight
59
+
60
+ weight = module.weight
61
+ if not isinstance(weight, torch.nn.Parameter):
62
+ raise TypeError(
63
+ f"Input weight should be of type nn.Parameter, got {type(weight)} instead"
64
+ )
65
+
66
+ cls_name = weight.__class__.__name__
67
+ if cls_name not in ("Params4bit", "Int8Params"):
68
+ return weight
69
+
70
+ quant_state = getattr(module, "state", None)
71
+ device = weight.device
72
+ is_cpu = device.type == torch.device("cpu").type
73
+ weight = dequantize_bnb_weight(weight, state=quant_state) # no-op if not bnb
74
+ if is_cpu:
75
+ # dequantize_bnb_weight for 8bit moves the device in-place, thus we need to move it back to CPU if necessary
76
+ module.weight = module.weight.to(device)
77
+ return weight
78
+
79
+
80
+ g_cached_range_tensor: Dict[torch.device, torch.Tensor] = {}
81
+ # also max batch size
82
+ g_max_range = 128
83
+
84
+
85
+ def get_range_tensor(device: torch.device, batch_size: int = 1024):
86
+ global g_cached_range_tensor
87
+ global g_max_range
88
+ if device not in g_cached_range_tensor or batch_size > g_max_range:
89
+ g_max_range = g_max_range if g_max_range > batch_size else batch_size
90
+ g_cached_range_tensor[device] = torch.arange(
91
+ 0, g_max_range, step=1, device=device
92
+ )
93
+ return g_cached_range_tensor[device]
94
+
95
+
96
+ class LoraFunction(torch.autograd.Function):
97
+ @staticmethod
98
+ def forward(
99
+ ctx,
100
+ result: torch.Tensor,
101
+ data: torch.Tensor,
102
+ input_args: LLMModelInput,
103
+ dropouts: List[float],
104
+ scalings: List[float],
105
+ *args,
106
+ ):
107
+ # the lora module is f32 precision
108
+ data = data.to(torch.float32)
109
+
110
+ save_inputs: Tuple[torch.Tensor | None, ...] = (data,)
111
+
112
+ lora_range = get_range_tensor(data.device, data.shape[0])
113
+ for lora_a, lora_b, lora_config, dropout, scaling in zip(
114
+ args[::2],
115
+ args[1::2],
116
+ input_args.batch_configs_,
117
+ dropouts,
118
+ scalings,
119
+ ):
120
+ assert not ((lora_a is None) ^ (lora_b is None))
121
+ if lora_a is None and lora_b is None:
122
+ save_inputs += (None, None, None)
123
+ continue
124
+
125
+ assert not ((lora_a.requires_grad) ^ (lora_b.requires_grad))
126
+ if not lora_a.requires_grad and not lora_b.requires_grad:
127
+ save_inputs += (None, None, None)
128
+ continue
129
+
130
+ start_idx = lora_config.batch_start_idx_
131
+ end_idx = lora_config.batch_end_idx_
132
+
133
+ # must ensure the dropout is not zero
134
+ # is dropout == 0, dropdata is a data's referece, so the data will be changed
135
+ assert dropout > 0.0
136
+
137
+ drop_data = F.dropout(data[start_idx:end_idx], p=dropout)
138
+ drop_data.mul_(scaling)
139
+ drop_data = drop_data @ lora_a.transpose(0, 1)
140
+ lora_data = drop_data @ lora_b.transpose(0, 1)
141
+
142
+ lora_data = lora_data.to(result.dtype)
143
+
144
+ result.index_add_(0, lora_range[start_idx:end_idx], lora_data)
145
+
146
+ save_inputs += (lora_a, lora_b, drop_data)
147
+
148
+ ctx.input_args = input_args
149
+ ctx.dropouts = dropouts
150
+ ctx.scalings = scalings
151
+ ctx.save_for_backward(*save_inputs)
152
+
153
+ return result
154
+
155
+ @staticmethod
156
+ def backward(ctx: Any, *grad_outputs: Any) -> Any:
157
+ grad_output: torch.Tensor = grad_outputs[0]
158
+ grad_result = None
159
+ grad_data: torch.Tensor | None = None
160
+ grad_input_args = None
161
+ grad_dropouts = None
162
+ grad_scalings = None
163
+ grad_loras: Tuple[torch.Tensor | None, ...] = ()
164
+
165
+ data, *loras = ctx.saved_tensors
166
+
167
+ if ctx.needs_input_grad[0]:
168
+ grad_result = grad_output
169
+ if ctx.needs_input_grad[1]:
170
+ grad_data = executor.init_tensor(data)
171
+
172
+ # the lora module is fp32 precision
173
+ grad_output = grad_output.to(torch.float32)
174
+ lora_range = get_range_tensor(
175
+ grad_output.device, batch_size=grad_output.shape[0]
176
+ )
177
+ for lora_a, lora_b, drop_data, dropout, scaling, lora_config in zip(
178
+ loras[::3],
179
+ loras[1::3],
180
+ loras[2::3],
181
+ ctx.dropouts,
182
+ ctx.scalings,
183
+ ctx.input_args.batch_configs_,
184
+ ):
185
+ start_idx = lora_config.batch_start_idx_
186
+ end_idx = lora_config.batch_end_idx_
187
+ assert not ((lora_a is None) ^ (lora_b is None))
188
+ if lora_a is None and lora_b is None:
189
+ grad_loras += (None, None)
190
+ if grad_data is not None:
191
+ executor.index_fill(grad_data, 0, lora_range[start_idx:end_idx], 0)
192
+ continue
193
+
194
+ # lora_data shape is batch_size * seq_len * in_dim
195
+ lora_data = data[start_idx:end_idx]
196
+ # grad_y shape is batch_size * seq_len * out_dim
197
+ grad_y = grad_output[start_idx:end_idx]
198
+
199
+ # drop_data shape is batch_size * seq_len * r
200
+
201
+ # bstage shape is batch_size * seq_len * r
202
+ bstage = grad_y @ lora_b
203
+ bstage *= scaling / (1 - dropout)
204
+
205
+ grad_a = torch.sum(bstage.transpose(1, 2) @ lora_data, dim=0)
206
+ grad_b = torch.sum(grad_y.transpose(1, 2) @ drop_data, dim=0)
207
+ grad_loras += (grad_a, grad_b)
208
+
209
+ # grad_data shape is batch_size * seq_len * in_dim
210
+ if grad_data is not None:
211
+ grad_x = bstage @ lora_a
212
+ executor.index_copy(grad_data, 0, lora_range[start_idx:end_idx], grad_x)
213
+
214
+ return (
215
+ grad_result,
216
+ grad_data,
217
+ grad_input_args,
218
+ grad_dropouts,
219
+ grad_scalings,
220
+ *grad_loras,
221
+ )
222
+
223
+
224
+ class Lora(nn.Module):
225
+ def __init__(
226
+ self,
227
+ base_layer: nn.Module,
228
+ shape: Tuple[int, int],
229
+ config: LoraConfig,
230
+ device: str,
231
+ ):
232
+
233
+ super().__init__()
234
+
235
+ self.base_layer_ = base_layer
236
+ self.device_ = torch.device(device)
237
+
238
+ self.initializer_ = config.lora_init_
239
+ self.r_ = config.lora_r_
240
+ self.alpha_ = config.lora_alpha_
241
+
242
+ if config.use_rslora_:
243
+ self.scaling_ = self.alpha_ / math.sqrt(self.r_)
244
+ else:
245
+ self.scaling_ = self.alpha_ / self.r_
246
+
247
+ self.in_features_, self.out_features_ = shape
248
+
249
+ assert config.lora_dropout_ > 0.0
250
+ self.dropout_ = nn.Dropout(p=config.lora_dropout_)
251
+
252
+ self.lora_a_ = nn.Linear(
253
+ self.in_features_,
254
+ self.r_,
255
+ bias=False,
256
+ dtype=torch.float32,
257
+ device=self.device_,
258
+ )
259
+ self.lora_b_ = nn.Linear(
260
+ self.r_,
261
+ self.out_features_,
262
+ bias=False,
263
+ dtype=torch.float32,
264
+ device=self.device_,
265
+ )
266
+
267
+ self.use_dora_: bool = config.use_dora_
268
+ self.magnitude_vector_: nn.Parameter = None
269
+
270
+ def _get_weight_norm(self, dtype: torch.dtype = torch.float32) -> torch.Tensor:
271
+ # calculate L2 norm of weight matrix, column-wise
272
+ weight = dequantize_module_weight(self.base_layer_).to(dtype)
273
+ lora_weight = self.lora_b_.weight @ self.lora_a_.weight
274
+ weight = weight + self.scaling_ * lora_weight
275
+ weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype)
276
+ return weight_norm
277
+
278
+ def reset_parameters(self, lora_tensor=(None, None)) -> None:
279
+ # if the lora_tensor is not (None, None), use it to init the lora weight
280
+ assert isinstance(lora_tensor, Tuple)
281
+ assert len(lora_tensor) == 2
282
+ assert ((lora_tensor[0] is None) and (lora_tensor[1] is None)) or (
283
+ isinstance(lora_tensor[0], torch.Tensor)
284
+ and isinstance(lora_tensor[1], torch.Tensor)
285
+ )
286
+
287
+ if lora_tensor == (None, None):
288
+ if self.initializer_ == "original":
289
+ nn.init.kaiming_uniform_(self.lora_a_.weight, a=math.sqrt(5))
290
+ elif self.initializer_ == "gaussian":
291
+ nn.init.normal_(self.lora_a_.weight, std=1 / self.r_)
292
+ else:
293
+ raise ValueError(f"Unknown initialization {self.initializer_}")
294
+ nn.init.zeros_(self.lora_b_.weight)
295
+ else:
296
+ with torch.no_grad():
297
+ self.lora_a_.weight.copy_(lora_tensor[0])
298
+ self.lora_b_.weight.copy_(lora_tensor[1])
299
+
300
+ if self.use_dora_:
301
+ self.magnitude_vector_ = nn.Parameter(
302
+ self._get_weight_norm(), requires_grad=True
303
+ )
304
+
305
+ def apply_dora(
306
+ self,
307
+ residual: torch.Tensor,
308
+ result_lora: torch.Tensor,
309
+ ):
310
+ weight_norm = self._get_weight_norm().detach()
311
+ mag_norm_scale = (self.magnitude_vector_ / weight_norm).view(1, -1)
312
+ return mag_norm_scale * residual + mag_norm_scale * result_lora
313
+
314
+ def lora_forward(self, hidden_states: torch.Tensor):
315
+ return (
316
+ self.lora_b_(self.lora_a_(self.dropout_(hidden_states.to(torch.float32))))
317
+ * self.scaling_
318
+ )
319
+
320
+ def forward(
321
+ self,
322
+ residual: torch.Tensor,
323
+ hidden_states: torch.Tensor,
324
+ ) -> torch.Tensor:
325
+ result_lora = self.lora_forward(hidden_states)
326
+ if self.use_dora_:
327
+ return self.apply_dora(residual, result_lora).to(residual.dtype)
328
+ else:
329
+ return residual + result_lora.to(residual.dtype)
330
+
331
+
332
+ class Linear(nn.Module):
333
+ def __init__(self, base_layer: nn.Module, device: str):
334
+ super().__init__()
335
+
336
+ if not isinstance(base_layer, nn.Linear):
337
+ assert isinstance(base_layer, Linear8bitLt) or isinstance(
338
+ base_layer, Linear4bit
339
+ ), f"error type - {type(base_layer)}."
340
+ else:
341
+ base_layer.requires_grad_(False)
342
+
343
+ self.device_ = torch.device(device)
344
+ self.base_layer_ = base_layer.to(self.device_)
345
+ self.loras_: Dict[str, Lora] = {}
346
+ self.moes_: Dict[str, LLMMoeBlock] = {}
347
+
348
+ if isinstance(self.base_layer_, Linear4bit):
349
+ self.out_features_, self.in_features_ = (
350
+ self.base_layer_.out_features,
351
+ self.base_layer_.in_features,
352
+ )
353
+ else:
354
+ self.out_features_, self.in_features_ = self.base_layer_.weight.shape
355
+
356
+ def init_lora_weight(
357
+ self, lora_config: LoraConfig, lora_tensor=(None, None), adapter_name=None
358
+ ):
359
+ if adapter_name is None:
360
+ adapter_name = lora_config.adapter_name
361
+
362
+ if adapter_name not in self.loras_:
363
+ self.loras_[adapter_name] = Lora(
364
+ self.base_layer_,
365
+ (self.in_features_, self.out_features_),
366
+ lora_config,
367
+ self.device_,
368
+ )
369
+
370
+ self.loras_[adapter_name].reset_parameters(lora_tensor)
371
+
372
+ def _appy_dora(
373
+ self,
374
+ residual: torch.Tensor,
375
+ lora_delta: torch.Tensor,
376
+ input_args: LLMModelInput,
377
+ ):
378
+ next_states = executor.init_tensor(residual)
379
+ lora_range = get_range_tensor(
380
+ next_states.device, batch_size=next_states.shape[0]
381
+ )
382
+ for lora_config in input_args.batch_configs_:
383
+ adapter_name = lora_config.adapter_name_
384
+ start_idx = lora_config.batch_start_idx_
385
+ end_idx = lora_config.batch_end_idx_
386
+
387
+ if adapter_name == "" or adapter_name not in self.loras_:
388
+ continue
389
+
390
+ if self.loras_[adapter_name].use_dora_:
391
+ lora_data = self.loras_[adapter_name].apply_dora(
392
+ residual[start_idx:end_idx],
393
+ lora_delta[start_idx:end_idx],
394
+ )
395
+ else:
396
+ lora_data = residual[start_idx:end_idx] + lora_delta[start_idx:end_idx]
397
+
398
+ executor.index_copy(
399
+ next_states, 0, lora_range[start_idx:end_idx], lora_data
400
+ )
401
+
402
+ return next_states
403
+
404
+ def _efficient_impl(
405
+ self, hidden_states: torch.Tensor, input_args: LLMModelInput
406
+ ) -> torch.Tensor:
407
+ # hidden_states shape is: batch_size * max_seq_len * dim
408
+ # result = hidden_states @ self.weight_.transpose(0, 1)
409
+ residual = self.base_layer_.forward(hidden_states)
410
+
411
+ if len(self.loras_) == 0:
412
+ return residual
413
+
414
+ # split the data and result
415
+ dropouts: List[float] = []
416
+ scalings: List[float] = []
417
+ loras: Tuple[torch.Tensor] = ()
418
+ for lora_config in input_args.batch_configs_:
419
+ adapter_name = lora_config.adapter_name_
420
+
421
+ if adapter_name not in self.loras_:
422
+ loras += (None, None)
423
+ dropouts.append(None)
424
+ scalings.append(None)
425
+ continue
426
+
427
+ loras += (
428
+ self.loras_[adapter_name].lora_a_.weight,
429
+ self.loras_[adapter_name].lora_b_.weight,
430
+ )
431
+ dropouts.append(self.loras_[adapter_name].dropout_.p)
432
+ scalings.append(self.loras_[adapter_name].scaling_)
433
+
434
+ have_dora = any(lora.use_dora_ for lora in self.loras_.values())
435
+
436
+ if have_dora:
437
+ lora_delta = torch.zeros_like(residual, dtype=torch.float32)
438
+ lora_delta = LoraFunction.apply(
439
+ lora_delta,
440
+ hidden_states.to(torch.float32),
441
+ input_args,
442
+ dropouts,
443
+ scalings,
444
+ *loras,
445
+ )
446
+ next_states = self._appy_dora(
447
+ residual.to(torch.float32), lora_delta, input_args
448
+ )
449
+ else:
450
+ next_states = LoraFunction.apply(
451
+ residual.to(torch.float32),
452
+ hidden_states.to(torch.float32),
453
+ input_args,
454
+ dropouts,
455
+ scalings,
456
+ *loras,
457
+ )
458
+
459
+ return next_states.to(hidden_states.dtype)
460
+
461
+ def _compatible_impl(
462
+ self, hidden_states: torch.Tensor, input_args: LLMModelInput
463
+ ) -> torch.Tensor:
464
+ # hidden_states shape is: batch_size * max_seq_len * dim
465
+ # result = hidden_states @ self.weight_.transpose(0, 1)
466
+ residual = self.base_layer_.forward(hidden_states)
467
+
468
+ if len(self.loras_) == 0:
469
+ return residual
470
+
471
+ next_states = executor.init_tensor(residual)
472
+ lora_range = get_range_tensor(hidden_states.device, hidden_states.shape[0])
473
+
474
+ for lora_config in input_args.batch_configs_:
475
+ adapter_name = lora_config.adapter_name_
476
+ start_idx = lora_config.batch_start_idx_
477
+ end_idx = lora_config.batch_end_idx_
478
+
479
+ if adapter_name in self.loras_:
480
+ fwd_fn = self.loras_[adapter_name].forward
481
+ kwargs = {}
482
+ elif adapter_name in self.moes_:
483
+ fwd_fn = self.moes_[adapter_name].forward
484
+ kwargs = {"lora_linear": self}
485
+ else:
486
+ executor.index_copy(
487
+ next_states,
488
+ 0,
489
+ lora_range[start_idx:end_idx],
490
+ residual[start_idx:end_idx],
491
+ )
492
+ continue
493
+
494
+ lora_data = fwd_fn(
495
+ residual=residual[start_idx:end_idx],
496
+ hidden_states=hidden_states[start_idx:end_idx],
497
+ **kwargs,
498
+ )
499
+ executor.index_copy(
500
+ next_states, 0, lora_range[start_idx:end_idx], lora_data
501
+ )
502
+
503
+ return next_states
504
+
505
+ def forward(
506
+ self, hidden_states: torch.Tensor, input_args: LLMModelInput
507
+ ) -> torch.Tensor:
508
+ if input_args.efficient_operator_ and len(self.moes_) == 0:
509
+ return self._efficient_impl(hidden_states, input_args)
510
+ else:
511
+ return self._compatible_impl(hidden_states, input_args)
c2cite/common/moe_utils.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import List, Optional
3
+
4
+ import torch
5
+
6
+ from .abstracts import LLMDecoder, LLMModelInput
7
+
8
+
9
+ def slice_tensor(
10
+ data: torch.Tensor,
11
+ slice: torch.Tensor,
12
+ dtype: torch.dtype,
13
+ last_value: Optional[torch.Tensor] = None,
14
+ ):
15
+ if last_value is None:
16
+ # for macOS debugging, please uncomment this line
17
+ # assert data.dtype in (torch.float, torch.int, torch.bool)
18
+ return data[None, slice].reshape(-1, data.shape[-1]).to(dtype)
19
+ else:
20
+ return last_value
21
+
22
+
23
+ def unpack_router_logits(gate_logits: List[torch.Tensor]) -> torch.Tensor:
24
+ compute_device = gate_logits[0].device
25
+ concatenated_gate_logits = torch.cat(
26
+ [layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0
27
+ )
28
+ return concatenated_gate_logits
29
+
30
+
31
+ def collect_plugin_router_logtis(
32
+ router_logits, input_args: LLMModelInput, decoder_layer: LLMDecoder
33
+ ):
34
+ if router_logits is None or len(router_logits) == 0:
35
+ router_logits = [None for _ in range(len(input_args.batch_configs_))]
36
+
37
+ attn_proj, mlp_proj = decoder_layer.state_dict()
38
+ all_proj = copy.copy(attn_proj)
39
+ all_proj.update(mlp_proj)
40
+ for idx, config in enumerate(input_args.batch_configs_):
41
+ if router_logits[idx] is not None:
42
+ continue
43
+ adapter_name = config.adapter_name_
44
+ for proj in all_proj.values():
45
+ if adapter_name in proj.moes_ and hasattr(
46
+ proj.moes_[adapter_name], "router_logits_"
47
+ ):
48
+ if router_logits[idx] is None:
49
+ router_logits[idx] = []
50
+ router_logits[idx].append(proj.moes_[adapter_name].router_logits_)
51
+ proj.moes_[adapter_name].router_logits_ = None
52
+
53
+ for idx, logits in enumerate(router_logits):
54
+ if isinstance(logits, list):
55
+ router_logits[idx] = torch.cat(logits, 0)
56
+
57
+ return router_logits
c2cite/common/rope.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+
6
+ from .config import LLMModelConfig
7
+
8
+
9
+ def _compute_default_rope_parameters(
10
+ config: Optional[LLMModelConfig] = None,
11
+ device: Optional[torch.device] = None,
12
+ seq_len: Optional[int] = None,
13
+ **rope_kwargs,
14
+ ) -> Tuple[torch.Tensor, float]:
15
+ if len(rope_kwargs) > 0:
16
+ base = rope_kwargs["base"]
17
+ dim = rope_kwargs["dim"]
18
+ elif config is not None:
19
+ base = config.rope_theta_
20
+ partial_rotary_factor = (
21
+ config.partial_rotary_factor_
22
+ if config.partial_rotary_factor_ is not None
23
+ else 1.0
24
+ )
25
+ head_dim = (
26
+ config.dim_ // config.n_heads_
27
+ if config.head_dim_ is None
28
+ else config.head_dim_
29
+ )
30
+ dim = int(head_dim * partial_rotary_factor)
31
+
32
+ attention_factor = 1.0 # Unused in this type of RoPE
33
+
34
+ # Compute the inverse frequencies
35
+ inv_freq = 1.0 / (
36
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)
37
+ )
38
+ return inv_freq, attention_factor
39
+
40
+
41
+ def _compute_llama3_parameters(
42
+ config: LLMModelConfig,
43
+ device: torch.device,
44
+ seq_len: Optional[int] = None,
45
+ **rope_kwargs,
46
+ ) -> Tuple[torch.Tensor, float]:
47
+ # Gets the default RoPE parameters
48
+ inv_freq, attention_factor = _compute_default_rope_parameters(
49
+ config, device, seq_len, **rope_kwargs
50
+ )
51
+
52
+ factor = config.rope_scaling_["factor"] # `8` in the original implementation
53
+ low_freq_factor = config.rope_scaling_[
54
+ "low_freq_factor"
55
+ ] # `1` in the original implementation
56
+ high_freq_factor = config.rope_scaling_[
57
+ "high_freq_factor"
58
+ ] # `4` in the original implementation
59
+ old_context_len = config.rope_scaling_[
60
+ "original_max_position_embeddings"
61
+ ] # `8192` in the original implementation
62
+
63
+ low_freq_wavelen = old_context_len / low_freq_factor
64
+ high_freq_wavelen = old_context_len / high_freq_factor
65
+
66
+ wavelen = 2 * math.pi / inv_freq
67
+ # wavelen < high_freq_wavelen: do nothing
68
+ # wavelen > low_freq_wavelen: divide by factor
69
+ inv_freq_llama = torch.where(
70
+ wavelen > low_freq_wavelen, inv_freq / factor, inv_freq
71
+ )
72
+ # otherwise: interpolate between the two, using a smooth factor
73
+ smooth_factor = (old_context_len / wavelen - low_freq_factor) / (
74
+ high_freq_factor - low_freq_factor
75
+ )
76
+ smoothed_inv_freq = (
77
+ 1 - smooth_factor
78
+ ) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
79
+ is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
80
+ inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
81
+
82
+ return inv_freq_llama, attention_factor
83
+
84
+
85
+ ROPE_INIT_FUNCTIONS = {
86
+ "default": _compute_default_rope_parameters,
87
+ "llama3": _compute_llama3_parameters,
88
+ }
c2cite/dispatcher.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+ import sys
4
+ from abc import abstractmethod
5
+ from typing import Callable, Dict, List
6
+
7
+ import datasets
8
+ import copy
9
+
10
+ from .common import InputData, LLMBatchConfig, LLMModelInput, Masks, Tokens
11
+ from .tokenizer import Tokenizer
12
+
13
+
14
+ class Event:
15
+ __callback_list: List[Callable] = None
16
+
17
+ def __init__(self):
18
+ self.__callback_list = []
19
+
20
+ def register(self, func: Callable) -> "Event":
21
+ self.__callback_list = [func] + self.__callback_list
22
+ return self
23
+
24
+ def activate(self, **kwargs) -> bool:
25
+ for func in self.__callback_list:
26
+ if func(**kwargs):
27
+ return True
28
+
29
+ return False
30
+
31
+
32
+ def load_dataset(data_path: str):
33
+ if data_path.endswith(".json") or data_path.endswith(".jsonl"):
34
+ return datasets.load_dataset("json", data_files=data_path)
35
+ else:
36
+ if ":" in data_path:
37
+ result = data_path.split(":")
38
+ return datasets.load_dataset(result[0], result[1])
39
+ else:
40
+ return datasets.load_dataset(data_path)
41
+
42
+
43
+ class TrainTask:
44
+ tokenizer_: Tokenizer = None
45
+
46
+ adapter_name_: str = ""
47
+ data_path_: str = ""
48
+ dataload_function_: Callable = None
49
+ train_token_data_: List[InputData] = None
50
+
51
+ # train parameter
52
+ total_epoch_num_: int = -1
53
+ max_train_batch_size_: int = -1
54
+ max_train_micro_batch_size_: int = -1
55
+ max_test_batch_size_: int = -1
56
+
57
+ train_cutoff_len_: int = -1
58
+ group_by_length_: bool = False
59
+
60
+ # count the stat of train and test data
61
+ epoch_cnt_: int = 1
62
+ next_train_data_start_idx_: int = 0
63
+ next_test_data_start_idx_: int = 0
64
+
65
+ def __init__(
66
+ self,
67
+ tokenzer: Tokenizer,
68
+ adapter_name: str,
69
+ dataload_function: Callable,
70
+ total_epoch_num: int,
71
+ max_train_batch_size: int,
72
+ max_train_micro_batch_size: int,
73
+ train_cutoff_len: int = 256,
74
+ group_by_length: bool = True,
75
+ ):
76
+ self.tokenizer_ = tokenzer
77
+ self.adapter_name_ = adapter_name
78
+ self.dataload_function_ = dataload_function
79
+ self.total_epoch_num_ = total_epoch_num
80
+ self.max_train_batch_size_ = max_train_batch_size
81
+ self.max_train_micro_batch_size_ = max_train_micro_batch_size
82
+ self.train_cutoff_len_ = train_cutoff_len
83
+ self.group_by_length_ = group_by_length
84
+
85
+ def load_data(self):
86
+ self.train_token_data_ = self.dataload_function_(self.tokenizer_)
87
+ max_train_tokens_len = 0
88
+ for data in self.train_token_data_:
89
+ max_train_tokens_len = max(max_train_tokens_len, len(data.tokens))
90
+ if len(data.tokens) > self.train_cutoff_len_:
91
+ data.tokens = data.tokens[: self.train_cutoff_len_]
92
+
93
+ logging.info(
94
+ f"Max train tokens length: {max_train_tokens_len}/{self.train_cutoff_len_}"
95
+ )
96
+ if self.group_by_length_:
97
+ self.train_token_data_.sort(key=lambda x: len(x.tokens), reverse=True)
98
+ else:
99
+ random.shuffle(self.train_token_data_)
100
+
101
+ def is_train_done(self):
102
+ if self.epoch_cnt_ <= self.total_epoch_num_:
103
+ return False
104
+ return True
105
+
106
+ def is_test_done(self):
107
+ if self.next_test_data_start_idx_ < len(self.test_token_data_):
108
+ return False
109
+ return True
110
+
111
+ def reset_test_status(self):
112
+ self.next_test_data_start_idx_ = 0
113
+
114
+ # reentry function
115
+ def get_train_deta_max_seq_len(self) -> int:
116
+ start_idx = self.next_train_data_start_idx_
117
+ assert start_idx < len(self.train_token_data_)
118
+ # in this strategy must sort
119
+ return len(self.train_token_data_[start_idx].tokens)
120
+
121
+ # non reentry function
122
+ def get_train_data(self) -> List[InputData]:
123
+ start_idx = self.next_train_data_start_idx_
124
+ end_idx = start_idx + self.max_train_micro_batch_size_
125
+
126
+ ret_data = self.train_token_data_[start_idx:end_idx]
127
+
128
+ logging.info(f"{self.adapter_name_} train data:")
129
+ logging.info(
130
+ f" epoch: {self.epoch_cnt_}/{self.total_epoch_num_} \
131
+ step in epoch: {start_idx}/{len(self.train_token_data_)}"
132
+ )
133
+
134
+ self.next_train_data_start_idx_ += self.max_train_micro_batch_size_
135
+ if self.next_train_data_start_idx_ >= len(self.train_token_data_):
136
+ self.next_train_data_start_idx_ = 0
137
+ self.epoch_cnt_ += 1
138
+
139
+ return ret_data
140
+
141
+
142
+ class DispatcherConfig:
143
+ @abstractmethod
144
+ def dispatcher_context(self) -> Dict[str, any]:
145
+ return {}
146
+
147
+
148
+ class Dispatcher:
149
+ config_ = None
150
+ tokenizer_: Tokenizer = None
151
+
152
+ # all train task
153
+ ready_train_task_: List[TrainTask] = None
154
+ running_train_task_: List[TrainTask] = None
155
+ done_train_task_: List[TrainTask] = None
156
+
157
+ # train task in event
158
+ train_task_in_event_: Event = None
159
+ train_task_out_event_: Event = None
160
+
161
+ # the number of max candidate training lora model
162
+ # can chose train data from this dataset
163
+ train_lora_candidate_num_: int = 0
164
+ # the number of simultaneously train lora model
165
+ train_lora_simultaneously_num_: int = 0
166
+
167
+ strategy_: str = ""
168
+
169
+ def __init__(
170
+ self,
171
+ tokenizer: Tokenizer,
172
+ configs: List[DispatcherConfig],
173
+ max_concurrent_jobs: int = None,
174
+ strategy: str = "optim",
175
+ cutoff_len: int = 256,
176
+ ) -> None:
177
+ if max_concurrent_jobs is None:
178
+ max_concurrent_jobs = len(configs)
179
+
180
+ self.tokenizer_ = tokenizer
181
+
182
+ self.ready_train_task_ = []
183
+ self.running_train_task_ = []
184
+ self.done_train_task_ = []
185
+
186
+ self.train_task_in_event_ = Event()
187
+ self.train_task_out_event_ = Event()
188
+
189
+ self.train_lora_candidate_num_ = sys.maxsize
190
+ self.train_lora_simultaneously_num_ = max_concurrent_jobs
191
+ self.strategy_ = strategy
192
+
193
+ # create ready task
194
+ for config_class in configs:
195
+ kwargs = config_class.dispatcher_context()
196
+ self.ready_train_task_.append(
197
+ TrainTask(
198
+ tokenzer=self.tokenizer_, train_cutoff_len=cutoff_len, **kwargs
199
+ )
200
+ )
201
+
202
+ def optim_dispatch_strategy(self) -> Dict[str, List[InputData]]:
203
+ task_len = {}
204
+ for idx, task in enumerate(self.running_train_task_):
205
+ task_len[idx] = task.get_train_deta_max_seq_len()
206
+ # sort to get the seq most similar data
207
+ task_len = sorted(task_len.items(), key=lambda x: x[1], reverse=True)
208
+ # find the mini diff
209
+ min_need_pad_len = sys.maxsize
210
+ win_start_idx = 0
211
+ for sidx in range(0, len(task_len) - self.train_lora_simultaneously_num_ + 1):
212
+ win = task_len[sidx : sidx + self.train_lora_simultaneously_num_]
213
+ need_pad_len = 0
214
+ for i in range(1, len(win)):
215
+ # aligin to the max seq len
216
+ need_pad_len += abs(win[i][1] - win[0][1])
217
+ if need_pad_len < min_need_pad_len:
218
+ min_need_pad_len = need_pad_len
219
+ win_start_idx = sidx
220
+ # the result is win_start_idx
221
+ result_win = task_len[
222
+ win_start_idx : win_start_idx + self.train_lora_simultaneously_num_
223
+ ]
224
+ ret_train_data = {}
225
+ for result_task_len in result_win:
226
+ task_idx = result_task_len[0]
227
+ ret_train_data[self.running_train_task_[task_idx].adapter_name_] = (
228
+ self.running_train_task_[task_idx].get_train_data()
229
+ )
230
+
231
+ return ret_train_data
232
+
233
+ def none_dispatch_strategy(self) -> Dict[str, List[InputData]]:
234
+ ret_train_data = {}
235
+ cnt = 0
236
+ for task in self.running_train_task_:
237
+ assert not task.is_train_done()
238
+ if cnt >= self.train_lora_simultaneously_num_:
239
+ break
240
+ ret_train_data[task.adapter_name_] = task.get_train_data()
241
+ cnt += 1
242
+ return ret_train_data
243
+
244
+ def check_task_done(self) -> bool:
245
+ if len(self.ready_train_task_) == 0 and len(self.running_train_task_) == 0:
246
+ return True
247
+ return False
248
+
249
+ def check_test_done(self) -> bool:
250
+ for task in self.running_train_task_:
251
+ if task.is_train_done():
252
+ return False
253
+ return True
254
+
255
+ def reset_test_task(self):
256
+ for task in self.running_train_task_:
257
+ task.reset_test_status()
258
+
259
+ # ready task -> running task
260
+ def __dispatch_task_in(self):
261
+ assert len(self.running_train_task_) <= self.train_lora_candidate_num_
262
+ if len(self.running_train_task_) == self.train_lora_candidate_num_:
263
+ return
264
+ # chose task into running
265
+ while (
266
+ len(self.running_train_task_) < self.train_lora_candidate_num_
267
+ and len(self.ready_train_task_) > 0
268
+ ):
269
+ # TODO to dispatch task
270
+ task = self.ready_train_task_.pop(0)
271
+ # to lazy load data
272
+ task.load_data()
273
+ self.train_task_in_event_.activate(task=task)
274
+ self.running_train_task_.append(task)
275
+
276
+ # running task -> done task
277
+ def __dispatch_task_out(self):
278
+ for task in self.running_train_task_:
279
+ if task.is_train_done():
280
+ self.train_task_out_event_.activate(task=task)
281
+ self.done_train_task_.append(task)
282
+
283
+ self.running_train_task_ = [
284
+ task for task in self.running_train_task_ if not task.is_train_done()
285
+ ]
286
+
287
+ def get_test_data(self) -> LLMModelInput:
288
+ pass
289
+
290
+ def get_train_data(self) -> LLMModelInput:
291
+ self.__dispatch_task_in()
292
+
293
+ # get task train data
294
+ all_train_data: Dict[str, List[InputData]] = {}
295
+ if self.strategy_ == "none":
296
+ all_train_data = self.none_dispatch_strategy()
297
+ elif self.strategy_ == "optim":
298
+ all_train_data = self.optim_dispatch_strategy()
299
+ else:
300
+ raise "unkown strategy"
301
+
302
+ batch_seq_len: int = -1
303
+ # to align batch token data
304
+ for adapter in all_train_data:
305
+ for data in all_train_data[adapter]:
306
+ batch_seq_len = max(batch_seq_len, len(data.tokens))
307
+ # all prompts and tokens / config
308
+ batch_tokens: List[Tokens] = []
309
+ attention_masks: List[Masks] = []
310
+ batch_labels: List[List] = []
311
+ lora_batch_data_config: List[LLMBatchConfig] = []
312
+
313
+ cites = []
314
+ cites_value = []
315
+ docs = []
316
+ prompt_len = []
317
+ # batch the all adapter data
318
+ adapter_start_idx: int = 0
319
+ for adapter in all_train_data:
320
+ adapter_end_idx: int = adapter_start_idx + len(all_train_data[adapter])
321
+ for data in all_train_data[adapter]:
322
+ tokens: Tokens = data.tokens.copy()
323
+ #print(data.inputs)
324
+ #print("")
325
+ def condition(i):
326
+ return (128010 <= i <= 128255) or i in {128004, 128002, 128003, 128005, 128008}
327
+ prompt_len.append(data.prompt_len)
328
+ cite = [index for index, value in enumerate(tokens) if condition(value)]
329
+ cite_value = [value for value in tokens if condition(value)]
330
+ assert len(cite) <40, print(f"too long!!! need:{len(cites)}")
331
+ if len(cite) > 0:
332
+ if cite[len(cite) - 1] != data.token_len:
333
+ cite.append(data.token_len)
334
+ pad_side = self.tokenizer_.padding_side_
335
+ assert pad_side == "right" or pad_side == "left"
336
+ # pad the tokens to align
337
+ while len(tokens) < batch_seq_len:
338
+ if pad_side == "right":
339
+ tokens.append(self.tokenizer_.pad_id_)
340
+ else:
341
+ tokens.insert(0, self.tokenizer_.pad_id_)
342
+ batch_tokens.append(tokens)
343
+ cites.append(cite.copy())
344
+ cites_value.append(cite_value.copy())
345
+ if data.citation_embeds == None:
346
+ docs.append(data.citation_tokens)
347
+ else:
348
+ docs.append(data.citation_embeds)
349
+ attention_masks.append(self.tokenizer_.mask_from(tokens))
350
+ labels = data.labels
351
+ if labels is None:
352
+ labels = tokens.copy()
353
+ else:
354
+ labels = labels.copy()
355
+ batch_labels.append(labels)
356
+
357
+ lora_batch_data_config.append(
358
+ LLMBatchConfig(
359
+ adapter_name_=adapter,
360
+ batch_start_idx_=adapter_start_idx,
361
+ batch_end_idx_=adapter_end_idx,
362
+ )
363
+ )
364
+ adapter_start_idx = adapter_end_idx
365
+
366
+ self.__dispatch_task_out()
367
+
368
+ return LLMModelInput(
369
+ batch_cites = cites,
370
+ batch_cites_value=cites_value,
371
+ batch_docs = docs,
372
+ batch_prompt_len = prompt_len,
373
+ batch_configs_=lora_batch_data_config,
374
+ batch_tokens_=batch_tokens,
375
+ batch_labels_=batch_labels,
376
+ batch_masks_=attention_masks,
377
+ gradient_checkpoint_="recompute",
378
+ )
c2cite/evaluator.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import time
4
+ import sys
5
+ import os
6
+ from dataclasses import dataclass
7
+ from typing import Dict, List, Tuple, Union, Optional
8
+
9
+ import torch
10
+
11
+ from .adapters import MixLoraConfig
12
+ from .common import InputData, LLMBatchConfig, LLMModelInput, Prompt, Tokens
13
+ from .model import LLMModel
14
+ from .tasks import BasicMetric, BasicTask, CommonSenseTask, task_dict
15
+ from .tokenizer import Tokenizer
16
+ from moe_peft.prompter import Prompter
17
+ from moe_peft.generator import _batch_generate
18
+ from moe_peft.solutions import get_output
19
+
20
+ @dataclass
21
+ class GenerateData:
22
+ adapter_name_: str = None
23
+ prompt_index_: int = None
24
+ prefix_length_: int = None
25
+ raw_tokens_: Tokens = None
26
+
27
+
28
+ @dataclass
29
+ class GenerateConfig:
30
+ adapter_name: str = None
31
+ prompts: List[Union[str, Tuple[str, str]]] = None
32
+ prompt_template: str = None
33
+ # Generate Arguments
34
+ batch_size: int = 8
35
+ stop_token: str = None
36
+ temperature: float = 1
37
+ top_p: float = 0.9
38
+ top_k: float = 50
39
+ do_sample: bool = True
40
+ repetition_penalty: float = 1.1
41
+ renormalize_logits: bool = True
42
+ # Do not set these manually
43
+ prompter_: Prompter = None
44
+ stop_token_: torch.Tensor = None
45
+ data_: List[GenerateData] = None
46
+
47
+ # Set prompt_template_ to enable the prompter
48
+ def generate_prompt(self, instruction: str, input: str = None) -> str:
49
+ if self.prompter_ is None:
50
+ self.prompter_ = Prompter(self.prompt_template)
51
+
52
+ return self.prompter_.generate_prompt(instruction=instruction, input=input)
53
+
54
+ def get_prompts(self) -> List[str]:
55
+ prompts = []
56
+ for prompt in self.prompts:
57
+ args = prompt if isinstance(prompt, Tuple) else (prompt, None)
58
+ prompts.append(self.generate_prompt(*args))
59
+
60
+ return prompts
61
+
62
+ def get_response(self, output: str) -> str:
63
+ if self.prompter_ is None:
64
+ return output.strip()
65
+ else:
66
+ return self.prompter_.get_response(output)
67
+
68
+ def reset_parameters(self):
69
+ self.prompter_ = Prompter(self.prompt_template)
70
+ self.stop_token_ = None
71
+ self.data_ = []
72
+
73
+
74
+ @dataclass
75
+ class EvaluateConfig:
76
+ adapter_name: str = None
77
+ task_name: str = None
78
+ data_path: str = None
79
+ batch_size: int = 16
80
+ router_profile: bool = False
81
+ # Do not set these manually
82
+ task_: BasicTask = None
83
+ data_: List[InputData] = None
84
+ metric_: BasicMetric = None
85
+ rollback_start_idx_: int = 0
86
+ batch_start_idx_: int = 0
87
+ batch_end_idx_: int = 0
88
+
89
+ def _dataload_fn(self, tokenizer: Tokenizer, **tokenizer_kwargs):
90
+ data = self.task_.loading_data(False, self.data_path)
91
+ for idx, data_point in enumerate(data):
92
+ assert not isinstance(data_point.inputs, Prompt)
93
+
94
+ data_point.tokens = tokenizer.encode(data_point.inputs, **tokenizer_kwargs)
95
+ data_point.prefix_length_ = len(data_point.tokens)
96
+ if data_point.citations is not None:
97
+ if data_point.citation_embeds is None:
98
+ data_point.citation_tokens = [tokenizer.encode(c, **tokenizer_kwargs)
99
+ for c in data_point.citations]
100
+ else:
101
+ data_point.citation_tokens = data_point.citation_embeds
102
+ if idx % 10000 == 0:
103
+ logging.info(f"Encode text data: {idx}/{len(data)}")
104
+
105
+ return data
106
+
107
+ @staticmethod
108
+ def from_config(config: Dict[str, any]) -> List["EvaluateConfig"]:
109
+ adapter_name = config["name"]
110
+ data_path = config.get("data", None)
111
+ task_list = config.get("task_name", "casual").split(";")
112
+ path_list = (
113
+ [None] * len(task_list) if data_path is None else data_path.split(";")
114
+ )
115
+ config_list = []
116
+ for task_name_, data_path_ in zip(task_list, path_list):
117
+ if task_name_ not in task_dict:
118
+ continue
119
+ config_list.append(
120
+ EvaluateConfig(
121
+ adapter_name=adapter_name,
122
+ task_name=task_name_,
123
+ data_path=data_path_,
124
+ batch_size=config["evaluate_batch_size"],
125
+ )
126
+ )
127
+
128
+ return config_list
129
+
130
+ def prepare(self, tokenizer: Tokenizer, device: str):
131
+ self.reset_parameters()
132
+ assert (
133
+ self.task_name != "casual"
134
+ ), "Auto evaluation is not currently available for casual supervised fine-tuning tasks."
135
+ self.task_ = task_dict[self.task_name]
136
+ self.data_ = self._dataload_fn(tokenizer)
137
+ self.metric_ = self.task_.loading_metric()
138
+ if isinstance(self.task_, CommonSenseTask):
139
+ labels = self.task_.label_list()
140
+ label_indices = [0] * len(labels)
141
+ for idx, label in enumerate(labels):
142
+ ids = tokenizer.encode(" " + label)
143
+ label_indices[idx] = ids[-1]
144
+ self.label_indices_ = torch.tensor(
145
+ label_indices, dtype=torch.int64, device=device
146
+ )
147
+ else:
148
+ self.label_indices_ = None
149
+
150
+ def reset_parameters(self):
151
+ self.task_ = None
152
+ self.data_ = None
153
+ self.metric_ = None
154
+ self.rollback_start_idx_ = 0
155
+ self.batch_start_idx_ = 0
156
+ self.batch_end_idx_ = 0
157
+
158
+
159
+ def _prepare_tasks(model, tokenizer, configs):
160
+ for config in configs:
161
+ config.prepare(tokenizer, model.device_)
162
+ if not isinstance(model.adapter_configs_[config.adapter_name], MixLoraConfig):
163
+ continue
164
+ for layer in model.model_.layers_:
165
+ if config.adapter_name in layer.mlp_.moes_:
166
+ layer.mlp_.moes_[config.adapter_name].router_profile_ = (
167
+ config.router_profile
168
+ )
169
+
170
+
171
+ def _dispatch_task_in(tokenizer, configs, concurrent_jobs, max_seq_len):
172
+ batch_data_config = []
173
+ sequence_lengths = []
174
+ current_configs = []
175
+ batch_tokens = []
176
+ batch_labels = []
177
+ more_grounds = []
178
+ atten_masks = []
179
+ max_tokens_len = 0
180
+ for config in configs:
181
+ if len(current_configs) >= concurrent_jobs:
182
+ break
183
+ if config.batch_start_idx_ >= len(config.data_):
184
+ continue
185
+ config.batch_end_idx_ = min(
186
+ config.batch_start_idx_ + config.batch_size, len(config.data_)
187
+ )
188
+ batch_start_idx = len(batch_tokens)
189
+ for idx in range(config.batch_start_idx_, config.batch_end_idx_):
190
+ if idx >= len(config.data_):
191
+ break
192
+ tokens = config.data_[idx].tokens
193
+ labels = config.data_[idx].labels
194
+ grounds = config.data_[idx].grounds
195
+ if len(tokens) > max_seq_len:
196
+ tokens = tokens[:max_seq_len]
197
+ max_tokens_len = max(len(tokens), max_tokens_len)
198
+ batch_tokens.append(tokens)
199
+ if labels:
200
+ batch_labels.append([labels].copy())
201
+ if grounds:
202
+ more_grounds.append(grounds.copy())
203
+
204
+ config.batch_start_idx_ = config.batch_end_idx_
205
+ current_configs.append(config)
206
+ batch_data_config.append(
207
+ LLMBatchConfig(
208
+ adapter_name_=config.adapter_name,
209
+ batch_start_idx_=batch_start_idx,
210
+ batch_end_idx_=len(batch_tokens),
211
+ )
212
+ )
213
+
214
+ max_seq_len = min(max_seq_len, max_tokens_len)
215
+
216
+ for tokens in batch_tokens:
217
+ sequence_lengths.append(len(tokens) - 1)
218
+ while len(tokens) < max_seq_len:
219
+ tokens.append(tokenizer.pad_id_)
220
+ atten_masks.append(tokenizer.mask_from(tokens))
221
+
222
+ return (
223
+ current_configs,
224
+ sequence_lengths,
225
+ batch_labels,
226
+ more_grounds,
227
+ LLMModelInput(
228
+ batch_configs_=batch_data_config,
229
+ batch_tokens_=batch_tokens,
230
+ batch_masks_=atten_masks,
231
+ inference_mode_=True,
232
+ ),
233
+ )
234
+
235
+
236
+ def _compute_metrcis(model, current_configs, sequence_lengths, batch_labels, outputs):
237
+ for idx, output in enumerate(outputs):
238
+ config: EvaluateConfig = current_configs[idx]
239
+ task: BasicTask = config.task_
240
+ metric: BasicMetric = config.metric_
241
+ start_idx = output.batch_start_idx_
242
+ end_idx = output.batch_end_idx_
243
+ logits = output.logits
244
+
245
+ if config.router_profile:
246
+ adapter_config = model.adapter_configs_[config.adapter_name]
247
+ if isinstance(adapter_config, MixLoraConfig):
248
+ router_statistic_ = list(0 for _ in range(adapter_config.num_experts_))
249
+ for layer in model.model_.layers_:
250
+ if config.adapter_name not in layer.mlp_.moes_:
251
+ continue
252
+ for idx, val in enumerate(
253
+ layer.mlp_.moes_[config.adapter_name].profiler_
254
+ ):
255
+ router_statistic_[idx] += val
256
+ for idx, val in enumerate(router_statistic_):
257
+ logging.info(
258
+ f"{config.adapter_name}: expert {idx}, load = {val/32}"
259
+ )
260
+
261
+ batch_size = logits.shape[0]
262
+ pooled_logits = logits[
263
+ torch.arange(batch_size, device=logits.device),
264
+ sequence_lengths[start_idx:end_idx],
265
+ ]
266
+ labels = torch.tensor(
267
+ batch_labels[start_idx:end_idx],
268
+ dtype=task.label_dtype_,
269
+ device=logits.device,
270
+ )
271
+ if task.task_type_ == "common_sense":
272
+ pooled_logits = pooled_logits[:, config.label_indices_]
273
+ pooled_logits = pooled_logits.softmax(-1).argmax(-1)
274
+ elif task.task_type_ == "single_label_classification":
275
+ pooled_logits = pooled_logits.softmax(-1).argmax(-1)
276
+ pooled_logits = pooled_logits.to(task.label_dtype_)
277
+ elif task.task_type_ != "multi_label_classification":
278
+ raise ValueError(f"unknown task type {task.task_type_}")
279
+
280
+ metric.add_batch(
281
+ predictions=pooled_logits.detach().cpu(), references=labels.detach().cpu()
282
+ )
283
+ logging.info(f"{config.adapter_name} evaluate data:")
284
+ logging.info(f" step: {config.batch_start_idx_}/{len(config.data_)}")
285
+
286
+
287
+ def _compute_result(model, configs, save_file):
288
+ results = []
289
+ for config in configs:
290
+ result = {
291
+ "adapter_name": config.adapter_name,
292
+ "task_name": config.task_name,
293
+ "date_time": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
294
+ "metrics": {},
295
+ }
296
+ compute_results = config.metric_.compute()
297
+ result["metrics"] = compute_results
298
+ if config.router_profile:
299
+ adapter_config = model.adapter_configs_[config.adapter_name]
300
+ if isinstance(adapter_config, MixLoraConfig):
301
+ router_statistic_ = list(0 for _ in range(adapter_config.num_experts_))
302
+ for layer in model.model_.layers_:
303
+ if config.adapter_name not in layer.mlp_.moes_:
304
+ continue
305
+ for idx, val in enumerate(
306
+ layer.mlp_.moes_[config.adapter_name].profiler_
307
+ ):
308
+ router_statistic_[idx] += val
309
+ layer.mlp_.moes_[config.adapter_name].profiler_ = None
310
+ result["router_profile"] = list(val / 32 for val in router_statistic_)
311
+
312
+ results.append(result)
313
+
314
+ if save_file is not None:
315
+ if not os.path.exists(save_file):
316
+ os.makedirs(save_file)
317
+ file_path = save_file + os.sep + f"{config.adapter_name}.json"
318
+ with open(file_path, "w") as f:
319
+ json.dump(results, f, indent=4)
320
+ logging.info(f"saving evaluation result to {file_path}")
321
+ else:
322
+ print(json.dumps(results, indent=4))
323
+
324
+ return results
325
+
326
+ def _dispatch_task_in2(
327
+ tokenizer,
328
+ configs: List[GenerateConfig],# config.data_, config.batch_size, config, config.adapter_name
329
+ concurrent_jobs: int,
330
+ strategy: str = "fair",
331
+ ):
332
+ assert strategy in ["fair", "fifo"], f"Unknown dispatch strategy {strategy}"
333
+ current_jobs = []
334
+ batch_config = []
335
+ input_tokens = []
336
+ max_tokens_len = 0
337
+ min_tokens_len = sys.maxsize
338
+ for config in configs:
339
+ if len(batch_config) >= concurrent_jobs:
340
+ break
341
+
342
+ if len(config.data_) == 0:
343
+ continue
344
+ print(f"count down:{len(config.data_)}")
345
+ if strategy == "fair":
346
+ per_task_jobs = max(concurrent_jobs // len(configs), 1)
347
+ else:
348
+ per_task_jobs = concurrent_jobs
349
+
350
+ per_task_jobs = min(per_task_jobs, config.batch_size)
351
+
352
+ batch_start_idx = len(input_tokens)
353
+ while per_task_jobs > 0 and len(config.data_) > 0:
354
+ per_task_jobs = per_task_jobs - 1
355
+ data = config.data_.pop(0)
356
+ current_jobs.append(data)
357
+ tokens = data.tokens
358
+ max_tokens_len = max(len(tokens), max_tokens_len)
359
+ min_tokens_len = min(len(tokens), min_tokens_len)
360
+ input_tokens.append(tokens)
361
+
362
+ batch_config.append(
363
+ LLMBatchConfig(
364
+ adapter_name_=config.adapter_name,
365
+ batch_start_idx_=batch_start_idx,
366
+ batch_end_idx_=len(input_tokens),
367
+ )
368
+ )
369
+
370
+ return (
371
+ current_jobs,
372
+ batch_config,
373
+ input_tokens,
374
+ max_tokens_len,
375
+ min_tokens_len,
376
+ )
377
+
378
+
379
+ def _generate_then_compute_metrics(
380
+ model, tokenizer, concurrent_jobs, \
381
+ max_gen_len, current_configs: List[EvaluateConfig],\
382
+ require_attention: Optional[int] = -1, require_hide: Optional[int] = -1
383
+ ):
384
+ # grounds 是qa_pair
385
+ metric = current_configs[0].metric_.metric_
386
+
387
+ ###outputs, hidden_output, hidden_atten = model.forward(input_args)
388
+
389
+ #!!! 在这把current_configs转化为GenerateConfig。现在是EvaluateConfig
390
+ #cnt = 50
391
+ #cases = []
392
+ while True:# configs里的data在变,是调度的唯一指标
393
+ dispatch_args = _dispatch_task_in2(tokenizer, current_configs, concurrent_jobs)
394
+ # 包含:current_jobs, batch_config(LLMBatchConfig(taskname,start,end)),
395
+ # batch_tokens, max_lenth, min_length
396
+
397
+ if len(dispatch_args[0]) == 0:
398
+ break
399
+ use_cache = True
400
+ cache_implementation = model.model_.cache_implementation()
401
+ if cache_implementation is None:
402
+ logging.warn(
403
+ "Cache disabled by model, use cache_implementation to force enable."
404
+ )
405
+ use_cache = False
406
+ outputs, running_jobs = _batch_generate(
407
+ model,
408
+ tokenizer,
409
+ max_gen_len,
410
+ use_cache,
411
+ require_attention,
412
+ require_hide,
413
+ cache_implementation,
414
+ None,
415
+ *dispatch_args,
416
+ )
417
+ for data in running_jobs:
418
+ current_configs[0].data_.append(data)
419
+
420
+ print(f"\noutput:{outputs[0]}\n")
421
+ metric.add_batch(
422
+ {
423
+ 'output': outputs[0],
424
+ 'qa_pairs': dispatch_args[0][0].grounds,
425
+ 'answer': dispatch_args[0][0].labels,
426
+ 'docs': dispatch_args[0][0].citations,
427
+ 'query': dispatch_args[0][0].query,
428
+ }
429
+ )
430
+
431
+
432
+ @torch.inference_mode()
433
+ def evaluate(
434
+ model: LLMModel,
435
+ tokenizer: Tokenizer,
436
+ configs: List[EvaluateConfig],
437
+ max_concurrent_jobs: int = None,
438
+ retrying_steps: int = 20,
439
+ max_seq_len: int = 512,
440
+ save_file: str = None,
441
+ require_attention: Optional[int] = -1,
442
+ require_hide: Optional[int] = -1,
443
+ ) -> Dict:
444
+
445
+ if max_concurrent_jobs is None:
446
+ max_concurrent_jobs = len(configs)
447
+ logging.info(
448
+ f"Setting max_concurrent_jobs to {max_concurrent_jobs} automatically"
449
+ )
450
+
451
+ assert max_concurrent_jobs > 0
452
+ assert retrying_steps > 0
453
+
454
+ _prepare_tasks(model, tokenizer, configs)
455
+
456
+ concurrent_jobs = max_concurrent_jobs
457
+ retrying_count = 0
458
+ while True:
459
+ if concurrent_jobs < max_concurrent_jobs and retrying_count > 0:
460
+ retrying_count -= 1
461
+ if retrying_count == 0:
462
+ concurrent_jobs += 1
463
+ logging.info(f"recovering concurrent jobs to {concurrent_jobs}")
464
+
465
+ current_configs, sequence_lengths, batch_labels, grounds, input_args = _dispatch_task_in(
466
+ tokenizer, configs, concurrent_jobs, max_seq_len
467
+ )
468
+ # current_configs(这个batch的configs)
469
+ # sequence_lengths(这个batch里的tokens的length)
470
+ # batch_labels、grounds
471
+ # input_args: LLMBatchConfig (batch_config(adapter_name, start,end),/
472
+ # tokens, attention_mask)
473
+ if len(current_configs) == 0:
474
+ break
475
+
476
+ try:
477
+ if current_configs[0].task_.task_type_ == 'attribute':
478
+ _generate_then_compute_metrics(
479
+ model,
480
+ tokenizer,
481
+ concurrent_jobs,
482
+ max_seq_len,
483
+ current_configs,
484
+ require_attention,
485
+ require_hide
486
+ )
487
+ else:
488
+ _compute_metrcis(
489
+ model,
490
+ current_configs,
491
+ sequence_lengths,
492
+ batch_labels,
493
+ model.forward(input_args),
494
+ )
495
+
496
+ except RuntimeError as e:
497
+ if "out of memory" in str(e).lower():
498
+ concurrent_jobs -= 1
499
+ if concurrent_jobs == 0:
500
+ raise e
501
+ logging.warn(
502
+ f"deprecating concurrent jobs to {concurrent_jobs} due to OOM."
503
+ )
504
+ # rollback
505
+ retrying_count = retrying_steps
506
+ for config in current_configs:
507
+ config.batch_start_idx_ = config.rollback_start_idx_
508
+ logging.info(
509
+ f"{config.adapter_name}: rollback to {config.batch_start_idx_}/{len(config.data_)}"
510
+ )
511
+ continue
512
+ else:
513
+ raise e
514
+
515
+ for config in current_configs:
516
+ config.rollback_start_idx_ = config.batch_start_idx_
517
+
518
+ return _compute_result(model, configs, save_file)
c2cite/executors/__init__.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+
4
+ import torch
5
+
6
+ from .common import BasicExecutor
7
+ from .cpu import CPUExecutor
8
+ from .cuda import CUDAExecutor
9
+ from .mps import MPSExecutor
10
+
11
+ executor_dict = {
12
+ "CUDA": CUDAExecutor,
13
+ "MPS": MPSExecutor,
14
+ "CPU": CPUExecutor,
15
+ }
16
+
17
+
18
+ def _init_executor():
19
+ env = os.getenv("MOE_PEFT_EXECUTOR_TYPE")
20
+ if env is not None:
21
+ env = env.upper()
22
+ if env not in executor_dict:
23
+ raise ValueError(f"Assigning unknown executor type {env}")
24
+ return executor_dict[env]()
25
+ elif torch.cuda.is_available():
26
+ return CUDAExecutor()
27
+ elif torch.backends.mps.is_available():
28
+ return MPSExecutor()
29
+ else:
30
+ return CPUExecutor()
31
+
32
+
33
+ executor: BasicExecutor = _init_executor()
34
+
35
+
36
+ class no_cache(object):
37
+ def __enter__(self):
38
+ executor.empty_cache()
39
+ gc.collect()
40
+ return self
41
+
42
+ def __exit__(self, type, value, traceback):
43
+ executor.empty_cache()
44
+ gc.collect()
45
+
46
+
47
+ __all__ = [
48
+ "BasicExecutor",
49
+ "CUDAExecutor",
50
+ "MPSExecutor",
51
+ "CPUExecutor",
52
+ "executor",
53
+ "no_cache",
54
+ ]
c2cite/executors/common.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+
4
+ import torch
5
+ from transformers.utils import is_torch_bf16_available_on_device
6
+
7
+ from moe_peft.utils import NoneContexts
8
+
9
+
10
+ class BasicExecutor:
11
+ def name(self) -> str:
12
+ raise NotImplementedError()
13
+
14
+ def device_name(self) -> str:
15
+ raise NotImplementedError()
16
+
17
+ def default_device_name(self) -> str:
18
+ return self.device_name()
19
+
20
+ def is_available(self) -> bool:
21
+ raise NotImplementedError()
22
+
23
+ def is_initialized(self) -> bool:
24
+ raise NotImplementedError()
25
+
26
+ def is_bf16_supported(self) -> bool:
27
+ return is_torch_bf16_available_on_device(self.device_name())
28
+
29
+ def manual_seed(self, seed: int):
30
+ random.seed(seed)
31
+ torch.manual_seed(seed)
32
+
33
+ def empty_cache(self):
34
+ raise NotImplementedError()
35
+
36
+ def use_deterministic_algorithms(self, mode: bool):
37
+ torch.use_deterministic_algorithms(mode)
38
+
39
+ def allow_tf32(self, mode: bool):
40
+ raise NotImplementedError()
41
+
42
+ def set_rng_state(self, device, state):
43
+ raise NotImplementedError()
44
+
45
+ def get_rng_state(self, device):
46
+ raise NotImplementedError()
47
+
48
+ def fork_rng(self, rng_devices: list):
49
+ return torch.random.fork_rng(
50
+ devices=rng_devices, device_type=self.device_name()
51
+ )
52
+
53
+ def autocast(self, **kwargs):
54
+ return NoneContexts()
55
+
56
+ def init_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
57
+ return torch.empty_like(tensor)
58
+
59
+ def index_fill(
60
+ self, input: torch.Tensor, dim: int, index: torch.Tensor, value: torch.Tensor
61
+ ):
62
+ input.index_fill_(dim, index, value)
63
+
64
+ def index_copy(
65
+ self, input: torch.Tensor, dim: int, index: torch.Tensor, source: torch.Tensor
66
+ ):
67
+ input.index_copy_(dim, index, source)
68
+
69
+ def check_available(self):
70
+ if not self.is_available():
71
+ logging.error(f"{self.name()} not available.")
72
+ return False
73
+ if not self.is_initialized():
74
+ logging.error(f"{self.name()} not initialized.")
75
+ return False
76
+ logging.info(f"{self.name()} initialized successfully.")
77
+ return True
c2cite/executors/cpu.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import logging
3
+
4
+ import torch
5
+
6
+ from .common import BasicExecutor
7
+
8
+
9
+ class CPUExecutor(BasicExecutor):
10
+ def __init__(self) -> None:
11
+ super().__init__()
12
+
13
+ def name(self) -> str:
14
+ return "CPU"
15
+
16
+ def device_name(self) -> str:
17
+ return "cpu"
18
+
19
+ def is_available(self) -> bool:
20
+ return True
21
+
22
+ def is_initialized(self) -> bool:
23
+ return False
24
+
25
+ def empty_cache(self):
26
+ pass
27
+
28
+ def allow_tf32(self, mode: bool):
29
+ assert not mode, "Enabling tf32 for CPU."
30
+
31
+ def set_rng_state(self, device: int, state: torch.Tensor):
32
+ assert device == 0
33
+ torch.set_rng_state(state)
34
+
35
+ def get_rng_state(self, device: int):
36
+ assert device == 0
37
+ return torch.get_rng_state()
38
+
39
+ @contextlib.contextmanager
40
+ def fork_rng(self, rng_devices: list):
41
+ # TODO: change to official implementation
42
+ assert len(rng_devices) == 0
43
+ cpu_rng_state = torch.get_rng_state()
44
+ try:
45
+ yield
46
+ finally:
47
+ torch.set_rng_state(cpu_rng_state)
48
+
49
+ def check_available(self):
50
+ logging.info(f"{self.name()} initialized successfully.")
51
+ return True
c2cite/executors/cuda.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from .common import BasicExecutor
4
+
5
+
6
+ class CUDAExecutor(BasicExecutor):
7
+ def __init__(self) -> None:
8
+ super().__init__()
9
+ torch.cuda.init()
10
+
11
+ def name(self) -> str:
12
+ return "NVIDIA CUDA"
13
+
14
+ def device_name(self) -> str:
15
+ return "cuda"
16
+
17
+ def default_device_name(self) -> str:
18
+ return "cuda:0"
19
+
20
+ def is_available(self) -> bool:
21
+ return torch.cuda.is_available()
22
+
23
+ def is_initialized(self) -> bool:
24
+ return torch.cuda.is_initialized()
25
+
26
+ def is_bf16_supported(self) -> bool:
27
+ return torch.cuda.is_bf16_supported()
28
+
29
+ def manual_seed(self, seed: int):
30
+ super().manual_seed(seed)
31
+ torch.cuda.manual_seed_all(seed)
32
+
33
+ def empty_cache(self):
34
+ torch.cuda.empty_cache()
35
+
36
+ def use_deterministic_algorithms(self, mode: bool):
37
+ torch.backends.cudnn.benchmark = not mode
38
+ torch.backends.cudnn.deterministic = mode
39
+
40
+ def allow_tf32(self, mode: bool):
41
+ torch.backends.cudnn.allow_tf32 = mode
42
+ torch.backends.cuda.matmul.allow_tf32 = mode
43
+
44
+ def set_rng_state(self, device, state):
45
+ with torch.cuda.device(device):
46
+ return torch.cuda.set_rng_state(state)
47
+
48
+ def get_rng_state(self, device):
49
+ with torch.cuda.device(device):
50
+ return torch.cuda.get_rng_state()
51
+
52
+ def autocast(self, **kwargs):
53
+ return torch.cuda.amp.autocast(**kwargs)
c2cite/executors/mps.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+
3
+ import torch
4
+
5
+ from .common import BasicExecutor
6
+
7
+
8
+ class MPSExecutor(BasicExecutor):
9
+ def __init__(self) -> None:
10
+ super().__init__()
11
+
12
+ def name(self) -> str:
13
+ return "APPLE MPS"
14
+
15
+ def device_name(self) -> str:
16
+ return "mps"
17
+
18
+ def is_available(self) -> bool:
19
+ return torch.backends.mps.is_available() and torch.backends.mps.is_built()
20
+
21
+ def is_initialized(self) -> bool:
22
+ # TODO: change to official implementation
23
+ return not torch.mps._is_in_bad_fork()
24
+
25
+ def manual_seed(self, seed: int):
26
+ super().manual_seed(seed)
27
+ torch.mps.manual_seed(seed)
28
+
29
+ def empty_cache(self):
30
+ torch.mps.empty_cache()
31
+
32
+ def allow_tf32(self, mode: bool):
33
+ assert not mode, "Enabling tf32 for MPS devices."
34
+
35
+ def set_rng_state(self, device: int, state: torch.Tensor):
36
+ assert device == 0
37
+ return torch.mps.set_rng_state(state)
38
+
39
+ def get_rng_state(self, device: int):
40
+ assert device == 0
41
+ return torch.mps.get_rng_state()
42
+
43
+ @contextlib.contextmanager
44
+ def fork_rng(self, rng_devices: list):
45
+ # TODO: change to official implementation
46
+ assert len(rng_devices) == 1 and rng_devices[0] == 0
47
+ cpu_rng_state = torch.get_rng_state()
48
+ device_rng_states = torch.mps.get_rng_state()
49
+ try:
50
+ yield
51
+ finally:
52
+ torch.set_rng_state(cpu_rng_state)
53
+ torch.mps.set_rng_state(device_rng_states)
54
+
55
+ def autocast(self, **kwargs):
56
+ # TODO: change to official implementation
57
+ # running with compatible mode
58
+ return torch.cuda.amp.autocast(**kwargs)
59
+
60
+ def init_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
61
+ return torch.zeros_like(tensor)
62
+
63
+ def index_fill(
64
+ self, input: torch.Tensor, dim: int, index: torch.Tensor, value: torch.Tensor
65
+ ):
66
+ pass
67
+
68
+ def index_copy(
69
+ self, input: torch.Tensor, dim: int, index: torch.Tensor, source: torch.Tensor
70
+ ):
71
+ input.index_add_(dim, index, source)
c2cite/generator.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import sys
3
+ from dataclasses import dataclass
4
+ from typing import Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ import re
8
+ import matplotlib.pyplot as plt
9
+
10
+ from moe_peft.common import LLMBatchConfig, LLMModelInput, Tokens, cache_factory
11
+ from moe_peft.executors import executor
12
+ from moe_peft.model import LLMModel
13
+ from moe_peft.prompter import Prompter
14
+ from moe_peft.tokenizer import Tokenizer
15
+ from moe_peft.solutions import get_output
16
+
17
+
18
+ @dataclass
19
+ class GenerateData:
20
+ adapter_name_: str = None
21
+ prompt_index_: int = None
22
+ prefix_length_: int = None
23
+ raw_tokens_: Tokens = None
24
+
25
+
26
+ @dataclass
27
+ class GenerateConfig:
28
+ adapter_name: str = None
29
+ prompts: List[Union[str, Tuple[str, str]]] = None
30
+ prompt_template: str = None
31
+ # Generate Arguments
32
+ batch_size: int = 8
33
+ stop_token: str = None
34
+ temperature: float = 1
35
+ top_p: float = 0.9
36
+ top_k: float = 50
37
+ do_sample: bool = True
38
+ repetition_penalty: float = 1.1
39
+ renormalize_logits: bool = True
40
+ # Do not set these manually
41
+ prompter_: Prompter = None
42
+ stop_token_: torch.Tensor = None
43
+ data_: List[GenerateData] = None
44
+
45
+ # Set prompt_template_ to enable the prompter
46
+ def generate_prompt(self, instruction: str, input: str = None) -> str:
47
+ if self.prompter_ is None:
48
+ self.prompter_ = Prompter(self.prompt_template)
49
+
50
+ return self.prompter_.generate_prompt(instruction=instruction, input=input)
51
+
52
+ def get_prompts(self) -> List[str]:
53
+ prompts = []
54
+ for prompt in self.prompts:
55
+ args = prompt if isinstance(prompt, Tuple) else (prompt, None)
56
+ prompts.append(self.generate_prompt(*args))
57
+
58
+ return prompts
59
+
60
+ def get_response(self, output: str) -> str:
61
+ if self.prompter_ is None:
62
+ return output.strip()
63
+ else:
64
+ return self.prompter_.get_response(output)
65
+
66
+ def reset_parameters(self):
67
+ self.prompter_ = Prompter(self.prompt_template)
68
+ self.stop_token_ = None
69
+ self.data_ = []
70
+
71
+
72
+ def _logits_sample_top_p(probs, p, filter_value=float("-inf"), min_tokens_to_keep=1):
73
+ sorted_logits, sorted_indices = torch.sort(probs, descending=False)
74
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
75
+ sorted_indices_to_remove = cumulative_probs <= (1 - p)
76
+ sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0
77
+ indices_to_remove = sorted_indices_to_remove.scatter(
78
+ 1, sorted_indices, sorted_indices_to_remove
79
+ )
80
+ return probs.masked_fill(indices_to_remove, filter_value)
81
+
82
+
83
+ def _logits_sample_top_k(probs, k, filter_value=float("-inf")):
84
+ top_k = min(k, probs.size(-1)) # Safety check
85
+ indices_to_remove = probs < torch.topk(probs, top_k)[0][..., -1, None]
86
+ return probs.masked_fill(indices_to_remove, filter_value)
87
+
88
+
89
+ def _logits_repetition_penalty(prev_tokens, probs, penalty):
90
+ score = torch.gather(probs, 1, prev_tokens)
91
+ score = torch.where(score < 0, score * penalty, score / penalty)
92
+ probs.scatter_(1, prev_tokens, score)
93
+ return probs
94
+
95
+
96
+ def id2token(x):
97
+ if x == 0:
98
+ return 128002
99
+ elif x == 1:
100
+ return 128003
101
+ elif x == 2:
102
+ return 128004
103
+ elif x == 3:
104
+ return 128005
105
+ elif x == 4:
106
+ return 128008
107
+ elif x >= 5:
108
+ return 128005 + x
109
+ else:
110
+ assert False, "wrong router"
111
+
112
+ def logits_process(
113
+ probs: torch.Tensor,
114
+ prev_tokens: torch.Tensor,
115
+ cite_flag = False,
116
+ temperature=0.9,
117
+ top_p=0,
118
+ top_k=0,
119
+ do_sample=True,
120
+ repetition_penalty=1.01,
121
+ renormalize_logits=True,
122
+ ):
123
+ if cite_flag == False:
124
+ process_conditions = any([repetition_penalty > 0])
125
+ sample_conditions = any([temperature > 0, top_p > 0 and top_p <= 1.0, top_k > 0])
126
+
127
+ if not do_sample and sample_conditions:
128
+ do_sample = True
129
+ logging.warn("do_sample force to enabled.")
130
+
131
+ if repetition_penalty > 0:
132
+ probs = _logits_repetition_penalty(prev_tokens, probs, repetition_penalty)
133
+
134
+ if process_conditions and renormalize_logits:
135
+ probs = probs.log_softmax(-1)
136
+
137
+ if temperature > 0:
138
+ probs = probs / temperature
139
+
140
+ if top_k > 0:
141
+ probs = _logits_sample_top_k(probs, top_k)
142
+
143
+ if top_p > 0 and top_p <= 1.0:
144
+ probs = _logits_sample_top_p(probs, top_p)
145
+
146
+ if sample_conditions and renormalize_logits:
147
+ probs = probs.log_softmax(-1)
148
+ else:
149
+ do_sample = False
150
+
151
+ if do_sample:
152
+ probs = torch.softmax(probs, dim=-1)
153
+ next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
154
+ else:
155
+ next_token = torch.argmax(probs, dim=-1)
156
+
157
+ if cite_flag:
158
+ for i in range(probs.shape[0]):
159
+ next_token[i] = id2token(next_token[i] + 1)
160
+ return next_token.reshape(-1)
161
+
162
+
163
+ def _extract_effective_tokens(
164
+ tokenizer: Tokenizer,
165
+ prefix_length: int,
166
+ tokens: Tokens,
167
+ remove_prefix=True,
168
+ remove_pad=True,
169
+ remove_eos=True,
170
+ ):
171
+ if remove_prefix:
172
+ tokens = tokens[prefix_length:]
173
+
174
+ if remove_pad and tokenizer.pad_id_ in tokens:
175
+ pad_idx = tokens.index(tokenizer.pad_id_)
176
+ tokens = tokens[:pad_idx]
177
+
178
+ if remove_eos and tokenizer.eos_id_ in tokens:
179
+ stop_idx = tokens.index(tokenizer.eos_id_)
180
+ tokens = tokens[:stop_idx]
181
+
182
+ return tokens
183
+
184
+
185
+ def _gen_outputs(
186
+ tokenizer: Tokenizer,
187
+ config_dict: Dict[str, GenerateConfig],
188
+ current_jobs: List[GenerateData],
189
+ tokens: torch.Tensor,
190
+ ):
191
+ tokens = tokens.tolist()
192
+ packed_outputs: Dict[str, List[str]] = {}
193
+ for idx, data in enumerate(current_jobs):
194
+ output = config_dict[data.adapter_name_].get_response(
195
+ tokenizer.decode(
196
+ _extract_effective_tokens(
197
+ tokenizer,
198
+ data.prefix_length_,
199
+ tokens[idx],
200
+ remove_prefix=True,
201
+ remove_pad=True,
202
+ remove_eos=True,
203
+ )
204
+ )
205
+ )
206
+ if data.adapter_name_ in packed_outputs:
207
+ packed_outputs[data.adapter_name_].append(output)
208
+ else:
209
+ packed_outputs[data.adapter_name_] = [output]
210
+
211
+ return packed_outputs
212
+
213
+
214
+ def _dispatch_task_in(
215
+ configs: List[GenerateConfig],# config.data_, config.batch_size, config, config.adapter_name
216
+ concurrent_jobs: int,
217
+ strategy: str = "fair",
218
+ ):
219
+ assert strategy in ["fair", "fifo"], f"Unknown dispatch strategy {strategy}"
220
+ current_jobs = []
221
+ batch_config = []
222
+ input_tokens = []
223
+ max_tokens_len = 0
224
+ min_tokens_len = sys.maxsize
225
+ for config in configs:
226
+ if len(batch_config) >= concurrent_jobs:
227
+ break
228
+
229
+ if len(config.data_) == 0:
230
+ continue
231
+
232
+ if strategy == "fair":
233
+ per_task_jobs = max(concurrent_jobs // len(configs), 1)
234
+ else:
235
+ per_task_jobs = concurrent_jobs
236
+
237
+ per_task_jobs = min(per_task_jobs, config.batch_size)
238
+
239
+ batch_start_idx = len(input_tokens)
240
+ while per_task_jobs > 0 and len(config.data_) > 0:
241
+ per_task_jobs = per_task_jobs - 1
242
+ data = config.data_.pop(0)
243
+ current_jobs.append(data)
244
+ tokens = data.raw_tokens_
245
+ max_tokens_len = max(len(tokens), max_tokens_len)
246
+ min_tokens_len = min(len(tokens), min_tokens_len)
247
+ input_tokens.append(tokens)
248
+
249
+ batch_config.append(
250
+ LLMBatchConfig(
251
+ adapter_name_=config.adapter_name,
252
+ batch_start_idx_=batch_start_idx,
253
+ batch_end_idx_=len(input_tokens),
254
+ )
255
+ )
256
+
257
+ return (
258
+ current_jobs,
259
+ batch_config,
260
+ input_tokens,
261
+ max_tokens_len,
262
+ min_tokens_len,
263
+ )
264
+
265
+
266
+ def _dispatch_task_out(
267
+ tokenizer: Tokenizer,
268
+ # config_dict: Dict[str, GenerateConfig],
269
+ current_jobs: List[GenerateData],
270
+ tokens: torch.Tensor,
271
+ stop_reached: torch.Tensor,
272
+ attentions,
273
+ hides,
274
+ require_attention,
275
+ require_hide
276
+ ):
277
+ """hide = []
278
+ if require_hide != -1:
279
+ ans_len = len(hides)
280
+ for i in range(len(hides[0])):
281
+ hide.append(torch.cat([t[i] for t in hides], dim = 1))
282
+ if require_attention != -1:
283
+ ans_len = len(attentions)
284
+ for i in range(len(hides[0])):
285
+ hide.append(torch.cat([t[i] for t in attentions], dim = 1))"""
286
+ tokens = tokens.tolist()
287
+ stop_reached = stop_reached.view(-1).tolist()
288
+ packed_outputs: List[str] = []
289
+ packed_add = []
290
+ running_jobs: List[GenerateData] = []
291
+ for idx, data in enumerate(current_jobs): # 这里的data是evaluate data, 但是应该是generate data
292
+ if stop_reached[idx]:
293
+ output_tokens = _extract_effective_tokens(
294
+ tokenizer,
295
+ data.prefix_length_,
296
+ tokens[idx],
297
+ remove_prefix=True,
298
+ remove_pad=True,
299
+ remove_eos=True,
300
+ )
301
+ #if len(hide):
302
+ # get_output(hide, output_tokens, ans_len)
303
+ output_s = tokenizer.decode(output_tokens).strip()
304
+ output = re.sub(r'<\|reserved_special_token_(\d+)\|>', r'[\1]', output_s)
305
+ packed_outputs.append(output)
306
+ else:
307
+ data.tokens = _extract_effective_tokens(
308
+ tokenizer,
309
+ data.prefix_length_,
310
+ tokens[idx],
311
+ remove_prefix=False,
312
+ remove_pad=True,
313
+ remove_eos=False,
314
+ )
315
+ running_jobs.append(data)
316
+
317
+ return packed_outputs, running_jobs
318
+
319
+
320
+ def _batch_generate(
321
+ model: LLMModel,
322
+ tokenizer: Tokenizer,
323
+ max_gen_len: Optional[int],
324
+ use_cache: bool,
325
+ require_attention: Optional[int],
326
+ require_hide: Optional[int],
327
+ cache_implementation: Optional[str],
328
+ stream_callback: Optional[Callable],
329
+ #config_dict: Dict[str, GenerateConfig],
330
+ current_jobs: List[GenerateData],
331
+ batch_config: List[LLMBatchConfig],
332
+ input_tokens: List[Tokens],
333
+ max_tokens_len: int,
334
+ min_tokens_len: int,
335
+ ):
336
+ executor.empty_cache()
337
+ device = torch.device(model.device_)
338
+ batch_size = len(input_tokens)
339
+ if max_gen_len is None:
340
+ max_gen_len = model.config_.max_seq_len_ - max_tokens_len
341
+ total_len = min(model.config_.max_seq_len_, max_gen_len + max_tokens_len)
342
+ past_key_values = (
343
+ cache_factory(
344
+ cache_implementation=cache_implementation,
345
+ config=model.model_.model_config(),
346
+ batch_size=batch_size,
347
+ max_cache_len=total_len,
348
+ )
349
+ if cache_implementation is not None
350
+ else None
351
+ )
352
+
353
+ tokens = torch.full(
354
+ (batch_size, total_len), tokenizer.pad_id_, dtype=torch.int64, device=device
355
+ )
356
+ # print(f"yyyyyy:\n{tokenizer.decode(input_tokens[0])}")
357
+ for k, t in enumerate(input_tokens):
358
+ tokens[k, : len(t)] = torch.tensor(t, dtype=torch.int64, device=device)
359
+ def condition(i):
360
+ return (128010 <= i <= 128255) or i in {128005, 128004, 128003, 128002, 128008}
361
+ prompt_len = len(input_tokens[0])
362
+ cite = [index for index, value in enumerate(input_tokens[0]) if condition(value)]
363
+ cite_v = [value for value in input_tokens[0] if condition(value)]
364
+
365
+ prev_pos = 0
366
+ stop_reached = torch.tensor([False] * batch_size, device=device)
367
+ input_text_mask = tokens != tokenizer.pad_id_
368
+
369
+ hidden_states = []
370
+ hidden_attentions = []
371
+ #arti_mask = torch.ones(batch_size, total_len, device=device, dtype=torch.int64)
372
+ cite_start = -1
373
+ #flag = -1
374
+ plac = []
375
+ for cur_pos in range(min_tokens_len, total_len):
376
+ input_data = LLMModelInput(
377
+ batch_configs_=batch_config,
378
+ batch_tokens_=tokens[:, prev_pos:cur_pos].tolist(),
379
+ #batch_masks_ = arti_mask,############
380
+ batch_cites = [cite],
381
+ batch_cites_value = [cite_v],
382
+ batch_docs = [current_jobs[0].citation_tokens],
383
+ batch_prompt_len = [prompt_len],
384
+ inference_mode_=True,
385
+ )
386
+ # print(f"fuck:\n{tokenizer.decode(tokens[0, prev_pos:cur_pos])}")
387
+ outputs = model.forward(input_data, past_key_values)
388
+ #hidden_states.append(hidden_state)
389
+ #hidden_attentions.append(hidden_attention)
390
+
391
+ #if flag != -1:
392
+ #输出attention
393
+
394
+ for output in outputs:
395
+ #config = config_dict[output.adapter_name]
396
+ start_idx = output.batch_start_idx_
397
+ end_idx = output.batch_end_idx_
398
+
399
+ next_token = logits_process(
400
+ output.logits[:, -1],#####看看它的维度,这里是乘完doc的,应该是logits
401
+ tokens[start_idx:end_idx, :cur_pos],
402
+ cite_flag = output.cite_flag,
403
+ )
404
+
405
+ next_token = torch.where(
406
+ input_text_mask[start_idx:end_idx, cur_pos],
407
+ tokens[start_idx:end_idx, cur_pos],
408
+ next_token,
409
+ ).to(torch.int64)
410
+ #print(tokenizer.decode(next_token))
411
+ if output.cite_flag == True:# 记得查看input_text_mask的形状
412
+ for i in range(start_idx, end_idx):
413
+ if input_text_mask[i, cur_pos]:#纯废话,这时候考虑上多batch了
414
+ continue
415
+ cite.append(cur_pos)
416
+ cite_v.append(next_token)
417
+
418
+ tokens[start_idx:end_idx, cur_pos] = next_token
419
+ stop_criteria = (~input_text_mask[start_idx:end_idx, cur_pos]) & (
420
+ next_token == torch.tensor(
421
+ [tokenizer.eos_id_], dtype=torch.int64, device=device
422
+ )
423
+ )
424
+ stop_reached[start_idx:end_idx] |= stop_criteria
425
+ if cite_start != -1:
426
+ if tokenizer.decode(next_token)[-1] in ['.','!','?']:
427
+ #arti_mask[start_idx:end_idx, cite_start:cur_pos] = 0
428
+ #tokens[start_idx:end_idx, cur_pos] = tokenizer.encode(tokenizer.decode(next_token)[-1])[-1]
429
+ cite_start = -1
430
+ if tokenizer.decode(next_token)[-1] in ['0','1','2','3','4','5','6','7','8','9']:
431
+ plac.append(cur_pos)
432
+ # tokens[start_idx:end_idx, cur_pos] = (tokens[start_idx:end_idx, cur_pos] + 2)
433
+
434
+ if tokenizer.decode(next_token)[-1] == '[' or tokenizer.decode(next_token) == '[':
435
+ if cite_start == -1:
436
+ cite_start = cur_pos
437
+ #flag = cur_pos
438
+
439
+ stop_reached |= total_len - cur_pos == 1
440
+
441
+ if any(stop_reached):
442
+ break
443
+
444
+ if use_cache:
445
+ prev_pos = cur_pos
446
+
447
+ """input_data = LLMModelInput(
448
+ batch_configs_=batch_config,
449
+ batch_tokens_=tokens[:,:hidden_attention.shape[0]].tolist(),
450
+ inference_mode_=True,
451
+ )"""
452
+ # print(f"fuck:\n{tokenizer.decode(tokens[0, prev_pos:cur_pos])}")
453
+ #outputs, _, attn = model.forward(input_data, None, require_attention, require_hide)
454
+ """for i in plac:
455
+
456
+ plt.figure(figsize=(hidden_attention.shape[0], 5), dpi = 50)
457
+ print("painting")
458
+ plt.bar(range(hidden_attention.shape[0]), attn[:,i].cpu().numpy())
459
+ plt.xticks(range(hidden_attention.shape[0]), [tokenizer.decode(j) for j in tokens[0][:hidden_attention.shape[0]]], fontsize = 8)
460
+ plt.savefig("high_res_heatmap.svg", dpi=50)
461
+ print("ok~")
462
+ input()
463
+ """
464
+ """attn[torch.arange(hidden_attention.shape[0]), torch.arange(hidden_attention.shape[0])] = 0.0
465
+ attn = torch.nn.functional.normalize(attn, p=2, dim=1)
466
+ attn = attn[min_tokens_len:hidden_attention.shape[0],min_tokens_len:hidden_attention.shape[0]]
467
+
468
+ plt.figure(figsize=(hidden_attention.shape[0] - min_tokens_len, hidden_attention.shape[0] - min_tokens_len)) # 调整图像大小
469
+ plt.imshow(attn.cpu().numpy(), cmap='viridis', vmin = 0, vmax = 0.1)
470
+ plt.colorbar(label='Value')
471
+ plt.xticks(range(hidden_attention.shape[0] - min_tokens_len), [tokenizer.decode(i) for i in tokens[0][min_tokens_len:hidden_attention.shape[0]]], fontsize = 10)
472
+ plt.yticks(range(hidden_attention.shape[0] - min_tokens_len), [tokenizer.decode(i) for i in tokens[0][min_tokens_len:hidden_attention.shape[0]]], fontsize = 10)
473
+ plt.savefig("high_res_heatmap.png", dpi=200) # 保存为高分辨率图像
474
+ plt.show()
475
+ print("ok~")
476
+ input()"""
477
+ """token2 = tokens * arti_mask
478
+ lst = token2[0].tolist()
479
+ lst = [ele for ele in lst if ele != 0]
480
+ tokens = torch.tensor(lst, dtype=torch.int64, device=device).unsqueeze(0)"""
481
+
482
+ return _dispatch_task_out(
483
+ tokenizer, current_jobs, tokens, stop_reached, hidden_states, hidden_attentions, require_attention, require_hide
484
+ )
485
+
486
+
487
+ def _batch_generate_original(
488
+ model: LLMModel,
489
+ tokenizer: Tokenizer,
490
+ max_gen_len: Optional[int],
491
+ use_cache: bool,
492
+ cache_implementation: Optional[str],
493
+ stream_callback: Optional[Callable],
494
+ config_dict: Dict[str, GenerateConfig],
495
+ current_jobs: List[GenerateData],
496
+ batch_config: List[LLMBatchConfig],
497
+ input_tokens: List[Tokens],
498
+ max_tokens_len: int,
499
+ min_tokens_len: int,
500
+ ):
501
+ executor.empty_cache()
502
+ device = torch.device(model.device_)
503
+ batch_size = len(input_tokens)
504
+ if max_gen_len is None:
505
+ max_gen_len = model.config_.max_seq_len_ - max_tokens_len
506
+ total_len = min(model.config_.max_seq_len_, max_gen_len + max_tokens_len)
507
+
508
+ past_key_values = (
509
+ cache_factory(
510
+ cache_implementation=cache_implementation,
511
+ config=model.model_.model_config(),
512
+ batch_size=batch_size,
513
+ max_cache_len=total_len,
514
+ )
515
+ if cache_implementation is not None
516
+ else None
517
+ )
518
+
519
+ tokens = torch.full(
520
+ (batch_size, total_len), tokenizer.pad_id_, dtype=torch.int64, device=device
521
+ )
522
+ for k, t in enumerate(input_tokens):
523
+ tokens[k, : len(t)] = torch.tensor(t, dtype=torch.int64, device=device)
524
+
525
+ prev_pos = 0
526
+ stop_reached = torch.tensor([False] * batch_size, device=device)
527
+ input_text_mask = tokens != tokenizer.pad_id_
528
+ for cur_pos in range(min_tokens_len, total_len):
529
+ input_data = LLMModelInput(
530
+ batch_configs_=batch_config,
531
+ batch_tokens_=tokens[:, prev_pos:cur_pos].tolist(),
532
+ inference_mode_=True,
533
+ )
534
+ outputs = model.forward(input_data, past_key_values)
535
+ for output in outputs:
536
+ config = config_dict[output.adapter_name]
537
+ start_idx = output.batch_start_idx_
538
+ end_idx = output.batch_end_idx_
539
+
540
+ next_token = logits_process(
541
+ output.logits[:, -1],
542
+ tokens[start_idx:end_idx, :cur_pos],
543
+ config.temperature,
544
+ config.top_p,
545
+ config.top_k,
546
+ config.do_sample,
547
+ config.repetition_penalty,
548
+ config.renormalize_logits,
549
+ )
550
+
551
+ next_token = torch.where(
552
+ input_text_mask[start_idx:end_idx, cur_pos],
553
+ tokens[start_idx:end_idx, cur_pos],
554
+ next_token,
555
+ ).to(torch.int64)
556
+ tokens[start_idx:end_idx, cur_pos] = next_token
557
+ stop_criteria = (~input_text_mask[start_idx:end_idx, cur_pos]) & (
558
+ next_token == config.stop_token_
559
+ )
560
+ stop_reached[start_idx:end_idx] |= stop_criteria
561
+
562
+ stop_reached |= total_len - cur_pos == 1
563
+
564
+ if any(stop_reached):
565
+ break
566
+
567
+ if stream_callback is not None:
568
+ stream_callback(
569
+ cur_pos,
570
+ _gen_outputs(
571
+ tokenizer,
572
+ config_dict,
573
+ current_jobs,
574
+ tokens,
575
+ ),
576
+ )
577
+
578
+ if use_cache:
579
+ prev_pos = cur_pos
580
+
581
+ return _dispatch_task_out(
582
+ tokenizer, config_dict, current_jobs, tokens, stop_reached
583
+ )
584
+
585
+
586
+ @torch.inference_mode()
587
+ def generate(
588
+ model: LLMModel,
589
+ tokenizer: Tokenizer,
590
+ configs: List[GenerateConfig],
591
+ max_gen_len: Optional[int] = None,
592
+ use_cache: bool = True,
593
+ dispatch_strategy: str = "fair",
594
+ concurrent_jobs: Optional[int] = None,
595
+ cache_implementation: Optional[str] = None,
596
+ stream_callback: Optional[Callable] = None,
597
+ ):
598
+ if concurrent_jobs is None:
599
+ concurrent_jobs = len(configs)
600
+ logging.info(f"Setting concurrent jobs to {concurrent_jobs} automatically")
601
+
602
+ assert concurrent_jobs > 0
603
+
604
+ # prepare for generation
605
+ device = torch.device(model.device_)
606
+ config_dict = {}
607
+ for config in configs:
608
+ config.reset_parameters()
609
+ config_dict[config.adapter_name] = config
610
+ if config.stop_token is not None:
611
+ stop_token = tokenizer.encode(" " + config.stop_token, False)[-1]
612
+ else:
613
+ stop_token = tokenizer.eos_id_
614
+ config.stop_token_ = torch.tensor(
615
+ [stop_token], dtype=torch.int64, device=device
616
+ )
617
+ for idx, prompt in enumerate(config.prompts):
618
+ args = prompt if isinstance(prompt, Tuple) else (prompt, None)
619
+ tokens = tokenizer.encode(config.generate_prompt(*args))
620
+ assert (
621
+ len(tokens) < model.config_.max_seq_len_
622
+ ), "Inputs exceeded max sequence length of model."
623
+ config.data_.append(
624
+ GenerateData(
625
+ adapter_name_=config.adapter_name,
626
+ prompt_index_=idx,
627
+ prefix_length_=len(tokens),
628
+ raw_tokens_=tokens,
629
+ )
630
+ )
631
+
632
+ if use_cache and cache_implementation is None:
633
+ cache_implementation = model.model_.cache_implementation()
634
+ if cache_implementation is None:
635
+ logging.warn(
636
+ "Cache disabled by model, use cache_implementation to force enable."
637
+ )
638
+ use_cache = False
639
+
640
+ packed_outputs: Dict[str, List] = {}
641
+
642
+ while True:# configs里的data在变,是调度的唯一指标
643
+ dispatch_args = _dispatch_task_in(configs, concurrent_jobs, dispatch_strategy)
644
+ # 包含:current_jobs, batch_config(LLMBatchConfig(taskname,start,end)),
645
+ # batch_tokens, max_lenth, min_length
646
+ if len(dispatch_args[0]) == 0:
647
+ break
648
+
649
+ outputs, running_jobs = _batch_generate(
650
+ model,
651
+ tokenizer,
652
+ max_gen_len,
653
+ use_cache,
654
+ cache_implementation,
655
+ stream_callback,
656
+ config_dict,
657
+ *dispatch_args,
658
+ )
659
+
660
+ for name, output in outputs.items():
661
+ if name in packed_outputs:
662
+ packed_outputs[name].extend(output)
663
+ else:
664
+ packed_outputs[name] = output
665
+
666
+ for data in running_jobs:
667
+ config_dict[data.adapter_name_].data_.append(data)
668
+
669
+ return packed_outputs
c2cite/model.py ADDED
@@ -0,0 +1,1039 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import logging
4
+ import math
5
+ import os
6
+ from typing import Dict, List, Optional, Tuple
7
+ import torch.nn.functional as F
8
+
9
+
10
+ import torch
11
+ from huggingface_hub import snapshot_download
12
+ from transformers import AutoModelForCausalLM
13
+
14
+ from moe_peft.adapters import (
15
+ LoraMoeConfig,
16
+ MixLoraConfig,
17
+ MolaConfig,
18
+ lora_config_factory,
19
+ moe_layer_factory,
20
+ router_loss_factory,
21
+ )
22
+ from moe_peft.common import (
23
+ CHECKPOINT_CLASSES,
24
+ AdapterConfig,
25
+ Linear,
26
+ LLMCache,
27
+ LLMDecoder,
28
+ LLMForCausalLM,
29
+ LLMModelConfig,
30
+ LLMModelInput,
31
+ LLMModelOutput,
32
+ LLMMoeBlock,
33
+ LLMOutput,
34
+ LoraConfig,
35
+ unpack_router_logits,
36
+ )
37
+ from moe_peft.executors import executor
38
+ from moe_peft.models import from_pretrained
39
+ from moe_peft.tasks import SequenceClassificationTask, task_dict
40
+ from moe_peft.utils import is_package_available
41
+
42
+ if is_package_available("bitsandbytes"):
43
+ from transformers import BitsAndBytesConfig
44
+ else:
45
+ from moe_peft.utils import BitsAndBytesConfig
46
+
47
+
48
+ class CasualOutputLayer(LLMOutput):
49
+ def __init__(self, vocab_size: int, weight: torch.nn.Linear):
50
+ super().__init__()
51
+ self.vocab_size_: int = vocab_size
52
+ self.lm_head_: torch.nn.Module = weight
53
+
54
+ def forward(self, data: torch.Tensor) -> torch.Tensor:
55
+ return self.lm_head_(data).float()
56
+
57
+ def loss(
58
+ self, input_ids: torch.Tensor, output_logits: torch.Tensor, labels,
59
+ cites: Optional[List] = None, cites_v: Optional[List] = None, prompt_lens: Optional[List] = None
60
+ ) -> torch.Tensor:
61
+ if isinstance(labels, torch.Tensor):
62
+ labels = (
63
+ labels.clone()
64
+ .detach()
65
+ .to(dtype=torch.long, device=output_logits.device)
66
+ )
67
+ else:
68
+ labels = torch.tensor(labels, dtype=torch.long, device=output_logits.device)
69
+
70
+
71
+ loss_fn = torch.nn.CrossEntropyLoss()
72
+ if cites:
73
+ for i in range(len(labels)):
74
+ for j in range(len(cites_v[i])):
75
+ labels[i][cites[i][j]] = -100
76
+ loss_fn = torch.nn.CrossEntropyLoss(ignore_index = -100)
77
+
78
+
79
+ """return loss_fn(
80
+ output_logits[..., :-1, :].contiguous().view(-1, self.vocab_size_),
81
+ labels[..., 1:].contiguous().view(-1),
82
+ )"""
83
+ ans = 0
84
+ for i in range(len(prompt_lens)):
85
+ ans += loss_fn(
86
+ output_logits[i, prompt_lens[i] - 1:-1, :].contiguous().view(-1, self.vocab_size_),
87
+ labels[i, prompt_lens[i]:].contiguous().view(-1),
88
+ )
89
+ return ans / len(prompt_lens)
90
+
91
+
92
+ class ClassificationOutputLayer(LLMOutput):
93
+ def __init__(
94
+ self,
95
+ task_type: str,
96
+ num_labels: int,
97
+ label_dtype: torch.dtype,
98
+ hidden_size: int,
99
+ pad_token_id: int,
100
+ device: str,
101
+ weight: Optional[torch.Tensor],
102
+ ):
103
+ super().__init__()
104
+ self.label_dtype_ = label_dtype
105
+ self.num_labels_ = num_labels
106
+ self.task_type_ = task_type
107
+ self.pad_id_ = pad_token_id
108
+ self.score_ = torch.nn.Linear(
109
+ hidden_size,
110
+ self.num_labels_,
111
+ bias=False,
112
+ dtype=torch.float32,
113
+ device=device,
114
+ )
115
+ if weight is None:
116
+ torch.nn.init.kaiming_normal_(self.score_.weight, a=math.sqrt(5))
117
+ else:
118
+ with torch.no_grad():
119
+ self.score_.weight.copy_(weight["classifier"])
120
+
121
+ def state_dict(self):
122
+ return {"classifier": self.score_.weight}
123
+
124
+ def forward(self, data: torch.Tensor) -> torch.Tensor:
125
+ return self.score_(data.to(torch.float32))
126
+
127
+ def loss(
128
+ self, input_ids: torch.Tensor, output_logits: torch.Tensor, labels
129
+ ) -> torch.Tensor:
130
+ if isinstance(labels, torch.Tensor):
131
+ labels = (
132
+ labels.clone()
133
+ .detach()
134
+ .to(dtype=self.label_dtype_, device=output_logits.device)
135
+ )
136
+ else:
137
+ labels = torch.tensor(
138
+ labels, dtype=self.label_dtype_, device=output_logits.device
139
+ )
140
+ batch_size = input_ids.shape[0]
141
+ sequence_lengths = (torch.eq(input_ids, self.pad_id_).int().argmax(-1) - 1).to(
142
+ output_logits.device
143
+ )
144
+ pooled_logits = output_logits[
145
+ torch.arange(batch_size, device=output_logits.device), sequence_lengths
146
+ ]
147
+ if self.task_type_ == "single_label_classification":
148
+ loss_fn = torch.nn.CrossEntropyLoss()
149
+ return loss_fn(pooled_logits.view(-1, self.num_labels_), labels.view(-1))
150
+ elif self.task_type_ == "multi_label_classification":
151
+ loss_fn = torch.nn.BCEWithLogitsLoss()
152
+ return loss_fn(pooled_logits, labels)
153
+ else:
154
+ raise ValueError(f"unknown task type {self.task_type_}")
155
+
156
+
157
+ class OutputLayer(torch.nn.Module):
158
+ def __init__(self):
159
+ super().__init__()
160
+ self.layers_: Dict[str, torch.nn.Module] = {}
161
+
162
+ def forward(
163
+ self, data: torch.Tensor, input_args: LLMModelInput
164
+ ) -> List[LLMModelOutput]:
165
+ outputs = []
166
+ for lora_config in input_args.batch_configs_:
167
+ adapter_name = lora_config.adapter_name_
168
+ start_idx = lora_config.batch_start_idx_
169
+ end_idx = lora_config.batch_end_idx_
170
+
171
+ assert adapter_name != "" and adapter_name in self.layers_
172
+ layer = self.layers_[adapter_name]
173
+ outputs.append(
174
+ LLMModelOutput(
175
+ adapter_name=adapter_name,
176
+ logits=layer.forward(data[start_idx:end_idx]),
177
+ loss_fn_=layer.loss,
178
+ )
179
+ )
180
+
181
+ return outputs
182
+
183
+
184
+ def init_lora_layer_weight(
185
+ transformer_layer: LLMDecoder,
186
+ llm_config: LLMModelConfig,
187
+ lora_config: LoraConfig,
188
+ lora_weights: Optional[Dict[str, torch.Tensor]],
189
+ ):
190
+ target_modules = lora_config.target_modules_
191
+ attn_state_dict, mlp_state_dict = transformer_layer.state_dict()
192
+ attn_state_dict: Dict[str, torch.Tensor]
193
+ mlp_state_dict: Dict[str, torch.Tensor]
194
+ all_state_dict: Dict[str, torch.Tensor] = copy.copy(attn_state_dict)
195
+ all_state_dict.update(mlp_state_dict)
196
+ moe_init_strategy = "none"
197
+ if isinstance(lora_config, MixLoraConfig):
198
+ model_prefix_name = "mixlora"
199
+ moe_layer_name_list = list(mlp_state_dict.keys())
200
+ moe_init_strategy = "fused_mlp"
201
+ elif isinstance(lora_config, LoraMoeConfig):
202
+ model_prefix_name = "loramoe"
203
+ moe_layer_name_list = list(mlp_state_dict.keys())
204
+ moe_init_strategy = "plugin"
205
+ elif isinstance(lora_config, MolaConfig):
206
+ model_prefix_name = "mola"
207
+ moe_layer_name_list = list(all_state_dict.keys())
208
+ moe_init_strategy = "plugin"
209
+ else:
210
+ model_prefix_name = "base_model.model.model"
211
+ moe_layer_name_list = []
212
+
213
+ assert len(moe_layer_name_list) == 0 or moe_init_strategy in ["plugin", "fused_mlp"]
214
+
215
+ if moe_init_strategy == "fused_mlp":
216
+ transformer_layer.mlp_.moes_[lora_config.adapter_name] = moe_layer_factory(
217
+ llm_config.dim_,
218
+ llm_config.device_,
219
+ lora_config,
220
+ (
221
+ None
222
+ if lora_weights is None
223
+ else lora_weights[
224
+ f"{model_prefix_name}.layers.{transformer_layer.layer_id_}.mlp.moe_gate.weight"
225
+ ]
226
+ ),
227
+ )
228
+
229
+ for proj_name, lora_linear in all_state_dict.items():
230
+ lora_linear: Linear
231
+ if proj_name not in target_modules or not target_modules[proj_name]:
232
+ continue
233
+ module_name = (
234
+ "self_attn"
235
+ if proj_name in attn_state_dict
236
+ else ("mlp" if proj_name in mlp_state_dict else None)
237
+ )
238
+ module_name = f"{model_prefix_name}.layers.{transformer_layer.layer_id_}.{module_name}.{proj_name}"
239
+ if proj_name in moe_layer_name_list:
240
+ if moe_init_strategy == "plugin":
241
+ # init for gating mechanisms
242
+ lora_linear.moes_[lora_config.adapter_name] = moe_layer_factory(
243
+ lora_linear.in_features_,
244
+ llm_config.device_,
245
+ lora_config,
246
+ (
247
+ lora_weights.get(f"{module_name}.moe_gate.weight", None)
248
+ if lora_weights is not None
249
+ else None
250
+ ),
251
+ )
252
+
253
+ for expert_idx in range(lora_config.num_experts_):
254
+ if lora_weights is None:
255
+ lora_a = None
256
+ lora_b = None
257
+ else:
258
+ lora_a = lora_weights.get(
259
+ f"{module_name}.experts.{expert_idx}.lora_A.weight", None
260
+ )
261
+ lora_b = lora_weights.get(
262
+ f"{module_name}.experts.{expert_idx}.lora_B.weight", None
263
+ )
264
+
265
+ lora_linear.init_lora_weight(
266
+ lora_config.expert_config(expert_idx), (lora_a, lora_b)
267
+ )
268
+ else:
269
+ if lora_weights is None:
270
+ lora_a = None
271
+ lora_b = None
272
+ else:
273
+ lora_a = lora_weights.get(f"{module_name}.lora_A.weight", None)
274
+ lora_b = lora_weights.get(f"{module_name}.lora_B.weight", None)
275
+
276
+ lora_linear.init_lora_weight(lora_config, (lora_a, lora_b))
277
+
278
+
279
+ def get_lora_layer_weight(
280
+ transformer_layer: LLMDecoder,
281
+ lora_config: LoraConfig,
282
+ lora_weights: Dict[str, torch.Tensor],
283
+ ):
284
+ target_modules = lora_config.target_modules_
285
+ attn_state_dict, mlp_state_dict = transformer_layer.state_dict()
286
+ attn_state_dict: Dict[str, torch.Tensor]
287
+ mlp_state_dict: Dict[str, torch.Tensor]
288
+ all_state_dict: Dict[str, torch.Tensor] = copy.copy(attn_state_dict)
289
+ all_state_dict.update(mlp_state_dict)
290
+ if isinstance(lora_config, MixLoraConfig):
291
+ model_prefix_name = "mixlora"
292
+ gate_layer_name = (
293
+ f"mixlora.layers.{transformer_layer.layer_id_}.mlp.moe_gate.weight"
294
+ )
295
+ moe_layer_name_list = list(mlp_state_dict.keys())
296
+ elif isinstance(lora_config, LoraMoeConfig):
297
+ model_prefix_name = "loramoe"
298
+ moe_layer_name_list = list(mlp_state_dict.keys())
299
+ elif isinstance(lora_config, MolaConfig):
300
+ model_prefix_name = "mola"
301
+ moe_layer_name_list = list(all_state_dict.keys())
302
+ else:
303
+ model_prefix_name = "base_model.model.model"
304
+ moe_layer_name_list = []
305
+
306
+ # for fused MoEs such as MixLoRA
307
+ mlp_moe_layer: LLMMoeBlock = transformer_layer.mlp_.moes_.get(
308
+ lora_config.adapter_name, None
309
+ )
310
+ if mlp_moe_layer is not None:
311
+ lora_weights[gate_layer_name] = mlp_moe_layer.gate_.weight
312
+
313
+ for proj_name, lora_linear in all_state_dict.items():
314
+ lora_linear: Linear
315
+ if proj_name not in target_modules or not target_modules[proj_name]:
316
+ continue
317
+ module_name = (
318
+ "self_attn"
319
+ if proj_name in attn_state_dict
320
+ else ("mlp" if proj_name in mlp_state_dict else None)
321
+ )
322
+ module_name = f"{model_prefix_name}.layers.{transformer_layer.layer_id_}.{module_name}.{proj_name}"
323
+ if proj_name in moe_layer_name_list:
324
+ moe_layer = (
325
+ lora_linear.moes_[lora_config.adapter_name]
326
+ if lora_config.adapter_name in lora_linear.moes_
327
+ else mlp_moe_layer
328
+ )
329
+ # for plugged MoEs such as LoRAMoE, MoLA, etc.
330
+ if lora_config.adapter_name in lora_linear.moes_:
331
+ lora_weights[f"{module_name}.moe_gate.weight"] = lora_linear.moes_[
332
+ lora_config.adapter_name
333
+ ].gate_.weight
334
+
335
+ for expert_idx in range(moe_layer.experts_):
336
+ moe_lora_name = f"moe.{lora_config.adapter_name}.experts.{expert_idx}"
337
+ lora_obj = lora_linear.loras_.get(moe_lora_name, None)
338
+ if lora_obj is not None:
339
+ lora_weights[
340
+ f"{module_name}.experts.{expert_idx}.lora_A.weight"
341
+ ] = lora_obj.lora_a_.weight
342
+ lora_weights[
343
+ f"{module_name}.experts.{expert_idx}.lora_B.weight"
344
+ ] = lora_obj.lora_b_.weight
345
+
346
+ else:
347
+ lora_obj = lora_linear.loras_.get(lora_config.adapter_name, None)
348
+ if lora_obj is not None:
349
+ lora_weights[f"{module_name}.lora_A.weight"] = lora_obj.lora_a_.weight
350
+ lora_weights[f"{module_name}.lora_B.weight"] = lora_obj.lora_b_.weight
351
+
352
+
353
+ def get_atten_tar(x, y, device, dtype):
354
+ si = torch.arange(0, y, device=device, dtype = dtype)
355
+ xi = torch.arange(1, x, device=device, dtype = dtype)#1~19
356
+ lamb = torch.tensor(-2, device=device, dtype= dtype)
357
+ alpha = (1 - torch.exp(-(si / 200))).detach()
358
+ base = torch.empty(x-1, device=device, dtype= dtype)#(19)
359
+ #base[0] = torch.log(torch.tensor(x, device=device, dtype = dtype)-1)
360
+ base[0] = torch.exp(lamb)
361
+ for i in range(1, x-1):
362
+ #base[i] = base[i - 1] + torch.log(torch.tensor(x-i-1, device=device, dtype = dtype))
363
+ base[i] = base[i - 1] + torch.exp(lamb * (i + 1))
364
+ award = (0.1 * (0.5 - 1 / (xi + 1)) + 0.2).detach()
365
+ #beta = (torch.log(x - xi) * award).expand(xi.shape[0], x-1).T
366
+ beta = (torch.exp(lamb * xi) * award).expand(xi.shape[0], x-1).T
367
+ beta = (beta / base).detach()
368
+
369
+
370
+
371
+ return alpha, beta # alpha是从0开始的,beta[0]是1。至少321长度时,beta至少得0.8
372
+
373
+
374
+ class LLMModel(torch.nn.Module):
375
+ def __init__(self, model: LLMForCausalLM):
376
+ super().__init__()
377
+ args: LLMModelConfig = model.config_
378
+ if args.vocab_size_ >= torch.finfo(args.dtype_).max:
379
+ logging.warn(
380
+ f"vocab_size >= max({args.dtype_}), consider load model with higher precision."
381
+ )
382
+ self.model_ = model
383
+ self.config_ = args
384
+ # configs
385
+ self.name_or_path_ = args.name_or_path_
386
+ self.vocab_size_ = args.vocab_size_
387
+ self.device_ = args.device_
388
+ self.dtype_ = args.dtype_
389
+
390
+ self.attention_weight = torch.nn.Parameter(torch.empty(
391
+ model.layers_[0].self_attn_.n_heads_,1,dtype=args.dtype_,device=args.device_,))
392
+
393
+ self.routerup = torch.nn.Parameter(torch.empty(
394
+ model.config_.dim_, 2,dtype=args.dtype_,device=args.device_,))
395
+ """self. routerdown = torch.nn.Parameter(torch.empty(
396
+ model.config_.dim_ * 2, 2,dtype=args.dtype_,device=args.device_,))"""
397
+ self.cite_output = torch.nn.Parameter(torch.empty(
398
+ model.config_.dim_,model.config_.dim_,dtype=args.dtype_,device=args.device_,))
399
+ self.doc_proj = torch.nn.Parameter(torch.empty(
400
+ model.config_.dim_, model.config_.dim_,dtype=args.dtype_,device=args.device_,))
401
+
402
+ self.alpha, self.beta= get_atten_tar(40, 3000, args.device_, args.dtype_)
403
+ self.silu = torch.nn.SiLU()
404
+
405
+ self.output_ = OutputLayer()
406
+ # adapter configs
407
+ self.adapter_configs_: Dict[str, LoraConfig] = {}
408
+
409
+ def token2id(self, t):
410
+ if isinstance(t, torch.Tensor):
411
+ x = t.item()
412
+ else:
413
+ x = t
414
+ if x == 128002:
415
+ return 0
416
+ elif x == 128003:
417
+ return 1
418
+ elif x == 128004:
419
+ return 2
420
+ elif x == 128005:
421
+ return 3
422
+ elif x == 128008:
423
+ return 4
424
+ elif x >= 128010 and x <= 128255:
425
+ return x - 128005
426
+ else:
427
+ return -1
428
+
429
+ def attention_target(self, i, j, T):
430
+ return self.alpha[j] * self.beta[T, i] * self.award[i]
431
+
432
+ def _prepare_inputs(
433
+ self, input_args: LLMModelInput, past_key_values: Optional[LLMCache] = None
434
+ ):
435
+ assert input_args.batch_tokens_ is not None, "Model have no input."
436
+ assert (
437
+ input_args.gradient_checkpoint_ == "none" or past_key_values is None
438
+ ), "Cache is incompatible with gradient checkpointing."
439
+ assert (
440
+ not input_args.inference_mode_ or input_args.gradient_checkpoint_ == "none"
441
+ ), "Can not use gradient checkpoint when inference."
442
+
443
+ # prepare inputs
444
+ if isinstance(input_args.batch_tokens_, torch.Tensor):
445
+ input_ids = input_args.batch_tokens_.to(
446
+ dtype=torch.int64, device=self.device_, requires_grad=False
447
+ )
448
+ else:
449
+ input_ids = torch.tensor(
450
+ input_args.batch_tokens_, dtype=torch.int64, device=self.device_, requires_grad=False
451
+ )
452
+
453
+ inputs_embeds = self.model_.embed_tokens(input_ids)
454
+
455
+ """if input_ids.shape[-1] > 1:
456
+ self.doc_embeds = []
457
+ cites = input_args.batch_cites
458
+ docs = input_args.batch_docs
459
+ for doc in docs:
460
+ doc = doc.clone().to(self.device_)
461
+ doc = doc @ self.doc_proj
462
+ self.doc_embeds.append(doc)
463
+ for i, cite in enumerate(cites):
464
+ for c in range(len(input_args.batch_cites_value[i])):
465
+ inputs_embeds[i, cite[c]] = self.doc_embeds[i][self.token2id(input_args.batch_cites_value[i][c]) - 1].to(self.device_)
466
+ else:
467
+ fk = self.token2id(input_ids[0,0])
468
+ if fk != -1:
469
+ inputs_embeds[0][0] = self.doc_embeds[0][fk - 1].to(self.device_)"""
470
+
471
+ docs = input_args.batch_docs
472
+ if input_ids.shape[-1] > 1:
473
+ self.doc_embeds = []
474
+ cites = input_args.batch_cites
475
+ if not isinstance(docs[0][0], torch.Tensor):
476
+ for i in range(len(docs)):
477
+ d = []
478
+ for j in range(len(docs[i])):
479
+ temp = self.model_.embed_tokens(torch.tensor(
480
+ docs[i][j][1:], dtype=torch.int64, device=self.device_, requires_grad=False))
481
+ temp = torch.mean(temp, dim = 0)
482
+ d.append(temp)
483
+ d = torch.stack(d)
484
+ self.doc_embeds.append(d)
485
+ for i, cite in enumerate(cites):
486
+ for c in range(len(input_args.batch_cites_value[i])):
487
+ doc_ind = self.token2id(input_args.batch_cites_value[i][c]) - 1
488
+ assert doc_ind >= 0, print("fake cite token")
489
+ inputs_embeds[i, cite[c]] = self.doc_embeds[i][doc_ind].to(self.device_)
490
+ else:
491
+ fk = self.token2id(input_ids[0,0]) - 1
492
+ if fk >= 0:
493
+ inputs_embeds[0][0] = self.doc_embeds[0][fk].to(self.device_)
494
+
495
+ if input_args.gradient_checkpoint_ != "none":
496
+ inputs_embeds.requires_grad_(True)
497
+
498
+ # prepare cache
499
+ past_seen_tokens = (
500
+ past_key_values.get_seq_length() if past_key_values is not None else 0
501
+ )
502
+
503
+ if past_seen_tokens is None:
504
+ past_seen_tokens = 0
505
+
506
+ cache_position = torch.arange(
507
+ past_seen_tokens,
508
+ past_seen_tokens + inputs_embeds.shape[1],
509
+ device=inputs_embeds.device,
510
+ )
511
+
512
+ # prepare mask
513
+ if input_args.batch_masks_ is not None:
514
+ # 2d mask is passed through the layers
515
+ if isinstance(input_args.batch_masks_, torch.Tensor):
516
+ attention_mask = input_args.batch_masks_.to(
517
+ dtype=torch.int64, device=self.device_
518
+ )
519
+ else:
520
+ attention_mask = torch.tensor(
521
+ input_args.batch_masks_, dtype=torch.int64, device=self.device_
522
+ )
523
+ else:
524
+ attention_mask = None
525
+
526
+ if self.config_.attn_implementation_ != "flash_attn":
527
+ causal_mask = self.model_.causal_mask(
528
+ attention_mask, inputs_embeds, cache_position, past_key_values
529
+ )
530
+ else:
531
+ causal_mask = attention_mask
532
+
533
+ return input_ids, inputs_embeds, attention_mask, causal_mask, cache_position
534
+
535
+ def _call_decoder_stack_original(
536
+ self,
537
+ hidden_states: torch.Tensor,
538
+ input_args: LLMModelInput,
539
+ rotary_emb: Tuple[torch.Tensor, torch.Tensor],
540
+ attention_mask: Optional[torch.Tensor] = None,
541
+ cache_position: Optional[torch.Tensor] = None,
542
+ past_key_value: Optional[LLMCache] = None,
543
+ ):
544
+ # decoder layers
545
+ num_adapters = len(input_args.batch_configs_)
546
+ all_router_logits = [[] for _ in range(num_adapters)]
547
+ gradient_checkpoint = CHECKPOINT_CLASSES[input_args.gradient_checkpoint_]
548
+
549
+ for decoder_layer in self.model_.decoder_stack():
550
+ hidden_states, *router_logits = gradient_checkpoint(
551
+ decoder_layer.forward,
552
+ hidden_states,
553
+ input_args,
554
+ rotary_emb,
555
+ attention_mask,
556
+ cache_position,
557
+ past_key_value,
558
+ )
559
+ if len(router_logits) == 0:
560
+ continue
561
+ # collecting router logits
562
+ assert len(router_logits) == num_adapters
563
+ for idx in range(num_adapters):
564
+ if router_logits[idx] is not None:
565
+ all_router_logits[idx].append(router_logits[idx])
566
+
567
+ hidden_states = self.model_.norm(hidden_states)
568
+
569
+ return hidden_states, all_router_logits
570
+
571
+
572
+ def _call_decoder_stack(
573
+ self,
574
+ hidden_states: torch.Tensor,
575
+ input_args: LLMModelInput,
576
+ rotary_emb: Tuple[torch.Tensor, torch.Tensor],
577
+ attention_mask: Optional[torch.Tensor] = None,
578
+ cache_position: Optional[torch.Tensor] = None,
579
+ past_key_value: Optional[LLMCache] = None,
580
+ #require_attention: Optional[int] = -1,
581
+ #require_hide: Optional[int] = -1,
582
+ ):
583
+ # decoder layers
584
+ gradient_checkpoint = CHECKPOINT_CLASSES[input_args.gradient_checkpoint_]
585
+
586
+ #hidden_output = []
587
+ #hidden_atten = []
588
+ attention_matrixs = []
589
+ for idx, decoder_layer in enumerate(self.model_.decoder_stack()):
590
+ hidden_states, attention_matrix = gradient_checkpoint(
591
+ decoder_layer.forward,
592
+ hidden_states,
593
+ input_args,
594
+ rotary_emb,
595
+ attention_mask,
596
+ cache_position,
597
+ past_key_value,
598
+ )
599
+ if idx in [31,30,29]:
600
+ attention_matrixs.append(attention_matrix)
601
+ """if require_hide == len(self.model_.layers_) or require_hide == idx:
602
+ hidden_output.append(hidden_states)
603
+ if require_attention == len(self.model_.layers_) or require_attention == idx:
604
+ hidden_atten.append(hidden_attention)"""
605
+
606
+ hidden_states = self.model_.norm(hidden_states)
607
+
608
+ return hidden_states, attention_matrixs#hidden_atten, hidden_output
609
+
610
+ # compute the model: output probs
611
+ def forward(
612
+ self, input_args: LLMModelInput, past_key_values: Optional[LLMCache] = None
613
+ ) -> List[LLMModelOutput]:
614
+ input_ids, inputs_embeds, attention_mask, causal_mask, cache_position = (
615
+ self._prepare_inputs(input_args, past_key_values)
616
+ )
617
+
618
+ labels = input_args.batch_labels_
619
+
620
+ input_args.batch_labels_ = None
621
+ input_args.batch_tokens_ = None
622
+ input_args.batch_masks_ = None
623
+
624
+ # embed positions
625
+ hidden_states = inputs_embeds
626
+
627
+ rotary_emb = self.model_.rotary_embed(
628
+ hidden_states, cache_position.unsqueeze(0)
629
+ )
630
+
631
+ hidden_states, attention_matrixs = self._call_decoder_stack(
632
+ hidden_states,
633
+ input_args,
634
+ rotary_emb,
635
+ causal_mask,
636
+ cache_position,
637
+ past_key_values,
638
+ #require_attention,
639
+ #require_hide,
640
+ )
641
+ attention_matrixs[-1] = attention_matrixs[-1].permute(0,2,3,1)
642
+ attention_matrixs[-1] = torch.sum(attention_matrixs[-1], dim = -1).squeeze().to('cpu').detach()
643
+ #print(attention_matrixs[-1].shape)
644
+ #print(torch.mean(attention_matrixs[-1][input_args.batch_cites[0][0] + 1:input_args.batch_cites[0][2],input_args.batch_cites[0][0]]))
645
+ import numpy as np
646
+ import matplotlib.pyplot as plt
647
+ import seaborn as sns
648
+ plt.figure(figsize=(8, 6))
649
+ print(f"len:{input_args.batch_prompt_len[0]}")
650
+ print(attention_matrixs[-1].shape)
651
+ sns.heatmap(attention_matrixs[-1][input_args.batch_prompt_len[0]:,input_args.batch_prompt_len[0]:], annot=False, cmap="YlGnBu", vmin = 0, vmax = 0.2, xticklabels=False, yticklabels=False)
652
+ plt.savefig("/yy21/heatmap", bbox_inches='tight', dpi=300)
653
+ input()
654
+ #route_logits = hidden_states @ (self.routerup @ self.routerdown)
655
+ route_logits = hidden_states @ self.routerup
656
+ hidden_cites = hidden_states @ self.cite_output
657
+ norm_cite_logits = F.normalize(hidden_cites, p = 2, dim = 2)
658
+ cite_logits = []
659
+ for batch in range(hidden_states.shape[0]):
660
+ #norm_doc = F.normalize(self.doc_embeds[batch], p = 2, dim = 1)
661
+ norm_doc = F.normalize(self.doc_embeds[batch].detach(), p = 2, dim = 1)
662
+ cite_logits.append(norm_cite_logits[batch] @ norm_doc.T)
663
+ #cite_logits.append(norm_cite_logits[batch])
664
+
665
+
666
+ # calculate loss
667
+ output = self.output_(hidden_states, input_args)
668
+ #att_s = hidden_atten[0].sum(dim = 1).squeeze() / 32 ###这里把List变为一个值
669
+ assert isinstance(output, List)
670
+ for indx, lora_config in enumerate(input_args.batch_configs_):
671
+ output_data = output[indx]
672
+ assert isinstance(output_data, LLMModelOutput)
673
+ start_idx = lora_config.batch_start_idx_
674
+ end_idx = lora_config.batch_end_idx_
675
+ output_data.batch_start_idx_ = start_idx
676
+ output_data.batch_end_idx_ = end_idx
677
+ #print(f"router:{route_logits[0,-1]}")
678
+ #print(f"cite:{cite_logits}")
679
+ if (labels is None) and (route_logits[0, -1, 1] > route_logits[0, -1, 0]):
680
+ output_data.logits = cite_logits[0].unsqueeze(0)
681
+ #output_data.logits = hidden_states[0].unsqueeze(0)
682
+ output_data.cite_flag = True
683
+ else:
684
+ output_data.cite_flag = False
685
+ if labels is None:
686
+ continue
687
+ # compute loss when labels provided
688
+ output_data.loss = output_data.loss_fn_(
689
+ input_ids[start_idx:end_idx],
690
+ output_data.logits,
691
+ labels[start_idx:end_idx],
692
+ input_args.batch_cites,
693
+ input_args.batch_cites_value,
694
+ input_args.batch_prompt_len
695
+ )
696
+ output_data.loss_fn_ = None
697
+ # route_logits和下面的合并
698
+ for idx in range(len(input_args.batch_cites)):
699
+ new_cites = []
700
+ new_cites_v = []
701
+ for i in range(len(input_args.batch_cites[idx])):
702
+ if input_args.batch_cites[idx][i] >= input_args.batch_prompt_len[idx]:
703
+ new_cites.append(input_args.batch_cites[idx][i])
704
+ if i < len(input_args.batch_cites_value[idx]):
705
+ new_cites_v.append(input_args.batch_cites_value[idx][i])
706
+ input_args.batch_cites[idx] = new_cites
707
+ input_args.batch_cites_value[idx] = new_cites_v
708
+ if output_data.aux_loss is None:
709
+ output_data.aux_loss = self.attn_mat_coin * 0.01 * self.attention_loss_fn(attention_matrixs, causal_mask, input_args.batch_cites, input_args.batch_prompt_len)
710
+ else:
711
+ output_data.aux_loss += self.attn_mat_coin * 0.01 * self.attention_loss_fn(attention_matrixs, causal_mask, input_args.batch_cites, input_args.batch_prompt_len)
712
+ print(f"1:{output_data.aux_loss}")
713
+ for idx in range(len(input_args.batch_cites)):
714
+ if len(input_args.batch_cites[idx]) > len(input_args.batch_cites_value[idx]):
715
+ input_args.batch_cites[idx] = input_args.batch_cites[idx][:-1]
716
+ output_data.aux_loss += self.router_coin * 10 * self.compute_route_loss(route_logits, input_args.batch_cites)#router的label中,cite位置的是1,其他是0
717
+ print(f"2:{output_data.aux_loss}")
718
+ #output_data.aux_loss += self.cite_coin * self.compute_cite_loss2(hidden_states, input_args.batch_cites,input_args.batch_cites_value,batch_doc_embed)#router的label中,cite位置的是1,其他是0
719
+ output_data.aux_loss += self.cite_coin * 100 * self.compute_cite_loss(cite_logits, input_args.batch_cites,input_args.batch_cites_value)#router的label中,cite位置的是1,其他是0
720
+ print(f"3:{output_data.aux_loss}")
721
+ return output
722
+
723
+ def from_pretrained(
724
+ name_or_path: str,
725
+ device: str,
726
+ bits: int = None,
727
+ attn_impl: str = "eager",
728
+ use_sliding_window: bool = False,
729
+ load_dtype: torch.dtype = torch.bfloat16,
730
+ compute_dtype: torch.dtype = torch.bfloat16,
731
+ double_quant: bool = True,
732
+ quant_type: str = "nf4",
733
+ ) -> "LLMModel":
734
+ # load_dtype will change the precision of LLaMA pre-trained model
735
+ # when loading with quantization (bits = 8 or bits = 4), load_dtype will only influence the actual computing precision
736
+ if load_dtype not in [torch.bfloat16, torch.float16, torch.float32]:
737
+ raise ValueError(f"unsupported load dtype {load_dtype}")
738
+
739
+ if compute_dtype not in [torch.bfloat16, torch.float16, torch.float32]:
740
+ raise ValueError(f"unsupported compute dtype {compute_dtype}")
741
+
742
+ if load_dtype in [torch.bfloat16, torch.float16]:
743
+ logging.info("Loading model with half precision.")
744
+
745
+ # BFloat16 is only supported after Ampere GPUs
746
+ if not executor.is_bf16_supported():
747
+ if load_dtype == torch.bfloat16:
748
+ logging.warning("bf16 is not available. deprecated to fp16.")
749
+ load_dtype = torch.float16
750
+
751
+ if bits in [4, 8] and compute_dtype == torch.bfloat16:
752
+ logging.warning("bf16 is not available. deprecated to fp16.")
753
+ compute_dtype = torch.float16
754
+
755
+ if bits in [4, 8]:
756
+ logging.info(f"Loading model with quantization, bits = {bits}.")
757
+ llm_model = AutoModelForCausalLM.from_pretrained(
758
+ name_or_path,
759
+ device_map=device,
760
+ trust_remote_code=True,
761
+ quantization_config=BitsAndBytesConfig(
762
+ load_in_4bit=bits == 4,
763
+ load_in_8bit=bits == 8,
764
+ llm_int8_threshold=6.0,
765
+ llm_int8_has_fp16_weight=False,
766
+ bnb_4bit_compute_dtype=compute_dtype,
767
+ bnb_4bit_use_double_quant=double_quant,
768
+ bnb_4bit_quant_type=quant_type,
769
+ ),
770
+ torch_dtype=load_dtype,
771
+ )
772
+ else:
773
+ llm_model = AutoModelForCausalLM.from_pretrained(
774
+ name_or_path,
775
+ device_map=device,
776
+ trust_remote_code=True,
777
+ torch_dtype=load_dtype,
778
+ )
779
+
780
+ llm_model.requires_grad_(False)
781
+
782
+ model = from_pretrained(
783
+ llm_model,
784
+ attn_impl=attn_impl,
785
+ use_sliding_window=use_sliding_window,
786
+ device=device,
787
+ )
788
+
789
+ logging.info(f"Use {attn_impl} as attention implementation.")
790
+
791
+ return LLMModel(model)
792
+
793
+ def init_adapter(
794
+ self, config: AdapterConfig, weight: Optional[Dict[str, torch.Tensor]] = None
795
+ ):
796
+ self.attn_mat_coin = config.atten_coin
797
+ self.router_coin = config.router_coin
798
+ self.cite_coin = config.cite_coin
799
+ # Patch for MixLoRA
800
+ if isinstance(config, MixLoraConfig) and config.act_fn_ is None:
801
+ config.act_fn_ = self.config_.hidden_act_
802
+
803
+ self.adapter_configs_[config.adapter_name] = config
804
+ # init output layer
805
+ if config.task_name in task_dict and isinstance(
806
+ task_dict[config.task_name], SequenceClassificationTask
807
+ ):
808
+ output_layer = ClassificationOutputLayer(
809
+ **task_dict[config.task_name].init_kwargs(),
810
+ hidden_size=self.config_.dim_,
811
+ pad_token_id=self.config_.pad_token_id_,
812
+ device=self.device_,
813
+ weight=weight,
814
+ )
815
+ else:
816
+ output_layer = CasualOutputLayer(
817
+ vocab_size=self.config_.vocab_size_, weight=self.model_.lm_head_
818
+ )
819
+
820
+ if weight is None:
821
+ torch.nn.init.kaiming_normal_(self.attention_weight, mode='fan_in', nonlinearity='relu')
822
+ torch.nn.init.kaiming_normal_(self.routerup, mode='fan_in', nonlinearity='relu')
823
+ #torch.nn.init.kaiming_normal_(self.routerdown, mode='fan_in', nonlinearity='relu')
824
+ torch.nn.init.kaiming_normal_(self.cite_output, mode='fan_in', nonlinearity='relu')
825
+ torch.nn.init.orthogonal_(self.doc_proj)
826
+ else:
827
+ with torch.no_grad():
828
+ self.attention_weight.copy_(weight.get(f"{config.adapter_name}.attention_mat_weight", None))
829
+ self.routerup.copy_(weight.get(f"{config.adapter_name}.router_weight_up", None))
830
+ #self.routerdown.copy_(weight.get(f"{config.adapter_name}.router_weight_down", None))
831
+ self.cite_output.copy_(weight.get(f"{config.adapter_name}.cite_weight", None))
832
+ self.doc_proj.copy_(weight.get(f"{config.adapter_name}.doc_weight", None))
833
+ self.output_.layers_[config.adapter_name] = output_layer
834
+ if type(config) is not AdapterConfig:
835
+ # init transformer layers
836
+ for transformer_layer in self.model_.layers_:
837
+ init_lora_layer_weight(transformer_layer, self.config_, config, weight)
838
+ else:
839
+ assert weight is None, "can not load basic adapter with weight"
840
+
841
+ return config.adapter_name
842
+
843
+ def get_adapter_weight_dict(self, adapter_name: str) -> Dict[str, torch.Tensor]:
844
+ # return the lora weight and target_module's name
845
+ lora_weight_dict = self.output_.layers_[adapter_name].state_dict()
846
+ atten_name = f"{adapter_name}.attention_mat_weight"
847
+ lora_weight_dict[atten_name] = self.attention_weight
848
+ route_name = f"{adapter_name}.router_weight_up"
849
+ lora_weight_dict[route_name] = self.routerup
850
+ """route_name = f"{adapter_name}.router_weight_down"
851
+ lora_weight_dict[route_name] = self.routerdown"""
852
+ cite_name = f"{adapter_name}.cite_weight"
853
+ lora_weight_dict[cite_name] = self.cite_output
854
+ doc_name = f"{adapter_name}.doc_weight"
855
+ lora_weight_dict[doc_name] = self.doc_proj
856
+ lora_config = self.adapter_configs_[adapter_name]
857
+ for transformer_layer in self.model_.layers_:
858
+ get_lora_layer_weight(transformer_layer, lora_config, lora_weight_dict)
859
+
860
+ return lora_weight_dict
861
+
862
+ def unload_adapter(
863
+ self, adapter_name: str
864
+ ) -> Tuple[LoraConfig, Dict[str, torch.Tensor]]:
865
+ assert adapter_name in self.adapter_configs_, "adapter not exist"
866
+ lora_weight = self.get_adapter_weight_dict(adapter_name)
867
+ lora_config = self.adapter_configs_.pop(adapter_name)
868
+ self.output_.layers_.pop(adapter_name)
869
+ for transformer_layer in self.model_.layers_:
870
+ attn_state_dict, mlp_state_dict = transformer_layer.state_dict()
871
+ attn_state_dict: Dict[str, torch.Tensor]
872
+ mlp_state_dict: Dict[str, torch.Tensor]
873
+ lora_layer_list = list(attn_state_dict.values())
874
+ lora_layer_list.extend(mlp_state_dict.values())
875
+
876
+ for lora_layer in lora_layer_list:
877
+ if adapter_name in lora_layer.loras_:
878
+ lora_layer.loras_.pop(adapter_name, None)
879
+ elif adapter_name in transformer_layer.mlp_.moes_:
880
+ for expert_idx in range(
881
+ transformer_layer.mlp_.moes_[adapter_name].experts_
882
+ ):
883
+ moe_lora_name = f"moe.{adapter_name}.experts.{expert_idx}"
884
+ lora_layer.loras_.pop(moe_lora_name, None)
885
+
886
+ transformer_layer.mlp_.moes_.pop(adapter_name)
887
+ elif adapter_name in lora_layer.moes_:
888
+ for expert_idx in range(lora_layer.moes_[adapter_name].experts_):
889
+ moe_lora_name = f"moe.{adapter_name}.experts.{expert_idx}"
890
+ lora_layer.loras_.pop(moe_lora_name, None)
891
+
892
+ lora_layer.moes_.pop(lora_config.adapter_name, None)
893
+
894
+ return lora_config, lora_weight
895
+
896
+ def load_adapter(self, name_or_path: str, adapter_name: Optional[str] = None):
897
+ if adapter_name is None:
898
+ adapter_name = name_or_path
899
+
900
+ if not os.path.exists(name_or_path):
901
+ name_or_path = snapshot_download(repo_id=name_or_path, repo_type="model")
902
+ with open(
903
+ name_or_path + os.sep + "adapter_config.json", "r", encoding="utf8"
904
+ ) as fp:
905
+ lora_config = lora_config_factory(json.load(fp))
906
+ lora_config.adapter_name = adapter_name
907
+ lora_weight = torch.load(
908
+ name_or_path + os.sep + "adapter_model.bin",
909
+ map_location=self.device_,
910
+ weights_only=False,
911
+ )
912
+
913
+ self.init_adapter(lora_config, lora_weight)
914
+ return adapter_name
915
+
916
+ def compute_route_loss(self, logits, cites):
917
+ nrom_logits = logits / torch.norm(logits, dim = -1, keepdim=True)
918
+ b, l, v = logits.shape
919
+ """for c in cites:
920
+ if c[-1] == l:
921
+ del c[-1]"""
922
+ label = []
923
+ for k in range(b):
924
+ label.append([1 if i in cites[k] else 0 for i in range(l)])
925
+
926
+ if isinstance(label, torch.Tensor):
927
+ label = (
928
+ label.clone()
929
+ .detach()
930
+ .to(dtype=torch.long, device=logits.device)
931
+ )
932
+ else:
933
+ label = torch.tensor(label, dtype=torch.long, device=logits.device)
934
+
935
+ loss_fn = torch.nn.CrossEntropyLoss()
936
+ return loss_fn(
937
+ nrom_logits[..., :-1, :].contiguous().view(-1, v),
938
+ label[..., 1:].contiguous().view(-1),
939
+ )
940
+
941
+ def compute_cite_loss2(self, logits, cites, cites_v, docs_pos):
942
+ b = len(logits)
943
+ docs_pos = [torch.tensor(i) for i in docs_pos]
944
+ doc_embeds = []
945
+ norm_logits = [F.normalize(logits[batch], p = 2, dim = 1) for batch in range(logits.shape[0])]
946
+ for i in range(b):
947
+ doc_embeds.append(norm_logits[i][docs_pos[i]].transpose(0,1))
948
+ b_logits = []
949
+
950
+ for i in range(len(cites)):
951
+ b_logits.append(norm_logits[i] @ doc_embeds[i])
952
+ for k in range(len(cites_v)):
953
+ cites_v[k] = [self.token2id(i) for i in cites_v[k]]
954
+
955
+ labels = []
956
+ for k in range(b):
957
+ labels.append([-100 for _ in range(logits[k].shape[0])])
958
+ for i, v in zip(cites[k], cites_v[k]):
959
+ labels[k][i] = v - 1
960
+
961
+ if isinstance(labels[0], torch.Tensor):
962
+ for k in range(b):
963
+ labels[k] = (
964
+ labels[k].clone()
965
+ .detach()
966
+ .to(dtype=torch.long, device=logits[0].device)
967
+ )
968
+ else:
969
+ for k in range(b):
970
+ labels[k] = torch.tensor(labels[k], dtype=torch.long, device=logits[0].device)
971
+
972
+ loss_fn = torch.nn.CrossEntropyLoss(ignore_index = -100)
973
+
974
+ loss = 0
975
+ for k in range(b):
976
+ if len(cites[k]) != 0:
977
+ loss += loss_fn(
978
+ b_logits[k][..., :-1, :].contiguous().view(-1, b_logits[k].shape[-1]),
979
+ labels[k][..., 1:].contiguous().view(-1),
980
+ )
981
+ return loss / b
982
+
983
+ def compute_cite_loss(self, logits, cites, cites_v):
984
+ b = len(logits)
985
+
986
+ for k in range(len(cites_v)):
987
+ """if len(cites[k]) > len(cites_v[k]):
988
+ del cites[k][-1]"""
989
+ cites_v[k] = [self.token2id(i) for i in cites_v[k]]
990
+
991
+ labels = []
992
+ for k in range(b):
993
+ labels.append([-100 for _ in range(logits[k].shape[0])])
994
+ for i, v in zip(cites[k], cites_v[k]):
995
+ labels[k][i] = v - 1
996
+
997
+ if isinstance(labels[0], torch.Tensor):
998
+ for k in range(b):
999
+ labels[k] = (
1000
+ labels[k].clone()
1001
+ .detach()
1002
+ .to(dtype=torch.long, device=logits[0].device)
1003
+ )
1004
+ else:
1005
+ for k in range(b):
1006
+ labels[k] = torch.tensor(labels[k], dtype=torch.long, device=logits[0].device)
1007
+
1008
+ loss_fn = torch.nn.CrossEntropyLoss(ignore_index = -100)
1009
+
1010
+ loss = 0
1011
+ for k in range(b):
1012
+ if len(cites[k]) != 0:
1013
+ loss += loss_fn(
1014
+ logits[k][..., :-1, :].contiguous().view(-1, logits[k].shape[-1]),
1015
+ labels[k][..., 1:].contiguous().view(-1),
1016
+ )
1017
+ return loss / b
1018
+
1019
+
1020
+ def attention_loss_fn(self, mat, mask, cites, prompt_len):# cites: T个元素,每个元素代表c_i所在列
1021
+ mat = torch.stack(mat, dim = 0)
1022
+ mat = mat.permute(1,0,3,4,2)
1023
+ #final_mat = torch.matmul(mat, self.attention_weight).squeeze(-1)
1024
+ final_mat = torch.mean(mat, dim = -1)
1025
+ final_mat += mask
1026
+ final_mat = F.softmax(final_mat, dim=-1)
1027
+ loss = torch.tensor(0.0, dtype = final_mat.dtype, device = final_mat.device)
1028
+ num_layer = final_mat.shape[1]
1029
+ for batch in range(final_mat.shape[0]):
1030
+ if len(cites[batch]) == 0:
1031
+ continue
1032
+ for k in range(len(cites[batch]) - 1):
1033
+ for i in range(k + 1):
1034
+ if cites[batch][k] == cites[batch][k + 1] - 1:
1035
+ continue
1036
+ loss_now = (self.alpha[cites[batch][k]:cites[batch][k + 1] - 1] * self.beta[k - i, k]).expand(1, num_layer,-1) - final_mat[batch,:,cites[batch][k]:cites[batch][k + 1] - 1,cites[batch][i]]
1037
+ loss += F.relu(loss_now).sum() / (cites[batch][k + 1] - cites[batch][k])
1038
+
1039
+ return loss
c2cite/models/__init__.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .modeling_chatglm import GLMForCausalLM
2
+ from .modeling_gemma import GemmaForCausalLM
3
+ from .modeling_gemma2 import Gemma2ForCausalLM
4
+ from .modeling_llama import LlamaForCausalLM
5
+ from .modeling_mistral import MistralForCausalLM
6
+ from .modeling_mistral import MistralForCausalLM as Qwen2ForCausalLM
7
+ from .modeling_phi import PhiForCausalLM
8
+ from .modeling_phi3 import Phi3ForCausalLM
9
+
10
+ model_dict = {
11
+ "llama": LlamaForCausalLM,
12
+ "gemma": GemmaForCausalLM,
13
+ "gemma2": Gemma2ForCausalLM,
14
+ "mistral": MistralForCausalLM,
15
+ "qwen2": Qwen2ForCausalLM,
16
+ "phi": PhiForCausalLM,
17
+ "phi3": Phi3ForCausalLM,
18
+ "chatglm": GLMForCausalLM,
19
+ }
20
+
21
+
22
+ def from_pretrained(llm_model, **kwargs):
23
+ if llm_model.config.model_type in model_dict:
24
+ return model_dict[llm_model.config.model_type].from_pretrained(
25
+ llm_model, **kwargs
26
+ )
27
+ else:
28
+ raise RuntimeError(f"Model {llm_model.config.model_type} not supported.")
29
+
30
+
31
+ __all__ = [
32
+ "LlamaForCausalLM",
33
+ "GemmaForCausalLM",
34
+ "MistralForCausalLM",
35
+ "Qwen2ForCausalLM",
36
+ "PhiForCausalLM",
37
+ "Phi3ForCausalLM",
38
+ "from_pretrained",
39
+ "GLMForCausalLM",
40
+ ]
c2cite/models/modeling_chatglm.py ADDED
@@ -0,0 +1,855 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import Dict, List, Optional, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch.nn import LayerNorm
9
+ from transformers.utils import is_flash_attn_2_available
10
+
11
+ from moe_peft.common import (
12
+ FeedForward,
13
+ Linear,
14
+ LLMAttention,
15
+ LLMCache,
16
+ LLMDecoder,
17
+ LLMFeedForward,
18
+ LLMForCausalLM,
19
+ LLMModelConfig,
20
+ LLMModelInput,
21
+ collect_plugin_router_logtis,
22
+ flash_attention_forward,
23
+ slice_tensor,
24
+ )
25
+ from moe_peft.executors import executor
26
+ from moe_peft.utils import copy_parameters
27
+
28
+
29
+ @dataclass
30
+ class GLMConfig(LLMModelConfig):
31
+ post_layer_norm: bool = True
32
+ rmsnorm: bool = True
33
+ layernorm_epsilon: float = 1e-5
34
+ apply_residual_connection_post_layernorm: bool = False
35
+ fp32_residual_connection: bool = False
36
+ kv_channels: int = 128
37
+ multi_query_attention: bool = False
38
+ multi_query_group_num: int = 2
39
+ apply_query_key_layer_scaling: bool = True
40
+ attention_softmax_in_fp32: bool = True
41
+ original_rope: bool = True
42
+ add_bias_linear: bool = False
43
+ padded_vocab_size: int = -1
44
+ rope_ratio: float = 1
45
+
46
+
47
+ def split_tensor_along_last_dim(
48
+ tensor: torch.Tensor,
49
+ num_partitions: int,
50
+ contiguous_split_chunks: bool = False,
51
+ ) -> List[torch.Tensor]:
52
+ # Get the size and dimension.
53
+ last_dim = tensor.dim() - 1
54
+ last_dim_size = tensor.size()[last_dim] // num_partitions
55
+ # Split.
56
+ tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
57
+ # Note: torch.split does not create contiguous tensors by default.
58
+ if contiguous_split_chunks:
59
+ return tuple(chunk.contiguous() for chunk in tensor_list)
60
+
61
+ return tensor_list
62
+
63
+
64
+ class RotaryEmbedding(nn.Module):
65
+ def __init__(self, dim, rope_ratio=1, original_impl=False, device=None, dtype=None):
66
+ super().__init__()
67
+ inv_freq = 1.0 / (
68
+ 10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)
69
+ )
70
+ self.register_buffer("inv_freq", inv_freq)
71
+ self.dim = dim
72
+ self.original_impl = original_impl
73
+ self.rope_ratio = rope_ratio
74
+
75
+ def forward_impl(
76
+ self,
77
+ seq_len: int,
78
+ n_elem: int,
79
+ dtype: torch.dtype,
80
+ device: torch.device,
81
+ base: int = 10000,
82
+ ):
83
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
84
+ base = base * self.rope_ratio
85
+ theta = 1.0 / (
86
+ base
87
+ ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem)
88
+ )
89
+
90
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
91
+ seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)
92
+
93
+ # Calculate the product of position index and $\theta_i$
94
+ idx_theta = torch.outer(seq_idx, theta).float()
95
+
96
+ cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
97
+
98
+ # this is to mimic the behaviour of complex32, else we will get different results
99
+ if dtype in (torch.float16, torch.bfloat16, torch.int8):
100
+ cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
101
+ return cache
102
+
103
+ def forward(self, max_seq_len, offset=0):
104
+ return self.forward_impl(
105
+ max_seq_len,
106
+ self.dim,
107
+ dtype=self.inv_freq.dtype,
108
+ device=self.inv_freq.device,
109
+ )
110
+
111
+
112
+ @torch.jit.script
113
+ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
114
+ # x: [b, np, sq, hn]
115
+ b, np, sq, _ = x.shape
116
+ rot_dim = rope_cache.shape[-2] * 2
117
+ x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
118
+ # truncate to support variable sizes
119
+ rope_cache = rope_cache[:, :sq]
120
+ xshaped = x.reshape(b, np, sq, rot_dim // 2, 2)
121
+ rope_cache = rope_cache.view(-1, 1, sq, xshaped.size(3), 2)
122
+ x_out2 = torch.stack(
123
+ [
124
+ xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
125
+ xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
126
+ ],
127
+ -1,
128
+ )
129
+ x_out2 = x_out2.flatten(3)
130
+ return torch.cat((x_out2, x_pass), dim=-1)
131
+
132
+
133
+ class RMSNorm(torch.nn.Module):
134
+ def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
135
+ super().__init__()
136
+ self.weight = torch.nn.Parameter(
137
+ torch.empty(normalized_shape, device=device, dtype=dtype)
138
+ )
139
+ self.eps = eps
140
+
141
+ def forward(self, hidden_states: torch.Tensor):
142
+ input_dtype = hidden_states.dtype
143
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
144
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
145
+
146
+ return (self.weight * hidden_states).to(input_dtype)
147
+
148
+
149
+ class CoreAttention(torch.nn.Module):
150
+ def __init__(self, config: GLMConfig, layer_number):
151
+ super(CoreAttention, self).__init__()
152
+ self.config = config
153
+ self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
154
+ self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
155
+ if self.apply_query_key_layer_scaling:
156
+ self.attention_softmax_in_fp32 = True
157
+ self.layer_number = max(1, layer_number)
158
+ self.is_causal = True
159
+
160
+ projection_size = config.kv_channels * config.n_heads_
161
+
162
+ # Per attention head and per partition values.
163
+ self.hidden_size_per_partition = projection_size
164
+ self.hidden_size_per_attention_head = projection_size // config.n_heads_
165
+ self.num_attention_heads_per_partition = config.n_heads_
166
+
167
+ coeff = None
168
+ self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
169
+ if self.apply_query_key_layer_scaling:
170
+ coeff = self.layer_number
171
+ self.norm_factor *= coeff
172
+ self.coeff = coeff
173
+
174
+ def forward(self, query_layer, key_layer, value_layer, attention_mask):
175
+ # [b, np, sq, sk]
176
+ output_size = (
177
+ query_layer.size(0),
178
+ query_layer.size(1),
179
+ query_layer.size(2),
180
+ key_layer.size(2),
181
+ )
182
+
183
+ # [b, np, sq, hn] -> [b * np, sq, hn]
184
+ query_layer = query_layer.view(
185
+ output_size[0] * output_size[1], output_size[2], -1
186
+ )
187
+ # [b, np, sk, hn] -> [b * np, sk, hn]
188
+ key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1)
189
+
190
+ # preallocting input tensor: [b * np, sq, sk]
191
+ matmul_input_buffer = torch.empty(
192
+ output_size[0] * output_size[1],
193
+ output_size[2],
194
+ output_size[3],
195
+ dtype=query_layer.dtype,
196
+ device=query_layer.device,
197
+ )
198
+
199
+ # Raw attention scores. [b * np, sq, sk]
200
+ matmul_result = torch.baddbmm(
201
+ matmul_input_buffer,
202
+ query_layer, # [b * np, sq, hn]
203
+ key_layer.transpose(1, 2), # [b * np, hn, sk]
204
+ beta=0.0,
205
+ alpha=(1.0 / self.norm_factor),
206
+ )
207
+
208
+ # change view to [b, np, sq, sk]
209
+ attention_scores = matmul_result.view(*output_size)
210
+
211
+ # attention scores and attention mask [b, np, sq, sk]
212
+ if self.attention_softmax_in_fp32:
213
+ attention_scores = attention_scores.float()
214
+ if self.coeff is not None:
215
+ attention_scores = attention_scores * self.coeff
216
+ if (
217
+ attention_mask is None
218
+ and attention_scores.shape[2] == attention_scores.shape[3]
219
+ ):
220
+ attention_mask = torch.ones(
221
+ output_size[0],
222
+ 1,
223
+ output_size[2],
224
+ output_size[3],
225
+ device=attention_scores.device,
226
+ dtype=torch.bool,
227
+ )
228
+ attention_mask.tril_()
229
+ attention_mask = ~attention_mask
230
+ if attention_mask is not None:
231
+ attention_scores = attention_scores.masked_fill(
232
+ attention_mask, float("-inf")
233
+ )
234
+ attention_probs = F.softmax(attention_scores, dim=-1)
235
+ attention_probs = attention_probs.type_as(value_layer)
236
+
237
+ # query layer shape: [b * np, sq, hn]
238
+ # value layer shape: [b, np, sk, hn]
239
+ # attention shape: [b, np, sq, sk]
240
+ # context layer shape: [b, np, sq, hn]
241
+ output_size = (
242
+ value_layer.size(0),
243
+ value_layer.size(1),
244
+ query_layer.size(1),
245
+ value_layer.size(3),
246
+ )
247
+ # change view [b * np, sk, hn]
248
+ value_layer = value_layer.view(
249
+ output_size[0] * output_size[1], value_layer.size(2), -1
250
+ )
251
+ # change view [b * np, sq, sk]
252
+ attention_probs = attention_probs.view(
253
+ output_size[0] * output_size[1], output_size[2], -1
254
+ )
255
+ # matmul: [b * np, sq, hn]
256
+ context_layer = torch.bmm(attention_probs, value_layer)
257
+ # change view [b, np, sq, hn]
258
+ context_layer = context_layer.view(*output_size)
259
+ # [b, np, sq, hn] --> [b, sq, np, hn]
260
+ context_layer = context_layer.transpose(1, 2).contiguous()
261
+ # [b, sq, np, hn] --> [b, sq, hp]
262
+ new_context_layer_shape = context_layer.size()[:-2] + (
263
+ self.hidden_size_per_partition,
264
+ )
265
+ context_layer = context_layer.reshape(*new_context_layer_shape)
266
+
267
+ return context_layer
268
+
269
+
270
+ class FlashAttention2(CoreAttention):
271
+ def __init__(self, *args, **kwargs):
272
+ assert is_flash_attn_2_available(), "Flash Attention is not available"
273
+ super().__init__(*args, **kwargs)
274
+
275
+ def forward(self, query_states, key_states, value_states, attention_mask):
276
+ query_states = query_states.transpose(1, 2)
277
+ key_states = key_states.transpose(1, 2)
278
+ value_states = value_states.transpose(1, 2)
279
+
280
+ batch_size, query_length = query_states.shape[:2]
281
+
282
+ attn_output = flash_attention_forward(
283
+ query_states,
284
+ key_states,
285
+ value_states,
286
+ attention_mask,
287
+ query_length,
288
+ is_causal=self.is_causal,
289
+ )
290
+
291
+ attn_output = attn_output.reshape(
292
+ batch_size, query_length, self.hidden_size_per_partition
293
+ ).contiguous()
294
+
295
+ return attn_output
296
+
297
+
298
+ CORE_ATTENTION_CLASSES = {
299
+ "eager": CoreAttention,
300
+ "flash_attn": FlashAttention2,
301
+ }
302
+
303
+
304
+ class GLMSelfAttention(LLMAttention):
305
+ def __init__(
306
+ self,
307
+ qkv_layer: torch.nn.Module,
308
+ dense_layer: torch.nn.Module,
309
+ config: GLMConfig,
310
+ layer_idx,
311
+ ):
312
+ super(GLMSelfAttention, self).__init__()
313
+ self.layer_idx = layer_idx
314
+
315
+ self.projection_size = config.kv_channels * config.n_heads_
316
+
317
+ # Per attention head and per-partition values.
318
+ self.hidden_size_per_attention_head = self.projection_size // config.n_heads_
319
+ self.num_attention_heads_per_partition = config.n_heads_
320
+ self.multi_query_attention = config.multi_query_attention
321
+ self.qkv_hidden_size = 3 * self.projection_size
322
+
323
+ if self.multi_query_attention:
324
+ self.num_multi_query_groups_per_partition = config.multi_query_group_num
325
+ self.qkv_hidden_size = (
326
+ self.projection_size
327
+ + 2
328
+ * self.hidden_size_per_attention_head
329
+ * self.num_multi_query_groups_per_partition
330
+ )
331
+
332
+ # QKV layer.
333
+ self.query_key_value = Linear(base_layer=qkv_layer, device=config.device_)
334
+ # Core attention layer.
335
+ self.core_attention = CORE_ATTENTION_CLASSES[config.attn_implementation_](
336
+ config, self.layer_idx
337
+ )
338
+
339
+ # Dense layer.
340
+ self.dense = Linear(base_layer=dense_layer, device=config.device_)
341
+
342
+ def state_dict(self) -> Dict[str, Linear]:
343
+ return {"qkv_proj": self.query_key_value, "dense": self.dense}
344
+
345
+ def forward(
346
+ self,
347
+ hidden_states: torch.Tensor,
348
+ input_args: LLMModelInput,
349
+ rotary_pos_emb: Tuple[torch.Tensor, torch.Tensor],
350
+ attention_mask: Optional[torch.Tensor] = None,
351
+ cache_position: Optional[torch.Tensor] = None,
352
+ past_key_value: Optional[LLMCache] = None,
353
+ ):
354
+ mixed_x_layer = self.query_key_value(hidden_states, input_args)
355
+
356
+ if self.multi_query_attention:
357
+ (query_layer, key_layer, value_layer) = mixed_x_layer.split(
358
+ [
359
+ self.num_attention_heads_per_partition
360
+ * self.hidden_size_per_attention_head,
361
+ self.num_multi_query_groups_per_partition
362
+ * self.hidden_size_per_attention_head,
363
+ self.num_multi_query_groups_per_partition
364
+ * self.hidden_size_per_attention_head,
365
+ ],
366
+ dim=-1,
367
+ )
368
+ query_layer = query_layer.view(
369
+ query_layer.size()[:-1]
370
+ + (
371
+ self.num_attention_heads_per_partition,
372
+ self.hidden_size_per_attention_head,
373
+ )
374
+ )
375
+ key_layer = key_layer.view(
376
+ key_layer.size()[:-1]
377
+ + (
378
+ self.num_multi_query_groups_per_partition,
379
+ self.hidden_size_per_attention_head,
380
+ )
381
+ )
382
+ value_layer = value_layer.view(
383
+ value_layer.size()[:-1]
384
+ + (
385
+ self.num_multi_query_groups_per_partition,
386
+ self.hidden_size_per_attention_head,
387
+ )
388
+ )
389
+ else:
390
+ new_tensor_shape = mixed_x_layer.size()[:-1] + (
391
+ self.num_attention_heads_per_partition,
392
+ 3 * self.hidden_size_per_attention_head,
393
+ )
394
+ mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
395
+
396
+ # [b, sq, np, 3 * hn] --> 3 [b, sq, np, hn]
397
+ (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(
398
+ mixed_x_layer, 3
399
+ )
400
+
401
+ # [b, sq, np, hn] -> [b, np, sq, hn]
402
+ query_layer, key_layer, value_layer = [
403
+ k.transpose(1, 2) for k in [query_layer, key_layer, value_layer]
404
+ ]
405
+
406
+ # apply relative positional encoding (rotary embedding)
407
+ query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
408
+ key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
409
+
410
+ if past_key_value is not None:
411
+ key_layer, value_layer = past_key_value.update(
412
+ key_layer,
413
+ value_layer,
414
+ self.layer_idx,
415
+ {"cache_position": cache_position},
416
+ )
417
+
418
+ if self.multi_query_attention:
419
+ key_layer = key_layer.unsqueeze(2)
420
+ key_layer = key_layer.expand(
421
+ -1,
422
+ -1,
423
+ self.num_attention_heads_per_partition
424
+ // self.num_multi_query_groups_per_partition,
425
+ -1,
426
+ -1,
427
+ )
428
+ key_layer = key_layer.contiguous().view(
429
+ key_layer.size()[:1]
430
+ + (self.num_attention_heads_per_partition,)
431
+ + key_layer.size()[3:]
432
+ )
433
+ value_layer = value_layer.unsqueeze(2)
434
+ value_layer = value_layer.expand(
435
+ -1,
436
+ -1,
437
+ self.num_attention_heads_per_partition
438
+ // self.num_multi_query_groups_per_partition,
439
+ -1,
440
+ -1,
441
+ )
442
+ value_layer = value_layer.contiguous().view(
443
+ value_layer.size()[:1]
444
+ + (self.num_attention_heads_per_partition,)
445
+ + value_layer.size()[3:]
446
+ )
447
+
448
+ context_layer = self.core_attention(
449
+ query_layer,
450
+ key_layer,
451
+ value_layer,
452
+ attention_mask,
453
+ )
454
+
455
+ output = self.dense(context_layer, input_args)
456
+
457
+ return output
458
+
459
+
460
+ def swiglu(x):
461
+ x = torch.chunk(x, 2, dim=-1)
462
+ return F.silu(x[0]) * x[1]
463
+
464
+
465
+ class GLMMLP(LLMFeedForward):
466
+ def __init__(
467
+ self,
468
+ dense_h_to_4h: torch.nn.Module,
469
+ dense_4h_to_h: torch.nn.Module,
470
+ config: GLMConfig,
471
+ ) -> None:
472
+ super().__init__()
473
+ self.dense_h_to_4h: Linear = Linear(dense_h_to_4h, config.device_)
474
+ self.dense_4h_to_h: Linear = Linear(dense_4h_to_h, config.device_)
475
+
476
+ self.activation_func = swiglu
477
+
478
+ def state_dict(self) -> Dict[str, torch.nn.Module]:
479
+ return {
480
+ "dense_h_to_4h": self.dense_h_to_4h,
481
+ "dense_4h_to_h": self.dense_4h_to_h,
482
+ }
483
+
484
+ def _batch_forward(
485
+ self, data: torch.Tensor, input_args: LLMModelInput
486
+ ) -> torch.Tensor:
487
+ # [b, sq, h] -> [b, sq, 4hp]
488
+ intermediate_parallel = self.dense_h_to_4h(data, input_args)
489
+ intermediate_parallel = self.activation_func(intermediate_parallel)
490
+ # [b, sq, 4hp] -> [b, sq, h]
491
+ output = self.dense_4h_to_h(intermediate_parallel, input_args)
492
+ return output
493
+
494
+ def _lora_forward(
495
+ self, lora_name: str, act_fn: torch.nn.Module, hidden_states: torch.Tensor
496
+ ) -> torch.Tensor:
497
+ if lora_name in self.dense_h_to_4h.loras_:
498
+ hidden_states = self.dense_h_to_4h.loras_[lora_name].forward(
499
+ self.dense_h_to_4h.base_layer_.forward(hidden_states), hidden_states
500
+ )
501
+ else:
502
+ hidden_states = self.dense_h_to_4h.base_layer_.forward(hidden_states)
503
+
504
+ hidden_states = self.activation_func(hidden_states)
505
+
506
+ if lora_name in self.dense_4h_to_h.loras_:
507
+ hidden_states = self.dense_4h_to_h.loras_[lora_name].forward(
508
+ self.dense_4h_to_h.base_layer_.forward(hidden_states), hidden_states
509
+ )
510
+ else:
511
+ hidden_states = self.dense_4h_to_h.base_layer_.forward(hidden_states)
512
+
513
+ return hidden_states
514
+
515
+ def _mixlora_forward(
516
+ self, moe_name, act_fn, expert_mask, hidden_states, input_dtype
517
+ ):
518
+ common_dense_h_to_4h = self.dense_h_to_4h.base_layer_.forward(
519
+ hidden_states.to(input_dtype)
520
+ ).to(hidden_states.dtype)
521
+ final_expert_states = []
522
+ for expert_idx in range(expert_mask.shape[0]):
523
+ _, top_x = torch.where(expert_mask[expert_idx])
524
+
525
+ lora_name = f"moe.{moe_name}.experts.{expert_idx}"
526
+ if lora_name in self.dense_h_to_4h.loras_:
527
+ lora_data = slice_tensor(hidden_states, top_x, input_dtype)
528
+ act_result = self.activation_func(
529
+ self.dense_h_to_4h.loras_[lora_name].forward(
530
+ slice_tensor(common_dense_h_to_4h, top_x, input_dtype),
531
+ lora_data,
532
+ )
533
+ )
534
+ else:
535
+ act_result = self.activation_func(
536
+ slice_tensor(common_dense_h_to_4h, top_x, input_dtype)
537
+ )
538
+
539
+ if lora_name in self.dense_4h_to_h.loras_:
540
+ final_expert_states.append(
541
+ self.dense_4h_to_h.loras_[lora_name].forward(
542
+ self.dense_4h_to_h.base_layer_.forward(act_result), act_result
543
+ )
544
+ )
545
+ else:
546
+ final_expert_states.append(
547
+ self.dense_4h_to_h.base_layer_.forward(act_result)
548
+ )
549
+
550
+ return final_expert_states
551
+
552
+
553
+ class GLMDecoderLayer(LLMDecoder):
554
+ def __init__(
555
+ self, self_attn: GLMSelfAttention, mlp: FeedForward, config: GLMConfig
556
+ ) -> None:
557
+ super().__init__()
558
+ self.layer_id_ = self_attn.layer_idx
559
+ self.apply_residual_connection_post_layernorm = (
560
+ config.apply_residual_connection_post_layernorm
561
+ )
562
+ self.fp32_residual_connection = config.fp32_residual_connection
563
+
564
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
565
+ # Input layer norm.
566
+ self.input_layernorm = LayerNormFunc(
567
+ config.dim_,
568
+ eps=config.layernorm_epsilon,
569
+ device=config.device_,
570
+ dtype=config.dtype_,
571
+ )
572
+ # Self-attention layer.
573
+ self.self_attn_: GLMSelfAttention = self_attn
574
+ self.hidden_dropout = config.hidden_dropout_
575
+
576
+ # Post attention layer norm.
577
+ self.post_layernorm = LayerNormFunc(
578
+ config.dim_,
579
+ eps=config.layernorm_epsilon,
580
+ device=config.device_,
581
+ dtype=config.dtype_,
582
+ )
583
+ # mlp
584
+ self.mlp_: FeedForward = mlp
585
+
586
+ def state_dict(self) -> Tuple[Dict[str, nn.Module], Dict[str, nn.Module]]:
587
+ return self.self_attn_.state_dict(), self.mlp_.state_dict()
588
+
589
+ def forward(
590
+ self,
591
+ hidden_states: torch.Tensor,
592
+ input_args: LLMModelInput,
593
+ rotary_pos_emb: Tuple[torch.Tensor, torch.Tensor],
594
+ attention_mask: Optional[torch.Tensor] = None,
595
+ cache_position: Optional[torch.Tensor] = None,
596
+ past_key_value: Optional[LLMCache] = None,
597
+ ):
598
+ layernorm_output = self.input_layernorm(hidden_states)
599
+
600
+ attention_output = self.self_attn_.forward(
601
+ layernorm_output,
602
+ input_args,
603
+ rotary_pos_emb,
604
+ attention_mask,
605
+ cache_position,
606
+ past_key_value,
607
+ )
608
+
609
+ # Residual connection.
610
+ if self.apply_residual_connection_post_layernorm:
611
+ residual = layernorm_output
612
+ else:
613
+ residual = hidden_states
614
+
615
+ layernorm_input = F.dropout(
616
+ attention_output,
617
+ p=self.hidden_dropout,
618
+ training=not input_args.inference_mode_,
619
+ )
620
+ layernorm_input = residual + layernorm_input
621
+
622
+ # Layer norm post the self attention.
623
+ layernorm_output = self.post_layernorm(layernorm_input)
624
+
625
+ # MLP.
626
+ mlp_output, router_logits = self.mlp_(layernorm_output, input_args)
627
+
628
+ # Second residual connection.
629
+ if self.apply_residual_connection_post_layernorm:
630
+ residual = layernorm_output
631
+ else:
632
+ residual = layernorm_input
633
+
634
+ output = F.dropout(
635
+ mlp_output, p=self.hidden_dropout, training=not input_args.inference_mode_
636
+ )
637
+ output = residual + output
638
+
639
+ if input_args.output_router_logits_:
640
+ router_logits = collect_plugin_router_logtis(
641
+ router_logits, input_args, self
642
+ )
643
+
644
+ return output, *router_logits
645
+
646
+
647
+ class GLMEmbedding(torch.nn.Module):
648
+ def __init__(self, config: GLMConfig):
649
+ super(GLMEmbedding, self).__init__()
650
+
651
+ self.hidden_size = config.dim_
652
+ # Word embeddings (parallel).
653
+ self.word_embeddings = nn.Embedding(
654
+ config.padded_vocab_size,
655
+ self.hidden_size,
656
+ dtype=config.dtype_,
657
+ device=config.device_,
658
+ )
659
+ self.fp32_residual_connection = config.fp32_residual_connection
660
+
661
+ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
662
+ # Embeddings.
663
+ words_embeddings = self.word_embeddings(input_ids)
664
+ embeddings = words_embeddings
665
+ # If the input flag for fp32 residual connection is set, convert for float.
666
+ if self.fp32_residual_connection:
667
+ embeddings = embeddings.float()
668
+ return embeddings
669
+
670
+
671
+ class GLMForCausalLM(LLMForCausalLM):
672
+ def __init__(self, config: GLMConfig) -> None:
673
+ self.config_ = config
674
+ self.padding_idx_ = config.pad_token_id_
675
+ self.vocab_size_ = config.vocab_size_
676
+
677
+ # Embedding layer.
678
+ self.embed_tokens_ = GLMEmbedding(config)
679
+ # Rotary Position Embedding.
680
+ self.rotary_emb_layer: RotaryEmbedding = None
681
+ # Encoder(Decoder) layers.
682
+ self.layers_: List[GLMDecoderLayer] = []
683
+ # Final layer norm.
684
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
685
+ if self.config_.post_layer_norm:
686
+ self.final_layernorm_ = LayerNormFunc(
687
+ config.dim_,
688
+ eps=config.layernorm_epsilon,
689
+ device=config.device_,
690
+ dtype=config.dtype_,
691
+ )
692
+ else:
693
+ self.final_layernorm_ = nn.Identity()
694
+ # Output layer.
695
+ self.lm_head_ = torch.nn.Linear(
696
+ config.dim_,
697
+ config.vocab_size_,
698
+ bias=config.add_bias_linear,
699
+ dtype=config.dtype_,
700
+ device=config.device_,
701
+ )
702
+
703
+ def embed_tokens(self, input_ids: torch.Tensor) -> torch.Tensor:
704
+ return self.embed_tokens_(input_ids)
705
+
706
+ def rotary_embed(
707
+ self, input_tensor: torch.Tensor, position_ids: torch.Tensor
708
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
709
+ return self.rotary_emb_layer(max_seq_len=self.config_.max_seq_len_)[
710
+ None, position_ids[-1]
711
+ ]
712
+
713
+ def decoder_stack(self) -> List[LLMDecoder]:
714
+ return self.layers_
715
+
716
+ def norm(self, hidden_states: torch.Tensor) -> torch.Tensor:
717
+ return self.final_layernorm_(hidden_states)
718
+
719
+ def get_masks(
720
+ self,
721
+ input_ids: torch.Tensor,
722
+ past_key_values: LLMCache,
723
+ padding_mask: torch.Tensor,
724
+ ):
725
+ batch_size, seq_length, _ = input_ids.shape
726
+ full_attention_mask = torch.ones(
727
+ batch_size, seq_length, seq_length, device=input_ids.device
728
+ )
729
+ full_attention_mask.tril_()
730
+ past_length = 0
731
+ if past_key_values:
732
+ past_length = past_key_values.get_seq_length()
733
+ if past_length:
734
+ full_attention_mask = torch.cat(
735
+ (
736
+ torch.ones(
737
+ batch_size, seq_length, past_length, device=input_ids.device
738
+ ),
739
+ full_attention_mask,
740
+ ),
741
+ dim=-1,
742
+ )
743
+ if padding_mask is not None:
744
+ full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
745
+ if not past_length and padding_mask is not None:
746
+ full_attention_mask -= padding_mask.unsqueeze(-1) - 1
747
+ full_attention_mask = (full_attention_mask < 0.5).bool()
748
+ full_attention_mask.unsqueeze_(1)
749
+ return full_attention_mask
750
+
751
+ def causal_mask(
752
+ self,
753
+ attention_mask: torch.Tensor,
754
+ input_tensor: torch.Tensor,
755
+ cache_position: torch.Tensor,
756
+ past_key_values: Optional[LLMCache],
757
+ ) -> torch.Tensor:
758
+ return self.get_masks(input_tensor, past_key_values, attention_mask)
759
+
760
+ def model_config(self) -> GLMConfig:
761
+ return self.config_
762
+
763
+ @staticmethod
764
+ def from_pretrained(
765
+ llm_model,
766
+ attn_impl: str = "eager",
767
+ use_sliding_window: bool = False,
768
+ device: str = executor.default_device_name(),
769
+ ):
770
+ assert not use_sliding_window, "ChatGLM model does not support SWA."
771
+ # Get the config from LLM model and input args.
772
+ llm_config = llm_model.config
773
+ config = GLMConfig(
774
+ # LLM model args.
775
+ name_or_path_=llm_config._name_or_path,
776
+ device_=device,
777
+ dim_=llm_config.hidden_size,
778
+ head_dim_=llm_config.hidden_size // llm_config.num_attention_heads,
779
+ n_heads_=llm_config.num_attention_heads,
780
+ n_kv_heads_=llm_config.multi_query_group_num,
781
+ n_layers_=llm_config.num_layers,
782
+ hidden_act_=swiglu,
783
+ hidden_dropout_=llm_config.hidden_dropout,
784
+ vocab_size_=llm_config.vocab_size,
785
+ pad_token_id_=llm_config.pad_token_id,
786
+ max_seq_len_=llm_config.seq_length,
787
+ attn_implementation_=attn_impl,
788
+ dtype_=llm_model.dtype,
789
+ # ChatGLM args.
790
+ post_layer_norm=llm_config.post_layer_norm,
791
+ rmsnorm=llm_config.rmsnorm,
792
+ layernorm_epsilon=llm_config.layernorm_epsilon,
793
+ apply_residual_connection_post_layernorm=llm_config.apply_residual_connection_post_layernorm,
794
+ fp32_residual_connection=llm_config.fp32_residual_connection,
795
+ apply_query_key_layer_scaling=llm_config.apply_query_key_layer_scaling,
796
+ kv_channels=llm_config.kv_channels,
797
+ multi_query_attention=llm_config.multi_query_attention,
798
+ multi_query_group_num=llm_config.multi_query_group_num,
799
+ attention_softmax_in_fp32=llm_config.attention_softmax_in_fp32,
800
+ original_rope=llm_config.original_rope,
801
+ add_bias_linear=llm_config.add_bias_linear,
802
+ padded_vocab_size=llm_config.padded_vocab_size,
803
+ rope_ratio=(
804
+ llm_config.rope_ratio if hasattr(llm_config, "rope_ratio") else 1
805
+ ),
806
+ )
807
+
808
+ model = GLMForCausalLM(config)
809
+ llm_model.requires_grad_(False)
810
+
811
+ copy_parameters(
812
+ llm_model.transformer.embedding,
813
+ model.embed_tokens_,
814
+ )
815
+
816
+ rotary_dim = (
817
+ config.dim_ // config.n_heads_
818
+ if config.kv_channels is None
819
+ else config.kv_channels
820
+ )
821
+ model.rotary_emb_layer = RotaryEmbedding(
822
+ dim=rotary_dim // 2,
823
+ rope_ratio=config.rope_ratio,
824
+ original_impl=config.original_rope,
825
+ device=device,
826
+ dtype=config.dtype_,
827
+ )
828
+
829
+ for idx, layer in enumerate(llm_model.transformer.encoder.layers):
830
+ # Get self-attention layer.
831
+ self_attention = GLMSelfAttention(
832
+ qkv_layer=layer.self_attention.query_key_value,
833
+ dense_layer=layer.self_attention.dense,
834
+ config=config,
835
+ layer_idx=idx,
836
+ )
837
+ # Get MLP layer.
838
+ mlp = FeedForward(
839
+ GLMMLP(layer.mlp.dense_h_to_4h, layer.mlp.dense_4h_to_h, config=config)
840
+ )
841
+ # Create a transformer block.
842
+ encoder = GLMDecoderLayer(self_attention, mlp, config)
843
+ copy_parameters(layer.input_layernorm, encoder.input_layernorm)
844
+ copy_parameters(layer.post_attention_layernorm, encoder.post_layernorm)
845
+ model.layers_.append(encoder)
846
+
847
+ if config.post_layer_norm:
848
+ copy_parameters(
849
+ llm_model.transformer.encoder.final_layernorm,
850
+ model.final_layernorm_,
851
+ )
852
+
853
+ copy_parameters(llm_model.transformer.output_layer, model.lm_head_)
854
+
855
+ return model
c2cite/models/modeling_gemma.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers.models.gemma import modeling_gemma
5
+
6
+ from moe_peft.common import FeedForward
7
+ from moe_peft.executors import executor
8
+ from moe_peft.models.modeling_llama import (
9
+ LLAMA_ATTENTION_CLASSES as GEMMA_ATTENTION_CLASSES,
10
+ )
11
+ from moe_peft.models.modeling_llama import (
12
+ LlamaConfig,
13
+ LlamaDecoderLayer,
14
+ LlamaForCausalLM,
15
+ LlamaMLP,
16
+ )
17
+ from moe_peft.utils import copy_parameters
18
+
19
+
20
+ class GemmaRMSNorm(nn.Module):
21
+ def __init__(self, weight: torch.Tensor, eps: float = 1e-6):
22
+ super().__init__()
23
+ self.norm_eps_ = eps
24
+ self.weight_ = weight
25
+
26
+ def _norm(self, x):
27
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.norm_eps_)
28
+
29
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
30
+ output = self._norm(x.to(torch.float32))
31
+ # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
32
+ # See https://github.com/huggingface/transformers/pull/29402
33
+ output = output * (1.0 + self.weight_.to(torch.float32))
34
+ return output.to(x.dtype)
35
+
36
+
37
+ class GemmaEmbedding(nn.Module):
38
+ def __init__(self, embedding: torch.Tensor, pad_token: int, normalizer: float):
39
+ super().__init__()
40
+ self.token_embedding_: torch.Tensor = embedding
41
+ self.padding_idx_: int = pad_token
42
+ self.normalizer_: float = normalizer
43
+
44
+ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
45
+ data = F.embedding(tokens, self.token_embedding_, padding_idx=self.padding_idx_)
46
+ # normalized
47
+ # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
48
+ # See https://github.com/huggingface/transformers/pull/29402
49
+ normalizer = torch.tensor(self.normalizer_, dtype=data.dtype)
50
+ return data * normalizer
51
+
52
+
53
+ def _patch_hidden_act(config: modeling_gemma.GemmaConfig) -> str:
54
+ if hasattr(config, "hidden_activation") and config.hidden_activation is not None:
55
+ return config.hidden_activation
56
+ else:
57
+ return config.hidden_act
58
+
59
+
60
+ class GemmaForCausalLM(LlamaForCausalLM):
61
+ def __init__(self, config: LlamaConfig) -> None:
62
+ super().__init__(config)
63
+
64
+ @staticmethod
65
+ def from_pretrained(
66
+ llm_model: modeling_gemma.GemmaForCausalLM,
67
+ attn_impl: str = "eager",
68
+ use_sliding_window: bool = False,
69
+ device: str = executor.default_device_name(),
70
+ ):
71
+ assert not use_sliding_window, "Gemma model does not support SWA."
72
+ llm_config: modeling_gemma.GemmaConfig = llm_model.config
73
+ llm_args = LlamaConfig(
74
+ name_or_path_=llm_config.name_or_path,
75
+ vocab_size_=llm_config.vocab_size,
76
+ dim_=llm_config.hidden_size,
77
+ head_dim_=llm_config.head_dim,
78
+ intermediate_=llm_config.intermediate_size,
79
+ n_layers_=llm_config.num_hidden_layers,
80
+ n_heads_=llm_config.num_attention_heads,
81
+ n_kv_heads_=llm_config.num_key_value_heads,
82
+ hidden_act_=_patch_hidden_act(llm_config),
83
+ rms_norm_eps_=llm_config.rms_norm_eps,
84
+ max_seq_len_=llm_config.max_position_embeddings,
85
+ rope_theta_=llm_config.rope_theta,
86
+ pad_token_id_=llm_config.pad_token_id,
87
+ attn_implementation_=attn_impl,
88
+ device_=torch.device(device),
89
+ dtype_=llm_model.dtype,
90
+ )
91
+
92
+ if llm_args.pad_token_id_ is None:
93
+ llm_args.pad_token_id_ = -1
94
+
95
+ model = GemmaForCausalLM(llm_args)
96
+ llm_model.requires_grad_(False)
97
+ model.embed_tokens_ = GemmaEmbedding(
98
+ llm_model.model.embed_tokens.weight,
99
+ llm_args.pad_token_id_,
100
+ llm_args.dim_**0.5,
101
+ )
102
+ model.norm_ = GemmaRMSNorm(llm_model.model.norm.weight, llm_args.rms_norm_eps_)
103
+ copy_parameters(llm_model.lm_head, model.lm_head_)
104
+
105
+ for idx, layer in enumerate(llm_model.model.layers):
106
+ decoder = LlamaDecoderLayer(idx)
107
+ decoder.self_attn_ = GEMMA_ATTENTION_CLASSES[llm_args.attn_implementation_](
108
+ layer.self_attn.q_proj,
109
+ layer.self_attn.k_proj,
110
+ layer.self_attn.v_proj,
111
+ layer.self_attn.o_proj,
112
+ idx,
113
+ llm_args,
114
+ )
115
+ decoder.mlp_ = FeedForward(
116
+ LlamaMLP(
117
+ layer.mlp.gate_proj,
118
+ layer.mlp.down_proj,
119
+ layer.mlp.up_proj,
120
+ llm_args,
121
+ )
122
+ )
123
+ decoder.input_layernorm_ = GemmaRMSNorm(
124
+ layer.input_layernorm.weight, llm_args.rms_norm_eps_
125
+ )
126
+ decoder.post_attention_layernorm_ = GemmaRMSNorm(
127
+ layer.post_attention_layernorm.weight, llm_args.rms_norm_eps_
128
+ )
129
+ model.layers_.append(decoder)
130
+
131
+ return model
c2cite/models/modeling_gemma2.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, List, Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers.models.gemma2 import modeling_gemma2
7
+ from transformers.models.gemma2.modeling_gemma2 import apply_rotary_pos_emb, repeat_kv
8
+ from transformers.utils import is_flash_attn_2_available
9
+
10
+ from moe_peft.common import (
11
+ FeedForward,
12
+ Linear,
13
+ LLMAttention,
14
+ LLMCache,
15
+ LLMDecoder,
16
+ LLMForCausalLM,
17
+ LLMModelConfig,
18
+ LLMModelInput,
19
+ collect_plugin_router_logtis,
20
+ flash_attention_forward,
21
+ prepare_4d_causal_attention_mask,
22
+ )
23
+ from moe_peft.executors import executor
24
+ from moe_peft.models.modeling_gemma import GemmaEmbedding, GemmaRMSNorm
25
+ from moe_peft.models.modeling_llama import LlamaMLP
26
+ from moe_peft.utils import copy_parameters, is_package_available
27
+
28
+
29
+ @dataclass
30
+ class Gemma2Config(LLMModelConfig):
31
+ rms_norm_eps_: float = 1e-6
32
+ attn_logit_softcapping_: float = 50.0
33
+ final_logit_softcapping_: float = 30.0
34
+ query_pre_attn_scalar_: int = 224
35
+ use_sliding_window_: bool = False
36
+ sliding_window_: int = 4096
37
+
38
+
39
+ class Gemma2RotaryEmbedding(nn.Module):
40
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
41
+ super().__init__()
42
+
43
+ self.dim = dim
44
+ self.max_position_embeddings = max_position_embeddings
45
+ self.base = base
46
+ inv_freq = 1.0 / (
47
+ self.base
48
+ ** (
49
+ torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device)
50
+ / self.dim
51
+ )
52
+ )
53
+ self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
54
+
55
+ @torch.no_grad()
56
+ def forward(self, x, position_ids):
57
+ # x: [bs, num_attention_heads, seq_len, head_size]
58
+ self.inv_freq.to(x.device)
59
+ inv_freq_expanded = (
60
+ self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
61
+ )
62
+ position_ids_expanded = position_ids[:, None, :].float()
63
+ # Force float32 since bfloat16 loses precision on long contexts
64
+ # See https://github.com/huggingface/transformers/pull/29285
65
+ device_type = x.device.type
66
+ device_type = (
67
+ device_type
68
+ if isinstance(device_type, str) and device_type != "mps"
69
+ else "cpu"
70
+ )
71
+ with torch.autocast(device_type=device_type, enabled=False):
72
+ freqs = (
73
+ inv_freq_expanded.float() @ position_ids_expanded.float()
74
+ ).transpose(1, 2)
75
+ emb = torch.cat((freqs, freqs), dim=-1)
76
+ cos = emb.cos()
77
+ sin = emb.sin()
78
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
79
+
80
+
81
+ # Multi-headed attention from 'Attention Is All You Need' paper.
82
+ class Gemma2Attention(LLMAttention):
83
+ def __init__(
84
+ self,
85
+ q_proj: nn.Module,
86
+ k_proj: nn.Module,
87
+ v_proj: nn.Module,
88
+ o_proj: nn.Module,
89
+ layer_idx: int,
90
+ config: Gemma2Config,
91
+ ):
92
+ super().__init__()
93
+ # attention
94
+ self.q_proj_: Linear = Linear(q_proj, config.device_)
95
+ self.k_proj_: Linear = Linear(k_proj, config.device_)
96
+ self.v_proj_: Linear = Linear(v_proj, config.device_)
97
+ self.o_proj_: Linear = Linear(o_proj, config.device_)
98
+ # config
99
+ self.layer_idx_ = layer_idx
100
+ self.config_ = config
101
+ self.dim_ = config.dim_
102
+ self.n_heads_ = config.n_heads_
103
+ self.n_kv_heads_ = config.n_kv_heads_
104
+ self.n_rep_ = self.n_heads_ // self.n_kv_heads_
105
+ self.head_dim_ = config.head_dim_
106
+ self.dtype_ = config.dtype_
107
+ self.is_causal_ = True
108
+
109
+ self.scaling_ = config.query_pre_attn_scalar_**-0.5
110
+ self.sliding_window_ = (
111
+ config.sliding_window_
112
+ if config.use_sliding_window_ and not bool(layer_idx % 2)
113
+ else None
114
+ )
115
+
116
+ def state_dict(self) -> Dict[str, Linear]:
117
+ return {
118
+ "q_proj": self.q_proj_,
119
+ "k_proj": self.k_proj_,
120
+ "v_proj": self.v_proj_,
121
+ "o_proj": self.o_proj_,
122
+ }
123
+
124
+ def forward(
125
+ self,
126
+ hidden_states: torch.Tensor,
127
+ input_args: LLMModelInput,
128
+ rotary_emb: Tuple[torch.Tensor, torch.Tensor],
129
+ attention_mask: Optional[torch.Tensor] = None,
130
+ cache_position: Optional[torch.Tensor] = None,
131
+ past_key_value: Optional[LLMCache] = None,
132
+ ):
133
+ bsz, q_len, _ = hidden_states.size()
134
+
135
+ query_states = self.q_proj_(hidden_states, input_args)
136
+ key_states = self.k_proj_(hidden_states, input_args)
137
+ value_states = self.v_proj_(hidden_states, input_args)
138
+
139
+ query_states = query_states.view(
140
+ bsz, q_len, self.n_heads_, self.head_dim_
141
+ ).transpose(1, 2)
142
+ key_states = key_states.view(
143
+ bsz, q_len, self.n_kv_heads_, self.head_dim_
144
+ ).transpose(1, 2)
145
+ value_states = value_states.view(
146
+ bsz, q_len, self.n_kv_heads_, self.head_dim_
147
+ ).transpose(1, 2)
148
+
149
+ cos, sin = rotary_emb
150
+ query_states, key_states = apply_rotary_pos_emb(
151
+ query_states, key_states, cos, sin
152
+ )
153
+
154
+ if past_key_value is not None:
155
+ cache_kwargs = {
156
+ "sin": sin,
157
+ "cos": cos,
158
+ "sliding_window": self.sliding_window_,
159
+ "cache_position": cache_position,
160
+ }
161
+ key_states, value_states = past_key_value.update(
162
+ key_states, value_states, self.layer_idx_, cache_kwargs
163
+ )
164
+
165
+ key_states = repeat_kv(key_states, self.n_rep_)
166
+ value_states = repeat_kv(value_states, self.n_rep_)
167
+
168
+ attn_weights = (
169
+ torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling_
170
+ )
171
+
172
+ if self.config_.attn_logit_softcapping_ is not None:
173
+ attn_weights = attn_weights / self.config_.attn_logit_softcapping_
174
+ attn_weights = torch.tanh(attn_weights)
175
+ attn_weights = attn_weights * self.config_.attn_logit_softcapping_
176
+
177
+ if attention_mask is not None: # no matter the length, we just slice it
178
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
179
+ attn_weights = attn_weights + causal_mask
180
+
181
+ # upcast attention to fp32
182
+ attn_weights = nn.functional.softmax(
183
+ attn_weights, dim=-1, dtype=torch.float32
184
+ ).to(query_states.dtype)
185
+ attn_output = torch.matmul(attn_weights, value_states)
186
+ attn_output = attn_output.transpose(1, 2).contiguous()
187
+
188
+ attn_output = attn_output.view(bsz, q_len, -1)
189
+ return self.o_proj_(attn_output, input_args)
190
+
191
+
192
+ class Gemma2FlashAttention2(Gemma2Attention):
193
+ def __init__(
194
+ self,
195
+ q_proj: nn.Module,
196
+ k_proj: nn.Module,
197
+ v_proj: nn.Module,
198
+ o_proj: nn.Module,
199
+ layer_idx: int,
200
+ config: Gemma2Config,
201
+ ):
202
+ assert is_flash_attn_2_available(), "Flash Attention is not available"
203
+ super().__init__(q_proj, k_proj, v_proj, o_proj, layer_idx, config)
204
+
205
+ def forward(
206
+ self,
207
+ hidden_states: torch.Tensor,
208
+ input_args: LLMModelInput,
209
+ rotary_emb: Tuple[torch.Tensor, torch.Tensor],
210
+ attention_mask: Optional[torch.Tensor] = None,
211
+ cache_position: Optional[torch.Tensor] = None,
212
+ past_key_value: Optional[LLMCache] = None,
213
+ ):
214
+ bsz, q_len, _ = hidden_states.size()
215
+
216
+ query_states = self.q_proj_(hidden_states, input_args)
217
+ key_states = self.k_proj_(hidden_states, input_args)
218
+ value_states = self.v_proj_(hidden_states, input_args)
219
+
220
+ query_states = query_states.view(
221
+ bsz, q_len, self.n_heads_, self.head_dim_
222
+ ).transpose(1, 2)
223
+ key_states = key_states.view(
224
+ bsz, q_len, self.n_kv_heads_, self.head_dim_
225
+ ).transpose(1, 2)
226
+ value_states = value_states.view(
227
+ bsz, q_len, self.n_kv_heads_, self.head_dim_
228
+ ).transpose(1, 2)
229
+
230
+ cos, sin = rotary_emb
231
+ query_states, key_states = apply_rotary_pos_emb(
232
+ query_states, key_states, cos, sin
233
+ )
234
+
235
+ if past_key_value is not None:
236
+ cache_kwargs = {
237
+ "sin": sin,
238
+ "cos": cos,
239
+ "sliding_window": self.sliding_window_,
240
+ "cache_position": cache_position,
241
+ }
242
+ key_states, value_states = past_key_value.update(
243
+ key_states, value_states, self.layer_idx_, cache_kwargs
244
+ )
245
+
246
+ if attention_mask is not None:
247
+ seq_len = attention_mask.shape[1]
248
+ key_states = key_states[:, :, :seq_len]
249
+ value_states = value_states[:, :, :seq_len]
250
+
251
+ query_states = query_states.transpose(1, 2)
252
+ key_states = key_states.transpose(1, 2)
253
+ value_states = value_states.transpose(1, 2)
254
+
255
+ input_dtype = query_states.dtype
256
+ if input_dtype == torch.float32:
257
+ if executor.is_bf16_supported():
258
+ target_dtype = torch.bfloat16
259
+ else:
260
+ target_dtype = torch.float16
261
+ query_states = query_states.to(target_dtype)
262
+ key_states = key_states.to(target_dtype)
263
+ value_states = value_states.to(target_dtype)
264
+
265
+ attn_output = flash_attention_forward(
266
+ query_states,
267
+ key_states,
268
+ value_states,
269
+ attention_mask,
270
+ q_len,
271
+ is_causal=self.is_causal_,
272
+ softmax_scale=self.scaling_,
273
+ sliding_window=(
274
+ self.sliding_window_ if self.config_.use_sliding_window_ else None
275
+ ),
276
+ softcap=(
277
+ self.config_.attn_logit_softcapping_
278
+ if is_package_available("flash_attn", "2.6.0")
279
+ else None
280
+ ),
281
+ ).to(input_dtype)
282
+
283
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
284
+ attn_output = self.o_proj_(attn_output, input_args)
285
+
286
+ return attn_output
287
+
288
+
289
+ GEMMA2_ATTENTION_CLASSES = {
290
+ "eager": Gemma2Attention,
291
+ "flash_attn": Gemma2FlashAttention2,
292
+ }
293
+
294
+
295
+ class Gemma2DecoderLayer(LLMDecoder):
296
+ def __init__(self, layer_idx: int, config: Gemma2Config) -> None:
297
+ super().__init__()
298
+ self.layer_id_: int = layer_idx
299
+ self.self_attn_: Gemma2Attention = None
300
+ self.mlp_: FeedForward = None
301
+ self.input_layernorm_: GemmaRMSNorm = None
302
+ self.post_attention_layernorm_: GemmaRMSNorm = None
303
+
304
+ self.config_ = config
305
+ self.is_sliding_ = not bool(layer_idx % 2)
306
+ self.pre_feedforward_layernorm_: GemmaRMSNorm = None
307
+ self.post_feedforward_layernorm_: GemmaRMSNorm = None
308
+ self.sliding_window_ = config.sliding_window_
309
+
310
+ def state_dict(self) -> Tuple[Dict[str, nn.Module], Dict[str, nn.Module]]:
311
+ return self.self_attn_.state_dict(), self.mlp_.state_dict()
312
+
313
+ def forward(
314
+ self,
315
+ hidden_states: torch.Tensor,
316
+ input_args: LLMModelInput,
317
+ rotary_emb: Tuple[torch.Tensor, torch.Tensor],
318
+ attention_mask: Optional[torch.Tensor] = None,
319
+ cache_position: Optional[torch.Tensor] = None,
320
+ past_key_value: Optional[LLMCache] = None,
321
+ ):
322
+ if (
323
+ self.config_.use_sliding_window_
324
+ and self.is_sliding_
325
+ and attention_mask is not None
326
+ ):
327
+ if self.config_.attn_implementation_ == "flash_attn":
328
+ if past_key_value is not None: # when decoding
329
+ attention_mask = attention_mask[:, -self.sliding_window :]
330
+ else:
331
+ min_dtype = torch.finfo(hidden_states.dtype).min
332
+ sliding_window_mask = torch.tril(
333
+ torch.ones_like(attention_mask, dtype=torch.bool),
334
+ diagonal=-self.sliding_window_,
335
+ )
336
+ attention_mask = torch.where(
337
+ sliding_window_mask, min_dtype, attention_mask
338
+ )
339
+ if attention_mask.shape[-1] <= 1: # when decoding
340
+ attention_mask = attention_mask[:, :, :, -self.sliding_window_ :]
341
+
342
+ residual = hidden_states
343
+
344
+ hidden_states = self.input_layernorm_(hidden_states)
345
+
346
+ hidden_states = self.self_attn_.forward(
347
+ hidden_states,
348
+ input_args,
349
+ rotary_emb,
350
+ attention_mask,
351
+ cache_position,
352
+ past_key_value,
353
+ )
354
+ hidden_states = self.post_attention_layernorm_(hidden_states)
355
+ hidden_states = residual + hidden_states
356
+
357
+ residual = hidden_states
358
+ hidden_states = self.pre_feedforward_layernorm_(hidden_states)
359
+ hidden_states, router_logits = self.mlp_.forward(hidden_states, input_args)
360
+ hidden_states = self.post_feedforward_layernorm_(hidden_states)
361
+ hidden_states = residual + hidden_states
362
+
363
+ if input_args.output_router_logits_:
364
+ router_logits = collect_plugin_router_logtis(
365
+ router_logits, input_args, self
366
+ )
367
+
368
+ return hidden_states, *router_logits
369
+
370
+
371
+ class Gemma2OutputLayer(nn.Module):
372
+ def __init__(self, config: Gemma2Config):
373
+ super().__init__()
374
+ self.lm_head_ = nn.Linear(
375
+ config.dim_,
376
+ config.vocab_size_,
377
+ bias=False,
378
+ dtype=config.dtype_,
379
+ device=config.device_,
380
+ )
381
+ self.final_logit_softcapping_ = config.final_logit_softcapping_
382
+
383
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
384
+ logits = self.lm_head_(hidden_states)
385
+ if self.final_logit_softcapping_ is not None:
386
+ logits = logits / self.final_logit_softcapping_
387
+ logits = torch.tanh(logits)
388
+ logits = logits * self.final_logit_softcapping_
389
+ return logits
390
+
391
+
392
+ class Gemma2ForCausalLM(LLMForCausalLM):
393
+ def __init__(self, config: Gemma2Config) -> None:
394
+ super().__init__()
395
+ self.config_ = config
396
+ self.padding_idx_ = config.pad_token_id_
397
+ self.vocab_size_ = config.vocab_size_
398
+ self.embed_tokens_: GemmaEmbedding = None
399
+ self.norm_: GemmaRMSNorm = None
400
+ self.rotary_emb_ = Gemma2RotaryEmbedding(
401
+ config.head_dim_,
402
+ max_position_embeddings=config.max_seq_len_,
403
+ base=config.rope_theta_,
404
+ device=config.device_,
405
+ )
406
+ self.lm_head_ = Gemma2OutputLayer(config)
407
+ self.layers_: List[Gemma2DecoderLayer] = []
408
+
409
+ def embed_tokens(self, input_ids: torch.Tensor) -> torch.Tensor:
410
+ return self.embed_tokens_(input_ids)
411
+
412
+ def rotary_embed(
413
+ self, input_tensor: torch.Tensor, position_ids: torch.Tensor
414
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
415
+ return self.rotary_emb_(input_tensor, position_ids)
416
+
417
+ def decoder_stack(self) -> List[LLMDecoder]:
418
+ return self.layers_
419
+
420
+ def norm(self, hidden_states: torch.Tensor) -> torch.Tensor:
421
+ return self.norm_(hidden_states)
422
+
423
+ def causal_mask(
424
+ self,
425
+ attention_mask: torch.Tensor,
426
+ input_tensor: torch.Tensor,
427
+ cache_position: torch.Tensor,
428
+ past_key_values: Optional[LLMCache],
429
+ ) -> torch.Tensor:
430
+
431
+ return prepare_4d_causal_attention_mask(
432
+ attention_mask,
433
+ input_tensor,
434
+ cache_position,
435
+ past_key_values,
436
+ )
437
+
438
+ def cache_implementation(self) -> str:
439
+ if self.config_.use_sliding_window_ and self.config_.sliding_window_:
440
+ return "hybrid"
441
+ else:
442
+ return "dynamic"
443
+
444
+ def model_config(self) -> Gemma2Config:
445
+ return self.config_
446
+
447
+ @staticmethod
448
+ def from_pretrained(
449
+ llm_model: modeling_gemma2.Gemma2PreTrainedModel,
450
+ attn_impl: str = "eager",
451
+ use_sliding_window: bool = False,
452
+ device: str = executor.default_device_name(),
453
+ ):
454
+ llm_config: modeling_gemma2.Gemma2Config = llm_model.config
455
+ model_config = Gemma2Config(
456
+ name_or_path_=llm_config.name_or_path,
457
+ vocab_size_=llm_config.vocab_size,
458
+ dim_=llm_config.hidden_size,
459
+ head_dim_=llm_config.head_dim,
460
+ intermediate_=llm_config.intermediate_size,
461
+ n_layers_=llm_config.num_hidden_layers,
462
+ n_heads_=llm_config.num_attention_heads,
463
+ n_kv_heads_=llm_config.num_key_value_heads,
464
+ hidden_act_=llm_config.hidden_activation,
465
+ rms_norm_eps_=llm_config.rms_norm_eps,
466
+ max_seq_len_=llm_config.max_position_embeddings,
467
+ rope_theta_=llm_config.rope_theta,
468
+ attn_logit_softcapping_=llm_config.attn_logit_softcapping,
469
+ final_logit_softcapping_=llm_config.final_logit_softcapping,
470
+ query_pre_attn_scalar_=llm_config.query_pre_attn_scalar,
471
+ pad_token_id_=llm_config.pad_token_id,
472
+ attn_implementation_=attn_impl,
473
+ use_sliding_window_=use_sliding_window,
474
+ sliding_window_=llm_config.sliding_window,
475
+ device_=torch.device(device),
476
+ dtype_=llm_model.dtype,
477
+ )
478
+
479
+ if model_config.pad_token_id_ is None:
480
+ model_config.pad_token_id_ = -1
481
+
482
+ model = Gemma2ForCausalLM(model_config)
483
+ llm_model.requires_grad_(False)
484
+ model.embed_tokens_ = GemmaEmbedding(
485
+ llm_model.model.embed_tokens.weight,
486
+ model_config.pad_token_id_,
487
+ model_config.dim_**0.5,
488
+ )
489
+ model.norm_ = GemmaRMSNorm(
490
+ llm_model.model.norm.weight, model_config.rms_norm_eps_
491
+ )
492
+ copy_parameters(llm_model.lm_head, model.lm_head_.lm_head_)
493
+
494
+ for layer_idx, layer in enumerate(llm_model.model.layers):
495
+ decoder = Gemma2DecoderLayer(layer_idx, model_config)
496
+ decoder.self_attn_ = GEMMA2_ATTENTION_CLASSES[
497
+ model_config.attn_implementation_
498
+ ](
499
+ layer.self_attn.q_proj,
500
+ layer.self_attn.k_proj,
501
+ layer.self_attn.v_proj,
502
+ layer.self_attn.o_proj,
503
+ layer_idx,
504
+ model_config,
505
+ )
506
+ decoder.mlp_ = FeedForward(
507
+ LlamaMLP(
508
+ layer.mlp.gate_proj,
509
+ layer.mlp.down_proj,
510
+ layer.mlp.up_proj,
511
+ model_config,
512
+ )
513
+ )
514
+ decoder.input_layernorm_ = GemmaRMSNorm(
515
+ layer.input_layernorm.weight, model_config.rms_norm_eps_
516
+ )
517
+ decoder.post_attention_layernorm_ = GemmaRMSNorm(
518
+ layer.post_attention_layernorm.weight, model_config.rms_norm_eps_
519
+ )
520
+ decoder.pre_feedforward_layernorm_ = GemmaRMSNorm(
521
+ layer.pre_feedforward_layernorm.weight, model_config.rms_norm_eps_
522
+ )
523
+ decoder.post_feedforward_layernorm_ = GemmaRMSNorm(
524
+ layer.post_feedforward_layernorm.weight, model_config.rms_norm_eps_
525
+ )
526
+ model.layers_.append(decoder)
527
+
528
+ return model
c2cite/models/modeling_llama.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, Dict, List, Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from transformers.activations import ACT2FN
8
+ from transformers.models.llama import modeling_llama
9
+ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
10
+ from transformers.utils import is_flash_attn_2_available
11
+
12
+ from moe_peft.common import (
13
+ ROPE_INIT_FUNCTIONS,
14
+ FeedForward,
15
+ Linear,
16
+ LLMAttention,
17
+ LLMCache,
18
+ LLMDecoder,
19
+ LLMFeedForward,
20
+ LLMForCausalLM,
21
+ LLMModelConfig,
22
+ LLMModelInput,
23
+ collect_plugin_router_logtis,
24
+ eager_attention_forward,
25
+ flash_attention_forward,
26
+ prepare_4d_causal_attention_mask,
27
+ slice_tensor,
28
+ )
29
+ from moe_peft.executors import executor
30
+ from moe_peft.utils import copy_parameters
31
+
32
+
33
+ @dataclass
34
+ class LlamaConfig(LLMModelConfig):
35
+ rms_norm_eps_: float = 1e-6
36
+ rope_scaling_: Optional[Dict[str, Any]] = None
37
+
38
+
39
+ class LlamaRotaryEmbedding(nn.Module):
40
+ def __init__(
41
+ self,
42
+ config: Optional[LlamaConfig],
43
+ scaling_factor=1.0,
44
+ rope_type="default",
45
+ ):
46
+ super().__init__()
47
+ self.rope_kwargs = {
48
+ "rope_type": rope_type,
49
+ "factor": scaling_factor,
50
+ "dim": config.head_dim_,
51
+ "base": config.rope_theta_,
52
+ "max_position_embeddings": config.max_seq_len_,
53
+ }
54
+ if config is None:
55
+ self.rope_type = rope_type
56
+ self.max_seq_len_cached = config.max_seq_len_
57
+ self.original_max_seq_len = config.max_seq_len_
58
+ else:
59
+ # BC: "rope_type" was originally "type"
60
+ if config.rope_scaling_ is not None:
61
+ self.rope_type = config.rope_scaling_.get(
62
+ "rope_type", config.rope_scaling_.get("type")
63
+ )
64
+ else:
65
+ self.rope_type = "default"
66
+ self.max_seq_len_cached = config.max_seq_len_
67
+ self.original_max_seq_len = config.max_seq_len_
68
+
69
+ self.config = config
70
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
71
+
72
+ inv_freq, self.attention_scaling = self.rope_init_fn(
73
+ self.config, config.device_, **self.rope_kwargs
74
+ )
75
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
76
+ self.original_inv_freq = self.inv_freq
77
+
78
+ def _dynamic_frequency_update(self, position_ids, device):
79
+ seq_len = torch.max(position_ids) + 1
80
+ if seq_len > self.max_seq_len_cached: # growth
81
+ inv_freq, self.attention_scaling = self.rope_init_fn(
82
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
83
+ )
84
+ self.register_buffer(
85
+ "inv_freq", inv_freq, persistent=False
86
+ ) # TODO joao: may break with compilation
87
+ self.max_seq_len_cached = seq_len
88
+
89
+ if (
90
+ seq_len < self.original_max_seq_len
91
+ and self.max_seq_len_cached > self.original_max_seq_len
92
+ ): # reset
93
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
94
+ self.max_seq_len_cached = self.original_max_seq_len
95
+
96
+ @torch.no_grad()
97
+ def forward(self, x, position_ids):
98
+ if "dynamic" in self.rope_type:
99
+ self._dynamic_frequency_update(position_ids, device=x.device)
100
+
101
+ # Core RoPE block
102
+ inv_freq_expanded = (
103
+ self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
104
+ )
105
+ position_ids_expanded = position_ids[:, None, :].float()
106
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
107
+ device_type = x.device.type
108
+ device_type = (
109
+ device_type
110
+ if isinstance(device_type, str) and device_type != "mps"
111
+ else "cpu"
112
+ )
113
+ with torch.autocast(device_type=device_type, enabled=False):
114
+ freqs = (
115
+ inv_freq_expanded.float() @ position_ids_expanded.float()
116
+ ).transpose(1, 2)
117
+ emb = torch.cat((freqs, freqs), dim=-1)
118
+ cos = emb.cos()
119
+ sin = emb.sin()
120
+
121
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
122
+ cos = cos * self.attention_scaling
123
+ sin = sin * self.attention_scaling
124
+
125
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
126
+
127
+
128
+ # Multi-headed attention from 'Attention Is All You Need' paper.
129
+ class LlamaAttention(LLMAttention):
130
+ def __init__(
131
+ self,
132
+ wq: nn.Module,
133
+ wk: nn.Module,
134
+ wv: nn.Module,
135
+ wo: nn.Module,
136
+ idx: int,
137
+ args: LlamaConfig,
138
+ ):
139
+ super().__init__()
140
+ # attention
141
+ self.wq_: Linear = Linear(wq, args.device_) # dim * dim
142
+ self.wk_: Linear = Linear(wk, args.device_) # dim * dim
143
+ self.wv_: Linear = Linear(wv, args.device_) # dim * dim
144
+ self.wo_: Linear = Linear(wo, args.device_) # dim * dim
145
+ # config
146
+ self.layer_idx_ = idx
147
+ self.dim_ = args.dim_
148
+ self.n_heads_ = args.n_heads_
149
+ self.n_kv_heads_ = args.n_kv_heads_
150
+ self.n_rep_ = self.n_heads_ // self.n_kv_heads_
151
+ self.head_dim_ = args.head_dim_
152
+ self.dtype_ = args.dtype_
153
+ self.is_causal_ = True
154
+
155
+ def state_dict(self) -> Dict[str, Linear]:
156
+ return {
157
+ "q_proj": self.wq_,
158
+ "k_proj": self.wk_,
159
+ "v_proj": self.wv_,
160
+ "o_proj": self.wo_,
161
+ }
162
+
163
+ def forward(
164
+ self,
165
+ hidden_states: torch.Tensor,
166
+ input_args: LLMModelInput,
167
+ rotary_emb: Tuple[torch.Tensor, torch.Tensor],
168
+ attention_mask: Optional[torch.Tensor] = None,
169
+ cache_position: Optional[torch.Tensor] = None,
170
+ past_key_value: Optional[LLMCache] = None,
171
+ ):
172
+ batch_size, max_seq_len, _ = hidden_states.shape
173
+
174
+ xq = self.wq_.forward(hidden_states, input_args)
175
+ xk = self.wk_.forward(hidden_states, input_args)
176
+ xv = self.wv_.forward(hidden_states, input_args)
177
+
178
+ # conver shape to multi head
179
+ xq = xq.view(batch_size, max_seq_len, self.n_heads_, self.head_dim_).transpose(
180
+ 1, 2
181
+ )
182
+ xk = xk.view(
183
+ batch_size, max_seq_len, self.n_kv_heads_, self.head_dim_
184
+ ).transpose(1, 2)
185
+ xv = xv.view(
186
+ batch_size, max_seq_len, self.n_kv_heads_, self.head_dim_
187
+ ).transpose(1, 2)
188
+
189
+ # apply rotary embedding
190
+ cos, sin = rotary_emb
191
+ xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin)
192
+
193
+ if past_key_value is not None:
194
+ cache_kwargs = {
195
+ "sin": sin,
196
+ "cos": cos,
197
+ "cache_position": cache_position,
198
+ }
199
+ xk, xv = past_key_value.update(xk, xv, self.layer_idx_, cache_kwargs)
200
+
201
+ # for llama2 need to repeat the heads
202
+ # before dim: batch_size, n_kv_head, seq_len, head_dim
203
+ # after dim: batch_size, n_head, seq_len, head_dim
204
+ xk = repeat_kv(xk, self.n_rep_)
205
+ xv = repeat_kv(xv, self.n_rep_)
206
+
207
+ attention_score, attention_matrix = eager_attention_forward(xq, xk, xv, attention_mask)
208
+ attention_score = attention_score.reshape(batch_size, max_seq_len, -1)
209
+
210
+ # get output attention score
211
+ return self.wo_.forward(attention_score, input_args), attention_matrix
212
+
213
+
214
+ class LlamaFlashAttention(LlamaAttention):
215
+ def __init__(
216
+ self,
217
+ wq: nn.Module,
218
+ wk: nn.Module,
219
+ wv: nn.Module,
220
+ wo: nn.Module,
221
+ idx: int,
222
+ args: LlamaConfig,
223
+ ):
224
+ assert is_flash_attn_2_available(), "Flash Attention is not available"
225
+ super().__init__(wq, wk, wv, wo, idx, args)
226
+
227
+ def forward(
228
+ self,
229
+ hidden_states: torch.Tensor,
230
+ input_args: LLMModelInput,
231
+ rotary_emb: Tuple[torch.Tensor, torch.Tensor],
232
+ attention_mask: Optional[torch.Tensor] = None,
233
+ cache_position: Optional[torch.Tensor] = None,
234
+ past_key_value: Optional[LLMCache] = None,
235
+ ):
236
+ batch_size, max_seq_len, _ = hidden_states.shape
237
+
238
+ xq = self.wq_.forward(hidden_states, input_args)
239
+ xk = self.wk_.forward(hidden_states, input_args)
240
+ xv = self.wv_.forward(hidden_states, input_args)
241
+
242
+ # conver shape to multi head
243
+ xq = xq.view(batch_size, max_seq_len, self.n_heads_, self.head_dim_).transpose(
244
+ 1, 2
245
+ )
246
+ xk = xk.view(
247
+ batch_size, max_seq_len, self.n_kv_heads_, self.head_dim_
248
+ ).transpose(1, 2)
249
+ xv = xv.view(
250
+ batch_size, max_seq_len, self.n_kv_heads_, self.head_dim_
251
+ ).transpose(1, 2)
252
+
253
+ # apply rotary embedding
254
+ cos, sin = rotary_emb
255
+ xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin)
256
+
257
+ if past_key_value is not None:
258
+ cache_kwargs = {
259
+ "sin": sin,
260
+ "cos": cos,
261
+ "cache_position": cache_position,
262
+ }
263
+ xk, xv = past_key_value.update(xk, xv, self.layer_idx_, cache_kwargs)
264
+
265
+ xq = xq.transpose(1, 2)
266
+ xk = xk.transpose(1, 2)
267
+ xv = xv.transpose(1, 2)
268
+
269
+ input_dtype = xq.dtype
270
+ if input_dtype == torch.float32:
271
+ if executor.is_bf16_supported():
272
+ target_dtype = torch.bfloat16
273
+ else:
274
+ target_dtype = torch.float16
275
+ xq = xq.to(target_dtype)
276
+ xk = xk.to(target_dtype)
277
+ xv = xv.to(target_dtype)
278
+
279
+ attn_output = flash_attention_forward(
280
+ xq,
281
+ xk,
282
+ xv,
283
+ attention_mask,
284
+ max_seq_len,
285
+ is_causal=self.is_causal_,
286
+ ).to(input_dtype)
287
+
288
+ attn_output = attn_output.reshape(batch_size, max_seq_len, -1).contiguous()
289
+ attn_output = self.wo_.forward(attn_output, input_args)
290
+
291
+ return attn_output
292
+
293
+
294
+ LLAMA_ATTENTION_CLASSES = {
295
+ "eager": LlamaAttention,
296
+ "flash_attn": LlamaFlashAttention,
297
+ }
298
+
299
+
300
+ class LlamaMLP(LLMFeedForward):
301
+ def __init__(
302
+ self, w1: nn.Module, w2: nn.Module, w3: nn.Module, args: LlamaConfig
303
+ ) -> None:
304
+ super().__init__()
305
+ # feed forward
306
+ self.w1_: Linear = Linear(w1, args.device_)
307
+ self.w2_: Linear = Linear(w2, args.device_)
308
+ self.w3_: Linear = Linear(w3, args.device_)
309
+ self.act_ = ACT2FN[args.hidden_act_]
310
+
311
+ def state_dict(self) -> Dict[str, nn.Module]:
312
+ return {
313
+ "gate_proj": self.w1_,
314
+ "down_proj": self.w2_,
315
+ "up_proj": self.w3_,
316
+ }
317
+
318
+ def _batch_forward(
319
+ self, data: torch.Tensor, input_args: LLMModelInput
320
+ ) -> torch.Tensor:
321
+ w1 = self.w1_.forward(data, input_args)
322
+ w3 = self.w3_.forward(data, input_args)
323
+ return self.w2_.forward(self.act_(w1) * w3, input_args)
324
+
325
+ def _lora_forward(
326
+ self, lora_name: str, act_fn: nn.Module, data: torch.Tensor
327
+ ) -> torch.Tensor:
328
+ # Applying LoRA weights to FFN weights
329
+ if lora_name in self.w1_.loras_:
330
+ w1 = self.w1_.loras_[lora_name].forward(
331
+ self.w1_.base_layer_.forward(data), data
332
+ )
333
+ else:
334
+ w1 = self.w1_.base_layer_.forward(data)
335
+
336
+ if lora_name in self.w3_.loras_:
337
+ w3 = self.w3_.loras_[lora_name].forward(
338
+ self.w3_.base_layer_.forward(data), data
339
+ )
340
+ else:
341
+ w3 = self.w3_.base_layer_.forward(data)
342
+
343
+ act_result = act_fn(w1) * w3
344
+ if lora_name in self.w2_.loras_:
345
+ return self.w2_.loras_[lora_name].forward(
346
+ self.w2_.base_layer_.forward(act_result), act_result
347
+ )
348
+ else:
349
+ return self.w2_.base_layer_.forward(act_result)
350
+
351
+ def _mixlora_forward(
352
+ self, moe_name, act_fn, expert_mask, hidden_states, input_dtype
353
+ ):
354
+ common_w1 = self.w1_.base_layer_.forward(hidden_states.to(input_dtype)).to(
355
+ hidden_states.dtype
356
+ )
357
+ common_w3 = self.w3_.base_layer_.forward(hidden_states.to(input_dtype)).to(
358
+ hidden_states.dtype
359
+ )
360
+ final_expert_states = []
361
+ for expert_idx in range(expert_mask.shape[0]):
362
+ _, top_x = torch.where(expert_mask[expert_idx])
363
+
364
+ lora_name = f"moe.{moe_name}.experts.{expert_idx}"
365
+ if lora_name in self.w1_.loras_:
366
+ lora_data = slice_tensor(hidden_states, top_x, input_dtype)
367
+ w1 = self.w1_.loras_[lora_name].forward(
368
+ slice_tensor(common_w1, top_x, input_dtype), lora_data
369
+ )
370
+ else:
371
+ lora_data = None
372
+ w1 = slice_tensor(common_w1, top_x, input_dtype)
373
+
374
+ if lora_name in self.w3_.loras_:
375
+ w3 = self.w3_.loras_[lora_name].forward(
376
+ slice_tensor(common_w3, top_x, input_dtype),
377
+ slice_tensor(hidden_states, top_x, input_dtype, lora_data),
378
+ )
379
+ else:
380
+ w3 = slice_tensor(common_w3, top_x, input_dtype)
381
+
382
+ act_result = act_fn(w1) * w3
383
+ if lora_name in self.w2_.loras_:
384
+ final_expert_states.append(
385
+ self.w2_.loras_[lora_name].forward(
386
+ self.w2_.base_layer_.forward(act_result), act_result
387
+ )
388
+ )
389
+ else:
390
+ final_expert_states.append(self.w2_.base_layer_.forward(act_result))
391
+
392
+ return final_expert_states
393
+
394
+
395
+ class LlamaRMSNorm(nn.Module):
396
+ def __init__(self, weight: torch.Tensor, eps: float = 1e-6):
397
+ super().__init__()
398
+ self.norm_eps_ = eps
399
+ self.weight_ = weight
400
+
401
+ def forward(self, data: torch.Tensor) -> torch.Tensor:
402
+ input_dtype = data.dtype
403
+ v = data.to(torch.float32).pow(2).mean(-1, keepdim=True)
404
+ data = data * torch.rsqrt(v + self.norm_eps_)
405
+
406
+ return (self.weight_ * data).to(input_dtype)
407
+
408
+
409
+ class LlamaDecoderLayer(LLMDecoder):
410
+ def __init__(self, layer_id: int) -> None:
411
+ super().__init__()
412
+ self.layer_id_: int = layer_id
413
+ self.self_attn_: LlamaAttention = None
414
+ self.mlp_: FeedForward = None
415
+ self.input_layernorm_: LlamaRMSNorm = None
416
+ self.post_attention_layernorm_: LlamaRMSNorm = None
417
+
418
+ def state_dict(self) -> Tuple[Dict[str, nn.Module], Dict[str, nn.Module]]:
419
+ return self.self_attn_.state_dict(), self.mlp_.state_dict()
420
+
421
+ def forward(
422
+ self,
423
+ hidden_states: torch.Tensor,
424
+ input_args: LLMModelInput,
425
+ rotary_emb: Tuple[torch.Tensor, torch.Tensor],
426
+ attention_mask: Optional[torch.Tensor] = None,
427
+ cache_position: Optional[torch.Tensor] = None,
428
+ past_key_value: Optional[LLMCache] = None,
429
+ ):
430
+
431
+ residual = hidden_states
432
+ hidden_states = self.input_layernorm_(hidden_states)
433
+ # Self Attention
434
+ hidden_states, attention_matrix = self.self_attn_.forward(
435
+ hidden_states,
436
+ input_args,
437
+ rotary_emb,
438
+ attention_mask,
439
+ cache_position,
440
+ past_key_value,
441
+ )
442
+ hidden_states = residual + hidden_states
443
+ # Fully Connected
444
+ residual = hidden_states
445
+ hidden_states = self.post_attention_layernorm_(hidden_states)
446
+ hidden_states = self.mlp_.forward(hidden_states, input_args)
447
+ hidden_states = residual + hidden_states
448
+
449
+ return hidden_states, attention_matrix
450
+
451
+
452
+ class LlamaEmbedding(nn.Module):
453
+ def __init__(self, embedding: torch.Tensor, pad_token: int):
454
+ super().__init__()
455
+ self.token_embedding_: torch.Tensor = embedding
456
+ self.padding_idx_: int = pad_token
457
+
458
+ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
459
+ data = F.embedding(tokens, self.token_embedding_, padding_idx=self.padding_idx_)
460
+ return data
461
+
462
+
463
+ class LlamaForCausalLM(LLMForCausalLM):
464
+ def __init__(self, config: LlamaConfig) -> None:
465
+ super().__init__()
466
+ self.config_ = config
467
+ self.padding_idx_ = config.pad_token_id_
468
+ self.vocab_size_ = config.vocab_size_
469
+ self.embed_tokens_: LlamaEmbedding = None
470
+ self.norm_: LlamaRMSNorm = None
471
+ self.rotary_emb_ = LlamaRotaryEmbedding(config)
472
+ self.lm_head_ = nn.Linear(
473
+ config.dim_,
474
+ config.vocab_size_,
475
+ bias=False,
476
+ dtype=config.dtype_,
477
+ device=config.device_,
478
+ )
479
+ self.layers_: List[LlamaDecoderLayer] = []
480
+
481
+ def embed_tokens(self, input_ids: torch.Tensor) -> torch.Tensor:
482
+ return self.embed_tokens_(input_ids)
483
+
484
+ def rotary_embed(
485
+ self, input_tensor: torch.Tensor, position_ids: torch.Tensor
486
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
487
+ return self.rotary_emb_(input_tensor, position_ids)
488
+
489
+ def decoder_stack(self) -> List[LLMDecoder]:
490
+ return self.layers_
491
+
492
+ def norm(self, hidden_states: torch.Tensor) -> torch.Tensor:
493
+ return self.norm_(hidden_states)
494
+
495
+ def causal_mask(
496
+ self,
497
+ attention_mask: torch.Tensor,
498
+ input_tensor: torch.Tensor,
499
+ cache_position: torch.Tensor,
500
+ past_key_values: Optional[LLMCache],
501
+ ) -> torch.Tensor:
502
+
503
+ return prepare_4d_causal_attention_mask(
504
+ attention_mask,
505
+ input_tensor,
506
+ cache_position,
507
+ past_key_values,
508
+ )
509
+
510
+ def model_config(self) -> LlamaConfig:
511
+ return self.config_
512
+
513
+ @staticmethod
514
+ def from_pretrained(
515
+ llm_model: modeling_llama.LlamaForCausalLM,
516
+ attn_impl: str = "eager",
517
+ use_sliding_window: bool = False,
518
+ device: str = executor.default_device_name(),
519
+ ):
520
+ assert not use_sliding_window, "Llama model does not support SWA."
521
+ llm_config: modeling_llama.LlamaConfig = llm_model.config
522
+ llm_args = LlamaConfig(
523
+ name_or_path_=llm_config.name_or_path,
524
+ vocab_size_=llm_config.vocab_size,
525
+ dim_=llm_config.hidden_size,
526
+ head_dim_=llm_config.hidden_size // llm_config.num_attention_heads,
527
+ intermediate_=llm_config.intermediate_size,
528
+ n_layers_=llm_config.num_hidden_layers,
529
+ n_heads_=llm_config.num_attention_heads,
530
+ n_kv_heads_=llm_config.num_key_value_heads,
531
+ hidden_act_=llm_config.hidden_act,
532
+ rms_norm_eps_=llm_config.rms_norm_eps,
533
+ max_seq_len_=llm_config.max_position_embeddings,
534
+ rope_theta_=llm_config.rope_theta,
535
+ rope_scaling_=llm_config.rope_scaling,
536
+ pad_token_id_=llm_config.pad_token_id,
537
+ attn_implementation_=attn_impl,
538
+ device_=torch.device(device),
539
+ dtype_=llm_model.dtype,
540
+ )
541
+
542
+ if llm_args.pad_token_id_ is None:
543
+ llm_args.pad_token_id_ = -1
544
+
545
+ model = LlamaForCausalLM(llm_args)
546
+ llm_model.requires_grad_(False)
547
+ model.embed_tokens_ = LlamaEmbedding(
548
+ llm_model.model.embed_tokens.weight, llm_args.pad_token_id_
549
+ )
550
+ model.norm_ = LlamaRMSNorm(llm_model.model.norm.weight, llm_args.rms_norm_eps_)
551
+ copy_parameters(llm_model.lm_head, model.lm_head_)
552
+
553
+ for idx, layer in enumerate(llm_model.model.layers):
554
+ decoder = LlamaDecoderLayer(idx)
555
+ decoder.self_attn_ = LLAMA_ATTENTION_CLASSES[llm_args.attn_implementation_](
556
+ layer.self_attn.q_proj,
557
+ layer.self_attn.k_proj,
558
+ layer.self_attn.v_proj,
559
+ layer.self_attn.o_proj,
560
+ idx,
561
+ llm_args,
562
+ )
563
+ decoder.mlp_ = FeedForward(
564
+ LlamaMLP(
565
+ layer.mlp.gate_proj,
566
+ layer.mlp.down_proj,
567
+ layer.mlp.up_proj,
568
+ llm_args,
569
+ )
570
+ )
571
+ decoder.input_layernorm_ = LlamaRMSNorm(
572
+ layer.input_layernorm.weight, llm_args.rms_norm_eps_
573
+ )
574
+ decoder.post_attention_layernorm_ = LlamaRMSNorm(
575
+ layer.post_attention_layernorm.weight, llm_args.rms_norm_eps_
576
+ )
577
+ model.layers_.append(decoder)
578
+
579
+ return model
c2cite/models/modeling_mistral.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers.models.mistral import modeling_mistral
7
+ from transformers.models.qwen2 import modeling_qwen2
8
+ from transformers.utils import is_flash_attn_2_available
9
+
10
+ from moe_peft.common import (
11
+ FeedForward,
12
+ LLMCache,
13
+ LLMModelInput,
14
+ flash_attention_forward,
15
+ )
16
+ from moe_peft.executors import executor
17
+ from moe_peft.models.modeling_llama import (
18
+ LlamaAttention,
19
+ LlamaConfig,
20
+ LlamaDecoderLayer,
21
+ LlamaEmbedding,
22
+ LlamaForCausalLM,
23
+ LlamaMLP,
24
+ LlamaRMSNorm,
25
+ apply_rotary_pos_emb,
26
+ repeat_kv,
27
+ )
28
+ from moe_peft.utils import copy_parameters
29
+
30
+
31
+ @dataclass
32
+ class MistralConfig(LlamaConfig):
33
+ use_sliding_window_: bool = False
34
+ max_window_layers_: int = None
35
+ sliding_window_: int = None
36
+
37
+
38
+ class MistralFlashAttention(LlamaAttention):
39
+ def __init__(
40
+ self,
41
+ wq: nn.Module,
42
+ wk: nn.Module,
43
+ wv: nn.Module,
44
+ wo: nn.Module,
45
+ idx: int,
46
+ args: MistralConfig,
47
+ ):
48
+ assert is_flash_attn_2_available(), "Flash Attention is not available"
49
+ super().__init__(wq, wk, wv, wo, idx, args)
50
+ # Qwen2
51
+ self.use_sliding_window_ = args.use_sliding_window_
52
+ self.max_window_layers_ = args.max_window_layers_
53
+ # Mistral and Qwen2
54
+ self.sliding_window_ = args.sliding_window_
55
+
56
+ def forward(
57
+ self,
58
+ hidden_states: torch.Tensor,
59
+ input_args: LLMModelInput,
60
+ rotary_emb: Tuple[torch.Tensor, torch.Tensor],
61
+ attention_mask: Optional[torch.Tensor] = None,
62
+ cache_position: Optional[torch.Tensor] = None,
63
+ past_key_value: Optional[LLMCache] = None,
64
+ ):
65
+ batch_size, max_seq_len, _ = hidden_states.shape
66
+
67
+ xq = self.wq_.forward(hidden_states, input_args)
68
+ xk = self.wk_.forward(hidden_states, input_args)
69
+ xv = self.wv_.forward(hidden_states, input_args)
70
+
71
+ # conver shape to multi head
72
+ xq = xq.view(batch_size, max_seq_len, self.n_heads_, self.head_dim_).transpose(
73
+ 1, 2
74
+ )
75
+ xk = xk.view(
76
+ batch_size, max_seq_len, self.n_kv_heads_, self.head_dim_
77
+ ).transpose(1, 2)
78
+ xv = xv.view(
79
+ batch_size, max_seq_len, self.n_kv_heads_, self.head_dim_
80
+ ).transpose(1, 2)
81
+
82
+ kv_seq_len = xk.shape[-2]
83
+ if past_key_value is not None:
84
+ kv_seq_len += cache_position[0]
85
+
86
+ # apply rotary embedding
87
+ cos, sin = rotary_emb
88
+ xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin)
89
+
90
+ if past_key_value is not None:
91
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
92
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx_) > 0
93
+ if (
94
+ self.sliding_window_ is not None
95
+ and kv_seq_len > self.sliding_window_
96
+ and cache_has_contents
97
+ ):
98
+ slicing_tokens = 1 - self.sliding_window_
99
+
100
+ past_key = past_key_value[self.layer_idx_][0]
101
+ past_value = past_key_value[self.layer_idx_][1]
102
+
103
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
104
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
105
+
106
+ if past_key.shape[-2] != self.sliding_window_ - 1:
107
+ raise ValueError(
108
+ f"past key must have a shape of (`batch_size, num_heads, self.sliding_window - 1, head_dim`), got"
109
+ f" {past_key.shape}"
110
+ )
111
+
112
+ if attention_mask is not None:
113
+ attention_mask = attention_mask[:, slicing_tokens:]
114
+ attention_mask = torch.cat(
115
+ [attention_mask, torch.ones_like(attention_mask[:, -1:])],
116
+ dim=-1,
117
+ )
118
+
119
+ cache_kwargs = {
120
+ "sin": sin,
121
+ "cos": cos,
122
+ "cache_position": cache_position,
123
+ } # Specific to RoPE models
124
+ xk, xv = past_key_value.update(xk, xv, self.layer_idx_, cache_kwargs)
125
+
126
+ xk = repeat_kv(xk, self.n_rep_)
127
+ xv = repeat_kv(xv, self.n_rep_)
128
+
129
+ input_dtype = xq.dtype
130
+ if input_dtype == torch.float32:
131
+ if executor.is_bf16_supported():
132
+ target_dtype = torch.bfloat16
133
+ else:
134
+ target_dtype = torch.float16
135
+ xq = xq.to(target_dtype)
136
+ xk = xk.to(target_dtype)
137
+ xv = xv.to(target_dtype)
138
+
139
+ xq = xq.transpose(1, 2)
140
+ xk = xk.transpose(1, 2)
141
+ xv = xv.transpose(1, 2)
142
+
143
+ if (
144
+ (self.use_sliding_window_ is None or self.use_sliding_window_)
145
+ and self.sliding_window_ is not None
146
+ and (
147
+ self.max_window_layers_ is None
148
+ or self.layer_idx_ >= self.max_window_layers_
149
+ )
150
+ ):
151
+ sliding_window = self.sliding_window_
152
+ else:
153
+ sliding_window = None
154
+
155
+ attn_output = flash_attention_forward(
156
+ xq,
157
+ xk,
158
+ xv,
159
+ attention_mask,
160
+ max_seq_len,
161
+ is_causal=self.is_causal_,
162
+ sliding_window=sliding_window,
163
+ ).to(input_dtype)
164
+
165
+ attn_output = attn_output.reshape(
166
+ batch_size, max_seq_len, self.dim_
167
+ ).contiguous()
168
+ attn_output = self.wo_.forward(attn_output, input_args)
169
+
170
+ return attn_output
171
+
172
+
173
+ MISTRAL_ATTENTION_CLASSES = {
174
+ "eager": LlamaAttention,
175
+ "flash_attn": MistralFlashAttention,
176
+ }
177
+
178
+
179
+ class MistralForCausalLM(LlamaForCausalLM):
180
+ def __init__(self, config: MistralConfig) -> None:
181
+ super().__init__(config)
182
+
183
+ @staticmethod
184
+ def from_pretrained(
185
+ llm_model: modeling_mistral.MistralForCausalLM,
186
+ attn_impl: str = "eager",
187
+ use_sliding_window: bool = False,
188
+ device: str = executor.default_device_name(),
189
+ ):
190
+ llm_config: modeling_mistral.MistralConfig = llm_model.config
191
+ llm_args = MistralConfig(
192
+ name_or_path_=llm_config.name_or_path,
193
+ vocab_size_=llm_config.vocab_size,
194
+ dim_=llm_config.hidden_size,
195
+ head_dim_=llm_config.hidden_size // llm_config.num_attention_heads,
196
+ intermediate_=llm_config.intermediate_size,
197
+ n_layers_=llm_config.num_hidden_layers,
198
+ n_heads_=llm_config.num_attention_heads,
199
+ n_kv_heads_=llm_config.num_key_value_heads,
200
+ hidden_act_=llm_config.hidden_act,
201
+ rms_norm_eps_=llm_config.rms_norm_eps,
202
+ max_seq_len_=llm_config.max_position_embeddings,
203
+ rope_theta_=llm_config.rope_theta,
204
+ pad_token_id_=llm_config.pad_token_id,
205
+ attn_implementation_=attn_impl,
206
+ use_sliding_window_=use_sliding_window,
207
+ sliding_window_=llm_config.sliding_window,
208
+ device_=torch.device(device),
209
+ dtype_=llm_model.dtype,
210
+ )
211
+
212
+ # compatible with qwen2
213
+ if isinstance(llm_config, modeling_qwen2.Qwen2Config):
214
+ llm_args.max_window_layers_ = llm_config.max_window_layers
215
+
216
+ if llm_args.pad_token_id_ is None:
217
+ llm_args.pad_token_id_ = -1
218
+
219
+ model = MistralForCausalLM(llm_args)
220
+ llm_model.requires_grad_(False)
221
+ model.embed_tokens_ = LlamaEmbedding(
222
+ llm_model.model.embed_tokens.weight, llm_args.pad_token_id_
223
+ )
224
+ model.norm_ = LlamaRMSNorm(llm_model.model.norm.weight, llm_args.rms_norm_eps_)
225
+ copy_parameters(llm_model.lm_head, model.lm_head_)
226
+
227
+ for idx, layer in enumerate(llm_model.model.layers):
228
+ decoder = LlamaDecoderLayer(idx)
229
+ decoder.self_attn_ = MISTRAL_ATTENTION_CLASSES[
230
+ llm_args.attn_implementation_
231
+ ](
232
+ layer.self_attn.q_proj,
233
+ layer.self_attn.k_proj,
234
+ layer.self_attn.v_proj,
235
+ layer.self_attn.o_proj,
236
+ idx,
237
+ llm_args,
238
+ )
239
+ decoder.mlp_ = FeedForward(
240
+ LlamaMLP(
241
+ layer.mlp.gate_proj,
242
+ layer.mlp.down_proj,
243
+ layer.mlp.up_proj,
244
+ llm_args,
245
+ )
246
+ )
247
+ decoder.input_layernorm_ = LlamaRMSNorm(
248
+ layer.input_layernorm.weight, llm_args.rms_norm_eps_
249
+ )
250
+ decoder.post_attention_layernorm_ = LlamaRMSNorm(
251
+ layer.post_attention_layernorm.weight, llm_args.rms_norm_eps_
252
+ )
253
+ model.layers_.append(decoder)
254
+
255
+ return model
c2cite/models/modeling_phi.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, List, Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from transformers.activations import ACT2FN
8
+ from transformers.models.phi import modeling_phi
9
+ from transformers.models.phi.modeling_phi import (
10
+ PhiRotaryEmbedding,
11
+ apply_rotary_pos_emb,
12
+ repeat_kv,
13
+ )
14
+ from transformers.utils import is_flash_attn_2_available
15
+
16
+ from moe_peft.common import (
17
+ FeedForward,
18
+ Linear,
19
+ LLMAttention,
20
+ LLMCache,
21
+ LLMDecoder,
22
+ LLMFeedForward,
23
+ LLMForCausalLM,
24
+ LLMModelConfig,
25
+ LLMModelInput,
26
+ collect_plugin_router_logtis,
27
+ eager_attention_forward,
28
+ flash_attention_forward,
29
+ prepare_4d_causal_attention_mask,
30
+ slice_tensor,
31
+ )
32
+ from moe_peft.executors import executor
33
+ from moe_peft.utils import copy_parameters
34
+
35
+
36
+ @dataclass
37
+ class PhiConfig(LLMModelConfig):
38
+ layer_norm_eps_: float = 1e-05
39
+ resid_pdrop_: float = 0.0
40
+ embd_pdrop_: float = 0.0
41
+ rotary_emb_dim_: int = 0
42
+ qk_layernorm_: bool = False
43
+
44
+
45
+ def apply_partial_rotary_emb(
46
+ xq: torch.Tensor,
47
+ xk: torch.Tensor,
48
+ rotary_emb_dim: int,
49
+ cos: torch.Tensor,
50
+ sin: torch.Tensor,
51
+ position_ids: torch.Tensor,
52
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
53
+ q_rot, q_pass = (
54
+ xq[..., :rotary_emb_dim],
55
+ xq[..., rotary_emb_dim:],
56
+ )
57
+ k_rot, k_pass = (
58
+ xk[..., :rotary_emb_dim],
59
+ xk[..., rotary_emb_dim:],
60
+ )
61
+ # [batch_size, seq_length, num_heads, head_dim // partial_rotary_factor]
62
+ q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin, position_ids)
63
+
64
+ # [batch_size, seq_length, num_heads, head_dim]
65
+ xq = torch.cat((q_rot, q_pass), dim=-1)
66
+ xk = torch.cat((k_rot, k_pass), dim=-1)
67
+
68
+ return xq, xk
69
+
70
+
71
+ # Multi-headed attention from 'Attention Is All You Need' paper.
72
+ class PhiAttention(LLMAttention):
73
+ def __init__(
74
+ self,
75
+ q_proj: nn.Module,
76
+ k_proj: nn.Module,
77
+ v_proj: nn.Module,
78
+ dense: nn.Module,
79
+ idx: int,
80
+ config: PhiConfig,
81
+ ):
82
+ super().__init__()
83
+ # attention
84
+ self.wq_: Linear = Linear(q_proj, config.device_)
85
+ self.wk_: Linear = Linear(k_proj, config.device_)
86
+ self.wv_: Linear = Linear(v_proj, config.device_)
87
+ self.dense_: Linear = Linear(dense, config.device_)
88
+ # config
89
+ self.layer_idx_ = idx
90
+ self.dim_ = config.dim_
91
+ self.n_heads_ = config.n_heads_
92
+ self.n_kv_heads_ = config.n_kv_heads_
93
+ self.n_rep_ = self.n_heads_ // self.n_kv_heads_
94
+ self.rotary_emb_dim_ = config.rotary_emb_dim_
95
+ self.head_dim_ = config.head_dim_
96
+ self.dtype_ = config.dtype_
97
+ self.is_causal_ = True
98
+ # qk norm
99
+ self.qk_layernorm_: bool = config.qk_layernorm_
100
+ if self.qk_layernorm_:
101
+ self.q_layernorm_ = nn.LayerNorm(
102
+ self.hidden_size_ // self.num_heads_,
103
+ eps=config.norm_eps_,
104
+ elementwise_affine=True,
105
+ )
106
+ self.k_layernorm_ = nn.LayerNorm(
107
+ self.hidden_size_ // self.num_heads_,
108
+ eps=config.norm_eps_,
109
+ elementwise_affine=True,
110
+ )
111
+ else:
112
+ self.q_layernorm_ = nn.Identity()
113
+ self.k_layernorm_ = nn.Identity()
114
+
115
+ def state_dict(self) -> Dict[str, Linear]:
116
+ return {
117
+ "q_proj": self.wq_,
118
+ "k_proj": self.wk_,
119
+ "v_proj": self.wv_,
120
+ "dense": self.dense_,
121
+ }
122
+
123
+ def forward(
124
+ self,
125
+ hidden_states: torch.Tensor,
126
+ input_args: LLMModelInput,
127
+ rotary_emb: Tuple[torch.Tensor, torch.Tensor],
128
+ attention_mask: Optional[torch.Tensor] = None,
129
+ cache_position: Optional[torch.Tensor] = None,
130
+ past_key_value: Optional[LLMCache] = None,
131
+ ):
132
+ batch_size, max_seq_len, _ = hidden_states.shape
133
+
134
+ xq = self.wq_.forward(hidden_states, input_args)
135
+ xk = self.wk_.forward(hidden_states, input_args)
136
+ xv = self.wv_.forward(hidden_states, input_args)
137
+
138
+ xq = self.q_layernorm_(xq)
139
+ xk = self.k_layernorm_(xk)
140
+
141
+ # conver shape to multi head
142
+ xq = xq.view(batch_size, max_seq_len, self.n_heads_, self.head_dim_).transpose(
143
+ 1, 2
144
+ )
145
+ xk = xk.view(
146
+ batch_size, max_seq_len, self.n_kv_heads_, self.head_dim_
147
+ ).transpose(1, 2)
148
+ xv = xv.view(
149
+ batch_size, max_seq_len, self.n_kv_heads_, self.head_dim_
150
+ ).transpose(1, 2)
151
+
152
+ cos, sin = rotary_emb
153
+
154
+ # partial rotary embedding
155
+ xq, xk = apply_partial_rotary_emb(
156
+ xq,
157
+ xk,
158
+ self.rotary_emb_dim_,
159
+ cos,
160
+ sin,
161
+ cache_position.unsqueeze(0),
162
+ )
163
+
164
+ if past_key_value is not None:
165
+ cache_kwargs = {
166
+ "sin": sin,
167
+ "cos": cos,
168
+ "partial_rotation_size": self.rotary_emb_dim_,
169
+ "cache_position": cache_position,
170
+ }
171
+ xk, xv = past_key_value.update(xk, xv, self.layer_idx_, cache_kwargs)
172
+
173
+ # before dim: batch_size, n_kv_head, seq_len, head_dim
174
+ # after dim: batch_size, n_head, seq_len, head_dim
175
+ xk = repeat_kv(xk, self.n_rep_)
176
+ xv = repeat_kv(xv, self.n_rep_)
177
+
178
+ attention_score = eager_attention_forward(
179
+ xq.to(torch.float32), xk.to(torch.float32), xv, attention_mask
180
+ )
181
+
182
+ attention_score = attention_score.reshape(batch_size, max_seq_len, -1)
183
+ attention_score = self.dense_.forward(attention_score, input_args)
184
+
185
+ return attention_score
186
+
187
+
188
+ class PhiFlashAttention2(PhiAttention):
189
+ def __init__(
190
+ self,
191
+ q_proj: nn.Module,
192
+ k_proj: nn.Module,
193
+ v_proj: nn.Module,
194
+ dense: nn.Module,
195
+ idx: int,
196
+ args: PhiConfig,
197
+ ):
198
+ assert is_flash_attn_2_available(), "Flash Attention is not available"
199
+ super().__init__(q_proj, k_proj, v_proj, dense, idx, args)
200
+
201
+ def forward(
202
+ self,
203
+ hidden_states: torch.Tensor,
204
+ input_args: LLMModelInput,
205
+ rotary_emb: Tuple[torch.Tensor, torch.Tensor],
206
+ attention_mask: Optional[torch.Tensor] = None,
207
+ cache_position: Optional[torch.Tensor] = None,
208
+ past_key_value: Optional[LLMCache] = None,
209
+ ):
210
+ batch_size, max_seq_len, _ = hidden_states.shape
211
+
212
+ xq = self.wq_.forward(hidden_states, input_args)
213
+ xk = self.wk_.forward(hidden_states, input_args)
214
+ xv = self.wv_.forward(hidden_states, input_args)
215
+
216
+ xq = self.q_layernorm_(xq)
217
+ xk = self.k_layernorm_(xk)
218
+
219
+ # conver shape to multi head
220
+ xq = xq.view(batch_size, max_seq_len, self.n_heads_, self.head_dim_).transpose(
221
+ 1, 2
222
+ )
223
+ xk = xk.view(
224
+ batch_size, max_seq_len, self.n_kv_heads_, self.head_dim_
225
+ ).transpose(1, 2)
226
+ xv = xv.view(
227
+ batch_size, max_seq_len, self.n_kv_heads_, self.head_dim_
228
+ ).transpose(1, 2)
229
+
230
+ cos, sin = rotary_emb
231
+
232
+ # partial rotary embedding
233
+ xq, xk = apply_partial_rotary_emb(
234
+ xq,
235
+ xk,
236
+ self.rotary_emb_dim_,
237
+ cos,
238
+ sin,
239
+ cache_position.unsqueeze(0),
240
+ )
241
+
242
+ if past_key_value is not None:
243
+ cache_kwargs = {
244
+ "sin": sin,
245
+ "cos": cos,
246
+ "partial_rotation_size": self.rotary_emb_dim_,
247
+ "cache_position": cache_position,
248
+ }
249
+ xk, xv = past_key_value.update(xk, xv, self.layer_idx_, cache_kwargs)
250
+
251
+ xq = xq.transpose(1, 2)
252
+ xk = xk.transpose(1, 2)
253
+ xv = xv.transpose(1, 2)
254
+
255
+ input_dtype = xq.dtype
256
+ if input_dtype == torch.float32:
257
+ if executor.is_bf16_supported():
258
+ target_dtype = torch.bfloat16
259
+ else:
260
+ target_dtype = torch.float16
261
+ xq = xq.to(target_dtype)
262
+ xk = xk.to(target_dtype)
263
+ xv = xv.to(target_dtype)
264
+
265
+ attn_output = flash_attention_forward(
266
+ xq,
267
+ xk,
268
+ xv,
269
+ attention_mask,
270
+ max_seq_len,
271
+ is_causal=self.is_causal_,
272
+ ).to(input_dtype)
273
+
274
+ attn_output = attn_output.reshape(
275
+ batch_size, max_seq_len, self.dim_
276
+ ).contiguous()
277
+ attn_output = self.dense_.forward(attn_output, input_args)
278
+
279
+ return attn_output
280
+
281
+
282
+ PHI_ATTENTION_CLASSES = {
283
+ "eager": PhiAttention,
284
+ "flash_attn": PhiFlashAttention2,
285
+ }
286
+
287
+
288
+ class PhiMLP(LLMFeedForward):
289
+ def __init__(self, fc1: nn.Module, fc2: nn.Module, args: PhiConfig) -> None:
290
+ super().__init__()
291
+ # feed forward
292
+ self.fc1_: Linear = Linear(fc1, args.device_)
293
+ self.fc2_: Linear = Linear(fc2, args.device_)
294
+ self.act_ = ACT2FN[args.hidden_act_]
295
+
296
+ def state_dict(self) -> Dict[str, nn.Module]:
297
+ return {
298
+ "fc1": self.fc1_,
299
+ "fc2": self.fc2_,
300
+ }
301
+
302
+ def _batch_forward(
303
+ self, hidden_states: torch.Tensor, input_args: LLMModelInput
304
+ ) -> torch.Tensor:
305
+ hidden_states = self.fc1_.forward(hidden_states, input_args)
306
+ hidden_states = self.act_(hidden_states)
307
+ hidden_states = self.fc2_.forward(hidden_states, input_args)
308
+ return hidden_states
309
+
310
+ def _lora_forward(
311
+ self, lora_name: str, act_fn: nn.Module, hidden_states: torch.Tensor
312
+ ) -> torch.Tensor:
313
+ if lora_name in self.fc1_.loras_:
314
+ hidden_states = self.fc1_.loras_[lora_name].forward(
315
+ self.fc1_.base_layer_.forward(hidden_states), hidden_states
316
+ )
317
+ else:
318
+ hidden_states = self.fc1_.base_layer_.forward(hidden_states)
319
+
320
+ hidden_states = act_fn(hidden_states)
321
+
322
+ if lora_name in self.fc2_.loras_:
323
+ hidden_states = self.fc2_.loras_[lora_name].forward(
324
+ self.fc2_.base_layer_.forward(hidden_states), hidden_states
325
+ )
326
+ else:
327
+ hidden_states = self.fc2_.base_layer_.forward(hidden_states)
328
+
329
+ return hidden_states
330
+
331
+ def _mixlora_forward(
332
+ self, moe_name, act_fn, expert_mask, hidden_states, input_dtype
333
+ ):
334
+ common_fc1 = self.fc1_.base_layer_.forward(hidden_states.to(input_dtype)).to(
335
+ hidden_states.dtype
336
+ )
337
+ final_expert_states = []
338
+ for expert_idx in range(expert_mask.shape[0]):
339
+ _, top_x = torch.where(expert_mask[expert_idx])
340
+
341
+ lora_name = f"moe.{moe_name}.experts.{expert_idx}"
342
+ if lora_name in self.fc1_.loras_:
343
+ lora_data = slice_tensor(hidden_states, top_x, input_dtype)
344
+ act_result = act_fn(
345
+ self.fc1_.loras_[lora_name].forward(
346
+ slice_tensor(common_fc1, top_x, input_dtype), lora_data
347
+ )
348
+ )
349
+ else:
350
+ act_result = act_fn(slice_tensor(common_fc1, top_x, input_dtype))
351
+
352
+ if lora_name in self.fc2_.loras_:
353
+ final_expert_states.append(
354
+ self.fc2_.loras_[lora_name].forward(
355
+ self.fc2_.base_layer_.forward(act_result), act_result
356
+ )
357
+ )
358
+ else:
359
+ final_expert_states.append(self.fc2_.base_layer_.forward(act_result))
360
+
361
+ return final_expert_states
362
+
363
+
364
+ class PhiDecoderLayer(LLMDecoder):
365
+ def __init__(
366
+ self, layer_id: int, self_attn: LLMAttention, mlp: FeedForward, args: PhiConfig
367
+ ) -> None:
368
+ super().__init__()
369
+ self.layer_id_: int = layer_id
370
+ self.self_attn_ = self_attn
371
+ self.mlp_ = mlp
372
+ self.input_layernorm_ = nn.LayerNorm(
373
+ args.dim_, eps=args.layer_norm_eps_, dtype=args.dtype_, device=args.device_
374
+ )
375
+ self.resid_pdrop_ = args.resid_pdrop_
376
+
377
+ def state_dict(self) -> Tuple[Dict[str, nn.Module], Dict[str, nn.Module]]:
378
+ return self.self_attn_.state_dict(), self.mlp_.state_dict()
379
+
380
+ def forward(
381
+ self,
382
+ hidden_states: torch.Tensor,
383
+ input_args: LLMModelInput,
384
+ rotary_emb: Tuple[torch.Tensor, torch.Tensor],
385
+ attention_mask: Optional[torch.Tensor] = None,
386
+ cache_position: Optional[torch.Tensor] = None,
387
+ past_key_value: Optional[LLMCache] = None,
388
+ ):
389
+ residual = hidden_states
390
+ hidden_states = self.input_layernorm_(hidden_states)
391
+ # Self Attention
392
+ attn_outputs = self.self_attn_.forward(
393
+ hidden_states,
394
+ input_args,
395
+ rotary_emb,
396
+ attention_mask,
397
+ cache_position,
398
+ past_key_value,
399
+ )
400
+ attn_outputs = F.dropout(
401
+ attn_outputs, self.resid_pdrop_, not input_args.inference_mode_
402
+ )
403
+ # Fully Connected
404
+ feed_forward_outputs, router_logits = self.mlp_.forward(
405
+ hidden_states, input_args
406
+ )
407
+ feed_forward_outputs = F.dropout(
408
+ feed_forward_outputs, self.resid_pdrop_, not input_args.inference_mode_
409
+ )
410
+ hidden_states = attn_outputs + feed_forward_outputs + residual
411
+
412
+ if input_args.output_router_logits_:
413
+ router_logits = collect_plugin_router_logtis(
414
+ router_logits, input_args, self
415
+ )
416
+
417
+ return hidden_states, *router_logits
418
+
419
+
420
+ class PhiEmbedding(nn.Module):
421
+ def __init__(self, config: PhiConfig):
422
+ super().__init__()
423
+ self.embed_tokens = nn.Embedding(
424
+ config.vocab_size_,
425
+ config.dim_,
426
+ config.pad_token_id_,
427
+ dtype=config.dtype_,
428
+ device=config.device_,
429
+ )
430
+ self.embed_dropout = nn.Dropout(config.embd_pdrop_)
431
+
432
+ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
433
+ inputs_embeds = self.embed_tokens(input_ids)
434
+ return self.embed_dropout(inputs_embeds)
435
+
436
+
437
+ class PhiLayerNorm(nn.Module):
438
+ def __init__(self, config: PhiConfig) -> None:
439
+ super().__init__()
440
+ self.layernorm_ = nn.LayerNorm(
441
+ config.dim_,
442
+ eps=config.layer_norm_eps_,
443
+ dtype=config.dtype_,
444
+ device=config.device_,
445
+ )
446
+
447
+ def forward(self, data: torch.Tensor) -> torch.Tensor:
448
+ return self.layernorm_(data)
449
+
450
+
451
+ class PhiForCausalLM(LLMForCausalLM):
452
+ def __init__(self, config: PhiConfig) -> None:
453
+ super().__init__()
454
+ self.config_ = config
455
+ self.padding_idx_ = config.pad_token_id_
456
+ self.vocab_size_ = config.vocab_size_
457
+ self.embed_tokens_ = PhiEmbedding(config)
458
+ self.final_layernorm_ = PhiLayerNorm(config)
459
+ self.rotary_emb_ = PhiRotaryEmbedding(
460
+ dim=config.rotary_emb_dim_,
461
+ max_position_embeddings=config.max_seq_len_,
462
+ base=config.rope_theta_,
463
+ device=config.device_,
464
+ )
465
+ self.lm_head_ = nn.Linear(
466
+ config.dim_,
467
+ config.vocab_size_,
468
+ bias=True,
469
+ dtype=config.dtype_,
470
+ device=config.device_,
471
+ )
472
+ self.layers_: List[PhiDecoderLayer] = []
473
+
474
+ def embed_tokens(self, input_ids: torch.Tensor) -> torch.Tensor:
475
+ return self.embed_tokens_(input_ids)
476
+
477
+ def rotary_embed(
478
+ self, input_tensor: torch.Tensor, position_ids: torch.Tensor
479
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
480
+ return self.rotary_emb_(input_tensor, seq_len=position_ids[-1, -1] + 1)
481
+
482
+ def decoder_stack(self) -> List[LLMDecoder]:
483
+ return self.layers_
484
+
485
+ def norm(self, hidden_states: torch.Tensor) -> torch.Tensor:
486
+ return self.final_layernorm_(hidden_states)
487
+
488
+ def causal_mask(
489
+ self,
490
+ attention_mask: torch.Tensor,
491
+ input_tensor: torch.Tensor,
492
+ cache_position: torch.Tensor,
493
+ past_key_values: Optional[LLMCache],
494
+ ) -> torch.Tensor:
495
+
496
+ return prepare_4d_causal_attention_mask(
497
+ attention_mask,
498
+ input_tensor,
499
+ cache_position,
500
+ past_key_values,
501
+ )
502
+
503
+ def model_config(self) -> PhiConfig:
504
+ return self.config_
505
+
506
+ @staticmethod
507
+ def from_pretrained(
508
+ llm_model: modeling_phi.PhiForCausalLM,
509
+ attn_impl: str = "eager",
510
+ use_sliding_window: bool = False,
511
+ device: str = executor.default_device_name(),
512
+ ):
513
+ assert not use_sliding_window, "Phi model does not support SWA."
514
+ llm_config: modeling_phi.PhiConfig = llm_model.config
515
+ llm_args = PhiConfig(
516
+ name_or_path_=llm_config.name_or_path,
517
+ vocab_size_=llm_config.vocab_size,
518
+ dim_=llm_config.hidden_size,
519
+ head_dim_=llm_config.hidden_size // llm_config.num_attention_heads,
520
+ intermediate_=llm_config.intermediate_size,
521
+ n_layers_=llm_config.num_hidden_layers,
522
+ n_heads_=llm_config.num_attention_heads,
523
+ n_kv_heads_=llm_config.num_key_value_heads,
524
+ hidden_act_=llm_config.hidden_act,
525
+ resid_pdrop_=llm_config.resid_pdrop,
526
+ embd_pdrop_=llm_config.embd_pdrop,
527
+ max_seq_len_=llm_config.max_position_embeddings,
528
+ layer_norm_eps_=llm_config.layer_norm_eps,
529
+ rope_theta_=llm_config.rope_theta,
530
+ partial_rotary_factor_=llm_config.partial_rotary_factor,
531
+ qk_layernorm_=llm_config.qk_layernorm,
532
+ pad_token_id_=llm_config.pad_token_id,
533
+ attn_implementation_=attn_impl,
534
+ device_=torch.device(device),
535
+ dtype_=llm_model.dtype,
536
+ )
537
+
538
+ llm_args.rotary_emb_dim_ = int(
539
+ llm_args.partial_rotary_factor_ * llm_args.head_dim_
540
+ )
541
+
542
+ if llm_args.pad_token_id_ is None:
543
+ llm_args.pad_token_id_ = -1
544
+
545
+ model = PhiForCausalLM(llm_args)
546
+ llm_model.requires_grad_(False)
547
+ copy_parameters(llm_model.model.embed_tokens, model.embed_tokens_.embed_tokens)
548
+ copy_parameters(
549
+ llm_model.model.final_layernorm, model.final_layernorm_.layernorm_
550
+ )
551
+ copy_parameters(llm_model.lm_head, model.lm_head_)
552
+
553
+ for idx, layer in enumerate(llm_model.model.layers):
554
+ decoder = PhiDecoderLayer(
555
+ idx,
556
+ PHI_ATTENTION_CLASSES[llm_args.attn_implementation_](
557
+ layer.self_attn.q_proj,
558
+ layer.self_attn.k_proj,
559
+ layer.self_attn.v_proj,
560
+ layer.self_attn.dense,
561
+ idx,
562
+ llm_args,
563
+ ),
564
+ FeedForward(
565
+ PhiMLP(
566
+ layer.mlp.fc1,
567
+ layer.mlp.fc2,
568
+ llm_args,
569
+ )
570
+ ),
571
+ llm_args,
572
+ )
573
+ copy_parameters(layer.input_layernorm, decoder.input_layernorm_)
574
+ model.layers_.append(decoder)
575
+
576
+ return model
c2cite/models/modeling_phi3.py ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, List, Optional, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from transformers.activations import ACT2FN
8
+ from transformers.models.phi3 import modeling_phi3
9
+ from transformers.models.phi3.modeling_phi3 import apply_rotary_pos_emb, repeat_kv
10
+ from transformers.utils import is_flash_attn_2_available
11
+
12
+ from moe_peft.common import (
13
+ FeedForward,
14
+ Linear,
15
+ LLMAttention,
16
+ LLMCache,
17
+ LLMDecoder,
18
+ LLMFeedForward,
19
+ LLMForCausalLM,
20
+ LLMModelConfig,
21
+ LLMModelInput,
22
+ collect_plugin_router_logtis,
23
+ eager_attention_forward,
24
+ flash_attention_forward,
25
+ prepare_4d_causal_attention_mask,
26
+ slice_tensor,
27
+ )
28
+ from moe_peft.executors import executor
29
+ from moe_peft.utils import copy_parameters
30
+
31
+ from .modeling_gemma2 import Gemma2RotaryEmbedding as Phi3RotaryEmbedding
32
+ from .modeling_llama import LlamaEmbedding as Phi3Embedding
33
+ from .modeling_llama import LlamaRMSNorm as Phi3RMSNorm
34
+
35
+
36
+ @dataclass
37
+ class Phi3Config(LLMModelConfig):
38
+ rms_norm_eps_: float = 1e-6
39
+ original_max_position_embeddings_: int = 4096
40
+ rope_scaling_: Optional[Dict[str, Any]] = None
41
+ use_sliding_window_: bool = False
42
+ sliding_window_: int = 4096
43
+ resid_pdrop_: float = 0.0
44
+
45
+
46
+ class Phi3LongRoPEScaledRotaryEmbedding(Phi3RotaryEmbedding):
47
+ def __init__(self, dim, config: Phi3Config, device=None):
48
+ super().__init__(dim, config.max_seq_len_, config.rope_theta_, device)
49
+
50
+ self.short_factor = config.rope_scaling_["short_factor"]
51
+ self.long_factor = config.rope_scaling_["long_factor"]
52
+ self.original_max_position_embeddings = config.original_max_position_embeddings_
53
+
54
+ @torch.no_grad()
55
+ def forward(self, x, position_ids):
56
+ seq_len = torch.max(position_ids) + 1
57
+ if seq_len > self.original_max_position_embeddings:
58
+ ext_factors = torch.tensor(
59
+ self.long_factor, dtype=torch.float32, device=x.device
60
+ )
61
+ else:
62
+ ext_factors = torch.tensor(
63
+ self.short_factor, dtype=torch.float32, device=x.device
64
+ )
65
+
66
+ inv_freq_shape = (
67
+ torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float()
68
+ / self.dim
69
+ )
70
+ self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
71
+
72
+ inv_freq_expanded = (
73
+ self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
74
+ )
75
+ position_ids_expanded = position_ids[:, None, :].float()
76
+
77
+ # Force float32 since bfloat16 loses precision on long contexts
78
+ # See https://github.com/huggingface/transformers/pull/29285
79
+ device_type = x.device.type
80
+ device_type = (
81
+ device_type
82
+ if isinstance(device_type, str) and device_type != "mps"
83
+ else "cpu"
84
+ )
85
+ with torch.autocast(device_type=device_type, enabled=False):
86
+ freqs = (
87
+ inv_freq_expanded.float() @ position_ids_expanded.float()
88
+ ).transpose(1, 2)
89
+ emb = torch.cat((freqs, freqs), dim=-1)
90
+
91
+ scale = self.max_position_embeddings / self.original_max_position_embeddings
92
+ if scale <= 1.0:
93
+ scaling_factor = 1.0
94
+ else:
95
+ scaling_factor = math.sqrt(
96
+ 1
97
+ + math.log(scale) / math.log(self.original_max_position_embeddings)
98
+ )
99
+
100
+ cos = emb.cos() * scaling_factor
101
+ sin = emb.sin() * scaling_factor
102
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
103
+
104
+
105
+ class Phi3Attention(LLMAttention):
106
+ def __init__(
107
+ self, qkv_proj: nn.Module, o_proj: nn.Module, layer_idx: int, args: Phi3Config
108
+ ) -> None:
109
+ super().__init__()
110
+ # attention
111
+ self.qkv_proj_ = Linear(qkv_proj, args.device_)
112
+ self.o_proj_ = Linear(o_proj, args.device_)
113
+ # config
114
+ self.layer_idx_ = layer_idx
115
+ self.args_ = args
116
+ self.dim_ = args.dim_
117
+ self.n_heads_ = args.n_heads_
118
+ self.n_kv_heads_ = args.n_kv_heads_
119
+ self.n_rep_ = self.n_heads_ // self.n_kv_heads_
120
+ self.rope_theta_ = args.rope_theta_
121
+ self.head_dim_ = self.dim_ // self.n_heads_
122
+ self.dtype_ = args.dtype_
123
+ self.is_causal_ = True
124
+
125
+ def state_dict(self) -> Dict[str, Linear]:
126
+ return {
127
+ "qkv_proj": self.qkv_proj_,
128
+ "o_proj": self.o_proj_,
129
+ }
130
+
131
+ def forward(
132
+ self,
133
+ hidden_states: torch.Tensor,
134
+ input_args: LLMModelInput,
135
+ rotary_emb: Tuple[torch.Tensor, torch.Tensor],
136
+ attention_mask: Optional[torch.Tensor] = None,
137
+ cache_position: Optional[torch.Tensor] = None,
138
+ past_key_value: Optional[LLMCache] = None,
139
+ ):
140
+ bsz, q_len, _ = hidden_states.size()
141
+
142
+ qkv = self.qkv_proj_.forward(hidden_states, input_args)
143
+ query_pos = self.n_heads_ * self.head_dim_
144
+ query_states = qkv[..., :query_pos]
145
+ key_states = qkv[..., query_pos : query_pos + self.n_kv_heads_ * self.head_dim_]
146
+ value_states = qkv[..., query_pos + self.n_kv_heads_ * self.head_dim_ :]
147
+
148
+ query_states = query_states.view(
149
+ bsz, q_len, self.n_heads_, self.head_dim_
150
+ ).transpose(1, 2)
151
+ key_states = key_states.view(
152
+ bsz, q_len, self.n_kv_heads_, self.head_dim_
153
+ ).transpose(1, 2)
154
+ value_states = value_states.view(
155
+ bsz, q_len, self.n_kv_heads_, self.head_dim_
156
+ ).transpose(1, 2)
157
+
158
+ # apply rotary embedding
159
+ cos, sin = rotary_emb
160
+ assert query_states.dtype == key_states.dtype
161
+ query_states, key_states = apply_rotary_pos_emb(
162
+ query_states, key_states, cos, sin, cache_position.unsqueeze(0)
163
+ )
164
+
165
+ if past_key_value is not None:
166
+ cache_kwargs = {
167
+ "sin": sin,
168
+ "cos": cos,
169
+ "cache_position": cache_position,
170
+ }
171
+ key_states, value_states = past_key_value.update(
172
+ key_states, value_states, self.layer_idx_, cache_kwargs
173
+ )
174
+
175
+ value_states = repeat_kv(value_states, self.n_rep_)
176
+ key_states = repeat_kv(key_states, self.n_rep_)
177
+
178
+ attn_output = eager_attention_forward(
179
+ query_states, key_states, value_states, attention_mask
180
+ )
181
+ attn_output = attn_output.reshape(bsz, q_len, -1)
182
+
183
+ return self.o_proj_(attn_output, input_args)
184
+
185
+
186
+ class Phi3FlashAttention2(Phi3Attention):
187
+ def __init__(
188
+ self, qkv_proj: nn.Module, o_proj: nn.Module, layer_idx: int, args: Phi3Config
189
+ ) -> None:
190
+ assert is_flash_attn_2_available(), "Flash Attention is not available"
191
+ super().__init__(qkv_proj, o_proj, layer_idx, args)
192
+ self.sliding_window_ = args.sliding_window_
193
+
194
+ def forward(
195
+ self,
196
+ hidden_states: torch.Tensor,
197
+ input_args: LLMModelInput,
198
+ rotary_emb: Tuple[torch.Tensor, torch.Tensor],
199
+ attention_mask: Optional[torch.Tensor] = None,
200
+ cache_position: Optional[torch.Tensor] = None,
201
+ past_key_value: Optional[LLMCache] = None,
202
+ ):
203
+
204
+ bsz, q_len, _ = hidden_states.size()
205
+
206
+ # cutting
207
+ qkv = self.qkv_proj_.forward(hidden_states, input_args)
208
+ query_pos = self.n_heads_ * self.head_dim_
209
+ query_states = qkv[..., :query_pos]
210
+ key_states = qkv[..., query_pos : query_pos + self.n_kv_heads_ * self.head_dim_]
211
+ value_states = qkv[..., query_pos + self.n_kv_heads_ * self.head_dim_ :]
212
+
213
+ # viewing
214
+ query_states = query_states.view(
215
+ bsz, q_len, self.n_heads_, self.head_dim_
216
+ ).transpose(1, 2)
217
+ key_states = key_states.view(
218
+ bsz, q_len, self.n_kv_heads_, self.head_dim_
219
+ ).transpose(1, 2)
220
+ value_states = value_states.view(
221
+ bsz, q_len, self.n_kv_heads_, self.head_dim_
222
+ ).transpose(1, 2)
223
+
224
+ kv_seq_len = key_states.shape[-2]
225
+ if past_key_value is not None:
226
+ kv_seq_len += cache_position[0]
227
+
228
+ # apply rotary embedding
229
+ cos, sin = rotary_emb
230
+ query_states, key_states = apply_rotary_pos_emb(
231
+ query_states, key_states, cos, sin
232
+ )
233
+
234
+ # Activate slicing cache
235
+ if past_key_value is not None:
236
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
237
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx_) > 0
238
+ if (
239
+ self.sliding_window_ is not None
240
+ and kv_seq_len > self.sliding_window_
241
+ and cache_has_contents
242
+ ):
243
+ slicing_tokens = 1 - self.sliding_window_
244
+
245
+ past_key = past_key_value[self.layer_idx_][0]
246
+ past_value = past_key_value[self.layer_idx_][1]
247
+
248
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
249
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
250
+
251
+ if past_key.shape[-2] != self.sliding_window_ - 1:
252
+ raise ValueError(
253
+ f"past key must have a shape of (`batch_size, num_heads, self.sliding_window - 1, head_dim`), got"
254
+ f" {past_key.shape}"
255
+ )
256
+
257
+ if attention_mask is not None:
258
+ attention_mask = attention_mask[:, slicing_tokens:]
259
+ attention_mask = torch.cat(
260
+ [attention_mask, torch.ones_like(attention_mask[:, -1:])],
261
+ dim=-1,
262
+ )
263
+
264
+ cache_kwargs = {
265
+ "sin": sin,
266
+ "cos": cos,
267
+ "cache_position": cache_position,
268
+ } # Specific to RoPE models
269
+ key_states, value_states = past_key_value.update(
270
+ key_states, value_states, self.layer_idx_, cache_kwargs
271
+ )
272
+
273
+ # repeat k/v heads if n_kv_heads < n_heads
274
+ key_states = repeat_kv(key_states, self.n_rep_)
275
+ value_states = repeat_kv(value_states, self.n_rep_)
276
+
277
+ input_dtype = query_states.dtype
278
+ if input_dtype == torch.float32:
279
+ if executor.is_bf16_supported():
280
+ target_dtype = torch.bfloat16
281
+ else:
282
+ target_dtype = torch.float16
283
+ query_states = query_states.to(target_dtype)
284
+ key_states = key_states.to(target_dtype)
285
+ value_states = value_states.to(target_dtype)
286
+
287
+ query_states = query_states.transpose(1, 2)
288
+ key_states = key_states.transpose(1, 2)
289
+ value_states = value_states.transpose(1, 2)
290
+
291
+ attn_output = flash_attention_forward(
292
+ query_states,
293
+ key_states,
294
+ value_states,
295
+ attention_mask,
296
+ q_len,
297
+ is_causal=self.is_causal_,
298
+ sliding_window=self.sliding_window_,
299
+ ).to(input_dtype)
300
+
301
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
302
+ attn_output = self.o_proj_(attn_output, input_args)
303
+
304
+ return attn_output
305
+
306
+
307
+ PHI3_ATTENTION_CLASSES = {
308
+ "eager": Phi3Attention,
309
+ "flash_attn": Phi3FlashAttention2,
310
+ }
311
+
312
+
313
+ class Phi3MLP(LLMFeedForward):
314
+ def __init__(self, gate: nn.Module, down: nn.Module, args: Phi3Config) -> None:
315
+ super().__init__()
316
+ # feed forward
317
+ self.gate_up_proj_ = Linear(gate, args.device_)
318
+ self.down_proj_ = Linear(down, args.device_)
319
+ self.act_ = ACT2FN[args.hidden_act_]
320
+
321
+ def state_dict(self) -> Dict[str, nn.Module]:
322
+ return {
323
+ "gate_up_proj": self.gate_up_proj_,
324
+ "down_proj": self.down_proj_,
325
+ }
326
+
327
+ def _batch_forward(
328
+ self, hidden_states: torch.Tensor, input_args: LLMModelInput
329
+ ) -> torch.Tensor:
330
+ up_states = self.gate_up_proj_(hidden_states, input_args)
331
+
332
+ gate, up_states = up_states.chunk(2, dim=-1)
333
+ up_states = up_states * self.act_(gate)
334
+
335
+ return self.down_proj_(up_states, input_args)
336
+
337
+ def _lora_forward(
338
+ self, lora_name: str, act_fn: nn.Module, data: torch.Tensor
339
+ ) -> torch.Tensor:
340
+ # Applying LoRA weights to FFN weights
341
+ if lora_name in self.gate_up_proj_.loras_:
342
+ gate_up_states = self.gate_up_proj_.loras_[lora_name].forward(
343
+ self.gate_up_proj_.base_layer_.forward(data), data
344
+ )
345
+ else:
346
+ gate_up_states = self.gate_up_proj_.base_layer_.forward(data)
347
+
348
+ gate_states, up_states = gate_up_states.chunk(2, dim=-1)
349
+ act_result = act_fn(gate_states) * up_states
350
+
351
+ if lora_name in self.down_proj_.loras_:
352
+ return self.down_proj_.loras_[lora_name].forward(
353
+ self.down_proj_.base_layer_.forward(act_result), act_result
354
+ )
355
+ else:
356
+ return self.down_proj_.base_layer_.forward(act_result)
357
+
358
+ def _mixlora_forward(
359
+ self, moe_name, act_fn, expert_mask, hidden_states, input_dtype
360
+ ):
361
+ common_gate_up = self.gate_up_proj_.base_layer_.forward(
362
+ hidden_states.to(input_dtype)
363
+ ).to(hidden_states.dtype)
364
+
365
+ final_expert_states = []
366
+ for expert_idx in range(expert_mask.shape[0]):
367
+ _, top_x = torch.where(expert_mask[expert_idx])
368
+
369
+ lora_name = f"moe.{moe_name}.experts.{expert_idx}"
370
+ if lora_name in self.gate_up_proj_.loras_:
371
+ gate_up_states = self.gate_up_proj_.loras_[lora_name].forward(
372
+ slice_tensor(common_gate_up, top_x, input_dtype),
373
+ slice_tensor(hidden_states, top_x, input_dtype),
374
+ )
375
+ else:
376
+ gate_up_states = slice_tensor(common_gate_up, top_x, input_dtype)
377
+
378
+ gate_states, up_states = gate_up_states.chunk(2, dim=-1)
379
+ act_result = up_states * act_fn(gate_states)
380
+
381
+ if lora_name in self.down_proj_.loras_:
382
+ final_expert_states.append(
383
+ self.down_proj_.loras_[lora_name].forward(
384
+ self.down_proj_.base_layer_.forward(act_result),
385
+ act_result,
386
+ )
387
+ )
388
+ else:
389
+ final_expert_states.append(
390
+ self.down_proj_.base_layer_.forward(act_result)
391
+ )
392
+
393
+ return final_expert_states
394
+
395
+
396
+ class Phi3DecoderLayer(LLMDecoder):
397
+ def __init__(self, layer_id: int, config: Phi3Config) -> None:
398
+ super().__init__()
399
+ self.layer_id_: int = layer_id
400
+ self.self_attn_: Phi3Attention = None
401
+ self.mlp_: FeedForward = None
402
+ self.input_layernorm_: Phi3RMSNorm = None
403
+
404
+ self.resid_attn_dropout = nn.Dropout(config.resid_pdrop_)
405
+ self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop_)
406
+ self.post_attention_layernorm_: Phi3RMSNorm = None
407
+
408
+ def state_dict(self) -> Tuple[Dict[str, nn.Module], Dict[str, nn.Module]]:
409
+ return self.self_attn_.state_dict(), self.mlp_.state_dict()
410
+
411
+ def forward(
412
+ self,
413
+ hidden_states: torch.Tensor,
414
+ input_args: LLMModelInput,
415
+ rotary_emb: Tuple[torch.Tensor, torch.Tensor],
416
+ attention_mask: Optional[torch.Tensor] = None,
417
+ cache_position: Optional[torch.Tensor] = None,
418
+ past_key_value: Optional[LLMCache] = None,
419
+ ):
420
+ residual = hidden_states
421
+ hidden_states = self.input_layernorm_(hidden_states)
422
+ # Self Attention
423
+ attn_outputs = self.self_attn_.forward(
424
+ hidden_states,
425
+ input_args,
426
+ rotary_emb,
427
+ attention_mask,
428
+ cache_position,
429
+ past_key_value,
430
+ )
431
+ hidden_states = residual + self.resid_attn_dropout(attn_outputs)
432
+ # Fully Connected
433
+ residual = hidden_states
434
+ hidden_states = self.post_attention_layernorm_(hidden_states)
435
+ hidden_states, router_logits = self.mlp_.forward(hidden_states, input_args)
436
+ hidden_states = residual + self.resid_mlp_dropout(hidden_states)
437
+
438
+ if input_args.output_router_logits_:
439
+ router_logits = collect_plugin_router_logtis(
440
+ router_logits, input_args, self
441
+ )
442
+
443
+ return hidden_states, *router_logits
444
+
445
+
446
+ class Phi3ForCausalLM(LLMForCausalLM):
447
+ def _init_rope(self):
448
+ if self.config_.rope_scaling_ is None:
449
+ return Phi3RotaryEmbedding(
450
+ self.config_.head_dim_,
451
+ max_position_embeddings=self.config_.max_seq_len_,
452
+ base=self.config_.rope_theta_,
453
+ device=self.config_.device_,
454
+ )
455
+ else:
456
+ scaling_type = self.config_.rope_scaling_["type"]
457
+ assert scaling_type == "longrope", ValueError(
458
+ f"Unknown RoPE scaling type {scaling_type}"
459
+ )
460
+ return Phi3LongRoPEScaledRotaryEmbedding(
461
+ self.config_.head_dim_,
462
+ config=self.config_,
463
+ device=self.config_.device_,
464
+ )
465
+
466
+ def __init__(self, config: Phi3Config) -> None:
467
+ super().__init__()
468
+ self.config_ = config
469
+ self.padding_idx_ = config.pad_token_id_
470
+ self.vocab_size_ = config.vocab_size_
471
+ self.embed_tokens_: Phi3Embedding = None
472
+ self.norm_: Phi3Embedding = None
473
+ self.rotary_emb_ = self._init_rope()
474
+ self.lm_head_ = nn.Linear(
475
+ config.dim_,
476
+ config.vocab_size_,
477
+ bias=False,
478
+ dtype=config.dtype_,
479
+ device=config.device_,
480
+ )
481
+ self.layers_: List[Phi3DecoderLayer] = []
482
+
483
+ def embed_tokens(self, input_ids: torch.Tensor) -> torch.Tensor:
484
+ return self.embed_tokens_(input_ids)
485
+
486
+ def rotary_embed(
487
+ self, input_tensor: torch.Tensor, position_ids: torch.Tensor
488
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
489
+ return self.rotary_emb_(input_tensor, position_ids)
490
+
491
+ def decoder_stack(self) -> List[LLMDecoder]:
492
+ return self.layers_
493
+
494
+ def norm(self, hidden_states: torch.Tensor) -> torch.Tensor:
495
+ return self.norm_(hidden_states)
496
+
497
+ def causal_mask(
498
+ self,
499
+ attention_mask: torch.Tensor,
500
+ input_tensor: torch.Tensor,
501
+ cache_position: torch.Tensor,
502
+ past_key_values: Optional[LLMCache],
503
+ ) -> torch.Tensor:
504
+
505
+ return prepare_4d_causal_attention_mask(
506
+ attention_mask,
507
+ input_tensor,
508
+ cache_position,
509
+ past_key_values,
510
+ )
511
+
512
+ def model_config(self) -> Phi3Config:
513
+ return self.config_
514
+
515
+ @staticmethod
516
+ def from_pretrained(
517
+ llm_model: modeling_phi3.Phi3ForCausalLM,
518
+ attn_impl: str = "eager",
519
+ use_sliding_window: bool = False,
520
+ device: str = executor.default_device_name(),
521
+ ):
522
+ llm_config: modeling_phi3.Phi3Config = llm_model.config
523
+ llm_args = Phi3Config(
524
+ name_or_path_=llm_config.name_or_path,
525
+ vocab_size_=llm_config.vocab_size,
526
+ dim_=llm_config.hidden_size,
527
+ head_dim_=llm_config.hidden_size // llm_config.num_attention_heads,
528
+ intermediate_=llm_config.intermediate_size,
529
+ n_layers_=llm_config.num_hidden_layers,
530
+ n_heads_=llm_config.num_attention_heads,
531
+ n_kv_heads_=llm_config.num_key_value_heads,
532
+ hidden_act_=llm_config.hidden_act,
533
+ rms_norm_eps_=llm_config.rms_norm_eps,
534
+ resid_pdrop_=llm_config.resid_pdrop,
535
+ max_seq_len_=llm_config.max_position_embeddings,
536
+ rope_theta_=llm_config.rope_theta,
537
+ rope_scaling_=llm_config.rope_scaling,
538
+ original_max_position_embeddings_=llm_config.original_max_position_embeddings,
539
+ pad_token_id_=llm_config.pad_token_id,
540
+ attn_implementation_=attn_impl,
541
+ use_sliding_window_=use_sliding_window,
542
+ sliding_window_=llm_config.sliding_window,
543
+ device_=torch.device(device),
544
+ dtype_=llm_model.dtype,
545
+ )
546
+
547
+ if llm_args.pad_token_id_ is None:
548
+ llm_args.pad_token_id_ = -1
549
+
550
+ model = Phi3ForCausalLM(llm_args)
551
+ llm_model.requires_grad_(False)
552
+ model.embed_tokens_ = Phi3Embedding(
553
+ llm_model.model.embed_tokens.weight, llm_args.pad_token_id_
554
+ )
555
+ model.norm_ = Phi3RMSNorm(llm_model.model.norm.weight, llm_args.rms_norm_eps_)
556
+ copy_parameters(llm_model.lm_head, model.lm_head_)
557
+
558
+ for idx, layer in enumerate(llm_model.model.layers):
559
+ decoder = Phi3DecoderLayer(idx, llm_args)
560
+ decoder.self_attn_ = PHI3_ATTENTION_CLASSES[llm_args.attn_implementation_](
561
+ layer.self_attn.qkv_proj,
562
+ layer.self_attn.o_proj,
563
+ idx,
564
+ llm_args,
565
+ )
566
+ decoder.mlp_ = FeedForward(
567
+ Phi3MLP(
568
+ layer.mlp.gate_up_proj,
569
+ layer.mlp.down_proj,
570
+ llm_args,
571
+ )
572
+ )
573
+ decoder.input_layernorm_ = Phi3RMSNorm(
574
+ layer.input_layernorm.weight, llm_args.rms_norm_eps_
575
+ )
576
+ decoder.post_attention_layernorm_ = Phi3RMSNorm(
577
+ layer.post_attention_layernorm.weight, llm_args.rms_norm_eps_
578
+ )
579
+ model.layers_.append(decoder)
580
+
581
+ return model
c2cite/prompter.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os.path as osp
4
+ from typing import Dict, Optional, Union
5
+
6
+ prompt_templates = {
7
+ "moe_peft": {
8
+ "description": "Default Prompt Template Provided by MoE-PEFT",
9
+ "prompt_input": "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Output:\n",
10
+ "prompt_no_input": "### Instruction:\n{instruction}\n\n### Output:\n",
11
+ "response_split": "### Output:",
12
+ },
13
+ "alpaca": {
14
+ "description": "Template used by Alpaca-LoRA.",
15
+ "prompt_input": "Below is an instruction that describes a task, "
16
+ + "paired with an input that provides further context. "
17
+ + "Write a response that appropriately completes the request.\n\n"
18
+ + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
19
+ "prompt_no_input": "Below is an instruction that describes a task. "
20
+ + "Write a response that appropriately completes the request.\n\n"
21
+ + "### Instruction:\n{instruction}\n\n### Response:\n",
22
+ "response_split": "### Response:",
23
+ },
24
+ }
25
+
26
+
27
+ # manage templates and prompt building.
28
+ class Prompter:
29
+ def __init__(self, template: Optional[Union[Dict, str]] = None):
30
+ if template is None:
31
+ self.template = prompt_templates["moe_peft"]
32
+ elif isinstance(template, str):
33
+ if osp.exists(template):
34
+ with open(template) as fp:
35
+ self.template = json.load(fp)
36
+ else:
37
+ self.template = prompt_templates[template]
38
+ else:
39
+ self.template = template
40
+
41
+ logging.info(f"Using prompt template: {self.template['description']}")
42
+
43
+ def generate_prompt(
44
+ self,
45
+ instruction: str,
46
+ input: Union[None, str] = None,
47
+ label: Union[None, str] = None,
48
+ ) -> str:
49
+ # returns the full prompt from instruction and optional input
50
+ # if a label (=response, =output) is provided, it's also appended.
51
+ if input:
52
+ res = self.template["prompt_input"].format(
53
+ instruction=instruction, input=input
54
+ )
55
+ else:
56
+ res = self.template["prompt_no_input"].format(instruction=instruction)
57
+ if label:
58
+ res = f"{res}{label}\n"
59
+ logging.debug(res)
60
+ return res
61
+
62
+ def get_response(self, output: str) -> str:
63
+ return output.split(self.template["response_split"])[-1].strip()
c2cite/solutions.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ # peering那篇
5
+ def get_output(layers, hidden, output, ans_len):
6
+ if layers == 32:
7
+ pass
8
+ else:
9
+ pass
c2cite/tasks/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import glue_tasks, qa_tasks, attribute_tasks
2
+ from .common import (
3
+ AutoMetric,
4
+ BasicMetric,
5
+ BasicTask,
6
+ CasualTask,
7
+ CommonSenseTask,
8
+ MultiTask,
9
+ SequenceClassificationTask,
10
+ task_dict,
11
+ )
12
+ from .qa_tasks import QuestionAnswerTask
13
+
14
+ glue_tasks.update_task_dict(task_dict)
15
+ qa_tasks.update_task_dict(task_dict)
16
+ attribute_tasks.update_task_dict(task_dict)
17
+
18
+
19
+ __all__ = [
20
+ "BasicMetric",
21
+ "AutoMetric",
22
+ "BasicTask",
23
+ "CasualTask",
24
+ "SequenceClassificationTask",
25
+ "CommonSenseTask",
26
+ "QuestionAnswerTask",
27
+ "MultiTask",
28
+ "task_dict",
29
+ ]
c2cite/tasks/attribute_tasks.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+ from typing import List, Optional
4
+
5
+ import datasets as hf_datasets
6
+ import torch
7
+ import json
8
+ import re
9
+ import os
10
+ from tqdm import tqdm
11
+
12
+ from transformers import BertTokenizer, BertModel
13
+
14
+
15
+ from moe_peft.common import InputData
16
+
17
+ from moe_peft.tasks.common import AttributeTask, BasicMetric, AutoMetric
18
+
19
+
20
+ class AttributedAnswerTask(AttributeTask):
21
+ def __init__(self) -> None:
22
+ super().__init__()
23
+
24
+
25
+ def loading_metric(self, metrics: List[str]):
26
+
27
+ return AutoMetric("attribute", metrics)
28
+
29
+ class ASQA(AttributedAnswerTask):
30
+ def __init__(self, sub: str = 'vani'):
31
+ super().__init__()
32
+ self.inst = 'Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite them properly. Use an unbiased and journalistic tone. Always cite for any factual claim. When citing several search results, use [1][2][3]. Cite at least one document and at most three documents in each sentence. If multiple documents support the sentence, only cite a minimum sufficient subset of the documents.'
33
+ self.inst_special_token = 'Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite them properly. Use an unbiased and journalistic tone. Always cite for any factual claim. Cite at least one document and at most three documents in each sentence. If multiple documents support the sentence, only cite a minimum sufficient subset of the documents.'
34
+ self.inst_new = 'Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite all of them at the end of the sentences. Use an unbiased and journalistic tone. Always cite for any factual claim. Cite at least one document in each sentence.'
35
+ self.sub = sub
36
+
37
+ def loading_data(self, is_train: bool = False, path: str = None, few_shot: bool = True
38
+ ) -> List[InputData]:
39
+ few_shot = False #################################
40
+
41
+ num_docs = 5
42
+ current_dir = os.path.dirname(os.path.abspath(__file__))
43
+ relative_path = "../../dataset/ALCE-data/asqa_eval_gtr_top100.json" # 向上两级再进入dataset目录
44
+ file_path = os.path.join(current_dir, relative_path)
45
+
46
+ with open(path if path is not None else file_path,'r',encoding='utf-8') as file:
47
+ data = json.load(file)
48
+ logging.info("Preparing data for ASQA")
49
+ ret: List[InputData] = []
50
+ #cnt = 5
51
+ """tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
52
+ model = BertModel.from_pretrained('bert-large-uncased')
53
+ device = 'cuda:6'
54
+ model = model.to(device)
55
+ model.eval()"""
56
+ for data_point in tqdm(data):
57
+ #if cnt == 0:
58
+ # break
59
+ #cnt = cnt - 1
60
+ #prompt = ""
61
+ prompt = "<|start_header_id|>system<|end_header_id|>\n\n" + "You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
62
+ #prompt += self.inst_new
63
+ prompt += self.inst_special_token
64
+ if few_shot:
65
+ prompt += f"Here is an example:\n\nQuestion: Who played galen in planet of the apes?\n\nDocument [1](Title: Planet of the Apes): installment. Jacobs died on June 27, 1973, bringing an end to the APJAC Productions era of the \"Planet of the Apes\" franchise. Former Fox executive Stan Hough took over as producer for the television project, titled \"Planet of the Apes\". CBS picked up the series for its 1974 autumn lineup. Ron Harper and James Naughton played Alan Virdon and Peter Burke, two 20th-century American astronauts who pass through a time warp to a future where apes subjugate humans (unlike the original film, the humans can speak). Roddy McDowall returned to the franchise as Galen, a chimpanzee who joins the astronauts.\nDocument [2](Title: Planet of the Apes (1968 film)): chimpanzees: animal psychologist Zira (Kim Hunter) and surgeon Galen (Wright King). While unable to speak as his throat wound is healing, called \"Bright Eyes\" by Zira and placed with one of the captive primitive humans he later names \"Nova\", Taylor observes the enhanced society of talking apes and in a strict caste system: the gorillas being the military police, hunters and workers; the orangutans overseeing the affairs of government, science, and religion; and intellectual chimpanzees being mostly scientists. While their society is a theocracy similar to the beginnings of the human Industrial Era, the apes consider the primitive humans as\nDocument [3](Title: Planet of the Apes (1968 film)): Planet of the Apes (1968 film) Planet of the Apes is a 1968 American science fiction film directed by Franklin J. Schaffner. It stars Charlton Heston, Roddy McDowall, Kim Hunter, Maurice Evans, James Whitmore, James Daly and Linda Harrison. The screenplay by Michael Wilson and Rod Serling was loosely based on the 1963 French novel \"La Plan\u00e8te des Singes\" by Pierre Boulle. Jerry Goldsmith composed the groundbreaking avant-garde score. It was the first in a series of five films made between 1968 and 1973, all produced by Arthur P. Jacobs and released by 20th Century Fox. The film tells the\nDocument [4](Title: Planet of the Apes): Rupert Wyatt. To portray ape characters realistically, the production avoided practical effects in favor of performance capture acting, partnering with New Zealand visual effects company Weta Digital. Wyatt cast James Franco as Will Rodman, while veteran performance capture actor Andy Serkis signed on to star as Caesar. \"Rise\" debuted on August 5, 2011. Critics reviewed it positively, especially praising the visual effects and Serkis's performance. It was a major box office hit, taking in $482 million globally, more than five times its $93 million budget. Weta's special effects earned the film two Visual Effects Society Awards and an Oscar nomination\nDocument [5](Title: Planet of the Apes): film stars Mark Wahlberg as astronaut Leo Davidson, who accidentally travels through a wormhole to a distant planet where talking apes enslave humans. He leads a human revolt and upends ape civilization by discovering that the apes evolved from the normal earth primates who had accompanied his mission, and arrived years before. Helena Bonham Carter played chimpanzee Ari, while Tim Roth played the human-hating chimpanzee General Thade. The film received mixed reviews; most critics believed it failed to compare to the original. Much of the negative commentary focused on the confusing plot and twist ending, though many reviewers praised the\n\nAnswer:In the 1968 film Planet of the Apes, Galen was played by Wright King [2]. And in the tv series Planet of the Apes, Galen was played by Roddy McDowall [1].\n\n\n"
66
+ #prompt += f"\n\n\nQusetion: {data_point['qa_pairs'][0]['question']}\n\n"
67
+ prompt += f"\n\n\nQusetion: {data_point['question']}\n\n"
68
+ docs = ""
69
+ cites = []
70
+ for i in range(num_docs):
71
+ cites.append({
72
+ 'text': data_point['docs'][i]['text'],
73
+ 'title': data_point['docs'][i]['title'],
74
+ 'summary': data_point['docs'][i]['summary'],
75
+ })
76
+ #random.shuffle(cites)
77
+ for i in range(num_docs):
78
+ docs += f"Document <|reserved_special_token_{i+1}|>: {cites[i]['text'] if self.sub=='vani' else cites[i]['summary']}\n"
79
+ #docs += f"Document <|reserved_special_token_{i+1}|>(Title: {cites[i]['title']}): {cites[i]['text'] if self.sub=='vani' else cites[i]['summary']}\n"
80
+ #docs += f"Document [{i+1}](Title: {cites[i]['title']}): {cites[i]['text'] if self.sub=='vani' else cites[i]['summary']}\n"
81
+ cites = [cites[i]['text'] if self.sub=='vani' else cites[i]['summary'] for i in range(num_docs)]
82
+ prompt += docs
83
+ prompt += f"\nAnswer:"
84
+ # prompt += "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
85
+ #citation_embeds = sents_embed(cites, model, tokenizer, device)
86
+ ret.append(InputData(inputs=prompt, labels=data_point['answer'], \
87
+ grounds=data_point['qa_pairs'], citations = cites,# citation_embeds = citation_embeds,\
88
+ query = data_point['question']))
89
+
90
+ return ret
91
+
92
+ def loading_metric(self):
93
+ config = {}
94
+ config['task'] = 'asqa'
95
+ config['metric'] = metric_list['asqa']
96
+ return AutoMetric("attribute", config)
97
+
98
+
99
+ class ELI5(AttributedAnswerTask):
100
+ def __init__(self, sub: str = 'vani'):
101
+ super().__init__()
102
+ self.inst = 'Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite them properly. Use an unbiased and journalistic tone. Always cite for any factual claim. When citing several search results, use [1][2][3]. Cite at least one document and at most three documents in each sentence. If multiple documents support the sentence, only cite a minimum sufficient subset of the documents.'
103
+ self.inst_special_token = 'Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite them properly. Use an unbiased and journalistic tone. Always cite for any factual claim. Cite at least one document and at most three documents in each sentence. If multiple documents support the sentence, only cite a minimum sufficient subset of the documents.'
104
+ self.sub = sub
105
+
106
+ def loading_data(self, is_train: bool = False, path: str = None, few_shot: bool = True
107
+ ) -> List[InputData]:
108
+ few_shot = False ##############
109
+ num_docs = 5
110
+ current_dir = os.path.dirname(os.path.abspath(__file__))
111
+ relative_path = "../../dataset/ALCE-data/eli5_eval_bm25_top100.json" # 向上两级再进入dataset目录
112
+ file_path = os.path.join(current_dir, relative_path)
113
+ with open(path if path is not None else file_path,'r',encoding='utf-8') as file:
114
+ data = json.load(file)
115
+ logging.info("Preparing data for ELI5")
116
+ ret: List[InputData] = []
117
+ #cnt = 5
118
+ for data_point in tqdm(data):
119
+ #if cnt == 0:
120
+ # break
121
+ #cnt = cnt - 1
122
+ prompt = "<|start_header_id|>system<|end_header_id|>\n\n" + "You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
123
+ #prompt += self.inst
124
+ prompt += self.inst_special_token
125
+ if few_shot:
126
+ prompt += f"Here is an example:\n\nQuestion: Who played galen in planet of the apes?\n\nDocument [1](Title: Planet of the Apes): installment. Jacobs died on June 27, 1973, bringing an end to the APJAC Productions era of the \"Planet of the Apes\" franchise. Former Fox executive Stan Hough took over as producer for the television project, titled \"Planet of the Apes\". CBS picked up the series for its 1974 autumn lineup. Ron Harper and James Naughton played Alan Virdon and Peter Burke, two 20th-century American astronauts who pass through a time warp to a future where apes subjugate humans (unlike the original film, the humans can speak). Roddy McDowall returned to the franchise as Galen, a chimpanzee who joins the astronauts.\nDocument [2](Title: Planet of the Apes (1968 film)): chimpanzees: animal psychologist Zira (Kim Hunter) and surgeon Galen (Wright King). While unable to speak as his throat wound is healing, called \"Bright Eyes\" by Zira and placed with one of the captive primitive humans he later names \"Nova\", Taylor observes the enhanced society of talking apes and in a strict caste system: the gorillas being the military police, hunters and workers; the orangutans overseeing the affairs of government, science, and religion; and intellectual chimpanzees being mostly scientists. While their society is a theocracy similar to the beginnings of the human Industrial Era, the apes consider the primitive humans as\nDocument [3](Title: Planet of the Apes (1968 film)): Planet of the Apes (1968 film) Planet of the Apes is a 1968 American science fiction film directed by Franklin J. Schaffner. It stars Charlton Heston, Roddy McDowall, Kim Hunter, Maurice Evans, James Whitmore, James Daly and Linda Harrison. The screenplay by Michael Wilson and Rod Serling was loosely based on the 1963 French novel \"La Plan\u00e8te des Singes\" by Pierre Boulle. Jerry Goldsmith composed the groundbreaking avant-garde score. It was the first in a series of five films made between 1968 and 1973, all produced by Arthur P. Jacobs and released by 20th Century Fox. The film tells the\nDocument [4](Title: Planet of the Apes): Rupert Wyatt. To portray ape characters realistically, the production avoided practical effects in favor of performance capture acting, partnering with New Zealand visual effects company Weta Digital. Wyatt cast James Franco as Will Rodman, while veteran performance capture actor Andy Serkis signed on to star as Caesar. \"Rise\" debuted on August 5, 2011. Critics reviewed it positively, especially praising the visual effects and Serkis's performance. It was a major box office hit, taking in $482 million globally, more than five times its $93 million budget. Weta's special effects earned the film two Visual Effects Society Awards and an Oscar nomination\nDocument [5](Title: Planet of the Apes): film stars Mark Wahlberg as astronaut Leo Davidson, who accidentally travels through a wormhole to a distant planet where talking apes enslave humans. He leads a human revolt and upends ape civilization by discovering that the apes evolved from the normal earth primates who had accompanied his mission, and arrived years before. Helena Bonham Carter played chimpanzee Ari, while Tim Roth played the human-hating chimpanzee General Thade. The film received mixed reviews; most critics believed it failed to compare to the original. Much of the negative commentary focused on the confusing plot and twist ending, though many reviewers praised the\n\nAnswer:In the 1968 film Planet of the Apes, Galen was played by Wright King [2]. And in the tv series Planet of the Apes, Galen was played by Roddy McDowall [1].\n\n\n"
127
+ prompt += f"\n\n\nQusetion: {data_point['question']}\n\n"
128
+ docs = ""
129
+ cites = []
130
+ for i in range(num_docs):
131
+ cites.append({
132
+ 'text': data_point['docs'][i]['text'],
133
+ 'title': data_point['docs'][i]['title'],
134
+ 'summary': data_point['docs'][i]['summary'],
135
+ })
136
+ #random.shuffle(cites)
137
+ for i in range(num_docs):
138
+ docs += f"Document <|reserved_special_token_{i+1}|>: {cites[i]['text'] if self.sub=='vani' else cites[i]['summary']}\n"
139
+ #docs += f"Document [{i+1}](Title: {cites[i]['title']}): {cites[i]['text'] if self.sub=='vani' else cites[i]['summary']}\n"
140
+ cites = [cites[i]['text'] if self.sub=='vani' else cites[i]['summary'] for i in range(num_docs)]
141
+ prompt += docs
142
+ prompt += f"\nAnswer:"
143
+ # prompt += "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
144
+ ret.append(InputData(inputs=prompt, labels=data_point['answer'], \
145
+ grounds=data_point['claims'], citations = cites, \
146
+ query = data_point['question']))
147
+
148
+ return ret
149
+
150
+ def loading_metric(self):
151
+ config = {}
152
+ config['task'] = 'eli5'
153
+ config['metric'] = metric_list['eli5']
154
+ return AutoMetric("attribute", config)
155
+
156
+ class Qampari(AttributedAnswerTask):
157
+ def __init__(self, sub: str = 'vani'):
158
+ super().__init__()
159
+ self.inst = 'Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite them properly. Use an unbiased and journalistic tone. Always cite for any factual claim. When citing several search results, use [1][2][3]. Cite at least one document and at most three documents in each sentence. If multiple documents support the sentence, only cite a minimum sufficient subset of the documents.'
160
+ self.inst_special_token = 'Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite them properly. Use an unbiased and journalistic tone. Always cite for any factual claim. Cite at least one document and at most three documents in each sentence. If multiple documents support the sentence, only cite a minimum sufficient subset of the documents.'
161
+ self.sub = sub
162
+
163
+ def loading_data(self, is_train: bool = False, path: str = None, few_shot: bool = True
164
+ ) -> List[InputData]:
165
+ few_shot = False ##############
166
+ num_docs = 5
167
+ current_dir = os.path.dirname(os.path.abspath(__file__))
168
+ relative_path = "../../dataset/ALCE-data/qampari_eval_gtr_top100.json" # 向上两级再进入dataset目录
169
+ file_path = os.path.join(current_dir, relative_path)
170
+ with open(path if path is not None else file_path,'r',encoding='utf-8') as file:
171
+ data = json.load(file)
172
+ logging.info("Preparing data for Qampari")
173
+ ret: List[InputData] = []
174
+ #cnt = 5
175
+ for data_point in tqdm(data):
176
+ #if cnt == 0:
177
+ # break
178
+ #cnt = cnt - 1
179
+ prompt = "<|start_header_id|>system<|end_header_id|>\n\n" + "You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
180
+ #prompt += self.inst
181
+ prompt += self.inst_special_token
182
+ if few_shot:
183
+ prompt += f"Here is an example:\n\nQuestion: Who played galen in planet of the apes?\n\nDocument [1](Title: Planet of the Apes): installment. Jacobs died on June 27, 1973, bringing an end to the APJAC Productions era of the \"Planet of the Apes\" franchise. Former Fox executive Stan Hough took over as producer for the television project, titled \"Planet of the Apes\". CBS picked up the series for its 1974 autumn lineup. Ron Harper and James Naughton played Alan Virdon and Peter Burke, two 20th-century American astronauts who pass through a time warp to a future where apes subjugate humans (unlike the original film, the humans can speak). Roddy McDowall returned to the franchise as Galen, a chimpanzee who joins the astronauts.\nDocument [2](Title: Planet of the Apes (1968 film)): chimpanzees: animal psychologist Zira (Kim Hunter) and surgeon Galen (Wright King). While unable to speak as his throat wound is healing, called \"Bright Eyes\" by Zira and placed with one of the captive primitive humans he later names \"Nova\", Taylor observes the enhanced society of talking apes and in a strict caste system: the gorillas being the military police, hunters and workers; the orangutans overseeing the affairs of government, science, and religion; and intellectual chimpanzees being mostly scientists. While their society is a theocracy similar to the beginnings of the human Industrial Era, the apes consider the primitive humans as\nDocument [3](Title: Planet of the Apes (1968 film)): Planet of the Apes (1968 film) Planet of the Apes is a 1968 American science fiction film directed by Franklin J. Schaffner. It stars Charlton Heston, Roddy McDowall, Kim Hunter, Maurice Evans, James Whitmore, James Daly and Linda Harrison. The screenplay by Michael Wilson and Rod Serling was loosely based on the 1963 French novel \"La Plan\u00e8te des Singes\" by Pierre Boulle. Jerry Goldsmith composed the groundbreaking avant-garde score. It was the first in a series of five films made between 1968 and 1973, all produced by Arthur P. Jacobs and released by 20th Century Fox. The film tells the\nDocument [4](Title: Planet of the Apes): Rupert Wyatt. To portray ape characters realistically, the production avoided practical effects in favor of performance capture acting, partnering with New Zealand visual effects company Weta Digital. Wyatt cast James Franco as Will Rodman, while veteran performance capture actor Andy Serkis signed on to star as Caesar. \"Rise\" debuted on August 5, 2011. Critics reviewed it positively, especially praising the visual effects and Serkis's performance. It was a major box office hit, taking in $482 million globally, more than five times its $93 million budget. Weta's special effects earned the film two Visual Effects Society Awards and an Oscar nomination\nDocument [5](Title: Planet of the Apes): film stars Mark Wahlberg as astronaut Leo Davidson, who accidentally travels through a wormhole to a distant planet where talking apes enslave humans. He leads a human revolt and upends ape civilization by discovering that the apes evolved from the normal earth primates who had accompanied his mission, and arrived years before. Helena Bonham Carter played chimpanzee Ari, while Tim Roth played the human-hating chimpanzee General Thade. The film received mixed reviews; most critics believed it failed to compare to the original. Much of the negative commentary focused on the confusing plot and twist ending, though many reviewers praised the\n\nAnswer:In the 1968 film Planet of the Apes, Galen was played by Wright King [2]. And in the tv series Planet of the Apes, Galen was played by Roddy McDowall [1].\n\n\n"
184
+ prompt += f"\n\n\nQusetion: {data_point['question']}\n\n"
185
+ docs = ""
186
+ cites = []
187
+ for i in range(num_docs):
188
+ cites.append({
189
+ 'text': data_point['docs'][i]['text'],
190
+ 'title': data_point['docs'][i]['title'],
191
+ })
192
+ #random.shuffle(cites)
193
+ for i in range(num_docs):
194
+ docs += f"Document <|reserved_special_token_{i+1}|>: {cites[i]['text']}\n"
195
+ #docs += f"Document [{i+1}](Title: {cites[i]['title']}): {cites[i]['text'] if self.sub=='vani' else cites[i]['summary']}\n"
196
+ cites = [cites[i]['text'] for i in range(num_docs)]
197
+ prompt += docs
198
+ prompt += f"\nAnswer:"
199
+ # prompt += "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
200
+ ret.append(InputData(inputs=prompt, labels=data_point['answers'], \
201
+ citations = cites, \
202
+ query = data_point['question']))
203
+ return ret
204
+
205
+ def loading_metric(self):
206
+ config = {}
207
+ config['task'] = 'qam'
208
+ config['metric'] = metric_list['qam']
209
+ return AutoMetric("attribute", config)
210
+
211
+
212
+ class QouteSum(AttributedAnswerTask):
213
+ def __init__(self, sub: str = 'vani'):
214
+ super().__init__()
215
+ self.sub = sub
216
+ self.inst = 'Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite them properly. Use an unbiased and journalistic tone. Always cite for any factual claim. When citing several search results, use [1][2][3]. Cite at least one document and at most three documents in each sentence. If multiple documents support the sentence, only cite a minimum sufficient subset of the documents.'
217
+ self.inst2 = 'Based on the information contained in the document, answer the question with details to the best of your bilities. Think step by step and explain your answer if that will help better understand the answer.'
218
+ self.inst_special_token = 'Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite them properly. Use an unbiased and journalistic tone. Always cite for any factual claim. Cite at least one document and at most three documents in each sentence. If multiple documents support the sentence, only cite a minimum sufficient subset of the documents.'
219
+ self.inst_new = 'Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite all of them at the end of the sentences. Use an unbiased and journalistic tone. Always cite for any factual claim. Cite at least one document in each sentence.'
220
+
221
+ def loading_data(self, is_train: bool = False, path: str = None,
222
+ few_shot: bool = True ) -> List[InputData]:
223
+ few_shot = False ###########
224
+ if is_train:
225
+ few_shot = False
226
+ ret: List[InputData] = []
227
+ examples_by_qid = {}
228
+ """tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
229
+ model = BertModel.from_pretrained('bert-large-uncased')
230
+ device = 'cuda:6'
231
+ model = model.to(device)
232
+ model.eval()"""
233
+ with open(f"/yy21/MoE-PEFT/dataset/{'qoutesum_alce' if self.sub == 'alce' else ( 'qoutesum_ans' if self.sub == 'ans' else 'qoutesum')}/{'train' if is_train else 'test'}.jsonl" if path is None else path, 'r') as f:
234
+ #cnt = 50
235
+ for line in f:
236
+ #if cnt == 0:
237
+ # break
238
+ #cnt -= 1
239
+ example = json.loads(line.strip())
240
+ if example['qid'] not in examples_by_qid:
241
+ examples_by_qid[example['qid']] = [example]
242
+ else:
243
+ examples_by_qid[example['qid']].append(example)
244
+
245
+ examples = list(examples_by_qid.values())
246
+ for example in examples:
247
+ prompt = "<|start_header_id|>system<|end_header_id|>\n\n" + "You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
248
+ prompt += self.inst_special_token
249
+ #prompt += self.inst_new
250
+ if few_shot:
251
+ if self.sub == 'alce':
252
+ prompt += f" Here are some examples:\nQuestion: how much power does a wind turbine produce?\nDocument [1](Title:): Compact wind acceleration turbine: It is generally thought that since the amount of power produced by a wind turbine is proportional to the cube of the wind speed, any acceleration benefit is potentially statistically significant in the economics of wind. As noted though this is an inaccurate as it ignores the impact of the exit to area ratio and is therefore an apples to oranges comparison. In the case of a typical CWAT/DAWT the power result in perfect theoretical operation once adjusted for the area of the shroud is actually the square of the velocity at the rotor. As the CWAT/DAWT diverges from theoretical function the power increase drops significantly according\nDocument [2](Title:): Sustainable architecture: roof ledge. Small-scale rooftop wind turbines have been known to be able to generate power from 10% to up to 25% of the electricity required of a regular domestic household dwelling. Turbines for residential scale use are usually between 7 feet (2 m) to 25 feet (8 m) in diameter and produce electricity at a rate of 900 watts to 10,000 watts at their tested wind speed. Building integrated wind turbine performance can be enhanced with the addition of an aerofoil wing on top of a roof mounted turbine. Solar water heaters, also called solar domestic hot water systems, can\nDocument [3](Title:): Turby wind turbine: can because horizontal axis (HAWT) types cannot change their pitch to face the wind directly. The turbine measures 2.0m (6'7\") in diameter by 2.9m (9'6\") high (including generator), and weighs 136 kg (300 lb). It is specified to generate power in winds of between 4 m/s (9 mph, 7.8kts) and 14 m/s (31 mph, 27.2kts), and can survive winds of 55 m/s (123 mph, 107kts). The rated power at 14 m/s is 2.5 kW (3.35 hp). The AC output from the synchronous generator is rectified to DC, then inverted to AC at 230V 50 Hz. Core International developed the turbine\nAnswer: One source states the amount of power produced by a wind turbine is proportional to the cube of the wind speed [1]. Other sources state Turbines for residential scale use produce electricity at a rate of 900 watts to 10,000 watts, and is specified to generate power in winds of between 4 m/s (9 mph, 7.8kts) and 14 m/s (31 mph, 27.2kts) [2][3]."
253
+ elif self.sub == 'vani':
254
+ prompt += f" Here are some examples:\nQuestion: how much power does a wind turbine produce?\n[1] Compact wind acceleration turbine: It is generally thought that since the amount of power produced by a wind turbine is proportional to the cube of the wind speed, any acceleration benefit is potentially statistically significant in the economics of wind. As noted though this is an inaccurate as it ignores the impact of the exit to area ratio and is therefore an apples to oranges comparison. In the case of a typical CWAT/DAWT the power result in perfect theoretical operation once adjusted for the area of the shroud is actually the square of the velocity at the rotor. As the CWAT/DAWT diverges from theoretical function the power increase drops significantly according\n[2] Sustainable architecture: roof ledge. Small-scale rooftop wind turbines have been known to be able to generate power from 10% to up to 25% of the electricity required of a regular domestic household dwelling. Turbines for residential scale use are usually between 7 feet (2 m) to 25 feet (8 m) in diameter and produce electricity at a rate of 900 watts to 10,000 watts at their tested wind speed. Building integrated wind turbine performance can be enhanced with the addition of an aerofoil wing on top of a roof mounted turbine. Solar water heaters, also called solar domestic hot water systems, can\n[3] Turby wind turbine: can because horizontal axis (HAWT) types cannot change their pitch to face the wind directly. The turbine measures 2.0m (6'7\") in diameter by 2.9m (9'6\") high (including generator), and weighs 136 kg (300 lb). It is specified to generate power in winds of between 4 m/s (9 mph, 7.8kts) and 14 m/s (31 mph, 27.2kts), and can survive winds of 55 m/s (123 mph, 107kts). The rated power at 14 m/s is 2.5 kW (3.35 hp). The AC output from the synchronous generator is rectified to DC, then inverted to AC at 230V 50 Hz. Core International developed the turbine\nAnswer: One source states the [ 1 amount of power produced by a wind turbine is proportional to the cube of the wind speed ] . Other sources state [ 2 Turbines for residential scale use ] [ 2 produce electricity at a rate of 900 watts to 10,000 watts ] , and [ 3 is specified to generate power in winds of between 4 m/s (9 mph, 7.8kts) and 14 m/s (31 mph, 27.2kts) ] .\n\nQuestion: a component is what?\n[1] Modular programming: in Dart, Go or Java) is sometimes used instead of module. In other implementations, this is a distinct concept; in Python a package is a collection of modules, while in Java 9 the introduction of the new module concept (a collection of packages with enhanced access control) is planned. Furthermore, the term \"package\" has other uses in software (for example .NET NuGet packages). A component is a similar concept, but typically refers to a higher level; a component is a piece of a whole system, while a module is a piece of an individual program. The scale of the term\n[2] Physical body: the system at a point in time changes from identifying the object to not identifying it. Also an object's identity is created at the first point in time that the simplest model of the system consistent with perception identifies it. An object may be composed of components. A component is an object completely within the boundary of a containing object. In classical mechanics a physical body is collection of matter having properties including mass, velocity, momentum and energy. The matter exists in a volume of three-dimensional space. This space is its extension. Under Newtonian gravity the gravitational field further away\nQuoted summary: [ 1 A component is a similar concept, but typically refers to a higher level; a component is a piece of a whole system, while a module is a piece of an individual program ] in terms of [ 1 Modular programming ] . Whereas in the [ 2 Physical body ] , a [ 2 component is an object completely within the boundary of a containing object ] ."
255
+ elif self.sub == 'ans':
256
+ pass
257
+ prompt += f"\n\nQusetion: {example[0]['question']}\n"
258
+ docs = ""
259
+ sources = []
260
+ citations = []
261
+ #fk = 0
262
+ for i in range(8):
263
+ if f"title{i+1}" not in example[0]:
264
+ break
265
+ #if example[0][f'title{i+1}'] == "":
266
+ # fk = i
267
+ sources.append({'title': example[0][f'title{i+1}'],
268
+ 'doc': example[0][f"source{i+1}"]}
269
+ )
270
+ #random.shuffle(sources[:fk])
271
+ for i in range(8):
272
+ if sources[i]['doc'] != "":
273
+ #docs += f"Document [{i+1}](Title: {sources[i]['title']}): {sources[i]['doc']}\n"
274
+ #docs += f"Document <|reserved_special_token_{i+1}|>(Title: {sources[i]['title']}): {sources[i]['doc']}\n"
275
+ docs += f"Document <|reserved_special_token_{i+1}|>: {sources[i]['doc']}\n"
276
+ citations.append(sources[i]['doc'])
277
+ else:
278
+ break
279
+ if len(citations) == 0:
280
+ continue
281
+ #citations = sents_embed(citations, model, tokenizer, device)
282
+ prompt += docs
283
+ prompt += f"\nAnswer:"
284
+ if is_train:
285
+ for e in example:
286
+ #ret.append(InputData(inputs = prompt + e['summary']))
287
+ ret.append(InputData(inputs = prompt + cite2token(e['summary']),
288
+ citations=citations, prompt = prompt))
289
+ else:
290
+ ret.append(InputData(inputs=prompt, labels=[e['summary'] for e in example], \
291
+ grounds=[i for e in example for i in e['covered_short_answers']], \
292
+ citations=citations, query = example[0]['question']))
293
+ return ret
294
+
295
+ def loading_metric(self):
296
+ config = {}
297
+ config['task'] = 'qsum'
298
+ if self.sub == 'alce':
299
+ config['metric'] = metric_list['qsum-a']
300
+ else:
301
+ config['metric'] = metric_list['qsum']
302
+ return AutoMetric("attribute", config)
303
+
304
+
305
+ class Front(AttributedAnswerTask):
306
+ def __init__(self, sub):
307
+ super().__init__()
308
+ self.inst = 'Extract the relevant content from the provided documents and then use the extracted content to guide answer generation and cite the sources properly.'
309
+ self.sub = sub
310
+
311
+ def loading_data(self, is_train: bool = False, few_shot: bool = True
312
+ ) -> List[InputData]:
313
+ few_shot = False ##############
314
+ with open("/yy21/MoE-PEFT/dataset/front/sft.json" if self.sub == 'sft' else "/yy21/MoE-PEFT/dataset/front/dpo.json",'r',encoding='utf-8') as file:
315
+ data = json.load(file)
316
+ logging.info("Preparing data for Front")
317
+ ret: List[InputData] = []
318
+ #cnt = 2
319
+
320
+ for data_point in data:
321
+ if data_point['instruction'] != self.inst:
322
+ continue
323
+ #if cnt == 0:
324
+ # break
325
+ #cnt = cnt - 1
326
+ prompt = "<|start_header_id|>system<|end_header_id|>\n\n" + "You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
327
+ prompt += self.inst
328
+ prompt += data_point['input']
329
+ prompt += "\nAnswer:"
330
+ prompt = cite2token(prompt)
331
+ q_start = len("Question: ")
332
+ q_end = data_point['input'].find("\n\n", q_start)
333
+ q = data_point['input'][q_start:q_end]
334
+ cites = []
335
+ pattern = r"Document \[(\d+)\]: (.*?)(?=Document \[\d+\]:|$)"
336
+ matches = re.findall(pattern, data_point['input'][q_end + 2:], re.DOTALL)
337
+ cites = [content.strip() for _, content in matches]
338
+ #random.shuffle(cites)
339
+ ans_idx = data_point['output'].find("[ANSWER]")
340
+ ans = cite2token(data_point['output'][ans_idx + len("[ANSWER]"):])
341
+ if is_train:
342
+ ret.append(InputData(inputs = prompt + ans, prompt = prompt, citations=cites))
343
+ else:
344
+ ret.append(InputData(inputs=prompt, labels=ans, \
345
+ citations = cites, query = q))
346
+ return ret
347
+
348
+ def loading_metric(self):
349
+ config = {}
350
+ config['task'] = 'front'
351
+ config['metric'] = metric_list['front']
352
+ return AutoMetric("attribute", config)
353
+
354
+
355
+ class Synsciqa(AttributedAnswerTask):
356
+ def __init__(self, sub):
357
+ super().__init__()
358
+ self.sub = sub
359
+ self.inst = lambda query: f"Can you respond to the question {query} by only relying on the sources. Ignore all sources that do not provide an answer to the question. Do not include any knowledge from outside of these sources. Only write a single paragraph. Each sentence must end with the reference in the form of (author, year, page number). Stricly follow this format. Citing multiple sources in one sentence is not allowed. However, if no source addresses the question, admit truthfully that no answer can be given. Answer the question concisly and avoid being verbose."
360
+ self.inst_a = 'Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite them properly. Use an unbiased and journalistic tone. Always cite for any factual claim. When citing several search results, use [1][2][3]. Cite at least one document and at most three documents in each sentence. If multiple documents support the sentence, only cite a minimum sufficient subset of the documents.'
361
+ self.inst_special_token = 'Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite them properly. Use an unbiased and journalistic tone. Always cite for any factual claim. Cite at least one document and at most three documents in each sentence. If multiple documents support the sentence, only cite a minimum sufficient subset of the documents.'
362
+ self.inst_new = 'Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite all of them at the end of the sentences. Use an unbiased and journalistic tone. Always cite for any factual claim. Cite at least one document in each sentence.'
363
+
364
+
365
+ def loading_data(self, is_train: bool = False, few_shot: bool = True
366
+ ) -> List[InputData]:
367
+ few_shot = False ##############
368
+ current_dir = os.path.dirname(os.path.abspath(__file__))
369
+ # 向上两级再进入dataset目录
370
+
371
+ if self.sub == 'synsci':
372
+ relative_path = "../../dataset/SynSciQA/SynSciQA.json"
373
+ elif self.sub == 'synsci+':
374
+ relative_path = "../../dataset/SynSciQA/SynSciQA+.json"
375
+ elif self.sub == 'synsci++':
376
+ relative_path = "../../dataset/SynSciQA/SynSciQA++.json"
377
+ file_path = os.path.join(current_dir, relative_path)
378
+ with open(file_path, 'r',encoding='utf-8') as file:
379
+ data = json.load(file)
380
+
381
+ logging.info("Preparing data for SynsciQA")
382
+ ret: List[InputData] = []
383
+ #cnt = 10
384
+
385
+ """tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
386
+ model = BertModel.from_pretrained('bert-large-uncased')
387
+ device = 'cuda:4'
388
+ model = model.to(device)
389
+ model.eval()"""
390
+ for line in tqdm(data):
391
+ #if cnt == 0:
392
+ # break
393
+ #cnt -= 1
394
+ data_point = line["instruction"]
395
+ answer = line["response"]
396
+ doc_start = data_point.find("[BEGIN OF SOURCES]")
397
+ doc_end = data_point.find("[END OF SOURCES]")
398
+ documents = data_point[doc_start + len("[BEGIN OF SOURCES]"): doc_end].strip().split("\n")
399
+ assert len(documents) > 0, print(f"No docs detected!")
400
+
401
+ data_point = data_point[doc_end + len("[END OF SOURCES]"):]
402
+ pattern = r'"([^"]*)"'
403
+ query = re.findall(pattern, data_point)
404
+ #prompt = ""
405
+ prompt = "<|start_header_id|>system<|end_header_id|>\n\n" + "You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
406
+ #prompt += self.inst_special_token
407
+ #prompt += self.inst_new
408
+ #prompt += self.inst_a
409
+ prompt += f"\n\nQuestion: {query[0]}\n"
410
+
411
+ docs = ""
412
+ citations = []
413
+ index_map = []
414
+ index_map2 = []
415
+ for i, d in enumerate(documents):
416
+ Ids = d[:d.find(":")]
417
+ cont = d[d.find(":") + 2:]
418
+ docs += f"Document <|reserved_special_token_{i+1}|>: {cont}\n"
419
+ #docs += f"Document [{i+1}]: {cont}\n"
420
+ citations.append(cont)
421
+ index_map.append({'index': f"({Ids})", 'ID': f'<|reserved_special_token_{i+1}|>'})
422
+ index_map2.append({'index': f"{Ids}", 'ID': f'<|reserved_special_token_{i+1}|>'})
423
+ #index_map.append({'index': f"({Ids})", 'ID': f'[{i+1}]'})
424
+ index_map = {item['index']: item['ID'] for item in index_map}
425
+ index_map2 = {item['index']: item['ID'] for item in index_map2}
426
+ prompt +=docs
427
+ prompt += "\nAnswer:"
428
+ pattern = re.compile('|'.join(map(re.escape, index_map)))
429
+ answer = pattern.sub(lambda m: index_map[m.group()], answer)
430
+ pattern = re.compile('|'.join(map(re.escape, index_map2)))
431
+ answer = pattern.sub(lambda m: index_map2[m.group()], answer)
432
+ pattern = r'\(\s*(<\|[^|]+\|>)\s*;\s*(<\|[^|]+\|>)\s*\)'
433
+ answer = re.sub(pattern, r'\1\2', answer)
434
+
435
+ pattern = r'<\|reserved_special_token_\d+\|>'
436
+ if bool(re.search(pattern, answer)) == False:
437
+ continue
438
+ pattern = r"\((?:[^)]*,){2}[^)]*p\.[^)]*\)"
439
+ fk = re.findall(pattern, answer)
440
+ if fk:
441
+ continue
442
+ #print(f"inputs:{prompt}\nans:{answer}\ncite{citations}")
443
+ #input()
444
+ #citation_embeds = sents_embed(citations, model, tokenizer, device)
445
+ if is_train:
446
+ ret.append(InputData(
447
+ inputs = prompt + answer, citations = citations, prompt = prompt#, citation_embeds = citation_embeds,
448
+ ))
449
+ return ret
450
+
451
+ def loading_metric(self):
452
+ config = {}
453
+ config['task'] = 'front'
454
+ config['metric'] = metric_list['front']
455
+ return AutoMetric("attribute", config)
456
+
457
+
458
+ class Reinf(AttributedAnswerTask):
459
+ def __init__(self):
460
+ super().__init__()
461
+ self.inst_a = 'Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite them properly. Use an unbiased and journalistic tone. Always cite for any factual claim. When citing several search results, use [1][2][3]. Cite at least one document and at most three documents in each sentence. If multiple documents support the sentence, only cite a minimum sufficient subset of the documents.'
462
+ self.inst_special_token = 'Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite them properly. Use an unbiased and journalistic tone. Always cite for any factual claim. Cite at least one document and at most three documents in each sentence. If multiple documents support the sentence, only cite a minimum sufficient subset of the documents.'
463
+
464
+
465
+ def loading_data(self, is_train: bool = False, few_shot: bool = True
466
+ ) -> List[InputData]:
467
+ few_shot = False ##############
468
+ with open("/yy21/MoE-PEFT/dataset/reinforcement/combined_train.json", 'r',encoding='utf-8') as file:
469
+ data = json.load(file)
470
+ logging.info("Preparing data for Reinforcement")
471
+ ret: List[InputData] = []
472
+ #cnt = 305
473
+
474
+ for line in tqdm(data):
475
+ #if cnt == 0:
476
+ # break
477
+
478
+ answer = line["output"][0]
479
+ if bool(re.search(r'\[(\d+)\]', answer)) == False:
480
+ continue
481
+ cs = re.findall(r'\[(\d+)\]', answer)
482
+ if max(map(int, cs)) > len(line["docs"]):
483
+ continue
484
+
485
+ query = line["question"]
486
+
487
+ documents = line["docs"]
488
+ answer = self.get_ans(answer)
489
+
490
+ prompt = "<|start_header_id|>system<|end_header_id|>\n\n" + "You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
491
+ prompt += self.inst_special_token
492
+ #prompt += self.inst_new
493
+ #prompt += self.inst_a
494
+ prompt += f"\n\nQuestion: {query}\n"
495
+
496
+ docs = ""
497
+ citations = []
498
+ for i, d in enumerate(documents):
499
+ docs += f"Document <|reserved_special_token_{i+1}|>: {d['text']}\n"
500
+ citations.append(d["text"])
501
+ prompt +=docs
502
+ prompt += "\nAnswer:"
503
+ #cnt -= 1
504
+ if is_train:
505
+ ret.append(InputData(
506
+ inputs = prompt + answer, citations = citations, prompt = prompt
507
+ ))
508
+ return ret
509
+
510
+ def get_ans(self, sent):
511
+ def replace_cite(x):
512
+ i = x.group(1)
513
+ return f"<|reserved_special_token_{i}|>"
514
+ return re.sub(r'\[(\d+)\]', replace_cite, sent)
515
+
516
+ def loading_metric(self):
517
+ config = {}
518
+ config['task'] = 'front'
519
+ config['metric'] = metric_list['front']
520
+ return AutoMetric("attribute", config)
521
+
522
+ def sents_embed(sents, model, tokenizer, device):
523
+ embeds = []
524
+ with torch.no_grad():
525
+ for sent in sents:
526
+ inputs = tokenizer(sent, return_tensors='pt', padding=True, truncation=True)
527
+ inputs = inputs.to(device)
528
+ output = model(**inputs)
529
+ embeds.append(output.pooler_output)
530
+ result = torch.stack(embeds).squeeze(1)
531
+ return result
532
+
533
+ def cite2token(sent):
534
+ pattern = r'\[(\d+)\]'
535
+ ans = re.sub(pattern, r'<|reserved_special_token_\g<1>|>', sent)
536
+ return ans
537
+
538
+ metric_list = {
539
+ 'asqa': ['cite_pr', 'length', 'short_ans'],
540
+ 'qsum': ['rouge_all', 'semqa_f1', 'semqa_short'],
541
+ 'qsum-a': ['rouge_all','semqa_short', 'cite_pr', 'length', 'semqa_f1'],
542
+ 'eli5': ['cite_pr', 'eli5_acc', 'length'],
543
+ 'qam': ['cite_pr', 'qampari'],
544
+ 'front': [],
545
+ }
546
+
547
+ def update_task_dict(task_dict):
548
+ task_dict.update(
549
+ {
550
+ "asqa": ASQA(),
551
+ "qsum": QouteSum('vani'),
552
+ "qsum-a": QouteSum('alce'),
553
+ "qsum-ans": QouteSum('ans'),
554
+ "eli5": ELI5(),
555
+ "front-s": Front('sft'),
556
+ "front-d": Front('dpo'),
557
+ "synsci": Synsciqa('synsci'),
558
+ "synsci+": Synsciqa('synsci+'),
559
+ "synsci++": Synsciqa('synsci++'),
560
+ "rein": Reinf(),
561
+ "qam": Qampari()
562
+ }
563
+ )
564
+
565
+ if __name__ == '__main__':
566
+ asqa = QouteSum()
567
+ asqa.loading_data()
c2cite/tasks/common.py ADDED
@@ -0,0 +1,1045 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import re
4
+ import json
5
+ import copy
6
+ import string
7
+ from nltk import sent_tokenize
8
+ from tqdm import tqdm
9
+ import numpy as np
10
+ from rouge import Rouge
11
+ import collections
12
+ from rouge_score import rouge_scorer, scoring
13
+ import functools
14
+ from typing import Any, Callable, Dict, List, Optional, Tuple
15
+
16
+ import transformers
17
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
18
+ import datasets as hf_datasets
19
+ import evaluate as hf_evaluate
20
+ import torch
21
+
22
+ from moe_peft.common import InputData, Prompt
23
+ os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
24
+
25
+ global autoais_model, autoais_tokenizer
26
+ autoais_model = None
27
+ autoais_tokenizer = None
28
+ qa_pipeline = None
29
+ get_docs_by_index = lambda i,docs: docs[i] if i < len(docs) else None
30
+ ais_LLM = None
31
+
32
+ evaluate_device = 'cuda:6'
33
+ #QA_MODEL = "gaotianyu1350/roberta-large-squad"
34
+ QA_MODEL = "/yy21/qa_model"
35
+ #AUTOAIS_MODEL = "google/t5_xxl_true_nli_mixture"
36
+ AUTOAIS_MODEL = "/yy21/autoais"
37
+
38
+ class BasicMetric:
39
+ def __init__(self) -> None:
40
+ pass
41
+
42
+ def add_batch(self, data):
43
+ pass
44
+
45
+ def add_batch(self, predictions: torch.Tensor, references: torch.Tensor):
46
+ pass
47
+
48
+ def compute(self) -> Dict[str, Any]:
49
+ pass
50
+
51
+
52
+ from statistics import harmonic_mean
53
+
54
+ def normalize_answers(text):
55
+ """QA style answer normalization. Similar to TriviaQA."""
56
+
57
+ def remove_articles(s):
58
+ return re.sub(r"\b(a|an|the)\b", " ", s)
59
+
60
+ def replace_punctuation(s):
61
+ to_replace = set(string.punctuation)
62
+ return "".join(" " if ch in to_replace else ch for ch in s)
63
+
64
+ def white_space_fix(s):
65
+ return " ".join(s.split())
66
+
67
+ text = text.lower()
68
+ text = replace_punctuation(text)
69
+ text = remove_articles(text)
70
+ text = white_space_fix(text)
71
+
72
+ return text
73
+
74
+
75
+ def strip_attribution_tokens(text):
76
+ """Strip the attribution tokens from an answer."""
77
+ return re.sub(r'\[ ([1-9]) ([^\[\]]*) \]',r'\2' , text)
78
+
79
+
80
+ def non_quoted(text):
81
+ """Returns only the text that is outside of quoted spans."""
82
+ return re.sub(r'\[ ([1-9]) ([^\[\]]*) \]', '' , text)
83
+
84
+
85
+ def only_quoted(text, sources='1-9', sep = ' '):
86
+ """Returns only the text that is within of quoted spans."""
87
+ return sep.join([x.group(1) for x in re.finditer(r'\[ [{}] ([^\[\]]*) \]'.format(sources), text)])
88
+
89
+
90
+ def quoted_sources(text):
91
+ """Returns the list of input sources that were quoted in the answer."""
92
+ return sorted(list(set([int(x.group(1)) for x in re.finditer(r'\[ ([1-9]) [^\[\]]* \]', text)])))
93
+
94
+
95
+ def score_all(data, scorer, aggr_measure, score_keys, preprocess_func=None, bootstrap=False):
96
+ """
97
+ Aggregates across all targets per sample.
98
+
99
+ all_targets: list of list of strings
100
+ all_predictions: list of strings
101
+ """
102
+ all_targets = [d['answer'] for d in data]
103
+ all_predictions = [d['output'] for d in data]
104
+
105
+ np.random.seed(1337)
106
+
107
+ is_rouge_measure = 'rouge' in aggr_measure
108
+
109
+ if preprocess_func is not None:
110
+ scoring_func = lambda target, prediction: scorer.score(target=preprocess_func(target), prediction=preprocess_func(prediction))
111
+ else:
112
+ scoring_func = scorer.score
113
+
114
+ aggregator = scoring.BootstrapAggregator()
115
+ all_scores = [] if is_rouge_measure else dict((k,[]) for k in score_keys)
116
+ for targets, prediction in zip(all_targets, all_predictions):
117
+ # Max across references by aggr_measure
118
+ if is_rouge_measure:
119
+ max_scores = max([scoring_func(target, prediction) for target in targets], key=lambda x: x[aggr_measure].fmeasure)
120
+
121
+ aggregator.add_scores(max_scores)
122
+ all_scores.append(max_scores[aggr_measure].fmeasure*100)
123
+ else:
124
+ if aggr_measure == 'independent':
125
+ max_scores = {}
126
+ for key in score_keys:
127
+ max_scores[key] = max([scoring_func(target, prediction)[key] for target in targets])
128
+ else:
129
+ max_scores = max([scoring_func(target, prediction) for target in targets], key=lambda x: x[aggr_measure])
130
+
131
+ aggregator.add_scores(max_scores)
132
+ for key in score_keys:
133
+ all_scores[key].append(max_scores[key]*100)
134
+
135
+ if not bootstrap:
136
+ return all_scores
137
+
138
+ result = aggregator.aggregate()
139
+ postprocess_result = (lambda x: x.fmeasure*100) if is_rouge_measure else (lambda x: x*100)
140
+ bootstrap_results = {}
141
+ for key in score_keys:
142
+ bootstrap_results[key] = (postprocess_result(result[key].mid), postprocess_result(result[key].low), postprocess_result(result[key].high))
143
+ return bootstrap_results, all_scores
144
+
145
+ ## ROUGE ##
146
+
147
+ score_all_rouge = functools.partial(score_all, scorer=rouge_scorer.RougeScorer(rouge_types=("rouge1", "rouge2", "rougeLsum", "rougeL")), aggr_measure='rougeLsum', score_keys=("rouge1", "rouge2", "rougeLsum"), preprocess_func=strip_attribution_tokens, bootstrap=True)
148
+
149
+ ## F1 ##
150
+
151
+ class _f1_scorer:
152
+ def score(self, target, prediction):
153
+ """Computes token F1 score for a single target and prediction."""
154
+ prediction_tokens = prediction.split()
155
+ target_tokens = target.split()
156
+ common = (collections.Counter(prediction_tokens) &
157
+ collections.Counter(target_tokens))
158
+ num_same = sum(common.values())
159
+ if len(target_tokens) == 0 and len(prediction_tokens) == 0:
160
+ return {'F1': 1.0, 'recall': 1.0, 'precision': 1.0}
161
+ elif len(target_tokens) == 0 and len(prediction_tokens) > 0:
162
+ return {'F1': 0.0, 'recall': 1.0, 'precision': 0.0}
163
+ elif len(target_tokens) > 0 and len(prediction_tokens) == 0:
164
+ return {'F1': 0.0, 'recall': 0.0, 'precision': 1.0}
165
+ elif num_same == 0:
166
+ return {'F1': 0.0, 'recall': 0.0, 'precision': 0.0}
167
+ else:
168
+ precision = 1.0 * num_same / len(prediction_tokens)
169
+ recall = 1.0 * num_same / len(target_tokens)
170
+ f1 = (2 * precision * recall) / (precision + recall)
171
+ return {'F1': f1, 'recall': recall, 'precision': precision}
172
+
173
+
174
+ score_all_f1 = functools.partial(score_all, scorer=_f1_scorer(), aggr_measure='F1', score_keys=("F1", "recall", "precision"))
175
+
176
+
177
+ def preprocess_quotes_f1(text, sep=' ', sources='1-7'):
178
+ text = only_quoted(text, sep=sep, sources=sources)
179
+ return normalize_answers(text)
180
+
181
+
182
+ def score_semqa_f1(data, harmonic=False):
183
+ examples = [d['docs'] for d in data]
184
+ per_source_prf1 = {}
185
+ for source in range(1, 8):
186
+ preprocess_quotes_f1_partial_sources = functools.partial(preprocess_quotes_f1, sep=' ', sources=f'{source}')
187
+ scores = score_all_f1(data, aggr_measure='independent', preprocess_func=preprocess_quotes_f1_partial_sources)
188
+
189
+ for aggr_measure in ('F1', 'recall', 'precision'):
190
+ per_source_prf1[f'{aggr_measure}_source_{source}'] = scores[aggr_measure]
191
+
192
+ semqa_f1s = []
193
+ for i in range(len(examples)):
194
+ precisions, recalls, f1s = [], [] , []
195
+ for source in range(1,8):
196
+ if examples[i][source]:
197
+ precisions.append(per_source_prf1[f'precision_source_{source}'][i])
198
+ recalls.append(per_source_prf1[f'recall_source_{source}'][i])
199
+ f1s.append(per_source_prf1[f'F1_source_{source}'][i])
200
+ if harmonic:
201
+ f1 = harmonic_mean(precisions + recalls)
202
+ else:
203
+ f1 = np.mean(f1s)
204
+ semqa_f1s.append(f1)
205
+
206
+ return np.mean(semqa_f1s)
207
+
208
+
209
+ score_all_recall = functools.partial(score_all, scorer=_f1_scorer(), aggr_measure='recall', score_keys=("recall",))
210
+
211
+
212
+ def score_semqa_short_recall(data):
213
+ if 'num' in data[0]['qa_pairs'][0].keys():
214
+ return compute_str_em(
215
+ [
216
+ {
217
+ 'qa_pairs': [
218
+ {
219
+ 'short_answers': i['ans'],
220
+ }for i in d['qa_pairs']],
221
+ 'output': d['output']
222
+ }
223
+ for d in data]
224
+ )
225
+
226
+ all_targets = [d['qa_pairs'] for d in data]
227
+ all_predictions = [d['output'] for d in data]
228
+
229
+ fuck = []
230
+ # Ignore examples with no targets.
231
+ non_empty_targets, non_empty_predictions = [], []
232
+ for tar, pred in zip(all_targets, all_predictions):
233
+ if len(tar) == 0 or all([x == '' for x in tar]):
234
+ continue
235
+ fuck.append({
236
+ 'answer': tar,
237
+ 'output': pred,
238
+ })
239
+ non_empty_targets.append(tar)
240
+ non_empty_predictions.append(pred)
241
+
242
+ per_source_recall = {}
243
+ for source in range(1, 8):
244
+ preprocess_quotes_f1_partial_sources = functools.partial(preprocess_quotes_f1, sep=' ', sources=f'{source}')
245
+ scores = score_all_recall(fuck, preprocess_func=preprocess_quotes_f1_partial_sources)
246
+ per_source_recall[f'recall_source_{source}'] = scores['recall']
247
+
248
+ semqa_recalls = []
249
+ for i in range(len(non_empty_targets)):
250
+ recalls = []
251
+ for source in range(1,8):
252
+ preprocess_quotes_f1_partial_sources = functools.partial(preprocess_quotes_f1, sep=' ', sources=f'{source}')
253
+ if any([preprocess_quotes_f1_partial_sources(tar) for tar in non_empty_targets[i]]):
254
+ recalls.append(per_source_recall[f'recall_source_{source}'][i])
255
+ avg_recalls = np.mean(recalls)
256
+ semqa_recalls.append(avg_recalls)
257
+
258
+ return np.mean(semqa_recalls)
259
+
260
+
261
+ def exact_presence(short_answers, context):
262
+ """Verify if any of the answers is present in the given context.
263
+ Args:
264
+ short_answers: list of short answers to look for in the context
265
+ context: a paragraph to search for short answers
266
+ Returns:
267
+ true if any of the short answers is present in the context
268
+ """
269
+
270
+ n_short_answers = [normalize_answer(sa) for sa in short_answers]
271
+ n_context = normalize_answer(context)
272
+
273
+ for ans in n_short_answers:
274
+ if ans in n_context:
275
+ return True
276
+
277
+ return False
278
+
279
+
280
+ def normalize_answer(s):
281
+ def remove_articles(text):
282
+ return re.sub(r"\b(a|an|the)\b", " ", text)
283
+
284
+ def white_space_fix(text):
285
+ return " ".join(text.split())
286
+
287
+ def remove_punc(text):
288
+ exclude = set(string.punctuation)
289
+ return "".join(ch for ch in text if ch not in exclude)
290
+
291
+ def lower(text):
292
+ return text.lower()
293
+
294
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
295
+
296
+
297
+ def remove_citations(sent):
298
+ return re.sub(r"\[\d+", "", re.sub(r" \[\d+", "", sent)).replace(" |", "").replace("]", "")
299
+
300
+
301
+ def load_auto_ais():
302
+ global autoais_model, autoais_tokenizer
303
+ print('Initializing eval model for citation precision and recall...')
304
+ autoais_model = AutoModelForSeq2SeqLM.from_pretrained(AUTOAIS_MODEL, torch_dtype=torch.bfloat16, device_map=evaluate_device, )
305
+ autoais_tokenizer = AutoTokenizer.from_pretrained(AUTOAIS_MODEL, use_fast=False)
306
+ print('Done!')
307
+
308
+ def _run_nli_autoais(passage, claim, test = False):
309
+ """
310
+ Run inference for assessing AIS between a premise and .hypothesis
311
+ Adapted from https://github.com/google-research-datasets/Attributed-QA/blob/main/evaluation.py
312
+ """
313
+ if not test:
314
+ global autoais_model, autoais_tokenizer
315
+ if not autoais_model:
316
+ load_auto_ais()
317
+ input_text = "premise: {} hypothesis: {}".format(passage, claim)
318
+ input_ids = autoais_tokenizer(input_text, return_tensors="pt").input_ids.to(autoais_model.device)
319
+ with torch.inference_mode():
320
+ outputs = autoais_model.generate(input_ids, max_new_tokens=10)
321
+ result = autoais_tokenizer.decode(outputs[0], skip_special_tokens=True)
322
+ inference = 1 if result == "1" else 0
323
+ return inference
324
+ else:
325
+ res = 114514
326
+
327
+ return res
328
+
329
+
330
+ def compute_autoais(data,
331
+ qampari=False,
332
+ at_most_sents = 50,
333
+ at_most_citations=3,
334
+ entail_function = _run_nli_autoais):
335
+ """
336
+ Compute AutoAIS score.
337
+
338
+ Args:
339
+ data: requires field `output` and `docs`
340
+ - docs should be a list of items with fields `title` and `text` (or `phrase` and `sent` for QA-extracted docs)
341
+ citation: check citations and use the corresponding references.
342
+ decontext: decontextualize the output
343
+ """
344
+
345
+ global autoais_model, autoais_tokenizer
346
+
347
+
348
+ ais_scores = []
349
+ ais_scores_prec = []
350
+
351
+ sent_total = 0
352
+ sent_mcite = 0
353
+ sent_mcite_support = 0
354
+ sent_mcite_overcite = 0
355
+ autoais_log = []
356
+ for item in tqdm(data):
357
+ # Get sentences by using NLTK
358
+ if qampari:
359
+ #print('now qampari...')
360
+ sents = [item['query'] + " " + x.strip() for x in
361
+ item['output'].rstrip().rstrip(".").rstrip(",").split(",")]
362
+ else:
363
+ sents = sent_tokenize(item['output'])[:at_most_sents]
364
+ if len(sents) == 0:
365
+ ais_scores.append(0.0)
366
+ ais_scores_prec.append(0.0) # len(sents))
367
+ continue
368
+
369
+ target_sents = [remove_citations(sent).strip() for sent in sents]
370
+
371
+ entail = 0
372
+ entail_prec = 0
373
+ total_citations = 0
374
+ for sent_id, sent in enumerate(sents):
375
+ target_sent = target_sents[sent_id] # Citation removed and (if opted for) decontextualized
376
+ joint_entail = -1 # Undecided
377
+
378
+ # Find references
379
+ #ref = [int(r[1:]) - 1 for r in re.findall(r"\[\d+", sent)] # In text citation id starts from 1
380
+ matches = re.findall(r"\[(\d+(?:,\s*\d+)*)\]", sent)
381
+ ref = [int(num)-1 for match in matches for num in match.replace(' ', '').split(',')]
382
+ if len(ref) == 0:
383
+ # No citations
384
+ joint_entail = 0
385
+ elif any([ref_id >= len(item['docs']) for ref_id in ref]):
386
+ # Citations out of range
387
+ joint_entail = 0
388
+ else:
389
+ if at_most_citations is not None:
390
+ ref = ref[:at_most_citations]
391
+ total_citations += len(ref)
392
+ joint_passage = '\n'.join([(item['docs'][psgs_id]) for psgs_id in ref])
393
+
394
+ # If not directly rejected by citation format error, calculate the recall score
395
+ if joint_entail == -1:
396
+ joint_entail = entail_function(joint_passage, target_sent)
397
+ autoais_log.append({
398
+ #"question": item['question'],
399
+ "output": item['output'],
400
+ "claim": sent,
401
+ "passage": [joint_passage],
402
+ "model_type": "NLI",
403
+ "model_output": joint_entail,
404
+ })
405
+
406
+ entail += joint_entail
407
+ if len(ref) > 1:
408
+ sent_mcite += 1
409
+
410
+ # calculate the precision score if applicable
411
+ if joint_entail and len(ref) > 1:
412
+ sent_mcite_support += 1
413
+ # Precision check: did the model cite any unnecessary documents?
414
+ for psgs_id in ref:
415
+ # condition A
416
+ passage = item['docs'][psgs_id]
417
+ nli_result = entail_function(passage, target_sent)
418
+
419
+ # condition B
420
+ if not nli_result:
421
+ subset_exclude = copy.deepcopy(ref)
422
+ subset_exclude.remove(psgs_id)
423
+ passage = '\n'.join([item['docs'][pid] for pid in subset_exclude])
424
+ nli_result =entail_function(passage, target_sent)
425
+ if nli_result: # psgs_id is not necessary
426
+ flag = 0
427
+ sent_mcite_overcite += 1
428
+ else:
429
+ entail_prec += 1
430
+ else:
431
+ entail_prec += 1
432
+ else:
433
+ entail_prec += joint_entail
434
+ sent_total += len(sents)
435
+ ais_scores.append(entail / len(sents))
436
+ ais_scores_prec.append(entail_prec / total_citations if total_citations > 0 else 0) # len(sents))
437
+
438
+ if sent_mcite > 0 and sent_mcite_support > 0:
439
+ print(
440
+ "Among all sentences, %.2f%% have multiple citations, among which %.2f%% are supported by the joint set, among which %.2f%% overcite." % (
441
+ 100 * sent_mcite / sent_total,
442
+ 100 * sent_mcite_support / sent_mcite,
443
+ 100 * sent_mcite_overcite / sent_mcite_support
444
+ ))
445
+
446
+ return {
447
+ "citation_rec": 100 * np.mean(ais_scores),
448
+ "citation_prec": 100 * np.mean(ais_scores_prec),
449
+ }
450
+
451
+
452
+ def compute_f1(a_gold, a_pred):
453
+ """Compute F1 score between two strings."""
454
+
455
+ def _get_tokens(s):
456
+ if not s:
457
+ return []
458
+ return normalize_answer(s).split()
459
+
460
+ gold_toks = _get_tokens(a_gold)
461
+ pred_toks = _get_tokens(a_pred)
462
+
463
+ common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
464
+ num_same = sum(common.values())
465
+
466
+ if len(gold_toks) == 0 or len(pred_toks) == 0:
467
+ # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
468
+ return int(gold_toks == pred_toks)
469
+
470
+ if num_same == 0:
471
+ return 0
472
+
473
+ precision = 1.0 * num_same / len(pred_toks)
474
+ recall = 1.0 * num_same / len(gold_toks)
475
+ f1 = (2 * precision * recall) / (precision + recall)
476
+
477
+ return f1
478
+
479
+
480
+ def compute_exact(a_gold, a_pred):
481
+ """Check whether two strings are equal up to normalization."""
482
+
483
+ return int(normalize_answer(a_gold) == normalize_answer(a_pred))
484
+
485
+
486
+ def compute_qa(data):
487
+ """Compute QA-based accuracy.
488
+ Args:
489
+ data: requires filed `qa_pairs/short_answers` and `output`
490
+ Returns:
491
+ QA metrics (QA-EM, QA-F1, QA-Hit)
492
+ """
493
+ if 'qa_pairs' not in data[0] or data[0]['qa_pairs'] is None:
494
+ logging.warn("Warning: no QA pairs found in data")
495
+ return {
496
+ 'QA-EM': 0,
497
+ 'QA-F1': 0,
498
+ 'QA-Hit': 0,
499
+ }
500
+
501
+ # Load model
502
+ #logger.info("Loading the RoBERTa-large SQuAD model for QA-based accuracy...")
503
+ global qa_pipeline
504
+ if not qa_pipeline:
505
+ qa_pipeline = transformers.pipeline("question-answering", model=QA_MODEL, device = evaluate_device)
506
+ #logger.info("Done")
507
+
508
+ # Get prediction
509
+ #logger.info("Computing the QA-based accuracy...")
510
+ em, f1, bins = [], [], []
511
+ for item in tqdm(data):
512
+ question = [qa_pair['question'] for qa_pair in item['qa_pairs']]
513
+ #question = [item['qa_pairs'][0]['question']]
514
+ context = item['output'] if len(item['output']) > 0 else " "
515
+ results = qa_pipeline(question=question, context=remove_citations(context), handle_impossible_answer=True)
516
+ loc_counter, loc_em, loc_f1 = 0, 0, 0
517
+
518
+ for idx, res in enumerate(results):
519
+ answers = item["qa_pairs"][idx]["short_answers"]
520
+ prediction = res["answer"]
521
+
522
+ loc_em += max([compute_exact(a, prediction) for a in answers])
523
+ loc_f1 += max([compute_f1(a, prediction) for a in answers])
524
+ loc_counter += 1
525
+
526
+ em.append(loc_em / loc_counter)
527
+ f1.append(loc_f1 / loc_counter)
528
+ bins.append(loc_em == loc_counter)
529
+
530
+ return {
531
+ 'QA-EM': 100 * np.mean(em),
532
+ 'QA-F1': 100 * np.mean(f1),
533
+ 'QA-Hit': 100 * np.mean(bins)
534
+ }
535
+
536
+
537
+ def compute_claims(data):
538
+ global autoais_model, autoais_tokenizer
539
+ if autoais_model is None:
540
+ #logger.info("Loading AutoAIS model...")
541
+ # autoais_model = AutoModelForSeq2SeqLM.from_pretrained(AUTOAIS_MODEL, torch_dtype=torch.bfloat16, max_memory=get_max_memory(), device_map="auto")
542
+ autoais_model = AutoModelForSeq2SeqLM.from_pretrained(AUTOAIS_MODEL, torch_dtype=torch.bfloat16,
543
+ device_map=evaluate_device)
544
+ # autoais_model = AutoModelForSeq2SeqLM.from_pretrained(AUTOAIS_MODEL, torch_dtype=torch.bfloat16, max_memory=get_max_memory(), device_map="auto",offload_folder= "/data/hongbang/zsf/projects/ALCE/ALCE/model/t5_xxl_true_nli_mixture/offload1")
545
+ autoais_tokenizer = AutoTokenizer.from_pretrained(AUTOAIS_MODEL, use_fast=False)
546
+ #logger.info("Computing claims...")
547
+ scores = []
548
+ for item in tqdm(data):
549
+ normalized_output = remove_citations(item['output'])
550
+ entail = 0
551
+ claims = item["qa_pairs"]
552
+ for claim in claims:
553
+ entail += _run_nli_autoais(normalized_output, claim)
554
+ scores.append(entail / len(claims))
555
+ return 100 * np.mean(scores)
556
+
557
+
558
+ def compute_qampari_f1(data, cot=False):
559
+ prec = []
560
+ rec = []
561
+ rec_top5 = []
562
+ f1 = []
563
+ f1_top5 = []
564
+
565
+ num_preds = []
566
+ for item in data:
567
+ if cot:
568
+ if ":" in item['output']:
569
+ o = ':'.join(item['output'].split(":")[1:]) # try to separate the COT part and the answer list part.
570
+ else:
571
+ o = ""
572
+ else:
573
+ o = item['output']
574
+
575
+ preds = [normalize_answer(x.strip()) for x in remove_citations(o).rstrip().rstrip(".").rstrip(",").split(",")]
576
+ preds = [p for p in preds if len(p) > 0] # delete empty answers
577
+ num_preds.append(len(preds))
578
+ answers = [[normalize_answer(x) for x in ans] for ans in item['answer']]
579
+ flat_answers = [item for sublist in answers for item in sublist]
580
+ prec.append(sum([p in flat_answers for p in preds]) / len(preds) if len(preds) > 0 else 0)
581
+
582
+ rec.append(sum([any([x in preds for x in a]) for a in answers]) / len(answers))
583
+ rec_top5.append(min(5, sum([any([x in preds for x in a]) for a in answers])) / min(5, len(answers)))
584
+ if (prec[-1] + rec[-1]) == 0:
585
+ f1.append(0)
586
+ else:
587
+ f1.append(2 * prec[-1] * rec[-1] / (prec[-1] + rec[-1]))
588
+ if (prec[-1] + rec_top5[-1]) == 0:
589
+ f1_top5.append(0)
590
+ else:
591
+ f1_top5.append(2 * prec[-1] * rec_top5[-1] / (prec[-1] + rec_top5[-1]))
592
+
593
+ return {
594
+ "num_preds": np.mean(num_preds),
595
+ "qampari_prec": 100 * np.mean(prec),
596
+ "qampari_rec": 100 * np.mean(rec),
597
+ "qampari_rec_top5": 100 * np.mean(rec_top5),
598
+ "qampari_f1": 100 * np.mean(f1),
599
+ "qampari_f1_top5": 100 * np.mean(f1_top5),
600
+ }
601
+
602
+
603
+ def compute_str_em(data):
604
+ """Compute STR-EM metric (only for ASQA)
605
+ Args:
606
+ data: requires field `qa_pairs/short_answers` and `output`
607
+ Returns:
608
+ STR-EM and STR-EM-HIT ()
609
+ """
610
+ if 'qa_pairs' not in data[0] or data[0]['qa_pairs'] is None:
611
+ return 0
612
+
613
+ acc = []
614
+ for item in data:
615
+ loc_acc = []
616
+ if len(item['qa_pairs']) == 0:
617
+ continue
618
+ loc_acc.append(exact_presence(item['qa_pairs'][0]['short_answers'], item["output"]))
619
+ """for qa_pair in item['qa_pairs']:
620
+ loc_acc.append(exact_presence(qa_pair['short_answers'], item["output"]))"""
621
+ acc.append(float(np.mean(loc_acc)))
622
+ return 100 * np.mean(acc) if len(acc) > 0 else 0
623
+
624
+
625
+ def compute_mauve(data):
626
+ """Compute Mauve score."""
627
+
628
+ logging.info("Computing MAUVE...")
629
+ human_data = []
630
+ model_data = []
631
+ for item in data:
632
+ # Remove ending punctuations
633
+ # Remove any new lines
634
+ # Truncate by 100 words
635
+ human_data.append(
636
+ ' '.join((item['query'] + " " + item['answer'].strip()).split()[:100]).rstrip(string.punctuation))
637
+ model_data.append(
638
+ ' '.join((item['query'] + " " + item['output'].strip()).split()[:100]).rstrip(string.punctuation))
639
+
640
+ import mauve
641
+ out = mauve.compute_mauve(
642
+ p_text=human_data,
643
+ q_text=model_data,
644
+ device_id=0,
645
+ max_text_length=512,
646
+ verbose=True,
647
+ batch_size=8,
648
+ featurize_model_name="gpt2-large"
649
+ )
650
+ return out.mauve * 100
651
+
652
+
653
+ def compute_rouge_l(data):
654
+ total = len(data)
655
+ res = {
656
+ "r": 0.0,
657
+ "p": 0.0,
658
+ "f": 0.0
659
+ }
660
+ for item in data:
661
+ # print(f"output:{item['output']}, \nanswer:{item['answer']}")
662
+ if item['output'] and item['answer']:
663
+ rouge = Rouge()
664
+ scores = rouge.get_scores(item['output'], item['answer'])
665
+ res['r'] += scores[0]['rouge-l']['r']
666
+ res['p'] += scores[0]['rouge-l']['p']
667
+ res['f'] += scores[0]['rouge-l']['f']
668
+ else:
669
+ print('Warning: no hypothesis or references')
670
+ res['r'] /= total
671
+ res['p'] /= total
672
+ res['f'] /= total
673
+
674
+ return res
675
+
676
+
677
+ def compute_length(data):
678
+ return sum(len(item['output'].split(' '))for item in data)/(len(data))
679
+
680
+
681
+ metric_list = {
682
+ 'cite_pr': compute_autoais,
683
+ 'asqa_acc': compute_qa,
684
+ 'eli5_acc': compute_claims,
685
+ 'qampari': compute_qampari_f1,
686
+ 'short_ans': compute_str_em,
687
+ # 'fluence': compute_mauve,
688
+ 'rouge': compute_rouge_l,
689
+ 'length': compute_length,
690
+ 'rouge_all': score_all_rouge,
691
+ 'semqa_f1': score_semqa_f1, # 相当于precision
692
+ 'semqa_short': score_semqa_short_recall, # 相当于recall
693
+ }
694
+
695
+ data_list = {
696
+ 'cite_pr': {'output': None, 'docs': None, 'query': None},
697
+ 'asqa_acc': {'output': None,'qa_pairs': None, 'query': None},
698
+ 'eli5_acc': {'output': None, 'qa_pairs': None},
699
+ 'qampari': {'output': None, 'answer': None},
700
+ 'short_ans': {'qa_pairs': None, 'output': None},
701
+ # 'fluence': {'query': None, 'answer': None, 'output': None},
702
+ 'rouge': {'output': None, 'answer': None},
703
+ 'length': {'output': None},
704
+ 'rouge_all': {'answer': None, 'output': None},
705
+ 'semqa_f1': {'answer': None, 'output': None, 'docs': None},
706
+ 'semqa_short':{'output': None, 'qa_pairs': None},
707
+ 'semqa': {}
708
+ }
709
+
710
+
711
+
712
+ class AttributeMetric:
713
+ def __init__(self, config):
714
+ self.task = 'attribute'
715
+ self.metrics = config['metric']
716
+ self.flag = False
717
+ self.data = {
718
+ 'cite_pr': [],
719
+ 'asqa_acc': [],
720
+ 'eli5_acc': [],
721
+ 'qampari': [],
722
+ 'short_ans': [],
723
+ 'fluence': [],
724
+ 'rouge': [],
725
+ 'length': [],
726
+ 'rouge_all': [],
727
+ 'semqa_f1': [],
728
+ 'semqa_short': [],
729
+ 'semqa': [],
730
+ }
731
+
732
+ def add_batch(self, data): #(output, qa_pairs, answer, docs, query)
733
+ for metric in self.metrics:
734
+ self.data[metric].append({k:v for k, v in data.items() if k in data_list[metric]})
735
+
736
+ def compute(self):
737
+ ans = {}
738
+ for metric in self.metrics:
739
+ assert metric in metric_list, logging.info("Invalid metric")
740
+ if metric == 'cite_pr' and 'qampari' in self.metrics:
741
+ ans[metric] = metric_list[metric](data = self.data[metric], qampari = True)
742
+ else:
743
+ ans[metric] = metric_list[metric](data = self.data[metric])
744
+ #if metric == 'semqa':
745
+ # self.flag = True
746
+ #else:
747
+ # ans[metric] = metric_list[metric](data = self.data[metric], qampari = True if 'qampari' in self.metrics else False)
748
+ #if metric == 'rouge_all':
749
+ # ans[metric] = ans[metric][0]['rougeLsum'][0]
750
+
751
+ #if self.flag:
752
+ # ans['semqa'] = np.sqrt(ans['rouge_all'] * ans['semqa_f1'])
753
+ return ans
754
+
755
+ class AutoMetric(BasicMetric):
756
+ def __init__(self, task_name: str, config: Optional[List]) -> None:
757
+ super().__init__()
758
+ path_prefix = os.getenv("MOE_PEFT_METRIC_PATH")
759
+ if path_prefix is None:
760
+ path_prefix = ""
761
+ elif not path_prefix.endswith(os.sep):
762
+ path_prefix += os.sep
763
+
764
+ if task_name == "attribute":
765
+ self.metric_ = AttributeMetric(config)
766
+ elif ":" in task_name:
767
+ split = task_name.split(":")
768
+ self.metric_ = hf_evaluate.load(path_prefix + split[0], split[1])
769
+ else:
770
+ self.metric_ = hf_evaluate.load(path_prefix + task_name)
771
+
772
+ def add_batch(self, predictions: torch.Tensor, references: torch.Tensor):
773
+ self.metric_.add_batch(predictions=predictions, references=references)
774
+
775
+ def compute(self) -> Dict[str, Any]:
776
+ return self.metric_.compute()
777
+
778
+
779
+ class BasicTask:
780
+ def __init__(self) -> None:
781
+ pass
782
+
783
+ @property
784
+ def peft_task_type(self) -> str:
785
+ pass
786
+
787
+ def loading_data(
788
+ self, is_train: bool = True, path: Optional[str] = None
789
+ ) -> List[InputData]:
790
+ pass
791
+
792
+ def loading_metric(self) -> BasicMetric:
793
+ pass
794
+
795
+ def init_kwargs(self) -> Dict:
796
+ return {}
797
+
798
+
799
+ # Casual Fine-tuning Tasks
800
+ # Instant-Created Class
801
+ class CasualTask(BasicTask):
802
+ @property
803
+ def peft_task_type(self) -> str:
804
+ return "CAUSAL_LM"
805
+
806
+ def loading_data(
807
+ self, is_train: bool = True, path: Optional[str] = None
808
+ ) -> List[InputData]:
809
+ assert path is not None, "Casual supervised fine-tuning requires data path."
810
+ assert is_train, "Casual supervised fine-tuning task only supports training."
811
+ # Loading dataset
812
+ if path.endswith(".json") or path.endswith(".jsonl"):
813
+ data = hf_datasets.load_dataset("json", data_files=path)
814
+ elif ":" in path:
815
+ split = path.split(":")
816
+ data = hf_datasets.load_dataset(split[0], split[1])
817
+ else:
818
+ data = hf_datasets.load_dataset(path)
819
+ ret: List[InputData] = []
820
+ for data_point in data["train"]:
821
+ ret.append(
822
+ InputData(
823
+ inputs=Prompt(
824
+ instruction=data_point["instruction"],
825
+ input=data_point.get("input", None),
826
+ label=data_point.get("output", None),
827
+ )
828
+ )
829
+ )
830
+
831
+ return ret
832
+
833
+
834
+ # Sequence Classification
835
+ class SequenceClassificationTask(BasicTask):
836
+ def __init__(
837
+ self,
838
+ task_name: str,
839
+ task_type: str,
840
+ label_dtype: torch.dtype,
841
+ num_labels: int,
842
+ dataload_function: Callable,
843
+ # Setting to `None` corresponds to the task name.
844
+ metric_name: Optional[str] = None,
845
+ # The default values are "train" and "validation".
846
+ subset_map: Optional[Tuple[str, str]] = ("train", "validation"),
847
+ ) -> None:
848
+ super().__init__()
849
+ self.task_name_ = task_name
850
+ self.task_type_ = task_type
851
+ self.label_dtype_ = label_dtype
852
+ self.num_labels_ = num_labels
853
+ self.dataload_function_ = dataload_function
854
+ if metric_name is None:
855
+ self.metric_name_ = task_name
856
+ else:
857
+ self.metric_name_ = metric_name
858
+ self.subset_map_ = subset_map
859
+
860
+ @property
861
+ def peft_task_type(self) -> str:
862
+ return "SEQ_CLS"
863
+
864
+ def loading_data(
865
+ self, is_train: bool = True, path: Optional[str] = None
866
+ ) -> List[InputData]:
867
+ if ":" in self.task_name_:
868
+ split = self.task_name_.split(":")
869
+ data = hf_datasets.load_dataset(
870
+ split[0] if path is None else path, split[1]
871
+ )
872
+ else:
873
+ data = hf_datasets.load_dataset(self.task_name_ if path is None else path)
874
+ data = data[self.subset_map_[0] if is_train else self.subset_map_[1]]
875
+ logging.info(f"Preparing data for {self.task_name_.upper()}")
876
+ ret: List[InputData] = []
877
+ for data_point in data:
878
+ inputs, labels = self.dataload_function_(data_point)
879
+ assert isinstance(labels, List)
880
+ ret.append(InputData(inputs=inputs, labels=labels))
881
+
882
+ return ret
883
+
884
+ def loading_metric(self) -> BasicMetric:
885
+ return AutoMetric(self.metric_name_)
886
+
887
+ def init_kwargs(self) -> Dict:
888
+ return {
889
+ "task_type": self.task_type_,
890
+ "num_labels": self.num_labels_,
891
+ "label_dtype": self.label_dtype_,
892
+ }
893
+
894
+
895
+ # Common Sense
896
+ class CommonSenseTask(BasicTask):
897
+ def __init__(self) -> None:
898
+ super().__init__()
899
+ self.task_type_ = "common_sense"
900
+ self.label_dtype_ = None
901
+
902
+ @property
903
+ def peft_task_type(self) -> str:
904
+ return "QUESTION_ANS"
905
+
906
+ def label_list(self) -> List[str]:
907
+ pass
908
+
909
+
910
+ class AttributeTask(BasicTask):
911
+ def __init__(self) -> None:
912
+ super().__init__()
913
+ self.task_type_ = "attribute"
914
+ self.label_dtype_ = None
915
+
916
+ @property
917
+ def peft_task_type(self) -> str:
918
+ return "ATTRIBUTE"
919
+
920
+ task_dict = {}
921
+
922
+
923
+ # Multi-Task (Only for train)
924
+ class MultiTask(BasicTask):
925
+ def __init__(self, task_names: str) -> None:
926
+ super().__init__()
927
+ self.task_type_ = "multi_task"
928
+ self.label_dtype_ = None
929
+ self.task_list_: List[BasicTask] = []
930
+ task_names = task_names.split(";")
931
+ for name in task_names:
932
+ self.task_list_.append(task_dict[name])
933
+
934
+ def loading_data(
935
+ self, is_train: bool = True, path: Optional[str] = None
936
+ ) -> List[InputData]:
937
+ logging.info(f"Preparing data for {len(self.task_list_)} tasks")
938
+ path_list = None if path is None else path.split(";")
939
+ data: List[InputData] = []
940
+ assert is_train
941
+ for idx, task in enumerate(self.task_list_):
942
+ path: str = "" if path_list is None else path_list[idx].strip()
943
+ data.extend(task.loading_data(is_train, None if len(path) == 0 else path))
944
+ return data
945
+
946
+
947
+ def main():
948
+ """source = '/yy21/MoE-PEFT/dataset/APO/preference_data.jsonl'
949
+ data = []
950
+ with open(source, 'r') as f:
951
+ for line in f:
952
+ y = json.loads(line)
953
+ output = ""
954
+ for s in y['statements']:
955
+ if isinstance(s, List):
956
+ for i in s:
957
+ output += i + " "
958
+ else:
959
+ dot = s['statement'].strip()[-1]
960
+ output += s['statement'].strip()[:-1]
961
+ if 'revised_used_document' in s:
962
+ for i in s['revised_used_document']:
963
+ output += '[' + i + ']'
964
+ else:
965
+ if len(s['used_document']) != 0:
966
+ for i in s['used_document']:
967
+ output += '[' + i + ']'
968
+ output += dot + ' '
969
+
970
+ docs = [d['text'] for d in y['documents']]
971
+ fk = {
972
+ 'query': y['query'],
973
+ 'output': output,
974
+ 'docs': docs,
975
+ }
976
+ ans = compute_autoais(fk)
977
+ print(ans)"""
978
+ def split_docs_and_answer(input_str):
979
+
980
+ if "[ANSWER]" not in input_str:
981
+ return ""
982
+ index = input_str.find("[ANSWER]")
983
+ ans = input_str[index + len("[ANSWER]"):][:-4].strip()
984
+
985
+ return ans
986
+
987
+ test_data = []
988
+ with open('/yy21/test_qamp_v2.jsonl', "r", encoding="utf-8") as fuck:
989
+ with open('/yy21/MoE-PEFT/dataset/front_output/qampari.json', "r", encoding="utf-8") as f:
990
+ data = json.load(f)
991
+ for idx, line in enumerate(fuck):
992
+ opt = json.loads(line)
993
+
994
+ ori_output = re.sub(r'\[ref_(\d+)\]', r'[\1]', opt['response'])
995
+ #qa_pairs = data[idx]['qa_pairs']
996
+ answer = data[idx]['answer']
997
+ query = data[idx]['question']
998
+
999
+ output = split_docs_and_answer(ori_output)
1000
+ ori_docs = []
1001
+ for i in range(5):
1002
+ ori_docs.append(data[idx]['docs'][i]['text'])
1003
+ fk = {
1004
+ #'qa_pairs' : qa_pairs,
1005
+ 'answer' : answer,
1006
+ 'query' : query,
1007
+ 'docs' : ori_docs,
1008
+ 'output' : ori_output
1009
+ }
1010
+ test_data.append(fk)
1011
+ ans = compute_autoais(test_data, qampari=True)
1012
+ print(ans)
1013
+ """with open('/yy21/test_eli5_output0.jsonl', "r", encoding="utf-8") as fuck,\
1014
+ open('/yy21/test_eli5_output.jsonl', "w", encoding="utf-8") as outputf:
1015
+ for idx, line in enumerate(fuck):
1016
+ opt = json.loads(line)
1017
+ opt['accuracy'] = acc[idx]
1018
+ outputf.write(json.dumps(opt, ensure_ascii=False) + '\n')"""
1019
+ """ with open('/yy21/MoE-PEFT/dataset/front_output/eli5.json', "r", encoding="utf-8") as f:
1020
+ data = json.load(f)
1021
+ test_data = []
1022
+ for data_point in data:
1023
+
1024
+ ori_output = data_point['output']
1025
+ qa_pairs = data_point['claims']
1026
+ answer = data_point['answer']
1027
+ query = data_point['question']
1028
+
1029
+ output = split_docs_and_answer(ori_output)
1030
+ ori_docs = []
1031
+ for i in range(5):
1032
+ ori_docs.append(data_point['docs'][i]['text'])
1033
+ fk = {
1034
+ 'qa_pairs' : qa_pairs,
1035
+ 'answer' : answer,
1036
+ 'query' : query,
1037
+ 'docs' : ori_docs,
1038
+ 'output' : output
1039
+ }
1040
+ test_data.append(fk)
1041
+ ans = compute_claims(test_data)
1042
+ print(ans)"""
1043
+
1044
+ if __name__ == "__main__":
1045
+ main()
c2cite/tasks/glue_tasks.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from .common import SequenceClassificationTask
4
+
5
+
6
+ def update_task_dict(task_dict):
7
+ task_dict.update(
8
+ {
9
+ "glue:cola": SequenceClassificationTask(
10
+ task_name="glue:cola",
11
+ task_type="single_label_classification",
12
+ num_labels=2,
13
+ label_dtype=torch.long,
14
+ dataload_function=lambda data_point: (
15
+ [data_point["sentence"]],
16
+ [int(data_point["label"])],
17
+ ),
18
+ ),
19
+ "glue:mnli": SequenceClassificationTask(
20
+ task_name="glue:mnli",
21
+ task_type="single_label_classification",
22
+ num_labels=3,
23
+ label_dtype=torch.long,
24
+ dataload_function=lambda data_point: (
25
+ [data_point["premise"], data_point["hypothesis"]],
26
+ [int(data_point["label"])],
27
+ ),
28
+ ),
29
+ "glue:mrpc": SequenceClassificationTask(
30
+ task_name="glue:mrpc",
31
+ task_type="single_label_classification",
32
+ num_labels=2,
33
+ label_dtype=torch.long,
34
+ dataload_function=lambda data_point: (
35
+ [data_point["sentence1"], data_point["sentence2"]],
36
+ [int(data_point["label"])],
37
+ ),
38
+ ),
39
+ "glue:qnli": SequenceClassificationTask(
40
+ task_name="glue:qnli",
41
+ task_type="single_label_classification",
42
+ num_labels=2,
43
+ label_dtype=torch.long,
44
+ dataload_function=lambda data_point: (
45
+ [data_point["question"], data_point["sentence"]],
46
+ [int(data_point["label"])],
47
+ ),
48
+ ),
49
+ "glue:qqp": SequenceClassificationTask(
50
+ task_name="glue:qqp",
51
+ task_type="single_label_classification",
52
+ num_labels=2,
53
+ label_dtype=torch.long,
54
+ dataload_function=lambda data_point: (
55
+ [data_point["question1"], data_point["question2"]],
56
+ [int(data_point["label"])],
57
+ ),
58
+ ),
59
+ "glue:rte": SequenceClassificationTask(
60
+ task_name="glue:rte",
61
+ task_type="single_label_classification",
62
+ num_labels=2,
63
+ label_dtype=torch.long,
64
+ dataload_function=lambda data_point: (
65
+ [data_point["sentence1"], data_point["sentence2"]],
66
+ [int(data_point["label"])],
67
+ ),
68
+ ),
69
+ "glue:sst2": SequenceClassificationTask(
70
+ task_name="glue:sst2",
71
+ task_type="single_label_classification",
72
+ num_labels=2,
73
+ label_dtype=torch.long,
74
+ dataload_function=lambda data_point: (
75
+ [data_point["sentence"]],
76
+ [int(data_point["label"])],
77
+ ),
78
+ ),
79
+ "glue:wnli": SequenceClassificationTask(
80
+ task_name="glue:wnli",
81
+ task_type="single_label_classification",
82
+ num_labels=2,
83
+ label_dtype=torch.long,
84
+ dataload_function=lambda data_point: (
85
+ [data_point["sentence1"] + " </s> " + data_point["sentence2"]],
86
+ [int(data_point["label"])],
87
+ ),
88
+ ),
89
+ }
90
+ )