Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- Prism/Dream/Dream_Baseline/LICENSE +201 -0
- Prism/Dream/Dream_Prism/LICENSE +201 -0
- Prism/Dream/Dream_Prism/eval_instruct/.gitignore +26 -0
- Prism/Dream/Dream_Prism/eval_instruct/README.md +16 -0
- Prism/Dream/Dream_Prism/eval_instruct/lm_eval/__init__.py +7 -0
- Prism/Dream/Dream_Prism/eval_instruct/lm_eval/__main__.py +512 -0
- Prism/Dream/Dream_Prism/eval_instruct/lm_eval/api/filter.py +56 -0
- Prism/Dream/Dream_Prism/eval_instruct/lm_eval/api/model.py +493 -0
- Prism/Dream/Dream_Prism/eval_instruct/lm_eval/api/samplers.py +232 -0
- Prism/Dream/Dream_Prism/eval_instruct/lm_eval/api/task.py +1839 -0
- Prism/Dream/Dream_Prism/eval_instruct/lm_eval/caching/__init__.py +0 -0
- Prism/Dream/Dream_Prism/eval_instruct/lm_eval/caching/cache.py +59 -0
- Prism/Dream/Dream_Prism/eval_instruct/lm_eval/decontamination/__init__.py +0 -0
- Prism/Dream/Dream_Prism/eval_instruct/lm_eval/decontamination/archiver.py +174 -0
- Prism/Dream/Dream_Prism/eval_instruct/lm_eval/decontamination/decontaminate.py +166 -0
- Prism/Dream/Dream_Prism/eval_instruct/lm_eval/decontamination/janitor.py +328 -0
- Prism/Dream/Dream_Prism/eval_instruct/lm_eval/evaluator.py +736 -0
- Prism/Dream/Dream_Prism/eval_instruct/lm_eval/evaluator_utils.py +554 -0
- Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/__init__.py +25 -0
- Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/custom.py +17 -0
- Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/decontamination.py +25 -0
- Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/extraction.py +188 -0
- Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/selection.py +61 -0
- Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/transformation.py +56 -0
- Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/__init__.py +17 -0
- Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/diffllm.py +563 -0
- Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/dummy.py +41 -0
- Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/hts_sampler.py +257 -0
- Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/huggingface.py +1459 -0
- Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/utils.py +731 -0
- Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/verifier.py +155 -0
- Prism/Dream/Dream_Prism/eval_instruct/lm_eval/utils.py +552 -0
- Prism/Dream/Dream_Prism/eval_instruct/pyproject.toml +134 -0
- Prism/Dream/Dream_Prism/eval_instruct/requirements.txt +1 -0
- Prism/Dream/Dream_Prism/eval_instruct/setup.py +5 -0
- Prism/Dream/Dream_Prism/metrics/gsmk8_eval.py +188 -0
- Prism/Dream/Dream_Prism/metrics/humaneval_eval.py +234 -0
- Prism/Dream/Dream_Prism/metrics/math500_eval.py +205 -0
- Prism/Dream/Dream_Prism/metrics/mbpp_eval.py +281 -0
- Prism/Dream/Dream_Prism/scripts/run_gsm8k.sh +31 -0
- Prism/Dream/Dream_Prism/scripts/run_humaneval.sh +31 -0
- Prism/Dream/Dream_Prism/scripts/run_math500.sh +30 -0
- Prism/Dream/Dream_Prism/scripts/run_mbpp.sh +31 -0
- Prism/Dream/Dream_Prism/src/__init__.py +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/aiohappyeyeballs/__pycache__/types.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/components/semiconnected.py +71 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/operators/__init__.py +4 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/operators/all.py +324 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/operators/binary.py +468 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/operators/product.py +633 -0
Prism/Dream/Dream_Baseline/LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
Prism/Dream/Dream_Prism/LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
Prism/Dream/Dream_Prism/eval_instruct/.gitignore
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
env
|
| 2 |
+
*.pyc
|
| 3 |
+
output/
|
| 4 |
+
output5/
|
| 5 |
+
data/
|
| 6 |
+
lm_cache
|
| 7 |
+
.idea
|
| 8 |
+
build
|
| 9 |
+
dist
|
| 10 |
+
*.egg-info
|
| 11 |
+
venv
|
| 12 |
+
.venv/
|
| 13 |
+
.vscode/
|
| 14 |
+
temp
|
| 15 |
+
__pycache__
|
| 16 |
+
.ipynb_checkpoints
|
| 17 |
+
temp
|
| 18 |
+
test_logs/
|
| 19 |
+
# IPython
|
| 20 |
+
profile_default/
|
| 21 |
+
ipython_config.py
|
| 22 |
+
# don't track (the default location of) the cached requests
|
| 23 |
+
lm_eval/caching/.cache
|
| 24 |
+
# don't track files created by wandb
|
| 25 |
+
wandb
|
| 26 |
+
examples/wandb
|
Prism/Dream/Dream_Prism/eval_instruct/README.md
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dream-Instruct Evaluation Toolkit
|
| 2 |
+
This toolkit contains the code Dream-Instruct models make use of for evaluation.
|
| 3 |
+
|
| 4 |
+
## Quickstart
|
| 5 |
+
To install the toolkit, run:
|
| 6 |
+
```
|
| 7 |
+
pip install -e ".[ifeval,math]"
|
| 8 |
+
```
|
| 9 |
+
|
| 10 |
+
We provide a script to evaluate [Dream-org/Dream-v0-Instruct-7B](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B):
|
| 11 |
+
```
|
| 12 |
+
bash eval.sh
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
## Acknowledgement
|
| 16 |
+
This is a fork of [EleutherAI/lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/main).
|
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
from .evaluator import evaluate, simple_evaluate
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
__version__ = "0.4.8"
|
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/__main__.py
ADDED
|
@@ -0,0 +1,512 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
from functools import partial
|
| 7 |
+
from typing import Union
|
| 8 |
+
|
| 9 |
+
from lm_eval import evaluator, utils
|
| 10 |
+
from lm_eval.evaluator import request_caching_arg_to_dict
|
| 11 |
+
from lm_eval.loggers import EvaluationTracker, WandbLogger
|
| 12 |
+
from lm_eval.tasks import TaskManager
|
| 13 |
+
from lm_eval.utils import (
|
| 14 |
+
handle_non_serializable,
|
| 15 |
+
make_table,
|
| 16 |
+
simple_parse_args_string,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def try_parse_json(value: str) -> Union[str, dict, None]:
|
| 21 |
+
if value is None:
|
| 22 |
+
return None
|
| 23 |
+
try:
|
| 24 |
+
return json.loads(value)
|
| 25 |
+
except json.JSONDecodeError:
|
| 26 |
+
if "{" in value:
|
| 27 |
+
raise argparse.ArgumentTypeError(
|
| 28 |
+
f"Invalid JSON: {value}. Hint: Use double quotes for JSON strings."
|
| 29 |
+
)
|
| 30 |
+
return value
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _int_or_none_list_arg_type(
|
| 34 |
+
min_len: int, max_len: int, defaults: str, value: str, split_char: str = ","
|
| 35 |
+
):
|
| 36 |
+
def parse_value(item):
|
| 37 |
+
item = item.strip().lower()
|
| 38 |
+
if item == "none":
|
| 39 |
+
return None
|
| 40 |
+
try:
|
| 41 |
+
return int(item)
|
| 42 |
+
except ValueError:
|
| 43 |
+
raise argparse.ArgumentTypeError(f"{item} is not an integer or None")
|
| 44 |
+
|
| 45 |
+
items = [parse_value(v) for v in value.split(split_char)]
|
| 46 |
+
num_items = len(items)
|
| 47 |
+
|
| 48 |
+
if num_items == 1:
|
| 49 |
+
# Makes downstream handling the same for single and multiple values
|
| 50 |
+
items = items * max_len
|
| 51 |
+
elif num_items < min_len or num_items > max_len:
|
| 52 |
+
raise argparse.ArgumentTypeError(
|
| 53 |
+
f"Argument requires {max_len} integers or None, separated by '{split_char}'"
|
| 54 |
+
)
|
| 55 |
+
elif num_items != max_len:
|
| 56 |
+
logging.warning(
|
| 57 |
+
f"Argument requires {max_len} integers or None, separated by '{split_char}'. "
|
| 58 |
+
"Missing values will be filled with defaults."
|
| 59 |
+
)
|
| 60 |
+
default_items = [parse_value(v) for v in defaults.split(split_char)]
|
| 61 |
+
items.extend(
|
| 62 |
+
default_items[num_items:]
|
| 63 |
+
) # extend items list with missing defaults
|
| 64 |
+
|
| 65 |
+
return items
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def check_argument_types(parser: argparse.ArgumentParser):
|
| 69 |
+
"""
|
| 70 |
+
Check to make sure all CLI args are typed, raises error if not
|
| 71 |
+
"""
|
| 72 |
+
for action in parser._actions:
|
| 73 |
+
if action.dest != "help" and not action.const:
|
| 74 |
+
if action.type is None:
|
| 75 |
+
raise ValueError(
|
| 76 |
+
f"Argument '{action.dest}' doesn't have a type specified."
|
| 77 |
+
)
|
| 78 |
+
else:
|
| 79 |
+
continue
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def setup_parser() -> argparse.ArgumentParser:
|
| 83 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
|
| 84 |
+
parser.add_argument(
|
| 85 |
+
"--model", "-m", type=str, default="hf", help="Name of model e.g. `hf`"
|
| 86 |
+
)
|
| 87 |
+
parser.add_argument(
|
| 88 |
+
"--tasks",
|
| 89 |
+
"-t",
|
| 90 |
+
default=None,
|
| 91 |
+
type=str,
|
| 92 |
+
metavar="task1,task2",
|
| 93 |
+
help="Comma-separated list of task names or task groupings to evaluate on.\nTo get full list of tasks, use one of the commands `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above",
|
| 94 |
+
)
|
| 95 |
+
parser.add_argument(
|
| 96 |
+
"--model_args",
|
| 97 |
+
"-a",
|
| 98 |
+
default="",
|
| 99 |
+
type=try_parse_json,
|
| 100 |
+
help="""Comma separated string or JSON formatted arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32` or '{"pretrained":"EleutherAI/pythia-160m","dtype":"float32"}'""",
|
| 101 |
+
)
|
| 102 |
+
parser.add_argument(
|
| 103 |
+
"--num_fewshot",
|
| 104 |
+
"-f",
|
| 105 |
+
type=int,
|
| 106 |
+
default=None,
|
| 107 |
+
metavar="N",
|
| 108 |
+
help="Number of examples in few-shot context",
|
| 109 |
+
)
|
| 110 |
+
parser.add_argument(
|
| 111 |
+
"--batch_size",
|
| 112 |
+
"-b",
|
| 113 |
+
type=str,
|
| 114 |
+
default=1,
|
| 115 |
+
metavar="auto|auto:N|N",
|
| 116 |
+
help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.",
|
| 117 |
+
)
|
| 118 |
+
parser.add_argument(
|
| 119 |
+
"--max_batch_size",
|
| 120 |
+
type=int,
|
| 121 |
+
default=None,
|
| 122 |
+
metavar="N",
|
| 123 |
+
help="Maximal batch size to try with --batch_size auto.",
|
| 124 |
+
)
|
| 125 |
+
parser.add_argument(
|
| 126 |
+
"--device",
|
| 127 |
+
type=str,
|
| 128 |
+
default=None,
|
| 129 |
+
help="Device to use (e.g. cuda, cuda:0, cpu).",
|
| 130 |
+
)
|
| 131 |
+
parser.add_argument(
|
| 132 |
+
"--output_path",
|
| 133 |
+
"-o",
|
| 134 |
+
default=None,
|
| 135 |
+
type=str,
|
| 136 |
+
metavar="DIR|DIR/file.json",
|
| 137 |
+
help="The path to the output file where the result metrics will be saved. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.",
|
| 138 |
+
)
|
| 139 |
+
parser.add_argument(
|
| 140 |
+
"--limit",
|
| 141 |
+
"-L",
|
| 142 |
+
type=float,
|
| 143 |
+
default=None,
|
| 144 |
+
metavar="N|0<N<1",
|
| 145 |
+
help="Limit the number of examples per task. "
|
| 146 |
+
"If <1, limit is a percentage of the total number of examples.",
|
| 147 |
+
)
|
| 148 |
+
parser.add_argument(
|
| 149 |
+
"--use_cache",
|
| 150 |
+
"-c",
|
| 151 |
+
type=str,
|
| 152 |
+
default=None,
|
| 153 |
+
metavar="DIR",
|
| 154 |
+
help="A path to a sqlite db file for caching model responses. `None` if not caching.",
|
| 155 |
+
)
|
| 156 |
+
parser.add_argument(
|
| 157 |
+
"--cache_requests",
|
| 158 |
+
type=str,
|
| 159 |
+
default=None,
|
| 160 |
+
choices=["true", "refresh", "delete"],
|
| 161 |
+
help="Speed up evaluation by caching the building of dataset requests. `None` if not caching.",
|
| 162 |
+
)
|
| 163 |
+
parser.add_argument(
|
| 164 |
+
"--check_integrity",
|
| 165 |
+
action="store_true",
|
| 166 |
+
help="Whether to run the relevant part of the test suite for the tasks.",
|
| 167 |
+
)
|
| 168 |
+
parser.add_argument(
|
| 169 |
+
"--write_out",
|
| 170 |
+
"-w",
|
| 171 |
+
action="store_true",
|
| 172 |
+
default=False,
|
| 173 |
+
help="Prints the prompt for the first few documents.",
|
| 174 |
+
)
|
| 175 |
+
parser.add_argument(
|
| 176 |
+
"--log_samples",
|
| 177 |
+
"-s",
|
| 178 |
+
action="store_true",
|
| 179 |
+
default=False,
|
| 180 |
+
help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis. Use with --output_path.",
|
| 181 |
+
)
|
| 182 |
+
parser.add_argument(
|
| 183 |
+
"--system_instruction",
|
| 184 |
+
type=str,
|
| 185 |
+
default=None,
|
| 186 |
+
help="System instruction to be used in the prompt",
|
| 187 |
+
)
|
| 188 |
+
parser.add_argument(
|
| 189 |
+
"--apply_chat_template",
|
| 190 |
+
type=str,
|
| 191 |
+
nargs="?",
|
| 192 |
+
const=True,
|
| 193 |
+
default=False,
|
| 194 |
+
help=(
|
| 195 |
+
"If True, apply chat template to the prompt. "
|
| 196 |
+
"Providing `--apply_chat_template` without an argument will apply the default chat template to the prompt. "
|
| 197 |
+
"To apply a specific template from the available list of templates, provide the template name as an argument. "
|
| 198 |
+
"E.g. `--apply_chat_template template_name`"
|
| 199 |
+
),
|
| 200 |
+
)
|
| 201 |
+
parser.add_argument(
|
| 202 |
+
"--fewshot_as_multiturn",
|
| 203 |
+
action="store_true",
|
| 204 |
+
default=False,
|
| 205 |
+
help="If True, uses the fewshot as a multi-turn conversation",
|
| 206 |
+
)
|
| 207 |
+
parser.add_argument(
|
| 208 |
+
"--show_config",
|
| 209 |
+
action="store_true",
|
| 210 |
+
default=False,
|
| 211 |
+
help="If True, shows the the full config of all tasks at the end of the evaluation.",
|
| 212 |
+
)
|
| 213 |
+
parser.add_argument(
|
| 214 |
+
"--include_path",
|
| 215 |
+
type=str,
|
| 216 |
+
default=None,
|
| 217 |
+
metavar="DIR",
|
| 218 |
+
help="Additional path to include if there are external tasks to include.",
|
| 219 |
+
)
|
| 220 |
+
parser.add_argument(
|
| 221 |
+
"--gen_kwargs",
|
| 222 |
+
type=try_parse_json,
|
| 223 |
+
default=None,
|
| 224 |
+
help=(
|
| 225 |
+
"Either comma delimited string or JSON formatted arguments for model generation on greedy_until tasks,"
|
| 226 |
+
""" e.g. '{"temperature":0.7,"until":["hello"]}' or temperature=0,top_p=0.1."""
|
| 227 |
+
),
|
| 228 |
+
)
|
| 229 |
+
parser.add_argument(
|
| 230 |
+
"--verbosity",
|
| 231 |
+
"-v",
|
| 232 |
+
type=str.upper,
|
| 233 |
+
default=None,
|
| 234 |
+
metavar="CRITICAL|ERROR|WARNING|INFO|DEBUG",
|
| 235 |
+
help="(Deprecated) Controls logging verbosity level. Use the `LOGLEVEL` environment variable instead. Set to DEBUG for detailed output when testing or adding new task configurations.",
|
| 236 |
+
)
|
| 237 |
+
parser.add_argument(
|
| 238 |
+
"--wandb_args",
|
| 239 |
+
type=str,
|
| 240 |
+
default="",
|
| 241 |
+
help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval",
|
| 242 |
+
)
|
| 243 |
+
parser.add_argument(
|
| 244 |
+
"--wandb_config_args",
|
| 245 |
+
type=str,
|
| 246 |
+
default="",
|
| 247 |
+
help="Comma separated string arguments passed to wandb.config.update. Use this to trace parameters that aren't already traced by default. eg. `lr=0.01,repeats=3",
|
| 248 |
+
)
|
| 249 |
+
parser.add_argument(
|
| 250 |
+
"--hf_hub_log_args",
|
| 251 |
+
type=str,
|
| 252 |
+
default="",
|
| 253 |
+
help="Comma separated string arguments passed to Hugging Face Hub's log function, e.g. `hub_results_org=EleutherAI,hub_repo_name=lm-eval-results`",
|
| 254 |
+
)
|
| 255 |
+
parser.add_argument(
|
| 256 |
+
"--predict_only",
|
| 257 |
+
"-x",
|
| 258 |
+
action="store_true",
|
| 259 |
+
default=False,
|
| 260 |
+
help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.",
|
| 261 |
+
)
|
| 262 |
+
default_seed_string = "0,1234,1234,1234"
|
| 263 |
+
parser.add_argument(
|
| 264 |
+
"--seed",
|
| 265 |
+
type=partial(_int_or_none_list_arg_type, 3, 4, default_seed_string),
|
| 266 |
+
default=default_seed_string, # for backward compatibility
|
| 267 |
+
help=(
|
| 268 |
+
"Set seed for python's random, numpy, torch, and fewshot sampling.\n"
|
| 269 |
+
"Accepts a comma-separated list of 4 values for python's random, numpy, torch, and fewshot sampling seeds, "
|
| 270 |
+
"respectively, or a single integer to set the same seed for all four.\n"
|
| 271 |
+
f"The values are either an integer or 'None' to not set the seed. Default is `{default_seed_string}` "
|
| 272 |
+
"(for backward compatibility).\n"
|
| 273 |
+
"E.g. `--seed 0,None,8,52` sets `random.seed(0)`, `torch.manual_seed(8)`, and fewshot sampling seed to 52. "
|
| 274 |
+
"Here numpy's seed is not set since the second value is `None`.\n"
|
| 275 |
+
"E.g, `--seed 42` sets all four seeds to 42."
|
| 276 |
+
),
|
| 277 |
+
)
|
| 278 |
+
parser.add_argument(
|
| 279 |
+
"--trust_remote_code",
|
| 280 |
+
action="store_true",
|
| 281 |
+
help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub",
|
| 282 |
+
)
|
| 283 |
+
parser.add_argument(
|
| 284 |
+
"--confirm_run_unsafe_code",
|
| 285 |
+
action="store_true",
|
| 286 |
+
help="Confirm that you understand the risks of running unsafe code for tasks that require it",
|
| 287 |
+
)
|
| 288 |
+
parser.add_argument(
|
| 289 |
+
"--metadata",
|
| 290 |
+
type=json.loads,
|
| 291 |
+
default=None,
|
| 292 |
+
help="""JSON string metadata to pass to task configs, for example '{"max_seq_lengths":[4096,8192]}'. Will be merged with model_args. Can also be set in task config.""",
|
| 293 |
+
)
|
| 294 |
+
return parser
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def parse_eval_args(parser: argparse.ArgumentParser) -> argparse.Namespace:
|
| 298 |
+
check_argument_types(parser)
|
| 299 |
+
return parser.parse_args()
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
|
| 303 |
+
if not args:
|
| 304 |
+
# we allow for args to be passed externally, else we parse them ourselves
|
| 305 |
+
parser = setup_parser()
|
| 306 |
+
args = parse_eval_args(parser)
|
| 307 |
+
|
| 308 |
+
if args.wandb_args:
|
| 309 |
+
wandb_args_dict = simple_parse_args_string(args.wandb_args)
|
| 310 |
+
wandb_config_args_dict = simple_parse_args_string(args.wandb_config_args)
|
| 311 |
+
wandb_logger = WandbLogger(wandb_args_dict, wandb_config_args_dict)
|
| 312 |
+
|
| 313 |
+
utils.setup_logging(args.verbosity)
|
| 314 |
+
eval_logger = logging.getLogger(__name__)
|
| 315 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 316 |
+
|
| 317 |
+
# update the evaluation tracker args with the output path and the HF token
|
| 318 |
+
if args.output_path:
|
| 319 |
+
args.hf_hub_log_args += f",output_path={args.output_path}"
|
| 320 |
+
if os.environ.get("HF_TOKEN", None):
|
| 321 |
+
args.hf_hub_log_args += f",token={os.environ.get('HF_TOKEN')}"
|
| 322 |
+
evaluation_tracker_args = simple_parse_args_string(args.hf_hub_log_args)
|
| 323 |
+
evaluation_tracker = EvaluationTracker(**evaluation_tracker_args)
|
| 324 |
+
|
| 325 |
+
if args.predict_only:
|
| 326 |
+
args.log_samples = True
|
| 327 |
+
if (args.log_samples or args.predict_only) and not args.output_path:
|
| 328 |
+
raise ValueError(
|
| 329 |
+
"Specify --output_path if providing --log_samples or --predict_only"
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
if args.fewshot_as_multiturn and args.apply_chat_template is False:
|
| 333 |
+
raise ValueError(
|
| 334 |
+
"When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set (either to `True` or to the chosen template name)."
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
if args.include_path is not None:
|
| 338 |
+
eval_logger.info(f"Including path: {args.include_path}")
|
| 339 |
+
metadata = (
|
| 340 |
+
simple_parse_args_string(args.model_args)
|
| 341 |
+
if isinstance(args.model_args, str)
|
| 342 |
+
else args.model_args
|
| 343 |
+
if isinstance(args.model_args, dict)
|
| 344 |
+
else {}
|
| 345 |
+
) | (
|
| 346 |
+
args.metadata
|
| 347 |
+
if isinstance(args.metadata, dict)
|
| 348 |
+
else simple_parse_args_string(args.metadata)
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
task_manager = TaskManager(include_path=args.include_path, metadata=metadata)
|
| 352 |
+
|
| 353 |
+
if "push_samples_to_hub" in evaluation_tracker_args and not args.log_samples:
|
| 354 |
+
eval_logger.warning(
|
| 355 |
+
"Pushing samples to the Hub requires --log_samples to be set. Samples will not be pushed to the Hub."
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
if args.limit:
|
| 359 |
+
eval_logger.warning(
|
| 360 |
+
" --limit SHOULD ONLY BE USED FOR TESTING."
|
| 361 |
+
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
if args.tasks is None:
|
| 365 |
+
eval_logger.error("Need to specify task to evaluate.")
|
| 366 |
+
sys.exit()
|
| 367 |
+
elif args.tasks == "list":
|
| 368 |
+
print(task_manager.list_all_tasks())
|
| 369 |
+
sys.exit()
|
| 370 |
+
elif args.tasks == "list_groups":
|
| 371 |
+
print(task_manager.list_all_tasks(list_subtasks=False, list_tags=False))
|
| 372 |
+
sys.exit()
|
| 373 |
+
elif args.tasks == "list_tags":
|
| 374 |
+
print(task_manager.list_all_tasks(list_groups=False, list_subtasks=False))
|
| 375 |
+
sys.exit()
|
| 376 |
+
elif args.tasks == "list_subtasks":
|
| 377 |
+
print(task_manager.list_all_tasks(list_groups=False, list_tags=False))
|
| 378 |
+
sys.exit()
|
| 379 |
+
else:
|
| 380 |
+
if os.path.isdir(args.tasks):
|
| 381 |
+
import glob
|
| 382 |
+
|
| 383 |
+
task_names = []
|
| 384 |
+
yaml_path = os.path.join(args.tasks, "*.yaml")
|
| 385 |
+
for yaml_file in glob.glob(yaml_path):
|
| 386 |
+
config = utils.load_yaml_config(yaml_file)
|
| 387 |
+
task_names.append(config)
|
| 388 |
+
else:
|
| 389 |
+
task_list = args.tasks.split(",")
|
| 390 |
+
task_names = task_manager.match_tasks(task_list)
|
| 391 |
+
for task in [task for task in task_list if task not in task_names]:
|
| 392 |
+
if os.path.isfile(task):
|
| 393 |
+
config = utils.load_yaml_config(task)
|
| 394 |
+
task_names.append(config)
|
| 395 |
+
task_missing = [
|
| 396 |
+
task for task in task_list if task not in task_names and "*" not in task
|
| 397 |
+
] # we don't want errors if a wildcard ("*") task name was used
|
| 398 |
+
|
| 399 |
+
if task_missing:
|
| 400 |
+
missing = ", ".join(task_missing)
|
| 401 |
+
eval_logger.error(
|
| 402 |
+
f"Tasks were not found: {missing}\n"
|
| 403 |
+
f"{utils.SPACING}Try `lm-eval --tasks list` for list of available tasks",
|
| 404 |
+
)
|
| 405 |
+
raise ValueError(
|
| 406 |
+
f"Tasks not found: {missing}. Try `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above, or pass '--verbosity DEBUG' to troubleshoot task registration issues."
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
# Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args
|
| 410 |
+
if args.trust_remote_code:
|
| 411 |
+
eval_logger.info(
|
| 412 |
+
"Passed `--trust_remote_code`, setting environment variable `HF_DATASETS_TRUST_REMOTE_CODE=true`"
|
| 413 |
+
)
|
| 414 |
+
# HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally,
|
| 415 |
+
# because it's already been determined based on the prior env var before launching our
|
| 416 |
+
# script--`datasets` gets imported by lm_eval internally before these lines can update the env.
|
| 417 |
+
import datasets
|
| 418 |
+
|
| 419 |
+
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
|
| 420 |
+
|
| 421 |
+
args.model_args = args.model_args + ",trust_remote_code=True"
|
| 422 |
+
eval_logger.info(
|
| 423 |
+
f"Selected Tasks: {task_names}"
|
| 424 |
+
) if eval_logger.getEffectiveLevel() >= logging.INFO else print(
|
| 425 |
+
f"Selected Tasks: {task_names}"
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
request_caching_args = request_caching_arg_to_dict(
|
| 429 |
+
cache_requests=args.cache_requests
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
results = evaluator.simple_evaluate(
|
| 433 |
+
model=args.model,
|
| 434 |
+
model_args=args.model_args,
|
| 435 |
+
tasks=task_names,
|
| 436 |
+
num_fewshot=args.num_fewshot,
|
| 437 |
+
batch_size=args.batch_size,
|
| 438 |
+
max_batch_size=args.max_batch_size,
|
| 439 |
+
device=args.device,
|
| 440 |
+
use_cache=args.use_cache,
|
| 441 |
+
limit=args.limit,
|
| 442 |
+
check_integrity=args.check_integrity,
|
| 443 |
+
write_out=args.write_out,
|
| 444 |
+
log_samples=args.log_samples,
|
| 445 |
+
evaluation_tracker=evaluation_tracker,
|
| 446 |
+
system_instruction=args.system_instruction,
|
| 447 |
+
apply_chat_template=args.apply_chat_template,
|
| 448 |
+
fewshot_as_multiturn=args.fewshot_as_multiturn,
|
| 449 |
+
gen_kwargs=args.gen_kwargs,
|
| 450 |
+
task_manager=task_manager,
|
| 451 |
+
predict_only=args.predict_only,
|
| 452 |
+
random_seed=args.seed[0],
|
| 453 |
+
numpy_random_seed=args.seed[1],
|
| 454 |
+
torch_random_seed=args.seed[2],
|
| 455 |
+
fewshot_random_seed=args.seed[3],
|
| 456 |
+
confirm_run_unsafe_code=args.confirm_run_unsafe_code,
|
| 457 |
+
metadata=metadata,
|
| 458 |
+
**request_caching_args,
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
if results is not None:
|
| 462 |
+
if args.log_samples:
|
| 463 |
+
samples = results.pop("samples")
|
| 464 |
+
dumped = json.dumps(
|
| 465 |
+
results, indent=2, default=handle_non_serializable, ensure_ascii=False
|
| 466 |
+
)
|
| 467 |
+
if args.show_config:
|
| 468 |
+
print(dumped)
|
| 469 |
+
|
| 470 |
+
batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))
|
| 471 |
+
|
| 472 |
+
# Add W&B logging
|
| 473 |
+
if args.wandb_args:
|
| 474 |
+
try:
|
| 475 |
+
wandb_logger.post_init(results)
|
| 476 |
+
wandb_logger.log_eval_result()
|
| 477 |
+
if args.log_samples:
|
| 478 |
+
wandb_logger.log_eval_samples(samples)
|
| 479 |
+
except Exception as e:
|
| 480 |
+
eval_logger.info(f"Logging to Weights and Biases failed due to {e}")
|
| 481 |
+
|
| 482 |
+
evaluation_tracker.save_results_aggregated(
|
| 483 |
+
results=results, samples=samples if args.log_samples else None
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
if args.log_samples:
|
| 487 |
+
for task_name, config in results["configs"].items():
|
| 488 |
+
evaluation_tracker.save_results_samples(
|
| 489 |
+
task_name=task_name, samples=samples[task_name]
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
if (
|
| 493 |
+
evaluation_tracker.push_results_to_hub
|
| 494 |
+
or evaluation_tracker.push_samples_to_hub
|
| 495 |
+
):
|
| 496 |
+
evaluation_tracker.recreate_metadata_card()
|
| 497 |
+
|
| 498 |
+
print(
|
| 499 |
+
f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
|
| 500 |
+
f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
|
| 501 |
+
)
|
| 502 |
+
print(make_table(results))
|
| 503 |
+
if "groups" in results:
|
| 504 |
+
print(make_table(results, "groups"))
|
| 505 |
+
|
| 506 |
+
if args.wandb_args:
|
| 507 |
+
# Tear down wandb run once all the logging is done.
|
| 508 |
+
wandb_logger.run.finish()
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
if __name__ == "__main__":
|
| 512 |
+
cli_evaluate()
|
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/api/filter.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Callable, Iterable, List, Union
|
| 4 |
+
|
| 5 |
+
from lm_eval.api.instance import Instance
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Filter(ABC):
|
| 9 |
+
"""
|
| 10 |
+
Filter classes operate on a per-task level.
|
| 11 |
+
They take all model outputs (`instance.resps` for all `task.instances`)
|
| 12 |
+
across all instances of a task, and perform operations.
|
| 13 |
+
In a single run, one can configure any number of separate filters or lists of filters.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, **kwargs) -> None:
|
| 18 |
+
"""
|
| 19 |
+
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
@abstractmethod
|
| 23 |
+
def apply(self, resps: Union[List, Iterable], docs: List[dict]) -> Iterable:
|
| 24 |
+
"""
|
| 25 |
+
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
|
| 26 |
+
Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
|
| 27 |
+
if pass in [<inst.resps for instance 0>, <inst.resps for instance 1>] should return
|
| 28 |
+
[<filtered resps for instance 0>, <filtered resps for instance 1>]
|
| 29 |
+
"""
|
| 30 |
+
return resps
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass
|
| 34 |
+
class FilterEnsemble:
|
| 35 |
+
"""
|
| 36 |
+
FilterEnsemble creates a pipeline applying multiple filters.
|
| 37 |
+
Its intended usage is to stack multiple post-processing steps in order.
|
| 38 |
+
`task.apply_filters` should use a list of FilterEnsemble classes that it stores, to apply each
|
| 39 |
+
pipeline separately.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
name: str
|
| 43 |
+
filters: List[Callable[[], Filter]]
|
| 44 |
+
|
| 45 |
+
def apply(self, instances: List[Instance]) -> None:
|
| 46 |
+
resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
|
| 47 |
+
resps, docs = list(resps), list(docs)
|
| 48 |
+
|
| 49 |
+
for f in self.filters:
|
| 50 |
+
# apply filters in sequence
|
| 51 |
+
resps = f().apply(resps, docs)
|
| 52 |
+
|
| 53 |
+
# add the end results after filtering to filtered_requests of their respective source instances.
|
| 54 |
+
# has key `self.name`: each FilterEnsemble applied in a given run should use a different name.
|
| 55 |
+
for inst, resp in zip(instances, resps):
|
| 56 |
+
inst.filtered_resps[self.name] = resp
|
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/api/model.py
ADDED
|
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
import hashlib
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
from typing import Dict, List, Optional, Tuple, Type, TypeVar, Union
|
| 7 |
+
|
| 8 |
+
import transformers
|
| 9 |
+
from sqlitedict import SqliteDict
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
from lm_eval import utils
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
eval_logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
T = TypeVar("T", bound="LM")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class LM(abc.ABC):
|
| 21 |
+
def __init__(self) -> None:
|
| 22 |
+
"""Defines the interface that should be implemented by all LM subclasses.
|
| 23 |
+
LMs are assumed to take text (strings) as input and yield strings as output
|
| 24 |
+
(inputs/outputs should be tokenization-agnostic.)
|
| 25 |
+
|
| 26 |
+
"""
|
| 27 |
+
# set rank and world size to a single process, by default.
|
| 28 |
+
self._rank = 0
|
| 29 |
+
self._world_size = 1
|
| 30 |
+
self.cache_hook = CacheHook(None)
|
| 31 |
+
|
| 32 |
+
@abc.abstractmethod
|
| 33 |
+
def loglikelihood(self, requests) -> List[Tuple[float, bool]]:
|
| 34 |
+
"""Compute log-likelihood of generating a continuation from a context.
|
| 35 |
+
Downstream tasks should attempt to use loglikelihood instead of other
|
| 36 |
+
LM calls whenever possible.
|
| 37 |
+
|
| 38 |
+
:param requests: list[Instance]
|
| 39 |
+
A list of Instance objects, with property `args` which returns a tuple (context, continuation).
|
| 40 |
+
`context: str`
|
| 41 |
+
Context string. Implementations of LM must be able to handle an
|
| 42 |
+
empty context string.
|
| 43 |
+
`continuation: str`
|
| 44 |
+
The continuation over which log likelihood will be calculated. If
|
| 45 |
+
there is a word boundary, the space should be in the continuation.
|
| 46 |
+
For example, context="hello" continuation=" world" is correct.
|
| 47 |
+
|
| 48 |
+
:return: list[tuple[float, bool]]
|
| 49 |
+
A list of pairs (logprob, isgreedy)
|
| 50 |
+
`logprob: float`
|
| 51 |
+
The log probability of `continuation`.
|
| 52 |
+
`isgreedy`:
|
| 53 |
+
Whether `continuation` would be generated by greedy sampling from `context`.
|
| 54 |
+
"""
|
| 55 |
+
pass
|
| 56 |
+
|
| 57 |
+
@abc.abstractmethod
|
| 58 |
+
def loglikelihood_rolling(self, requests) -> List[float]:
|
| 59 |
+
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
|
| 60 |
+
- We will use the full max context length of the model.
|
| 61 |
+
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
|
| 62 |
+
the max context length.
|
| 63 |
+
- IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
|
| 64 |
+
which may simply concatenate multiple documents together.
|
| 65 |
+
- IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
|
| 66 |
+
multiple chunks, the last input will still a full-sized context.
|
| 67 |
+
Example:
|
| 68 |
+
Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
|
| 69 |
+
Prefix: BOS/EOS
|
| 70 |
+
Max context length: 4
|
| 71 |
+
Resulting input/prediction pairs:
|
| 72 |
+
|
| 73 |
+
INPUT: BOS 0 1 2
|
| 74 |
+
PRED: 0 1 2 3
|
| 75 |
+
|
| 76 |
+
INPUT: 3 4 5 6
|
| 77 |
+
PRED: 4 5 6 7
|
| 78 |
+
|
| 79 |
+
INPUT: 5 6 7 8
|
| 80 |
+
PRED: 8 9
|
| 81 |
+
|
| 82 |
+
Observe that:
|
| 83 |
+
1. Each token is predicted exactly once
|
| 84 |
+
2. For the last pair, we provide the full context, but only score the last two tokens
|
| 85 |
+
|
| 86 |
+
:param requests: list[Instance]
|
| 87 |
+
A list of Instance objects with property `args` which returns a tuple (context,).
|
| 88 |
+
string: str
|
| 89 |
+
String for which we are computing overall loglikelihood
|
| 90 |
+
:return: list[tuple[float]]
|
| 91 |
+
A list of tuples (logprob,)
|
| 92 |
+
logprob: float
|
| 93 |
+
The log probability of `context` conditioned on the BOS/EOS token.
|
| 94 |
+
Can also be overridden for custom cases by `prefix_token_id`.
|
| 95 |
+
"""
|
| 96 |
+
pass
|
| 97 |
+
|
| 98 |
+
# TODO: Add an optional max length
|
| 99 |
+
@abc.abstractmethod
|
| 100 |
+
def generate_until(self, requests) -> List[str]:
|
| 101 |
+
"""Generate greedily until a stopping sequence
|
| 102 |
+
|
| 103 |
+
:param requests: list[Instance]
|
| 104 |
+
A list of Instance objects with property `args` which returns a tuple (context, gen_kwargs).
|
| 105 |
+
context: str
|
| 106 |
+
Context string
|
| 107 |
+
gen_kwargs: dict
|
| 108 |
+
A dictionary of keyword arguments to pass to the generation function e.g. top_k, until, etc.
|
| 109 |
+
:return: list[str]
|
| 110 |
+
A list of model generated continuations.
|
| 111 |
+
continuation: str
|
| 112 |
+
The generated continuation.
|
| 113 |
+
"""
|
| 114 |
+
pass
|
| 115 |
+
|
| 116 |
+
def apply_chat_template(
|
| 117 |
+
self, chat_history: List[Dict[str, str]], add_generation_prompt=True
|
| 118 |
+
) -> str:
|
| 119 |
+
"""
|
| 120 |
+
Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM.
|
| 121 |
+
|
| 122 |
+
:param chat_history: list[dict[str, str]]
|
| 123 |
+
A list of dictionaries with keys 'role' and 'content'.
|
| 124 |
+
Values are strings representing the role name and the content of the message, respectively.
|
| 125 |
+
:param add_generation_prompt: bool
|
| 126 |
+
Whether to append an assistant gen prefix (for e.g. <|assistant|>) to the assistant messages in the chat history. False if prefilling an assistant message.
|
| 127 |
+
:return: str
|
| 128 |
+
A string representing the chat history in a format that can be used as input to the LM.
|
| 129 |
+
"""
|
| 130 |
+
raise NotImplementedError(
|
| 131 |
+
"To use this model with chat templates, please implement the 'apply_chat_template' method for your model type."
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
@classmethod
|
| 135 |
+
def create_from_arg_string(
|
| 136 |
+
cls: Type[T], arg_string: str, additional_config: Optional[dict] = None
|
| 137 |
+
) -> T:
|
| 138 |
+
"""
|
| 139 |
+
Creates an instance of the LM class using the given argument string and additional config.
|
| 140 |
+
|
| 141 |
+
Parameters:
|
| 142 |
+
- arg_string: A string containing arguments in the format key1=value1,key2=value2.
|
| 143 |
+
- additional_config: Optional dictionary containing additional configuration parameters.
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
- Instance of the LM class.
|
| 147 |
+
"""
|
| 148 |
+
additional_config = {} if additional_config is None else additional_config
|
| 149 |
+
args = utils.simple_parse_args_string(arg_string)
|
| 150 |
+
args2 = {k: v for k, v in additional_config.items() if v is not None}
|
| 151 |
+
return cls(**args, **args2)
|
| 152 |
+
|
| 153 |
+
@classmethod
|
| 154 |
+
def create_from_arg_obj(
|
| 155 |
+
cls: Type[T], arg_dict: dict, additional_config: Optional[dict] = None
|
| 156 |
+
) -> T:
|
| 157 |
+
"""
|
| 158 |
+
Creates an instance of the LM class using the given arg_obj
|
| 159 |
+
|
| 160 |
+
Parameters:
|
| 161 |
+
- arg_obj: A dict containing arguments in the format key1=value1,key2=value2.
|
| 162 |
+
- additional_config: Optional dictionary containing additional configuration parameters.
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
- Instance of the LM class.
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
additional_config = {} if additional_config is None else additional_config
|
| 169 |
+
additional_config = {
|
| 170 |
+
k: v for k, v in additional_config.items() if v is not None
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
return cls(**arg_dict, **additional_config)
|
| 174 |
+
|
| 175 |
+
@property
|
| 176 |
+
def rank(self):
|
| 177 |
+
# used in the case of parallelism. Hardcoded to
|
| 178 |
+
# ensure no errors arise using API models which do
|
| 179 |
+
# not support multi-device parallelism nor expect it.
|
| 180 |
+
return self._rank
|
| 181 |
+
|
| 182 |
+
@property
|
| 183 |
+
def world_size(self):
|
| 184 |
+
# used in the case of parallelism. Hardcoded to
|
| 185 |
+
# ensure no errors arise using API models which do
|
| 186 |
+
# not support multi-device parallelism nor expect it.
|
| 187 |
+
return self._world_size
|
| 188 |
+
|
| 189 |
+
@property
|
| 190 |
+
def tokenizer_name(self) -> str:
|
| 191 |
+
"""Must be defined for LM subclasses which implement Chat Templating.
|
| 192 |
+
Should return the name of the tokenizer or chat template used.
|
| 193 |
+
Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used.
|
| 194 |
+
"""
|
| 195 |
+
raise NotImplementedError(
|
| 196 |
+
"To use this model with chat templates, please implement the 'tokenizer_name' property."
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
|
| 200 |
+
"""Returns the chat template structure for user/assistant messages if a template is provided.
|
| 201 |
+
This method is intended to be overridden in a subclass to define a specific chat template format.
|
| 202 |
+
For models that do not support chat templates, this method returns None by default.
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
return ""
|
| 206 |
+
|
| 207 |
+
def set_cache_hook(self, cache_hook) -> None:
|
| 208 |
+
self.cache_hook = cache_hook
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
### SQLite-based caching of LM responses
|
| 212 |
+
def hash_args(attr, args):
|
| 213 |
+
dat = json.dumps([attr] + list(args))
|
| 214 |
+
return hashlib.sha256(dat.encode("utf-8")).hexdigest()
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class CacheHook:
|
| 218 |
+
def __init__(self, cachinglm) -> None:
|
| 219 |
+
if cachinglm is None:
|
| 220 |
+
self.dbdict = None
|
| 221 |
+
return
|
| 222 |
+
|
| 223 |
+
self.dbdict = cachinglm.dbdict
|
| 224 |
+
|
| 225 |
+
def add_partial(self, attr, req, res) -> None:
|
| 226 |
+
if self.dbdict is None:
|
| 227 |
+
return
|
| 228 |
+
hsh = hash_args(attr, req)
|
| 229 |
+
self.dbdict[hsh] = res
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class CachingLM:
|
| 233 |
+
def __init__(self, lm, cache_db) -> None:
|
| 234 |
+
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
|
| 235 |
+
|
| 236 |
+
:param lm: LM
|
| 237 |
+
Underlying LM
|
| 238 |
+
:param cache_db: str
|
| 239 |
+
Path to cache db
|
| 240 |
+
"""
|
| 241 |
+
self.lm = lm
|
| 242 |
+
self.cache_db = cache_db
|
| 243 |
+
if os.path.dirname(cache_db):
|
| 244 |
+
os.makedirs(os.path.dirname(cache_db), exist_ok=True)
|
| 245 |
+
self.dbdict = SqliteDict(cache_db, autocommit=True)
|
| 246 |
+
|
| 247 |
+
# add hook to lm
|
| 248 |
+
lm.set_cache_hook(self.get_cache_hook())
|
| 249 |
+
|
| 250 |
+
def __getattr__(self, attr: str):
|
| 251 |
+
lm_attr = getattr(self.lm, attr)
|
| 252 |
+
if attr not in ["loglikelihood", "loglikelihood_rolling", "generate_until"]:
|
| 253 |
+
eval_logger.debug(f"Passing through attribute '{attr}' to underlying LM")
|
| 254 |
+
return lm_attr
|
| 255 |
+
|
| 256 |
+
def fn(requests):
|
| 257 |
+
res = []
|
| 258 |
+
remaining_reqs = []
|
| 259 |
+
warned = False
|
| 260 |
+
# figure out which ones are cached and which ones are new
|
| 261 |
+
eval_logger.info(
|
| 262 |
+
f"Loading '{attr}' responses from cache '{self.cache_db}' where possible..."
|
| 263 |
+
)
|
| 264 |
+
for req in tqdm(requests, desc="Checking cached requests"):
|
| 265 |
+
hsh = hash_args(attr, req.args)
|
| 266 |
+
if attr == "generate_until" and req.args[1].get("do_sample", False):
|
| 267 |
+
# when we are doing non-greedy generation, don't use the cache
|
| 268 |
+
# (else every "randomly sampled" generation would be identical for repeats > 1).
|
| 269 |
+
if not warned:
|
| 270 |
+
eval_logger.warning(
|
| 271 |
+
f"Arguments to lm.generate_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed for such requests."
|
| 272 |
+
)
|
| 273 |
+
warned = True
|
| 274 |
+
res.append(None)
|
| 275 |
+
remaining_reqs.append(req)
|
| 276 |
+
elif hsh in self.dbdict:
|
| 277 |
+
ob = self.dbdict[hsh]
|
| 278 |
+
|
| 279 |
+
assert ob is not None
|
| 280 |
+
|
| 281 |
+
res.append(ob)
|
| 282 |
+
else:
|
| 283 |
+
res.append(None)
|
| 284 |
+
remaining_reqs.append(req)
|
| 285 |
+
eval_logger.info(
|
| 286 |
+
f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}"
|
| 287 |
+
)
|
| 288 |
+
if remaining_reqs:
|
| 289 |
+
# actually run the LM on the requests that do not have cached results
|
| 290 |
+
rem_res = getattr(self.lm, attr)(remaining_reqs)
|
| 291 |
+
else:
|
| 292 |
+
rem_res = []
|
| 293 |
+
|
| 294 |
+
# stick the new ones back into the list and also cache any of the new ones
|
| 295 |
+
resptr = 0
|
| 296 |
+
for req, r in zip(remaining_reqs, rem_res):
|
| 297 |
+
while res[resptr] is not None:
|
| 298 |
+
resptr += 1
|
| 299 |
+
|
| 300 |
+
res[resptr] = r
|
| 301 |
+
|
| 302 |
+
# caching
|
| 303 |
+
hsh = hash_args(attr, req.args)
|
| 304 |
+
self.dbdict[hsh] = r
|
| 305 |
+
self.dbdict.commit()
|
| 306 |
+
|
| 307 |
+
return res
|
| 308 |
+
|
| 309 |
+
return fn
|
| 310 |
+
|
| 311 |
+
def get_cache_hook(self):
|
| 312 |
+
return CacheHook(self)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
class TemplateLM(LM):
|
| 316 |
+
"""
|
| 317 |
+
A class acting as intermediary between the LM base class
|
| 318 |
+
and boilerplate often included in other LM subclasses.
|
| 319 |
+
"""
|
| 320 |
+
|
| 321 |
+
tokenizer = None
|
| 322 |
+
|
| 323 |
+
@property
|
| 324 |
+
@abc.abstractmethod
|
| 325 |
+
def eot_token_id(self):
|
| 326 |
+
pass
|
| 327 |
+
|
| 328 |
+
@property
|
| 329 |
+
def prefix_token_id(self):
|
| 330 |
+
# it is used as prefix for loglikelihood
|
| 331 |
+
return self.eot_token_id
|
| 332 |
+
|
| 333 |
+
@abc.abstractmethod
|
| 334 |
+
def tok_encode(self, string: str, **kwargs) -> List[int]:
|
| 335 |
+
"""
|
| 336 |
+
Tokenize a string using the model's tokenizer and return a list of token IDs.
|
| 337 |
+
"""
|
| 338 |
+
pass
|
| 339 |
+
|
| 340 |
+
@abc.abstractmethod
|
| 341 |
+
def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]:
|
| 342 |
+
pass
|
| 343 |
+
|
| 344 |
+
def _encode_pair(
|
| 345 |
+
self, context: str, continuation: str
|
| 346 |
+
) -> Tuple[List[int], List[int]]:
|
| 347 |
+
n_spaces = len(context) - len(context.rstrip())
|
| 348 |
+
if n_spaces > 0:
|
| 349 |
+
continuation = context[-n_spaces:] + continuation
|
| 350 |
+
context = context[:-n_spaces]
|
| 351 |
+
|
| 352 |
+
model_class = getattr(self, "AUTO_MODEL_CLASS", None)
|
| 353 |
+
|
| 354 |
+
if model_class == transformers.AutoModelForSeq2SeqLM:
|
| 355 |
+
context_enc = self.tok_encode(context)
|
| 356 |
+
continuation_enc = self.tok_encode(continuation, add_special_tokens=False)
|
| 357 |
+
else:
|
| 358 |
+
whole_enc = self.tok_encode(context + continuation)
|
| 359 |
+
context_enc = self.tok_encode(context)
|
| 360 |
+
|
| 361 |
+
context_enc_len = len(context_enc)
|
| 362 |
+
continuation_enc = whole_enc[context_enc_len:]
|
| 363 |
+
|
| 364 |
+
return context_enc, continuation_enc
|
| 365 |
+
|
| 366 |
+
def loglikelihood(
|
| 367 |
+
self, requests, disable_tqdm: bool = False
|
| 368 |
+
) -> List[Tuple[float, bool]]:
|
| 369 |
+
new_reqs = []
|
| 370 |
+
for context, continuation in [req.args for req in requests]:
|
| 371 |
+
if context == "":
|
| 372 |
+
# BOS or EOS as context
|
| 373 |
+
context_enc, continuation_enc = (
|
| 374 |
+
[self.prefix_token_id],
|
| 375 |
+
self.tok_encode(continuation),
|
| 376 |
+
)
|
| 377 |
+
else:
|
| 378 |
+
context_enc, continuation_enc = self._encode_pair(context, continuation)
|
| 379 |
+
|
| 380 |
+
new_reqs.append(((context, continuation), context_enc, continuation_enc))
|
| 381 |
+
|
| 382 |
+
return self._loglikelihood_tokens(new_reqs, disable_tqdm=disable_tqdm)
|
| 383 |
+
|
| 384 |
+
@abc.abstractmethod
|
| 385 |
+
def loglikelihood_rolling(
|
| 386 |
+
self, requests, disable_tqdm: bool = False
|
| 387 |
+
) -> List[float]:
|
| 388 |
+
pass
|
| 389 |
+
|
| 390 |
+
@abc.abstractmethod
|
| 391 |
+
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
|
| 392 |
+
pass
|
| 393 |
+
|
| 394 |
+
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
|
| 395 |
+
"""
|
| 396 |
+
Set and get the appropriate chat template for the model.
|
| 397 |
+
This method sets the tokenizer's chat_template and returns the template string for reproducibility.
|
| 398 |
+
|
| 399 |
+
The template selection logic is adapted from the Transformers library's `apply_chat_template`
|
| 400 |
+
method in the Tokenizer class. The original implementation can be found at:
|
| 401 |
+
https://github.com/huggingface/transformers/blob/fc35907f95459d7a6c5281dfadd680b6f7b620e3/src/transformers/tokenization_utils_base.py#L1687
|
| 402 |
+
|
| 403 |
+
This method ensures that the right template is chosen based on the following:
|
| 404 |
+
0. If the model has no 'tokenizer' attribute: assumes that there is only a single possible chat template, handled on the model provider side internally. Returns the empty string.
|
| 405 |
+
1. If the model's tokenizer has multiple templates:
|
| 406 |
+
a. Use the specified template if it exists in the dictionary.
|
| 407 |
+
b. Use the default template from the list if no specific template is provided.
|
| 408 |
+
c. Raise an error if no default template exists and no specific template is provided.
|
| 409 |
+
2. If the model's tokenizer has a single template or no template:
|
| 410 |
+
a. Use the tokenizer's chat template if available.
|
| 411 |
+
b. Fall back to the default chat template if no tokenizer chat template exists.
|
| 412 |
+
|
| 413 |
+
Args:
|
| 414 |
+
chat_template (Union[bool, str]): Specifies the chat template to use.
|
| 415 |
+
- If False or None, no template is applied.
|
| 416 |
+
- If True, the default or only available template is used.
|
| 417 |
+
- If a string, the template with the matching name is used.
|
| 418 |
+
|
| 419 |
+
Returns:
|
| 420 |
+
Optional[str]: The selected chat template, or None if no template is applied.
|
| 421 |
+
"""
|
| 422 |
+
if self.tokenizer is None:
|
| 423 |
+
return ""
|
| 424 |
+
|
| 425 |
+
if chat_template is False or chat_template is None:
|
| 426 |
+
eval_logger.warning(
|
| 427 |
+
"model.chat_template was called with the chat_template set to False or None. "
|
| 428 |
+
"Therefore no chat template will be applied. Make sure this is an intended behavior."
|
| 429 |
+
)
|
| 430 |
+
return None
|
| 431 |
+
|
| 432 |
+
# Convert boolean chat_template to None to ensure compatibility with the adapted logic
|
| 433 |
+
if isinstance(chat_template, bool):
|
| 434 |
+
chat_template = None
|
| 435 |
+
using_default_template = False
|
| 436 |
+
|
| 437 |
+
# First, handle the cases when the model has a dict of multiple templates
|
| 438 |
+
try:
|
| 439 |
+
template = (
|
| 440 |
+
self.tokenizer.chat_template or self.tokenizer.default_chat_template
|
| 441 |
+
)
|
| 442 |
+
except AttributeError:
|
| 443 |
+
return None
|
| 444 |
+
|
| 445 |
+
if isinstance(template, dict):
|
| 446 |
+
using_default_dict = self.tokenizer.chat_template is None
|
| 447 |
+
|
| 448 |
+
if chat_template is not None:
|
| 449 |
+
if chat_template in template:
|
| 450 |
+
selected_template = template[chat_template]
|
| 451 |
+
if using_default_dict:
|
| 452 |
+
using_default_template = True
|
| 453 |
+
else:
|
| 454 |
+
raise ValueError(
|
| 455 |
+
f"The specified chat template '{chat_template}' is not available. "
|
| 456 |
+
f"Available template names are {sorted(template.keys())}."
|
| 457 |
+
)
|
| 458 |
+
else:
|
| 459 |
+
# If user didn't pass a chat template, use the default template from the dict
|
| 460 |
+
if "default" in template:
|
| 461 |
+
selected_template = template["default"]
|
| 462 |
+
using_default_template = True
|
| 463 |
+
else:
|
| 464 |
+
raise ValueError(
|
| 465 |
+
"This model has multiple chat templates with no default specified! Please either pass a chat "
|
| 466 |
+
"template or the name of the template you wish to use to the `chat_template` argument. Available "
|
| 467 |
+
f"template names are {sorted(template.keys())}."
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
# Cases when the model has a single template or no template
|
| 471 |
+
else:
|
| 472 |
+
# priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template
|
| 473 |
+
if isinstance(chat_template, str):
|
| 474 |
+
eval_logger.warning(
|
| 475 |
+
"Chat template name provided, but the tokenizer's chat template is not a dictionary. "
|
| 476 |
+
"Using the tokenizer's chat template or the default template instead."
|
| 477 |
+
)
|
| 478 |
+
if self.tokenizer.chat_template is not None:
|
| 479 |
+
selected_template = self.tokenizer.chat_template
|
| 480 |
+
else:
|
| 481 |
+
selected_template = self.tokenizer.default_chat_template
|
| 482 |
+
using_default_template = True
|
| 483 |
+
|
| 484 |
+
if using_default_template:
|
| 485 |
+
eval_logger.warning(
|
| 486 |
+
"No chat template is set for this tokenizer, falling back to a default class-level template. This is "
|
| 487 |
+
"very error-prone, because models are often trained with templates different from the class default! "
|
| 488 |
+
"Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which "
|
| 489 |
+
"point any code depending on them will stop working. We recommend setting a valid chat template before "
|
| 490 |
+
"then to ensure that this model continues working without issues."
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
return selected_template
|
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/api/samplers.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import warnings
|
| 3 |
+
from functools import partial
|
| 4 |
+
from typing import TYPE_CHECKING, Iterable, Optional, Union
|
| 5 |
+
|
| 6 |
+
import datasets
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
if TYPE_CHECKING:
|
| 10 |
+
from random import Random
|
| 11 |
+
|
| 12 |
+
from lm_eval.api.task import ConfigurableTask, Task
|
| 13 |
+
|
| 14 |
+
eval_logger = logging.getLogger("lm-eval")
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ContextSampler:
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
docs: list[dict],
|
| 21 |
+
task: Union["Task", "ConfigurableTask"],
|
| 22 |
+
fewshot_indices: Optional[Iterable] = None,
|
| 23 |
+
rnd: Optional["Random"] = None,
|
| 24 |
+
) -> None:
|
| 25 |
+
self.rnd = rnd
|
| 26 |
+
if not self.rnd:
|
| 27 |
+
raise ValueError(
|
| 28 |
+
"A `random.Random` generator argument must be provided to `rnd` of FewShotSampler!"
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
self.task = task
|
| 32 |
+
self.config = task._config
|
| 33 |
+
|
| 34 |
+
self.target_delimiter = self.config.target_delimiter
|
| 35 |
+
self.fewshot_delimiter = self.config.fewshot_delimiter
|
| 36 |
+
|
| 37 |
+
if (
|
| 38 |
+
self.config.fewshot_config is not None
|
| 39 |
+
and self.config.fewshot_config.get("doc_to_text", None) is not None
|
| 40 |
+
):
|
| 41 |
+
self.doc_to_text = partial(
|
| 42 |
+
self.task.doc_to_text,
|
| 43 |
+
doc_to_text=self.config.fewshot_config.get("doc_to_text", None),
|
| 44 |
+
)
|
| 45 |
+
else:
|
| 46 |
+
self.doc_to_text = self.task.doc_to_text
|
| 47 |
+
|
| 48 |
+
if (
|
| 49 |
+
self.config.fewshot_config is not None
|
| 50 |
+
and self.config.fewshot_config.get("doc_to_target", None) is not None
|
| 51 |
+
):
|
| 52 |
+
self.doc_to_target = partial(
|
| 53 |
+
self.task.doc_to_target,
|
| 54 |
+
doc_to_target=self.config.fewshot_config.get("doc_to_target", None),
|
| 55 |
+
)
|
| 56 |
+
else:
|
| 57 |
+
self.doc_to_target = self.task.doc_to_target
|
| 58 |
+
|
| 59 |
+
if (
|
| 60 |
+
self.config.fewshot_config is not None
|
| 61 |
+
and self.config.fewshot_config.get("doc_to_choice", None) is not None
|
| 62 |
+
):
|
| 63 |
+
self.doc_to_choice = partial(
|
| 64 |
+
self.task.doc_to_choice,
|
| 65 |
+
doc_to_choice=self.config.fewshot_config.get("doc_to_choice", None),
|
| 66 |
+
)
|
| 67 |
+
else:
|
| 68 |
+
self.doc_to_choice = self.task.doc_to_choice
|
| 69 |
+
|
| 70 |
+
self.docs = docs # HF dataset split, provided by task._fewshot_docs()
|
| 71 |
+
if fewshot_indices: # subset few-shot docs from
|
| 72 |
+
if not isinstance(self.docs, datasets.Dataset):
|
| 73 |
+
raise ValueError(
|
| 74 |
+
"Got `fewshot_indices` but fewshot_docs are not a HF dataset. Don't use both `fewshot_indices` and a user-defined few-shot sample list simultaneously"
|
| 75 |
+
)
|
| 76 |
+
self.docs = self.docs.select(fewshot_indices)
|
| 77 |
+
|
| 78 |
+
def get_context(self, doc: dict, num_fewshot: int, gen_prefix: str = None):
|
| 79 |
+
# draw an extra fewshot sample if using same split as evaluating on
|
| 80 |
+
prefix = gen_prefix + " " if gen_prefix else ""
|
| 81 |
+
n_samples = (
|
| 82 |
+
num_fewshot + 1
|
| 83 |
+
if self.config.fewshot_split == self.config.test_split
|
| 84 |
+
else num_fewshot
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# draw `n_samples` docs from fewshot_docs
|
| 88 |
+
fewshotex = self.sample(n_samples)
|
| 89 |
+
|
| 90 |
+
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
|
| 91 |
+
# TODO: should we just stop people from using fewshot from same split as evaluating?
|
| 92 |
+
selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
|
| 93 |
+
|
| 94 |
+
labeled_examples = ""
|
| 95 |
+
for doc in selected_docs:
|
| 96 |
+
doc_content = self.doc_to_text(doc)
|
| 97 |
+
doc_target = self.doc_to_target(doc)
|
| 98 |
+
if self.config.doc_to_choice is None or isinstance(doc_content, str):
|
| 99 |
+
labeled_examples += doc_content
|
| 100 |
+
else:
|
| 101 |
+
labeled_examples += self.doc_to_choice(doc)[doc_content]
|
| 102 |
+
|
| 103 |
+
if doc_target != "":
|
| 104 |
+
if self.target_delimiter.isspace() and str(doc_target)[0].isspace():
|
| 105 |
+
# TODO: add logger warn once here.
|
| 106 |
+
warnings.warn(
|
| 107 |
+
"Both target_delimiter and target start with a space. This may cause issues.",
|
| 108 |
+
Warning,
|
| 109 |
+
stacklevel=2,
|
| 110 |
+
)
|
| 111 |
+
labeled_examples += self.target_delimiter
|
| 112 |
+
labeled_examples += prefix
|
| 113 |
+
labeled_examples += (
|
| 114 |
+
str(doc_target[0])
|
| 115 |
+
if isinstance(doc_target, list)
|
| 116 |
+
else doc_target
|
| 117 |
+
if self.config.doc_to_choice is None or isinstance(doc_target, str)
|
| 118 |
+
else str(self.doc_to_choice(doc)[doc_target])
|
| 119 |
+
)
|
| 120 |
+
labeled_examples += self.fewshot_delimiter
|
| 121 |
+
|
| 122 |
+
return labeled_examples
|
| 123 |
+
|
| 124 |
+
def get_chat_context(
|
| 125 |
+
self,
|
| 126 |
+
doc: dict,
|
| 127 |
+
num_fewshot: int,
|
| 128 |
+
fewshot_as_multiturn: bool = False,
|
| 129 |
+
gen_prefix: Optional[str] = None,
|
| 130 |
+
):
|
| 131 |
+
# TODO: Do we need any other delimiter
|
| 132 |
+
prefix = gen_prefix + " " if gen_prefix else ""
|
| 133 |
+
chat_history = []
|
| 134 |
+
# draw an extra fewshot sample if using same split as evaluating on
|
| 135 |
+
n_samples = (
|
| 136 |
+
num_fewshot + 1
|
| 137 |
+
if self.config.fewshot_split == self.config.test_split
|
| 138 |
+
else num_fewshot
|
| 139 |
+
)
|
| 140 |
+
# draw `n_samples` docs from fewshot_docs
|
| 141 |
+
fewshotex = self.sample(n_samples)
|
| 142 |
+
|
| 143 |
+
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
|
| 144 |
+
# TODO: should we just stop people from using fewshot from same split as evaluating?
|
| 145 |
+
selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
|
| 146 |
+
|
| 147 |
+
if fewshot_as_multiturn:
|
| 148 |
+
for doc in selected_docs:
|
| 149 |
+
doc_content = self.doc_to_text(doc)
|
| 150 |
+
doc_target = self.doc_to_target(doc)
|
| 151 |
+
chat_history.append(
|
| 152 |
+
{
|
| 153 |
+
"role": "user",
|
| 154 |
+
"content": doc_content
|
| 155 |
+
if self.config.doc_to_choice is None
|
| 156 |
+
or isinstance(doc_content, str)
|
| 157 |
+
else self.doc_to_choice(doc)[doc_content],
|
| 158 |
+
}
|
| 159 |
+
)
|
| 160 |
+
chat_history.append(
|
| 161 |
+
{
|
| 162 |
+
"role": "assistant",
|
| 163 |
+
"content": prefix + str(doc_target[0])
|
| 164 |
+
if isinstance(doc_target, list)
|
| 165 |
+
else prefix + doc_target
|
| 166 |
+
if self.config.doc_to_choice is None
|
| 167 |
+
or isinstance(doc_target, str)
|
| 168 |
+
else prefix + str(self.doc_to_choice(doc)[doc_target]),
|
| 169 |
+
}
|
| 170 |
+
)
|
| 171 |
+
else:
|
| 172 |
+
# get fewshot context as one user turn
|
| 173 |
+
chat_history.append(
|
| 174 |
+
{
|
| 175 |
+
"role": "user",
|
| 176 |
+
"content": self.get_context(
|
| 177 |
+
doc, num_fewshot, gen_prefix=gen_prefix
|
| 178 |
+
),
|
| 179 |
+
}
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
return chat_history
|
| 183 |
+
|
| 184 |
+
def sample(self, n: int):
|
| 185 |
+
"""
|
| 186 |
+
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
return self.rnd.sample(self.docs, n)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class FirstNSampler(ContextSampler):
|
| 193 |
+
def sample(self, n: int) -> None:
|
| 194 |
+
"""
|
| 195 |
+
Draw the first `n` samples in order from the specified split.
|
| 196 |
+
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
|
| 197 |
+
"""
|
| 198 |
+
assert n <= len(self.docs), (
|
| 199 |
+
f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available."
|
| 200 |
+
)
|
| 201 |
+
return self.docs[:n]
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class BalancedSampler(ContextSampler):
|
| 205 |
+
def sample(self, n: int) -> None:
|
| 206 |
+
"""
|
| 207 |
+
TODO: this should return approximately class-balanced samples from our fewshot examples.
|
| 208 |
+
TODO: what order should they be in? maybe random?
|
| 209 |
+
"""
|
| 210 |
+
|
| 211 |
+
pass
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class ManualSampler(ContextSampler):
|
| 215 |
+
def sample(self, n: int) -> None:
|
| 216 |
+
""" """
|
| 217 |
+
pass
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
SAMPLER_REGISTRY = {
|
| 221 |
+
"default": ContextSampler,
|
| 222 |
+
"first_n": FirstNSampler,
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def get_sampler(name: str):
|
| 227 |
+
try:
|
| 228 |
+
return SAMPLER_REGISTRY[name]
|
| 229 |
+
except KeyError:
|
| 230 |
+
raise ValueError(
|
| 231 |
+
f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}"
|
| 232 |
+
)
|
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/api/task.py
ADDED
|
@@ -0,0 +1,1839 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
import ast
|
| 3 |
+
import logging
|
| 4 |
+
import random
|
| 5 |
+
import re
|
| 6 |
+
from collections.abc import Callable
|
| 7 |
+
from copy import deepcopy
|
| 8 |
+
from dataclasses import asdict, dataclass
|
| 9 |
+
from inspect import getsource
|
| 10 |
+
from typing import (
|
| 11 |
+
Any,
|
| 12 |
+
Dict,
|
| 13 |
+
Iterable,
|
| 14 |
+
Iterator,
|
| 15 |
+
List,
|
| 16 |
+
Literal,
|
| 17 |
+
Mapping,
|
| 18 |
+
Optional,
|
| 19 |
+
Tuple,
|
| 20 |
+
Union,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
import datasets
|
| 24 |
+
import numpy as np
|
| 25 |
+
from tqdm import tqdm
|
| 26 |
+
|
| 27 |
+
from lm_eval import utils
|
| 28 |
+
from lm_eval.api import samplers
|
| 29 |
+
from lm_eval.api.instance import Instance, OutputType
|
| 30 |
+
from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
|
| 31 |
+
from lm_eval.api.registry import (
|
| 32 |
+
AGGREGATION_REGISTRY,
|
| 33 |
+
DEFAULT_METRIC_REGISTRY,
|
| 34 |
+
get_aggregation,
|
| 35 |
+
get_metric,
|
| 36 |
+
get_metric_aggregation,
|
| 37 |
+
is_higher_better,
|
| 38 |
+
)
|
| 39 |
+
from lm_eval.caching.cache import load_from_cache, save_to_cache
|
| 40 |
+
from lm_eval.filters import build_filter_ensemble
|
| 41 |
+
from lm_eval.prompts import get_prompt
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
ALL_OUTPUT_TYPES = [
|
| 45 |
+
"loglikelihood",
|
| 46 |
+
"multiple_choice",
|
| 47 |
+
"loglikelihood_rolling",
|
| 48 |
+
"generate_until",
|
| 49 |
+
]
|
| 50 |
+
|
| 51 |
+
eval_logger = logging.getLogger(__name__)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@dataclass
|
| 55 |
+
class TaskConfig(dict):
|
| 56 |
+
# task naming/registry
|
| 57 |
+
task: Optional[str] = None
|
| 58 |
+
task_alias: Optional[str] = None
|
| 59 |
+
tag: Optional[Union[str, list]] = None
|
| 60 |
+
# HF dataset options.
|
| 61 |
+
# which dataset to use,
|
| 62 |
+
# and what splits for what purpose
|
| 63 |
+
custom_dataset: Optional[Callable] = None
|
| 64 |
+
dataset_path: Optional[str] = None
|
| 65 |
+
dataset_name: Optional[str] = None
|
| 66 |
+
dataset_kwargs: Optional[dict] = None
|
| 67 |
+
training_split: Optional[str] = None
|
| 68 |
+
validation_split: Optional[str] = None
|
| 69 |
+
test_split: Optional[str] = None
|
| 70 |
+
fewshot_split: Optional[str] = (
|
| 71 |
+
None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?)
|
| 72 |
+
)
|
| 73 |
+
# formatting / prompting options.
|
| 74 |
+
# see docs/advanced_task_guide.md for more info
|
| 75 |
+
process_docs: Optional[Callable] = None
|
| 76 |
+
doc_to_text: Optional[Union[Callable, str]] = None
|
| 77 |
+
doc_to_target: Optional[Union[Callable, str]] = None
|
| 78 |
+
doc_to_image: Union[Callable, str] = None
|
| 79 |
+
doc_to_audio: Union[Callable, str] = None
|
| 80 |
+
unsafe_code: bool = False
|
| 81 |
+
doc_to_choice: Optional[Union[Callable, str, dict, list]] = None
|
| 82 |
+
process_results: Optional[Union[Callable, str]] = None
|
| 83 |
+
use_prompt: Optional[str] = None
|
| 84 |
+
description: str = ""
|
| 85 |
+
target_delimiter: str = " "
|
| 86 |
+
fewshot_delimiter: str = "\n\n"
|
| 87 |
+
fewshot_config: Optional[dict] = None
|
| 88 |
+
# runtime configuration options
|
| 89 |
+
num_fewshot: Optional[int] = None
|
| 90 |
+
# scoring options
|
| 91 |
+
metric_list: Optional[list] = None
|
| 92 |
+
output_type: OutputType = "generate_until"
|
| 93 |
+
generation_kwargs: Optional[dict] = None
|
| 94 |
+
repeats: int = 1
|
| 95 |
+
filter_list: Optional[Union[str, list]] = None
|
| 96 |
+
should_decontaminate: bool = False
|
| 97 |
+
doc_to_decontamination_query: Optional[str] = None
|
| 98 |
+
gen_prefix: Optional[str] = None
|
| 99 |
+
metadata: Optional[dict] = (
|
| 100 |
+
None # by default, not used in the code. allows for users to pass arbitrary info to tasks
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
def __post_init__(self) -> None:
|
| 104 |
+
if self.generation_kwargs is not None:
|
| 105 |
+
if self.output_type != "generate_until":
|
| 106 |
+
eval_logger.warning(
|
| 107 |
+
f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!"
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
if "temperature" in self.generation_kwargs:
|
| 111 |
+
self.generation_kwargs["temperature"] = float(
|
| 112 |
+
self.generation_kwargs["temperature"]
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
if "until" not in self.generation_kwargs:
|
| 116 |
+
self.generation_kwargs["until"] = [self.fewshot_delimiter]
|
| 117 |
+
else:
|
| 118 |
+
if self.output_type == "generate_until":
|
| 119 |
+
# ensure that we greedily generate in absence of explicit arguments otherwise
|
| 120 |
+
self.generation_kwargs = {
|
| 121 |
+
"until": (
|
| 122 |
+
None
|
| 123 |
+
if self.fewshot_delimiter is None
|
| 124 |
+
else [self.fewshot_delimiter]
|
| 125 |
+
),
|
| 126 |
+
"do_sample": False,
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
def __getitem__(self, item):
|
| 130 |
+
return getattr(self, item)
|
| 131 |
+
|
| 132 |
+
def __setitem__(self, item, value):
|
| 133 |
+
return setattr(self, item, value)
|
| 134 |
+
|
| 135 |
+
def to_dict(self, keep_callable: bool = False) -> dict:
|
| 136 |
+
"""dumps the current config as a dictionary object, as a printable format.
|
| 137 |
+
null fields will not be printed.
|
| 138 |
+
Used for dumping results alongside full task configuration
|
| 139 |
+
|
| 140 |
+
:return: dict
|
| 141 |
+
A printable dictionary version of the TaskConfig object.
|
| 142 |
+
|
| 143 |
+
# TODO: should any default value in the TaskConfig not be printed?
|
| 144 |
+
"""
|
| 145 |
+
cfg_dict = asdict(self)
|
| 146 |
+
# remove values that are `None`
|
| 147 |
+
for k, v in list(cfg_dict.items()):
|
| 148 |
+
if v is None:
|
| 149 |
+
cfg_dict.pop(k)
|
| 150 |
+
elif k == "metric_list":
|
| 151 |
+
for metric_dict in v:
|
| 152 |
+
for metric_key, metric_value in metric_dict.items():
|
| 153 |
+
if callable(metric_value):
|
| 154 |
+
metric_dict[metric_key] = self.serialize_function(
|
| 155 |
+
metric_value, keep_callable=keep_callable
|
| 156 |
+
)
|
| 157 |
+
cfg_dict[k] = v
|
| 158 |
+
elif callable(v):
|
| 159 |
+
cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable)
|
| 160 |
+
return cfg_dict
|
| 161 |
+
|
| 162 |
+
def serialize_function(
|
| 163 |
+
self, value: Union[Callable, str], keep_callable=False
|
| 164 |
+
) -> Union[Callable, str]:
|
| 165 |
+
"""Serializes a given function or string.
|
| 166 |
+
|
| 167 |
+
If 'keep_callable' is True, the original callable is returned.
|
| 168 |
+
Otherwise, attempts to return the source code of the callable using 'getsource'.
|
| 169 |
+
"""
|
| 170 |
+
if keep_callable:
|
| 171 |
+
return value
|
| 172 |
+
else:
|
| 173 |
+
try:
|
| 174 |
+
return getsource(value)
|
| 175 |
+
except (TypeError, OSError):
|
| 176 |
+
return str(value)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class Task(abc.ABC):
|
| 180 |
+
"""A task represents an entire benchmark including its dataset, problems,
|
| 181 |
+
answers, and evaluation methods. See BoolQ for a simple example implementation
|
| 182 |
+
|
| 183 |
+
A `doc` can be any python object which represents one instance of evaluation.
|
| 184 |
+
This is usually a dictionary e.g.
|
| 185 |
+
{"question": ..., "answer": ...} or
|
| 186 |
+
{"question": ..., question, answer)
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
VERSION: Optional[Union[int, str]] = None
|
| 190 |
+
|
| 191 |
+
# The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
|
| 192 |
+
# or a path to a custom `datasets` loading script.
|
| 193 |
+
DATASET_PATH: Optional[str] = None
|
| 194 |
+
|
| 195 |
+
# The name of a subset within `DATASET_PATH`.
|
| 196 |
+
DATASET_NAME: Optional[str] = None
|
| 197 |
+
|
| 198 |
+
OUTPUT_TYPE: Optional[OutputType] = None
|
| 199 |
+
|
| 200 |
+
def __init__(
|
| 201 |
+
self,
|
| 202 |
+
data_dir: Optional[str] = None,
|
| 203 |
+
cache_dir: Optional[str] = None,
|
| 204 |
+
download_mode: Optional[datasets.DownloadMode] = None,
|
| 205 |
+
config: Optional[Mapping] = None, # Union[dict, TaskConfig]
|
| 206 |
+
) -> None:
|
| 207 |
+
"""
|
| 208 |
+
:param data_dir: str
|
| 209 |
+
Stores the path to a local folder containing the `Task`'s data files.
|
| 210 |
+
Use this to specify the path to manually downloaded data (usually when
|
| 211 |
+
the dataset is not publicly accessible).
|
| 212 |
+
:param cache_dir: str
|
| 213 |
+
The directory to read/write the `Task` dataset. This follows the
|
| 214 |
+
HuggingFace `datasets` API with the default cache directory located at:
|
| 215 |
+
`~/.cache/huggingface/datasets`
|
| 216 |
+
NOTE: You can change the cache location globally for a given process
|
| 217 |
+
to another directory:
|
| 218 |
+
`export HF_DATASETS_CACHE="/path/to/another/directory"`
|
| 219 |
+
:param download_mode: datasets.DownloadMode
|
| 220 |
+
How to treat pre-existing `Task` downloads and data.
|
| 221 |
+
- `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
|
| 222 |
+
Reuse download and reuse dataset.
|
| 223 |
+
- `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
|
| 224 |
+
Reuse download with fresh dataset.
|
| 225 |
+
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
|
| 226 |
+
Fresh download and fresh dataset.
|
| 227 |
+
"""
|
| 228 |
+
self.download(data_dir, cache_dir, download_mode)
|
| 229 |
+
self._training_docs: Optional[list] = None
|
| 230 |
+
self._fewshot_docs: Optional[list] = None
|
| 231 |
+
self._instances: Optional[List[Instance]] = None
|
| 232 |
+
|
| 233 |
+
self._config: TaskConfig = TaskConfig({**config}) if config else TaskConfig()
|
| 234 |
+
|
| 235 |
+
self._filters = [build_filter_ensemble("none", [["take_first", None]])]
|
| 236 |
+
self.fewshot_rnd: Optional[random.Random] = (
|
| 237 |
+
None # purposely induce errors in case of improper usage
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
def download(
|
| 241 |
+
self,
|
| 242 |
+
data_dir: Optional[str] = None,
|
| 243 |
+
cache_dir: Optional[str] = None,
|
| 244 |
+
download_mode=None,
|
| 245 |
+
) -> None:
|
| 246 |
+
"""Downloads and returns the task dataset.
|
| 247 |
+
Override this method to download the dataset from a custom API.
|
| 248 |
+
|
| 249 |
+
:param data_dir: str
|
| 250 |
+
Stores the path to a local folder containing the `Task`'s data files.
|
| 251 |
+
Use this to specify the path to manually downloaded data (usually when
|
| 252 |
+
the dataset is not publicly accessible).
|
| 253 |
+
:param cache_dir: str
|
| 254 |
+
The directory to read/write the `Task` dataset. This follows the
|
| 255 |
+
HuggingFace `datasets` API with the default cache directory located at:
|
| 256 |
+
`~/.cache/huggingface/datasets`
|
| 257 |
+
NOTE: You can change the cache location globally for a given process
|
| 258 |
+
by setting the shell environment variable, `HF_DATASETS_CACHE`,
|
| 259 |
+
to another directory:
|
| 260 |
+
`export HF_DATASETS_CACHE="/path/to/another/directory"`
|
| 261 |
+
:param download_mode: datasets.DownloadMode
|
| 262 |
+
How to treat pre-existing `Task` downloads and data.
|
| 263 |
+
- `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
|
| 264 |
+
Reuse download and reuse dataset.
|
| 265 |
+
- `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
|
| 266 |
+
Reuse download with fresh dataset.
|
| 267 |
+
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
|
| 268 |
+
Fresh download and fresh dataset.
|
| 269 |
+
"""
|
| 270 |
+
self.dataset = datasets.load_dataset(
|
| 271 |
+
path=self.DATASET_PATH,
|
| 272 |
+
name=self.DATASET_NAME,
|
| 273 |
+
data_dir=data_dir,
|
| 274 |
+
cache_dir=cache_dir,
|
| 275 |
+
download_mode=download_mode,
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
@property
|
| 279 |
+
def config(self) -> TaskConfig:
|
| 280 |
+
"""Returns the TaskConfig associated with this class."""
|
| 281 |
+
return self._config
|
| 282 |
+
|
| 283 |
+
@abc.abstractmethod
|
| 284 |
+
def has_training_docs(self):
|
| 285 |
+
"""Whether the task has a training set"""
|
| 286 |
+
pass
|
| 287 |
+
|
| 288 |
+
@abc.abstractmethod
|
| 289 |
+
def has_validation_docs(self):
|
| 290 |
+
"""Whether the task has a validation set"""
|
| 291 |
+
pass
|
| 292 |
+
|
| 293 |
+
@abc.abstractmethod
|
| 294 |
+
def has_test_docs(self):
|
| 295 |
+
"""Whether the task has a test set"""
|
| 296 |
+
pass
|
| 297 |
+
|
| 298 |
+
def training_docs(self) -> Iterable:
|
| 299 |
+
"""
|
| 300 |
+
:return: Iterable[obj]
|
| 301 |
+
A iterable of any object, that doc_to_text can handle
|
| 302 |
+
"""
|
| 303 |
+
return []
|
| 304 |
+
|
| 305 |
+
def validation_docs(self) -> Iterable:
|
| 306 |
+
"""
|
| 307 |
+
:return: Iterable[obj]
|
| 308 |
+
A iterable of any object, that doc_to_text can handle
|
| 309 |
+
"""
|
| 310 |
+
return []
|
| 311 |
+
|
| 312 |
+
def test_docs(self) -> Iterable:
|
| 313 |
+
"""
|
| 314 |
+
:return: Iterable[obj]
|
| 315 |
+
A iterable of any object, that doc_to_text can handle
|
| 316 |
+
"""
|
| 317 |
+
return []
|
| 318 |
+
|
| 319 |
+
def fewshot_docs(self) -> Iterable:
|
| 320 |
+
"""
|
| 321 |
+
:return: Iterable[obj]
|
| 322 |
+
A iterable of any object, that doc_to_text can handle
|
| 323 |
+
"""
|
| 324 |
+
if self.has_training_docs():
|
| 325 |
+
return self.training_docs()
|
| 326 |
+
elif self.has_validation_docs():
|
| 327 |
+
return self.validation_docs()
|
| 328 |
+
else:
|
| 329 |
+
if self.config.get("num_fewshot", 0) > 0:
|
| 330 |
+
eval_logger.warning(
|
| 331 |
+
f"[Task: {self.config.task}] has_training_docs and has_validation_docs are False"
|
| 332 |
+
", using test_docs as fewshot_docs but this is not recommended."
|
| 333 |
+
)
|
| 334 |
+
return self.test_docs()
|
| 335 |
+
|
| 336 |
+
def _process_doc(self, doc: dict) -> dict:
|
| 337 |
+
"""
|
| 338 |
+
Override this to process (detokenize, strip, replace, etc.) individual
|
| 339 |
+
documents. This can be used in a map over documents of a data split.
|
| 340 |
+
E.g. `map(self._process_doc, self.dataset["validation"])`
|
| 341 |
+
|
| 342 |
+
:return: dict
|
| 343 |
+
The processed version of the specified `doc`.
|
| 344 |
+
"""
|
| 345 |
+
return doc
|
| 346 |
+
|
| 347 |
+
@property
|
| 348 |
+
def instances(self) -> List[Instance]:
|
| 349 |
+
"""After calling `task.build_all_requests()`, tasks
|
| 350 |
+
maintain a list of the dataset instances which will be evaluated.
|
| 351 |
+
"""
|
| 352 |
+
return self._instances
|
| 353 |
+
|
| 354 |
+
def fewshot_examples(self, k, rnd):
|
| 355 |
+
if self._training_docs is None:
|
| 356 |
+
self._training_docs = list(self.training_docs())
|
| 357 |
+
|
| 358 |
+
return rnd.sample(self._training_docs, k)
|
| 359 |
+
|
| 360 |
+
def doc_to_decontamination_query(self, doc):
|
| 361 |
+
raise NotImplementedError(
|
| 362 |
+
"Override doc_to_decontamination_query with document specific decontamination query."
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
@abc.abstractmethod
|
| 366 |
+
def doc_to_text(self, doc):
|
| 367 |
+
pass
|
| 368 |
+
|
| 369 |
+
@abc.abstractmethod
|
| 370 |
+
def doc_to_target(self, doc):
|
| 371 |
+
pass
|
| 372 |
+
|
| 373 |
+
# not an abstractmethod because not every language-only task has to implement this
|
| 374 |
+
def doc_to_image(self, doc):
|
| 375 |
+
raise NotImplementedError
|
| 376 |
+
|
| 377 |
+
def doc_to_audio(self, doc):
|
| 378 |
+
raise NotImplementedError
|
| 379 |
+
|
| 380 |
+
def doc_to_prefix(self, doc):
|
| 381 |
+
return ""
|
| 382 |
+
|
| 383 |
+
def build_all_requests(
|
| 384 |
+
self,
|
| 385 |
+
*,
|
| 386 |
+
limit: Union[int, None] = None,
|
| 387 |
+
rank: int = 0,
|
| 388 |
+
world_size: int = 1,
|
| 389 |
+
cache_requests: bool = False,
|
| 390 |
+
rewrite_requests_cache: bool = False,
|
| 391 |
+
system_instruction: Optional[str] = None,
|
| 392 |
+
apply_chat_template: bool = False,
|
| 393 |
+
fewshot_as_multiturn: bool = False,
|
| 394 |
+
chat_template: Optional[Callable] = None,
|
| 395 |
+
tokenizer_name: str = "",
|
| 396 |
+
) -> None:
|
| 397 |
+
"""Build a set of Instances for a task, and store them in task.instances"""
|
| 398 |
+
|
| 399 |
+
# used with caching
|
| 400 |
+
og_limit = limit
|
| 401 |
+
|
| 402 |
+
cache_key = f"requests-{self._config.task}-{self.config.num_fewshot}shot-rank{rank}-world_size{world_size}"
|
| 403 |
+
cache_key += "-chat_template" if apply_chat_template else ""
|
| 404 |
+
cache_key += "-fewshot_as_multiturn" if fewshot_as_multiturn else ""
|
| 405 |
+
cache_key += (
|
| 406 |
+
f"-system_prompt_hash{utils.hash_string(system_instruction)}"
|
| 407 |
+
if system_instruction is not None
|
| 408 |
+
else ""
|
| 409 |
+
)
|
| 410 |
+
cache_key += f"-tokenizer{tokenizer_name}"
|
| 411 |
+
|
| 412 |
+
cached_instances = load_from_cache(file_name=cache_key, cache=cache_requests)
|
| 413 |
+
|
| 414 |
+
if cache_requests and cached_instances and not rewrite_requests_cache:
|
| 415 |
+
cached_instances = cached_instances[:limit]
|
| 416 |
+
|
| 417 |
+
flattened_instances = [
|
| 418 |
+
instance
|
| 419 |
+
for instance_group in cached_instances
|
| 420 |
+
for instance in instance_group
|
| 421 |
+
]
|
| 422 |
+
|
| 423 |
+
self._instances = flattened_instances
|
| 424 |
+
return
|
| 425 |
+
|
| 426 |
+
eval_logger.info(f"Building contexts for {self.config.task} on rank {rank}...")
|
| 427 |
+
|
| 428 |
+
instances = []
|
| 429 |
+
|
| 430 |
+
# process all documents when caching is specified for simplicity
|
| 431 |
+
if (
|
| 432 |
+
cache_requests
|
| 433 |
+
and (not cached_instances or rewrite_requests_cache)
|
| 434 |
+
and limit is not None
|
| 435 |
+
):
|
| 436 |
+
limit = None
|
| 437 |
+
|
| 438 |
+
doc_id_docs = list(
|
| 439 |
+
self.doc_iterator(rank=rank, limit=limit, world_size=world_size)
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
num_docs = len(doc_id_docs)
|
| 443 |
+
|
| 444 |
+
for doc_id, doc in tqdm(
|
| 445 |
+
doc_id_docs,
|
| 446 |
+
total=num_docs,
|
| 447 |
+
):
|
| 448 |
+
# sample fewshot context #TODO: need to offset doc_id by rank now!
|
| 449 |
+
fewshot_ctx = self.fewshot_context(
|
| 450 |
+
doc,
|
| 451 |
+
0 if self.config.num_fewshot is None else self.config.num_fewshot,
|
| 452 |
+
system_instruction,
|
| 453 |
+
apply_chat_template,
|
| 454 |
+
fewshot_as_multiturn,
|
| 455 |
+
chat_template,
|
| 456 |
+
gen_prefix=self.doc_to_prefix(doc),
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
# TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
|
| 460 |
+
inst = self.construct_requests(
|
| 461 |
+
doc=doc,
|
| 462 |
+
ctx=fewshot_ctx,
|
| 463 |
+
metadata=(self.config["task"], doc_id, self.config.repeats),
|
| 464 |
+
apply_chat_template=apply_chat_template,
|
| 465 |
+
chat_template=chat_template,
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
if not isinstance(inst, list):
|
| 469 |
+
inst = [inst]
|
| 470 |
+
|
| 471 |
+
instances.append(inst)
|
| 472 |
+
|
| 473 |
+
# now flatten, this is to allow slicing to work with pickles
|
| 474 |
+
|
| 475 |
+
sliced_instances = instances[:og_limit]
|
| 476 |
+
|
| 477 |
+
flattened_instances = [
|
| 478 |
+
instance
|
| 479 |
+
for instance_group in sliced_instances
|
| 480 |
+
for instance in instance_group
|
| 481 |
+
]
|
| 482 |
+
|
| 483 |
+
self._instances = flattened_instances
|
| 484 |
+
|
| 485 |
+
if len(self._instances) == 0:
|
| 486 |
+
raise ValueError("task.build_requests() did not find any docs!")
|
| 487 |
+
|
| 488 |
+
if cache_requests and (not cached_instances or rewrite_requests_cache):
|
| 489 |
+
save_to_cache(file_name=cache_key, obj=instances)
|
| 490 |
+
|
| 491 |
+
@abc.abstractmethod
|
| 492 |
+
def construct_requests(self, doc, ctx, **kwargs):
|
| 493 |
+
"""Uses RequestFactory to construct Requests and returns an iterable of
|
| 494 |
+
Requests which will be sent to the LM.
|
| 495 |
+
|
| 496 |
+
:param doc:
|
| 497 |
+
The document as returned from training_docs, validation_docs, or test_docs.
|
| 498 |
+
:param ctx: str
|
| 499 |
+
The context string, generated by fewshot_context. This includes the natural
|
| 500 |
+
language description, as well as the few shot examples, and the question
|
| 501 |
+
part of the document for `doc`.
|
| 502 |
+
:param doc_idx: int
|
| 503 |
+
The index of a document within `self.test_docs()` or `self.validation_docs()`,
|
| 504 |
+
whichever is the main split used.
|
| 505 |
+
:param repeats: int
|
| 506 |
+
TODO: update this docstring
|
| 507 |
+
The number of times each instance in a dataset is inferred on. Defaults to 1,
|
| 508 |
+
can be increased for techniques like majority voting.
|
| 509 |
+
"""
|
| 510 |
+
pass
|
| 511 |
+
|
| 512 |
+
@abc.abstractmethod
|
| 513 |
+
def process_results(self, doc, results):
|
| 514 |
+
"""Take a single document and the LM results and evaluates, returning a
|
| 515 |
+
dict where keys are the names of submetrics and values are the values of
|
| 516 |
+
the metric for that one document
|
| 517 |
+
|
| 518 |
+
:param doc:
|
| 519 |
+
The document as returned from training_docs, validation_docs, or test_docs.
|
| 520 |
+
:param results:
|
| 521 |
+
The results of the requests created in construct_requests.
|
| 522 |
+
"""
|
| 523 |
+
pass
|
| 524 |
+
|
| 525 |
+
@abc.abstractmethod
|
| 526 |
+
def aggregation(self):
|
| 527 |
+
"""
|
| 528 |
+
:returns: {str: [metric_score] -> float}
|
| 529 |
+
A dictionary where keys are the names of submetrics and values are
|
| 530 |
+
functions that aggregate a list of metric scores
|
| 531 |
+
"""
|
| 532 |
+
pass
|
| 533 |
+
|
| 534 |
+
@abc.abstractmethod
|
| 535 |
+
def higher_is_better(self):
|
| 536 |
+
"""
|
| 537 |
+
:returns: {str: bool}
|
| 538 |
+
A dictionary where keys are the names of submetrics and values are
|
| 539 |
+
whether a higher value of the submetric is better
|
| 540 |
+
"""
|
| 541 |
+
pass
|
| 542 |
+
|
| 543 |
+
def get_config(self, key: str) -> Any:
|
| 544 |
+
return getattr(self._config, key, None)
|
| 545 |
+
|
| 546 |
+
@classmethod
|
| 547 |
+
def count_bytes(cls, doc):
|
| 548 |
+
"""Used for byte-level perplexity metrics in rolling loglikelihood"""
|
| 549 |
+
return len(doc.encode("utf-8"))
|
| 550 |
+
|
| 551 |
+
@classmethod
|
| 552 |
+
def count_words(cls, doc):
|
| 553 |
+
"""Downstream loglikelihood_rolling perplexity tasks with custom word boundaries should override this!"""
|
| 554 |
+
return len(re.split(r"\s+", doc))
|
| 555 |
+
|
| 556 |
+
@utils.positional_deprecated
|
| 557 |
+
def fewshot_context(self, doc, num_fewshot, rnd=None, description=None, **kwargs):
|
| 558 |
+
"""Returns a fewshot context string that is made up of a prepended description
|
| 559 |
+
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
|
| 560 |
+
|
| 561 |
+
:param doc: str
|
| 562 |
+
The document as returned from training_docs, validation_docs, or test_docs.
|
| 563 |
+
:param num_fewshot: int
|
| 564 |
+
The number of fewshot examples to provide in the returned context string.
|
| 565 |
+
:param rnd: random.Random
|
| 566 |
+
The pseudo-random number generator used to randomly sample examples.
|
| 567 |
+
WARNING: This is currently a required arg although it's optionalized with a default `None`.
|
| 568 |
+
:param description: str
|
| 569 |
+
The task's description that will be prepended to the fewshot examples.
|
| 570 |
+
:returns: str
|
| 571 |
+
The fewshot context.
|
| 572 |
+
"""
|
| 573 |
+
if rnd is None:
|
| 574 |
+
if self.fewshot_rnd is not None:
|
| 575 |
+
rnd = self.fewshot_rnd
|
| 576 |
+
else:
|
| 577 |
+
raise ValueError(
|
| 578 |
+
"A `random.Random` generator argument must be provided to `rnd`"
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
description = description if description else ""
|
| 582 |
+
|
| 583 |
+
if num_fewshot == 0:
|
| 584 |
+
labeled_examples = ""
|
| 585 |
+
else:
|
| 586 |
+
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
|
| 587 |
+
if self.has_training_docs():
|
| 588 |
+
fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
|
| 589 |
+
else:
|
| 590 |
+
if self._fewshot_docs is None:
|
| 591 |
+
self._fewshot_docs = list(
|
| 592 |
+
self.validation_docs()
|
| 593 |
+
if self.has_validation_docs()
|
| 594 |
+
else self.test_docs()
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
|
| 598 |
+
|
| 599 |
+
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
|
| 600 |
+
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
|
| 601 |
+
|
| 602 |
+
labeled_examples = (
|
| 603 |
+
"\n\n".join(
|
| 604 |
+
[
|
| 605 |
+
self.doc_to_text(doc) + self.doc_to_target(doc)
|
| 606 |
+
for doc in fewshotex
|
| 607 |
+
]
|
| 608 |
+
)
|
| 609 |
+
+ "\n\n"
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
example = self.doc_to_text(doc)
|
| 613 |
+
return description + labeled_examples + example
|
| 614 |
+
|
| 615 |
+
def apply_filters(self) -> Optional[List[Instance]]:
|
| 616 |
+
"""Iterates over FilterEnsembles and applies them to instances"""
|
| 617 |
+
if hasattr(self, "_filters"):
|
| 618 |
+
for f in self._filters:
|
| 619 |
+
f.apply(self._instances)
|
| 620 |
+
else:
|
| 621 |
+
eval_logger.warning("No filter defined, passing through instances")
|
| 622 |
+
return self._instances
|
| 623 |
+
|
| 624 |
+
def dump_config(self) -> dict:
|
| 625 |
+
"""Returns the config as a dictionary."""
|
| 626 |
+
# TODO: this should only return the overrides applied to a non-YAML task's configuration.
|
| 627 |
+
# (num_fewshot)
|
| 628 |
+
return self.config.to_dict()
|
| 629 |
+
|
| 630 |
+
def set_config(self, key: str, value: Any, update: bool = False) -> None:
|
| 631 |
+
"""Set or update the configuration for a given key."""
|
| 632 |
+
if key is None:
|
| 633 |
+
raise ValueError("Key must be provided.")
|
| 634 |
+
|
| 635 |
+
if update:
|
| 636 |
+
current_value = getattr(self._config, key, {})
|
| 637 |
+
if not isinstance(current_value, dict):
|
| 638 |
+
raise TypeError(
|
| 639 |
+
f"Expected a dict for key '{key}', got {type(current_value).__name__} instead."
|
| 640 |
+
)
|
| 641 |
+
current_value.update(value)
|
| 642 |
+
else:
|
| 643 |
+
setattr(self._config, key, value)
|
| 644 |
+
|
| 645 |
+
def override_metric(self, metric_name: str) -> None:
|
| 646 |
+
"""
|
| 647 |
+
Override the default metrics used for evaluation with custom metrics.
|
| 648 |
+
|
| 649 |
+
Parameters:
|
| 650 |
+
- metric_name (str): The name of the custom metric to override. Should be registered in api.metrics.
|
| 651 |
+
"""
|
| 652 |
+
(
|
| 653 |
+
self._metric_fn_list,
|
| 654 |
+
self._aggregation_list,
|
| 655 |
+
self._metric_fn_kwargs,
|
| 656 |
+
self._higher_is_better,
|
| 657 |
+
) = ({}, {}, {}, {})
|
| 658 |
+
self._metric_fn_list[metric_name] = get_metric(metric_name)
|
| 659 |
+
self._aggregation_list[metric_name] = get_metric_aggregation(metric_name)
|
| 660 |
+
self._higher_is_better[metric_name] = is_higher_better(metric_name)
|
| 661 |
+
self._metric_fn_kwargs[metric_name] = {}
|
| 662 |
+
if not isinstance(self, ConfigurableTask):
|
| 663 |
+
self.process_results = lambda x, y: {metric_name: get_metric(metric_name)}
|
| 664 |
+
self.aggregation = lambda: {
|
| 665 |
+
metric_name: get_metric_aggregation(metric_name)
|
| 666 |
+
}
|
| 667 |
+
setattr(self._config, "metric_list", [{"metric": metric_name}])
|
| 668 |
+
setattr(self._config, "process_results", None)
|
| 669 |
+
|
| 670 |
+
def set_fewshot_seed(self, seed: Optional[int] = None) -> None:
|
| 671 |
+
self.fewshot_rnd = random.Random(seed)
|
| 672 |
+
if hasattr(self, "sampler"):
|
| 673 |
+
self.sampler.rnd = self.fewshot_rnd
|
| 674 |
+
|
| 675 |
+
@property
|
| 676 |
+
def eval_docs(self) -> Union[datasets.Dataset, List[dict]]:
|
| 677 |
+
if self.has_test_docs():
|
| 678 |
+
return self.test_docs()
|
| 679 |
+
elif self.has_validation_docs():
|
| 680 |
+
return self.validation_docs()
|
| 681 |
+
else:
|
| 682 |
+
raise ValueError(
|
| 683 |
+
f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
def doc_iterator(
|
| 687 |
+
self, *, rank: int = 0, limit: Union[int, None] = None, world_size: int = 1
|
| 688 |
+
) -> Iterator[Tuple[int, Any]]:
|
| 689 |
+
limit = int(limit) if limit else None
|
| 690 |
+
doc_iterator = utils.create_iterator(
|
| 691 |
+
enumerate(self.eval_docs),
|
| 692 |
+
rank=int(rank),
|
| 693 |
+
limit=limit,
|
| 694 |
+
world_size=int(world_size),
|
| 695 |
+
)
|
| 696 |
+
return doc_iterator
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
class ConfigurableTask(Task):
|
| 700 |
+
VERSION = "Yaml"
|
| 701 |
+
OUTPUT_TYPE = None
|
| 702 |
+
CONFIG = None
|
| 703 |
+
|
| 704 |
+
def __init__(
|
| 705 |
+
self,
|
| 706 |
+
data_dir=None,
|
| 707 |
+
cache_dir=None,
|
| 708 |
+
download_mode=None,
|
| 709 |
+
config: Optional[dict] = None,
|
| 710 |
+
) -> None: # TODO no super() call here
|
| 711 |
+
# Get pre-configured attributes
|
| 712 |
+
self._config = self.CONFIG
|
| 713 |
+
|
| 714 |
+
# Use new configurations if there was no preconfiguration
|
| 715 |
+
if self.config is None:
|
| 716 |
+
self._config = TaskConfig(**config)
|
| 717 |
+
# Overwrite configs
|
| 718 |
+
else:
|
| 719 |
+
if config is not None:
|
| 720 |
+
self._config.__dict__.update(config)
|
| 721 |
+
|
| 722 |
+
if self.config is None:
|
| 723 |
+
raise ValueError(
|
| 724 |
+
"Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg"
|
| 725 |
+
)
|
| 726 |
+
|
| 727 |
+
if isinstance(self.config.metadata, dict):
|
| 728 |
+
if "version" in self.config.metadata:
|
| 729 |
+
self.VERSION = self.config.metadata["version"]
|
| 730 |
+
|
| 731 |
+
if self.config.output_type is not None:
|
| 732 |
+
if self.config.output_type not in ALL_OUTPUT_TYPES:
|
| 733 |
+
raise ValueError(
|
| 734 |
+
f"Got invalid output_type '{self.config.output_type}', must be in '{','.join(ALL_OUTPUT_TYPES)}'"
|
| 735 |
+
)
|
| 736 |
+
self.OUTPUT_TYPE = self.config.output_type
|
| 737 |
+
|
| 738 |
+
if self.config.doc_to_image is not None:
|
| 739 |
+
# mark the task as requiring multimodality.
|
| 740 |
+
self.MULTIMODAL = True
|
| 741 |
+
|
| 742 |
+
if self.config.doc_to_audio:
|
| 743 |
+
# mark the task as requiring multimodality.
|
| 744 |
+
self.MULTIMODAL = True
|
| 745 |
+
|
| 746 |
+
if self.config.unsafe_code is not False:
|
| 747 |
+
self.UNSAFE_CODE = True
|
| 748 |
+
|
| 749 |
+
if self.config.dataset_path is not None:
|
| 750 |
+
self.DATASET_PATH = self.config.dataset_path
|
| 751 |
+
|
| 752 |
+
if self.config.dataset_name is not None:
|
| 753 |
+
self.DATASET_NAME = self.config.dataset_name
|
| 754 |
+
|
| 755 |
+
self._metric_fn_list = {}
|
| 756 |
+
self._metric_fn_kwargs = {}
|
| 757 |
+
self._aggregation_list = {}
|
| 758 |
+
self._higher_is_better = {}
|
| 759 |
+
|
| 760 |
+
if self.config.metric_list is None:
|
| 761 |
+
# TODO: handle this in TaskConfig.__post_init__ ?
|
| 762 |
+
_metric_list = DEFAULT_METRIC_REGISTRY[self.config.output_type]
|
| 763 |
+
|
| 764 |
+
for metric_name in _metric_list:
|
| 765 |
+
self._metric_fn_list[metric_name] = get_metric(metric_name)
|
| 766 |
+
self._metric_fn_kwargs[metric_name] = {}
|
| 767 |
+
self._aggregation_list[metric_name] = get_metric_aggregation(
|
| 768 |
+
metric_name
|
| 769 |
+
)
|
| 770 |
+
self._higher_is_better[metric_name] = is_higher_better(metric_name)
|
| 771 |
+
else:
|
| 772 |
+
for metric_config in self.config.metric_list:
|
| 773 |
+
if "metric" not in metric_config:
|
| 774 |
+
raise ValueError(
|
| 775 |
+
"'metric' key not provided for an entry in 'metric_list', must be specified!"
|
| 776 |
+
)
|
| 777 |
+
metric_name = metric_config["metric"]
|
| 778 |
+
kwargs = {
|
| 779 |
+
key: metric_config[key]
|
| 780 |
+
for key in metric_config
|
| 781 |
+
if key
|
| 782 |
+
not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"]
|
| 783 |
+
}
|
| 784 |
+
hf_evaluate_metric = (
|
| 785 |
+
"hf_evaluate" in metric_config
|
| 786 |
+
and metric_config["hf_evaluate"] is True
|
| 787 |
+
)
|
| 788 |
+
|
| 789 |
+
if self.config.process_results is not None:
|
| 790 |
+
self._metric_fn_list[metric_name] = None
|
| 791 |
+
self._metric_fn_kwargs[metric_name] = {}
|
| 792 |
+
elif callable(metric_name):
|
| 793 |
+
metric_fn = metric_name.__call__
|
| 794 |
+
metric_name = metric_name.__name__
|
| 795 |
+
self._metric_fn_list[metric_name] = metric_fn
|
| 796 |
+
self._metric_fn_kwargs[metric_name] = kwargs
|
| 797 |
+
else:
|
| 798 |
+
self._metric_fn_list[metric_name] = get_metric(
|
| 799 |
+
metric_name, hf_evaluate_metric
|
| 800 |
+
)
|
| 801 |
+
self._metric_fn_kwargs[metric_name] = kwargs
|
| 802 |
+
|
| 803 |
+
if "aggregation" in metric_config:
|
| 804 |
+
agg_name = metric_config["aggregation"]
|
| 805 |
+
if isinstance(agg_name, str):
|
| 806 |
+
self._aggregation_list[metric_name] = get_aggregation(agg_name)
|
| 807 |
+
elif callable(agg_name): # noqa: E721
|
| 808 |
+
self._aggregation_list[metric_name] = metric_config[
|
| 809 |
+
"aggregation"
|
| 810 |
+
]
|
| 811 |
+
else:
|
| 812 |
+
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
|
| 813 |
+
metric_agg = get_metric_aggregation(metric_name)
|
| 814 |
+
eval_logger.warning(
|
| 815 |
+
f"[Task: {self.config.task}] metric {metric_name} is defined, but aggregation is not. "
|
| 816 |
+
f"using default "
|
| 817 |
+
f"aggregation={INV_AGG_REGISTRY[metric_agg]}"
|
| 818 |
+
)
|
| 819 |
+
self._aggregation_list[metric_name] = metric_agg
|
| 820 |
+
|
| 821 |
+
if "higher_is_better" in metric_config:
|
| 822 |
+
self._higher_is_better[metric_name] = metric_config[
|
| 823 |
+
"higher_is_better"
|
| 824 |
+
]
|
| 825 |
+
else:
|
| 826 |
+
eval_logger.warning(
|
| 827 |
+
f"[Task: {self.config.task}] metric {metric_name} is defined, but higher_is_better is not. "
|
| 828 |
+
f"using default "
|
| 829 |
+
f"higher_is_better={is_higher_better(metric_name)}"
|
| 830 |
+
)
|
| 831 |
+
self._higher_is_better[metric_name] = is_higher_better(metric_name)
|
| 832 |
+
|
| 833 |
+
self.download(self.config.dataset_kwargs)
|
| 834 |
+
self._training_docs = None
|
| 835 |
+
self._fewshot_docs = None
|
| 836 |
+
|
| 837 |
+
if self.config.filter_list is not None:
|
| 838 |
+
self._filters = []
|
| 839 |
+
for filter_config in self.config.filter_list:
|
| 840 |
+
filter_name = filter_config["name"]
|
| 841 |
+
filter_functions = filter_config["filter"]
|
| 842 |
+
components = []
|
| 843 |
+
for function in filter_functions:
|
| 844 |
+
kwargs = {
|
| 845 |
+
key: function[key] for key in function if key != "function"
|
| 846 |
+
}
|
| 847 |
+
components.append([function["function"], kwargs])
|
| 848 |
+
filter_pipeline = build_filter_ensemble(filter_name, components)
|
| 849 |
+
self._filters.append(filter_pipeline)
|
| 850 |
+
else:
|
| 851 |
+
# TODO: handle repeats in a more general way rather than just discarding
|
| 852 |
+
eval_logger.debug(
|
| 853 |
+
"No custom filters defined. Using default 'take_first' filter for handling repeats."
|
| 854 |
+
)
|
| 855 |
+
self._filters = [build_filter_ensemble("none", [["take_first", None]])]
|
| 856 |
+
|
| 857 |
+
if self.config.use_prompt is not None:
|
| 858 |
+
eval_logger.info(f"loading prompt {self.config.use_prompt}")
|
| 859 |
+
self.prompt = get_prompt(
|
| 860 |
+
self.config.use_prompt, self.DATASET_PATH, self.DATASET_NAME
|
| 861 |
+
)
|
| 862 |
+
else:
|
| 863 |
+
self.prompt = None
|
| 864 |
+
|
| 865 |
+
if self.fewshot_docs() is not None:
|
| 866 |
+
self.fewshot_rnd = (
|
| 867 |
+
random.Random()
|
| 868 |
+
) # setting with no seed, to be overridden at a later time
|
| 869 |
+
config_sampler: Union[str, Callable] = (
|
| 870 |
+
self.config.fewshot_config.get("sampler", "default")
|
| 871 |
+
if self.config.fewshot_config
|
| 872 |
+
else "default"
|
| 873 |
+
)
|
| 874 |
+
if isinstance(config_sampler, str):
|
| 875 |
+
self.sampler = samplers.get_sampler(config_sampler)(
|
| 876 |
+
list(self.fewshot_docs()), self, rnd=self.fewshot_rnd
|
| 877 |
+
)
|
| 878 |
+
elif callable(config_sampler) and issubclass(
|
| 879 |
+
config_sampler, samplers.ContextSampler
|
| 880 |
+
):
|
| 881 |
+
self.sampler = config_sampler(
|
| 882 |
+
docs=list(self.fewshot_docs()), task=self, rnd=self.fewshot_rnd
|
| 883 |
+
)
|
| 884 |
+
else:
|
| 885 |
+
raise TypeError(
|
| 886 |
+
f"fewshot_config.sampler should be a string or callable of ContextSampler type, "
|
| 887 |
+
f"not {type(config_sampler)}"
|
| 888 |
+
)
|
| 889 |
+
|
| 890 |
+
self.task_docs = self.eval_docs
|
| 891 |
+
|
| 892 |
+
# Test One Doc
|
| 893 |
+
self.features = list(self.task_docs.features.keys())
|
| 894 |
+
self.multiple_input = 0
|
| 895 |
+
self.multiple_target = 0
|
| 896 |
+
test_doc = self.task_docs[0]
|
| 897 |
+
test_text = self.doc_to_text(test_doc)
|
| 898 |
+
test_target = self.doc_to_target(test_doc)
|
| 899 |
+
|
| 900 |
+
if self.config.doc_to_choice is not None:
|
| 901 |
+
test_choice = self.doc_to_choice(test_doc)
|
| 902 |
+
if not isinstance(test_choice, list):
|
| 903 |
+
eval_logger.error("doc_to_choice must return list")
|
| 904 |
+
else:
|
| 905 |
+
num_choice = len(test_choice)
|
| 906 |
+
|
| 907 |
+
if isinstance(test_text, int):
|
| 908 |
+
self.multiple_input = num_choice
|
| 909 |
+
else:
|
| 910 |
+
test_choice = None
|
| 911 |
+
|
| 912 |
+
if isinstance(test_target, list):
|
| 913 |
+
self.multiple_target = len(test_target)
|
| 914 |
+
else:
|
| 915 |
+
if (isinstance(test_target, int)) and (test_choice is not None):
|
| 916 |
+
test_target = test_choice[test_target]
|
| 917 |
+
else:
|
| 918 |
+
test_target = str(test_target)
|
| 919 |
+
|
| 920 |
+
if test_choice is not None:
|
| 921 |
+
check_choices = test_choice
|
| 922 |
+
else:
|
| 923 |
+
check_choices = [test_target]
|
| 924 |
+
if self.config.doc_to_choice is not None:
|
| 925 |
+
for choice in check_choices:
|
| 926 |
+
choice_has_whitespace = True if choice[0].isspace() else False
|
| 927 |
+
delimiter_has_whitespace = (
|
| 928 |
+
True
|
| 929 |
+
if self.config.target_delimiter.rstrip()
|
| 930 |
+
!= self.config.target_delimiter
|
| 931 |
+
else False
|
| 932 |
+
)
|
| 933 |
+
|
| 934 |
+
if delimiter_has_whitespace and choice_has_whitespace:
|
| 935 |
+
eval_logger.debug(
|
| 936 |
+
f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" have whitespace'
|
| 937 |
+
)
|
| 938 |
+
elif (not delimiter_has_whitespace) and (not choice_has_whitespace):
|
| 939 |
+
eval_logger.debug(
|
| 940 |
+
f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
|
| 941 |
+
)
|
| 942 |
+
|
| 943 |
+
def download(
|
| 944 |
+
self, dataset_kwargs: Optional[Dict[str, Any]] = None, **kwargs
|
| 945 |
+
) -> None:
|
| 946 |
+
if isinstance(self.config.custom_dataset, Callable):
|
| 947 |
+
eval_logger.warning(
|
| 948 |
+
f"{self.config.task}: Custom kwargs can be passed to `--metadata` in console (as json string) or to the TaskManager."
|
| 949 |
+
+ "\nFor example --metadata='{\"max_seq_lengths\":[4096, 8192]}'. For details see task Readme."
|
| 950 |
+
)
|
| 951 |
+
self.dataset = self.config.custom_dataset(
|
| 952 |
+
**(self.config.metadata or {}), **(self.config.dataset_kwargs or {})
|
| 953 |
+
)
|
| 954 |
+
else:
|
| 955 |
+
self.dataset = datasets.load_dataset(
|
| 956 |
+
path=self.DATASET_PATH,
|
| 957 |
+
name=self.DATASET_NAME,
|
| 958 |
+
**dataset_kwargs if dataset_kwargs is not None else {},
|
| 959 |
+
)
|
| 960 |
+
|
| 961 |
+
def has_training_docs(self) -> bool:
|
| 962 |
+
if self.config.training_split is not None:
|
| 963 |
+
return True
|
| 964 |
+
else:
|
| 965 |
+
return False
|
| 966 |
+
|
| 967 |
+
def has_validation_docs(self) -> bool:
|
| 968 |
+
if self.config.validation_split is not None:
|
| 969 |
+
return True
|
| 970 |
+
else:
|
| 971 |
+
return False
|
| 972 |
+
|
| 973 |
+
def has_test_docs(self) -> bool:
|
| 974 |
+
if self.config.test_split is not None:
|
| 975 |
+
return True
|
| 976 |
+
else:
|
| 977 |
+
return False
|
| 978 |
+
|
| 979 |
+
def training_docs(self) -> datasets.Dataset:
|
| 980 |
+
if self.has_training_docs():
|
| 981 |
+
if self.config.process_docs is not None:
|
| 982 |
+
return self.config.process_docs(
|
| 983 |
+
self.dataset[self.config.training_split]
|
| 984 |
+
)
|
| 985 |
+
return self.dataset[self.config.training_split]
|
| 986 |
+
|
| 987 |
+
def validation_docs(self) -> datasets.Dataset:
|
| 988 |
+
if self.has_validation_docs():
|
| 989 |
+
if self.config.process_docs is not None:
|
| 990 |
+
return self.config.process_docs(
|
| 991 |
+
self.dataset[self.config.validation_split]
|
| 992 |
+
)
|
| 993 |
+
return self.dataset[self.config.validation_split]
|
| 994 |
+
|
| 995 |
+
def test_docs(self) -> datasets.Dataset:
|
| 996 |
+
if self.has_test_docs():
|
| 997 |
+
if self.config.process_docs is not None:
|
| 998 |
+
return self.config.process_docs(self.dataset[self.config.test_split])
|
| 999 |
+
return self.dataset[self.config.test_split]
|
| 1000 |
+
|
| 1001 |
+
def fewshot_docs(self):
|
| 1002 |
+
if self.config.fewshot_split is not None:
|
| 1003 |
+
if self.config.process_docs is not None:
|
| 1004 |
+
return self.config.process_docs(self.dataset[self.config.fewshot_split])
|
| 1005 |
+
return self.dataset[self.config.fewshot_split]
|
| 1006 |
+
elif (
|
| 1007 |
+
self.config.fewshot_config is not None
|
| 1008 |
+
and self.config.fewshot_config.get("samples", None) is not None
|
| 1009 |
+
):
|
| 1010 |
+
if isinstance(self.config.fewshot_config["samples"], list):
|
| 1011 |
+
return self.config.fewshot_config["samples"]
|
| 1012 |
+
elif callable(self.config.fewshot_config["samples"]):
|
| 1013 |
+
return self.config.fewshot_config["samples"]()
|
| 1014 |
+
else:
|
| 1015 |
+
raise Exception(
|
| 1016 |
+
"`fewshot_config['samples']` was incorrectly defined in the configuration. It should be either a list of samples as a dict, or function returning this list."
|
| 1017 |
+
)
|
| 1018 |
+
else:
|
| 1019 |
+
if (self.config.num_fewshot is not None) and (self.config.num_fewshot > 0):
|
| 1020 |
+
eval_logger.warning(
|
| 1021 |
+
f"[Task: {self.config.task}] "
|
| 1022 |
+
"num_fewshot > 0 but fewshot_split is None. "
|
| 1023 |
+
"using preconfigured rule."
|
| 1024 |
+
)
|
| 1025 |
+
return super().fewshot_docs()
|
| 1026 |
+
|
| 1027 |
+
@staticmethod
|
| 1028 |
+
def append_target_question(
|
| 1029 |
+
labeled_examples: List[Dict[str, str]],
|
| 1030 |
+
question: str,
|
| 1031 |
+
fewshot_as_multiturn: bool = False,
|
| 1032 |
+
gen_prefix: Optional[str] = None,
|
| 1033 |
+
) -> None:
|
| 1034 |
+
"""Adds a target question to the labeled examples list.
|
| 1035 |
+
If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry.
|
| 1036 |
+
Otherwise, it is appended to the last user entry, ensuring that the conversation alternates between the user and the assistant.
|
| 1037 |
+
"""
|
| 1038 |
+
if not fewshot_as_multiturn:
|
| 1039 |
+
# if no messages or last message is system, append as new user entry
|
| 1040 |
+
if len(labeled_examples) == 0 or labeled_examples[-1]["role"] == "system":
|
| 1041 |
+
labeled_examples.append({"role": "user", "content": question})
|
| 1042 |
+
# if last message is user, append to it to avoid two user messages in a row
|
| 1043 |
+
else:
|
| 1044 |
+
labeled_examples[-1]["content"] += question
|
| 1045 |
+
else:
|
| 1046 |
+
# if fewshot_as_multiturn is True, append as next user entry (last is always assistant)
|
| 1047 |
+
labeled_examples.append({"role": "user", "content": question})
|
| 1048 |
+
if gen_prefix:
|
| 1049 |
+
labeled_examples.append({"role": "assistant", "content": gen_prefix})
|
| 1050 |
+
|
| 1051 |
+
@utils.positional_deprecated
|
| 1052 |
+
def fewshot_context(
|
| 1053 |
+
self,
|
| 1054 |
+
doc: dict,
|
| 1055 |
+
num_fewshot: int,
|
| 1056 |
+
system_instruction: Optional[str] = None,
|
| 1057 |
+
apply_chat_template: bool = False,
|
| 1058 |
+
fewshot_as_multiturn: bool = False,
|
| 1059 |
+
chat_template: Optional[Callable] = None,
|
| 1060 |
+
gen_prefix: Optional[str] = None,
|
| 1061 |
+
) -> Union[str, List[str]]:
|
| 1062 |
+
"""Returns a fewshot context string that is made up of a prepended description
|
| 1063 |
+
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
|
| 1064 |
+
|
| 1065 |
+
:param doc: str
|
| 1066 |
+
The document as returned from training_docs, validation_docs, or test_docs.
|
| 1067 |
+
:param num_fewshot: int
|
| 1068 |
+
The number of fewshot examples to provide in the returned context string.
|
| 1069 |
+
:param system_instruction: str
|
| 1070 |
+
System instruction to be applied to the prompt.
|
| 1071 |
+
:param apply_chat_template: bool
|
| 1072 |
+
Whether to apply the chat template to the fewshot context.
|
| 1073 |
+
:param fewshot_as_multiturn: bool
|
| 1074 |
+
Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
|
| 1075 |
+
:param chat_template:
|
| 1076 |
+
callable (from lm.apply_chat_template) that takes in a list[Dict] chat transcript and renders it into a string.
|
| 1077 |
+
:param gen_prefix:
|
| 1078 |
+
String to append after the <|assistant|> token.
|
| 1079 |
+
:returns: str
|
| 1080 |
+
The fewshot context.
|
| 1081 |
+
"""
|
| 1082 |
+
if apply_chat_template:
|
| 1083 |
+
labeled_examples = []
|
| 1084 |
+
else:
|
| 1085 |
+
labeled_examples = ""
|
| 1086 |
+
|
| 1087 |
+
# get task description
|
| 1088 |
+
if description := self.config.description:
|
| 1089 |
+
description = utils.apply_template(self.config.description, doc)
|
| 1090 |
+
|
| 1091 |
+
# create system prompt based on the provided system instruction and description
|
| 1092 |
+
if system_instruction is not None and description:
|
| 1093 |
+
system_prompt = (
|
| 1094 |
+
f"{system_instruction}{self.sampler.fewshot_delimiter}{description}"
|
| 1095 |
+
)
|
| 1096 |
+
elif system_instruction is not None:
|
| 1097 |
+
system_prompt = system_instruction
|
| 1098 |
+
elif description:
|
| 1099 |
+
system_prompt = description
|
| 1100 |
+
else:
|
| 1101 |
+
system_prompt = ""
|
| 1102 |
+
|
| 1103 |
+
# add system prompt if specified
|
| 1104 |
+
if system_prompt:
|
| 1105 |
+
if apply_chat_template:
|
| 1106 |
+
labeled_examples.append({"role": "system", "content": system_prompt})
|
| 1107 |
+
else:
|
| 1108 |
+
labeled_examples = system_prompt
|
| 1109 |
+
# if few-shot - append examples after the system prompt
|
| 1110 |
+
if num_fewshot > 0:
|
| 1111 |
+
if apply_chat_template:
|
| 1112 |
+
labeled_examples.extend(
|
| 1113 |
+
self.sampler.get_chat_context(
|
| 1114 |
+
doc,
|
| 1115 |
+
num_fewshot,
|
| 1116 |
+
fewshot_as_multiturn,
|
| 1117 |
+
gen_prefix=gen_prefix,
|
| 1118 |
+
)
|
| 1119 |
+
)
|
| 1120 |
+
else:
|
| 1121 |
+
labeled_examples += self.sampler.get_context(
|
| 1122 |
+
doc, num_fewshot, gen_prefix=gen_prefix
|
| 1123 |
+
)
|
| 1124 |
+
|
| 1125 |
+
example = self.doc_to_text(doc)
|
| 1126 |
+
if apply_chat_template:
|
| 1127 |
+
if self.multiple_input:
|
| 1128 |
+
# TODO: append prefill?
|
| 1129 |
+
if not labeled_examples:
|
| 1130 |
+
return ""
|
| 1131 |
+
return chat_template(labeled_examples)
|
| 1132 |
+
if isinstance(example, str):
|
| 1133 |
+
self.append_target_question(
|
| 1134 |
+
labeled_examples,
|
| 1135 |
+
example,
|
| 1136 |
+
fewshot_as_multiturn,
|
| 1137 |
+
gen_prefix=gen_prefix,
|
| 1138 |
+
)
|
| 1139 |
+
# for loglikelihood create a list of questions with appended choices
|
| 1140 |
+
elif isinstance(example, list):
|
| 1141 |
+
labeled_examples_list = []
|
| 1142 |
+
# copy chat history for each example and append the answer
|
| 1143 |
+
for ex in example:
|
| 1144 |
+
chat = deepcopy(labeled_examples)
|
| 1145 |
+
self.append_target_question(
|
| 1146 |
+
chat,
|
| 1147 |
+
ex,
|
| 1148 |
+
fewshot_as_multiturn,
|
| 1149 |
+
gen_prefix=gen_prefix,
|
| 1150 |
+
)
|
| 1151 |
+
# TODO: append prefill?
|
| 1152 |
+
labeled_examples_list.append(
|
| 1153 |
+
chat_template(
|
| 1154 |
+
chat,
|
| 1155 |
+
add_generation_prompt=False if gen_prefix else True,
|
| 1156 |
+
)
|
| 1157 |
+
)
|
| 1158 |
+
return labeled_examples_list
|
| 1159 |
+
# if example is an integer, append the choice or convert to string
|
| 1160 |
+
elif isinstance(example, int):
|
| 1161 |
+
if self.config.doc_to_choice is not None:
|
| 1162 |
+
choices = self.doc_to_choice(doc)
|
| 1163 |
+
self.append_target_question(
|
| 1164 |
+
labeled_examples,
|
| 1165 |
+
choices[example],
|
| 1166 |
+
fewshot_as_multiturn,
|
| 1167 |
+
gen_prefix=gen_prefix,
|
| 1168 |
+
)
|
| 1169 |
+
else:
|
| 1170 |
+
self.append_target_question(
|
| 1171 |
+
labeled_examples,
|
| 1172 |
+
str(example),
|
| 1173 |
+
fewshot_as_multiturn,
|
| 1174 |
+
gen_prefix=gen_prefix,
|
| 1175 |
+
)
|
| 1176 |
+
# return lm.apply_chat_template(labeled_examples)
|
| 1177 |
+
return chat_template(
|
| 1178 |
+
labeled_examples,
|
| 1179 |
+
add_generation_prompt=False if gen_prefix else True,
|
| 1180 |
+
)
|
| 1181 |
+
else:
|
| 1182 |
+
prefix = (
|
| 1183 |
+
self.config.target_delimiter + gen_prefix
|
| 1184 |
+
if gen_prefix is not None
|
| 1185 |
+
else ""
|
| 1186 |
+
)
|
| 1187 |
+
if self.multiple_input:
|
| 1188 |
+
return labeled_examples
|
| 1189 |
+
if isinstance(example, str):
|
| 1190 |
+
return labeled_examples + example + prefix
|
| 1191 |
+
elif isinstance(example, list):
|
| 1192 |
+
return [labeled_examples + ex + prefix for ex in example]
|
| 1193 |
+
elif isinstance(example, int):
|
| 1194 |
+
if self.config.doc_to_choice is not None:
|
| 1195 |
+
choices = self.doc_to_choice(doc)
|
| 1196 |
+
return labeled_examples + choices[example] + prefix
|
| 1197 |
+
else:
|
| 1198 |
+
return labeled_examples + str(example) + prefix
|
| 1199 |
+
|
| 1200 |
+
def apply_filters(self) -> Optional[List[Instance]]:
|
| 1201 |
+
"""Iterates over FilterEnsembles and applies them to instances"""
|
| 1202 |
+
if hasattr(self, "_filters"):
|
| 1203 |
+
for f in self._filters:
|
| 1204 |
+
f.apply(self._instances)
|
| 1205 |
+
else:
|
| 1206 |
+
eval_logger.warning("No filter defined, passing through instances")
|
| 1207 |
+
return self._instances
|
| 1208 |
+
|
| 1209 |
+
def should_decontaminate(self):
|
| 1210 |
+
return self.config.should_decontaminate
|
| 1211 |
+
|
| 1212 |
+
def doc_to_decontamination_query(self, doc: dict):
|
| 1213 |
+
if self.config.should_decontaminate:
|
| 1214 |
+
if self.config.doc_to_decontamination_query is None:
|
| 1215 |
+
return self.doc_to_text(doc)
|
| 1216 |
+
else:
|
| 1217 |
+
doc_to_decontamination_query = self.config.doc_to_decontamination_query
|
| 1218 |
+
if doc_to_decontamination_query in self.features:
|
| 1219 |
+
return doc[doc_to_decontamination_query]
|
| 1220 |
+
elif callable(doc_to_decontamination_query):
|
| 1221 |
+
return doc_to_decontamination_query(doc)
|
| 1222 |
+
else:
|
| 1223 |
+
return ast.literal_eval(
|
| 1224 |
+
utils.apply_template(
|
| 1225 |
+
self.config.doc_to_decontamination_query, doc
|
| 1226 |
+
)
|
| 1227 |
+
)
|
| 1228 |
+
|
| 1229 |
+
def _process_doc(self, doc: dict) -> dict:
|
| 1230 |
+
"""
|
| 1231 |
+
Override this to process (detokenize, strip, replace, etc.) individual
|
| 1232 |
+
documents. This can be used in a map over documents of a data split.
|
| 1233 |
+
E.g. `map(self._process_doc, self.dataset["validation"])`
|
| 1234 |
+
|
| 1235 |
+
:return: dict
|
| 1236 |
+
The processed version of the specified `doc`.
|
| 1237 |
+
"""
|
| 1238 |
+
return doc
|
| 1239 |
+
|
| 1240 |
+
def doc_to_text(self, doc, doc_to_text=None):
|
| 1241 |
+
if self.prompt is not None:
|
| 1242 |
+
doc_to_text = self.prompt
|
| 1243 |
+
elif doc_to_text is not None:
|
| 1244 |
+
doc_to_text = doc_to_text
|
| 1245 |
+
else:
|
| 1246 |
+
doc_to_text = self.config.doc_to_text
|
| 1247 |
+
|
| 1248 |
+
if isinstance(doc_to_text, int):
|
| 1249 |
+
return doc_to_text
|
| 1250 |
+
elif isinstance(doc_to_text, str):
|
| 1251 |
+
if doc_to_text in self.features:
|
| 1252 |
+
# if self.config.doc_to_choice is not None:
|
| 1253 |
+
# return self.doc_to_choice(doc)[doc[doc_to_text]]
|
| 1254 |
+
# else:
|
| 1255 |
+
return doc[doc_to_text]
|
| 1256 |
+
else:
|
| 1257 |
+
text_string = utils.apply_template(doc_to_text, doc)
|
| 1258 |
+
if text_string.isdigit() and self._config.doc_to_choice is not None:
|
| 1259 |
+
return ast.literal_eval(text_string)
|
| 1260 |
+
else:
|
| 1261 |
+
return text_string
|
| 1262 |
+
elif callable(doc_to_text):
|
| 1263 |
+
return doc_to_text(doc)
|
| 1264 |
+
# Used when applying a Promptsource template
|
| 1265 |
+
elif hasattr(doc_to_text, "apply"):
|
| 1266 |
+
applied_prompt = doc_to_text.apply(doc)
|
| 1267 |
+
if len(applied_prompt) == 2:
|
| 1268 |
+
return applied_prompt[0]
|
| 1269 |
+
else:
|
| 1270 |
+
eval_logger.warning("Applied prompt returns empty string")
|
| 1271 |
+
return self.config.fewshot_delimiter
|
| 1272 |
+
else:
|
| 1273 |
+
print(type(doc_to_text))
|
| 1274 |
+
raise TypeError
|
| 1275 |
+
|
| 1276 |
+
def doc_to_target(self, doc: Mapping, doc_to_target=None) -> Union[int, str, list]:
|
| 1277 |
+
if self.prompt is not None:
|
| 1278 |
+
doc_to_target = self.prompt
|
| 1279 |
+
elif doc_to_target is not None:
|
| 1280 |
+
doc_to_target = doc_to_target
|
| 1281 |
+
else:
|
| 1282 |
+
doc_to_target = self.config.doc_to_target
|
| 1283 |
+
|
| 1284 |
+
if isinstance(doc_to_target, int):
|
| 1285 |
+
return doc_to_target
|
| 1286 |
+
elif isinstance(doc_to_target, str):
|
| 1287 |
+
if doc_to_target in self.features:
|
| 1288 |
+
# if self.config.doc_to_choice is not None:
|
| 1289 |
+
# return self.doc_to_choice(doc)[doc[doc_to_target]]
|
| 1290 |
+
# else:
|
| 1291 |
+
return doc[doc_to_target]
|
| 1292 |
+
else:
|
| 1293 |
+
target_string = utils.apply_template(doc_to_target, doc)
|
| 1294 |
+
if target_string.isdigit() and self._config.doc_to_choice is not None:
|
| 1295 |
+
return ast.literal_eval(target_string)
|
| 1296 |
+
elif (
|
| 1297 |
+
len(target_string) >= 2
|
| 1298 |
+
and (target_string[0] == "[")
|
| 1299 |
+
and (target_string[-1] == "]")
|
| 1300 |
+
):
|
| 1301 |
+
try:
|
| 1302 |
+
return ast.literal_eval(target_string)
|
| 1303 |
+
except (SyntaxError, ValueError):
|
| 1304 |
+
return target_string
|
| 1305 |
+
else:
|
| 1306 |
+
return target_string
|
| 1307 |
+
elif isinstance(doc_to_target, list):
|
| 1308 |
+
return doc_to_target
|
| 1309 |
+
elif callable(doc_to_target):
|
| 1310 |
+
return doc_to_target(doc)
|
| 1311 |
+
# Used when applying a Promptsource template
|
| 1312 |
+
elif hasattr(doc_to_target, "apply"):
|
| 1313 |
+
applied_prompt = doc_to_target.apply(doc)
|
| 1314 |
+
if len(applied_prompt) == 2:
|
| 1315 |
+
return applied_prompt[1]
|
| 1316 |
+
else:
|
| 1317 |
+
eval_logger.warning("Applied prompt returns empty string")
|
| 1318 |
+
return self.config.fewshot_delimiter
|
| 1319 |
+
else:
|
| 1320 |
+
raise TypeError
|
| 1321 |
+
|
| 1322 |
+
def doc_to_choice(self, doc: Any, doc_to_choice=None) -> List[str]:
|
| 1323 |
+
if self.prompt is not None:
|
| 1324 |
+
doc_to_choice = self.prompt
|
| 1325 |
+
elif doc_to_choice is not None:
|
| 1326 |
+
doc_to_choice = doc_to_choice
|
| 1327 |
+
elif self.config.doc_to_choice is None:
|
| 1328 |
+
eval_logger.error("doc_to_choice was called but not set in config")
|
| 1329 |
+
else:
|
| 1330 |
+
doc_to_choice = self.config.doc_to_choice
|
| 1331 |
+
|
| 1332 |
+
if isinstance(doc_to_choice, str):
|
| 1333 |
+
if doc_to_choice in self.features:
|
| 1334 |
+
return doc[doc_to_choice]
|
| 1335 |
+
else:
|
| 1336 |
+
return ast.literal_eval(utils.apply_template(doc_to_choice, doc))
|
| 1337 |
+
elif isinstance(doc_to_choice, list):
|
| 1338 |
+
return doc_to_choice
|
| 1339 |
+
elif isinstance(doc_to_choice, dict):
|
| 1340 |
+
return list(doc_to_choice.values())
|
| 1341 |
+
elif callable(doc_to_choice):
|
| 1342 |
+
return doc_to_choice(doc)
|
| 1343 |
+
elif hasattr(doc_to_choice, "get_answer_choices_list"):
|
| 1344 |
+
return doc_to_choice.get_answer_choices_list(doc)
|
| 1345 |
+
else:
|
| 1346 |
+
raise TypeError
|
| 1347 |
+
|
| 1348 |
+
def doc_to_image(self, doc: Any, doc_to_image=None) -> Union[int, str, list]:
|
| 1349 |
+
if doc_to_image is not None:
|
| 1350 |
+
doc_to_image = doc_to_image
|
| 1351 |
+
elif self.config.doc_to_image is not None:
|
| 1352 |
+
doc_to_image = self.config.doc_to_image
|
| 1353 |
+
else:
|
| 1354 |
+
return None
|
| 1355 |
+
|
| 1356 |
+
if isinstance(doc_to_image, list):
|
| 1357 |
+
image_feature = [
|
| 1358 |
+
self.doc_to_image(doc, feature) for feature in doc_to_image
|
| 1359 |
+
]
|
| 1360 |
+
return [feature for feature in image_feature if feature is not None]
|
| 1361 |
+
elif isinstance(doc_to_image, str):
|
| 1362 |
+
if doc_to_image in self.features:
|
| 1363 |
+
return doc[doc_to_image]
|
| 1364 |
+
else:
|
| 1365 |
+
return ast.literal_eval(utils.apply_template(doc_to_image, doc))
|
| 1366 |
+
elif callable(doc_to_image):
|
| 1367 |
+
return doc_to_image(doc)
|
| 1368 |
+
else:
|
| 1369 |
+
return None
|
| 1370 |
+
|
| 1371 |
+
def doc_to_audio(self, doc: Any, doc_to_audio=None) -> Union[int, str, list]:
|
| 1372 |
+
if doc_to_audio is not None:
|
| 1373 |
+
doc_to_audio = doc_to_audio
|
| 1374 |
+
elif self.config.doc_to_audio is not None:
|
| 1375 |
+
doc_to_audio = self.config.doc_to_audio
|
| 1376 |
+
else:
|
| 1377 |
+
return None
|
| 1378 |
+
|
| 1379 |
+
if isinstance(doc_to_audio, list):
|
| 1380 |
+
audio_feature = [
|
| 1381 |
+
self.doc_to_audio(doc, feature) for feature in doc_to_audio
|
| 1382 |
+
]
|
| 1383 |
+
return [feature for feature in audio_feature if feature is not None]
|
| 1384 |
+
elif isinstance(doc_to_audio, str):
|
| 1385 |
+
if doc_to_audio in self.features:
|
| 1386 |
+
return doc[doc_to_audio]
|
| 1387 |
+
else:
|
| 1388 |
+
return ast.literal_eval(utils.apply_template(doc_to_audio, doc))
|
| 1389 |
+
elif callable(doc_to_audio):
|
| 1390 |
+
return doc_to_audio(doc)
|
| 1391 |
+
else:
|
| 1392 |
+
return None
|
| 1393 |
+
|
| 1394 |
+
def doc_to_prefix(self, doc):
|
| 1395 |
+
if (gen_prefix := self.config.gen_prefix) is not None:
|
| 1396 |
+
if gen_prefix in self.features:
|
| 1397 |
+
return doc[gen_prefix]
|
| 1398 |
+
else:
|
| 1399 |
+
return utils.apply_template(gen_prefix, doc)
|
| 1400 |
+
return None
|
| 1401 |
+
|
| 1402 |
+
def construct_requests(
|
| 1403 |
+
self, doc: dict, ctx: str, **kwargs
|
| 1404 |
+
) -> Union[List[Instance], Instance]:
|
| 1405 |
+
apply_chat_template = kwargs.pop("apply_chat_template", False)
|
| 1406 |
+
chat_template: Callable | None = kwargs.pop("chat_template", None)
|
| 1407 |
+
|
| 1408 |
+
aux_arguments = None
|
| 1409 |
+
|
| 1410 |
+
if self.OUTPUT_TYPE == "loglikelihood":
|
| 1411 |
+
arguments = (ctx, self.doc_to_target(doc))
|
| 1412 |
+
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
|
| 1413 |
+
arguments = (self.doc_to_target(doc),)
|
| 1414 |
+
elif self.OUTPUT_TYPE == "multiple_choice":
|
| 1415 |
+
choices = self.doc_to_choice(doc)
|
| 1416 |
+
target_delimiter = self.config.target_delimiter
|
| 1417 |
+
if apply_chat_template:
|
| 1418 |
+
target_delimiter = ""
|
| 1419 |
+
if self.multiple_input:
|
| 1420 |
+
# If there are multiple inputs, choices are placed in the ctx
|
| 1421 |
+
# apply chat_template to choices if apply_chat_template
|
| 1422 |
+
cont = self.doc_to_target(doc)
|
| 1423 |
+
|
| 1424 |
+
arguments = [
|
| 1425 |
+
(
|
| 1426 |
+
ctx
|
| 1427 |
+
+ (
|
| 1428 |
+
chat_template([{"role": "user", "content": choice}])
|
| 1429 |
+
if apply_chat_template
|
| 1430 |
+
else choice
|
| 1431 |
+
),
|
| 1432 |
+
f"{target_delimiter}{cont}",
|
| 1433 |
+
)
|
| 1434 |
+
for choice in choices
|
| 1435 |
+
]
|
| 1436 |
+
else:
|
| 1437 |
+
# Otherwise they are placed in the continuation
|
| 1438 |
+
arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]
|
| 1439 |
+
|
| 1440 |
+
# TODO: we should raise a warning telling users this will at most ~2x runtime.
|
| 1441 |
+
if "acc_mutual_info" in self._metric_fn_list.keys():
|
| 1442 |
+
# if we are calculating multiple choice accuracy
|
| 1443 |
+
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.
|
| 1444 |
+
|
| 1445 |
+
# here mutual info refers to calculating
|
| 1446 |
+
# log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
|
| 1447 |
+
# in other words normalizing by subtracting the unconditional logprob of each choice.
|
| 1448 |
+
aux_arguments = [("", f"{choice}") for choice in choices]
|
| 1449 |
+
|
| 1450 |
+
arguments.extend(aux_arguments)
|
| 1451 |
+
|
| 1452 |
+
elif self.OUTPUT_TYPE == "generate_until":
|
| 1453 |
+
arguments = (ctx, deepcopy(self.config.generation_kwargs))
|
| 1454 |
+
|
| 1455 |
+
multimodal_arg = {}
|
| 1456 |
+
if (
|
| 1457 |
+
self.config.doc_to_image
|
| 1458 |
+
): # TODO: ensure that non-multimodal tasks aren't getting visual args
|
| 1459 |
+
multimodal_arg = {
|
| 1460 |
+
**multimodal_arg,
|
| 1461 |
+
**{"visual": self.doc_to_image(doc)},
|
| 1462 |
+
}
|
| 1463 |
+
|
| 1464 |
+
if (
|
| 1465 |
+
self.config.doc_to_audio
|
| 1466 |
+
): # TODO: ensure that non-multimodal tasks aren't getting audio args
|
| 1467 |
+
multimodal_arg = {
|
| 1468 |
+
**multimodal_arg,
|
| 1469 |
+
**{"audio": self.doc_to_audio(doc)},
|
| 1470 |
+
}
|
| 1471 |
+
|
| 1472 |
+
if bool(multimodal_arg):
|
| 1473 |
+
if isinstance(arguments, list):
|
| 1474 |
+
arguments = [arg + (multimodal_arg,) for arg in arguments]
|
| 1475 |
+
else:
|
| 1476 |
+
arguments = arguments + (multimodal_arg,)
|
| 1477 |
+
|
| 1478 |
+
if self.OUTPUT_TYPE == "multiple_choice":
|
| 1479 |
+
request_list = [
|
| 1480 |
+
Instance(
|
| 1481 |
+
request_type="loglikelihood",
|
| 1482 |
+
doc=doc,
|
| 1483 |
+
arguments=arg,
|
| 1484 |
+
idx=i,
|
| 1485 |
+
**kwargs,
|
| 1486 |
+
)
|
| 1487 |
+
for i, arg in enumerate(arguments)
|
| 1488 |
+
]
|
| 1489 |
+
|
| 1490 |
+
return request_list
|
| 1491 |
+
|
| 1492 |
+
return Instance(
|
| 1493 |
+
request_type=self.OUTPUT_TYPE,
|
| 1494 |
+
doc=doc,
|
| 1495 |
+
arguments=arguments,
|
| 1496 |
+
idx=0,
|
| 1497 |
+
**kwargs,
|
| 1498 |
+
)
|
| 1499 |
+
|
| 1500 |
+
def process_results(self, doc, results):
|
| 1501 |
+
if callable(self.config.process_results):
|
| 1502 |
+
return self.config.process_results(doc, results)
|
| 1503 |
+
|
| 1504 |
+
result_dict = {}
|
| 1505 |
+
use_metric = list(self._metric_fn_list.keys())
|
| 1506 |
+
if self.OUTPUT_TYPE == "loglikelihood":
|
| 1507 |
+
results = results[0]
|
| 1508 |
+
ll, is_greedy = results
|
| 1509 |
+
return {
|
| 1510 |
+
**({"perplexity": ll} if "perplexity" in use_metric else {}),
|
| 1511 |
+
**({"acc": int(is_greedy)} if "acc" in use_metric else {}),
|
| 1512 |
+
}
|
| 1513 |
+
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
|
| 1514 |
+
(loglikelihood,) = results
|
| 1515 |
+
_words = self.count_words(self.doc_to_target(doc))
|
| 1516 |
+
_bytes = self.count_bytes(self.doc_to_target(doc))
|
| 1517 |
+
return {
|
| 1518 |
+
**(
|
| 1519 |
+
{"word_perplexity": (loglikelihood, _words)}
|
| 1520 |
+
if "word_perplexity" in use_metric
|
| 1521 |
+
else {}
|
| 1522 |
+
),
|
| 1523 |
+
**(
|
| 1524 |
+
{"byte_perplexity": (loglikelihood, _bytes)}
|
| 1525 |
+
if "byte_perplexity" in use_metric
|
| 1526 |
+
else {}
|
| 1527 |
+
),
|
| 1528 |
+
**(
|
| 1529 |
+
{"bits_per_byte": (loglikelihood, _bytes)}
|
| 1530 |
+
if "bits_per_byte" in use_metric
|
| 1531 |
+
else {}
|
| 1532 |
+
),
|
| 1533 |
+
}
|
| 1534 |
+
elif self.OUTPUT_TYPE == "multiple_choice":
|
| 1535 |
+
lls, is_greedy = zip(*results)
|
| 1536 |
+
|
| 1537 |
+
# retrieve choices in List[str] form, to compute choice lengths, etc.
|
| 1538 |
+
choices = self.doc_to_choice(doc)
|
| 1539 |
+
completion_len = np.array([float(len(i)) for i in choices])
|
| 1540 |
+
|
| 1541 |
+
if (
|
| 1542 |
+
2 * len(choices) == len(lls)
|
| 1543 |
+
and "acc_mutual_info" in self._metric_fn_list.keys()
|
| 1544 |
+
):
|
| 1545 |
+
# then we are doing mutual info.
|
| 1546 |
+
# this stores the "dryrun" / unconditional answer loglikelihoods
|
| 1547 |
+
lls_unconditional = lls[1::2]
|
| 1548 |
+
if len(lls_unconditional) != len(choices):
|
| 1549 |
+
raise ValueError
|
| 1550 |
+
# and this stores our "regular" conditional loglikelihoods
|
| 1551 |
+
lls = lls[::2]
|
| 1552 |
+
|
| 1553 |
+
pred = np.argmax(lls)
|
| 1554 |
+
pred_norm = np.argmax(lls / completion_len)
|
| 1555 |
+
|
| 1556 |
+
if self.multiple_input:
|
| 1557 |
+
gold = self.doc_to_text(doc)
|
| 1558 |
+
else:
|
| 1559 |
+
gold = self.doc_to_target(doc)
|
| 1560 |
+
|
| 1561 |
+
gold_index_error = False
|
| 1562 |
+
if isinstance(gold, list):
|
| 1563 |
+
gold = [i if i < len(choices) else -100 for i in gold]
|
| 1564 |
+
if -100 in gold:
|
| 1565 |
+
gold_index_error = True
|
| 1566 |
+
else:
|
| 1567 |
+
if isinstance(gold, int):
|
| 1568 |
+
gold = gold if gold < len(choices) else -100
|
| 1569 |
+
elif isinstance(gold, str):
|
| 1570 |
+
gold = choices.index(gold) if gold in choices else -100
|
| 1571 |
+
|
| 1572 |
+
if gold == -100:
|
| 1573 |
+
gold_index_error = True
|
| 1574 |
+
|
| 1575 |
+
if gold_index_error:
|
| 1576 |
+
eval_logger.warning(
|
| 1577 |
+
f"Label index was not in within range of available choices,"
|
| 1578 |
+
f"Sample:\n\n{doc}\n\n"
|
| 1579 |
+
)
|
| 1580 |
+
|
| 1581 |
+
if self.multiple_target:
|
| 1582 |
+
acc = 1.0 if pred in gold else 0.0
|
| 1583 |
+
acc_norm = 1.0 if pred_norm in gold else 0.0
|
| 1584 |
+
exact_match = int(any([is_greedy[i] if i != -100 else 0 for i in gold]))
|
| 1585 |
+
else:
|
| 1586 |
+
acc = 1.0 if pred == gold else 0.0
|
| 1587 |
+
acc_norm = 1.0 if pred_norm == gold else 0.0
|
| 1588 |
+
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
|
| 1589 |
+
exact_match = int(is_greedy[gold]) if gold != -100 else 0
|
| 1590 |
+
|
| 1591 |
+
prob_norm = utils.softmax(lls)
|
| 1592 |
+
|
| 1593 |
+
# TODO use keyword arguments to the metric?
|
| 1594 |
+
# gold, pred, norm stuff, the original lls,
|
| 1595 |
+
result_dict = {
|
| 1596 |
+
**({"acc": acc} if "acc" in use_metric else {}),
|
| 1597 |
+
**({"f1": (gold, pred)} if "f1" in use_metric else {}),
|
| 1598 |
+
**({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
|
| 1599 |
+
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
|
| 1600 |
+
**({"exact_match": exact_match} if "exact_match" in use_metric else {}),
|
| 1601 |
+
**(
|
| 1602 |
+
{"brier_score": (gold, prob_norm)}
|
| 1603 |
+
if "brier_score" in use_metric
|
| 1604 |
+
else {}
|
| 1605 |
+
),
|
| 1606 |
+
}
|
| 1607 |
+
|
| 1608 |
+
if "acc_mutual_info" in use_metric:
|
| 1609 |
+
lls_mutual_info = [
|
| 1610 |
+
ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional)
|
| 1611 |
+
]
|
| 1612 |
+
acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0
|
| 1613 |
+
result_dict["acc_mutual_info"] = acc_mutual_info
|
| 1614 |
+
|
| 1615 |
+
elif self.OUTPUT_TYPE == "generate_until":
|
| 1616 |
+
gold = self.doc_to_target(doc)
|
| 1617 |
+
result = results[0]
|
| 1618 |
+
if self.config.doc_to_choice is not None:
|
| 1619 |
+
# If you set doc_to_choice,
|
| 1620 |
+
# it assumes that doc_to_target returns a number.
|
| 1621 |
+
choices = self.doc_to_choice(doc)
|
| 1622 |
+
gold = choices[gold]
|
| 1623 |
+
# we expect multiple_targets to be a list.
|
| 1624 |
+
elif self.multiple_target:
|
| 1625 |
+
gold = list(gold)
|
| 1626 |
+
# TODO: handle this better
|
| 1627 |
+
elif type(gold) is not type(result) and not (
|
| 1628 |
+
"bypass" in self._metric_fn_list.keys() or isinstance(result, list)
|
| 1629 |
+
):
|
| 1630 |
+
# cast gold to the same type as result
|
| 1631 |
+
gold = type(result)(gold)
|
| 1632 |
+
|
| 1633 |
+
for metric in self._metric_fn_list.keys():
|
| 1634 |
+
if self.multiple_target:
|
| 1635 |
+
# in the case where we have multiple targets,
|
| 1636 |
+
# return true if any are true
|
| 1637 |
+
# TODO: this may break for multipLe_target, non zero-or-1 metrics
|
| 1638 |
+
scores = []
|
| 1639 |
+
if not isinstance(gold, list):
|
| 1640 |
+
# sometimes, a multiple_target dataset has exceptions where one doc has only one string answer
|
| 1641 |
+
# print(gold)
|
| 1642 |
+
gold = [gold]
|
| 1643 |
+
if metric == "exact_match":
|
| 1644 |
+
result = [result for _ in range(len(gold))]
|
| 1645 |
+
scores = self._metric_fn_list[metric](
|
| 1646 |
+
references=gold,
|
| 1647 |
+
predictions=result,
|
| 1648 |
+
**self._metric_fn_kwargs[metric],
|
| 1649 |
+
)[metric]
|
| 1650 |
+
result_score = 1.0 if scores > 0.0 else 0.0
|
| 1651 |
+
else:
|
| 1652 |
+
for gold_option in gold:
|
| 1653 |
+
try:
|
| 1654 |
+
result_score = self._metric_fn_list[metric](
|
| 1655 |
+
references=[gold_option],
|
| 1656 |
+
predictions=[result],
|
| 1657 |
+
**self._metric_fn_kwargs[metric],
|
| 1658 |
+
)
|
| 1659 |
+
except (
|
| 1660 |
+
TypeError
|
| 1661 |
+
): # TODO: this is hacky and I don't want to do it
|
| 1662 |
+
result_score = self._metric_fn_list[metric](
|
| 1663 |
+
[gold_option, result]
|
| 1664 |
+
)
|
| 1665 |
+
if isinstance(result_score, dict):
|
| 1666 |
+
# TODO: this handles the case where HF evaluate returns a dict.
|
| 1667 |
+
result_score = result_score[metric]
|
| 1668 |
+
scores.append(result_score)
|
| 1669 |
+
if any(scores):
|
| 1670 |
+
result_score = 1.0
|
| 1671 |
+
else:
|
| 1672 |
+
result_score = 0.0
|
| 1673 |
+
else:
|
| 1674 |
+
try:
|
| 1675 |
+
result_score = self._metric_fn_list[metric](
|
| 1676 |
+
references=[gold],
|
| 1677 |
+
predictions=[result],
|
| 1678 |
+
**self._metric_fn_kwargs[metric],
|
| 1679 |
+
)
|
| 1680 |
+
except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
|
| 1681 |
+
result_score = self._metric_fn_list[metric]([gold, result])
|
| 1682 |
+
if isinstance(result_score, dict):
|
| 1683 |
+
# TODO: this handles the case where HF evaluate returns a dict.
|
| 1684 |
+
# This allows for multiple metrics to be returned from the same function
|
| 1685 |
+
for k, v in result_score.items():
|
| 1686 |
+
result_dict[k] = v
|
| 1687 |
+
else:
|
| 1688 |
+
result_dict[metric] = result_score
|
| 1689 |
+
else:
|
| 1690 |
+
raise ValueError(
|
| 1691 |
+
f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
|
| 1692 |
+
"'loglikelihood', 'loglikelihood_rolling', 'generate_until' or 'multiple_choice'",
|
| 1693 |
+
)
|
| 1694 |
+
|
| 1695 |
+
return result_dict
|
| 1696 |
+
|
| 1697 |
+
def aggregation(self) -> dict:
|
| 1698 |
+
return self._aggregation_list
|
| 1699 |
+
|
| 1700 |
+
def higher_is_better(self) -> dict:
|
| 1701 |
+
return self._higher_is_better
|
| 1702 |
+
|
| 1703 |
+
def get_config(self, key: str) -> Any:
|
| 1704 |
+
return getattr(self._config, key, None)
|
| 1705 |
+
|
| 1706 |
+
@property
|
| 1707 |
+
def task_name(self) -> Any:
|
| 1708 |
+
return getattr(self.config, "task", None)
|
| 1709 |
+
|
| 1710 |
+
def __repr__(self):
|
| 1711 |
+
return (
|
| 1712 |
+
f"ConfigurableTask(task_name={getattr(self.config, 'task', None)},"
|
| 1713 |
+
f"output_type={self.OUTPUT_TYPE},"
|
| 1714 |
+
f"num_fewshot={getattr(self.config, 'num_fewshot', None)},"
|
| 1715 |
+
f"num_samples={len(self.eval_docs)})"
|
| 1716 |
+
)
|
| 1717 |
+
|
| 1718 |
+
|
| 1719 |
+
class MultipleChoiceTask(Task):
|
| 1720 |
+
OUTPUT_TYPE = "loglikelihood"
|
| 1721 |
+
|
| 1722 |
+
def doc_to_target(self, doc: dict) -> str:
|
| 1723 |
+
return " " + doc["choices"][doc["gold"]]
|
| 1724 |
+
|
| 1725 |
+
def construct_requests(self, doc: dict, ctx: str, **kwargs) -> List[Instance]:
|
| 1726 |
+
# TODO: add mutual info here?
|
| 1727 |
+
return [
|
| 1728 |
+
Instance(
|
| 1729 |
+
request_type="loglikelihood",
|
| 1730 |
+
doc=doc,
|
| 1731 |
+
arguments=(ctx, " {}".format(choice)),
|
| 1732 |
+
idx=i,
|
| 1733 |
+
**kwargs,
|
| 1734 |
+
)
|
| 1735 |
+
for i, choice in enumerate(doc["choices"])
|
| 1736 |
+
]
|
| 1737 |
+
|
| 1738 |
+
def process_results(self, doc: dict, results: Iterable[Tuple[float, bool]]) -> dict:
|
| 1739 |
+
results = [
|
| 1740 |
+
res[0] for res in results
|
| 1741 |
+
] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
|
| 1742 |
+
gold = doc["gold"]
|
| 1743 |
+
|
| 1744 |
+
acc = 1.0 if np.argmax(results) == gold else 0.0
|
| 1745 |
+
completion_len = np.array([float(len(i)) for i in doc["choices"]])
|
| 1746 |
+
acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0
|
| 1747 |
+
|
| 1748 |
+
return {
|
| 1749 |
+
"acc": acc,
|
| 1750 |
+
"acc_norm": acc_norm,
|
| 1751 |
+
}
|
| 1752 |
+
|
| 1753 |
+
def higher_is_better(self) -> dict:
|
| 1754 |
+
return {
|
| 1755 |
+
"acc": True,
|
| 1756 |
+
"acc_norm": True,
|
| 1757 |
+
}
|
| 1758 |
+
|
| 1759 |
+
def aggregation(self) -> dict:
|
| 1760 |
+
return {
|
| 1761 |
+
"acc": mean,
|
| 1762 |
+
"acc_norm": mean,
|
| 1763 |
+
}
|
| 1764 |
+
|
| 1765 |
+
|
| 1766 |
+
class PerplexityTask(Task):
|
| 1767 |
+
OUTPUT_TYPE = "loglikelihood_rolling"
|
| 1768 |
+
|
| 1769 |
+
def has_training_docs(self) -> bool:
|
| 1770 |
+
return False
|
| 1771 |
+
|
| 1772 |
+
def fewshot_examples(self, k: int, rnd) -> List:
|
| 1773 |
+
if k != 0:
|
| 1774 |
+
raise ValueError(
|
| 1775 |
+
"The number of fewshot examples must be 0 for perplexity tasks."
|
| 1776 |
+
)
|
| 1777 |
+
return []
|
| 1778 |
+
|
| 1779 |
+
def fewshot_context(self, doc: dict, num_fewshot: int) -> Literal[""]:
|
| 1780 |
+
if num_fewshot != 0:
|
| 1781 |
+
raise ValueError(
|
| 1782 |
+
"The number of fewshot examples must be 0 for perplexity tasks."
|
| 1783 |
+
)
|
| 1784 |
+
|
| 1785 |
+
return ""
|
| 1786 |
+
|
| 1787 |
+
def higher_is_better(self) -> dict:
|
| 1788 |
+
return {
|
| 1789 |
+
"word_perplexity": False,
|
| 1790 |
+
"byte_perplexity": False,
|
| 1791 |
+
"bits_per_byte": False,
|
| 1792 |
+
}
|
| 1793 |
+
|
| 1794 |
+
def doc_to_decontamination_query(self, doc):
|
| 1795 |
+
return doc
|
| 1796 |
+
|
| 1797 |
+
def doc_to_text(self, doc) -> str:
|
| 1798 |
+
return ""
|
| 1799 |
+
|
| 1800 |
+
def doc_to_target(self, doc):
|
| 1801 |
+
return doc
|
| 1802 |
+
|
| 1803 |
+
def construct_requests(self, doc: dict, ctx: Optional[str], **kwargs):
|
| 1804 |
+
if bool(ctx):
|
| 1805 |
+
raise ValueError
|
| 1806 |
+
|
| 1807 |
+
return Instance(
|
| 1808 |
+
request_type=self.OUTPUT_TYPE,
|
| 1809 |
+
doc=doc,
|
| 1810 |
+
arguments=(self.doc_to_target(doc),),
|
| 1811 |
+
idx=0,
|
| 1812 |
+
**kwargs,
|
| 1813 |
+
)
|
| 1814 |
+
|
| 1815 |
+
def process_results(self, doc: dict, results: Tuple[float]) -> dict:
|
| 1816 |
+
(loglikelihood,) = results
|
| 1817 |
+
words = self.count_words(self.doc_to_target(doc))
|
| 1818 |
+
bytes_ = self.count_bytes(self.doc_to_target(doc))
|
| 1819 |
+
return {
|
| 1820 |
+
"word_perplexity": (loglikelihood, words),
|
| 1821 |
+
"byte_perplexity": (loglikelihood, bytes_),
|
| 1822 |
+
"bits_per_byte": (loglikelihood, bytes_),
|
| 1823 |
+
}
|
| 1824 |
+
|
| 1825 |
+
def aggregation(self) -> dict:
|
| 1826 |
+
return {
|
| 1827 |
+
"word_perplexity": weighted_perplexity,
|
| 1828 |
+
"byte_perplexity": weighted_perplexity,
|
| 1829 |
+
"bits_per_byte": bits_per_byte,
|
| 1830 |
+
}
|
| 1831 |
+
|
| 1832 |
+
@classmethod
|
| 1833 |
+
def count_bytes(cls, doc) -> int:
|
| 1834 |
+
return len(doc.encode("utf-8"))
|
| 1835 |
+
|
| 1836 |
+
@classmethod
|
| 1837 |
+
def count_words(cls, doc) -> int:
|
| 1838 |
+
"""Downstream tasks with custom word boundaries should override this!"""
|
| 1839 |
+
return len(re.split(r"\s+", doc))
|
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/caching/__init__.py
ADDED
|
File without changes
|
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/caching/cache.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import dill
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
eval_logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
MODULE_DIR = os.path.dirname(os.path.realpath(__file__))
|
| 12 |
+
|
| 13 |
+
OVERRIDE_PATH = os.getenv("LM_HARNESS_CACHE_PATH")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
PATH = OVERRIDE_PATH if OVERRIDE_PATH else f"{MODULE_DIR}/.cache"
|
| 17 |
+
|
| 18 |
+
# This should be sufficient for uniqueness
|
| 19 |
+
HASH_INPUT = "EleutherAI-lm-evaluation-harness"
|
| 20 |
+
|
| 21 |
+
HASH_PREFIX = hashlib.sha256(HASH_INPUT.encode("utf-8")).hexdigest()
|
| 22 |
+
|
| 23 |
+
FILE_SUFFIX = f".{HASH_PREFIX}.pickle"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def load_from_cache(file_name: str, cache: bool = False):
|
| 27 |
+
if not cache:
|
| 28 |
+
return
|
| 29 |
+
try:
|
| 30 |
+
path = f"{PATH}/{file_name}{FILE_SUFFIX}"
|
| 31 |
+
|
| 32 |
+
with open(path, "rb") as file:
|
| 33 |
+
cached_task_dict = dill.loads(file.read())
|
| 34 |
+
return cached_task_dict
|
| 35 |
+
|
| 36 |
+
except Exception:
|
| 37 |
+
eval_logger.debug(f"{file_name} is not cached, generating...")
|
| 38 |
+
pass
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def save_to_cache(file_name, obj):
|
| 42 |
+
if not os.path.exists(PATH):
|
| 43 |
+
os.mkdir(PATH)
|
| 44 |
+
|
| 45 |
+
file_path = f"{PATH}/{file_name}{FILE_SUFFIX}"
|
| 46 |
+
|
| 47 |
+
eval_logger.debug(f"Saving {file_path} to cache...")
|
| 48 |
+
with open(file_path, "wb") as file:
|
| 49 |
+
file.write(dill.dumps(obj))
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# NOTE the "key" param is to allow for flexibility
|
| 53 |
+
def delete_cache(key: str = ""):
|
| 54 |
+
files = os.listdir(PATH)
|
| 55 |
+
|
| 56 |
+
for file in files:
|
| 57 |
+
if file.startswith(key) and file.endswith(FILE_SUFFIX):
|
| 58 |
+
file_path = f"{PATH}/{file}"
|
| 59 |
+
os.unlink(file_path)
|
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/decontamination/__init__.py
ADDED
|
File without changes
|
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/decontamination/archiver.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
import io
|
| 3 |
+
import json
|
| 4 |
+
import mmap
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
import jsonlines
|
| 10 |
+
import tqdm
|
| 11 |
+
import zstandard
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def json_serial(obj: Any) -> str:
|
| 15 |
+
"""JSON serializer for objects not serializable by default json code"""
|
| 16 |
+
|
| 17 |
+
if isinstance(obj, (datetime.datetime,)):
|
| 18 |
+
return obj.isoformat()
|
| 19 |
+
raise TypeError("Type %s not serializable" % type(obj))
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Modified version of lm_dataformat Archive for single file.
|
| 23 |
+
class Archive:
|
| 24 |
+
def __init__(self, file_path: str, compression_level: int = 3) -> None:
|
| 25 |
+
self.file_path = file_path
|
| 26 |
+
dir_name = os.path.dirname(file_path)
|
| 27 |
+
if dir_name:
|
| 28 |
+
os.makedirs(dir_name, exist_ok=True)
|
| 29 |
+
self.fh = open(self.file_path, "wb")
|
| 30 |
+
self.cctx = zstandard.ZstdCompressor(level=compression_level)
|
| 31 |
+
self.compressor = self.cctx.stream_writer(self.fh)
|
| 32 |
+
|
| 33 |
+
def add_data(self, data, meta=None) -> None:
|
| 34 |
+
if meta is None:
|
| 35 |
+
meta = {}
|
| 36 |
+
self.compressor.write(
|
| 37 |
+
json.dumps({"text": data, "meta": meta}, default=json_serial).encode(
|
| 38 |
+
"UTF-8"
|
| 39 |
+
)
|
| 40 |
+
+ b"\n"
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
def commit(self) -> None:
|
| 44 |
+
self.compressor.flush(zstandard.FLUSH_FRAME)
|
| 45 |
+
self.fh.flush()
|
| 46 |
+
self.fh.close()
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm.
|
| 50 |
+
class Reader:
|
| 51 |
+
def __init__(self) -> None:
|
| 52 |
+
pass
|
| 53 |
+
|
| 54 |
+
def read(
|
| 55 |
+
self,
|
| 56 |
+
file,
|
| 57 |
+
get_meta: bool = False,
|
| 58 |
+
autojoin_paragraphs: bool = True,
|
| 59 |
+
para_joiner: str = "\n\n",
|
| 60 |
+
):
|
| 61 |
+
with open(file, "rb") as fh:
|
| 62 |
+
self.fh = fh
|
| 63 |
+
cctx = zstandard.ZstdDecompressor()
|
| 64 |
+
reader = io.BufferedReader(cctx.stream_reader(fh))
|
| 65 |
+
rdr = jsonlines.Reader(reader)
|
| 66 |
+
for ob in rdr:
|
| 67 |
+
# naive jsonl where each object is just the string itself, with no meta. For legacy compatibility.
|
| 68 |
+
if isinstance(ob, str):
|
| 69 |
+
assert not get_meta
|
| 70 |
+
yield ob
|
| 71 |
+
continue
|
| 72 |
+
|
| 73 |
+
text = ob["text"]
|
| 74 |
+
|
| 75 |
+
if autojoin_paragraphs and isinstance(text, list):
|
| 76 |
+
text = para_joiner.join(text)
|
| 77 |
+
|
| 78 |
+
if get_meta:
|
| 79 |
+
yield text, (ob["meta"] if "meta" in ob else {})
|
| 80 |
+
else:
|
| 81 |
+
yield text
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class TextArchive:
|
| 85 |
+
def __init__(self, file_path, mode: str = "rb+") -> None:
|
| 86 |
+
self.file_path = file_path
|
| 87 |
+
dir_name = os.path.dirname(file_path)
|
| 88 |
+
if dir_name:
|
| 89 |
+
os.makedirs(dir_name, exist_ok=True)
|
| 90 |
+
|
| 91 |
+
if not os.path.exists(file_path):
|
| 92 |
+
Path(file_path).touch()
|
| 93 |
+
|
| 94 |
+
self.fh = open(self.file_path, mode)
|
| 95 |
+
|
| 96 |
+
def add_data(self, data) -> None:
|
| 97 |
+
self.fh.write(data.encode("UTF-8") + b"\n")
|
| 98 |
+
|
| 99 |
+
def commit(self) -> None:
|
| 100 |
+
self.fh.flush()
|
| 101 |
+
self.fh.close()
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class TextReader:
|
| 105 |
+
def __init__(self, file_path) -> None:
|
| 106 |
+
self.file_path = file_path
|
| 107 |
+
|
| 108 |
+
# Optimized mmap read with infrequent tqdm updates to maintain speed
|
| 109 |
+
# Tested up to 250MB/s.
|
| 110 |
+
def read_tqdm(self, update_frequency: int = 10000):
|
| 111 |
+
current_file_position = 0
|
| 112 |
+
line_counter = 0
|
| 113 |
+
with (
|
| 114 |
+
open(self.file_path, "r", encoding="utf-8") as fh,
|
| 115 |
+
tqdm.tqdm(
|
| 116 |
+
total=os.path.getsize(self.file_path),
|
| 117 |
+
dynamic_ncols=True,
|
| 118 |
+
unit="byte",
|
| 119 |
+
unit_scale=1,
|
| 120 |
+
) as progress,
|
| 121 |
+
):
|
| 122 |
+
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
|
| 123 |
+
for line in iter(mmap_obj.readline, b""):
|
| 124 |
+
line = line.decode("utf-8")
|
| 125 |
+
line_counter += 1
|
| 126 |
+
if line_counter == update_frequency:
|
| 127 |
+
new_file_pos = mmap_obj.tell()
|
| 128 |
+
bytes_read = new_file_pos - current_file_position
|
| 129 |
+
current_file_position = new_file_pos
|
| 130 |
+
progress.update(bytes_read)
|
| 131 |
+
line_counter = 0
|
| 132 |
+
yield line[:-1]
|
| 133 |
+
|
| 134 |
+
def read_and_tell(self):
|
| 135 |
+
current_file_position = 0
|
| 136 |
+
with open(self.file_path, "r", encoding="utf8") as fh:
|
| 137 |
+
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
|
| 138 |
+
for line in iter(mmap_obj.readline, b""):
|
| 139 |
+
line = line.decode("utf-8")
|
| 140 |
+
new_file_pos = mmap_obj.tell()
|
| 141 |
+
raw_bytes_read = new_file_pos - current_file_position
|
| 142 |
+
current_file_position = new_file_pos
|
| 143 |
+
yield line[:-1], raw_bytes_read
|
| 144 |
+
|
| 145 |
+
def read(self):
|
| 146 |
+
with open(self.file_path, "r", encoding="utf8") as fh:
|
| 147 |
+
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
|
| 148 |
+
for line in iter(mmap_obj.readline, b""):
|
| 149 |
+
line = line.decode("utf-8")
|
| 150 |
+
yield line[:-1]
|
| 151 |
+
|
| 152 |
+
def read_slow(self):
|
| 153 |
+
with open(self.file_path, "r", encoding="utf8") as fh:
|
| 154 |
+
while True:
|
| 155 |
+
line = fh.readline()
|
| 156 |
+
if line == -1 or line == "":
|
| 157 |
+
break
|
| 158 |
+
else:
|
| 159 |
+
yield line[:-1]
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
# Optimized for speed. Decompresses the archive in shell before
|
| 163 |
+
# using the mmap'd TextReader.
|
| 164 |
+
class ZStdTextReader:
|
| 165 |
+
def __init__(self, file) -> None:
|
| 166 |
+
self.file = file
|
| 167 |
+
|
| 168 |
+
def read_tqdm(self):
|
| 169 |
+
decompressed_file = self.file[:-4]
|
| 170 |
+
print("Decompressing file, please wait...")
|
| 171 |
+
os.system(f"zstd -d {self.file}") # linux decompress is faster
|
| 172 |
+
reader = TextReader(decompressed_file)
|
| 173 |
+
yield from reader.read_tqdm()
|
| 174 |
+
os.remove(decompressed_file)
|
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/decontamination/decontaminate.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import glob
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import pickle
|
| 6 |
+
import random
|
| 7 |
+
import time
|
| 8 |
+
|
| 9 |
+
from .archiver import ZStdTextReader
|
| 10 |
+
from .janitor import Janitor, word_ngrams
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# Was used for testing the evaluator decoupled from the full logic below
|
| 14 |
+
def get_train_overlap_stub(docs: dict, ngrams_path: str, ngrams_n_size: str):
|
| 15 |
+
simulated_overlap = 0.1
|
| 16 |
+
contaminated = int(len(docs) * simulated_overlap)
|
| 17 |
+
return random.sample(range(len(docs)), contaminated)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Returns a dictionary containing all overlapping documents in each
|
| 21 |
+
# task. In the standard use case, an overlap occurs when any of the 13-grams
|
| 22 |
+
# found in the task document exist in the training set documents.
|
| 23 |
+
#
|
| 24 |
+
# To generate 13-grams for the pile see scripts/clean_training_data. The final output of these
|
| 25 |
+
# scripts are an info.json file containing the n_gram_size (13) and a bunch of "ngrams_{x}.bkt.txt.sorted.zst"
|
| 26 |
+
# files. These should exist in the "ngrams_path" provided to this function.
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Algorithm:
|
| 30 |
+
# 1. Build lookups for each dataset {ngram: list(document_ids)}
|
| 31 |
+
# 2. Merge into an overall lookup {ngram: [(task_name, task_set, doc_ids),]}
|
| 32 |
+
# 3. Full scan the 13-grams from the training set against the merged lookup,
|
| 33 |
+
# saving matches in the "duplicates" dictionary {(task_name, task_set): set(doc_ids)}
|
| 34 |
+
# 4. Strip the task_set from the dictionary keys and return
|
| 35 |
+
#
|
| 36 |
+
# We cache the task+set lookups as well as the overlaps.
|
| 37 |
+
def get_train_overlap(docs_by_task_set: dict, ngrams_path: str, limit: int) -> dict:
|
| 38 |
+
# return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size)
|
| 39 |
+
|
| 40 |
+
info_dict_path = os.path.join(ngrams_path, "info.json")
|
| 41 |
+
info_dict = json.load(open(info_dict_path, "r", encoding="utf-8"))
|
| 42 |
+
ngrams_n_size = info_dict["ngram_size"]
|
| 43 |
+
|
| 44 |
+
janitor = Janitor()
|
| 45 |
+
|
| 46 |
+
# Build lookup for each dataset first in case we use different task combinations later
|
| 47 |
+
print("Building Lookups...")
|
| 48 |
+
start = time.perf_counter()
|
| 49 |
+
|
| 50 |
+
def get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit) -> str:
|
| 51 |
+
return f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.overlaps"
|
| 52 |
+
|
| 53 |
+
lookups = {}
|
| 54 |
+
duplicates = {} # (task_name, task_set): set(doc_ids)}
|
| 55 |
+
sets_to_decontaminate = len(docs_by_task_set.keys())
|
| 56 |
+
|
| 57 |
+
for (task_name, task_set), docs in docs_by_task_set.items():
|
| 58 |
+
if not os.path.exists(f"data/{task_name}"):
|
| 59 |
+
os.mkdir(f"data/{task_name}")
|
| 60 |
+
|
| 61 |
+
# Check if we've decontaminated this combination before
|
| 62 |
+
overlaps_dump_path = get_overlaps_dump_path(
|
| 63 |
+
task_name, task_set, ngrams_n_size, limit
|
| 64 |
+
)
|
| 65 |
+
if os.path.exists(overlaps_dump_path):
|
| 66 |
+
duplicates[(task_name, task_set)] = pickle.load(
|
| 67 |
+
open(overlaps_dump_path, "rb")
|
| 68 |
+
)
|
| 69 |
+
sets_to_decontaminate -= 1
|
| 70 |
+
continue
|
| 71 |
+
else:
|
| 72 |
+
duplicates[(task_name, task_set)] = set()
|
| 73 |
+
|
| 74 |
+
# Build/load the task lookup {ngram: set(documents)}.
|
| 75 |
+
task_set_lookup_path = (
|
| 76 |
+
f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.lookup"
|
| 77 |
+
)
|
| 78 |
+
if os.path.exists(task_set_lookup_path):
|
| 79 |
+
print(f"{task_set_lookup_path} available, loading...")
|
| 80 |
+
lookups[(task_name, task_set)] = pickle.load(
|
| 81 |
+
open(task_set_lookup_path, "rb")
|
| 82 |
+
)
|
| 83 |
+
else:
|
| 84 |
+
print(f"{task_set_lookup_path} not available, building...")
|
| 85 |
+
lookup = collections.defaultdict(set)
|
| 86 |
+
|
| 87 |
+
for doc_id, document in enumerate(docs):
|
| 88 |
+
ngrams = word_ngrams(janitor.normalize_string(document), ngrams_n_size)
|
| 89 |
+
for ngram in ngrams:
|
| 90 |
+
lookup[ngram].add(doc_id)
|
| 91 |
+
|
| 92 |
+
pickle.dump(lookup, open(task_set_lookup_path, "wb"))
|
| 93 |
+
lookups[(task_name, task_set)] = lookup
|
| 94 |
+
|
| 95 |
+
elapsed = time.perf_counter() - start
|
| 96 |
+
print(f"Building lookups took {elapsed:0.5f} seconds.")
|
| 97 |
+
|
| 98 |
+
matched_ngrams = []
|
| 99 |
+
|
| 100 |
+
if sets_to_decontaminate > 0:
|
| 101 |
+
print("Merging lookups...")
|
| 102 |
+
start = time.perf_counter()
|
| 103 |
+
merged_lookup = collections.defaultdict(list)
|
| 104 |
+
for (task_name, task_set), lookup in lookups.items():
|
| 105 |
+
for ngram, doc_ids in lookup.items():
|
| 106 |
+
merged_lookup[ngram].append((task_name, task_set, doc_ids))
|
| 107 |
+
|
| 108 |
+
elapsed = time.perf_counter() - start
|
| 109 |
+
print(f"Merging lookups took {elapsed:0.5f} seconds.")
|
| 110 |
+
|
| 111 |
+
print(f"{ngrams_n_size} grams files found in {ngrams_path}:")
|
| 112 |
+
files = glob.glob(os.path.join(ngrams_path, "*.sorted.zst"))
|
| 113 |
+
print(files)
|
| 114 |
+
|
| 115 |
+
for file in files:
|
| 116 |
+
start = time.perf_counter()
|
| 117 |
+
print(f"Scanning {file}")
|
| 118 |
+
reader = ZStdTextReader(file)
|
| 119 |
+
total_ngrams = 0
|
| 120 |
+
unique_ngrams = 0
|
| 121 |
+
matching_unique = 0
|
| 122 |
+
non_matching_unique = 0
|
| 123 |
+
|
| 124 |
+
current_ngram = ""
|
| 125 |
+
for line in reader.read_tqdm(): # Scan training set ngrams file
|
| 126 |
+
total_ngrams += 1
|
| 127 |
+
[ngram, document_id] = line.rsplit(" ", 1)
|
| 128 |
+
if (
|
| 129 |
+
ngram != current_ngram
|
| 130 |
+
): # Only need to match the ngram once in training set
|
| 131 |
+
unique_ngrams += 1
|
| 132 |
+
current_ngram = ngram
|
| 133 |
+
if ngram in merged_lookup:
|
| 134 |
+
matched_ngrams.append(ngram) # For logging
|
| 135 |
+
matching_unique += 1
|
| 136 |
+
for task_name, task_set, doc_ids in merged_lookup[ngram]:
|
| 137 |
+
task_doc_set = duplicates[(task_name, task_set)]
|
| 138 |
+
for doc_id in doc_ids: # Record contamination across all relevant task/set combos
|
| 139 |
+
task_doc_set.add(doc_id)
|
| 140 |
+
del merged_lookup[ngram] # No point matching again
|
| 141 |
+
else:
|
| 142 |
+
non_matching_unique += 1
|
| 143 |
+
|
| 144 |
+
print(f"Total Ngrams: {total_ngrams}")
|
| 145 |
+
print(f"Unique Ngrams: {unique_ngrams}")
|
| 146 |
+
print(f"Unique Matching: {matching_unique}")
|
| 147 |
+
print(f"Unique Non Matching: {non_matching_unique}")
|
| 148 |
+
print("Matched ngrams:")
|
| 149 |
+
for ngram in matched_ngrams:
|
| 150 |
+
print(ngram)
|
| 151 |
+
|
| 152 |
+
elapsed = time.perf_counter() - start
|
| 153 |
+
print(f"Read took {elapsed:0.5f} seconds.")
|
| 154 |
+
print(f"Speed: {(os.path.getsize(file) / 1000000.0) / elapsed}MB/second")
|
| 155 |
+
|
| 156 |
+
print(duplicates)
|
| 157 |
+
|
| 158 |
+
# Dump overlaps separately
|
| 159 |
+
for (task_name, task_set), doc_ids in duplicates.items():
|
| 160 |
+
overlaps_dump_path = get_overlaps_dump_path(
|
| 161 |
+
task_name, task_set, ngrams_n_size, limit
|
| 162 |
+
)
|
| 163 |
+
pickle.dump(doc_ids, open(overlaps_dump_path, "wb"))
|
| 164 |
+
|
| 165 |
+
# Strip task set and return
|
| 166 |
+
return {task_name: doc_ids for (task_name, task_set), doc_ids in duplicates.items()}
|
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/decontamination/janitor.py
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import re
|
| 3 |
+
import string
|
| 4 |
+
import traceback
|
| 5 |
+
from typing import Iterator, List, Sequence, Tuple, TypeVar
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# This is a cpp module. Compile janitor_util.cpp with:
|
| 9 |
+
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
|
| 10 |
+
try:
|
| 11 |
+
import janitor_util
|
| 12 |
+
|
| 13 |
+
JANITOR_CPP = True
|
| 14 |
+
except Exception:
|
| 15 |
+
print("WARNING: C++ module could not be loaded. Janitor running in python mode")
|
| 16 |
+
traceback.print_exc()
|
| 17 |
+
JANITOR_CPP = False
|
| 18 |
+
|
| 19 |
+
T = TypeVar("T")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Implementation from nltk source
|
| 23 |
+
# https://www.nltk.org/_modules/nltk/util.html
|
| 24 |
+
def form_ngrams(sequence: Iterator[T], n: int) -> Iterator[Tuple[T, ...]]:
|
| 25 |
+
history = []
|
| 26 |
+
while n > 1:
|
| 27 |
+
# PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
|
| 28 |
+
try:
|
| 29 |
+
next_item = next(sequence)
|
| 30 |
+
except StopIteration:
|
| 31 |
+
# no more data, terminate the generator
|
| 32 |
+
return
|
| 33 |
+
history.append(next_item)
|
| 34 |
+
n -= 1
|
| 35 |
+
for item in sequence:
|
| 36 |
+
history.append(item)
|
| 37 |
+
yield tuple(history)
|
| 38 |
+
del history[0]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def word_ngrams(s: str, n: int) -> Iterator[str]:
|
| 42 |
+
"""Splits a string into ngram words"""
|
| 43 |
+
tokens = s.split() # not a generator :(
|
| 44 |
+
ngram_seqs = form_ngrams(iter(tokens), n)
|
| 45 |
+
return (" ".join(ngram) for ngram in ngram_seqs)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# Does character sequences only - combined faster function to play around with later
|
| 49 |
+
# def word_ngrams_indices_combined(sequence, n):
|
| 50 |
+
# current_word = ""
|
| 51 |
+
# history = []
|
| 52 |
+
# gap = False;
|
| 53 |
+
# start = 0
|
| 54 |
+
# end = 0
|
| 55 |
+
# for character in sequence:
|
| 56 |
+
# if character == " ":
|
| 57 |
+
# if not gap:
|
| 58 |
+
# gap = True
|
| 59 |
+
# history.append(current_word)
|
| 60 |
+
# end += len(current_word) - 1
|
| 61 |
+
# current_word = ""
|
| 62 |
+
# if len(history) == n:
|
| 63 |
+
# yield (tuple(history), start, end)
|
| 64 |
+
# del history[0]
|
| 65 |
+
# start = end + 1
|
| 66 |
+
# end = start
|
| 67 |
+
# else:
|
| 68 |
+
# gap = False
|
| 69 |
+
# current_word += character
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python
|
| 73 |
+
def split_indices(s: str) -> Iterator[Tuple[str, Tuple[int, int]]]:
|
| 74 |
+
"""Splits a string on whitespaces and records the indices of each in the original string.
|
| 75 |
+
@:return generator((word, (start_idx, end_idx)), ...)
|
| 76 |
+
"""
|
| 77 |
+
return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s))
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def word_ngrams_indices(s: str, n: int) -> Iterator[Tuple[str, Tuple[int, int]]]:
|
| 81 |
+
"""Splits a string into pairs of (ngram words, their start/end indices)"""
|
| 82 |
+
tokens_with_indices = split_indices(s)
|
| 83 |
+
|
| 84 |
+
# Generator of ngrams of (word, idx_pairs)
|
| 85 |
+
# (
|
| 86 |
+
# [(word, (start,end)), (word, (start, end))...],
|
| 87 |
+
# [(word, (start, end)), ...],
|
| 88 |
+
# ...
|
| 89 |
+
# )
|
| 90 |
+
ngram_seqs_with_indices = form_ngrams(tokens_with_indices, n)
|
| 91 |
+
|
| 92 |
+
# Generator of pairs of word and index ngrams
|
| 93 |
+
# (
|
| 94 |
+
# ([word, word, ...], [(start,end), (start,end), ...]),
|
| 95 |
+
# ...
|
| 96 |
+
# )
|
| 97 |
+
ngram_indices_pairs = (
|
| 98 |
+
zip(*ngram_with_indices) for ngram_with_indices in ngram_seqs_with_indices
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...)
|
| 102 |
+
return (
|
| 103 |
+
(" ".join(ngram_seq), (indices[0][0], indices[-1][1]))
|
| 104 |
+
for ngram_seq, indices in ngram_indices_pairs
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class Janitor:
|
| 109 |
+
# FIXME delete_chars: Should anything else go here? Special chars?
|
| 110 |
+
def __init__(
|
| 111 |
+
self,
|
| 112 |
+
ngram_n: int = 13,
|
| 113 |
+
window_to_remove: int = 200,
|
| 114 |
+
too_dirty_cutoff: int = 10,
|
| 115 |
+
minimum_slice_length: int = 200,
|
| 116 |
+
delete_chars: str = string.punctuation,
|
| 117 |
+
) -> None:
|
| 118 |
+
self.ngram_n = ngram_n
|
| 119 |
+
self.window_to_remove = window_to_remove
|
| 120 |
+
self.too_dirty_cutoff = too_dirty_cutoff
|
| 121 |
+
self.minimum_slice_length = minimum_slice_length
|
| 122 |
+
self.delete_chars = delete_chars
|
| 123 |
+
|
| 124 |
+
self.dirt_ngrams = set()
|
| 125 |
+
|
| 126 |
+
# If in python, we'll translate uppercase to lowercase and delete naughty characters.
|
| 127 |
+
# This is fast by python standards
|
| 128 |
+
# https://stackoverflow.com/questions/638893/what-is-the-most-efficient-way-in-python-to-convert-a-string-to-all-lowercase-st
|
| 129 |
+
self.translation_table = str.maketrans(
|
| 130 |
+
string.ascii_lowercase + string.ascii_uppercase, # These characters
|
| 131 |
+
string.ascii_lowercase * 2, # Become these characters
|
| 132 |
+
self.delete_chars, # These are deleted
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
##############
|
| 136 |
+
# I/O for saving contamination ngrams
|
| 137 |
+
##############
|
| 138 |
+
|
| 139 |
+
def save_contamination_ngrams(self, filename: str) -> None:
|
| 140 |
+
with open(filename, "wb") as fp:
|
| 141 |
+
pickle.dump(filename, fp)
|
| 142 |
+
|
| 143 |
+
def load_contamination_ngrams(self, filename: str) -> None:
|
| 144 |
+
with open(filename, "rb") as fp:
|
| 145 |
+
self.dirt_ngrams = pickle.load(fp)
|
| 146 |
+
|
| 147 |
+
##############
|
| 148 |
+
# Call these :)
|
| 149 |
+
##############
|
| 150 |
+
|
| 151 |
+
def register_contaminant(self, dirt_string: str) -> None:
|
| 152 |
+
"""Register a string as contamination to be removed, e.g. a test set
|
| 153 |
+
This breaks the dirt_string into ngrams to store for future cleaning"""
|
| 154 |
+
if JANITOR_CPP:
|
| 155 |
+
return self.register_contaminant_cpp(dirt_string)
|
| 156 |
+
else:
|
| 157 |
+
print("WARNING: Janitor running in python mode")
|
| 158 |
+
return self.register_contaminant_python(dirt_string)
|
| 159 |
+
|
| 160 |
+
def clean(self, dirty_string: str) -> List[str]:
|
| 161 |
+
"""Clean a string (e.g. a training set) by removing all ngrams previously
|
| 162 |
+
registered as contaminants. Returns a list of clean chunks, or empty if
|
| 163 |
+
the string was too dirty"""
|
| 164 |
+
if JANITOR_CPP:
|
| 165 |
+
return self.clean_cpp(dirty_string)
|
| 166 |
+
else:
|
| 167 |
+
print("WARNING: Janitor running in python mode")
|
| 168 |
+
return self.clean_python(dirty_string)
|
| 169 |
+
|
| 170 |
+
def _split_chunks(
|
| 171 |
+
self, dirty_string: str, dirty_parts: Sequence[Tuple]
|
| 172 |
+
) -> List[str]:
|
| 173 |
+
clean_chunks = []
|
| 174 |
+
splice_idx = 0
|
| 175 |
+
end = -1
|
| 176 |
+
for i, (ngram, start, end) in enumerate(dirty_parts):
|
| 177 |
+
if i >= self.too_dirty_cutoff:
|
| 178 |
+
return []
|
| 179 |
+
start = max(0, start - self.window_to_remove)
|
| 180 |
+
end = min(len(dirty_string), end + self.window_to_remove)
|
| 181 |
+
|
| 182 |
+
if start - splice_idx > self.minimum_slice_length:
|
| 183 |
+
clean_chunks.append(dirty_string[splice_idx:start])
|
| 184 |
+
splice_idx = end
|
| 185 |
+
|
| 186 |
+
if end < len(dirty_string) - self.minimum_slice_length:
|
| 187 |
+
clean_chunks.append(dirty_string[end + 1 :])
|
| 188 |
+
|
| 189 |
+
return clean_chunks
|
| 190 |
+
|
| 191 |
+
##############
|
| 192 |
+
# Fast C++
|
| 193 |
+
##############
|
| 194 |
+
|
| 195 |
+
def register_contaminant_cpp(self, dirt_string) -> None:
|
| 196 |
+
self.dirt_ngrams.update(
|
| 197 |
+
janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n)
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
def clean_cpp(self, dirty_string: str) -> List[str]:
|
| 201 |
+
contamination_indices = janitor_util.clean_ngram_with_indices(
|
| 202 |
+
dirty_string, self.delete_chars, self.ngram_n
|
| 203 |
+
)
|
| 204 |
+
return self._split_chunks(dirty_string, contamination_indices)
|
| 205 |
+
|
| 206 |
+
##############
|
| 207 |
+
# Slow python
|
| 208 |
+
##############
|
| 209 |
+
|
| 210 |
+
def normalize_string(self, s: str) -> str:
|
| 211 |
+
return s.translate(self.translation_table)
|
| 212 |
+
|
| 213 |
+
def register_contaminant_python(self, dirt_string: str) -> None:
|
| 214 |
+
self.dirt_ngrams.update(
|
| 215 |
+
word_ngrams(self.normalize_string(dirt_string), self.ngram_n)
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
def clean_python(self, dirty_string: str) -> List[str]:
|
| 219 |
+
contamination_indices = (
|
| 220 |
+
(None, *idx_pair)
|
| 221 |
+
for dirty_ngram, idx_pair in word_ngrams_indices(dirty_string, self.ngram_n)
|
| 222 |
+
if self.normalize_string(dirty_ngram) in self.dirt_ngrams
|
| 223 |
+
)
|
| 224 |
+
return self._split_chunks(dirty_string, contamination_indices)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
##################################################################
|
| 228 |
+
# Tests
|
| 229 |
+
#################################################################
|
| 230 |
+
|
| 231 |
+
# def print_cpp():
|
| 232 |
+
# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
|
| 233 |
+
|
| 234 |
+
# for i in range(1, 10, 2):
|
| 235 |
+
# pprint(janitor_util.clean_ngram(source, string.punctuation, i))
|
| 236 |
+
# for ngram, start, end in \
|
| 237 |
+
# janitor_util.clean_ngram_with_indices(source, string.punctuation, i):
|
| 238 |
+
# print(ngram, "\t", start, end, source[start:end].replace("\n", "\\n"))
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# def test_cpp():
|
| 242 |
+
# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
|
| 243 |
+
# contaminant = "dirty boy. Clean he he"
|
| 244 |
+
|
| 245 |
+
# jan_python = Janitor()
|
| 246 |
+
# jan_cpp = Janitor()
|
| 247 |
+
|
| 248 |
+
# jan_python.register_contaminant_python(contaminant)
|
| 249 |
+
# jan_cpp.register_contaminant(contaminant)
|
| 250 |
+
|
| 251 |
+
# assert jan_python.dirt_ngrams == jan_cpp.dirt_ngrams, (jan_python.dirt_ngrams, jan_cpp.dirt_ngrams)
|
| 252 |
+
|
| 253 |
+
# assert jan_python.clean_python(source) == jan_cpp.clean(source), \
|
| 254 |
+
# (jan_python.clean_python(source), jan_cpp.clean(source))
|
| 255 |
+
|
| 256 |
+
# print("Passed test, python==cpp")
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
# def benchmark():
|
| 260 |
+
# # Download and put in data folder: enwik8 (100 MB) from https://cs.fit.edu/~mmahoney/compression/textdata.html
|
| 261 |
+
# setup = \
|
| 262 |
+
# """
|
| 263 |
+
# with open("data/enwik8", "r") as f:
|
| 264 |
+
# data = f.read()
|
| 265 |
+
# jan = Janitor(too_dirty_cutoff=1000)
|
| 266 |
+
# jan.register_contaminant('''
|
| 267 |
+
# theories is that there is a connection between "geekdom" and autism.
|
| 268 |
+
# This is hinted, for instance, by a ''Wired Magazine'' article in 2001 entitled "
|
| 269 |
+
# The [[Geek]] Syndrome", which is a point argued by many in the autism rights
|
| 270 |
+
# movement{{ref|Wired}}. This article, many professionals assert, is just one example of
|
| 271 |
+
# the media's application of mental disease labels to what is actually variant normal behavior
|
| 272 |
+
# &mdash;they argue that shyness, lack of athletic ability or social skills, and intellectual
|
| 273 |
+
# interests, even when they seem unusual to others, are not in themselves signs of autism or
|
| 274 |
+
# Asperger's syndrome. Others assert that it is actually the medical profession which is applying
|
| 275 |
+
# mental disease labels to children who in the past would have simply been accepted as a little
|
| 276 |
+
# different or even labeled 'gifted'. See [[clinomorphism]] for further discussion of this issue.
|
| 277 |
+
# Due to the recent publicity surrounding autism and autis
|
| 278 |
+
# ultan Al Nahyan]] granted [[Petroleum]] concessions, and oil was first found in 1958. At first,
|
| 279 |
+
# oil money had a marginal impact. A few lowrise concete buildings were erected, and the first
|
| 280 |
+
# paved road was completed in 1961, but Sheikh Shakbut, uncertain whether the new oil royalties
|
| 281 |
+
# would last, took a cautious approach, preferring to save the revenue rather than investing it in
|
| 282 |
+
# development. His brother, [[Zayed bin Sultan Al Nahayan]], saw that oil wealth had the potential
|
| 283 |
+
# to transform Abu Dhabi. The ruling Al Nahayan family decided that Sheikh Zayed should replace his
|
| 284 |
+
# brother as Ruler and carry out his vision of developing the country. On [[August 6]], [[1966]],
|
| 285 |
+
# with the assistance of the British, Sheikh Zayed became the new ruler. See generally, Al-Fahim, M,
|
| 286 |
+
# ''From Rags to Riches: A Story of Abu Dhabi'', Chapter Six (London Centre of Arab Studies, 1995),
|
| 287 |
+
# ISBN 1 900404 00 1. With the announcement by Britain in 1968 that it would withdraw from the
|
| 288 |
+
# Gulf area by 1971, Sheikh Zayed became the main driving force behind the formation of the
|
| 289 |
+
# [[United Arab Emirates]]. After the Emirates gained independence in 1971,
|
| 290 |
+
# ''')
|
| 291 |
+
# """
|
| 292 |
+
|
| 293 |
+
# n = 1
|
| 294 |
+
# print(f"Timing {n} run on 100 MB")
|
| 295 |
+
# print("Register contaminant")
|
| 296 |
+
# # print("\tPython", timeit.timeit("jan.register_contaminant_python(data)", setup=setup, globals=globals(), number=n))
|
| 297 |
+
# print("\tCpp", timeit.timeit("jan.register_contaminant(data)", setup=setup, globals=globals(), number=n))
|
| 298 |
+
|
| 299 |
+
# print("Clean")
|
| 300 |
+
# # print("\tPython", timeit.timeit("jan.clean_python(data)", setup=setup, globals=globals(), number=n))
|
| 301 |
+
# print("\tCpp", timeit.timeit("jan.clean(data)", setup=setup, globals=globals(), number=n))
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
# def test_janitor_general():
|
| 305 |
+
# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
|
| 306 |
+
# contaminant = "dirty boy. Clean he he"
|
| 307 |
+
|
| 308 |
+
# jan = Janitor(ngram_n=3)
|
| 309 |
+
# jan.register_contaminant(contaminant)
|
| 310 |
+
# cleaned = " ".join(jan.clean(source))
|
| 311 |
+
# for contam in jan.dirt_ngrams:
|
| 312 |
+
# assert contam not in cleaned, contam
|
| 313 |
+
|
| 314 |
+
# filename = "data/saved_contam"
|
| 315 |
+
# jan.save_contamination_ngrams(filename)
|
| 316 |
+
|
| 317 |
+
# jan = Janitor(ngram_n=3)
|
| 318 |
+
# jan.load_contamination_ngrams(filename)
|
| 319 |
+
# cleaned = " ".join(jan.clean(source))
|
| 320 |
+
# for contam in jan.dirt_ngrams:
|
| 321 |
+
# assert contam not in cleaned, contam
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
# if __name__ == "__main__":
|
| 325 |
+
# test()
|
| 326 |
+
# # print_cpp()
|
| 327 |
+
# # test_cpp()
|
| 328 |
+
# # benchmark()
|
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/evaluator.py
ADDED
|
@@ -0,0 +1,736 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import random
|
| 5 |
+
import time
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
from typing import TYPE_CHECKING, List, Optional, Union
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
import lm_eval.api.metrics
|
| 13 |
+
import lm_eval.api.registry
|
| 14 |
+
import lm_eval.api.task
|
| 15 |
+
import lm_eval.models
|
| 16 |
+
from lm_eval.caching.cache import delete_cache
|
| 17 |
+
from lm_eval.evaluator_utils import (
|
| 18 |
+
consolidate_group_results,
|
| 19 |
+
consolidate_results,
|
| 20 |
+
get_sample_size,
|
| 21 |
+
get_subtask_list,
|
| 22 |
+
get_task_list,
|
| 23 |
+
prepare_print_tasks,
|
| 24 |
+
print_writeout,
|
| 25 |
+
run_task_tests,
|
| 26 |
+
)
|
| 27 |
+
from lm_eval.loggers import EvaluationTracker
|
| 28 |
+
from lm_eval.loggers.utils import add_env_info, add_tokenizer_info, get_git_commit_hash
|
| 29 |
+
from lm_eval.tasks import (
|
| 30 |
+
TaskManager,
|
| 31 |
+
get_task_dict,
|
| 32 |
+
)
|
| 33 |
+
from lm_eval.utils import (
|
| 34 |
+
handle_non_serializable,
|
| 35 |
+
hash_string,
|
| 36 |
+
positional_deprecated,
|
| 37 |
+
setup_logging,
|
| 38 |
+
simple_parse_args_string,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
if TYPE_CHECKING:
|
| 43 |
+
from lm_eval.api.model import LM
|
| 44 |
+
from lm_eval.api.task import Task
|
| 45 |
+
|
| 46 |
+
eval_logger = logging.getLogger(__name__)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@positional_deprecated
|
| 50 |
+
def simple_evaluate(
|
| 51 |
+
model,
|
| 52 |
+
model_args: Optional[Union[str, dict]] = None,
|
| 53 |
+
tasks: Optional[List[Union[str, dict, object]]] = None,
|
| 54 |
+
num_fewshot: Optional[int] = None,
|
| 55 |
+
batch_size: Optional[Union[int, str]] = None,
|
| 56 |
+
max_batch_size: Optional[int] = None,
|
| 57 |
+
device: Optional[str] = None,
|
| 58 |
+
use_cache: Optional[str] = None,
|
| 59 |
+
cache_requests: bool = False,
|
| 60 |
+
rewrite_requests_cache: bool = False,
|
| 61 |
+
delete_requests_cache: bool = False,
|
| 62 |
+
limit: Optional[Union[int, float]] = None,
|
| 63 |
+
bootstrap_iters: int = 100000,
|
| 64 |
+
check_integrity: bool = False,
|
| 65 |
+
write_out: bool = False,
|
| 66 |
+
log_samples: bool = True,
|
| 67 |
+
evaluation_tracker: Optional[EvaluationTracker] = None,
|
| 68 |
+
system_instruction: Optional[str] = None,
|
| 69 |
+
apply_chat_template: Union[bool, str] = False,
|
| 70 |
+
fewshot_as_multiturn: bool = False,
|
| 71 |
+
gen_kwargs: Union[str, dict, None] = None,
|
| 72 |
+
task_manager: Optional[TaskManager] = None,
|
| 73 |
+
verbosity=None,
|
| 74 |
+
predict_only: bool = False,
|
| 75 |
+
random_seed: int = 0,
|
| 76 |
+
numpy_random_seed: int = 1234,
|
| 77 |
+
torch_random_seed: int = 1234,
|
| 78 |
+
fewshot_random_seed: int = 1234,
|
| 79 |
+
confirm_run_unsafe_code: bool = False,
|
| 80 |
+
metadata: Optional[dict] = None,
|
| 81 |
+
):
|
| 82 |
+
"""Instantiate and evaluate a model on a list of tasks.
|
| 83 |
+
|
| 84 |
+
:param model: Union[str, LM]
|
| 85 |
+
Name of model or LM object, see lm_eval.models.get_model
|
| 86 |
+
:param model_args: Optional[str, dict]
|
| 87 |
+
String or dict arguments for each model class, see LM.create_from_arg_string and LM.create_from_arg_object.
|
| 88 |
+
Ignored if `model` argument is a LM object.
|
| 89 |
+
:param tasks: list[Union[str, dict, Task]]
|
| 90 |
+
List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
|
| 91 |
+
:param num_fewshot: int
|
| 92 |
+
Number of examples in few-shot context
|
| 93 |
+
:param batch_size: int or str, optional
|
| 94 |
+
Batch size for model
|
| 95 |
+
:param max_batch_size: int, optional
|
| 96 |
+
Maximal batch size to try with automatic batch size detection
|
| 97 |
+
:param device: str, optional
|
| 98 |
+
PyTorch device (e.g. "cpu" or "cuda:0") for running models
|
| 99 |
+
:param use_cache: str, optional
|
| 100 |
+
A path to a sqlite db file for caching model responses. `None` if not caching.
|
| 101 |
+
:param cache_requests: bool, optional
|
| 102 |
+
Speed up evaluation by caching the building of dataset requests. `None` if not caching.
|
| 103 |
+
:param rewrite_requests_cache: bool, optional
|
| 104 |
+
Rewrites all the request cache if set to `True`. `None` if not desired.
|
| 105 |
+
:param delete_requests_cache: bool, optional
|
| 106 |
+
Deletes all the request cache if set to `True`. `None` if not desired.
|
| 107 |
+
:param limit: int or float, optional
|
| 108 |
+
Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
|
| 109 |
+
:param bootstrap_iters:
|
| 110 |
+
Number of iterations for bootstrap statistics, used when calculating stderrs. set to 0 for no stderr calculations to be performed.
|
| 111 |
+
:param check_integrity: bool
|
| 112 |
+
Whether to run the relevant part of the test suite for the tasks
|
| 113 |
+
:param write_out: bool
|
| 114 |
+
If True, write out an example document and model input for checking task integrity
|
| 115 |
+
:param log_samples: bool
|
| 116 |
+
If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
|
| 117 |
+
:param system_instruction: str
|
| 118 |
+
System instruction to be applied to the prompt
|
| 119 |
+
:param apply_chat_template: Union[bool, str]
|
| 120 |
+
Specifies whether to apply a chat template to the prompt.
|
| 121 |
+
- If set to True, the default chat template is applied.
|
| 122 |
+
- If set to a string, applies the specified chat template by name.
|
| 123 |
+
Defaults to False (no chat template applied).
|
| 124 |
+
:param fewshot_as_multiturn: bool
|
| 125 |
+
Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
|
| 126 |
+
:param gen_kwargs: dict or comma-separated string
|
| 127 |
+
Arguments for model generation
|
| 128 |
+
Ignored for all tasks with loglikelihood output_type
|
| 129 |
+
:param verbosity: str
|
| 130 |
+
Verbosity level for logging
|
| 131 |
+
:param predict_only: bool
|
| 132 |
+
If true only model outputs will be generated and returned. Metrics will not be evaluated
|
| 133 |
+
:param random_seed: int
|
| 134 |
+
Random seed for python's random module. If set to None, the seed will not be set.
|
| 135 |
+
:param numpy_random_seed: int
|
| 136 |
+
Random seed for numpy. If set to None, the seed will not be set.
|
| 137 |
+
:param torch_random_seed: int
|
| 138 |
+
Random seed for torch. If set to None, the seed will not be set.
|
| 139 |
+
:param fewshot_random_seed: int
|
| 140 |
+
Random seed for fewshot sampler random generator. If set to None, the seed of generator will be set to None.
|
| 141 |
+
:param metadata: dict
|
| 142 |
+
Additional metadata to be added to the task manager. Will get passed to the download function of the task.
|
| 143 |
+
|
| 144 |
+
return
|
| 145 |
+
Dictionary of results
|
| 146 |
+
"""
|
| 147 |
+
if verbosity is not None:
|
| 148 |
+
setup_logging(verbosity=verbosity)
|
| 149 |
+
start_date = time.time()
|
| 150 |
+
|
| 151 |
+
if isinstance(model_args, str) and (
|
| 152 |
+
"instruct" in model_args and not apply_chat_template
|
| 153 |
+
):
|
| 154 |
+
eval_logger.warning(
|
| 155 |
+
"Instruct model detected, but chat template not applied. Recommend setting `apply_chat_template` (optionally `fewshot_as_multiturn`)."
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
if delete_requests_cache:
|
| 159 |
+
eval_logger.info("Deleting requests cache...")
|
| 160 |
+
delete_cache()
|
| 161 |
+
|
| 162 |
+
seed_message = []
|
| 163 |
+
if random_seed is not None:
|
| 164 |
+
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1412
|
| 165 |
+
seed_message.append(f"Setting random seed to {random_seed}")
|
| 166 |
+
random.seed(random_seed)
|
| 167 |
+
|
| 168 |
+
if numpy_random_seed is not None:
|
| 169 |
+
seed_message.append(f"Setting numpy seed to {numpy_random_seed}")
|
| 170 |
+
np.random.seed(numpy_random_seed)
|
| 171 |
+
|
| 172 |
+
if torch_random_seed is not None:
|
| 173 |
+
seed_message.append(f"Setting torch manual seed to {torch_random_seed}")
|
| 174 |
+
torch.manual_seed(torch_random_seed)
|
| 175 |
+
|
| 176 |
+
if fewshot_random_seed is not None:
|
| 177 |
+
seed_message.append(f"Setting fewshot manual seed to {fewshot_random_seed}")
|
| 178 |
+
|
| 179 |
+
if seed_message:
|
| 180 |
+
eval_logger.info(" | ".join(seed_message))
|
| 181 |
+
|
| 182 |
+
if tasks is None:
|
| 183 |
+
tasks = []
|
| 184 |
+
if len(tasks) == 0:
|
| 185 |
+
raise ValueError(
|
| 186 |
+
"No tasks specified, or no tasks found. Please verify the task names."
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
if gen_kwargs is not None:
|
| 190 |
+
if isinstance(gen_kwargs, str):
|
| 191 |
+
gen_kwargs = simple_parse_args_string(gen_kwargs)
|
| 192 |
+
eval_logger.warning(
|
| 193 |
+
f"generation_kwargs: {gen_kwargs} specified through cli, these settings will update set parameters in yaml tasks. "
|
| 194 |
+
"Ensure 'do_sample=True' for non-greedy decoding!"
|
| 195 |
+
)
|
| 196 |
+
if not gen_kwargs:
|
| 197 |
+
gen_kwargs = None
|
| 198 |
+
|
| 199 |
+
if isinstance(model, str):
|
| 200 |
+
if model_args is None:
|
| 201 |
+
eval_logger.warning("model_args not specified. Using defaults.")
|
| 202 |
+
model_args = ""
|
| 203 |
+
|
| 204 |
+
if isinstance(model_args, dict):
|
| 205 |
+
eval_logger.info(
|
| 206 |
+
f"Initializing {model} model, with arguments: {model_args}"
|
| 207 |
+
)
|
| 208 |
+
lm = lm_eval.api.registry.get_model(model).create_from_arg_obj(
|
| 209 |
+
model_args,
|
| 210 |
+
{
|
| 211 |
+
"batch_size": batch_size,
|
| 212 |
+
"max_batch_size": max_batch_size,
|
| 213 |
+
"device": device,
|
| 214 |
+
},
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
else:
|
| 218 |
+
eval_logger.info(
|
| 219 |
+
f"Initializing {model} model, with arguments: {simple_parse_args_string(model_args)}"
|
| 220 |
+
)
|
| 221 |
+
lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
|
| 222 |
+
model_args,
|
| 223 |
+
{
|
| 224 |
+
"batch_size": batch_size,
|
| 225 |
+
"max_batch_size": max_batch_size,
|
| 226 |
+
"device": device,
|
| 227 |
+
},
|
| 228 |
+
)
|
| 229 |
+
else:
|
| 230 |
+
if not isinstance(model, lm_eval.api.model.LM):
|
| 231 |
+
raise TypeError(
|
| 232 |
+
f"The value of `model` passed to simple_evaluate() was of type {type(model)}, but is required to be a subclass of lm_eval.api.model.LM . This may be because you are passing an initialized Hugging Face PreTrainedModel without having wrapped it in `lm_eval.models.huggingface.HFLM(pretrained=my_model)` first."
|
| 233 |
+
)
|
| 234 |
+
eval_logger.info("Using pre-initialized model")
|
| 235 |
+
lm = model
|
| 236 |
+
|
| 237 |
+
if use_cache is not None:
|
| 238 |
+
eval_logger.info(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
|
| 239 |
+
lm = lm_eval.api.model.CachingLM(
|
| 240 |
+
lm,
|
| 241 |
+
use_cache
|
| 242 |
+
# each rank receives a different cache db.
|
| 243 |
+
# necessary to avoid multiple writes to cache at once
|
| 244 |
+
+ "_rank"
|
| 245 |
+
+ str(lm.rank)
|
| 246 |
+
+ ".db",
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
if task_manager is None:
|
| 250 |
+
metadata = (
|
| 251 |
+
simple_parse_args_string(model_args)
|
| 252 |
+
if isinstance(model_args, str)
|
| 253 |
+
else model_args
|
| 254 |
+
if isinstance(model_args, dict)
|
| 255 |
+
else {}
|
| 256 |
+
) | (metadata or {})
|
| 257 |
+
task_manager = TaskManager(metadata=metadata)
|
| 258 |
+
|
| 259 |
+
task_dict = get_task_dict(
|
| 260 |
+
tasks,
|
| 261 |
+
task_manager,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
# helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
|
| 265 |
+
# (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed)
|
| 266 |
+
def _adjust_config(task_dict):
|
| 267 |
+
adjusted_task_dict = {}
|
| 268 |
+
for task_name, task_obj in task_dict.items():
|
| 269 |
+
if isinstance(task_obj, dict):
|
| 270 |
+
adjusted_task_dict = {
|
| 271 |
+
**adjusted_task_dict,
|
| 272 |
+
**{task_name: _adjust_config(task_obj)},
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
else:
|
| 276 |
+
if task_obj.get_config("output_type") == "generate_until":
|
| 277 |
+
if gen_kwargs is not None:
|
| 278 |
+
task_obj.set_config(
|
| 279 |
+
key="generation_kwargs", value=gen_kwargs, update=True
|
| 280 |
+
)
|
| 281 |
+
eval_logger.info(
|
| 282 |
+
f"{task_obj.config.task}: Using gen_kwargs: {task_obj.config.generation_kwargs}"
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
if predict_only:
|
| 286 |
+
eval_logger.info(
|
| 287 |
+
f"Processing {task_name} in output-only mode. Metrics will not be calculated!"
|
| 288 |
+
)
|
| 289 |
+
# we have to change the class properties post-hoc. This is pretty hacky.
|
| 290 |
+
task_obj.override_metric(metric_name="bypass")
|
| 291 |
+
|
| 292 |
+
# override tasks' fewshot values to the provided num_fewshot arg value
|
| 293 |
+
# except if tasks have it set to 0 manually in their configs--then we should never overwrite that
|
| 294 |
+
if num_fewshot is not None:
|
| 295 |
+
if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0:
|
| 296 |
+
eval_logger.info(
|
| 297 |
+
f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
|
| 298 |
+
)
|
| 299 |
+
else:
|
| 300 |
+
eval_logger.warning(
|
| 301 |
+
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
|
| 302 |
+
)
|
| 303 |
+
task_obj.set_config(key="num_fewshot", value=num_fewshot)
|
| 304 |
+
else:
|
| 305 |
+
# if num_fewshot not provided, and the task does not define a default one, default to 0
|
| 306 |
+
if (
|
| 307 |
+
default_num_fewshot := task_obj.get_config("num_fewshot")
|
| 308 |
+
) is None:
|
| 309 |
+
task_obj.set_config(key="num_fewshot", value=0)
|
| 310 |
+
# fewshot_random_seed set for tasks, even with a default num_fewshot (e.g. in the YAML file)
|
| 311 |
+
task_obj.set_fewshot_seed(seed=fewshot_random_seed)
|
| 312 |
+
|
| 313 |
+
adjusted_task_dict[task_name] = task_obj
|
| 314 |
+
|
| 315 |
+
return adjusted_task_dict
|
| 316 |
+
|
| 317 |
+
task_dict = _adjust_config(task_dict)
|
| 318 |
+
|
| 319 |
+
if check_integrity:
|
| 320 |
+
run_task_tests(task_list=tasks)
|
| 321 |
+
|
| 322 |
+
if evaluation_tracker is not None:
|
| 323 |
+
evaluation_tracker.general_config_tracker.log_experiment_args(
|
| 324 |
+
model_source=model,
|
| 325 |
+
model_args=model_args,
|
| 326 |
+
system_instruction=system_instruction,
|
| 327 |
+
chat_template=lm.chat_template(apply_chat_template)
|
| 328 |
+
if apply_chat_template
|
| 329 |
+
else None,
|
| 330 |
+
fewshot_as_multiturn=fewshot_as_multiturn,
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
results = evaluate(
|
| 334 |
+
lm=lm,
|
| 335 |
+
task_dict=task_dict,
|
| 336 |
+
limit=limit,
|
| 337 |
+
cache_requests=cache_requests,
|
| 338 |
+
rewrite_requests_cache=rewrite_requests_cache,
|
| 339 |
+
bootstrap_iters=bootstrap_iters,
|
| 340 |
+
write_out=write_out,
|
| 341 |
+
log_samples=True if predict_only else log_samples,
|
| 342 |
+
system_instruction=system_instruction,
|
| 343 |
+
apply_chat_template=apply_chat_template,
|
| 344 |
+
fewshot_as_multiturn=fewshot_as_multiturn,
|
| 345 |
+
verbosity=verbosity,
|
| 346 |
+
confirm_run_unsafe_code=confirm_run_unsafe_code,
|
| 347 |
+
)
|
| 348 |
+
if verbosity is not None:
|
| 349 |
+
setup_logging(verbosity=verbosity)
|
| 350 |
+
|
| 351 |
+
if lm.rank == 0:
|
| 352 |
+
if isinstance(model, str):
|
| 353 |
+
model_name = model
|
| 354 |
+
elif hasattr(model, "config") and hasattr(model.config, "_name_or_path"):
|
| 355 |
+
model_name = model.config._name_or_path
|
| 356 |
+
else:
|
| 357 |
+
model_name = type(model).__name__
|
| 358 |
+
|
| 359 |
+
# add info about the model and few shot config
|
| 360 |
+
results["config"] = {
|
| 361 |
+
"model": model_name,
|
| 362 |
+
"model_args": model_args,
|
| 363 |
+
}
|
| 364 |
+
# add more detailed model info if available
|
| 365 |
+
if isinstance(lm, lm_eval.models.huggingface.HFLM):
|
| 366 |
+
results["config"].update(lm.get_model_info())
|
| 367 |
+
# add info about execution
|
| 368 |
+
results["config"].update(
|
| 369 |
+
{
|
| 370 |
+
"batch_size": batch_size,
|
| 371 |
+
"batch_sizes": (
|
| 372 |
+
list(lm.batch_sizes.values()) if hasattr(lm, "batch_sizes") else []
|
| 373 |
+
),
|
| 374 |
+
"device": device,
|
| 375 |
+
"use_cache": use_cache,
|
| 376 |
+
"limit": limit,
|
| 377 |
+
"bootstrap_iters": bootstrap_iters,
|
| 378 |
+
"gen_kwargs": gen_kwargs,
|
| 379 |
+
"random_seed": random_seed,
|
| 380 |
+
"numpy_seed": numpy_random_seed,
|
| 381 |
+
"torch_seed": torch_random_seed,
|
| 382 |
+
"fewshot_seed": fewshot_random_seed,
|
| 383 |
+
}
|
| 384 |
+
)
|
| 385 |
+
results["git_hash"] = get_git_commit_hash()
|
| 386 |
+
results["date"] = start_date
|
| 387 |
+
add_env_info(results) # additional environment info to results
|
| 388 |
+
add_tokenizer_info(results, lm) # additional info about tokenizer
|
| 389 |
+
return results
|
| 390 |
+
else:
|
| 391 |
+
return None
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
@positional_deprecated
|
| 395 |
+
def evaluate(
|
| 396 |
+
lm: "LM",
|
| 397 |
+
task_dict,
|
| 398 |
+
limit: Optional[int] = None,
|
| 399 |
+
cache_requests: bool = False,
|
| 400 |
+
rewrite_requests_cache: bool = False,
|
| 401 |
+
bootstrap_iters: Optional[int] = 100000,
|
| 402 |
+
write_out: bool = False,
|
| 403 |
+
log_samples: bool = True,
|
| 404 |
+
system_instruction: Optional[str] = None,
|
| 405 |
+
apply_chat_template: Union[bool, str] = False,
|
| 406 |
+
fewshot_as_multiturn: bool = False,
|
| 407 |
+
verbosity: str = "INFO",
|
| 408 |
+
confirm_run_unsafe_code: bool = False,
|
| 409 |
+
):
|
| 410 |
+
"""Instantiate and evaluate a model on a list of tasks.
|
| 411 |
+
|
| 412 |
+
:param lm: obj
|
| 413 |
+
Language Model
|
| 414 |
+
:param task_dict: dict[str, Task]
|
| 415 |
+
Dictionary of tasks. Tasks will be taken to have name type(task).config.task .
|
| 416 |
+
:param limit: int, optional
|
| 417 |
+
Limit the number of examples per task (only use this for testing)
|
| 418 |
+
:param cache_requests: bool, optional
|
| 419 |
+
Speed up evaluation by caching the building of dataset requests.
|
| 420 |
+
:param rewrite_requests_cache: bool, optional
|
| 421 |
+
Rewrites all the request cache if set to `True`.
|
| 422 |
+
:param bootstrap_iters:
|
| 423 |
+
Number of iterations for bootstrap statistics, used when calculating stderr. Set to 0 for skipping all stderr calculations.
|
| 424 |
+
:param write_out: bool
|
| 425 |
+
If True, write out an example document and model input for checking task integrity
|
| 426 |
+
:param log_samples: bool
|
| 427 |
+
If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
|
| 428 |
+
:param system_instruction: str
|
| 429 |
+
System instruction to be applied to the prompt
|
| 430 |
+
:param apply_chat_template: Union[bool, str]
|
| 431 |
+
Specifies whether to apply a chat template to the prompt.
|
| 432 |
+
- If set to True, the default chat template is applied.
|
| 433 |
+
- If set to a string, applies the specified chat template by name.
|
| 434 |
+
Defaults to False (no chat template applied).
|
| 435 |
+
:param fewshot_as_multiturn: bool
|
| 436 |
+
Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
|
| 437 |
+
:param verbosity: str
|
| 438 |
+
Verbosity level for logging
|
| 439 |
+
:param confirm_run_unsafe_code: bool
|
| 440 |
+
Whether to confirm running tasks marked as unsafe.
|
| 441 |
+
:return
|
| 442 |
+
Dictionary of results
|
| 443 |
+
"""
|
| 444 |
+
|
| 445 |
+
if apply_chat_template:
|
| 446 |
+
eval_logger.warning(
|
| 447 |
+
"Chat template formatting change affects loglikelihood and multiple-choice tasks. See docs/chat-template-readme.md for details."
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
# tracks all Instances/requests a model must generate output on.
|
| 451 |
+
requests = defaultdict(list)
|
| 452 |
+
# stores the amount to pad out reqs per req. type so that
|
| 453 |
+
# number of fwd passes per distributed rank is equal
|
| 454 |
+
padding_requests = defaultdict(int)
|
| 455 |
+
|
| 456 |
+
# get lists of group hierarchy and each type of request
|
| 457 |
+
eval_tasks = get_task_list(task_dict)
|
| 458 |
+
if not log_samples:
|
| 459 |
+
if not all(
|
| 460 |
+
"bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys()
|
| 461 |
+
for task_output in eval_tasks
|
| 462 |
+
):
|
| 463 |
+
raise ValueError("log_samples must be True for 'bypass' metric-only tasks")
|
| 464 |
+
|
| 465 |
+
# validation checks:
|
| 466 |
+
# 1.are we running multimodal task <-> non-multimodal model class, or vice-versa.
|
| 467 |
+
# 2.are we running code that is marked as unsafe.
|
| 468 |
+
incompatible_tasks = []
|
| 469 |
+
for task_output in eval_tasks:
|
| 470 |
+
task: Task = task_output.task
|
| 471 |
+
|
| 472 |
+
if getattr(lm, "MULTIMODAL", False) != getattr(task, "MULTIMODAL", False):
|
| 473 |
+
incompatible_tasks.append(task_output.task_name)
|
| 474 |
+
elif getattr(task, "UNSAFE_CODE", False) and not confirm_run_unsafe_code:
|
| 475 |
+
raise ValueError(
|
| 476 |
+
f"Attempted to run task: {task_output.task_name} which is marked as unsafe. Set confirm_run_unsafe_code=True to run this task."
|
| 477 |
+
)
|
| 478 |
+
if len(incompatible_tasks) > 0:
|
| 479 |
+
if not getattr(lm, "MULTIMODAL", False):
|
| 480 |
+
raise ValueError(
|
| 481 |
+
f"Attempted to run tasks: {incompatible_tasks} which require multimodal input, but the selected model type does not currently implement this. Multimodal support is currently restricted to the ['hf-multimodal', 'vllm-vlm'] model type."
|
| 482 |
+
)
|
| 483 |
+
else:
|
| 484 |
+
raise ValueError(
|
| 485 |
+
f"Attempted to run tasks: {incompatible_tasks} which are text-only, but used a model type which only currently supports multimodal tasks."
|
| 486 |
+
)
|
| 487 |
+
# end validation check
|
| 488 |
+
|
| 489 |
+
# Cache the limit arg.
|
| 490 |
+
limit_arg = limit
|
| 491 |
+
limits = []
|
| 492 |
+
for task_output in eval_tasks:
|
| 493 |
+
task: Task = task_output.task
|
| 494 |
+
|
| 495 |
+
limit = get_sample_size(task, limit_arg)
|
| 496 |
+
limits.append(limit)
|
| 497 |
+
task.build_all_requests(
|
| 498 |
+
limit=limit,
|
| 499 |
+
rank=lm.rank,
|
| 500 |
+
world_size=lm.world_size,
|
| 501 |
+
cache_requests=cache_requests,
|
| 502 |
+
rewrite_requests_cache=rewrite_requests_cache,
|
| 503 |
+
system_instruction=system_instruction,
|
| 504 |
+
apply_chat_template=bool(apply_chat_template),
|
| 505 |
+
fewshot_as_multiturn=fewshot_as_multiturn,
|
| 506 |
+
chat_template=getattr(lm, "apply_chat_template")
|
| 507 |
+
if apply_chat_template
|
| 508 |
+
else None,
|
| 509 |
+
tokenizer_name=getattr(lm, "tokenizer_name", "")
|
| 510 |
+
if apply_chat_template
|
| 511 |
+
else "",
|
| 512 |
+
)
|
| 513 |
+
eval_logger.debug(
|
| 514 |
+
f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}"
|
| 515 |
+
)
|
| 516 |
+
if write_out:
|
| 517 |
+
print_writeout(task)
|
| 518 |
+
# aggregate Instances by LM method requested to get output.
|
| 519 |
+
for instance in task.instances:
|
| 520 |
+
reqtype = instance.request_type
|
| 521 |
+
requests[reqtype].append(instance)
|
| 522 |
+
|
| 523 |
+
if lm.world_size > 1:
|
| 524 |
+
instances_rnk = torch.tensor(len(task._instances), device=lm.device)
|
| 525 |
+
gathered_item = (
|
| 526 |
+
lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
|
| 527 |
+
)
|
| 528 |
+
# "multiple_choice" task types dispatch (several) "loglikelihood" request types
|
| 529 |
+
reqtype = (
|
| 530 |
+
"loglikelihood"
|
| 531 |
+
if task.OUTPUT_TYPE == "multiple_choice"
|
| 532 |
+
else task.OUTPUT_TYPE
|
| 533 |
+
)
|
| 534 |
+
# compute number of pseudo-batches to pad with (FSDP/DDP require even batches among ranks)
|
| 535 |
+
numpad = max(gathered_item) - gathered_item[lm.rank]
|
| 536 |
+
# todo: may not account for padding in cases like SquadV2 which has multiple req types
|
| 537 |
+
padding_requests[reqtype] += numpad
|
| 538 |
+
|
| 539 |
+
### Run LM on inputs, get all outputs ###
|
| 540 |
+
# execute each type of request
|
| 541 |
+
for reqtype, reqs in requests.items():
|
| 542 |
+
eval_logger.info(f"Running {reqtype} requests")
|
| 543 |
+
# create `K` copies of each request `req` based off `K = req.repeats`
|
| 544 |
+
cloned_reqs = []
|
| 545 |
+
for req in reqs:
|
| 546 |
+
cloned_reqs.extend([req] * req.repeats)
|
| 547 |
+
|
| 548 |
+
if (lm.world_size > 1) and (padding_requests[reqtype] > 0):
|
| 549 |
+
for _ in range(padding_requests[reqtype]):
|
| 550 |
+
cloned_reqs.extend([req] * req.repeats)
|
| 551 |
+
|
| 552 |
+
# run requests through model
|
| 553 |
+
resps = getattr(lm, reqtype)(cloned_reqs)
|
| 554 |
+
|
| 555 |
+
# put responses from model into a list of length K for each request.
|
| 556 |
+
for x, req in zip(resps, cloned_reqs):
|
| 557 |
+
req.resps.append(x)
|
| 558 |
+
|
| 559 |
+
if lm.world_size > 1:
|
| 560 |
+
lm.accelerator.wait_for_everyone()
|
| 561 |
+
|
| 562 |
+
RANK = lm.rank
|
| 563 |
+
WORLD_SIZE = lm.world_size
|
| 564 |
+
### Postprocess outputs ###
|
| 565 |
+
# TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
|
| 566 |
+
for task_output, limit in zip(eval_tasks, limits):
|
| 567 |
+
task = task_output.task
|
| 568 |
+
task.apply_filters()
|
| 569 |
+
|
| 570 |
+
### Collect values of metrics on all datapoints ###
|
| 571 |
+
# # unpack results and sort back in order and return control to Task
|
| 572 |
+
# TODO: make it possible to use a different metric per filter
|
| 573 |
+
# Pre-process task.instances to group by doc_id
|
| 574 |
+
instances_by_doc_id = defaultdict(list)
|
| 575 |
+
for instance in task.instances:
|
| 576 |
+
instances_by_doc_id[instance.doc_id].append(instance)
|
| 577 |
+
# Sort instances within each group
|
| 578 |
+
for instances in instances_by_doc_id.values():
|
| 579 |
+
instances.sort(key=lambda x: x.idx)
|
| 580 |
+
# iterate over different filters used
|
| 581 |
+
for filter_key in task.instances[0].filtered_resps.keys():
|
| 582 |
+
doc_iterator = task.doc_iterator(
|
| 583 |
+
rank=RANK, limit=limit, world_size=WORLD_SIZE
|
| 584 |
+
)
|
| 585 |
+
for doc_id, doc in doc_iterator:
|
| 586 |
+
requests = instances_by_doc_id[doc_id]
|
| 587 |
+
metrics = task.process_results(
|
| 588 |
+
doc, [req.filtered_resps[filter_key] for req in requests]
|
| 589 |
+
)
|
| 590 |
+
if log_samples:
|
| 591 |
+
target = task.doc_to_target(doc)
|
| 592 |
+
example = {
|
| 593 |
+
"doc_id": doc_id,
|
| 594 |
+
"doc": doc,
|
| 595 |
+
"target": target,
|
| 596 |
+
"arguments": [req.args for req in requests],
|
| 597 |
+
"resps": [req.resps for req in requests],
|
| 598 |
+
"filtered_resps": [
|
| 599 |
+
req.filtered_resps[filter_key] for req in requests
|
| 600 |
+
],
|
| 601 |
+
"filter": filter_key,
|
| 602 |
+
"metrics": list(metrics.keys()),
|
| 603 |
+
"doc_hash": hash_string(
|
| 604 |
+
json.dumps(
|
| 605 |
+
requests[0].doc,
|
| 606 |
+
indent=2,
|
| 607 |
+
default=handle_non_serializable,
|
| 608 |
+
ensure_ascii=False,
|
| 609 |
+
)
|
| 610 |
+
),
|
| 611 |
+
"prompt_hash": hash_string(requests[0].arguments[0]),
|
| 612 |
+
"target_hash": hash_string(str(target)),
|
| 613 |
+
}
|
| 614 |
+
example.update(metrics)
|
| 615 |
+
task_output.logged_samples.append(example)
|
| 616 |
+
for metric, value in metrics.items():
|
| 617 |
+
task_output.sample_metrics[(metric, filter_key)].append(value)
|
| 618 |
+
|
| 619 |
+
if WORLD_SIZE > 1:
|
| 620 |
+
# if multigpu, then gather data across all ranks to rank 0
|
| 621 |
+
# first gather logged samples across all ranks
|
| 622 |
+
for task_output in eval_tasks:
|
| 623 |
+
if log_samples:
|
| 624 |
+
# for task_name, task_samples in list(samples.items()):
|
| 625 |
+
full_samples = [None] * WORLD_SIZE if RANK == 0 else None
|
| 626 |
+
torch.distributed.gather_object(
|
| 627 |
+
obj=task_output.logged_samples,
|
| 628 |
+
object_gather_list=full_samples,
|
| 629 |
+
dst=0,
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
if RANK == 0:
|
| 633 |
+
task_output.logged_samples = list(
|
| 634 |
+
itertools.chain.from_iterable(full_samples)
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
# then collect metrics across all ranks
|
| 638 |
+
for metrics in task_output.sample_metrics:
|
| 639 |
+
metric_list = [None] * WORLD_SIZE if RANK == 0 else None
|
| 640 |
+
torch.distributed.gather_object(
|
| 641 |
+
obj=task_output.sample_metrics[metrics],
|
| 642 |
+
object_gather_list=metric_list,
|
| 643 |
+
dst=0,
|
| 644 |
+
)
|
| 645 |
+
if RANK == 0:
|
| 646 |
+
task_output.sample_metrics[metrics] = list(
|
| 647 |
+
itertools.chain.from_iterable(metric_list)
|
| 648 |
+
)
|
| 649 |
+
|
| 650 |
+
if RANK == 0:
|
| 651 |
+
### Aggregate results over all datapoints ###
|
| 652 |
+
# aggregate results ; run bootstrap CIs
|
| 653 |
+
for task_output in eval_tasks:
|
| 654 |
+
task_output.calculate_aggregate_metric(bootstrap_iters=bootstrap_iters)
|
| 655 |
+
(
|
| 656 |
+
results,
|
| 657 |
+
samples,
|
| 658 |
+
configs,
|
| 659 |
+
versions,
|
| 660 |
+
num_fewshot,
|
| 661 |
+
higher_is_better,
|
| 662 |
+
) = consolidate_results(eval_tasks)
|
| 663 |
+
|
| 664 |
+
### Calculate group metrics ###
|
| 665 |
+
if bool(results):
|
| 666 |
+
results, versions, show_group_table, *_ = consolidate_group_results(
|
| 667 |
+
results, versions, task_dict
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
results_agg, group_agg = prepare_print_tasks(task_dict, results)
|
| 671 |
+
subtask_list = get_subtask_list(task_dict)
|
| 672 |
+
|
| 673 |
+
# collect all higher_is_better values for metrics
|
| 674 |
+
# in the group's subtasks.
|
| 675 |
+
# TODO: clean this up ; unify with the below metric_list loop?
|
| 676 |
+
_higher_is_better = {}
|
| 677 |
+
for group, task_list in subtask_list.items():
|
| 678 |
+
if (
|
| 679 |
+
len(task_list) != 0
|
| 680 |
+
): # subtask list will list "task_name": [] for solo tasks
|
| 681 |
+
for task in task_list:
|
| 682 |
+
for m, h in higher_is_better[task].items():
|
| 683 |
+
if m not in _higher_is_better.keys():
|
| 684 |
+
_higher_is_better[m] = h
|
| 685 |
+
|
| 686 |
+
if (
|
| 687 |
+
m in _higher_is_better
|
| 688 |
+
and _higher_is_better[m] is not None
|
| 689 |
+
and _higher_is_better[m] != h
|
| 690 |
+
):
|
| 691 |
+
eval_logger.warning(
|
| 692 |
+
f"Higher_is_better values for metric {m} in group {group} are not consistent. Defaulting to None."
|
| 693 |
+
)
|
| 694 |
+
_higher_is_better[m] = None
|
| 695 |
+
higher_is_better[group] = _higher_is_better
|
| 696 |
+
|
| 697 |
+
results_dict = {
|
| 698 |
+
"results": dict(results_agg.items()),
|
| 699 |
+
**(
|
| 700 |
+
{"groups": dict(group_agg.items())}
|
| 701 |
+
if (bool(group_agg) & show_group_table)
|
| 702 |
+
else {}
|
| 703 |
+
),
|
| 704 |
+
"group_subtasks": dict(reversed(subtask_list.items())),
|
| 705 |
+
"configs": dict(sorted(configs.items())),
|
| 706 |
+
"versions": dict(sorted(versions.items())),
|
| 707 |
+
"n-shot": dict(sorted(num_fewshot.items())),
|
| 708 |
+
"higher_is_better": dict(sorted(higher_is_better.items())),
|
| 709 |
+
"n-samples": {
|
| 710 |
+
task_output.task_name: {
|
| 711 |
+
"original": len(task_output.task.eval_docs),
|
| 712 |
+
"effective": min(
|
| 713 |
+
limit if limit else len(task_output.task.eval_docs),
|
| 714 |
+
len(task_output.task.eval_docs),
|
| 715 |
+
),
|
| 716 |
+
}
|
| 717 |
+
for task_output, limit in zip(eval_tasks, limits)
|
| 718 |
+
},
|
| 719 |
+
}
|
| 720 |
+
if log_samples:
|
| 721 |
+
results_dict["samples"] = dict(samples)
|
| 722 |
+
|
| 723 |
+
return results_dict
|
| 724 |
+
|
| 725 |
+
else:
|
| 726 |
+
return None
|
| 727 |
+
|
| 728 |
+
|
| 729 |
+
def request_caching_arg_to_dict(cache_requests: str) -> dict:
|
| 730 |
+
request_caching_args = {
|
| 731 |
+
"cache_requests": cache_requests in {"true", "refresh"},
|
| 732 |
+
"rewrite_requests_cache": cache_requests == "refresh",
|
| 733 |
+
"delete_requests_cache": cache_requests == "delete",
|
| 734 |
+
}
|
| 735 |
+
|
| 736 |
+
return request_caching_args
|
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/evaluator_utils.py
ADDED
|
@@ -0,0 +1,554 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import logging
|
| 3 |
+
import math
|
| 4 |
+
import pathlib
|
| 5 |
+
import sys
|
| 6 |
+
from typing import List, Optional, Tuple, Union
|
| 7 |
+
|
| 8 |
+
from lm_eval.api.group import ConfigurableGroup
|
| 9 |
+
from lm_eval.api.metrics import (
|
| 10 |
+
aggregate_subtask_metrics,
|
| 11 |
+
mean,
|
| 12 |
+
pooled_sample_stderr,
|
| 13 |
+
stderr_for_metric,
|
| 14 |
+
)
|
| 15 |
+
from lm_eval.api.task import Task
|
| 16 |
+
from lm_eval.utils import positional_deprecated
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
eval_logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class TaskOutput:
|
| 23 |
+
"""
|
| 24 |
+
Wrapper class for Task outputs.It contains various attributes and methods to manage and calculate metrics for the task.
|
| 25 |
+
|
| 26 |
+
Attributes:
|
| 27 |
+
task (object): The task object.
|
| 28 |
+
task_name (str): The name of the task.
|
| 29 |
+
task_config (dict): The configuration of the task.
|
| 30 |
+
version (str): The version of the task.
|
| 31 |
+
group_name (str): The name of the task group.
|
| 32 |
+
n_shot (int): The number of shots for the task.
|
| 33 |
+
task_alias (str): The alias of the task.
|
| 34 |
+
group_alias (str): The alias of the task group.
|
| 35 |
+
is_group (bool): Indicates if the task is a group.
|
| 36 |
+
logged_samples (list): The list of logged samples.
|
| 37 |
+
sample_len (int): The length of the samples.
|
| 38 |
+
sample_metrics (defaultdict): The dictionary of samples' metrics.
|
| 39 |
+
agg_metrics (defaultdict): The dictionary of aggregate metrics.
|
| 40 |
+
|
| 41 |
+
Methods:
|
| 42 |
+
from_taskdict(cls, task_name: str, task):
|
| 43 |
+
Creates a TaskOutput instance from a task dictionary.
|
| 44 |
+
|
| 45 |
+
calculate_aggregate_metric(bootstrap_iters=100000) -> None:
|
| 46 |
+
Calculates the aggregate metrics for the task.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
task=None,
|
| 52 |
+
task_name=None,
|
| 53 |
+
task_config=None,
|
| 54 |
+
version=None,
|
| 55 |
+
group_name=None,
|
| 56 |
+
n_shot=None,
|
| 57 |
+
task_alias=None,
|
| 58 |
+
group_alias=None,
|
| 59 |
+
is_group=None,
|
| 60 |
+
):
|
| 61 |
+
self.task = task
|
| 62 |
+
self.task_config = task_config
|
| 63 |
+
self.task_name = task_name
|
| 64 |
+
self.group_name = group_name
|
| 65 |
+
self.version = version
|
| 66 |
+
self.n_shot = n_shot
|
| 67 |
+
self.task_alias = task_alias
|
| 68 |
+
self.group_alias = group_alias
|
| 69 |
+
self.is_group = is_group
|
| 70 |
+
self.logged_samples = []
|
| 71 |
+
self.sample_len = None
|
| 72 |
+
self.sample_metrics = collections.defaultdict(list)
|
| 73 |
+
self.agg_metrics = collections.defaultdict(list)
|
| 74 |
+
|
| 75 |
+
@classmethod
|
| 76 |
+
def from_taskdict(cls, task_name: str, task):
|
| 77 |
+
if isinstance(task, tuple):
|
| 78 |
+
group_name, task = task
|
| 79 |
+
else:
|
| 80 |
+
group_name = None
|
| 81 |
+
if not task:
|
| 82 |
+
# these gets filtered out in get_task_list
|
| 83 |
+
# once they are added to group hierarchy
|
| 84 |
+
is_group = True
|
| 85 |
+
return cls(
|
| 86 |
+
task=task, task_name=task_name, is_group=is_group, group_name=group_name
|
| 87 |
+
)
|
| 88 |
+
version = task.VERSION
|
| 89 |
+
task_config = dict(task.dump_config())
|
| 90 |
+
if (n_shot := task_config.get("num_fewshot")) == 0:
|
| 91 |
+
n_shot = task_config.get("metadata", {}).get("num_fewshot", 0)
|
| 92 |
+
task_alias = task_config.get("alias")
|
| 93 |
+
group_alias = task_config.get("group_alias")
|
| 94 |
+
return cls(
|
| 95 |
+
task=task,
|
| 96 |
+
task_name=task_name,
|
| 97 |
+
task_config=task_config,
|
| 98 |
+
group_name=group_name,
|
| 99 |
+
version=version,
|
| 100 |
+
n_shot=n_shot,
|
| 101 |
+
task_alias=task_alias,
|
| 102 |
+
group_alias=group_alias,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
def calculate_aggregate_metric(self, bootstrap_iters=100000) -> None:
|
| 106 |
+
for (metric, filter_key), items in self.sample_metrics.items():
|
| 107 |
+
try:
|
| 108 |
+
agg_fn = self.task.aggregation()[metric]
|
| 109 |
+
except KeyError:
|
| 110 |
+
# This is when process results output an arbitrary metric
|
| 111 |
+
# TODO: Handle this better and allow other aggregate functions other than mean.
|
| 112 |
+
agg_fn = mean
|
| 113 |
+
metric_key = f"{metric},{filter_key}"
|
| 114 |
+
self.agg_metrics[metric_key] = agg_fn(items)
|
| 115 |
+
self.sample_len = len(items) # TODO: same sample size for each metric?
|
| 116 |
+
if isinstance(bootstrap_iters, int):
|
| 117 |
+
stderr_fn = stderr_for_metric(
|
| 118 |
+
metric=agg_fn,
|
| 119 |
+
bootstrap_iters=min(bootstrap_iters, 100)
|
| 120 |
+
if metric in ["bleu", "chrf", "ter"]
|
| 121 |
+
else bootstrap_iters,
|
| 122 |
+
)
|
| 123 |
+
self.agg_metrics[f"{metric}_stderr,{filter_key}"] = (
|
| 124 |
+
stderr_fn(items) if (stderr_fn and len(items) > 1) else "N/A"
|
| 125 |
+
)
|
| 126 |
+
else:
|
| 127 |
+
raise ValueError(
|
| 128 |
+
f"Received bootstrap_iters '{bootstrap_iters}' but expected an integer. Set to 0 to turn off stderr calculations."
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
def __repr__(self):
|
| 132 |
+
return (
|
| 133 |
+
f"TaskOutput(task_name={self.task_name}, "
|
| 134 |
+
f"group_name={self.group_name}, "
|
| 135 |
+
f"version={self.version}, "
|
| 136 |
+
f"n_shot={self.n_shot}, "
|
| 137 |
+
f"task_alias={self.task_alias}, "
|
| 138 |
+
f"group_alias={self.group_alias})"
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def get_task_list(task_dict: dict) -> List[TaskOutput]:
|
| 143 |
+
outputs = []
|
| 144 |
+
for task_name, task_obj in task_dict.items():
|
| 145 |
+
if isinstance(task_obj, dict):
|
| 146 |
+
_outputs = get_task_list(task_obj)
|
| 147 |
+
outputs.extend(_outputs)
|
| 148 |
+
else:
|
| 149 |
+
task_output = TaskOutput.from_taskdict(task_name, task_obj)
|
| 150 |
+
outputs.append(task_output)
|
| 151 |
+
|
| 152 |
+
return outputs
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def get_subtask_list(task_dict, task_root=None, depth=0):
|
| 156 |
+
subtask_list = {}
|
| 157 |
+
for group_obj, task_obj in task_dict.items():
|
| 158 |
+
if isinstance(group_obj, ConfigurableGroup):
|
| 159 |
+
# group_name = group_obj.group_name
|
| 160 |
+
group_name = group_obj.group_name
|
| 161 |
+
else:
|
| 162 |
+
group_name = group_obj
|
| 163 |
+
if isinstance(task_obj, dict):
|
| 164 |
+
_subtask_list = get_subtask_list(
|
| 165 |
+
task_obj, task_root=group_name, depth=depth + 1
|
| 166 |
+
)
|
| 167 |
+
if task_root:
|
| 168 |
+
subtask_list.setdefault((task_root, depth), []).extend(
|
| 169 |
+
[
|
| 170 |
+
_task
|
| 171 |
+
for (_task, _depth) in _subtask_list.keys()
|
| 172 |
+
if (_depth - 1) == depth
|
| 173 |
+
]
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
subtask_list = {**subtask_list, **_subtask_list}
|
| 177 |
+
else:
|
| 178 |
+
if isinstance(task_obj, ConfigurableGroup):
|
| 179 |
+
# group_or_task_name = task_obj.group_name
|
| 180 |
+
group_or_task_name = task_obj.group_name
|
| 181 |
+
elif isinstance(task_obj, Task):
|
| 182 |
+
# group_or_task_name = task_obj.task_name
|
| 183 |
+
group_or_task_name = task_obj.task_name
|
| 184 |
+
|
| 185 |
+
if task_root is None:
|
| 186 |
+
subtask_list.setdefault((group_or_task_name, depth), [])
|
| 187 |
+
else:
|
| 188 |
+
subtask_list.setdefault((task_root, depth), []).append(
|
| 189 |
+
group_or_task_name
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
if depth == 0:
|
| 193 |
+
_subtask_list = {}
|
| 194 |
+
for group_key, task_list in subtask_list.items():
|
| 195 |
+
group_name, depth = group_key
|
| 196 |
+
_subtask_list[group_name] = task_list
|
| 197 |
+
subtask_list = _subtask_list
|
| 198 |
+
|
| 199 |
+
return subtask_list
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def print_writeout(task) -> None:
|
| 203 |
+
for inst in task.instances:
|
| 204 |
+
# print the prompt for the first few documents
|
| 205 |
+
if inst.doc_id < 1:
|
| 206 |
+
eval_logger.info(
|
| 207 |
+
f"Task: {task}; document {inst.doc_id}; context prompt (starting on next line):\
|
| 208 |
+
\n{inst.args[0]}\n(end of prompt on previous line)\ntarget string or answer choice index (starting on next line):\n{task.doc_to_target(inst.doc)}\n(end of target on previous line)"
|
| 209 |
+
)
|
| 210 |
+
eval_logger.info(f"Request: {str(inst)}")
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def get_sample_size(task, limit: Optional[int]) -> Union[int, None]:
|
| 214 |
+
if limit is not None:
|
| 215 |
+
limit = (
|
| 216 |
+
int(math.ceil(len(task.eval_docs) * limit)) if limit < 1.0 else int(limit)
|
| 217 |
+
)
|
| 218 |
+
return limit
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def prepare_print_tasks(
|
| 222 |
+
task_dict: dict,
|
| 223 |
+
results: dict,
|
| 224 |
+
task_depth=0,
|
| 225 |
+
group_depth=0,
|
| 226 |
+
) -> Tuple[dict, dict]:
|
| 227 |
+
"""
|
| 228 |
+
@param task_dict: Dictionary representing the group hierarchy of tasks. Each key is a group name and its
|
| 229 |
+
value is a list of task names.
|
| 230 |
+
@param results: Dictionary containing the results of each task. Each key is a
|
| 231 |
+
group name and its value is a dictionary of task results.
|
| 232 |
+
@param task_depth: The indentation level for printing the task
|
| 233 |
+
hierarchy. Default is 0.
|
| 234 |
+
@param group_depth: The indentation level for printing the group
|
| 235 |
+
hierarchy. Default is 0.
|
| 236 |
+
@return: A tuple of two dictionaries: results_agg and groups_agg. results_agg contains
|
| 237 |
+
aggregated results for each task, and groups_agg contains aggregated results for each group.
|
| 238 |
+
|
| 239 |
+
Prepares the task hierarchy and aggregates the results for each task and group recursively for printing.
|
| 240 |
+
"""
|
| 241 |
+
|
| 242 |
+
def _sort_task_dict(task_dict):
|
| 243 |
+
"""
|
| 244 |
+
Helper utility. Sorts the task dict at the current level of the hierarchy based on alphabetized task name.
|
| 245 |
+
Required so that we end up sorting within each sub-header correctly.
|
| 246 |
+
"""
|
| 247 |
+
|
| 248 |
+
return dict(
|
| 249 |
+
sorted(
|
| 250 |
+
task_dict.items(),
|
| 251 |
+
key=lambda item: item[0].group_name
|
| 252 |
+
if isinstance(item[0], ConfigurableGroup)
|
| 253 |
+
else item[0],
|
| 254 |
+
)
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
task_agg = collections.defaultdict(dict)
|
| 258 |
+
group_agg = collections.defaultdict(dict)
|
| 259 |
+
task_dict = _sort_task_dict(task_dict)
|
| 260 |
+
for task_or_group_name, task_or_group_obj in task_dict.items():
|
| 261 |
+
tab_string = " " * task_depth + "- " if task_depth > 0 else ""
|
| 262 |
+
if isinstance(task_or_group_name, ConfigurableGroup):
|
| 263 |
+
# string_name = task_or_group_name.group_name
|
| 264 |
+
name = task_or_group_name.group_name
|
| 265 |
+
from_configurable_group = True
|
| 266 |
+
task_or_group_obj = _sort_task_dict(task_or_group_obj)
|
| 267 |
+
elif isinstance(task_or_group_name, str):
|
| 268 |
+
name = task_or_group_name
|
| 269 |
+
if isinstance(task_or_group_obj, Task):
|
| 270 |
+
# string_name = task_or_group_obj.task_name
|
| 271 |
+
name = task_or_group_obj.task_name
|
| 272 |
+
from_configurable_group = False
|
| 273 |
+
|
| 274 |
+
task_agg[name] = results[name].copy()
|
| 275 |
+
if from_configurable_group:
|
| 276 |
+
if task_or_group_name.group_alias is not None:
|
| 277 |
+
alias = task_or_group_name.group_alias
|
| 278 |
+
else:
|
| 279 |
+
alias = task_or_group_name.group
|
| 280 |
+
else:
|
| 281 |
+
if "alias" in task_agg[name]:
|
| 282 |
+
alias = task_agg[name]["alias"]
|
| 283 |
+
else:
|
| 284 |
+
alias = name
|
| 285 |
+
|
| 286 |
+
task_agg[name]["alias"] = tab_string + alias
|
| 287 |
+
if "samples" in task_agg[name]:
|
| 288 |
+
task_agg[name].pop("samples")
|
| 289 |
+
|
| 290 |
+
if from_configurable_group and (" " not in results[name]):
|
| 291 |
+
group_tab_string = " " * group_depth + "- " if group_depth > 0 else ""
|
| 292 |
+
group_agg[name] = results[name].copy()
|
| 293 |
+
group_agg[name]["alias"] = group_tab_string + alias
|
| 294 |
+
if "samples" in group_agg[name]:
|
| 295 |
+
group_agg[name].pop("samples")
|
| 296 |
+
|
| 297 |
+
if isinstance(task_or_group_obj, dict):
|
| 298 |
+
task_depth += 1
|
| 299 |
+
group_depth += 1
|
| 300 |
+
_task_agg, _group_agg = prepare_print_tasks(
|
| 301 |
+
task_or_group_obj, results, task_depth, group_depth
|
| 302 |
+
)
|
| 303 |
+
task_agg = {
|
| 304 |
+
**task_agg,
|
| 305 |
+
**_task_agg,
|
| 306 |
+
}
|
| 307 |
+
group_agg = {**group_agg, **_group_agg}
|
| 308 |
+
task_depth -= 1
|
| 309 |
+
group_depth -= 1
|
| 310 |
+
return task_agg, group_agg
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def consolidate_results(
|
| 314 |
+
eval_tasks: List[TaskOutput],
|
| 315 |
+
) -> Tuple[dict, dict, dict, dict, dict, dict]:
|
| 316 |
+
"""
|
| 317 |
+
@param eval_tasks: list(TaskOutput).
|
| 318 |
+
@return: A tuple containing the consolidated results, samples, configs, versions, and num_fewshot.
|
| 319 |
+
|
| 320 |
+
Consolidates the results of multiple evaluation tasks into a single structure.
|
| 321 |
+
|
| 322 |
+
The method iterates over each evaluation instance and extracts relevant information to create the consolidated
|
| 323 |
+
results structure. The consolidated results structure has the following properties:
|
| 324 |
+
|
| 325 |
+
- results: A defaultdict with task names as keys and dictionaries as values. Each dictionary contains
|
| 326 |
+
metric/filter pairs as keys and corresponding metric values as values. The "alias" key is used to store task
|
| 327 |
+
aliases specified in the task configuration.
|
| 328 |
+
- samples: A defaultdict with task names as keys and lists of log samples as values.
|
| 329 |
+
- configs: A defaultdict with task names as keys and task configurations as values.
|
| 330 |
+
- versions: A defaultdict with task names as keys and task versions as values.
|
| 331 |
+
- num_fewshot: A defaultdict with task names as keys and number of few-shot samples as values.
|
| 332 |
+
- higher_is_better: A defaultdict with task names as keys and indicators of whether higher values are better
|
| 333 |
+
for each metric as values.
|
| 334 |
+
|
| 335 |
+
The method then returns the consolidated results, samples, configs, versions, and num_fewshot as a tuple.
|
| 336 |
+
"""
|
| 337 |
+
# stores the final result for each task, for each metric/filter pair.
|
| 338 |
+
results = collections.defaultdict(dict)
|
| 339 |
+
# logs info about each document evaluated.
|
| 340 |
+
samples = collections.defaultdict(list)
|
| 341 |
+
# store num-fewshot value per task
|
| 342 |
+
num_fewshot = collections.defaultdict(int)
|
| 343 |
+
# Tracks the YAML configs of all chosen task
|
| 344 |
+
configs = collections.defaultdict(dict)
|
| 345 |
+
# Tracks each task's version.
|
| 346 |
+
versions = collections.defaultdict(dict)
|
| 347 |
+
# Track `higher_is_better` for each metric
|
| 348 |
+
higher_is_better = collections.defaultdict(dict)
|
| 349 |
+
|
| 350 |
+
for task_output in eval_tasks:
|
| 351 |
+
if "task_alias" in (task_config := task_output.task_config):
|
| 352 |
+
results[task_output.task_name]["alias"] = task_config["task_alias"]
|
| 353 |
+
else:
|
| 354 |
+
results[task_output.task_name]["alias"] = task_output.task_name
|
| 355 |
+
if group_alias := task_output.group_alias:
|
| 356 |
+
if group_alias not in results and (group_name := task_output.group_name):
|
| 357 |
+
results[group_name]["alias"] = group_alias
|
| 358 |
+
num_fewshot[task_output.task_name] = task_output.n_shot
|
| 359 |
+
configs[task_output.task_name] = task_output.task_config
|
| 360 |
+
versions[task_output.task_name] = task_output.version
|
| 361 |
+
samples[task_output.task_name] = task_output.logged_samples
|
| 362 |
+
higher_is_better[task_output.task_name] = task_output.task.higher_is_better()
|
| 363 |
+
for (metric, filter_key), items in task_output.sample_metrics.items():
|
| 364 |
+
metric_key = f"{metric},{filter_key}"
|
| 365 |
+
results[task_output.task_name][metric_key] = task_output.agg_metrics[
|
| 366 |
+
metric_key
|
| 367 |
+
]
|
| 368 |
+
results[task_output.task_name]["samples"] = task_output.sample_len
|
| 369 |
+
results[task_output.task_name][f"{metric}_stderr,{filter_key}"] = (
|
| 370 |
+
task_output.agg_metrics[f"{metric}_stderr,{filter_key}"]
|
| 371 |
+
)
|
| 372 |
+
return results, samples, configs, versions, num_fewshot, higher_is_better
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def consolidate_group_results(
|
| 376 |
+
results,
|
| 377 |
+
versions,
|
| 378 |
+
task_dict,
|
| 379 |
+
task_root=None,
|
| 380 |
+
show_group_table=False,
|
| 381 |
+
task_aggregation_list=None,
|
| 382 |
+
) -> Tuple[dict, dict, bool, Union[None,]]:
|
| 383 |
+
"""
|
| 384 |
+
(Recursively) calculates groups' aggregated metrics and updates the results and versions dictionaries with this info.
|
| 385 |
+
|
| 386 |
+
@return: a tuple [results, versions, show_group_table, task_aggregation_list] with formats described below:
|
| 387 |
+
|
| 388 |
+
- results: A defaultdict with task names (and, after this function is called, group names of
|
| 389 |
+
groups that perform aggregation) as keys, and dictionaries with "alias" and metric,filter_name pairs as keys.
|
| 390 |
+
- versions: A defaultdict with task names (and, after this function is called, group names of
|
| 391 |
+
groups that perform aggregation) as keys, and float values representing the task or group's version if a version is specified. (defaulting to None).
|
| 392 |
+
- show_group_table: a boolean which is true if there exists a group that requires printing of its aggregated scores in a group table.
|
| 393 |
+
- task_aggregation_list: a defaultdict listing the subtasks to average over to produce a given group's end metric.
|
| 394 |
+
|
| 395 |
+
The method then returns the updated results, versions, show_group_table, and task_aggregation_list as a tuple.
|
| 396 |
+
In the top-level invocation of this function, task_aggregation_list is ignored.
|
| 397 |
+
"""
|
| 398 |
+
if task_root is None:
|
| 399 |
+
task_root = {}
|
| 400 |
+
|
| 401 |
+
if task_aggregation_list is None:
|
| 402 |
+
task_aggregation_list = {}
|
| 403 |
+
|
| 404 |
+
for group_or_task, group_or_task_info in task_dict.items():
|
| 405 |
+
# Convert to string
|
| 406 |
+
if isinstance(group_or_task, ConfigurableGroup):
|
| 407 |
+
group_config = group_or_task.config
|
| 408 |
+
group_or_task = group_or_task.group_name
|
| 409 |
+
else:
|
| 410 |
+
group_config = None
|
| 411 |
+
|
| 412 |
+
if isinstance(group_or_task_info, Task):
|
| 413 |
+
if task_root:
|
| 414 |
+
task_aggregation_list.setdefault(task_root, []).append(
|
| 415 |
+
group_or_task_info.task_name
|
| 416 |
+
)
|
| 417 |
+
else:
|
| 418 |
+
(
|
| 419 |
+
results,
|
| 420 |
+
versions,
|
| 421 |
+
show_group_table,
|
| 422 |
+
_task_aggregation_list,
|
| 423 |
+
) = consolidate_group_results(
|
| 424 |
+
results,
|
| 425 |
+
versions,
|
| 426 |
+
group_or_task_info,
|
| 427 |
+
group_or_task,
|
| 428 |
+
show_group_table,
|
| 429 |
+
task_aggregation_list,
|
| 430 |
+
)
|
| 431 |
+
if task_root:
|
| 432 |
+
task_aggregation_list.setdefault(task_root, []).extend(
|
| 433 |
+
task_aggregation_list.get(group_or_task, [])
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
if (group_config is None) or (
|
| 437 |
+
group_config["aggregate_metric_list"] is None
|
| 438 |
+
):
|
| 439 |
+
results[group_or_task][" "] = " "
|
| 440 |
+
continue
|
| 441 |
+
|
| 442 |
+
if "aggregate_metric_list" in group_config:
|
| 443 |
+
agg_metric_list = group_config["aggregate_metric_list"]
|
| 444 |
+
|
| 445 |
+
show_group_table = show_group_table | bool(
|
| 446 |
+
group_config["aggregate_metric_list"]
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
task_list = _task_aggregation_list[group_or_task]
|
| 450 |
+
|
| 451 |
+
metric_list = list(
|
| 452 |
+
{
|
| 453 |
+
key
|
| 454 |
+
for task in task_list
|
| 455 |
+
for key in results[task].keys()
|
| 456 |
+
if "_stderr" not in key and key not in ["task", "alias", "samples"]
|
| 457 |
+
}
|
| 458 |
+
)
|
| 459 |
+
for metric in metric_list:
|
| 460 |
+
stderr = "_stderr,".join(metric.split(","))
|
| 461 |
+
|
| 462 |
+
# gather metrics, sizes, and stderrs from subtasks
|
| 463 |
+
metrics = [
|
| 464 |
+
results[task][metric]
|
| 465 |
+
for task in task_list
|
| 466 |
+
if metric in results[task]
|
| 467 |
+
] # TODO: copy?
|
| 468 |
+
stderrs = [
|
| 469 |
+
results[task][stderr]
|
| 470 |
+
for task in task_list
|
| 471 |
+
if stderr in results[task]
|
| 472 |
+
]
|
| 473 |
+
sizes = [
|
| 474 |
+
results[task]["samples"]
|
| 475 |
+
for task in task_list
|
| 476 |
+
if metric in results[task]
|
| 477 |
+
]
|
| 478 |
+
|
| 479 |
+
for metric_config in agg_metric_list:
|
| 480 |
+
for filter_name in metric_config["filter_list"]:
|
| 481 |
+
if metric != ",".join([metric_config["metric"], filter_name]):
|
| 482 |
+
continue
|
| 483 |
+
|
| 484 |
+
# compute group's pooled metric and stderr
|
| 485 |
+
if metric_config["aggregation"] == "mean":
|
| 486 |
+
aggregate_fn = aggregate_subtask_metrics
|
| 487 |
+
elif callable(metric_config["aggregation"]):
|
| 488 |
+
aggregate_fn = metric_config["aggregation"]
|
| 489 |
+
else:
|
| 490 |
+
raise ValueError(
|
| 491 |
+
f"Currently, only 'mean' is supported for automatically aggregating scores across groups' subtasks. Got '{metric_config['aggregation']}' for group '{group_or_task}'"
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
results[group_or_task][metric] = aggregate_fn(
|
| 495 |
+
metrics,
|
| 496 |
+
sizes,
|
| 497 |
+
metric_config["weight_by_size"],
|
| 498 |
+
)
|
| 499 |
+
# TODO: calculate groups' metrics using arbitrary agg fns
|
| 500 |
+
if "N/A" in stderrs:
|
| 501 |
+
results[group_or_task][stderr] = "N/A"
|
| 502 |
+
else:
|
| 503 |
+
# NOTE: this assumes we are using the mean to aggregate. There are warnings about this elsewhere
|
| 504 |
+
results[group_or_task][stderr] = pooled_sample_stderr(
|
| 505 |
+
stderrs, sizes
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
results[group_or_task]["samples"] = sum(sizes)
|
| 509 |
+
group_metadata = group_config.get("metadata", None)
|
| 510 |
+
if group_metadata is not None:
|
| 511 |
+
versions[group_or_task] = group_metadata.get("version", None)
|
| 512 |
+
# print(results)
|
| 513 |
+
return results, versions, show_group_table, task_aggregation_list
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
@positional_deprecated
|
| 517 |
+
def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
|
| 518 |
+
"""
|
| 519 |
+
Search upward in the directory tree to a maximum of three layers
|
| 520 |
+
to find and return the package root (containing the 'tests' folder)
|
| 521 |
+
"""
|
| 522 |
+
cur_path = start_path.resolve()
|
| 523 |
+
max_layers = 3
|
| 524 |
+
for _ in range(max_layers):
|
| 525 |
+
if (cur_path / "tests" / "test_version_stable.py").exists():
|
| 526 |
+
return cur_path
|
| 527 |
+
else:
|
| 528 |
+
cur_path = cur_path.parent.resolve()
|
| 529 |
+
raise FileNotFoundError(
|
| 530 |
+
f"Unable to find package root within {max_layers} upwards" + f"of {start_path}"
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
@positional_deprecated
|
| 535 |
+
def run_task_tests(task_list: List[str]):
|
| 536 |
+
"""
|
| 537 |
+
Find the package root and run the tests for the given tasks
|
| 538 |
+
"""
|
| 539 |
+
import pytest
|
| 540 |
+
|
| 541 |
+
package_root = find_test_root(start_path=pathlib.Path(__file__))
|
| 542 |
+
task_string = " or ".join(task_list)
|
| 543 |
+
args = [
|
| 544 |
+
f"{package_root}/tests/test_version_stable.py",
|
| 545 |
+
f"--rootdir={package_root}",
|
| 546 |
+
"-k",
|
| 547 |
+
f"{task_string}",
|
| 548 |
+
]
|
| 549 |
+
sys.path.append(str(package_root))
|
| 550 |
+
pytest_return_val = pytest.main(args)
|
| 551 |
+
if pytest_return_val:
|
| 552 |
+
raise ValueError(
|
| 553 |
+
f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}"
|
| 554 |
+
)
|
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
from lm_eval.api.filter import FilterEnsemble
|
| 5 |
+
from lm_eval.api.registry import get_filter
|
| 6 |
+
|
| 7 |
+
from . import custom, extraction, selection, transformation
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def build_filter_ensemble(
|
| 11 |
+
filter_name: str, components: List[List[str]]
|
| 12 |
+
) -> FilterEnsemble:
|
| 13 |
+
"""
|
| 14 |
+
Create a filtering pipeline.
|
| 15 |
+
"""
|
| 16 |
+
filters = []
|
| 17 |
+
for function, kwargs in components:
|
| 18 |
+
if kwargs is None:
|
| 19 |
+
kwargs = {}
|
| 20 |
+
# create a filter given its name in the registry
|
| 21 |
+
f = partial(get_filter(function), **kwargs)
|
| 22 |
+
# add the filter as a pipeline step
|
| 23 |
+
filters.append(f)
|
| 24 |
+
|
| 25 |
+
return FilterEnsemble(name=filter_name, filters=filters)
|
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/custom.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from lm_eval.api.filter import Filter
|
| 2 |
+
from lm_eval.api.registry import register_filter
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@register_filter("custom")
|
| 6 |
+
class CustomFilter(Filter):
|
| 7 |
+
"""
|
| 8 |
+
Custom filter that applies a custom, user-defined function to the model responses.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def __init__(self, **kwargs) -> None:
|
| 12 |
+
self.filter_fn = kwargs.pop("filter_fn")
|
| 13 |
+
|
| 14 |
+
super().__init__(**kwargs)
|
| 15 |
+
|
| 16 |
+
def apply(self, resps, docs):
|
| 17 |
+
return self.filter_fn(resps, docs)
|
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/decontamination.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from lm_eval.api.filter import Filter
|
| 2 |
+
from lm_eval.api.registry import register_filter
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@register_filter("decontaminate")
|
| 6 |
+
class DecontaminationFilter(Filter):
|
| 7 |
+
"""
|
| 8 |
+
A filter which evaluates
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
name = "track_decontamination"
|
| 12 |
+
|
| 13 |
+
def __init__(self, path) -> None:
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path").
|
| 17 |
+
should further cache result on a given (task_name, doc_id)
|
| 18 |
+
"""
|
| 19 |
+
self._decontam_results = None
|
| 20 |
+
|
| 21 |
+
def apply(self, resps, docs) -> None:
|
| 22 |
+
"""
|
| 23 |
+
Return {"no_contamination", "only_contamination"} keys for the 2 different subsets
|
| 24 |
+
"""
|
| 25 |
+
pass
|
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/extraction.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import sys
|
| 3 |
+
import unicodedata
|
| 4 |
+
|
| 5 |
+
from lm_eval.api.filter import Filter
|
| 6 |
+
from lm_eval.api.registry import register_filter
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@register_filter("regex")
|
| 10 |
+
class RegexFilter(Filter):
|
| 11 |
+
"""A filter that extracts values from text using regex pattern matching.
|
| 12 |
+
|
| 13 |
+
This filter applies a regex pattern to each model response and extracts matched values.
|
| 14 |
+
If no match is found, returns a fallback value. Useful for extracting structured data
|
| 15 |
+
(like numbers) from unstructured model outputs.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
|
| 21 |
+
group_select: int = 0,
|
| 22 |
+
fallback: str = "[invalid]",
|
| 23 |
+
) -> None:
|
| 24 |
+
"""
|
| 25 |
+
pass a string `regex` to run `re.compile(r"regex")` on.
|
| 26 |
+
`fallback` defines the output returned if no matches for the regex are located.
|
| 27 |
+
"""
|
| 28 |
+
self.regex_pattern = regex_pattern
|
| 29 |
+
self.regex = re.compile(regex_pattern)
|
| 30 |
+
self.group_select = group_select
|
| 31 |
+
self.fallback = fallback
|
| 32 |
+
|
| 33 |
+
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
|
| 34 |
+
# here, we assume we have a list, in which each element is
|
| 35 |
+
# a list of model responses for some particular input/target pair.
|
| 36 |
+
# so we process each of these (same input/target response sets)
|
| 37 |
+
# independently (and keep them a list.)
|
| 38 |
+
def filter_set(inst):
|
| 39 |
+
filtered = []
|
| 40 |
+
for resp in inst:
|
| 41 |
+
match = self.regex.findall(resp)
|
| 42 |
+
if match:
|
| 43 |
+
match = match[self.group_select]
|
| 44 |
+
if isinstance(match, tuple):
|
| 45 |
+
match = [m for m in match if m]
|
| 46 |
+
if match:
|
| 47 |
+
match = match[0]
|
| 48 |
+
else:
|
| 49 |
+
match = self.fallback
|
| 50 |
+
match = match.strip()
|
| 51 |
+
else:
|
| 52 |
+
match = self.fallback
|
| 53 |
+
filtered.append(match)
|
| 54 |
+
return filtered
|
| 55 |
+
|
| 56 |
+
filtered_resps = list(map(lambda x: filter_set(x), resps))
|
| 57 |
+
|
| 58 |
+
return filtered_resps
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@register_filter("remove_whitespace")
|
| 62 |
+
class WhitespaceFilter(Filter):
|
| 63 |
+
"""Filters out leading whitespace from responses."""
|
| 64 |
+
|
| 65 |
+
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
|
| 66 |
+
def filter_set(inst):
|
| 67 |
+
filtered_resp = []
|
| 68 |
+
for resp in inst:
|
| 69 |
+
resp = resp.lstrip()
|
| 70 |
+
filtered_resp.append(resp)
|
| 71 |
+
return filtered_resp
|
| 72 |
+
|
| 73 |
+
filtered_resps = [filter_set(resp) for resp in resps]
|
| 74 |
+
|
| 75 |
+
return filtered_resps
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@register_filter("multi_choice_regex")
|
| 79 |
+
class MultiChoiceRegexFilter(RegexFilter):
|
| 80 |
+
"""
|
| 81 |
+
A filter used to extract a model's answer on multiple choice questions with
|
| 82 |
+
letter answers. assumes each document has a "choices" field
|
| 83 |
+
containing the list of answer choices and that the answer label symbols
|
| 84 |
+
are of the form (A), (B), (C), ... or A, B, C.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
|
| 90 |
+
group_select=0,
|
| 91 |
+
fallback: str = "[invalid]",
|
| 92 |
+
ignore_case=False,
|
| 93 |
+
ignore_punctuation=False,
|
| 94 |
+
regexes_to_ignore=None,
|
| 95 |
+
) -> None:
|
| 96 |
+
"""
|
| 97 |
+
regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
|
| 98 |
+
- step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response.
|
| 99 |
+
- step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices.
|
| 100 |
+
group_select: Selects the (group_select)th match from the findall result.
|
| 101 |
+
ignore_case: Ignores the case during step 1 matching
|
| 102 |
+
ignore_punctuation: Remove the punctuation during step 1 matching
|
| 103 |
+
regexes_to_ignore: Remove these regexes during step 1 matching
|
| 104 |
+
"""
|
| 105 |
+
super().__init__(regex_pattern, group_select, fallback)
|
| 106 |
+
self.ignore_case = ignore_case
|
| 107 |
+
self.ignore_punctuation = ignore_punctuation
|
| 108 |
+
self.regexes_to_ignore = regexes_to_ignore
|
| 109 |
+
|
| 110 |
+
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
|
| 111 |
+
# here, we assume we have a list, in which each element is
|
| 112 |
+
# a list of model responses for some particular input/target pair.
|
| 113 |
+
# so we process each of these (same input/target response sets)
|
| 114 |
+
# independently (and keep them a list.)
|
| 115 |
+
|
| 116 |
+
def find_match(regex, resp, convert_dict={}):
|
| 117 |
+
match = regex.findall(resp)
|
| 118 |
+
if match:
|
| 119 |
+
match = match[self.group_select]
|
| 120 |
+
if isinstance(match, tuple):
|
| 121 |
+
match = [m for m in match if m][0]
|
| 122 |
+
match = match.strip()
|
| 123 |
+
if match and match in convert_dict:
|
| 124 |
+
match = convert_dict[match]
|
| 125 |
+
return match
|
| 126 |
+
|
| 127 |
+
punct_tbl = dict.fromkeys(
|
| 128 |
+
i
|
| 129 |
+
for i in range(sys.maxunicode)
|
| 130 |
+
if unicodedata.category(chr(i)).startswith("P")
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
def filter_ignores(st):
|
| 134 |
+
if self.regexes_to_ignore is not None:
|
| 135 |
+
for s in self.regexes_to_ignore:
|
| 136 |
+
st = re.sub(s, "", st)
|
| 137 |
+
|
| 138 |
+
if self.ignore_case:
|
| 139 |
+
st = st.lower()
|
| 140 |
+
|
| 141 |
+
if self.ignore_punctuation:
|
| 142 |
+
# https://stackoverflow.com/a/266162
|
| 143 |
+
st = st.translate(punct_tbl)
|
| 144 |
+
return st
|
| 145 |
+
|
| 146 |
+
filtered_resps = []
|
| 147 |
+
|
| 148 |
+
for r, doc in zip(resps, docs):
|
| 149 |
+
fallback_regexes = []
|
| 150 |
+
choice_to_alpha = {}
|
| 151 |
+
next_alpha = "A"
|
| 152 |
+
|
| 153 |
+
without_paren_fallback_regexes = []
|
| 154 |
+
without_paren_to_target = {}
|
| 155 |
+
|
| 156 |
+
choices = doc["choices"]
|
| 157 |
+
for c in choices:
|
| 158 |
+
m = filter_ignores(c.strip())
|
| 159 |
+
fallback_regexes.append(f"{re.escape(m)}")
|
| 160 |
+
choice_to_alpha[m] = f"({next_alpha})"
|
| 161 |
+
|
| 162 |
+
without_paren_fallback_regexes.append(next_alpha)
|
| 163 |
+
without_paren_to_target[next_alpha] = f"({next_alpha})"
|
| 164 |
+
|
| 165 |
+
next_alpha = chr(ord(next_alpha) + 1)
|
| 166 |
+
fallback_regex = re.compile("|".join(fallback_regexes))
|
| 167 |
+
without_paren_fallback_regex = "|".join(without_paren_fallback_regexes)
|
| 168 |
+
without_paren_fallback_regex = re.compile(
|
| 169 |
+
rf":[\s]*({without_paren_fallback_regex})"
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
filtered = []
|
| 173 |
+
for resp in r:
|
| 174 |
+
match = find_match(self.regex, resp)
|
| 175 |
+
if not match:
|
| 176 |
+
match = find_match(
|
| 177 |
+
fallback_regex, filter_ignores(resp), choice_to_alpha
|
| 178 |
+
)
|
| 179 |
+
if not match:
|
| 180 |
+
match = find_match(
|
| 181 |
+
without_paren_fallback_regex, resp, without_paren_to_target
|
| 182 |
+
)
|
| 183 |
+
if not match:
|
| 184 |
+
match = self.fallback
|
| 185 |
+
filtered.append(match)
|
| 186 |
+
filtered_resps.append(filtered)
|
| 187 |
+
|
| 188 |
+
return filtered_resps
|
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/selection.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import Counter
|
| 2 |
+
|
| 3 |
+
from lm_eval.api.filter import Filter
|
| 4 |
+
from lm_eval.api.registry import register_filter
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# TODO: implement "arg_max" filter. either it should take in an arbitrary "scoring"/reward function
|
| 8 |
+
# that takes an input and returns a scalar and then should select the max reward,
|
| 9 |
+
# or should implement different filters for different ways of handling a reward model's inference.
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@register_filter("take_first")
|
| 13 |
+
class TakeFirstFilter(Filter):
|
| 14 |
+
def __init__(self) -> None:
|
| 15 |
+
"""
|
| 16 |
+
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def apply(self, resps, docs):
|
| 20 |
+
"""
|
| 21 |
+
Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
|
| 22 |
+
"""
|
| 23 |
+
return map(lambda r: r[0], resps)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@register_filter("take_first_k")
|
| 27 |
+
class TakeKFilter(Filter):
|
| 28 |
+
def __init__(self, **kwargs) -> None:
|
| 29 |
+
self.k = kwargs.pop("k")
|
| 30 |
+
|
| 31 |
+
super().__init__(**kwargs)
|
| 32 |
+
|
| 33 |
+
def apply(self, resps, docs):
|
| 34 |
+
# need resp to be subscriptable to check below
|
| 35 |
+
resps = list(resps)
|
| 36 |
+
# check we have at least k responses per doc, else we can't take the first k
|
| 37 |
+
assert len(resps[0]) >= self.k, (
|
| 38 |
+
f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ."
|
| 39 |
+
)
|
| 40 |
+
return map(lambda r: r[: self.k], resps)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@register_filter("majority_vote")
|
| 44 |
+
class MajorityVoteFilter(Filter):
|
| 45 |
+
def __init__(self) -> None:
|
| 46 |
+
"""
|
| 47 |
+
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def apply(self, resps, docs):
|
| 51 |
+
"""
|
| 52 |
+
Each entry of `resps` is a list of model responses.
|
| 53 |
+
We select the response that occurs most frequently in each entry of `resps`.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def select_majority(resp):
|
| 57 |
+
counts = Counter(resp)
|
| 58 |
+
vote = counts.most_common(1)[0][0]
|
| 59 |
+
return vote
|
| 60 |
+
|
| 61 |
+
return map(lambda r: [select_majority(r)], resps)
|
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/transformation.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from lm_eval.api.filter import Filter
|
| 2 |
+
from lm_eval.api.registry import register_filter
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@register_filter("lowercase")
|
| 6 |
+
class LowercaseFilter(Filter):
|
| 7 |
+
def __init__(self) -> None:
|
| 8 |
+
pass
|
| 9 |
+
|
| 10 |
+
def apply(self, resps, docs):
|
| 11 |
+
def filter_set(inst):
|
| 12 |
+
return [resp.lower() for resp in inst]
|
| 13 |
+
|
| 14 |
+
return [filter_set(resp) for resp in resps]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@register_filter("uppercase")
|
| 18 |
+
class UppercaseFilter(Filter):
|
| 19 |
+
def __init__(self) -> None:
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
+
def apply(self, resps, docs):
|
| 23 |
+
def filter_set(inst):
|
| 24 |
+
return [resp.upper() for resp in inst]
|
| 25 |
+
|
| 26 |
+
return [filter_set(resp) for resp in resps]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@register_filter("map")
|
| 30 |
+
class MapFilter(Filter):
|
| 31 |
+
def __init__(self, mapping_dict: dict = None, default_value=None) -> None:
|
| 32 |
+
"""
|
| 33 |
+
Initializes the MapFilter with a given mapping dictionary and default value.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
- mapping_dict (dict): A dictionary containing the key-value mappings.
|
| 37 |
+
Default is an empty dictionary.
|
| 38 |
+
- default_value (Any): The value to be returned when a key is not found in the mapping_dict.
|
| 39 |
+
Default is None.
|
| 40 |
+
|
| 41 |
+
Example:
|
| 42 |
+
mapper = MapFilter({'A': 1, 'B': 2}, default_value=0)
|
| 43 |
+
"""
|
| 44 |
+
if mapping_dict is None:
|
| 45 |
+
mapping_dict = {}
|
| 46 |
+
assert isinstance(mapping_dict, dict), (
|
| 47 |
+
"Provided mapping_dict is not a dictionary"
|
| 48 |
+
)
|
| 49 |
+
self.mapping_dict = mapping_dict
|
| 50 |
+
self.default_value = default_value
|
| 51 |
+
|
| 52 |
+
def apply(self, resps, docs):
|
| 53 |
+
def filter_set(inst):
|
| 54 |
+
return [self.mapping_dict.get(resp, self.default_value) for resp in inst]
|
| 55 |
+
|
| 56 |
+
return [filter_set(resp) for resp in resps]
|
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import (
|
| 2 |
+
diffllm,
|
| 3 |
+
huggingface,
|
| 4 |
+
)
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# TODO: implement __all__
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
# enable hf hub transfer if available
|
| 12 |
+
import hf_transfer # type: ignore # noqa
|
| 13 |
+
import huggingface_hub.constants # type: ignore
|
| 14 |
+
|
| 15 |
+
huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
|
| 16 |
+
except ImportError:
|
| 17 |
+
pass
|
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/diffllm.py
ADDED
|
@@ -0,0 +1,563 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import gc
|
| 3 |
+
import random
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
from datetime import timedelta
|
| 8 |
+
from typing import List, Optional, Tuple, Type, TypeVar, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
import transformers
|
| 13 |
+
from accelerate import (
|
| 14 |
+
Accelerator,
|
| 15 |
+
InitProcessGroupKwargs,
|
| 16 |
+
find_executable_batch_size,
|
| 17 |
+
)
|
| 18 |
+
from datasets import Dataset
|
| 19 |
+
from packaging import version
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
|
| 22 |
+
from lm_eval import utils
|
| 23 |
+
from lm_eval.api.instance import Instance
|
| 24 |
+
from lm_eval.api.model import LM
|
| 25 |
+
from lm_eval.api.registry import register_model
|
| 26 |
+
from lm_eval.models.utils import Collator, get_dtype
|
| 27 |
+
|
| 28 |
+
eval_logger = logging.getLogger(__name__)
|
| 29 |
+
T = TypeVar("T", bound="LM")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def empty_cache_by_memory(threshold_gb=70):
|
| 33 |
+
"""
|
| 34 |
+
Empty CUDA cache if allocated memory exceeds threshold
|
| 35 |
+
Args:
|
| 36 |
+
threshold_gb: Memory threshold in GB
|
| 37 |
+
"""
|
| 38 |
+
if torch.cuda.is_available():
|
| 39 |
+
# Get current memory allocated
|
| 40 |
+
allocated = torch.cuda.memory_allocated() / 1024**3 # Convert to GB
|
| 41 |
+
|
| 42 |
+
if allocated > threshold_gb:
|
| 43 |
+
# Clear cache
|
| 44 |
+
gc.collect()
|
| 45 |
+
torch.cuda.empty_cache()
|
| 46 |
+
print(f"Cache cleared. Memory freed: {allocated:.2f} GB")
|
| 47 |
+
|
| 48 |
+
@register_model("diffllm")
|
| 49 |
+
class DiffLLM(LM):
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
pretrained: Union[str, transformers.PreTrainedModel],
|
| 53 |
+
batch_size: Optional[Union[int, str]] = 1,
|
| 54 |
+
device: Optional[str] = "cuda",
|
| 55 |
+
dtype: Optional[Union[str, torch.dtype]] = "auto",
|
| 56 |
+
max_prompt_len: Optional[int] = 1024,
|
| 57 |
+
max_new_tokens: Optional[int] = 128,
|
| 58 |
+
nll_type: Optional[str] = "mc",
|
| 59 |
+
log_type: Optional[str] = "ftb",
|
| 60 |
+
classifier_free_guidance: Optional[float] = 1.0,
|
| 61 |
+
pad_to_max_len: Optional[bool] = False,
|
| 62 |
+
sampling_eps: Optional[float] = 1e-3,
|
| 63 |
+
diffusion_steps: Optional[int] = 32,
|
| 64 |
+
trust_remote_code: Optional[bool] = True,
|
| 65 |
+
parallelize: Optional[bool] = False,
|
| 66 |
+
autogptq: Optional[Union[bool, str]] = False,
|
| 67 |
+
**kwargs,
|
| 68 |
+
) -> None:
|
| 69 |
+
super().__init__()
|
| 70 |
+
|
| 71 |
+
# prepare for parallelism
|
| 72 |
+
assert isinstance(device, str)
|
| 73 |
+
assert isinstance(pretrained, str)
|
| 74 |
+
assert isinstance(batch_size, (int, str))
|
| 75 |
+
|
| 76 |
+
gpus = torch.cuda.device_count()
|
| 77 |
+
accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
|
| 78 |
+
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
|
| 79 |
+
|
| 80 |
+
self.accelerator = accelerator
|
| 81 |
+
|
| 82 |
+
if "npu" in accelerator.device.type:
|
| 83 |
+
gpus = torch.npu.device_count()
|
| 84 |
+
|
| 85 |
+
# using one process with no model parallelism
|
| 86 |
+
if not (parallelize or accelerator.num_processes > 1):
|
| 87 |
+
# use user-passed device
|
| 88 |
+
device_list = set(
|
| 89 |
+
["cuda", "cpu"]
|
| 90 |
+
+ [f"cuda:{i}" for i in range(gpus)]
|
| 91 |
+
+ ["mps", "mps:0"]
|
| 92 |
+
+ [f"npu:{i}" for i in range(gpus)]
|
| 93 |
+
)
|
| 94 |
+
if device and device in device_list:
|
| 95 |
+
self._device = torch.device(device)
|
| 96 |
+
eval_logger.info(f"Using device '{device}'")
|
| 97 |
+
if device in ("mps", "mps:0") and version.parse(
|
| 98 |
+
torch.__version__
|
| 99 |
+
) < version.parse("2.1"):
|
| 100 |
+
raise RuntimeError(
|
| 101 |
+
f"mps requires torch >= 2.1. You have {torch.__version__}"
|
| 102 |
+
)
|
| 103 |
+
else:
|
| 104 |
+
eval_logger.info("Device not specified")
|
| 105 |
+
eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
|
| 106 |
+
self._device = (
|
| 107 |
+
torch.device("cuda")
|
| 108 |
+
if torch.cuda.is_available()
|
| 109 |
+
else torch.device("cpu")
|
| 110 |
+
)
|
| 111 |
+
else:
|
| 112 |
+
if device != "cuda":
|
| 113 |
+
eval_logger.info(
|
| 114 |
+
f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
|
| 115 |
+
)
|
| 116 |
+
self._device = self.accelerator.device
|
| 117 |
+
|
| 118 |
+
self.batch_size_per_gpu = batch_size
|
| 119 |
+
if isinstance(batch_size, str):
|
| 120 |
+
self.batch_size_per_gpu = int(batch_size)
|
| 121 |
+
self._create_model_and_tokenizer(pretrained, dtype, trust_remote_code)
|
| 122 |
+
|
| 123 |
+
if isinstance(pretrained, str):
|
| 124 |
+
if gpus >= 1 or str(self.device) == "mps":
|
| 125 |
+
if not (parallelize or autogptq or (hasattr(self, "accelerator") and self.accelerator.num_processes > 1)):
|
| 126 |
+
try:
|
| 127 |
+
self.model.to(self.device)
|
| 128 |
+
except ValueError:
|
| 129 |
+
eval_logger.debug(
|
| 130 |
+
"Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore."
|
| 131 |
+
)
|
| 132 |
+
if gpus > 1:
|
| 133 |
+
if self.accelerator.num_processes > 1:
|
| 134 |
+
self._device = torch.device(f"{accelerator.device}")
|
| 135 |
+
self._rank = self.accelerator.local_process_index
|
| 136 |
+
self._world_size = self.accelerator.num_processes
|
| 137 |
+
else:
|
| 138 |
+
self._rank = 0
|
| 139 |
+
self._world_size = 1
|
| 140 |
+
else:
|
| 141 |
+
self._rank = 0
|
| 142 |
+
self._world_size = 1
|
| 143 |
+
else:
|
| 144 |
+
eval_logger.warning(
|
| 145 |
+
"Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration"
|
| 146 |
+
)
|
| 147 |
+
self._rank = 0
|
| 148 |
+
self._world_size = 1
|
| 149 |
+
|
| 150 |
+
self.max_prompt_len = max_prompt_len
|
| 151 |
+
self.max_new_tokens = max_new_tokens
|
| 152 |
+
self.diffusion_steps = diffusion_steps
|
| 153 |
+
self.temperature = kwargs.get("temperature", 0.7)
|
| 154 |
+
self.top_p = kwargs.get("top_p", 0.95)
|
| 155 |
+
self.alg = kwargs.get("alg", "entropy")
|
| 156 |
+
self.alg_temp = kwargs.get("alg_temp", 0.0)
|
| 157 |
+
self.top_k = kwargs.get("top_k", None)
|
| 158 |
+
|
| 159 |
+
self.nll_type = nll_type
|
| 160 |
+
self.log_type = log_type
|
| 161 |
+
self.classifier_free_guidance = classifier_free_guidance
|
| 162 |
+
self.pad_to_max_len = pad_to_max_len
|
| 163 |
+
self.sampling_eps = sampling_eps
|
| 164 |
+
|
| 165 |
+
self.mask_id = 151666
|
| 166 |
+
self.eos_id = 151643
|
| 167 |
+
|
| 168 |
+
raw_use_hts = kwargs.get("use_hts", False)
|
| 169 |
+
if isinstance(raw_use_hts, str):
|
| 170 |
+
self.use_hts = raw_use_hts.lower() == "true"
|
| 171 |
+
else:
|
| 172 |
+
self.use_hts = bool(raw_use_hts)
|
| 173 |
+
|
| 174 |
+
self.realtime_output = kwargs.get("realtime_output", "eval_results.jsonl")
|
| 175 |
+
|
| 176 |
+
if self.use_hts:
|
| 177 |
+
from .hts_sampler import HTSSampler
|
| 178 |
+
self.hts_sampler = HTSSampler(self.model, self.tokenizer, device=self.device)
|
| 179 |
+
eval_logger.info(f"Rank {self.rank}: HTS Sampler initialized for Dream.")
|
| 180 |
+
|
| 181 |
+
@property
|
| 182 |
+
def batch_size(self):
|
| 183 |
+
return self.batch_size_per_gpu
|
| 184 |
+
|
| 185 |
+
@property
|
| 186 |
+
def device(self):
|
| 187 |
+
return self._device
|
| 188 |
+
|
| 189 |
+
@property
|
| 190 |
+
def rank(self):
|
| 191 |
+
return self._rank
|
| 192 |
+
|
| 193 |
+
@property
|
| 194 |
+
def world_size(self):
|
| 195 |
+
return self._world_size
|
| 196 |
+
|
| 197 |
+
def _create_model_and_tokenizer(self, pretrained, dtype, trust_remote_code):
|
| 198 |
+
self.model = (
|
| 199 |
+
transformers.AutoModel.from_pretrained(
|
| 200 |
+
pretrained,
|
| 201 |
+
torch_dtype=get_dtype(dtype),
|
| 202 |
+
trust_remote_code=trust_remote_code,
|
| 203 |
+
)
|
| 204 |
+
.eval()
|
| 205 |
+
).to(self.device)
|
| 206 |
+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 207 |
+
pretrained, trust_remote_code=trust_remote_code
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
def tok_decode(self, tokens, skip_special_tokens=True):
|
| 211 |
+
return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
| 212 |
+
|
| 213 |
+
def tok_encode(self, text, add_special_tokens=True):
|
| 214 |
+
return self.tokenizer(
|
| 215 |
+
text, return_tensors="pt", add_special_tokens=add_special_tokens
|
| 216 |
+
).input_ids
|
| 217 |
+
|
| 218 |
+
@classmethod
|
| 219 |
+
def create_from_arg_string(
|
| 220 |
+
cls: Type[T], arg_string: str, additional_config: Optional[dict] = None
|
| 221 |
+
) -> T:
|
| 222 |
+
additional_config = {} if additional_config is None else additional_config
|
| 223 |
+
args = utils.simple_parse_args_string(arg_string)
|
| 224 |
+
args2 = {k: v for k, v in additional_config.items() if v is not None}
|
| 225 |
+
return cls(**args, **args2)
|
| 226 |
+
|
| 227 |
+
def apply_chat_template(
|
| 228 |
+
self, chat_history, add_generation_prompt: bool = True
|
| 229 |
+
) -> str:
|
| 230 |
+
chat_templated = self.tokenizer.apply_chat_template(
|
| 231 |
+
chat_history,
|
| 232 |
+
tokenize=False,
|
| 233 |
+
add_generation_prompt=add_generation_prompt,
|
| 234 |
+
continue_final_message=not add_generation_prompt,
|
| 235 |
+
)
|
| 236 |
+
return chat_templated
|
| 237 |
+
|
| 238 |
+
@property
|
| 239 |
+
def tokenizer_name(self) -> str:
|
| 240 |
+
return self.tokenizer.name_or_path.replace("/", "__")
|
| 241 |
+
|
| 242 |
+
def _generate_batch(self, prompts: List[str], gen_kwargs: dict = None) -> Tuple[List[str], List[dict]]:
|
| 243 |
+
raw_val = gen_kwargs.get("use_hts", self.use_hts)
|
| 244 |
+
use_hts_now = str(raw_val).lower() == "true" if not isinstance(raw_val, bool) else raw_val
|
| 245 |
+
|
| 246 |
+
all_stats = []
|
| 247 |
+
if not use_hts_now:
|
| 248 |
+
prompt_ids = self.tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left").input_ids
|
| 249 |
+
prompt_ids = prompt_ids[:, -self.max_prompt_len:]
|
| 250 |
+
attn_mask = prompt_ids.ne(self.tokenizer.pad_token_id).to(self.device)
|
| 251 |
+
prompt_ids = prompt_ids.to(device=self.device)
|
| 252 |
+
|
| 253 |
+
generation_ids = self.model.diffusion_generate(
|
| 254 |
+
prompt_ids,
|
| 255 |
+
attention_mask=attn_mask,
|
| 256 |
+
max_new_tokens=self.max_new_tokens,
|
| 257 |
+
output_history=False,
|
| 258 |
+
return_dict_in_generate=True,
|
| 259 |
+
steps=self.diffusion_steps,
|
| 260 |
+
temperature=self.temperature,
|
| 261 |
+
top_p=self.top_p,
|
| 262 |
+
top_k=self.top_k,
|
| 263 |
+
alg=self.alg,
|
| 264 |
+
alg_temp=self.alg_temp,
|
| 265 |
+
)
|
| 266 |
+
responses = [
|
| 267 |
+
self.tokenizer.decode(g[len(p) :].tolist()).split(self.tokenizer.eos_token)[0]
|
| 268 |
+
for p, g in zip(prompt_ids, generation_ids.sequences)
|
| 269 |
+
]
|
| 270 |
+
all_stats = [{} for _ in responses]
|
| 271 |
+
return responses, all_stats
|
| 272 |
+
else:
|
| 273 |
+
if not hasattr(self, "hts_sampler"):
|
| 274 |
+
from .hts_sampler import HTSSampler
|
| 275 |
+
self.hts_sampler = HTSSampler(self.model, self.tokenizer, device=self.device)
|
| 276 |
+
|
| 277 |
+
results = []
|
| 278 |
+
for prompt in prompts:
|
| 279 |
+
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
|
| 280 |
+
|
| 281 |
+
final_codes, stats = self.hts_sampler.generate_hts(
|
| 282 |
+
prompt_text=prompt,
|
| 283 |
+
input_ids=input_ids,
|
| 284 |
+
initial_N=int(gen_kwargs.get("initial_N", 4)),
|
| 285 |
+
final_K=int(gen_kwargs.get("final_K", 1)),
|
| 286 |
+
hts_survivor_k=int(gen_kwargs.get("hts_survivor_k", 4)),
|
| 287 |
+
reward_mode=gen_kwargs.get("reward_mode", "svf"),
|
| 288 |
+
task_type=gen_kwargs.get("task_type", "code"),
|
| 289 |
+
steps=self.diffusion_steps,
|
| 290 |
+
gen_length=self.max_new_tokens,
|
| 291 |
+
temperature=float(gen_kwargs.get("temperature", self.temperature)),
|
| 292 |
+
top_p=float(gen_kwargs.get("top_p", self.top_p)),
|
| 293 |
+
top_k=gen_kwargs.get("top_k", self.top_k),
|
| 294 |
+
until=gen_kwargs.get("until", []),
|
| 295 |
+
hts_mode=True,
|
| 296 |
+
mask_id=self.mask_id,
|
| 297 |
+
eos_id=self.eos_id
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
results.append(final_codes[0])
|
| 301 |
+
all_stats.append(stats)
|
| 302 |
+
return results, all_stats
|
| 303 |
+
|
| 304 |
+
def generate_until(self, requests: List[Instance], disable_tqdm: bool = False):
|
| 305 |
+
res = []
|
| 306 |
+
|
| 307 |
+
gen_kwargs_first = requests[0].args[1]
|
| 308 |
+
actual_output_path = gen_kwargs_first.get("realtime_output", self.realtime_output)
|
| 309 |
+
|
| 310 |
+
raw_val = gen_kwargs_first.get("use_hts", self.use_hts)
|
| 311 |
+
self.use_hts = str(raw_val).lower() == "true" if not isinstance(raw_val, bool) else raw_val
|
| 312 |
+
|
| 313 |
+
rank_tmp_file = actual_output_path.replace(".jsonl", f"_rank{self.rank}.tmp")
|
| 314 |
+
|
| 315 |
+
output_dir = os.path.dirname(rank_tmp_file)
|
| 316 |
+
if output_dir and not os.path.exists(output_dir):
|
| 317 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 318 |
+
|
| 319 |
+
pbar = tqdm(
|
| 320 |
+
total=len(requests),
|
| 321 |
+
disable=(disable_tqdm or (self.rank != 0)),
|
| 322 |
+
desc="Running generate_until",
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
for batch_idx in range(0, len(requests), self.batch_size_per_gpu):
|
| 326 |
+
batch_requests = requests[batch_idx : batch_idx + self.batch_size_per_gpu]
|
| 327 |
+
contexts, task_gen_args = zip(*[req.arguments for req in batch_requests])
|
| 328 |
+
|
| 329 |
+
responses, stats_list = self._generate_batch(contexts, gen_kwargs=task_gen_args[0])
|
| 330 |
+
|
| 331 |
+
for i, r in enumerate(responses):
|
| 332 |
+
r = r.replace("```python", "").replace("```", "")
|
| 333 |
+
|
| 334 |
+
for s in task_gen_args[0].get('until', []):
|
| 335 |
+
r = r.split(s)[0]
|
| 336 |
+
|
| 337 |
+
target_val = getattr(batch_requests[i], "target", None)
|
| 338 |
+
if target_val is None or target_val == "N/A":
|
| 339 |
+
target_val = batch_requests[i].doc.get("answer", batch_requests[i].doc.get("solution", "N/A"))
|
| 340 |
+
|
| 341 |
+
save_data = {
|
| 342 |
+
"doc": batch_requests[i].doc,
|
| 343 |
+
"target": target_val,
|
| 344 |
+
"prompt": contexts[i],
|
| 345 |
+
"response": r,
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
if self.use_hts:
|
| 349 |
+
save_data.update(stats_list[i])
|
| 350 |
+
|
| 351 |
+
with open(rank_tmp_file, "a", encoding="utf-8") as f:
|
| 352 |
+
f.write(json.dumps(save_data, ensure_ascii=False) + "\n")
|
| 353 |
+
f.flush()
|
| 354 |
+
|
| 355 |
+
responses[i] = r
|
| 356 |
+
|
| 357 |
+
if self.rank == 0 and batch_idx == 0:
|
| 358 |
+
print(f"Sample Response:\n{responses[0]}\n")
|
| 359 |
+
|
| 360 |
+
res.extend(responses)
|
| 361 |
+
pbar.update(len(batch_requests))
|
| 362 |
+
|
| 363 |
+
pbar.close()
|
| 364 |
+
|
| 365 |
+
self.accelerator.wait_for_everyone()
|
| 366 |
+
|
| 367 |
+
if self.rank == 0:
|
| 368 |
+
eval_logger.info(f"Merging rank files into {actual_output_path}...")
|
| 369 |
+
with open(actual_output_path, "w", encoding="utf-8") as final_f:
|
| 370 |
+
for r in range(self.world_size):
|
| 371 |
+
temp_f = actual_output_path.replace(".jsonl", f"_rank{r}.tmp")
|
| 372 |
+
if os.path.exists(temp_f):
|
| 373 |
+
with open(temp_f, "r", encoding="utf-8") as tf:
|
| 374 |
+
for line in tf:
|
| 375 |
+
final_f.write(line)
|
| 376 |
+
os.remove(temp_f)
|
| 377 |
+
eval_logger.info("Merge completed.")
|
| 378 |
+
|
| 379 |
+
return res
|
| 380 |
+
|
| 381 |
+
def _forward_process(self, batch):
|
| 382 |
+
b, l = batch.shape
|
| 383 |
+
u0 = torch.rand(1, device=batch.device, dtype=torch.float32)
|
| 384 |
+
indices = torch.arange(b, device=batch.device).float()
|
| 385 |
+
t = (u0 + indices / b) % 1
|
| 386 |
+
p_mask = (1 - self.sampling_eps) * t + self.sampling_eps
|
| 387 |
+
p_mask = p_mask[:, None].repeat(1, l)
|
| 388 |
+
mask_indices = torch.rand((b, l), device=batch.device) < p_mask
|
| 389 |
+
mask_indices[:, 0] = False
|
| 390 |
+
mask_indices[:, -1] = False
|
| 391 |
+
noisy_batch = torch.where(mask_indices, self.mask_id, batch)
|
| 392 |
+
return noisy_batch, p_mask
|
| 393 |
+
|
| 394 |
+
@torch.no_grad()
|
| 395 |
+
def get_logits(self, batch, prompt_index):
|
| 396 |
+
if self.classifier_free_guidance > 1.:
|
| 397 |
+
assert len(prompt_index) == batch.shape[1]
|
| 398 |
+
prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1)
|
| 399 |
+
un_batch = batch.clone()
|
| 400 |
+
un_batch[prompt_index] = self.mask_id
|
| 401 |
+
batch = torch.cat([batch, un_batch])
|
| 402 |
+
|
| 403 |
+
if self.pad_to_max_len:
|
| 404 |
+
raise NotImplementedError
|
| 405 |
+
else:
|
| 406 |
+
input = batch
|
| 407 |
+
|
| 408 |
+
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
| 409 |
+
logits = self.model(input, 'full').logits
|
| 410 |
+
|
| 411 |
+
if self.classifier_free_guidance > 1.:
|
| 412 |
+
logits, un_logits = torch.chunk(logits, 2, dim=0)
|
| 413 |
+
logits = un_logits + self.classifier_free_guidance * (logits - un_logits)
|
| 414 |
+
return logits[:, :batch.shape[1]]
|
| 415 |
+
|
| 416 |
+
@torch.no_grad()
|
| 417 |
+
def _eval_target_nll_mc(self, prefix, target):
|
| 418 |
+
if prefix is None:
|
| 419 |
+
seq = target[None, :]
|
| 420 |
+
else:
|
| 421 |
+
seq = torch.concatenate([prefix, target])[None, :]
|
| 422 |
+
seq = seq.repeat((self.batch_size, 1)).to(self.device)
|
| 423 |
+
|
| 424 |
+
if self.log_type == 'ftb':
|
| 425 |
+
prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
|
| 426 |
+
else:
|
| 427 |
+
prompt_index = torch.arange(seq.shape[1], device=self.device) >= len(prefix)
|
| 428 |
+
|
| 429 |
+
loss_acc = []
|
| 430 |
+
mc_num = self.diffusion_steps
|
| 431 |
+
for _ in range(max(mc_num // self.batch_size, 1)):
|
| 432 |
+
perturbed_seq = seq.clone()
|
| 433 |
+
perturbed_seq_, p_mask = self._forward_process(seq)
|
| 434 |
+
if self.log_type == 'ftb':
|
| 435 |
+
perturbed_seq[:, -len(target):] = perturbed_seq_[:, -len(target):]
|
| 436 |
+
elif self.log_type == 'btf':
|
| 437 |
+
perturbed_seq[:, :len(prefix)] = perturbed_seq_[:, :len(prefix)]
|
| 438 |
+
elif self.log_type == 'union':
|
| 439 |
+
perturbed_seq = perturbed_seq_
|
| 440 |
+
else:
|
| 441 |
+
raise NotImplementedError(self.log_type)
|
| 442 |
+
|
| 443 |
+
mask_indices = perturbed_seq == self.mask_id
|
| 444 |
+
|
| 445 |
+
logits = self.get_logits(perturbed_seq, prompt_index)
|
| 446 |
+
|
| 447 |
+
loss = F.cross_entropy(logits[mask_indices], seq[mask_indices], reduction='none') / p_mask[mask_indices]
|
| 448 |
+
loss = loss.sum() / self.batch_size
|
| 449 |
+
loss_acc.append(loss.item())
|
| 450 |
+
del logits, loss, perturbed_seq, perturbed_seq_, p_mask, mask_indices
|
| 451 |
+
empty_cache_by_memory(threshold_gb=70)
|
| 452 |
+
|
| 453 |
+
return sum(loss_acc) / len(loss_acc)
|
| 454 |
+
|
| 455 |
+
@torch.no_grad()
|
| 456 |
+
def _eval_target_nll_ar(self, prefix, target):
|
| 457 |
+
prefix, target = prefix.unsqueeze(0), target.unsqueeze(0) # 1*l1, 1*l2
|
| 458 |
+
assert self.log_type in ['ftb', 'btf']
|
| 459 |
+
assert self.nll_type in ['ar_ftb', 'ar_btf']
|
| 460 |
+
|
| 461 |
+
if self.log_type == 'ftb':
|
| 462 |
+
prompt_index = torch.arange(prefix.shape[1] + target.shape[1], device=self.device) < prefix.shape[1]
|
| 463 |
+
else:
|
| 464 |
+
prompt_index = torch.arange(prefix.shape[1] + target.shape[1], device=self.device) >= prefix.shape[1]
|
| 465 |
+
|
| 466 |
+
if self.log_type == 'ftb':
|
| 467 |
+
perturbed_ = target.repeat(target.shape[1], 1).clone().contiguous() # l2*l2
|
| 468 |
+
else:
|
| 469 |
+
perturbed_ = prefix.repeat(prefix.shape[1], 1).clone().contiguous() # l1*l1
|
| 470 |
+
|
| 471 |
+
mask_index = torch.ones((perturbed_.shape[1], perturbed_.shape[1]), dtype=torch.bool)
|
| 472 |
+
if self.nll_type == 'ar_ftb':
|
| 473 |
+
mask_index = torch.triu(mask_index)
|
| 474 |
+
else:
|
| 475 |
+
mask_index = torch.tril(mask_index)
|
| 476 |
+
perturbed_[mask_index] = self.mask_id
|
| 477 |
+
if self.log_type == 'ftb':
|
| 478 |
+
perturbed_seq = torch.cat([prefix.repeat(perturbed_.shape[0], 1), perturbed_], dim=-1)
|
| 479 |
+
else:
|
| 480 |
+
perturbed_seq = torch.cat([perturbed_, target.repeat(perturbed_.shape[0], 1)], dim=-1)
|
| 481 |
+
|
| 482 |
+
logits_ = []
|
| 483 |
+
num = len(perturbed_seq) // self.batch_size if len(perturbed_seq) % self.batch_size == 0 else len(perturbed_seq) // self.batch_size + 1
|
| 484 |
+
for i in range(num):
|
| 485 |
+
end = (i + 1) * self.batch_size if (i + 1) * self.batch_size < len(perturbed_seq) else len(perturbed_seq)
|
| 486 |
+
perturbed_seq_ = perturbed_seq[i * self.batch_size: end]
|
| 487 |
+
perturbed_seq_ = perturbed_seq_.to(self.device)
|
| 488 |
+
if len(perturbed_seq_.shape) == 1:
|
| 489 |
+
perturbed_seq_ = perturbed_seq_.unsqueeze(0)
|
| 490 |
+
logits = self.get_logits(perturbed_seq_, prompt_index)
|
| 491 |
+
logits_.append(logits.cpu())
|
| 492 |
+
logits = torch.cat(logits_, dim=0)
|
| 493 |
+
|
| 494 |
+
temp_index = torch.ones((perturbed_.shape[1], perturbed_.shape[1]), dtype=torch.bool)
|
| 495 |
+
if self.nll_type == 'ar_ftb':
|
| 496 |
+
temp_index = torch.triu(temp_index, diagonal=1)
|
| 497 |
+
else:
|
| 498 |
+
temp_index = torch.tril(temp_index, diagonal=-1)
|
| 499 |
+
mask_index[temp_index] = False
|
| 500 |
+
if self.log_type == 'ftb':
|
| 501 |
+
logits_index = torch.cat([torch.zeros((perturbed_.shape[1], prefix.shape[1]), dtype=torch.bool), mask_index], dim=-1)
|
| 502 |
+
else:
|
| 503 |
+
logits_index = torch.cat([mask_index, torch.zeros((perturbed_.shape[1], target.shape[1]), dtype=torch.bool)], dim=-1)
|
| 504 |
+
|
| 505 |
+
if self.log_type == 'ftb':
|
| 506 |
+
loss = F.cross_entropy(logits[logits_index], target[0], reduction='sum').cpu().item()
|
| 507 |
+
else:
|
| 508 |
+
loss = F.cross_entropy(logits[logits_index], prefix[0], reduction='sum').cpu().item()
|
| 509 |
+
return loss
|
| 510 |
+
|
| 511 |
+
def _encode_pair(self, context, continuation):
|
| 512 |
+
n_spaces = len(context) - len(context.rstrip())
|
| 513 |
+
if n_spaces > 0:
|
| 514 |
+
continuation = context[-n_spaces:] + continuation
|
| 515 |
+
context = context[:-n_spaces]
|
| 516 |
+
|
| 517 |
+
whole_enc = self.tokenizer.encode(context + continuation) + [
|
| 518 |
+
self.tokenizer.eos_token_id
|
| 519 |
+
]
|
| 520 |
+
context_enc = self.tokenizer.encode(context)
|
| 521 |
+
|
| 522 |
+
context_enc_len = len(context_enc)
|
| 523 |
+
continuation_enc = whole_enc[context_enc_len:]
|
| 524 |
+
|
| 525 |
+
return context_enc, continuation_enc
|
| 526 |
+
|
| 527 |
+
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
|
| 528 |
+
def _tokenize(e):
|
| 529 |
+
prefix, target = self._encode_pair(e["prefix"], e["target"])
|
| 530 |
+
return {
|
| 531 |
+
"prefix_text": e["prefix"],
|
| 532 |
+
"target_text": e["target"],
|
| 533 |
+
"prefix": prefix,
|
| 534 |
+
"target": target,
|
| 535 |
+
}
|
| 536 |
+
|
| 537 |
+
ds = []
|
| 538 |
+
ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests]
|
| 539 |
+
ds = Dataset.from_list(ds)
|
| 540 |
+
ds = ds.map(_tokenize)
|
| 541 |
+
ds = ds.with_format("torch")
|
| 542 |
+
|
| 543 |
+
out = []
|
| 544 |
+
with torch.no_grad():
|
| 545 |
+
for elem in tqdm(ds, desc="Computing likelihood..."):
|
| 546 |
+
prefix = elem["prefix"]
|
| 547 |
+
target = elem["target"]
|
| 548 |
+
|
| 549 |
+
if self.nll_type == 'mc':
|
| 550 |
+
ll = -self._eval_target_nll_mc(prefix, target)
|
| 551 |
+
if self.log_type == 'union':
|
| 552 |
+
ll = ll / (len(target) + len(prefix))
|
| 553 |
+
elif self.nll_type == 'ar_ftb' or self.nll_type == 'ar_btf':
|
| 554 |
+
ll = -self._eval_target_nll_ar(prefix, target)
|
| 555 |
+
else:
|
| 556 |
+
raise NotImplementedError(self.nll_type)
|
| 557 |
+
|
| 558 |
+
is_target_greedy_dec = False
|
| 559 |
+
out.append((ll, 1.0 if is_target_greedy_dec else 0.0))
|
| 560 |
+
return out
|
| 561 |
+
|
| 562 |
+
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
|
| 563 |
+
raise NotImplementedError
|
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/dummy.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
|
| 5 |
+
from lm_eval.api.model import LM
|
| 6 |
+
from lm_eval.api.registry import register_model
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@register_model("dummy")
|
| 10 |
+
class DummyLM(LM):
|
| 11 |
+
def __init__(self) -> None:
|
| 12 |
+
super().__init__()
|
| 13 |
+
|
| 14 |
+
@classmethod
|
| 15 |
+
def create_from_arg_string(cls, arg_string, additional_config=None):
|
| 16 |
+
return cls()
|
| 17 |
+
|
| 18 |
+
def loglikelihood(self, requests, disable_tqdm: bool = False):
|
| 19 |
+
res = []
|
| 20 |
+
|
| 21 |
+
for _ in tqdm(requests, disable=disable_tqdm):
|
| 22 |
+
res.append((-random.random(), False))
|
| 23 |
+
|
| 24 |
+
return res
|
| 25 |
+
|
| 26 |
+
def generate_until(self, requests, disable_tqdm: bool = False):
|
| 27 |
+
res = []
|
| 28 |
+
|
| 29 |
+
for request in tqdm(requests, disable=disable_tqdm):
|
| 30 |
+
res.append("lol")
|
| 31 |
+
assert request.arguments[0].strip() != ""
|
| 32 |
+
|
| 33 |
+
return res
|
| 34 |
+
|
| 35 |
+
def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
|
| 36 |
+
res = []
|
| 37 |
+
|
| 38 |
+
for _ in tqdm(requests, disable=disable_tqdm):
|
| 39 |
+
res.append(-random.random())
|
| 40 |
+
|
| 41 |
+
return res
|
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/hts_sampler.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import numpy as np
|
| 4 |
+
from .verifier import CodeVerifier
|
| 5 |
+
import logging
|
| 6 |
+
import re
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
class HTSSampler:
|
| 12 |
+
def __init__(self, model, tokenizer, device="cuda"):
|
| 13 |
+
self.model = model
|
| 14 |
+
self.tokenizer = tokenizer
|
| 15 |
+
self.device = device
|
| 16 |
+
self.verifier = CodeVerifier(model, tokenizer, device)
|
| 17 |
+
|
| 18 |
+
def _get_num_transfer_tokens(self, block_length, steps):
|
| 19 |
+
if steps == 0: return torch.tensor([], dtype=torch.int64)
|
| 20 |
+
base = block_length // steps
|
| 21 |
+
remainder = block_length % steps
|
| 22 |
+
num_transfer_tokens = torch.full((steps,), base, dtype=torch.int64)
|
| 23 |
+
num_transfer_tokens[:remainder] += 1
|
| 24 |
+
return num_transfer_tokens
|
| 25 |
+
|
| 26 |
+
def _sample_with_temperature(self, logits, temperature, top_k, top_p):
|
| 27 |
+
logits = logits.to(torch.float32)
|
| 28 |
+
orig_probs = torch.softmax(logits, dim=-1)
|
| 29 |
+
x0_p, _ = torch.max(orig_probs, dim=-1)
|
| 30 |
+
|
| 31 |
+
if temperature > 0.0:
|
| 32 |
+
noise = torch.rand_like(logits, dtype=torch.float32)
|
| 33 |
+
gumbel_noise = -torch.log(-torch.log(noise + 1e-10) + 1e-10)
|
| 34 |
+
logits = logits / temperature + gumbel_noise
|
| 35 |
+
|
| 36 |
+
if top_k is not None and top_k > 0:
|
| 37 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
| 38 |
+
logits[indices_to_remove] = -float('Inf')
|
| 39 |
+
|
| 40 |
+
x0 = torch.argmax(logits, dim=-1)
|
| 41 |
+
return x0, x0_p
|
| 42 |
+
|
| 43 |
+
def _safe_scalar(self, val):
|
| 44 |
+
if isinstance(val, torch.Tensor):
|
| 45 |
+
if val.numel() > 1: return val.mean().item()
|
| 46 |
+
return val.item()
|
| 47 |
+
return float(val)
|
| 48 |
+
|
| 49 |
+
def _analyze_structure(self, text, task_type="code"):
|
| 50 |
+
score = 0.0
|
| 51 |
+
stripped = text.strip()
|
| 52 |
+
if task_type == "code":
|
| 53 |
+
if len(stripped) < 5: return -0.1
|
| 54 |
+
keywords = ["return", "print", "yield", "lambda", "class ", "def "]
|
| 55 |
+
if any(k in stripped for k in keywords): score += 0.05
|
| 56 |
+
if ":" in stripped: score += 0.02
|
| 57 |
+
if " " in text: score += 0.03
|
| 58 |
+
elif task_type == "math":
|
| 59 |
+
if "\\boxed{" in stripped: score += 0.1
|
| 60 |
+
if "The answer is" in stripped: score += 0.05
|
| 61 |
+
return score
|
| 62 |
+
|
| 63 |
+
def _chunked_forward(self, x, chunk_size=32, slice_indices=None):
|
| 64 |
+
total_batch = x.shape[0]
|
| 65 |
+
logits_list = []
|
| 66 |
+
for i in range(0, total_batch, chunk_size):
|
| 67 |
+
end_idx = min(i + chunk_size, total_batch)
|
| 68 |
+
sub_x = x[i:end_idx]
|
| 69 |
+
with torch.no_grad():
|
| 70 |
+
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
| 71 |
+
outputs = self.model(sub_x, 'full')
|
| 72 |
+
sub_logits = outputs.logits
|
| 73 |
+
sub_logits = torch.cat([sub_logits[:, :1, :], sub_logits[:, :-1, :]], dim=1)
|
| 74 |
+
if slice_indices is not None:
|
| 75 |
+
s_start, s_end = slice_indices
|
| 76 |
+
sub_logits = sub_logits[:, s_start:s_end, :]
|
| 77 |
+
logits_list.append(sub_logits.detach().clone())
|
| 78 |
+
return torch.cat(logits_list, dim=0)
|
| 79 |
+
|
| 80 |
+
def _branch_and_resample(self, x, conf_scores, survivor_indices, target_width, mask_id,
|
| 81 |
+
prompt_length, resample_window=6, task_type="code"):
|
| 82 |
+
num_survivors = len(survivor_indices)
|
| 83 |
+
if num_survivors == 0: return x[:target_width].clone(), conf_scores[:target_width].clone()
|
| 84 |
+
|
| 85 |
+
base_repeat = target_width // num_survivors
|
| 86 |
+
remainder = target_width % num_survivors
|
| 87 |
+
new_x_list, new_conf_list = [], []
|
| 88 |
+
|
| 89 |
+
for i in range(num_survivors):
|
| 90 |
+
count = base_repeat + (1 if i < remainder else 0)
|
| 91 |
+
if count == 0: continue
|
| 92 |
+
survivor_x = x[survivor_indices[i]]
|
| 93 |
+
survivor_conf = conf_scores[survivor_indices[i]]
|
| 94 |
+
|
| 95 |
+
new_x_list.append(survivor_x.unsqueeze(0))
|
| 96 |
+
new_conf_list.append(survivor_conf.unsqueeze(0))
|
| 97 |
+
|
| 98 |
+
if count > 1:
|
| 99 |
+
gen_part = survivor_x[prompt_length:]
|
| 100 |
+
gen_conf = survivor_conf[prompt_length:]
|
| 101 |
+
non_mask_indices = (gen_part != mask_id).nonzero(as_tuple=True)[0]
|
| 102 |
+
for _ in range(count - 1):
|
| 103 |
+
perturbed_x = survivor_x.clone()
|
| 104 |
+
perturbed_conf = survivor_conf.clone()
|
| 105 |
+
if len(non_mask_indices) > 0:
|
| 106 |
+
pool_size = min(resample_window * 2, len(non_mask_indices))
|
| 107 |
+
current_token_confs = gen_conf[non_mask_indices]
|
| 108 |
+
_, candidate_pool = torch.topk(current_token_confs, k=pool_size, largest=False)
|
| 109 |
+
|
| 110 |
+
num_to_perturb = min(resample_window, pool_size)
|
| 111 |
+
rand_indices = torch.randperm(pool_size, device=self.device)[:num_to_perturb]
|
| 112 |
+
selected_sub_indices = candidate_pool[rand_indices]
|
| 113 |
+
|
| 114 |
+
target_idx_in_x = prompt_length + non_mask_indices[selected_sub_indices]
|
| 115 |
+
perturbed_x[target_idx_in_x] = mask_id
|
| 116 |
+
perturbed_conf[target_idx_in_x] = 0.0
|
| 117 |
+
new_x_list.append(perturbed_x.unsqueeze(0))
|
| 118 |
+
new_conf_list.append(perturbed_conf.unsqueeze(0))
|
| 119 |
+
|
| 120 |
+
return torch.cat(new_x_list, dim=0), torch.cat(new_conf_list, dim=0)
|
| 121 |
+
|
| 122 |
+
@torch.no_grad()
|
| 123 |
+
def generate_hts(self, prompt_text, input_ids, problem_data=None,
|
| 124 |
+
initial_N=1, final_K=1, survivor_K=None,
|
| 125 |
+
prune_step_pct=0.0, reward_mode="confidence",
|
| 126 |
+
temperature=0.7, block_length=32, steps=64, gen_length=1024,
|
| 127 |
+
top_p=0.95, top_k=None, minimal_topk=1, threshold=0.9,
|
| 128 |
+
eos_id=151643, mask_id=151666,
|
| 129 |
+
hts_mode=False, hts_start_pct=0.1, hts_end_pct=0.6, decay_factor=1.2,
|
| 130 |
+
hts_survivor_k=4, task_type="code", until=None, pruning_interval=20):
|
| 131 |
+
|
| 132 |
+
input_ids = input_ids.to(self.device)
|
| 133 |
+
prompt_length = input_ids.shape[1]
|
| 134 |
+
total_length = prompt_length + gen_length
|
| 135 |
+
|
| 136 |
+
x = torch.full((initial_N, total_length), mask_id, dtype=torch.long, device=self.device)
|
| 137 |
+
x[:, :prompt_length] = input_ids.repeat(initial_N, 1)
|
| 138 |
+
conf_scores = torch.zeros((initial_N, total_length), dtype=torch.float32, device=self.device)
|
| 139 |
+
conf_scores[:, :prompt_length] = 1.0
|
| 140 |
+
|
| 141 |
+
schedule = self._get_num_transfer_tokens(gen_length, steps)
|
| 142 |
+
current_bsz = initial_N
|
| 143 |
+
schedule_map = {}
|
| 144 |
+
ts_start, tr_end = 0, 0
|
| 145 |
+
|
| 146 |
+
if hts_mode:
|
| 147 |
+
ts_start, tr_end = int(steps * hts_start_pct), int(steps * hts_end_pct)
|
| 148 |
+
else:
|
| 149 |
+
final_K_list = [final_K] if not isinstance(final_K, list) else final_K
|
| 150 |
+
prune_pct_list = [prune_step_pct] if not isinstance(prune_step_pct, list) else prune_step_pct
|
| 151 |
+
for pct, width in zip(prune_pct_list, final_K_list):
|
| 152 |
+
if pct > 0: schedule_map[int(steps * pct)] = width
|
| 153 |
+
|
| 154 |
+
stats = {
|
| 155 |
+
"initial_n": initial_N,
|
| 156 |
+
"final_k": final_K if not isinstance(final_K, list) else final_K[-1],
|
| 157 |
+
"nfe": 0,
|
| 158 |
+
"svf_calls": 0,
|
| 159 |
+
"pruning_history": [],
|
| 160 |
+
"entropy_history": [],
|
| 161 |
+
"final_scores": []
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
next_allowed_pruning_step = ts_start
|
| 165 |
+
|
| 166 |
+
for step in range(steps):
|
| 167 |
+
perform_pruning = False
|
| 168 |
+
num_parents_to_select = hts_survivor_k
|
| 169 |
+
|
| 170 |
+
if hts_mode and ts_start <= step < tr_end and step >= next_allowed_pruning_step:
|
| 171 |
+
target_width = max(stats["final_k"], math.ceil(initial_N * (decay_factor ** -(step - ts_start))))
|
| 172 |
+
if current_bsz > target_width:
|
| 173 |
+
perform_pruning = True
|
| 174 |
+
elif not hts_mode and step in schedule_map:
|
| 175 |
+
target_width = schedule_map[step]
|
| 176 |
+
num_parents_to_select = target_width
|
| 177 |
+
if current_bsz > target_width:
|
| 178 |
+
perform_pruning = True
|
| 179 |
+
|
| 180 |
+
if perform_pruning:
|
| 181 |
+
stats["svf_calls"] += current_bsz
|
| 182 |
+
full_logits = self._chunked_forward(x[:current_bsz, :], slice_indices=(prompt_length, total_length))
|
| 183 |
+
rough_ids = torch.argmax(full_logits, dim=-1)
|
| 184 |
+
rough_codes = self.tokenizer.batch_decode(rough_ids, skip_special_tokens=True)
|
| 185 |
+
|
| 186 |
+
candidates = []
|
| 187 |
+
for i in range(current_bsz):
|
| 188 |
+
s = self._safe_scalar(self.verifier.get_reward(prompt_text, rough_codes[i], mode=reward_mode, current_logits=full_logits[i], task_type=task_type))
|
| 189 |
+
s += self._analyze_structure(rough_codes[i], task_type=task_type)
|
| 190 |
+
clean_text = rough_codes[i].strip().replace(" ", "").replace("\n", "")
|
| 191 |
+
content_key = hash(clean_text[:150] + clean_text[-150:]) if clean_text else i
|
| 192 |
+
candidates.append({'score': s, 'idx': i, 'key': content_key})
|
| 193 |
+
|
| 194 |
+
stats["pruning_history"].append({"step": step, "scores": [c['score'] for c in candidates]})
|
| 195 |
+
candidates.sort(key=lambda c: c['score'], reverse=True)
|
| 196 |
+
|
| 197 |
+
selected_indices, seen_keys = [], set()
|
| 198 |
+
for cand in candidates:
|
| 199 |
+
if len(selected_indices) >= num_parents_to_select: break
|
| 200 |
+
if cand['key'] not in seen_keys:
|
| 201 |
+
selected_indices.append(cand['idx']); seen_keys.add(cand['key'])
|
| 202 |
+
for cand in candidates:
|
| 203 |
+
if len(selected_indices) >= num_parents_to_select: break
|
| 204 |
+
if cand['idx'] not in selected_indices: selected_indices.append(cand['idx'])
|
| 205 |
+
|
| 206 |
+
top_indices = torch.tensor(selected_indices, device=self.device)
|
| 207 |
+
x, conf_scores = self._branch_and_resample(x, conf_scores, top_indices, target_width, mask_id, prompt_length, task_type=task_type)
|
| 208 |
+
|
| 209 |
+
current_bsz = target_width
|
| 210 |
+
next_allowed_pruning_step = step + pruning_interval
|
| 211 |
+
|
| 212 |
+
active_mask = (x[:current_bsz, prompt_length:] == mask_id)
|
| 213 |
+
if active_mask.sum() == 0: break
|
| 214 |
+
|
| 215 |
+
stats["nfe"] += current_bsz
|
| 216 |
+
logits = self._chunked_forward(x[:current_bsz, :], slice_indices=(prompt_length, total_length))
|
| 217 |
+
|
| 218 |
+
with torch.no_grad():
|
| 219 |
+
probs = torch.softmax(logits.float(), dim=-1)
|
| 220 |
+
token_entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=-1)
|
| 221 |
+
sample_entropy = token_entropy.mean(dim=-1)
|
| 222 |
+
stats["entropy_history"].append(sample_entropy.tolist())
|
| 223 |
+
|
| 224 |
+
x0, x0_p = self._sample_with_temperature(logits, temperature, top_k, top_p)
|
| 225 |
+
num_transfer = schedule[step].item()
|
| 226 |
+
confidence = torch.where(active_mask, x0_p, -torch.inf)
|
| 227 |
+
transfer_idx = torch.zeros_like(x0, dtype=torch.bool)
|
| 228 |
+
|
| 229 |
+
for b in range(current_bsz):
|
| 230 |
+
k = min(num_transfer, active_mask[b].sum().item())
|
| 231 |
+
if k <= 0: continue
|
| 232 |
+
high_conf_mask = (confidence[b] > threshold)
|
| 233 |
+
if high_conf_mask.sum() >= k:
|
| 234 |
+
transfer_idx[b] = high_conf_mask
|
| 235 |
+
else:
|
| 236 |
+
_, topk_ids = torch.topk(confidence[b], k=k)
|
| 237 |
+
transfer_idx[b, topk_ids] = True
|
| 238 |
+
|
| 239 |
+
if transfer_idx.any():
|
| 240 |
+
x[:current_bsz, prompt_length:][transfer_idx] = x0[transfer_idx]
|
| 241 |
+
conf_scores[:current_bsz, prompt_length:][transfer_idx] = x0_p[transfer_idx]
|
| 242 |
+
|
| 243 |
+
final_codes = self.tokenizer.batch_decode(x[:current_bsz, prompt_length:], skip_special_tokens=True)
|
| 244 |
+
final_candidates = []
|
| 245 |
+
for i, code in enumerate(final_codes):
|
| 246 |
+
txt = code.split(self.tokenizer.eos_token)[0]
|
| 247 |
+
if until:
|
| 248 |
+
for term in until:
|
| 249 |
+
if term in txt: txt = txt.split(term)[0]
|
| 250 |
+
s = self._safe_scalar(self.verifier.get_reward(prompt_text, txt, mode=reward_mode, task_type=task_type))
|
| 251 |
+
final_candidates.append({'resp': txt, 'score': s})
|
| 252 |
+
|
| 253 |
+
final_candidates.sort(key=lambda c: c['score'], reverse=True)
|
| 254 |
+
stats["final_scores"] = [c['score'] for c in final_candidates]
|
| 255 |
+
stats["all_trajectories"] = [{"rank": i+1, "resp": c['resp'], "score": c['score']} for i, c in enumerate(final_candidates)]
|
| 256 |
+
|
| 257 |
+
return [c['resp'] for c in final_candidates], stats
|
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/huggingface.py
ADDED
|
@@ -0,0 +1,1459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
from datetime import timedelta
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Dict, List, Literal, Optional, Tuple, Union
|
| 7 |
+
|
| 8 |
+
import jinja2
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import transformers
|
| 12 |
+
from accelerate import (
|
| 13 |
+
Accelerator,
|
| 14 |
+
InitProcessGroupKwargs,
|
| 15 |
+
find_executable_batch_size,
|
| 16 |
+
)
|
| 17 |
+
from accelerate.utils import get_max_memory
|
| 18 |
+
from huggingface_hub import HfApi
|
| 19 |
+
from packaging import version
|
| 20 |
+
from peft import PeftModel
|
| 21 |
+
from peft import __version__ as PEFT_VERSION
|
| 22 |
+
from tqdm import tqdm
|
| 23 |
+
from transformers.models.auto.modeling_auto import (
|
| 24 |
+
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
| 25 |
+
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
from lm_eval import utils
|
| 29 |
+
from lm_eval.api.instance import Instance
|
| 30 |
+
from lm_eval.api.model import TemplateLM
|
| 31 |
+
from lm_eval.api.registry import register_model
|
| 32 |
+
from lm_eval.models.utils import (
|
| 33 |
+
Collator,
|
| 34 |
+
clear_torch_cache,
|
| 35 |
+
configure_pad_token,
|
| 36 |
+
get_dtype,
|
| 37 |
+
handle_stop_sequences,
|
| 38 |
+
pad_and_concat,
|
| 39 |
+
stop_sequences_criteria,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
eval_logger = logging.getLogger(__name__)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@register_model("hf-auto", "hf", "huggingface")
|
| 47 |
+
class HFLM(TemplateLM):
|
| 48 |
+
"""
|
| 49 |
+
An abstracted Huggingface model class. Enables usage with both models of
|
| 50 |
+
`transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes.
|
| 51 |
+
|
| 52 |
+
Supports data-parallel multi-GPU with HF Accelerate.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
AUTO_MODEL_CLASS = None
|
| 56 |
+
_DEFAULT_MAX_LENGTH = 2048
|
| 57 |
+
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
pretrained: Union[str, transformers.PreTrainedModel],
|
| 61 |
+
backend: Literal["default", "causal", "seq2seq"] = "default",
|
| 62 |
+
# override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
|
| 63 |
+
revision: Optional[str] = "main",
|
| 64 |
+
subfolder: Optional[str] = None,
|
| 65 |
+
tokenizer: Optional[
|
| 66 |
+
Union[
|
| 67 |
+
str,
|
| 68 |
+
transformers.PreTrainedTokenizer,
|
| 69 |
+
transformers.PreTrainedTokenizerFast,
|
| 70 |
+
]
|
| 71 |
+
] = None,
|
| 72 |
+
truncation: Optional[bool] = False,
|
| 73 |
+
logits_cache: bool = True,
|
| 74 |
+
max_length: Optional[int] = None,
|
| 75 |
+
device: Optional[str] = "cuda",
|
| 76 |
+
dtype: Optional[Union[str, torch.dtype]] = "auto",
|
| 77 |
+
batch_size: Optional[Union[int, str]] = 1,
|
| 78 |
+
max_batch_size: Optional[int] = 64,
|
| 79 |
+
trust_remote_code: Optional[bool] = False,
|
| 80 |
+
use_fast_tokenizer: Optional[bool] = True,
|
| 81 |
+
add_bos_token: Optional[bool] = False,
|
| 82 |
+
prefix_token_id: Optional[int] = None,
|
| 83 |
+
# arguments used for splitting a model across GPUs naively.
|
| 84 |
+
# only used if `parallelize=True`.
|
| 85 |
+
parallelize: Optional[bool] = False,
|
| 86 |
+
max_memory_per_gpu: Optional[Union[int, str]] = None,
|
| 87 |
+
max_cpu_memory: Optional[Union[int, str]] = None,
|
| 88 |
+
offload_folder: Optional[Union[str, os.PathLike]] = "./offload",
|
| 89 |
+
# PEFT, delta weights and quantization options
|
| 90 |
+
peft: Optional[str] = None,
|
| 91 |
+
delta: Optional[str] = None,
|
| 92 |
+
autogptq: Optional[Union[bool, str]] = False,
|
| 93 |
+
gptqmodel: Optional[bool] = False,
|
| 94 |
+
gguf_file: Optional[str] = None,
|
| 95 |
+
**kwargs,
|
| 96 |
+
) -> None:
|
| 97 |
+
super().__init__()
|
| 98 |
+
# optionally: take in an already-initialized transformers.PreTrainedModel
|
| 99 |
+
if not isinstance(pretrained, str):
|
| 100 |
+
eval_logger.warning(
|
| 101 |
+
"`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way."
|
| 102 |
+
)
|
| 103 |
+
assert not parallelize, (
|
| 104 |
+
"`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`"
|
| 105 |
+
)
|
| 106 |
+
self._model = pretrained
|
| 107 |
+
self._device = self._model.device
|
| 108 |
+
self._config = self._model.config
|
| 109 |
+
gpus = 0
|
| 110 |
+
|
| 111 |
+
else:
|
| 112 |
+
assert isinstance(device, str)
|
| 113 |
+
assert isinstance(pretrained, str)
|
| 114 |
+
assert isinstance(batch_size, (int, str))
|
| 115 |
+
|
| 116 |
+
gpus = torch.cuda.device_count()
|
| 117 |
+
accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
|
| 118 |
+
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
|
| 119 |
+
if accelerator.num_processes > 1:
|
| 120 |
+
self.accelerator = accelerator
|
| 121 |
+
|
| 122 |
+
if "npu" in accelerator.device.type:
|
| 123 |
+
gpus = torch.npu.device_count()
|
| 124 |
+
|
| 125 |
+
# using one process with no model parallelism
|
| 126 |
+
if not (parallelize or accelerator.num_processes > 1):
|
| 127 |
+
# use user-passed device
|
| 128 |
+
device_list = set(
|
| 129 |
+
["cuda", "cpu"]
|
| 130 |
+
+ [f"cuda:{i}" for i in range(gpus)]
|
| 131 |
+
+ ["mps", "mps:0"]
|
| 132 |
+
+ [f"npu:{i}" for i in range(gpus)]
|
| 133 |
+
)
|
| 134 |
+
if device and device in device_list:
|
| 135 |
+
self._device = torch.device(device)
|
| 136 |
+
eval_logger.info(f"Using device '{device}'")
|
| 137 |
+
if device in ("mps", "mps:0") and version.parse(
|
| 138 |
+
torch.__version__
|
| 139 |
+
) < version.parse("2.1"):
|
| 140 |
+
raise RuntimeError(
|
| 141 |
+
f"mps requires torch >= 2.1. You have {torch.__version__}"
|
| 142 |
+
)
|
| 143 |
+
else:
|
| 144 |
+
eval_logger.info("Device not specified")
|
| 145 |
+
eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
|
| 146 |
+
self._device = (
|
| 147 |
+
torch.device("cuda")
|
| 148 |
+
if torch.cuda.is_available()
|
| 149 |
+
else torch.device("cpu")
|
| 150 |
+
)
|
| 151 |
+
else: # Parallelism managed by accelerate
|
| 152 |
+
if device != "cuda":
|
| 153 |
+
eval_logger.info(
|
| 154 |
+
f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
|
| 155 |
+
)
|
| 156 |
+
# TODO: include in warning that `load_in_8bit` etc. affect this too
|
| 157 |
+
self._device = (
|
| 158 |
+
self.accelerator.device
|
| 159 |
+
if hasattr(self, "accelerator")
|
| 160 |
+
else torch.device(device)
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
revision = str(revision) # cast to string if not already one
|
| 164 |
+
# TODO: update this to be less of a hack once subfolder is fixed in HF
|
| 165 |
+
revision = revision + ("/" + subfolder if subfolder is not None else "")
|
| 166 |
+
|
| 167 |
+
self._get_config(
|
| 168 |
+
pretrained,
|
| 169 |
+
revision=revision,
|
| 170 |
+
trust_remote_code=trust_remote_code,
|
| 171 |
+
gguf_file=gguf_file,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# determine which of 'causal' and 'seq2seq' backends to use for HF models
|
| 175 |
+
self._get_backend(
|
| 176 |
+
config=self.config, backend=backend, trust_remote_code=trust_remote_code
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
# load tokenizer so we know tokenizer vocabulary size before loading model and PEFT
|
| 180 |
+
self._create_tokenizer(
|
| 181 |
+
pretrained,
|
| 182 |
+
tokenizer,
|
| 183 |
+
revision=revision,
|
| 184 |
+
trust_remote_code=trust_remote_code,
|
| 185 |
+
use_fast_tokenizer=use_fast_tokenizer,
|
| 186 |
+
gguf_file=gguf_file,
|
| 187 |
+
add_bos_token=add_bos_token,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
# if we passed `pretrained` as a string, initialize our model now
|
| 191 |
+
if isinstance(pretrained, str):
|
| 192 |
+
self._create_model(
|
| 193 |
+
pretrained=pretrained,
|
| 194 |
+
revision=revision,
|
| 195 |
+
dtype=dtype,
|
| 196 |
+
trust_remote_code=trust_remote_code,
|
| 197 |
+
parallelize=parallelize,
|
| 198 |
+
gpus=gpus,
|
| 199 |
+
max_memory_per_gpu=max_memory_per_gpu,
|
| 200 |
+
max_cpu_memory=max_cpu_memory,
|
| 201 |
+
offload_folder=offload_folder,
|
| 202 |
+
peft=peft,
|
| 203 |
+
delta=delta,
|
| 204 |
+
autogptq=autogptq,
|
| 205 |
+
gptqmodel=gptqmodel,
|
| 206 |
+
gguf_file=gguf_file,
|
| 207 |
+
**kwargs,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# access self._model through self.model property outside this method
|
| 211 |
+
if isinstance(self.model, torch.nn.Module):
|
| 212 |
+
self.model.eval()
|
| 213 |
+
self.model.tie_weights()
|
| 214 |
+
|
| 215 |
+
self.truncation = truncation
|
| 216 |
+
self.logits_cache = logits_cache
|
| 217 |
+
self.vocab_size = self.tokenizer.vocab_size
|
| 218 |
+
# select (or create) a pad token to use
|
| 219 |
+
self.tokenizer = configure_pad_token(self.tokenizer, model_config=self.config)
|
| 220 |
+
|
| 221 |
+
self.add_bos_token = add_bos_token
|
| 222 |
+
if "gemma" in getattr(self.config, "model_type", ""):
|
| 223 |
+
self.add_bos_token = True
|
| 224 |
+
eval_logger.info(
|
| 225 |
+
f"Model type is '{self.config.model_type}', part of the Gemma family--a BOS token will be used as Gemma underperforms without it."
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
self._max_length = max_length
|
| 229 |
+
self.pretrained = pretrained
|
| 230 |
+
self.delta = delta
|
| 231 |
+
self.peft = peft
|
| 232 |
+
self.revision = revision
|
| 233 |
+
self.batch_schedule = 1
|
| 234 |
+
self.batch_sizes = {}
|
| 235 |
+
self.max_batch_size = max_batch_size
|
| 236 |
+
|
| 237 |
+
if str(batch_size).startswith("auto"):
|
| 238 |
+
batch_size = batch_size.split(":")
|
| 239 |
+
self.batch_size_per_gpu = batch_size[0]
|
| 240 |
+
self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1
|
| 241 |
+
else:
|
| 242 |
+
self.batch_size_per_gpu = int(batch_size)
|
| 243 |
+
|
| 244 |
+
if isinstance(pretrained, str):
|
| 245 |
+
if gpus >= 1 or str(self.device) == "mps":
|
| 246 |
+
# TODO: can remove this whole snippet except in the mps case, perhaps?
|
| 247 |
+
if not (parallelize or autogptq or hasattr(self, "accelerator")):
|
| 248 |
+
# place model onto device requested manually,
|
| 249 |
+
# if not using HF Accelerate or device_map
|
| 250 |
+
# or any other option that preloads model onto device
|
| 251 |
+
try:
|
| 252 |
+
self.model.to(self.device)
|
| 253 |
+
except ValueError:
|
| 254 |
+
eval_logger.debug(
|
| 255 |
+
"Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore."
|
| 256 |
+
)
|
| 257 |
+
# multigpu data-parallel support when launched with accelerate
|
| 258 |
+
if gpus > 1:
|
| 259 |
+
if accelerator.num_processes > 1:
|
| 260 |
+
if parallelize:
|
| 261 |
+
eval_logger.warning(
|
| 262 |
+
"You are both using a HF Accelerate `device_map` (`--model_args parallelize=True`) and launching via `accelerate launch`. This will attempt to do model and data parallelism depending on the resources available."
|
| 263 |
+
)
|
| 264 |
+
elif gpus > accelerator.num_processes:
|
| 265 |
+
eval_logger.warning(
|
| 266 |
+
"WARNING: The number of total system GPUs does not match the number of spawned processes. "
|
| 267 |
+
"If you would like to use data parallelism, please launch the script "
|
| 268 |
+
"with 'accelerate launch *script*'. "
|
| 269 |
+
f"Current run will proceed with {accelerator.num_processes} devices."
|
| 270 |
+
)
|
| 271 |
+
if self.accelerator.is_local_main_process:
|
| 272 |
+
eval_logger.info(
|
| 273 |
+
f"Using {gpus} devices with data parallelism"
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
self._device = torch.device(f"{accelerator.device}")
|
| 277 |
+
self.accelerator = accelerator
|
| 278 |
+
|
| 279 |
+
self._rank = self.accelerator.local_process_index
|
| 280 |
+
self._world_size = self.accelerator.num_processes
|
| 281 |
+
else:
|
| 282 |
+
# if we aren't launching via accelerate, ditch
|
| 283 |
+
self._rank = 0
|
| 284 |
+
self._world_size = 1
|
| 285 |
+
else:
|
| 286 |
+
# if a PreTrainedModel was passed into HFLM, we forgo distributed setup.
|
| 287 |
+
eval_logger.warning(
|
| 288 |
+
"Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration"
|
| 289 |
+
)
|
| 290 |
+
self._rank = 0
|
| 291 |
+
self._world_size = 1
|
| 292 |
+
|
| 293 |
+
self.custom_prefix_token_id = prefix_token_id
|
| 294 |
+
if prefix_token_id is not None:
|
| 295 |
+
eval_logger.info(
|
| 296 |
+
f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}"
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
def _get_accelerate_args(
|
| 300 |
+
self,
|
| 301 |
+
parallelize: Optional[bool] = None,
|
| 302 |
+
device_map: Optional[str] = "auto",
|
| 303 |
+
max_memory_per_gpu: Optional[Union[int, str]] = None,
|
| 304 |
+
max_cpu_memory: Optional[Union[int, str]] = None,
|
| 305 |
+
offload_folder: Optional[str] = "./offload",
|
| 306 |
+
gpus: Optional[int] = None,
|
| 307 |
+
) -> dict:
|
| 308 |
+
"""Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`."""
|
| 309 |
+
num_local_processes = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
|
| 310 |
+
num_machines = int(os.environ.get("WORLD_SIZE", 0)) // num_local_processes
|
| 311 |
+
if (
|
| 312 |
+
num_machines == 0
|
| 313 |
+
and hasattr(self, "accelerator")
|
| 314 |
+
and self.accelerator is not None
|
| 315 |
+
):
|
| 316 |
+
eval_logger.info(
|
| 317 |
+
"We are not in a distributed setting for accelerate. Setting model_parallel to False."
|
| 318 |
+
)
|
| 319 |
+
parallelize = False
|
| 320 |
+
|
| 321 |
+
if parallelize is None:
|
| 322 |
+
# If parallelism is unset by the user, we automatically assign model parallelism
|
| 323 |
+
# if enough extra GPUs are available
|
| 324 |
+
max_memory_all_gpus = get_max_memory()
|
| 325 |
+
# We just want gpu, not cpu, max memory
|
| 326 |
+
if "cpu" in max_memory_all_gpus:
|
| 327 |
+
del max_memory_all_gpus["cpu"]
|
| 328 |
+
parallelize = bool(num_local_processes < len(max_memory_all_gpus))
|
| 329 |
+
eval_logger.info(
|
| 330 |
+
f"Setting model parallel to {parallelize} since "
|
| 331 |
+
f"the number of local processes is {num_local_processes} "
|
| 332 |
+
f"and the number of GPUs is {len(max_memory_all_gpus)}"
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
args = {}
|
| 336 |
+
if parallelize: # Model parallelism will be used
|
| 337 |
+
max_memory = {}
|
| 338 |
+
if max_memory_per_gpu is not None: # Using the provided memory requirements
|
| 339 |
+
max_memory_per_gpu_map = {
|
| 340 |
+
device_idx: max_memory_per_gpu for device_idx in range(gpus)
|
| 341 |
+
}
|
| 342 |
+
else: # Estimating the possible memory requirements
|
| 343 |
+
max_memory_all_gpus = get_max_memory()
|
| 344 |
+
if "cpu" in max_memory_all_gpus:
|
| 345 |
+
del max_memory_all_gpus["cpu"]
|
| 346 |
+
if not hasattr(self, "accelerator"):
|
| 347 |
+
max_memory_per_gpu_map = {
|
| 348 |
+
k: v for k, v in max_memory_all_gpus.items()
|
| 349 |
+
}
|
| 350 |
+
else:
|
| 351 |
+
# use only 1 / num_processes of the GPUs if we are running under accelerate launch
|
| 352 |
+
max_memory_per_gpu_map = {
|
| 353 |
+
k: v
|
| 354 |
+
for k, v in max_memory_all_gpus.items()
|
| 355 |
+
if k % num_local_processes
|
| 356 |
+
== (self.accelerator.process_index % num_local_processes)
|
| 357 |
+
}
|
| 358 |
+
args["max_memory"] = max_memory_per_gpu_map
|
| 359 |
+
args["device_map"] = "auto" if device_map is None else device_map
|
| 360 |
+
eval_logger.info(
|
| 361 |
+
f"Model parallel was set to True, setting max memory per GPU to {max_memory_per_gpu_map} and device map to {args.get('device_map')}"
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
if max_cpu_memory is not None:
|
| 365 |
+
max_memory["cpu"] = max_cpu_memory
|
| 366 |
+
|
| 367 |
+
args["offload_folder"] = offload_folder
|
| 368 |
+
elif (
|
| 369 |
+
device_map is None
|
| 370 |
+
): # No model parallelism, we use the default provided device for our model
|
| 371 |
+
if hasattr(self, "accelerator"):
|
| 372 |
+
device_map = {"": f"{self.accelerator.device}"}
|
| 373 |
+
else:
|
| 374 |
+
device_map = {"": str(self.device)}
|
| 375 |
+
args["max_memory"] = None
|
| 376 |
+
args["device_map"] = device_map
|
| 377 |
+
eval_logger.info(
|
| 378 |
+
f"Model parallel was set to False, max memory was not set, and device map was set to {device_map}"
|
| 379 |
+
)
|
| 380 |
+
else:
|
| 381 |
+
args["max_memory"] = None
|
| 382 |
+
args["device_map"] = None
|
| 383 |
+
eval_logger.info("Model parallel was set to False.")
|
| 384 |
+
|
| 385 |
+
return args
|
| 386 |
+
|
| 387 |
+
@property
|
| 388 |
+
def config(self):
|
| 389 |
+
# return the associated transformers.AutoConfig for the given pretrained model.
|
| 390 |
+
return self._config
|
| 391 |
+
|
| 392 |
+
@property
|
| 393 |
+
def model(self):
|
| 394 |
+
# returns the model, unwrapping it if using Accelerate
|
| 395 |
+
if hasattr(self, "accelerator"):
|
| 396 |
+
return self.accelerator.unwrap_model(self._model)
|
| 397 |
+
else:
|
| 398 |
+
return self._model
|
| 399 |
+
|
| 400 |
+
@property
|
| 401 |
+
def eot_token_id(self):
|
| 402 |
+
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
|
| 403 |
+
return self.tokenizer.eos_token_id
|
| 404 |
+
|
| 405 |
+
@property
|
| 406 |
+
def prefix_token_id(self):
|
| 407 |
+
# it is used as prefix for loglikelihood
|
| 408 |
+
if self.custom_prefix_token_id is not None:
|
| 409 |
+
return self.custom_prefix_token_id
|
| 410 |
+
if self.tokenizer.bos_token_id is not None:
|
| 411 |
+
return self.tokenizer.bos_token_id
|
| 412 |
+
return self.tokenizer.eos_token_id
|
| 413 |
+
|
| 414 |
+
@property
|
| 415 |
+
def max_length(self):
|
| 416 |
+
if self._max_length: # if max length manually set, return it
|
| 417 |
+
return self._max_length
|
| 418 |
+
seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
|
| 419 |
+
for attr in seqlen_config_attrs:
|
| 420 |
+
if hasattr(self.model.config, attr):
|
| 421 |
+
return getattr(self.model.config, attr)
|
| 422 |
+
if hasattr(self.tokenizer, "model_max_length"):
|
| 423 |
+
if self.tokenizer.model_max_length == 1000000000000000019884624838656:
|
| 424 |
+
return self._DEFAULT_MAX_LENGTH
|
| 425 |
+
return self.tokenizer.model_max_length
|
| 426 |
+
return self._DEFAULT_MAX_LENGTH
|
| 427 |
+
|
| 428 |
+
@property
|
| 429 |
+
def max_gen_toks(self) -> int:
|
| 430 |
+
return 256
|
| 431 |
+
|
| 432 |
+
@property
|
| 433 |
+
def batch_size(self):
|
| 434 |
+
return self.batch_size_per_gpu
|
| 435 |
+
|
| 436 |
+
@property
|
| 437 |
+
def device(self):
|
| 438 |
+
return self._device
|
| 439 |
+
|
| 440 |
+
@property
|
| 441 |
+
def rank(self):
|
| 442 |
+
return self._rank
|
| 443 |
+
|
| 444 |
+
@property
|
| 445 |
+
def world_size(self):
|
| 446 |
+
return self._world_size
|
| 447 |
+
|
| 448 |
+
@property
|
| 449 |
+
def tokenizer_name(self) -> str:
|
| 450 |
+
return self.tokenizer.name_or_path.replace("/", "__")
|
| 451 |
+
|
| 452 |
+
def _get_backend(
|
| 453 |
+
self,
|
| 454 |
+
config: Union[transformers.PretrainedConfig, transformers.AutoConfig],
|
| 455 |
+
backend: Literal["default", "causal", "seq2seq"] = "default",
|
| 456 |
+
trust_remote_code: Optional[bool] = False,
|
| 457 |
+
) -> None:
|
| 458 |
+
"""
|
| 459 |
+
Helper method during initialization.
|
| 460 |
+
Determines the backend ("causal" (decoder-only) or "seq2seq" (encoder-decoder)) model type to be used.
|
| 461 |
+
sets `self.AUTO_MODEL_CLASS` appropriately if not already set.
|
| 462 |
+
|
| 463 |
+
**If not calling HFLM.__init__() or HFLM._get_backend() within a subclass of HFLM,
|
| 464 |
+
user must set `self.backend` to be either "causal" or "seq2seq" manually!**
|
| 465 |
+
"""
|
| 466 |
+
|
| 467 |
+
assert backend in ["default", "causal", "seq2seq"]
|
| 468 |
+
|
| 469 |
+
if backend != "default":
|
| 470 |
+
# if we've settled on non-default backend, use that manually
|
| 471 |
+
if backend == "causal":
|
| 472 |
+
self.backend = backend
|
| 473 |
+
elif backend == "seq2seq":
|
| 474 |
+
self.backend = backend
|
| 475 |
+
eval_logger.info(
|
| 476 |
+
f"Overrode HF model backend type, and using type '{self.backend}'"
|
| 477 |
+
)
|
| 478 |
+
else:
|
| 479 |
+
# determine and use the default HF backend for this model, based on its config + metadata.
|
| 480 |
+
if (
|
| 481 |
+
getattr(config, "model_type")
|
| 482 |
+
in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
|
| 483 |
+
):
|
| 484 |
+
# first check if model type is listed under seq2seq models, since some
|
| 485 |
+
# models like MBart are listed in both seq2seq and causal mistakenly in HF transformers.
|
| 486 |
+
# these special cases should be treated as seq2seq models.
|
| 487 |
+
self.backend = "seq2seq"
|
| 488 |
+
eval_logger.debug(f"Using model type '{self.backend}'")
|
| 489 |
+
elif (
|
| 490 |
+
getattr(self.config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
| 491 |
+
):
|
| 492 |
+
self.backend = "causal"
|
| 493 |
+
eval_logger.debug(f"Using model type '{self.backend}'")
|
| 494 |
+
else:
|
| 495 |
+
if not trust_remote_code:
|
| 496 |
+
eval_logger.warning(
|
| 497 |
+
"HF model type is neither marked as CausalLM or Seq2SeqLM. \
|
| 498 |
+
This is expected if your model requires `trust_remote_code=True` but may be an error otherwise."
|
| 499 |
+
"Setting backend to causal"
|
| 500 |
+
)
|
| 501 |
+
# if model type is neither in HF transformers causal or seq2seq model registries
|
| 502 |
+
# then we default to assuming AutoModelForCausalLM
|
| 503 |
+
self.backend = "causal"
|
| 504 |
+
eval_logger.info(
|
| 505 |
+
f"Model type cannot be determined. Using default model type '{self.backend}'"
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
if self.AUTO_MODEL_CLASS is None:
|
| 509 |
+
if self.backend == "causal":
|
| 510 |
+
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
|
| 511 |
+
elif self.backend == "seq2seq":
|
| 512 |
+
self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
|
| 513 |
+
|
| 514 |
+
def _get_config(
|
| 515 |
+
self,
|
| 516 |
+
pretrained: str,
|
| 517 |
+
revision: str = "main",
|
| 518 |
+
trust_remote_code: bool = False,
|
| 519 |
+
gguf_file: Optional[str] = None,
|
| 520 |
+
) -> None:
|
| 521 |
+
"""Return the model config for HuggingFace models"""
|
| 522 |
+
self._config = transformers.AutoConfig.from_pretrained(
|
| 523 |
+
pretrained,
|
| 524 |
+
revision=revision,
|
| 525 |
+
trust_remote_code=trust_remote_code,
|
| 526 |
+
gguf_file=gguf_file,
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
def _create_model(
|
| 530 |
+
self,
|
| 531 |
+
pretrained: str,
|
| 532 |
+
revision: Optional[str] = "main",
|
| 533 |
+
dtype: Optional[Union[str, torch.dtype]] = "auto",
|
| 534 |
+
trust_remote_code: Optional[bool] = False,
|
| 535 |
+
# arguments used for splitting a model across GPUs naively.
|
| 536 |
+
# only used if `parallelize=True`.
|
| 537 |
+
# (accelerate naive PP (device_map) options)
|
| 538 |
+
parallelize: Optional[bool] = False,
|
| 539 |
+
gpus: Optional[int] = None,
|
| 540 |
+
max_memory_per_gpu: Optional[Union[int, str]] = None,
|
| 541 |
+
max_cpu_memory: Optional[Union[int, str]] = None,
|
| 542 |
+
offload_folder: Optional[str] = "./offload",
|
| 543 |
+
# PEFT, delta weights and quantization options
|
| 544 |
+
peft: Optional[str] = None,
|
| 545 |
+
delta: Optional[str] = None,
|
| 546 |
+
autogptq: Optional[Union[bool, str]] = False,
|
| 547 |
+
gptqmodel: Optional[bool] = False,
|
| 548 |
+
gguf_file: Optional[str] = None,
|
| 549 |
+
**kwargs,
|
| 550 |
+
) -> None:
|
| 551 |
+
"""
|
| 552 |
+
Initializes an HF or HF-compatible PreTrainedModel from scratch
|
| 553 |
+
inside HFLM, using the kwargs passed into self.__init__().
|
| 554 |
+
|
| 555 |
+
Also handles functionality such as AutoGPTQ usage and PEFT wrapping.
|
| 556 |
+
|
| 557 |
+
For future similar extensions to AutoGPTQ that are not core to HF's ecosystem,
|
| 558 |
+
(such as PyTorch models that are nearly, but not quite, fully mirroring
|
| 559 |
+
HF's public interface relied on in this HFLM class)
|
| 560 |
+
please consider subclassing HFLM and overriding this and other methods as needed.
|
| 561 |
+
"""
|
| 562 |
+
|
| 563 |
+
model_kwargs = kwargs if kwargs else {}
|
| 564 |
+
|
| 565 |
+
model_kwargs.update(
|
| 566 |
+
self._get_accelerate_args(
|
| 567 |
+
parallelize=parallelize,
|
| 568 |
+
device_map=kwargs.get("device_map", None),
|
| 569 |
+
max_memory_per_gpu=max_memory_per_gpu,
|
| 570 |
+
max_cpu_memory=max_cpu_memory,
|
| 571 |
+
offload_folder=offload_folder,
|
| 572 |
+
gpus=gpus,
|
| 573 |
+
)
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
if not autogptq and not gptqmodel:
|
| 577 |
+
if model_kwargs.get("load_in_4bit", None):
|
| 578 |
+
assert transformers.__version__ >= "4.30.0", (
|
| 579 |
+
"load_in_4bit requires transformers >= 4.30.0"
|
| 580 |
+
)
|
| 581 |
+
if transformers.__version__ >= "4.30.0":
|
| 582 |
+
if model_kwargs.get("load_in_4bit", None):
|
| 583 |
+
if model_kwargs.get("bnb_4bit_compute_dtype", None):
|
| 584 |
+
model_kwargs["bnb_4bit_compute_dtype"] = get_dtype(
|
| 585 |
+
model_kwargs["bnb_4bit_compute_dtype"]
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
self._model = self.AUTO_MODEL_CLASS.from_pretrained(
|
| 589 |
+
pretrained,
|
| 590 |
+
revision=revision,
|
| 591 |
+
torch_dtype=get_dtype(dtype),
|
| 592 |
+
trust_remote_code=trust_remote_code,
|
| 593 |
+
gguf_file=gguf_file,
|
| 594 |
+
**model_kwargs,
|
| 595 |
+
)
|
| 596 |
+
else:
|
| 597 |
+
if autogptq and gptqmodel:
|
| 598 |
+
raise ValueError(
|
| 599 |
+
"Cannot use both 'autogptq' and 'gptqmodel' options at the same time."
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
if autogptq:
|
| 603 |
+
try:
|
| 604 |
+
from auto_gptq import AutoGPTQForCausalLM
|
| 605 |
+
except ModuleNotFoundError as exception:
|
| 606 |
+
raise type(exception)(
|
| 607 |
+
"Tried to load auto_gptq, but auto-gptq is not installed ",
|
| 608 |
+
"please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]",
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
self._model = AutoGPTQForCausalLM.from_quantized(
|
| 612 |
+
pretrained,
|
| 613 |
+
trust_remote_code=trust_remote_code,
|
| 614 |
+
model_basename=None if autogptq is True else Path(autogptq).stem,
|
| 615 |
+
use_safetensors=True
|
| 616 |
+
if autogptq is True
|
| 617 |
+
else autogptq.endswith(".safetensors"),
|
| 618 |
+
**model_kwargs,
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
if gptqmodel:
|
| 622 |
+
try:
|
| 623 |
+
from gptqmodel import GPTQModel
|
| 624 |
+
except ModuleNotFoundError as exception:
|
| 625 |
+
raise type(exception)(
|
| 626 |
+
"Tried to load gptqmodel, but gptqmodel is not installed ",
|
| 627 |
+
"please install gptqmodel via `pip install gptqmodel --no-build-isolation` or `pip install lm-eval[gptqmodel] --no-build-isolation`",
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
self._model = GPTQModel.from_quantized(
|
| 631 |
+
pretrained, trust_remote_code=trust_remote_code, **model_kwargs
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
if peft and delta:
|
| 635 |
+
raise ValueError(
|
| 636 |
+
"Cannot use both 'peft' and 'delta' options at the same time."
|
| 637 |
+
)
|
| 638 |
+
|
| 639 |
+
if peft:
|
| 640 |
+
if model_kwargs.get("load_in_4bit", None):
|
| 641 |
+
if version.parse(PEFT_VERSION) < version.parse("0.4.0"):
|
| 642 |
+
raise AssertionError("load_in_4bit requires peft >= 0.4.0")
|
| 643 |
+
if self._model.config.vocab_size != len(self.tokenizer):
|
| 644 |
+
# resize model for LoRAs with added tokens
|
| 645 |
+
eval_logger.info(
|
| 646 |
+
f"Model config indicates vocab_size='{self._model.config.vocab_size}', but found tokenizer with vocab size '{len(self.tokenizer)}'. Resizing model embedding layer..."
|
| 647 |
+
)
|
| 648 |
+
self._model.resize_token_embeddings(len(self.tokenizer))
|
| 649 |
+
self._model = PeftModel.from_pretrained(
|
| 650 |
+
self._model, peft, revision=revision
|
| 651 |
+
)
|
| 652 |
+
elif delta:
|
| 653 |
+
if autogptq:
|
| 654 |
+
eval_logger.warning(
|
| 655 |
+
"Delta weights might trigger unexpected behavior when used with AutoGPTQ."
|
| 656 |
+
)
|
| 657 |
+
_model_delta = self.AUTO_MODEL_CLASS.from_pretrained(
|
| 658 |
+
delta,
|
| 659 |
+
revision=revision,
|
| 660 |
+
torch_dtype=get_dtype(dtype),
|
| 661 |
+
trust_remote_code=trust_remote_code,
|
| 662 |
+
**model_kwargs,
|
| 663 |
+
)
|
| 664 |
+
for name, param in self._model.state_dict().items():
|
| 665 |
+
try:
|
| 666 |
+
param.data += _model_delta.state_dict()[name]
|
| 667 |
+
except KeyError:
|
| 668 |
+
raise KeyError(f"Delta model is missing weights for layer: {name}")
|
| 669 |
+
except Exception as e:
|
| 670 |
+
raise RuntimeError(
|
| 671 |
+
f"Failed to add delta weights to layer {name}. Error: {e}"
|
| 672 |
+
)
|
| 673 |
+
|
| 674 |
+
del _model_delta
|
| 675 |
+
|
| 676 |
+
return None
|
| 677 |
+
|
| 678 |
+
def _create_tokenizer(
|
| 679 |
+
self,
|
| 680 |
+
pretrained: Union[str, transformers.PreTrainedModel],
|
| 681 |
+
tokenizer: Optional[
|
| 682 |
+
Union[
|
| 683 |
+
str,
|
| 684 |
+
transformers.PreTrainedTokenizer,
|
| 685 |
+
transformers.PreTrainedTokenizerFast,
|
| 686 |
+
]
|
| 687 |
+
],
|
| 688 |
+
revision: Optional[str] = "main",
|
| 689 |
+
trust_remote_code: Optional[bool] = False,
|
| 690 |
+
use_fast_tokenizer: Optional[bool] = True,
|
| 691 |
+
gguf_file: Optional[str] = None,
|
| 692 |
+
add_bos_token: Optional[bool] = False,
|
| 693 |
+
) -> None:
|
| 694 |
+
"""
|
| 695 |
+
Helper method during initialization.
|
| 696 |
+
|
| 697 |
+
Create a tokenizer object corresponding to the correct
|
| 698 |
+
tokenizer for value of `pretrained`, or use the pre-initialized tokenizer passed.
|
| 699 |
+
"""
|
| 700 |
+
kwargs = {
|
| 701 |
+
"revision": revision,
|
| 702 |
+
"trust_remote_code": trust_remote_code,
|
| 703 |
+
}
|
| 704 |
+
|
| 705 |
+
# gguf format embeds tokenizer and is not compatible with hf tokenizer `use_fast` param
|
| 706 |
+
if gguf_file is not None:
|
| 707 |
+
kwargs["gguf_file"] = gguf_file
|
| 708 |
+
else:
|
| 709 |
+
kwargs["use_fast"] = use_fast_tokenizer
|
| 710 |
+
|
| 711 |
+
if add_bos_token:
|
| 712 |
+
kwargs["add_bos_token"] = True
|
| 713 |
+
|
| 714 |
+
if tokenizer:
|
| 715 |
+
if isinstance(tokenizer, str):
|
| 716 |
+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 717 |
+
tokenizer, **kwargs
|
| 718 |
+
)
|
| 719 |
+
else:
|
| 720 |
+
assert isinstance(
|
| 721 |
+
tokenizer, transformers.PreTrainedTokenizer
|
| 722 |
+
) or isinstance(tokenizer, transformers.PreTrainedTokenizerFast)
|
| 723 |
+
self.tokenizer = tokenizer
|
| 724 |
+
else:
|
| 725 |
+
# Get tokenizer based on 'pretrained'
|
| 726 |
+
if isinstance(pretrained, str):
|
| 727 |
+
model_name = pretrained
|
| 728 |
+
else:
|
| 729 |
+
# get the HF hub name via accessor on model
|
| 730 |
+
model_name = self.model.name_or_path
|
| 731 |
+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 732 |
+
model_name, **kwargs
|
| 733 |
+
)
|
| 734 |
+
return None
|
| 735 |
+
|
| 736 |
+
def _detect_batch_size(self, requests=None, pos: int = 0):
|
| 737 |
+
if requests:
|
| 738 |
+
_, context_enc, continuation_enc = requests[pos]
|
| 739 |
+
max_length = len(
|
| 740 |
+
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]
|
| 741 |
+
)
|
| 742 |
+
max_context_enc = len(context_enc[-(self.max_length + 1) :])
|
| 743 |
+
max_cont_enc = len(continuation_enc[-(self.max_length + 1) :])
|
| 744 |
+
else:
|
| 745 |
+
max_length = self.max_length
|
| 746 |
+
max_context_enc = max_length
|
| 747 |
+
max_cont_enc = max_length
|
| 748 |
+
|
| 749 |
+
# if OOM, then halves batch_size and tries again
|
| 750 |
+
@find_executable_batch_size(starting_batch_size=self.max_batch_size)
|
| 751 |
+
def forward_batch(batch_size):
|
| 752 |
+
if self.backend == "seq2seq":
|
| 753 |
+
length = max(max_context_enc, max_cont_enc)
|
| 754 |
+
batched_conts = torch.ones(
|
| 755 |
+
(batch_size, length), device=self.device
|
| 756 |
+
).long()
|
| 757 |
+
test_batch = torch.ones((batch_size, length), device=self.device).long()
|
| 758 |
+
call_kwargs = {
|
| 759 |
+
"attn_mask": test_batch,
|
| 760 |
+
"labels": batched_conts,
|
| 761 |
+
}
|
| 762 |
+
else:
|
| 763 |
+
call_kwargs = {}
|
| 764 |
+
test_batch = torch.ones(
|
| 765 |
+
(batch_size, max_length), device=self.device
|
| 766 |
+
).long()
|
| 767 |
+
for _ in range(5):
|
| 768 |
+
out = F.log_softmax(self._model_call(test_batch, **call_kwargs), dim=-1) # noqa: F841
|
| 769 |
+
|
| 770 |
+
return batch_size
|
| 771 |
+
|
| 772 |
+
try:
|
| 773 |
+
batch_size = forward_batch()
|
| 774 |
+
except RuntimeError as e:
|
| 775 |
+
if "No executable batch size found" in str(e):
|
| 776 |
+
batch_size = 1
|
| 777 |
+
else:
|
| 778 |
+
raise
|
| 779 |
+
|
| 780 |
+
if self.world_size > 1:
|
| 781 |
+
# if multi-GPU, always take minimum over all selected batch sizes
|
| 782 |
+
max_rnk_bs = torch.tensor([batch_size], device=self.device)
|
| 783 |
+
gathered = (
|
| 784 |
+
self.accelerator.gather(max_rnk_bs).cpu().detach().numpy().tolist()
|
| 785 |
+
)
|
| 786 |
+
batch_size = min(gathered)
|
| 787 |
+
clear_torch_cache()
|
| 788 |
+
return batch_size
|
| 789 |
+
|
| 790 |
+
clear_torch_cache()
|
| 791 |
+
return batch_size
|
| 792 |
+
|
| 793 |
+
def tok_encode(
|
| 794 |
+
self, string: str, left_truncate_len=None, add_special_tokens=None
|
| 795 |
+
) -> List[int]:
|
| 796 |
+
""" """
|
| 797 |
+
# default for None - empty dict, use predefined tokenizer param
|
| 798 |
+
# used for all models except for CausalLM or predefined value
|
| 799 |
+
special_tokens_kwargs = {}
|
| 800 |
+
|
| 801 |
+
# by default for CausalLM - false or self.add_bos_token is set
|
| 802 |
+
if add_special_tokens is None:
|
| 803 |
+
if self.backend == "causal":
|
| 804 |
+
special_tokens_kwargs = {
|
| 805 |
+
"add_special_tokens": False or self.add_bos_token
|
| 806 |
+
}
|
| 807 |
+
# otherwise the method explicitly defines the value
|
| 808 |
+
else:
|
| 809 |
+
special_tokens_kwargs = {"add_special_tokens": add_special_tokens}
|
| 810 |
+
|
| 811 |
+
encoding = self.tokenizer.encode(string, **special_tokens_kwargs)
|
| 812 |
+
|
| 813 |
+
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
|
| 814 |
+
if left_truncate_len:
|
| 815 |
+
encoding = encoding[-left_truncate_len:]
|
| 816 |
+
|
| 817 |
+
return encoding
|
| 818 |
+
|
| 819 |
+
def tok_batch_encode(
|
| 820 |
+
self,
|
| 821 |
+
strings: List[str],
|
| 822 |
+
padding_side: str = "left",
|
| 823 |
+
left_truncate_len: int = None,
|
| 824 |
+
truncation: bool = False,
|
| 825 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 826 |
+
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
|
| 827 |
+
old_padding_side = self.tokenizer.padding_side
|
| 828 |
+
self.tokenizer.padding_side = padding_side
|
| 829 |
+
|
| 830 |
+
add_special_tokens = {}
|
| 831 |
+
if self.backend == "causal":
|
| 832 |
+
add_special_tokens = {"add_special_tokens": False or self.add_bos_token}
|
| 833 |
+
|
| 834 |
+
encoding = self.tokenizer(
|
| 835 |
+
strings,
|
| 836 |
+
truncation=truncation,
|
| 837 |
+
padding="longest",
|
| 838 |
+
return_tensors="pt",
|
| 839 |
+
**add_special_tokens,
|
| 840 |
+
)
|
| 841 |
+
if left_truncate_len:
|
| 842 |
+
original_lengths = encoding["input_ids"].size(1)
|
| 843 |
+
if original_lengths > left_truncate_len:
|
| 844 |
+
eval_logger.warn(
|
| 845 |
+
f"Left truncation applied. Original sequence length was {original_lengths}, "
|
| 846 |
+
f"truncating to last {left_truncate_len} tokens. Some content will be lost.",
|
| 847 |
+
)
|
| 848 |
+
encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
|
| 849 |
+
encoding["attention_mask"] = encoding["attention_mask"][
|
| 850 |
+
:, -left_truncate_len:
|
| 851 |
+
]
|
| 852 |
+
self.tokenizer.padding_side = old_padding_side
|
| 853 |
+
|
| 854 |
+
return encoding["input_ids"], encoding["attention_mask"]
|
| 855 |
+
|
| 856 |
+
def tok_decode(self, tokens, skip_special_tokens=True):
|
| 857 |
+
return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
| 858 |
+
|
| 859 |
+
def _model_call(self, inps, attn_mask=None, labels=None):
|
| 860 |
+
"""
|
| 861 |
+
:param inps: torch.Tensor
|
| 862 |
+
A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape
|
| 863 |
+
[batch, sequence_ctx]. the size of sequence may vary from call to call
|
| 864 |
+
:param attn_mask: torch.Tensor, optional
|
| 865 |
+
A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
|
| 866 |
+
(and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
|
| 867 |
+
:param labels: torch.Tensor, optional
|
| 868 |
+
A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
|
| 869 |
+
(and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
|
| 870 |
+
:return
|
| 871 |
+
A torch tensor of shape [batch, sequence, vocab] with the
|
| 872 |
+
logits returned from the model's decoder
|
| 873 |
+
"""
|
| 874 |
+
with torch.no_grad():
|
| 875 |
+
if attn_mask is not None or labels is not None:
|
| 876 |
+
assert attn_mask is not None and labels is not None
|
| 877 |
+
assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM
|
| 878 |
+
return self.model(
|
| 879 |
+
input_ids=inps, attention_mask=attn_mask, labels=labels
|
| 880 |
+
).logits
|
| 881 |
+
else:
|
| 882 |
+
assert self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
|
| 883 |
+
return self.model(inps).logits
|
| 884 |
+
|
| 885 |
+
def _model_generate(self, context, max_length, stop, **generation_kwargs):
|
| 886 |
+
# temperature = 0.0 if not set
|
| 887 |
+
# if do_sample is false and temp==0.0:
|
| 888 |
+
# remove temperature, as do_sample=False takes care of this
|
| 889 |
+
# and we don't want a warning from HF
|
| 890 |
+
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
|
| 891 |
+
do_sample = generation_kwargs.get("do_sample", None)
|
| 892 |
+
|
| 893 |
+
# The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
|
| 894 |
+
if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
|
| 895 |
+
generation_kwargs["do_sample"] = do_sample = False
|
| 896 |
+
|
| 897 |
+
if do_sample is False and generation_kwargs.get("temperature") == 0.0:
|
| 898 |
+
generation_kwargs.pop("temperature")
|
| 899 |
+
# build stopping criteria
|
| 900 |
+
stopping_criteria = stop_sequences_criteria(
|
| 901 |
+
self.tokenizer, stop, context.shape[1], context.shape[0]
|
| 902 |
+
)
|
| 903 |
+
return self.model.generate(
|
| 904 |
+
input_ids=context,
|
| 905 |
+
max_length=max_length,
|
| 906 |
+
stopping_criteria=stopping_criteria,
|
| 907 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 908 |
+
use_cache=True,
|
| 909 |
+
**generation_kwargs,
|
| 910 |
+
)
|
| 911 |
+
|
| 912 |
+
def _select_cont_toks(
|
| 913 |
+
self, logits: torch.Tensor, contlen: int = None, inplen: int = None
|
| 914 |
+
) -> torch.Tensor:
|
| 915 |
+
if self.backend == "causal":
|
| 916 |
+
assert contlen and inplen, (
|
| 917 |
+
"Must pass input len and cont. len to select scored logits for causal LM"
|
| 918 |
+
)
|
| 919 |
+
# discard right-padding.
|
| 920 |
+
# also discard the input/context tokens. we'll only score continuations.
|
| 921 |
+
logits = logits[inplen - contlen : inplen]
|
| 922 |
+
elif self.backend == "seq2seq":
|
| 923 |
+
assert contlen and not inplen, (
|
| 924 |
+
"Selecting scored logits for Seq2SeqLM requires only cont. len"
|
| 925 |
+
)
|
| 926 |
+
# only discard right-padding.
|
| 927 |
+
# the logits input to this fn only contain decoder-side tokens.
|
| 928 |
+
logits = logits[:contlen]
|
| 929 |
+
|
| 930 |
+
return logits
|
| 931 |
+
|
| 932 |
+
def loglikelihood_rolling(
|
| 933 |
+
self, requests: List[Instance], disable_tqdm: bool = False
|
| 934 |
+
) -> List[float]:
|
| 935 |
+
adaptive_batch_size = None
|
| 936 |
+
if self.batch_size == "auto":
|
| 937 |
+
# using rolling window with maximum context
|
| 938 |
+
print("Passed argument batch_size = auto. Detecting largest batch size")
|
| 939 |
+
batch_size = self._detect_batch_size()
|
| 940 |
+
print(f"Determined Largest batch size: {batch_size}")
|
| 941 |
+
adaptive_batch_size = batch_size
|
| 942 |
+
|
| 943 |
+
# First, collect all windows from all requests
|
| 944 |
+
all_windows = [] # List of (request_idx, window) tuples
|
| 945 |
+
request_window_counts = [] # Track number of windows per request
|
| 946 |
+
|
| 947 |
+
for req_idx, (string,) in enumerate(
|
| 948 |
+
tqdm(
|
| 949 |
+
[req.args for req in requests],
|
| 950 |
+
disable=(disable_tqdm or (self.rank != 0)),
|
| 951 |
+
)
|
| 952 |
+
):
|
| 953 |
+
rolling_token_windows: List[Tuple[List[int], List[int]]] = list(
|
| 954 |
+
map(
|
| 955 |
+
utils.make_disjoint_window,
|
| 956 |
+
utils.get_rolling_token_windows(
|
| 957 |
+
token_list=self.tok_encode(string),
|
| 958 |
+
prefix_token=self.prefix_token_id,
|
| 959 |
+
max_seq_len=self.max_length,
|
| 960 |
+
context_len=1,
|
| 961 |
+
),
|
| 962 |
+
)
|
| 963 |
+
)
|
| 964 |
+
|
| 965 |
+
# TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
|
| 966 |
+
windows = [(None,) + x for x in rolling_token_windows]
|
| 967 |
+
|
| 968 |
+
# Store windows with their request index
|
| 969 |
+
all_windows.extend((req_idx, window) for window in windows)
|
| 970 |
+
request_window_counts.append(len(windows))
|
| 971 |
+
|
| 972 |
+
# Handle distributed case padding
|
| 973 |
+
pad_amnt = 0
|
| 974 |
+
if self.world_size > 1:
|
| 975 |
+
mytensor = torch.tensor(len(all_windows), device=self.device)
|
| 976 |
+
gathered = self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
|
| 977 |
+
pad_amnt = max(gathered) - gathered[self.rank]
|
| 978 |
+
if pad_amnt > 0:
|
| 979 |
+
all_windows += pad_amnt * [all_windows[0]]
|
| 980 |
+
|
| 981 |
+
all_nlls = []
|
| 982 |
+
batch_size = adaptive_batch_size or self.batch_size
|
| 983 |
+
for i in range(0, len(all_windows), batch_size):
|
| 984 |
+
batch = all_windows[i : i + batch_size]
|
| 985 |
+
# Extract just the windows for processing, keeping track of request indices
|
| 986 |
+
batch_indices, batch_windows = zip(*batch)
|
| 987 |
+
|
| 988 |
+
batch_nlls = self._loglikelihood_tokens(
|
| 989 |
+
requests=batch_windows,
|
| 990 |
+
disable_tqdm=False,
|
| 991 |
+
override_bs=len(batch_windows),
|
| 992 |
+
)
|
| 993 |
+
# Store results with their request indices
|
| 994 |
+
all_nlls.extend(zip(batch_indices, batch_nlls))
|
| 995 |
+
|
| 996 |
+
# Remove padding if necessary
|
| 997 |
+
if (self.world_size > 1) and (pad_amnt > 0):
|
| 998 |
+
all_nlls = all_nlls[:-pad_amnt]
|
| 999 |
+
|
| 1000 |
+
# Reconstruct per-request loglikelihoods
|
| 1001 |
+
loglikelihoods = []
|
| 1002 |
+
current_idx = 0
|
| 1003 |
+
for window_count in request_window_counts:
|
| 1004 |
+
# Get all nlls for this request
|
| 1005 |
+
request_nlls = all_nlls[current_idx : current_idx + window_count]
|
| 1006 |
+
# Sum up the nlls for this request (discarding is_greedy)
|
| 1007 |
+
request_total = sum(nll[0] for _, nll in request_nlls)
|
| 1008 |
+
loglikelihoods.append(request_total)
|
| 1009 |
+
current_idx += window_count
|
| 1010 |
+
|
| 1011 |
+
string = requests[len(loglikelihoods) - 1].args[0]
|
| 1012 |
+
self.cache_hook.add_partial(
|
| 1013 |
+
"loglikelihood_rolling", (string,), request_total
|
| 1014 |
+
)
|
| 1015 |
+
|
| 1016 |
+
return loglikelihoods
|
| 1017 |
+
|
| 1018 |
+
def _batch_scheduler(self, pos, n_reordered_requests):
|
| 1019 |
+
sched = pos // int(len(n_reordered_requests) / self.batch_schedule)
|
| 1020 |
+
if sched in self.batch_sizes:
|
| 1021 |
+
return self.batch_sizes[sched]
|
| 1022 |
+
if (len(self.batch_sizes) > 1) and (
|
| 1023 |
+
self.batch_sizes[sched - 1] == self.max_batch_size
|
| 1024 |
+
):
|
| 1025 |
+
# if previous batch size is already maximal, skip recomputation
|
| 1026 |
+
self.batch_sizes[sched] = self.max_batch_size
|
| 1027 |
+
return self.batch_sizes[sched]
|
| 1028 |
+
print(
|
| 1029 |
+
f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size"
|
| 1030 |
+
)
|
| 1031 |
+
self.batch_sizes[sched] = self._detect_batch_size(n_reordered_requests, pos)
|
| 1032 |
+
print(f"Determined largest batch size: {self.batch_sizes[sched]}")
|
| 1033 |
+
return self.batch_sizes[sched]
|
| 1034 |
+
|
| 1035 |
+
def _loglikelihood_tokens(
|
| 1036 |
+
self,
|
| 1037 |
+
requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
|
| 1038 |
+
disable_tqdm: bool = False,
|
| 1039 |
+
override_bs: int = None,
|
| 1040 |
+
) -> List[Tuple[float, bool]]:
|
| 1041 |
+
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
|
| 1042 |
+
res = []
|
| 1043 |
+
|
| 1044 |
+
def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]):
|
| 1045 |
+
"""Defines the key for the sorted method"""
|
| 1046 |
+
# the negative sign on len(toks) sorts descending - this has a few advantages:
|
| 1047 |
+
# - time estimates will always be over not underestimates, which is more useful for planning
|
| 1048 |
+
# - to know the size of a batch when going through the list, you know the first one is always the batch
|
| 1049 |
+
# padded context length. this is useful to simplify the batching logic and more importantly to make
|
| 1050 |
+
# automatic adaptive batches much much easier to implement
|
| 1051 |
+
# - any OOMs will happen right away rather than near the end
|
| 1052 |
+
|
| 1053 |
+
toks = req[1] + req[2]
|
| 1054 |
+
return -len(toks), tuple(toks)
|
| 1055 |
+
|
| 1056 |
+
def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
|
| 1057 |
+
"""Defines the key to group and lookup one-token continuations"""
|
| 1058 |
+
# Use with group_by="contexts" (optional)"
|
| 1059 |
+
# allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
|
| 1060 |
+
# speeds up some multiple-choice tasks proportionally to the number of choices.
|
| 1061 |
+
# groups requests by context+continuation[:-1] and infer on one request/group.
|
| 1062 |
+
return req[-2] + req[-1][:-1]
|
| 1063 |
+
|
| 1064 |
+
re_ord = Collator(
|
| 1065 |
+
requests,
|
| 1066 |
+
sort_fn=_collate,
|
| 1067 |
+
group_by="contexts"
|
| 1068 |
+
if self.backend == "causal" and self.logits_cache
|
| 1069 |
+
else None,
|
| 1070 |
+
group_fn=_lookup_one_token_cont,
|
| 1071 |
+
)
|
| 1072 |
+
|
| 1073 |
+
# automatic (variable) batch size detection for vectorization
|
| 1074 |
+
# pull longest context sample from request
|
| 1075 |
+
n_reordered_requests = len(re_ord)
|
| 1076 |
+
batch_size = (
|
| 1077 |
+
self.batch_size
|
| 1078 |
+
if self.batch_size != "auto"
|
| 1079 |
+
else override_bs
|
| 1080 |
+
if override_bs is not None
|
| 1081 |
+
else 0
|
| 1082 |
+
)
|
| 1083 |
+
batch_fn = (
|
| 1084 |
+
self._batch_scheduler
|
| 1085 |
+
if self.batch_size == "auto"
|
| 1086 |
+
and n_reordered_requests > 0
|
| 1087 |
+
and not override_bs
|
| 1088 |
+
else None
|
| 1089 |
+
)
|
| 1090 |
+
|
| 1091 |
+
chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn)
|
| 1092 |
+
pbar = tqdm(
|
| 1093 |
+
total=len(requests),
|
| 1094 |
+
disable=(disable_tqdm or (self.rank != 0)),
|
| 1095 |
+
desc="Running loglikelihood requests",
|
| 1096 |
+
)
|
| 1097 |
+
for chunk in chunks:
|
| 1098 |
+
inps = []
|
| 1099 |
+
cont_toks_list = []
|
| 1100 |
+
inplens = []
|
| 1101 |
+
|
| 1102 |
+
conts = []
|
| 1103 |
+
encoder_attns = []
|
| 1104 |
+
|
| 1105 |
+
padding_len_inp = None
|
| 1106 |
+
padding_len_cont = None
|
| 1107 |
+
# because vectorizing is annoying, we first convert each (context, continuation) pair to padded
|
| 1108 |
+
# tensors, then we pack them together into a batch, call the model, and then pick it all apart
|
| 1109 |
+
# again because vectorizing is annoying
|
| 1110 |
+
|
| 1111 |
+
for _, context_enc, continuation_enc in chunk:
|
| 1112 |
+
# sanity check
|
| 1113 |
+
assert len(context_enc) > 0
|
| 1114 |
+
assert len(continuation_enc) > 0
|
| 1115 |
+
assert len(continuation_enc) <= self.max_length
|
| 1116 |
+
|
| 1117 |
+
# how this all works (illustrated on a causal decoder-only setup):
|
| 1118 |
+
# CTX CONT
|
| 1119 |
+
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
|
| 1120 |
+
# model \ \
|
| 1121 |
+
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
|
| 1122 |
+
# cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice
|
| 1123 |
+
|
| 1124 |
+
# when too long to fit in context, truncate from the left
|
| 1125 |
+
if self.backend == "causal":
|
| 1126 |
+
total_length = len(context_enc) + len(continuation_enc)
|
| 1127 |
+
if total_length > self.max_length + 1:
|
| 1128 |
+
eval_logger.warn(
|
| 1129 |
+
f"Combined length of context ({len(context_enc)}) and continuation ({len(continuation_enc)}) "
|
| 1130 |
+
f"exceeds model's maximum length ({self.max_length}). "
|
| 1131 |
+
f"Truncating {total_length - self.max_length + 1} tokens from the left."
|
| 1132 |
+
)
|
| 1133 |
+
inp = torch.tensor(
|
| 1134 |
+
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
|
| 1135 |
+
dtype=torch.long,
|
| 1136 |
+
device=self.device,
|
| 1137 |
+
)
|
| 1138 |
+
(inplen,) = inp.shape
|
| 1139 |
+
elif self.backend == "seq2seq":
|
| 1140 |
+
inp = torch.tensor(
|
| 1141 |
+
(context_enc)[-self.max_length :],
|
| 1142 |
+
dtype=torch.long,
|
| 1143 |
+
device=self.device,
|
| 1144 |
+
)
|
| 1145 |
+
(inplen,) = inp.shape
|
| 1146 |
+
|
| 1147 |
+
# build encoder attn masks
|
| 1148 |
+
encoder_attns.append(torch.ones_like(inp))
|
| 1149 |
+
|
| 1150 |
+
cont = torch.tensor(
|
| 1151 |
+
(continuation_enc)[-self.max_length :],
|
| 1152 |
+
# TODO: left-shift these?
|
| 1153 |
+
# TODO: our code assumes we never end up truncating conts for either model type
|
| 1154 |
+
dtype=torch.long,
|
| 1155 |
+
device=self.device,
|
| 1156 |
+
)
|
| 1157 |
+
(contlen,) = cont.shape
|
| 1158 |
+
|
| 1159 |
+
conts.append(cont)
|
| 1160 |
+
|
| 1161 |
+
padding_len_cont = (
|
| 1162 |
+
max(padding_len_cont, contlen)
|
| 1163 |
+
if padding_len_cont is not None
|
| 1164 |
+
else contlen
|
| 1165 |
+
)
|
| 1166 |
+
|
| 1167 |
+
padding_len_inp = (
|
| 1168 |
+
max(padding_len_inp, inplen)
|
| 1169 |
+
if padding_len_inp is not None
|
| 1170 |
+
else inplen
|
| 1171 |
+
)
|
| 1172 |
+
|
| 1173 |
+
inps.append(inp) # [1, inp_length]
|
| 1174 |
+
cont_toks_list.append(continuation_enc)
|
| 1175 |
+
inplens.append(inplen)
|
| 1176 |
+
|
| 1177 |
+
# create encoder attn mask and batched conts, if seq2seq
|
| 1178 |
+
call_kwargs = {}
|
| 1179 |
+
if self.backend == "causal":
|
| 1180 |
+
batched_inps = pad_and_concat(
|
| 1181 |
+
padding_len_inp, inps, padding_side="right"
|
| 1182 |
+
) # [batch, padding_len_inp]
|
| 1183 |
+
elif self.backend == "seq2seq":
|
| 1184 |
+
# TODO: left-pad encoder inps and mask?
|
| 1185 |
+
batched_inps = pad_and_concat(
|
| 1186 |
+
padding_len_inp, inps
|
| 1187 |
+
) # [batch, padding_len_inp]
|
| 1188 |
+
batched_conts = pad_and_concat(
|
| 1189 |
+
padding_len_cont, conts
|
| 1190 |
+
) # [batch, padding_len_cont]
|
| 1191 |
+
batched_encoder_mask = pad_and_concat(
|
| 1192 |
+
padding_len_inp, encoder_attns
|
| 1193 |
+
) # [batch, padding_len_inp]
|
| 1194 |
+
call_kwargs = {
|
| 1195 |
+
"attn_mask": batched_encoder_mask,
|
| 1196 |
+
"labels": batched_conts,
|
| 1197 |
+
}
|
| 1198 |
+
|
| 1199 |
+
multi_logits = F.log_softmax(
|
| 1200 |
+
self._model_call(batched_inps, **call_kwargs), dim=-1
|
| 1201 |
+
) # [batch, padding_length (inp or cont), vocab]
|
| 1202 |
+
|
| 1203 |
+
for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip(
|
| 1204 |
+
chunk, multi_logits, inplens, cont_toks_list
|
| 1205 |
+
):
|
| 1206 |
+
# Slice to original seq length
|
| 1207 |
+
contlen = len(cont_toks)
|
| 1208 |
+
# take only logits in the continuation
|
| 1209 |
+
# (discard context toks if decoder-only ; discard right-padding)
|
| 1210 |
+
# also discards + checks for "virtual tokens" in the causal LM's input window
|
| 1211 |
+
# from prompt/prefix tuning tokens, if applicable
|
| 1212 |
+
ctx_len = (
|
| 1213 |
+
inplen + (logits.shape[0] - padding_len_inp)
|
| 1214 |
+
if self.backend == "causal"
|
| 1215 |
+
else None
|
| 1216 |
+
)
|
| 1217 |
+
logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
|
| 1218 |
+
logits = logits.unsqueeze(0) # [1, seq, vocab]
|
| 1219 |
+
|
| 1220 |
+
# Check if per-token argmax is exactly equal to continuation
|
| 1221 |
+
greedy_tokens = logits.argmax(dim=-1)
|
| 1222 |
+
|
| 1223 |
+
# check for one-token continuation cache hits.
|
| 1224 |
+
# noop in case group_by != "contexts" or no cache hit and returns the
|
| 1225 |
+
# original args. Otherwise, expands the logits batch dimension and yields each
|
| 1226 |
+
# batch along with matching continuation tokens and prompt strings.
|
| 1227 |
+
# logits -> [1, seq, vocab]
|
| 1228 |
+
for request_str, cont_toks, logits in re_ord.get_cache(
|
| 1229 |
+
req_str=request_str,
|
| 1230 |
+
cxt_toks=ctx_tokens,
|
| 1231 |
+
cont_toks=cont_toks,
|
| 1232 |
+
logits=logits,
|
| 1233 |
+
):
|
| 1234 |
+
cont_toks = torch.tensor(
|
| 1235 |
+
cont_toks, dtype=torch.long, device=self.device
|
| 1236 |
+
).unsqueeze(0) # [1, seq]
|
| 1237 |
+
max_equal = (greedy_tokens == cont_toks).all()
|
| 1238 |
+
|
| 1239 |
+
# Obtain log-probs at the corresponding continuation token indices
|
| 1240 |
+
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
|
| 1241 |
+
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
|
| 1242 |
+
-1
|
| 1243 |
+
) # [1, seq]
|
| 1244 |
+
|
| 1245 |
+
# Answer: (log prob, is-exact-match)
|
| 1246 |
+
answer = (float(logits.sum()), bool(max_equal))
|
| 1247 |
+
|
| 1248 |
+
res.append(answer)
|
| 1249 |
+
|
| 1250 |
+
if request_str is not None:
|
| 1251 |
+
# special case: loglikelihood_rolling produces a number of loglikelihood requests
|
| 1252 |
+
# all with cache key None. instead do add_partial on the per-example level
|
| 1253 |
+
# in the loglikelihood_rolling() function for those.
|
| 1254 |
+
self.cache_hook.add_partial(
|
| 1255 |
+
"loglikelihood", request_str, answer
|
| 1256 |
+
)
|
| 1257 |
+
pbar.update(1)
|
| 1258 |
+
|
| 1259 |
+
pbar.close()
|
| 1260 |
+
|
| 1261 |
+
return re_ord.get_original(res)
|
| 1262 |
+
|
| 1263 |
+
def generate_until(
|
| 1264 |
+
self, requests: List[Instance], disable_tqdm: bool = False
|
| 1265 |
+
) -> List[str]:
|
| 1266 |
+
res = []
|
| 1267 |
+
|
| 1268 |
+
def _collate(req: Tuple[str, dict]):
|
| 1269 |
+
"""Defines the key for the sorted method"""
|
| 1270 |
+
# the negative sign on len(toks) sorts descending - this has a few advantages:
|
| 1271 |
+
# - time estimates will always be over not underestimates, which is more useful for planning
|
| 1272 |
+
# - to know the size of a batch when going through the list, you know the first one is always the batch
|
| 1273 |
+
# padded context length. this is useful to simplify the batching logic and more importantly to make
|
| 1274 |
+
# automatic adaptive batches much much easier to implement
|
| 1275 |
+
# - any OOMs will happen right away rather than near the end
|
| 1276 |
+
toks = self.tok_encode(req[0])
|
| 1277 |
+
return -len(toks), req[0]
|
| 1278 |
+
|
| 1279 |
+
pbar = tqdm(
|
| 1280 |
+
total=len(requests),
|
| 1281 |
+
disable=(disable_tqdm or (self.rank != 0)),
|
| 1282 |
+
desc="Running generate_until requests",
|
| 1283 |
+
)
|
| 1284 |
+
adaptive_batch_size = None
|
| 1285 |
+
if self.batch_size == "auto":
|
| 1286 |
+
# using rolling window with maximum context
|
| 1287 |
+
print("Passed argument batch_size = auto. Detecting largest batch size")
|
| 1288 |
+
batch_size = self._detect_batch_size()
|
| 1289 |
+
print(f"Determined Largest batch size: {batch_size}")
|
| 1290 |
+
adaptive_batch_size = batch_size
|
| 1291 |
+
# for each different set of kwargs, we execute all requests, by batch.
|
| 1292 |
+
batch_size = (
|
| 1293 |
+
self.batch_size
|
| 1294 |
+
if self.batch_size != "auto"
|
| 1295 |
+
else adaptive_batch_size
|
| 1296 |
+
if adaptive_batch_size is not None
|
| 1297 |
+
else 0
|
| 1298 |
+
)
|
| 1299 |
+
batch_fn = (
|
| 1300 |
+
self._batch_scheduler
|
| 1301 |
+
if self.batch_size == "auto" and not adaptive_batch_size
|
| 1302 |
+
else None
|
| 1303 |
+
)
|
| 1304 |
+
|
| 1305 |
+
# we group requests by their generation_kwargs,
|
| 1306 |
+
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
|
| 1307 |
+
# in the same batch.
|
| 1308 |
+
# group_fn=lambda x: x[1] -> x=(context, gen_kwargs)
|
| 1309 |
+
re_ords = Collator(
|
| 1310 |
+
[reg.args for reg in requests],
|
| 1311 |
+
sort_fn=_collate,
|
| 1312 |
+
group_by="gen_kwargs",
|
| 1313 |
+
group_fn=lambda x: x[1],
|
| 1314 |
+
)
|
| 1315 |
+
chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn)
|
| 1316 |
+
eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
|
| 1317 |
+
for chunk in chunks:
|
| 1318 |
+
contexts, all_gen_kwargs = zip(*chunk)
|
| 1319 |
+
# we assume all gen kwargs in the batch are the same
|
| 1320 |
+
# this is safe to assume because the `grouper` object ensures it.
|
| 1321 |
+
gen_kwargs = all_gen_kwargs[0]
|
| 1322 |
+
# unpack our keyword arguments.
|
| 1323 |
+
if isinstance(gen_kwargs, dict):
|
| 1324 |
+
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
|
| 1325 |
+
# add EOS token to stop sequences
|
| 1326 |
+
until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
|
| 1327 |
+
else:
|
| 1328 |
+
raise ValueError(
|
| 1329 |
+
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
|
| 1330 |
+
)
|
| 1331 |
+
if "max_gen_toks" in kwargs.keys():
|
| 1332 |
+
max_gen_toks = kwargs.pop("max_gen_toks")
|
| 1333 |
+
else:
|
| 1334 |
+
max_gen_toks = self.max_gen_toks
|
| 1335 |
+
|
| 1336 |
+
# set the max length in tokens of inputs ("context_enc")
|
| 1337 |
+
if self.backend == "causal":
|
| 1338 |
+
# max len for inputs = max length, minus room to generate the max new tokens
|
| 1339 |
+
max_ctx_len = self.max_length - max_gen_toks
|
| 1340 |
+
assert max_ctx_len > 0, (
|
| 1341 |
+
f"Invalid configuration: requested max tokens to generate ({max_gen_toks}) must be less than model's maximum sequence length ({self.max_length})."
|
| 1342 |
+
)
|
| 1343 |
+
elif self.backend == "seq2seq":
|
| 1344 |
+
# max len for inputs = encoder's whole max_length
|
| 1345 |
+
max_ctx_len = self.max_length
|
| 1346 |
+
|
| 1347 |
+
# encode, pad, and truncate contexts for this batch
|
| 1348 |
+
context_enc, attn_masks = self.tok_batch_encode(
|
| 1349 |
+
contexts,
|
| 1350 |
+
left_truncate_len=max_ctx_len,
|
| 1351 |
+
truncation=self.truncation,
|
| 1352 |
+
)
|
| 1353 |
+
context_enc = context_enc.to(self.device)
|
| 1354 |
+
attn_masks = attn_masks.to(self.device)
|
| 1355 |
+
|
| 1356 |
+
if "max_length" not in kwargs:
|
| 1357 |
+
kwargs["max_length"] = context_enc.shape[1] + max_gen_toks
|
| 1358 |
+
|
| 1359 |
+
# perform batched generation
|
| 1360 |
+
cont = self._model_generate(
|
| 1361 |
+
context=context_enc,
|
| 1362 |
+
attention_mask=attn_masks,
|
| 1363 |
+
stop=until,
|
| 1364 |
+
**kwargs,
|
| 1365 |
+
)
|
| 1366 |
+
|
| 1367 |
+
cont_toks_list = cont.tolist()
|
| 1368 |
+
for cont_toks, context in zip(cont_toks_list, contexts):
|
| 1369 |
+
# discard context + left-padding toks if using causal decoder-only LM
|
| 1370 |
+
if self.backend == "causal":
|
| 1371 |
+
cont_toks = cont_toks[context_enc.shape[1] :]
|
| 1372 |
+
|
| 1373 |
+
s = self.tok_decode(cont_toks)
|
| 1374 |
+
|
| 1375 |
+
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
|
| 1376 |
+
for term in until:
|
| 1377 |
+
if len(term) > 0:
|
| 1378 |
+
# ignore '' separator,
|
| 1379 |
+
# for seq2seq case where self.tok_decode(self.eot_token_id) = ''
|
| 1380 |
+
s = s.split(term)[0]
|
| 1381 |
+
|
| 1382 |
+
res.append(s)
|
| 1383 |
+
|
| 1384 |
+
self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s)
|
| 1385 |
+
pbar.update(1)
|
| 1386 |
+
# reorder this group of results back to original unsorted form
|
| 1387 |
+
res = re_ords.get_original(res)
|
| 1388 |
+
|
| 1389 |
+
pbar.close()
|
| 1390 |
+
|
| 1391 |
+
return res
|
| 1392 |
+
|
| 1393 |
+
def apply_chat_template(
|
| 1394 |
+
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
|
| 1395 |
+
) -> str:
|
| 1396 |
+
"""
|
| 1397 |
+
Method to apply a chat template to a list of chat history between user and model.
|
| 1398 |
+
"""
|
| 1399 |
+
try:
|
| 1400 |
+
chat_templated = self.tokenizer.apply_chat_template(
|
| 1401 |
+
chat_history,
|
| 1402 |
+
tokenize=False,
|
| 1403 |
+
add_generation_prompt=add_generation_prompt,
|
| 1404 |
+
continue_final_message=not add_generation_prompt,
|
| 1405 |
+
)
|
| 1406 |
+
except jinja2.exceptions.TemplateError:
|
| 1407 |
+
eval_logger.warning(
|
| 1408 |
+
"Failed to apply chat template. removing the system role in chat history."
|
| 1409 |
+
)
|
| 1410 |
+
chat_history = [msg for msg in chat_history if msg["role"] != "system"]
|
| 1411 |
+
chat_templated = self.tokenizer.apply_chat_template(
|
| 1412 |
+
chat_history,
|
| 1413 |
+
tokenize=False,
|
| 1414 |
+
add_generation_prompt=add_generation_prompt,
|
| 1415 |
+
continue_final_message=not add_generation_prompt,
|
| 1416 |
+
)
|
| 1417 |
+
|
| 1418 |
+
return chat_templated
|
| 1419 |
+
|
| 1420 |
+
def get_model_info(self) -> dict:
|
| 1421 |
+
"""
|
| 1422 |
+
Method to get Hugging Face model information for experiment reproducibility.
|
| 1423 |
+
"""
|
| 1424 |
+
|
| 1425 |
+
def get_model_num_params(model) -> int:
|
| 1426 |
+
if hasattr(model, "num_parameters"):
|
| 1427 |
+
return model.num_parameters()
|
| 1428 |
+
if hasattr(model, "parameters"):
|
| 1429 |
+
return sum(p.numel() for p in model.parameters())
|
| 1430 |
+
else:
|
| 1431 |
+
return -1
|
| 1432 |
+
|
| 1433 |
+
def get_model_dtype(model) -> str:
|
| 1434 |
+
if hasattr(model, "dtype"):
|
| 1435 |
+
return model.dtype
|
| 1436 |
+
else:
|
| 1437 |
+
return ""
|
| 1438 |
+
|
| 1439 |
+
def get_model_sha(pretrained: str, revision: str) -> str:
|
| 1440 |
+
try:
|
| 1441 |
+
model_info = HfApi().model_info(repo_id=pretrained, revision=revision)
|
| 1442 |
+
return model_info.sha
|
| 1443 |
+
except Exception as e:
|
| 1444 |
+
eval_logger.debug(
|
| 1445 |
+
f"Failed to get model SHA for {pretrained} at revision {revision}. Error: {e}"
|
| 1446 |
+
)
|
| 1447 |
+
return ""
|
| 1448 |
+
|
| 1449 |
+
model_info = {
|
| 1450 |
+
"model_num_parameters": get_model_num_params(self._model),
|
| 1451 |
+
"model_dtype": get_model_dtype(self._model),
|
| 1452 |
+
"model_revision": self.revision,
|
| 1453 |
+
"model_sha": get_model_sha(self.pretrained, self.revision),
|
| 1454 |
+
}
|
| 1455 |
+
if self.peft:
|
| 1456 |
+
model_info["peft_sha"] = get_model_sha(self.peft, self.revision)
|
| 1457 |
+
if self.delta:
|
| 1458 |
+
model_info["delta_sha"] = get_model_sha(self.delta, self.revision)
|
| 1459 |
+
return model_info
|
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/utils.py
ADDED
|
@@ -0,0 +1,731 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import fnmatch
|
| 3 |
+
import gc
|
| 4 |
+
import itertools
|
| 5 |
+
import logging
|
| 6 |
+
import time
|
| 7 |
+
from functools import wraps
|
| 8 |
+
from typing import (
|
| 9 |
+
TYPE_CHECKING,
|
| 10 |
+
Any,
|
| 11 |
+
Callable,
|
| 12 |
+
Dict,
|
| 13 |
+
Iterable,
|
| 14 |
+
Iterator,
|
| 15 |
+
List,
|
| 16 |
+
Literal,
|
| 17 |
+
Optional,
|
| 18 |
+
Tuple,
|
| 19 |
+
Type,
|
| 20 |
+
Union,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import transformers
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
eval_logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
if TYPE_CHECKING:
|
| 31 |
+
from transformers import PreTrainedTokenizerBase
|
| 32 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def chunks(iter, n: int = 0, fn=None):
|
| 36 |
+
"""
|
| 37 |
+
Divides an iterable into chunks of specified size or based on a given function.
|
| 38 |
+
Useful for batching
|
| 39 |
+
|
| 40 |
+
Parameters:
|
| 41 |
+
- iter: The input iterable to be divided into chunks.
|
| 42 |
+
- n: An integer representing the size of each chunk. Default is 0.
|
| 43 |
+
- fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None.
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
An iterator that yields chunks of the input iterable.
|
| 47 |
+
|
| 48 |
+
Example usage:
|
| 49 |
+
```
|
| 50 |
+
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
| 51 |
+
for chunk in chunks(data, 3):
|
| 52 |
+
print(chunk)
|
| 53 |
+
```
|
| 54 |
+
Output:
|
| 55 |
+
```
|
| 56 |
+
[1, 2, 3]
|
| 57 |
+
[4, 5, 6]
|
| 58 |
+
[7, 8, 9]
|
| 59 |
+
[10]
|
| 60 |
+
```
|
| 61 |
+
"""
|
| 62 |
+
arr = []
|
| 63 |
+
for i, x in enumerate(iter):
|
| 64 |
+
arr.append(x)
|
| 65 |
+
if len(arr) == (fn(i, iter) if fn else n):
|
| 66 |
+
yield arr
|
| 67 |
+
arr = []
|
| 68 |
+
|
| 69 |
+
if arr:
|
| 70 |
+
yield arr
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class MultiChoice:
|
| 74 |
+
def __init__(self, choices) -> None:
|
| 75 |
+
self.choices = choices
|
| 76 |
+
|
| 77 |
+
# Simple wildcard support (linux filename patterns)
|
| 78 |
+
def __contains__(self, values) -> bool:
|
| 79 |
+
for value in values.split(","):
|
| 80 |
+
if len(fnmatch.filter(self.choices, value)) == 0:
|
| 81 |
+
eval_logger.info("Available tasks to choose:")
|
| 82 |
+
for choice in self.choices:
|
| 83 |
+
eval_logger.info(f" - {choice}")
|
| 84 |
+
raise ValueError("'{}' is not in task list".format(value))
|
| 85 |
+
return True
|
| 86 |
+
|
| 87 |
+
def __iter__(self) -> Iterator:
|
| 88 |
+
for choice in self.choices:
|
| 89 |
+
yield choice
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class Grouper:
|
| 93 |
+
"""
|
| 94 |
+
takes an array `arr` and function `fn` and returns a dictionary
|
| 95 |
+
with keys fn(ob) for each ob in `arr` and with values `self.arr[key]` a list of all
|
| 96 |
+
objects in `arr` satisfying `key == fn(ob)`.
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
def __init__(self, arr, fn) -> None:
|
| 100 |
+
# self.orig_arr = arr
|
| 101 |
+
self.size = len(arr)
|
| 102 |
+
arr = list(enumerate(arr))
|
| 103 |
+
|
| 104 |
+
def group_return_dict(arr, fn):
|
| 105 |
+
res = collections.defaultdict(list)
|
| 106 |
+
|
| 107 |
+
for ob in arr:
|
| 108 |
+
res[fn(ob)].append(ob)
|
| 109 |
+
return res
|
| 110 |
+
|
| 111 |
+
arr = group_return_dict(arr, lambda x: fn(x[1]))
|
| 112 |
+
|
| 113 |
+
# self.arr has format Dict[Tuple[int, <entry from orig. arr>]]
|
| 114 |
+
self.arr = arr
|
| 115 |
+
self._grouped = None
|
| 116 |
+
|
| 117 |
+
def get_grouped(self):
|
| 118 |
+
# return the contents but not indices for our grouped dict.
|
| 119 |
+
if self._grouped:
|
| 120 |
+
return self._grouped
|
| 121 |
+
grouped = {}
|
| 122 |
+
for key in self.arr.keys():
|
| 123 |
+
# drop the index from each element of self.arr
|
| 124 |
+
grouped[key] = [y[1] for y in self.arr[key]]
|
| 125 |
+
self._grouped = grouped
|
| 126 |
+
return grouped
|
| 127 |
+
|
| 128 |
+
def get_original(self, grouped_dict):
|
| 129 |
+
# take in a grouped dictionary with e.g. results for each key listed
|
| 130 |
+
# in the same order as the instances in `self.arr`, and
|
| 131 |
+
# return the results in the same (single list) order as `self.orig_arr`.
|
| 132 |
+
res = [None] * self.size
|
| 133 |
+
cov = [False] * self.size
|
| 134 |
+
# orig = [None] * self.size
|
| 135 |
+
|
| 136 |
+
assert grouped_dict.keys() == self.arr.keys()
|
| 137 |
+
|
| 138 |
+
for key in grouped_dict.keys():
|
| 139 |
+
for (ind, _), v in zip(self.arr[key], grouped_dict[key]):
|
| 140 |
+
res[ind] = v
|
| 141 |
+
cov[ind] = True
|
| 142 |
+
# orig[ind] = _
|
| 143 |
+
|
| 144 |
+
assert all(cov)
|
| 145 |
+
# assert orig == self.orig_arr
|
| 146 |
+
|
| 147 |
+
return res
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def pad_and_concat(
|
| 151 |
+
max_length: int,
|
| 152 |
+
tensors: List[torch.Tensor],
|
| 153 |
+
padding_side: Literal["right", "left"] = "right",
|
| 154 |
+
):
|
| 155 |
+
"""
|
| 156 |
+
Method for padding a list of tensors given the maximum tensor
|
| 157 |
+
length in the batch. Used for batching inputs and continuations in
|
| 158 |
+
seq2seq models.
|
| 159 |
+
"""
|
| 160 |
+
assert padding_side == "left" or padding_side == "right", (
|
| 161 |
+
f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
for i, tensor in enumerate(tensors):
|
| 165 |
+
if len(tensor.shape) == 2:
|
| 166 |
+
tensor = tensor.squeeze(0) # squeeze, in case passed [1, seq] size
|
| 167 |
+
tensor_len = tensor.shape[0]
|
| 168 |
+
if tensor_len < max_length:
|
| 169 |
+
if padding_side == "right":
|
| 170 |
+
# right-pad
|
| 171 |
+
tensors[i] = torch.cat(
|
| 172 |
+
[
|
| 173 |
+
tensor, # [seq]
|
| 174 |
+
torch.zeros(
|
| 175 |
+
max_length - tensor_len,
|
| 176 |
+
dtype=torch.long,
|
| 177 |
+
device=tensor.device,
|
| 178 |
+
), # [padding_length - seq]
|
| 179 |
+
],
|
| 180 |
+
dim=0,
|
| 181 |
+
).unsqueeze(0)
|
| 182 |
+
else:
|
| 183 |
+
# left-pad
|
| 184 |
+
tensors[i] = torch.cat(
|
| 185 |
+
[
|
| 186 |
+
torch.zeros(
|
| 187 |
+
max_length - tensor_len,
|
| 188 |
+
dtype=torch.long,
|
| 189 |
+
device=tensor.device,
|
| 190 |
+
), # [padding_length - seq]
|
| 191 |
+
tensor, # [seq]
|
| 192 |
+
],
|
| 193 |
+
dim=0,
|
| 194 |
+
).unsqueeze(0)
|
| 195 |
+
else:
|
| 196 |
+
tensors[i] = tensor.unsqueeze(0)
|
| 197 |
+
|
| 198 |
+
return torch.cat(tensors, dim=0)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def clear_torch_cache() -> None:
|
| 202 |
+
gc.collect()
|
| 203 |
+
torch.cuda.empty_cache()
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype:
|
| 207 |
+
"""Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig"""
|
| 208 |
+
if isinstance(dtype, str) and dtype != "auto":
|
| 209 |
+
# Convert `str` args torch dtype: `float16` -> `torch.float16`
|
| 210 |
+
_torch_dtype = getattr(torch, dtype)
|
| 211 |
+
else:
|
| 212 |
+
_torch_dtype = dtype
|
| 213 |
+
return _torch_dtype
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class MultiTokenEOSCriteria(transformers.StoppingCriteria):
|
| 217 |
+
"""Criteria to stop on the specified multi-token sequence."""
|
| 218 |
+
|
| 219 |
+
def __init__(
|
| 220 |
+
self,
|
| 221 |
+
sequence: str,
|
| 222 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 223 |
+
initial_decoder_input_length: int,
|
| 224 |
+
batch_size: int,
|
| 225 |
+
) -> None:
|
| 226 |
+
self.initial_decoder_input_length = initial_decoder_input_length
|
| 227 |
+
self.done_tracker = [False] * batch_size
|
| 228 |
+
self.sequence = sequence
|
| 229 |
+
self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
|
| 230 |
+
# print(sequence, self.sequence_ids)
|
| 231 |
+
# we look back for 2 more tokens than it takes to encode our stop sequence
|
| 232 |
+
# because tokenizers suck, and a model might generate `['\n', '\n']` but our `sequence` is `['\n\n']`
|
| 233 |
+
# and we don't want to mistakenly not stop a generation because our
|
| 234 |
+
# (string) stop sequence was output in a different tokenization
|
| 235 |
+
|
| 236 |
+
# NOTE: there is a minor danger that this will end up looking back 2 tokens into the past, into the inputs to the model,
|
| 237 |
+
# and stopping generation immediately as a result. With only 2 extra tokens of lookback, this risk is minimized
|
| 238 |
+
# Additionally, in lookback_ids_batch we should prevent ever looking back into the inputs as described.
|
| 239 |
+
self.sequence_id_len = len(self.sequence_ids) + 2
|
| 240 |
+
self.tokenizer = tokenizer
|
| 241 |
+
|
| 242 |
+
def __call__(self, input_ids, scores, **kwargs) -> bool:
|
| 243 |
+
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
|
| 244 |
+
lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :]
|
| 245 |
+
|
| 246 |
+
lookback_ids_batch = lookback_ids_batch[:, -self.sequence_id_len :]
|
| 247 |
+
|
| 248 |
+
lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
|
| 249 |
+
|
| 250 |
+
for i, done in enumerate(self.done_tracker):
|
| 251 |
+
if not done:
|
| 252 |
+
self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
|
| 253 |
+
return False not in self.done_tracker
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def stop_sequences_criteria(
|
| 257 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 258 |
+
stop_sequences: List[str],
|
| 259 |
+
initial_decoder_input_length: int,
|
| 260 |
+
batch_size: int,
|
| 261 |
+
) -> transformers.StoppingCriteriaList:
|
| 262 |
+
return transformers.StoppingCriteriaList(
|
| 263 |
+
[
|
| 264 |
+
*[
|
| 265 |
+
MultiTokenEOSCriteria(
|
| 266 |
+
sequence, tokenizer, initial_decoder_input_length, batch_size
|
| 267 |
+
)
|
| 268 |
+
for sequence in stop_sequences
|
| 269 |
+
],
|
| 270 |
+
]
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def undistribute(iterable):
|
| 275 |
+
"""
|
| 276 |
+
Undoes https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.distribute .
|
| 277 |
+
|
| 278 |
+
Re-interleaves results that have been split using more_itertools.distribute:
|
| 279 |
+
>>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6])
|
| 280 |
+
>>> list(group_1)
|
| 281 |
+
[1, 3, 5]
|
| 282 |
+
>>> list(group_2)
|
| 283 |
+
[2, 4, 6]
|
| 284 |
+
>>> undistribute([group_1, group_2])
|
| 285 |
+
[1, 2, 3, 4, 5, 6]
|
| 286 |
+
|
| 287 |
+
Handles non-uniform component lengths:
|
| 288 |
+
|
| 289 |
+
>>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7])
|
| 290 |
+
>>> [list(c) for c in children]
|
| 291 |
+
[[1, 4, 7], [2, 5], [3, 6]]
|
| 292 |
+
>>> undistribute(children)
|
| 293 |
+
[1, 2, 3, 4, 5, 6, 7]
|
| 294 |
+
|
| 295 |
+
Also handles when some iterables are empty:
|
| 296 |
+
|
| 297 |
+
>>> children = distribute(5, [1, 2, 3])
|
| 298 |
+
>>> [list(c) for c in children]
|
| 299 |
+
[[1], [2], [3], [], []]
|
| 300 |
+
>>> undistribute(children)
|
| 301 |
+
[1, 2, 3]
|
| 302 |
+
|
| 303 |
+
"""
|
| 304 |
+
|
| 305 |
+
return [
|
| 306 |
+
x
|
| 307 |
+
for x in itertools.chain.from_iterable(
|
| 308 |
+
itertools.zip_longest(*[list(x) for x in iterable])
|
| 309 |
+
)
|
| 310 |
+
if x is not None
|
| 311 |
+
]
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def retry_on_specific_exceptions(
|
| 315 |
+
on_exceptions: List[Type[Exception]],
|
| 316 |
+
max_retries: Optional[int] = None,
|
| 317 |
+
backoff_time: float = 3.0,
|
| 318 |
+
backoff_multiplier: float = 1.5,
|
| 319 |
+
on_exception_callback: Optional[Callable[[Exception, float], Any]] = None,
|
| 320 |
+
):
|
| 321 |
+
"""Retry on an LLM Provider's rate limit error with exponential backoff
|
| 322 |
+
For example, to use for OpenAI, do the following:
|
| 323 |
+
```
|
| 324 |
+
from openai import RateLimitError
|
| 325 |
+
|
| 326 |
+
# Recommend specifying max_retries to avoid infinite loops!
|
| 327 |
+
@retry_on_specific_exceptions([RateLimitError], max_retries=3)
|
| 328 |
+
def completion(...):
|
| 329 |
+
# Wrap OpenAI completion function here
|
| 330 |
+
...
|
| 331 |
+
```
|
| 332 |
+
"""
|
| 333 |
+
|
| 334 |
+
def decorator(func: Callable):
|
| 335 |
+
@wraps(func)
|
| 336 |
+
def wrapper(*args, **kwargs):
|
| 337 |
+
sleep_time = backoff_time
|
| 338 |
+
attempt = 0
|
| 339 |
+
while max_retries is None or attempt < max_retries:
|
| 340 |
+
try:
|
| 341 |
+
return func(*args, **kwargs)
|
| 342 |
+
except tuple(on_exceptions) as e:
|
| 343 |
+
if on_exception_callback is not None:
|
| 344 |
+
on_exception_callback(e, sleep_time)
|
| 345 |
+
time.sleep(sleep_time)
|
| 346 |
+
sleep_time *= backoff_multiplier
|
| 347 |
+
attempt += 1
|
| 348 |
+
|
| 349 |
+
return wrapper
|
| 350 |
+
|
| 351 |
+
return decorator
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
class Collator:
|
| 355 |
+
"""
|
| 356 |
+
A class for reordering and batching elements of an array.
|
| 357 |
+
|
| 358 |
+
This class allows for sorting an array based on a provided sorting function, grouping elements based on a grouping function, and generating batches from the sorted and grouped data.
|
| 359 |
+
|
| 360 |
+
Objects of this class have the group_by attribute which determines the method for grouping
|
| 361 |
+
the data while batching it. Three options include "gen_kwargs", "contexts", or None:
|
| 362 |
+
If group_by == "gen_kwargs" then requests will be grouped by gen_kwargs
|
| 363 |
+
If group_by == "contexts" then requests will be grouped by context + cont[:-1]
|
| 364 |
+
If None then requests will just be reordered by length descending.
|
| 365 |
+
"""
|
| 366 |
+
|
| 367 |
+
def __init__(
|
| 368 |
+
self,
|
| 369 |
+
arr: List,
|
| 370 |
+
sort_fn: Callable = lambda x: x,
|
| 371 |
+
group_fn: Callable = lambda x: x[1],
|
| 372 |
+
group_by: Union[Literal["gen_kwargs", "contexts"], None] = None,
|
| 373 |
+
) -> None:
|
| 374 |
+
self._group_by = group_by
|
| 375 |
+
# 0 indices are enumerated indices. Apply functions to original arr.
|
| 376 |
+
self._sort_fn = lambda x: sort_fn(x[1])
|
| 377 |
+
self._group_fn = lambda x: group_fn(x[1])
|
| 378 |
+
self._reorder_indices: List = []
|
| 379 |
+
self._size = len(arr)
|
| 380 |
+
self._arr_with_indices: Union[Dict, Tuple[Tuple[int, Any], ...]] = tuple(
|
| 381 |
+
enumerate(arr)
|
| 382 |
+
) # [indices, (arr)]
|
| 383 |
+
if self._group_by == "contexts":
|
| 384 |
+
self._group_by_context()
|
| 385 |
+
elif self._group_by == "gen_kwargs":
|
| 386 |
+
self._group_by_index()
|
| 387 |
+
|
| 388 |
+
def _group_by_index(self) -> None:
|
| 389 |
+
"""Group the elements of a list based on their indices."""
|
| 390 |
+
self._arr_with_indices = self.group(
|
| 391 |
+
self._arr_with_indices, fn=self._group_fn, group_by="gen_kwargs"
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
def _group_by_context(self) -> None:
|
| 395 |
+
"""Group the array with indices by context."""
|
| 396 |
+
self._arr_with_indices = self.group(
|
| 397 |
+
self._arr_with_indices, fn=self._group_fn, group_by="contexts"
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None) -> Iterator:
|
| 401 |
+
"""
|
| 402 |
+
Generates and yields batches from the reordered array. The method of grouping and batching
|
| 403 |
+
depends on the parameter `group_by`.
|
| 404 |
+
If `group_by` is set to "gen_kwargs", it will batch the
|
| 405 |
+
re-ordered values with same gen_kwargs for each batch.
|
| 406 |
+
If `group_by` is "contexts", it caches the requests by context before batching.
|
| 407 |
+
If `group_by` is neither "gen_kwargs" nor "contexts", it yields the reordered array
|
| 408 |
+
|
| 409 |
+
Parameters:
|
| 410 |
+
- n (int): The size of each batch. Defaults to 1.
|
| 411 |
+
- batch_fn ([Callable[[int, Iterable], int]] | None): A function to determine the size of
|
| 412 |
+
each batch. Optional, defaults to None.
|
| 413 |
+
|
| 414 |
+
Returns:
|
| 415 |
+
Iterator: An iterator over batches of reordered elements grouped as per the `group_by`
|
| 416 |
+
attribute.
|
| 417 |
+
|
| 418 |
+
Yields:
|
| 419 |
+
List of batched elements according to the `group_by` attribute.
|
| 420 |
+
"""
|
| 421 |
+
if self._group_by == "gen_kwargs":
|
| 422 |
+
for (
|
| 423 |
+
key,
|
| 424 |
+
values,
|
| 425 |
+
) in self._arr_with_indices.items(): # type: ignore
|
| 426 |
+
values = self._reorder(values)
|
| 427 |
+
batch = self.get_chunks(values, n=n, fn=batch_fn)
|
| 428 |
+
yield from batch
|
| 429 |
+
elif self._group_by == "contexts":
|
| 430 |
+
# Get one sample from each key
|
| 431 |
+
values = self._reorder(
|
| 432 |
+
[value[0] for value in self._arr_with_indices.values()]
|
| 433 |
+
)
|
| 434 |
+
batch = self.get_chunks(values, n=n, fn=batch_fn)
|
| 435 |
+
yield from batch
|
| 436 |
+
else:
|
| 437 |
+
values = self._reorder(self._arr_with_indices) # type: ignore
|
| 438 |
+
batch = self.get_chunks(values, n=n, fn=batch_fn)
|
| 439 |
+
yield from batch
|
| 440 |
+
|
| 441 |
+
def get_cache(
|
| 442 |
+
self,
|
| 443 |
+
req_str: Tuple[str, str] = None,
|
| 444 |
+
cxt_toks: List[int] = None,
|
| 445 |
+
cont_toks: List[int] = None,
|
| 446 |
+
logits: torch.Tensor = None,
|
| 447 |
+
) -> Iterator[Tuple[Tuple[str, str], List[int], torch.Tensor]]:
|
| 448 |
+
"""
|
| 449 |
+
Retrieves cached single-token continuations and their associated arguments, updating indices as necessary.
|
| 450 |
+
|
| 451 |
+
The behavior of this function varies depending on how the `group_by` attribute is set:
|
| 452 |
+
|
| 453 |
+
- When `group_by` is "contexts":
|
| 454 |
+
The function identifies single-token continuations by checking for keys that equate to
|
| 455 |
+
[context+continuation][-1] and logs the indices for re-ordering.
|
| 456 |
+
In this mode, this function can work in two scenarios:
|
| 457 |
+
|
| 458 |
+
1. Cache Hit - Single Match:
|
| 459 |
+
If a single matching context-continuation pair is found in the cache,
|
| 460 |
+
the function yields the original arguments.
|
| 461 |
+
|
| 462 |
+
2. Cache Hit - Multiple Matches:
|
| 463 |
+
If multiple matching context-continuation pairs are found in the cache,
|
| 464 |
+
the function expands the logits batch dimension to match the number of cache hits.
|
| 465 |
+
It updates the original requests and continuation tokens.
|
| 466 |
+
|
| 467 |
+
- When `group_by` is not set to "contexts":
|
| 468 |
+
This method yields the original arguments, logits and continuation tokens,
|
| 469 |
+
without checking for one-token continuations.
|
| 470 |
+
|
| 471 |
+
Parameters:
|
| 472 |
+
- req_str (tuple[str, str]): Original strings used for CachingLM.
|
| 473 |
+
- cxt_toks (list[int]): Full context tokens used for lookup.
|
| 474 |
+
- cont_toks (list[int]): Continuation tokens for which logits were generated.
|
| 475 |
+
- logits (torch.Tensor [1, seq_length, vocab_size]): Logits generated by the model given context and continuation keys.
|
| 476 |
+
|
| 477 |
+
Yields:
|
| 478 |
+
- Iterator:
|
| 479 |
+
- req_str (tuple[str, str]): strings used for CachingLM.
|
| 480 |
+
- cont_toks (list[int]) : continuation tokens.
|
| 481 |
+
- logits (torch.Tensor [1, seq_length, vocab_size]): The original logits (repeated cache hit times)
|
| 482 |
+
"""
|
| 483 |
+
if self._group_by == "contexts":
|
| 484 |
+
cache_hit: List[
|
| 485 |
+
Tuple[int, Tuple[Tuple[str, str], List[int], List[int]]]
|
| 486 |
+
] = self._arr_with_indices.pop(tuple(cxt_toks + cont_toks[:-1]))
|
| 487 |
+
if (cache_size := len(cache_hit)) == 1:
|
| 488 |
+
self._reorder_indices.extend(x[0] for x in cache_hit)
|
| 489 |
+
yield req_str, cont_toks, logits
|
| 490 |
+
else:
|
| 491 |
+
# If we have matching requests then expand the batch dimension (no-op) and
|
| 492 |
+
# yield each along with its corresponding args.
|
| 493 |
+
multilogits = logits.expand(cache_size, -1, -1).chunk(cache_size)
|
| 494 |
+
indices, req_str, cont_toks = zip(
|
| 495 |
+
*[(x[0], x[1][0], x[-1][-1]) for x in cache_hit]
|
| 496 |
+
)
|
| 497 |
+
self._reorder_indices.extend(indices)
|
| 498 |
+
for c_key, cont_tok, logit in zip(req_str, cont_toks, multilogits):
|
| 499 |
+
yield c_key, cont_tok, logit
|
| 500 |
+
else:
|
| 501 |
+
yield req_str, cont_toks, logits
|
| 502 |
+
|
| 503 |
+
def _reorder(self, arr: Union[List, Tuple[Tuple[int, Any], ...]]) -> Iterator:
|
| 504 |
+
"""
|
| 505 |
+
Reorders the elements in the array based on the sorting function.
|
| 506 |
+
|
| 507 |
+
Parameters:
|
| 508 |
+
- arr (list | tuple[tuple[int, Any], ...]]): The array or iterable to be reordered.
|
| 509 |
+
|
| 510 |
+
Yields:
|
| 511 |
+
Iterator
|
| 512 |
+
"""
|
| 513 |
+
arr = sorted(arr, key=self._sort_fn)
|
| 514 |
+
if not self._group_by == "contexts":
|
| 515 |
+
# If grouped by contexts then indices will be set in get_cache()
|
| 516 |
+
self._reorder_indices.extend([x[0] for x in arr])
|
| 517 |
+
yield from [x[1] for x in arr]
|
| 518 |
+
|
| 519 |
+
def get_original(self, newarr: List) -> List:
|
| 520 |
+
"""
|
| 521 |
+
Restores the original order of elements from the reordered list.
|
| 522 |
+
|
| 523 |
+
Parameters:
|
| 524 |
+
- newarr (list): The reordered array.
|
| 525 |
+
|
| 526 |
+
Returns:
|
| 527 |
+
list: The array with elements restored to their original order.
|
| 528 |
+
"""
|
| 529 |
+
res = [None] * self._size
|
| 530 |
+
cov = [False] * self._size
|
| 531 |
+
|
| 532 |
+
for ind, v in zip(self._reorder_indices, newarr):
|
| 533 |
+
res[ind] = v
|
| 534 |
+
cov[ind] = True
|
| 535 |
+
|
| 536 |
+
assert all(cov)
|
| 537 |
+
|
| 538 |
+
return res
|
| 539 |
+
|
| 540 |
+
def __len__(self):
|
| 541 |
+
return self._size
|
| 542 |
+
|
| 543 |
+
@staticmethod
|
| 544 |
+
def group(
|
| 545 |
+
arr: Iterable,
|
| 546 |
+
fn: Callable,
|
| 547 |
+
group_by: Literal["gen_kwargs", "contexts"] = "gen_kwargs",
|
| 548 |
+
) -> dict:
|
| 549 |
+
"""
|
| 550 |
+
Groups elements of an iterable based on a provided function.
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
The `group_by` parameter determines the method of grouping.
|
| 554 |
+
If `group_by` is "contexts", the elements are grouped by [context + cont][:-1].
|
| 555 |
+
If `group_by` is "gen_kwargs", the elements are grouped based on the gen_kwargs dict.
|
| 556 |
+
|
| 557 |
+
Parameters:
|
| 558 |
+
- arr (Iterable): The iterable to be grouped.
|
| 559 |
+
- fn (Callable): The function to determine the grouping.
|
| 560 |
+
- values (bool): If True, returns the values of the group. Defaults to False.
|
| 561 |
+
|
| 562 |
+
Returns:
|
| 563 |
+
Iterator: An iterable of grouped elements.
|
| 564 |
+
"""
|
| 565 |
+
res = collections.defaultdict(list)
|
| 566 |
+
for ob in arr:
|
| 567 |
+
# where ob == [context + cont]
|
| 568 |
+
if group_by == "contexts":
|
| 569 |
+
res[tuple(fn(ob))].append(ob)
|
| 570 |
+
else:
|
| 571 |
+
try:
|
| 572 |
+
hashable_dict = tuple(
|
| 573 |
+
(
|
| 574 |
+
key,
|
| 575 |
+
tuple(value)
|
| 576 |
+
if isinstance(value, collections.abc.Iterable)
|
| 577 |
+
else value,
|
| 578 |
+
)
|
| 579 |
+
for key, value in sorted(fn(ob).items())
|
| 580 |
+
)
|
| 581 |
+
res[hashable_dict].append(ob)
|
| 582 |
+
except (TypeError, AttributeError):
|
| 583 |
+
res[tuple(fn(ob))].append(ob)
|
| 584 |
+
return res
|
| 585 |
+
|
| 586 |
+
@staticmethod
|
| 587 |
+
def get_chunks(_iter, n: int = 0, fn=None):
|
| 588 |
+
"""
|
| 589 |
+
Divides an iterable into chunks of specified size or based on a given function.
|
| 590 |
+
Useful for batching
|
| 591 |
+
|
| 592 |
+
Parameters:
|
| 593 |
+
- iter: The input iterable to be divided into chunks.
|
| 594 |
+
- n: An integer representing the size of each chunk. Default is 0.
|
| 595 |
+
- fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None.
|
| 596 |
+
|
| 597 |
+
Returns:
|
| 598 |
+
An iterator that yields chunks of the input iterable.
|
| 599 |
+
|
| 600 |
+
Example usage:
|
| 601 |
+
```
|
| 602 |
+
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
| 603 |
+
for chunk in chunks(data, 3):
|
| 604 |
+
print(chunk)
|
| 605 |
+
```
|
| 606 |
+
Output:
|
| 607 |
+
```
|
| 608 |
+
[1, 2, 3]
|
| 609 |
+
[4, 5, 6]
|
| 610 |
+
[7, 8, 9]
|
| 611 |
+
[10]
|
| 612 |
+
```
|
| 613 |
+
"""
|
| 614 |
+
arr = []
|
| 615 |
+
_iter = tuple(_iter)
|
| 616 |
+
for i, x in enumerate(_iter):
|
| 617 |
+
arr.append(x)
|
| 618 |
+
if len(arr) == (fn(i, _iter) if fn else n):
|
| 619 |
+
yield arr
|
| 620 |
+
arr = []
|
| 621 |
+
|
| 622 |
+
if arr:
|
| 623 |
+
yield arr
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
def configure_pad_token(
|
| 627 |
+
tokenizer: "PreTrainedTokenizerBase",
|
| 628 |
+
model_config: Optional["PretrainedConfig"] = None,
|
| 629 |
+
) -> "PreTrainedTokenizerBase":
|
| 630 |
+
"""
|
| 631 |
+
This function checks if the (Hugging Face) tokenizer has a padding token and sets it if not present.
|
| 632 |
+
Some tokenizers require special handling.
|
| 633 |
+
|
| 634 |
+
Args:
|
| 635 |
+
tokenizer: The tokenizer for which the padding token is to be handled.
|
| 636 |
+
model_config: The configuration of the model. Default is None.
|
| 637 |
+
|
| 638 |
+
Returns:
|
| 639 |
+
The tokenizer after the padding token has been handled.
|
| 640 |
+
|
| 641 |
+
Raises:
|
| 642 |
+
AssertionError: If the tokenizer is of type RWKVWorldTokenizer or Rwkv5Tokenizer and the padding token id is not 0.
|
| 643 |
+
"""
|
| 644 |
+
if tokenizer.pad_token:
|
| 645 |
+
pass
|
| 646 |
+
elif tokenizer.unk_token:
|
| 647 |
+
tokenizer.pad_token_id = tokenizer.unk_token_id
|
| 648 |
+
elif tokenizer.eos_token:
|
| 649 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 650 |
+
else:
|
| 651 |
+
# handle special cases
|
| 652 |
+
if model_config and getattr(model_config, "model_type", None) == "qwen":
|
| 653 |
+
# Qwen's trust_remote_code tokenizer does not allow for adding special tokens
|
| 654 |
+
tokenizer.pad_token = "<|endoftext|>"
|
| 655 |
+
elif (
|
| 656 |
+
tokenizer.__class__.__name__ == "RWKVWorldTokenizer"
|
| 657 |
+
or tokenizer.__class__.__name__ == "Rwkv5Tokenizer"
|
| 658 |
+
):
|
| 659 |
+
# The RWKV world tokenizer, does not allow for adding special tokens / setting the pad token (which is set as 0)
|
| 660 |
+
# The additional tokenizer name check is needed, as there exists rwkv4 models with neox tokenizer
|
| 661 |
+
# ---
|
| 662 |
+
# Note that the world tokenizer class name, might change in the future for the final huggingface merge
|
| 663 |
+
# https://github.com/huggingface/transformers/pull/26963
|
| 664 |
+
assert tokenizer.pad_token_id == 0
|
| 665 |
+
else:
|
| 666 |
+
tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
|
| 667 |
+
|
| 668 |
+
return tokenizer
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
def replace_placeholders(
|
| 672 |
+
string: str, default_placeholder: str, image_token: str, max_images: int
|
| 673 |
+
):
|
| 674 |
+
"""
|
| 675 |
+
A utility function used for local multimodal models. It locates all `placeholder` string
|
| 676 |
+
occurrences in the given input `string_` and replaces the first `max_count` instances with
|
| 677 |
+
`replacement`, and all subsequent occurrences with the empty string.
|
| 678 |
+
|
| 679 |
+
This is used to replace <image> placeholder tags by model-specific image tokens like <|image_pad|>
|
| 680 |
+
and to allow for only the first `max_count` images to be passed to a model if desired.
|
| 681 |
+
|
| 682 |
+
:param string: The original string containing placeholders.
|
| 683 |
+
:param default_placeholder: The placeholder text to be replaced.
|
| 684 |
+
:param image_token: The token to replace the placeholder with.
|
| 685 |
+
:param max_images: The maximum number of replacements to make.
|
| 686 |
+
:return: The string with placeholders replaced.
|
| 687 |
+
"""
|
| 688 |
+
count = 0
|
| 689 |
+
result = []
|
| 690 |
+
|
| 691 |
+
parts = string.split(default_placeholder)
|
| 692 |
+
for part in parts[:-1]: # Iterate through all but the last part
|
| 693 |
+
result.append(part)
|
| 694 |
+
if count < max_images:
|
| 695 |
+
result.append(image_token)
|
| 696 |
+
count += 1
|
| 697 |
+
elif default_placeholder != image_token:
|
| 698 |
+
result.append(default_placeholder)
|
| 699 |
+
|
| 700 |
+
# Add the last part of the string
|
| 701 |
+
result.append(parts[-1])
|
| 702 |
+
return "".join(result)
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
def flatten_image_list(images: List[List]):
|
| 706 |
+
"""
|
| 707 |
+
Takes in a list of lists of images, and returns a single list of all images in order.
|
| 708 |
+
Used for some multimodal models like Llava-1.5 which expects this flattened-list format for its image processor.
|
| 709 |
+
|
| 710 |
+
:param images: A list of lists of PIL images.
|
| 711 |
+
:return: a list of PIL images, via concatenating all the sub-lists in order.
|
| 712 |
+
"""
|
| 713 |
+
return [image for image_list in images for image in image_list]
|
| 714 |
+
|
| 715 |
+
|
| 716 |
+
def handle_stop_sequences(
|
| 717 |
+
until: Union[str, List[str], None], eos: Optional[str]
|
| 718 |
+
) -> List[str]:
|
| 719 |
+
"""Ensures that the `until` parameter is a list of stop sequences and includes the EOS token."""
|
| 720 |
+
if isinstance(until, str):
|
| 721 |
+
until = [until]
|
| 722 |
+
elif until is None:
|
| 723 |
+
until = []
|
| 724 |
+
elif not isinstance(until, list):
|
| 725 |
+
raise ValueError(
|
| 726 |
+
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
if eos is not None and eos not in until:
|
| 730 |
+
until.append(eos)
|
| 731 |
+
return until
|
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/verifier.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import logging
|
| 3 |
+
import ast
|
| 4 |
+
import re
|
| 5 |
+
import numpy as np
|
| 6 |
+
import textwrap
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
class CodeVerifier:
|
| 11 |
+
def __init__(self, model, tokenizer, device="cuda"):
|
| 12 |
+
self.model = model
|
| 13 |
+
self.tokenizer = tokenizer
|
| 14 |
+
self.device = device
|
| 15 |
+
|
| 16 |
+
self.yes_ids, self.no_ids = [], []
|
| 17 |
+
for t in ["Yes", " Yes", "YES"]:
|
| 18 |
+
ids = self.tokenizer.encode(t, add_special_tokens=False)
|
| 19 |
+
if len(ids) > 0: self.yes_ids.append(ids[-1])
|
| 20 |
+
for t in ["No", " No", "NO"]:
|
| 21 |
+
ids = self.tokenizer.encode(t, add_special_tokens=False)
|
| 22 |
+
if len(ids) > 0: self.no_ids.append(ids[-1])
|
| 23 |
+
|
| 24 |
+
self.yes_ids = list(set(self.yes_ids))
|
| 25 |
+
self.no_ids = list(set(self.no_ids))
|
| 26 |
+
|
| 27 |
+
def _extract_python_code(self, text):
|
| 28 |
+
text = text.strip()
|
| 29 |
+
match = re.search(r"```python\s*(.*?)```", text, re.DOTALL)
|
| 30 |
+
if match: return match.group(1)
|
| 31 |
+
match_generic = re.search(r"```\s*(.*?)```", text, re.DOTALL)
|
| 32 |
+
if match_generic: return match_generic.group(1)
|
| 33 |
+
return text
|
| 34 |
+
|
| 35 |
+
def check_syntax(self, code_str):
|
| 36 |
+
clean_code = self._extract_python_code(code_str)
|
| 37 |
+
try:
|
| 38 |
+
if len(clean_code.strip()) < 5: return False
|
| 39 |
+
ast.parse(clean_code)
|
| 40 |
+
return True
|
| 41 |
+
except:
|
| 42 |
+
return False
|
| 43 |
+
|
| 44 |
+
def compute_confidence(self, logits):
|
| 45 |
+
if logits is None: return 0.0
|
| 46 |
+
probs = torch.softmax(logits, dim=-1)
|
| 47 |
+
max_probs, _ = torch.max(probs, dim=-1)
|
| 48 |
+
log_probs = torch.log(max_probs + 1e-10)
|
| 49 |
+
return torch.exp(torch.mean(log_probs)).item()
|
| 50 |
+
|
| 51 |
+
def svf_score(self, prompt, code_str, task_type="code"):
|
| 52 |
+
|
| 53 |
+
max_len = 2000
|
| 54 |
+
if len(code_str) > max_len:
|
| 55 |
+
if task_type == "reasoning":
|
| 56 |
+
truncated_code = code_str[:500] + "\n...[truncated]...\n" + code_str[-(max_len-500):]
|
| 57 |
+
else:
|
| 58 |
+
truncated_code = code_str[-max_len:]
|
| 59 |
+
else:
|
| 60 |
+
truncated_code = code_str
|
| 61 |
+
|
| 62 |
+
if task_type == "code":
|
| 63 |
+
prompt_template = f"""
|
| 64 |
+
You are an expert programming contest judge. Your task is to evaluate a generated solution for a given problem based on correctness, efficiency, and adherence to constraints.
|
| 65 |
+
|
| 66 |
+
[Problem Statement]
|
| 67 |
+
{prompt}
|
| 68 |
+
[/Problem Statement]
|
| 69 |
+
|
| 70 |
+
[Proposed Python Solution]
|
| 71 |
+
```python
|
| 72 |
+
{truncated_code}
|
| 73 |
+
```
|
| 74 |
+
[/Proposed Python Solution]
|
| 75 |
+
|
| 76 |
+
**Analysis Steps:**
|
| 77 |
+
1. Correctness: Does the core algorithm correctly solve the problem?
|
| 78 |
+
2. Efficiency: Is the time complexity acceptable for the given constraints?
|
| 79 |
+
3. Edge Cases & Constraints: Does the code handle all rules and edge cases?
|
| 80 |
+
|
| 81 |
+
**Conclusion**: Based on your analysis, is the solution likely to be fully correct? Answer with a single word: Yes or No.
|
| 82 |
+
**Answer:** """
|
| 83 |
+
|
| 84 |
+
elif task_type == "math":
|
| 85 |
+
prompt_template = f"""
|
| 86 |
+
You are an expert mathematician and competition judge. Your task is to evaluate a proposed mathematical solution for a given problem based on its logical rigor and accuracy.
|
| 87 |
+
|
| 88 |
+
[Math Problem]
|
| 89 |
+
{prompt}
|
| 90 |
+
[/Math Problem]
|
| 91 |
+
|
| 92 |
+
[Proposed Mathematical Solution]
|
| 93 |
+
{truncated_code}
|
| 94 |
+
[/Proposed Mathematical Solution]
|
| 95 |
+
|
| 96 |
+
**Analysis Steps:**
|
| 97 |
+
1. Reasoning Validity: Are the logical steps and mathematical properties applied correctly?
|
| 98 |
+
2. Calculation Accuracy: Are the intermediate calculations or algebraic manipulations accurate?
|
| 99 |
+
3. Goal Alignment: Does the current reasoning path directly lead toward the final answer required by the problem?
|
| 100 |
+
|
| 101 |
+
**Conclusion**: Based on your analysis, is this solution path sound and likely to result in the correct final answer? Answer with a single word: Yes or No.
|
| 102 |
+
**Answer:** """
|
| 103 |
+
|
| 104 |
+
elif task_type == "reasoning":
|
| 105 |
+
prompt_template = f"""
|
| 106 |
+
You are an expert reading comprehension and faithfulness judge. Your task is to evaluate a generated answer based on the provided context and question.
|
| 107 |
+
|
| 108 |
+
[Context and Question]
|
| 109 |
+
{prompt}
|
| 110 |
+
[/Context and Question]
|
| 111 |
+
|
| 112 |
+
[Proposed Answer]
|
| 113 |
+
{truncated_code}
|
| 114 |
+
[/Proposed Answer]
|
| 115 |
+
|
| 116 |
+
**Analysis Steps :**
|
| 117 |
+
1. Faithfulness: Is the answer an exact, literal span from the context?
|
| 118 |
+
2. Relevance: Does the answer directly address the specific question asked without hallucinating external information?
|
| 119 |
+
3. Accuracy: Does the provided context strictly support this answer?
|
| 120 |
+
|
| 121 |
+
**Conclusion**: Based on your analysis, is the answer fully faithful to the context and correct? Answer with a single word: Yes or No.
|
| 122 |
+
**Answer:** """
|
| 123 |
+
|
| 124 |
+
else:
|
| 125 |
+
prompt_template = f"Is the following answer correct?\nQuestion: {prompt}\nAnswer: {truncated_code}\nAnswer Yes or No.\nAnswer:"
|
| 126 |
+
|
| 127 |
+
verify_text = textwrap.dedent(prompt_template).strip()
|
| 128 |
+
input_ids = self.tokenizer(verify_text, return_tensors="pt").input_ids.to(self.device)
|
| 129 |
+
|
| 130 |
+
max_pos = getattr(self.model.config, "max_position_embeddings",
|
| 131 |
+
getattr(self.model.config, "n_positions",
|
| 132 |
+
getattr(self.model.config, "max_sequence_length", 20480)))
|
| 133 |
+
|
| 134 |
+
if input_ids.shape[1] > max_pos - 16:
|
| 135 |
+
logger.warning("Verifier input is too long, truncating from the left.")
|
| 136 |
+
input_ids = input_ids[:, -(max_pos - 16):]
|
| 137 |
+
|
| 138 |
+
with torch.no_grad():
|
| 139 |
+
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
| 140 |
+
outputs = self.model(input_ids, 'full')
|
| 141 |
+
logits = outputs.logits[0, -1, :]
|
| 142 |
+
|
| 143 |
+
yes_score = max((logits[i].item() for i in self.yes_ids if i < logits.shape[-1]), default=-float('inf'))
|
| 144 |
+
no_score = max((logits[i].item() for i in self.no_ids if i < logits.shape[-1]), default=-float('inf'))
|
| 145 |
+
|
| 146 |
+
if yes_score == -float('inf') and no_score == -float('inf'): return 0.5
|
| 147 |
+
|
| 148 |
+
probs = torch.softmax(torch.tensor([yes_score, no_score]), dim=0)
|
| 149 |
+
return probs[0].item()
|
| 150 |
+
|
| 151 |
+
def get_reward(self, prompt, code_str, mode="confidence", problem_data=None, current_logits=None, task_type="code"):
|
| 152 |
+
if mode == "svf":
|
| 153 |
+
return self.svf_score(prompt, code_str, task_type=task_type)
|
| 154 |
+
else:
|
| 155 |
+
return self.compute_confidence(current_logits)
|
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/utils.py
ADDED
|
@@ -0,0 +1,552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import fnmatch
|
| 3 |
+
import functools
|
| 4 |
+
import hashlib
|
| 5 |
+
import importlib.util
|
| 6 |
+
import inspect
|
| 7 |
+
import json
|
| 8 |
+
import logging
|
| 9 |
+
import os
|
| 10 |
+
import re
|
| 11 |
+
from dataclasses import asdict, is_dataclass
|
| 12 |
+
from itertools import islice
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Any, Callable, Generator, List, Optional, Tuple
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import yaml
|
| 18 |
+
from jinja2 import BaseLoader, Environment, StrictUndefined
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
SPACING = " " * 47
|
| 22 |
+
|
| 23 |
+
HIGHER_IS_BETTER_SYMBOLS = {
|
| 24 |
+
True: "↑",
|
| 25 |
+
False: "↓",
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def setup_logging(verbosity=logging.INFO):
|
| 30 |
+
# Configure the root logger
|
| 31 |
+
class CustomFormatter(logging.Formatter):
|
| 32 |
+
def format(self, record):
|
| 33 |
+
if record.name.startswith("lm_eval."):
|
| 34 |
+
record.name = record.name[len("lm_eval.") :]
|
| 35 |
+
return super().format(record)
|
| 36 |
+
|
| 37 |
+
formatter = CustomFormatter(
|
| 38 |
+
"%(asctime)s %(levelname)-8s [%(name)s:%(lineno)d] %(message)s",
|
| 39 |
+
datefmt="%Y-%m-%d:%H:%M:%S",
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
log_level = os.environ.get("LOGLEVEL", verbosity) or verbosity
|
| 43 |
+
|
| 44 |
+
level_map = {
|
| 45 |
+
"DEBUG": logging.DEBUG,
|
| 46 |
+
"INFO": logging.INFO,
|
| 47 |
+
"WARNING": logging.WARNING,
|
| 48 |
+
"ERROR": logging.ERROR,
|
| 49 |
+
"CRITICAL": logging.CRITICAL,
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
log_level = level_map.get(str(log_level).upper(), logging.INFO)
|
| 53 |
+
|
| 54 |
+
if not logging.root.handlers:
|
| 55 |
+
handler = logging.StreamHandler()
|
| 56 |
+
handler.setFormatter(formatter)
|
| 57 |
+
|
| 58 |
+
root_logger = logging.getLogger()
|
| 59 |
+
root_logger.addHandler(handler)
|
| 60 |
+
root_logger.setLevel(log_level)
|
| 61 |
+
|
| 62 |
+
if log_level == logging.DEBUG:
|
| 63 |
+
third_party_loggers = ["urllib3", "filelock", "fsspec"]
|
| 64 |
+
for logger_name in third_party_loggers:
|
| 65 |
+
logging.getLogger(logger_name).setLevel(logging.INFO)
|
| 66 |
+
else:
|
| 67 |
+
logging.getLogger().setLevel(log_level)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def hash_string(string: str) -> str:
|
| 71 |
+
return hashlib.sha256(string.encode("utf-8")).hexdigest()
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def escaped_split(text, sep_char, maxsplit=-1):
|
| 75 |
+
"""Split text into a list on occurrences of the given separation
|
| 76 |
+
character `sep_char`. The separation character may be escaped by a
|
| 77 |
+
backslash to avoid splitting at that location.
|
| 78 |
+
|
| 79 |
+
The separation character must be a string of size 1.
|
| 80 |
+
|
| 81 |
+
If `maxsplit` is given, at most `maxsplit` splits are done (thus,
|
| 82 |
+
the list will have at most `maxsplit + 1` elements). If `maxsplit`
|
| 83 |
+
is not specified or less than 0, then there is no limit on the
|
| 84 |
+
number of splits (all possible splits are made).
|
| 85 |
+
"""
|
| 86 |
+
assert len(sep_char) == 1, (
|
| 87 |
+
"separation string must be a single character for escaped splitting"
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
if maxsplit == 0:
|
| 91 |
+
return text
|
| 92 |
+
maxsplit = max(0, maxsplit)
|
| 93 |
+
|
| 94 |
+
return re.split(r"(?<!\\)" + sep_char, text, maxsplit)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def handle_arg_string(arg):
|
| 98 |
+
if arg.lower() == "true":
|
| 99 |
+
return True
|
| 100 |
+
elif arg.lower() == "false":
|
| 101 |
+
return False
|
| 102 |
+
elif arg.isnumeric():
|
| 103 |
+
return int(arg)
|
| 104 |
+
try:
|
| 105 |
+
return float(arg)
|
| 106 |
+
except ValueError:
|
| 107 |
+
return arg
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def handle_non_serializable(o):
|
| 111 |
+
if isinstance(o, np.int64) or isinstance(o, np.int32):
|
| 112 |
+
return int(o)
|
| 113 |
+
elif isinstance(o, set):
|
| 114 |
+
return list(o)
|
| 115 |
+
else:
|
| 116 |
+
return str(o)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def sanitize_list(sub):
|
| 120 |
+
"""
|
| 121 |
+
Takes possible nested list and recursively converts all inner component to strings
|
| 122 |
+
"""
|
| 123 |
+
if isinstance(sub, list):
|
| 124 |
+
return [sanitize_list(item) for item in sub]
|
| 125 |
+
if isinstance(sub, tuple):
|
| 126 |
+
return tuple(sanitize_list(item) for item in sub)
|
| 127 |
+
else:
|
| 128 |
+
return str(sub)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def simple_parse_args_string(args_string: Optional[str]) -> dict:
|
| 132 |
+
"""
|
| 133 |
+
Parses something like
|
| 134 |
+
args1=val1,arg2=val2
|
| 135 |
+
Into a dictionary
|
| 136 |
+
"""
|
| 137 |
+
if args_string is None:
|
| 138 |
+
return {}
|
| 139 |
+
args_string = args_string.strip()
|
| 140 |
+
if not args_string:
|
| 141 |
+
return {}
|
| 142 |
+
arg_list = [arg for arg in args_string.split(",") if arg]
|
| 143 |
+
args_dict = {
|
| 144 |
+
kv[0]: handle_arg_string("=".join(kv[1:]))
|
| 145 |
+
for kv in [arg.split("=") for arg in arg_list]
|
| 146 |
+
}
|
| 147 |
+
return args_dict
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def join_iters(iters):
|
| 151 |
+
for iter in iters:
|
| 152 |
+
yield from iter
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def group(arr, fn):
|
| 156 |
+
res = collections.defaultdict(list)
|
| 157 |
+
|
| 158 |
+
for ob in arr:
|
| 159 |
+
res[fn(ob)].append(ob)
|
| 160 |
+
|
| 161 |
+
return list(res.values())
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
# Returns a list containing all values of the source_list that
|
| 165 |
+
# match at least one of the patterns
|
| 166 |
+
def pattern_match(patterns, source_list):
|
| 167 |
+
if isinstance(patterns, str):
|
| 168 |
+
patterns = [patterns]
|
| 169 |
+
|
| 170 |
+
task_names = set()
|
| 171 |
+
for pattern in patterns:
|
| 172 |
+
for matching in fnmatch.filter(source_list, pattern):
|
| 173 |
+
task_names.add(matching)
|
| 174 |
+
return sorted(list(task_names))
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def softmax(x) -> np.ndarray:
|
| 178 |
+
"""Compute softmax values for each sets of scores in x."""
|
| 179 |
+
e_x = np.exp(x - np.max(x))
|
| 180 |
+
return e_x / e_x.sum()
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def general_detokenize(string) -> str:
|
| 184 |
+
string = string.replace(" n't", "n't")
|
| 185 |
+
string = string.replace(" )", ")")
|
| 186 |
+
string = string.replace("( ", "(")
|
| 187 |
+
string = string.replace('" ', '"')
|
| 188 |
+
string = string.replace(' "', '"')
|
| 189 |
+
string = re.sub(r" (['.,])", r"\1", string)
|
| 190 |
+
return string
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def get_file_task_name(filename: str) -> str:
|
| 194 |
+
"""
|
| 195 |
+
Given the sample results filenames, extracts and returns the task name.
|
| 196 |
+
"""
|
| 197 |
+
return filename[filename.find("_") + 1 : filename.rfind("_")]
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def get_file_datetime(filename: str) -> str:
|
| 201 |
+
"""
|
| 202 |
+
Given the results and sample results filenames, extracts and returns the datetime.
|
| 203 |
+
"""
|
| 204 |
+
return filename[filename.rfind("_") + 1 :].replace(".jsonl", "")
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def sanitize_model_name(model_name: str) -> str:
|
| 208 |
+
"""
|
| 209 |
+
Given the model name, returns a sanitized version of it.
|
| 210 |
+
"""
|
| 211 |
+
return re.sub(r"[\"<>:/\|\\?\*\[\]]+", "__", model_name)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def sanitize_task_name(task_name: str) -> str:
|
| 215 |
+
"""
|
| 216 |
+
Given the task name, returns a sanitized version of it.
|
| 217 |
+
"""
|
| 218 |
+
return re.sub(r"\W", "_", task_name)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def get_latest_filename(filenames: List[str]) -> str:
|
| 222 |
+
"""
|
| 223 |
+
Given a list of filenames, returns the filename with the latest datetime.
|
| 224 |
+
"""
|
| 225 |
+
return max(filenames, key=lambda f: get_file_datetime(f))
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def get_results_filenames(filenames: List[str]) -> List[str]:
|
| 229 |
+
"""
|
| 230 |
+
Extracts filenames that correspond to aggregated results.
|
| 231 |
+
"""
|
| 232 |
+
return [f for f in filenames if "/results_" in f and ".json" in f]
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def get_sample_results_filenames(filenames: List[str]) -> List[str]:
|
| 236 |
+
"""
|
| 237 |
+
Extracts filenames that correspond to sample results.
|
| 238 |
+
"""
|
| 239 |
+
return [f for f in filenames if "/samples_" in f and ".json" in f]
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def get_rolling_token_windows(
|
| 243 |
+
token_list: List[int], prefix_token: int, max_seq_len: int, context_len: int
|
| 244 |
+
) -> Generator[Tuple[List[int], List[int]], None, None]:
|
| 245 |
+
"""
|
| 246 |
+
- context_len allows for a rolling window context, allowing each prediction window to potentially
|
| 247 |
+
condition on some context
|
| 248 |
+
|
| 249 |
+
:param token_list: list
|
| 250 |
+
List of tokens to be PREDICTED
|
| 251 |
+
:param max_seq_len: int
|
| 252 |
+
max_seq_len of model (or max_seq_len we want to use)
|
| 253 |
+
:param context_len: int
|
| 254 |
+
Amount of desired token context for prediction. Needs to be at least 1.
|
| 255 |
+
:param prefix_token: token
|
| 256 |
+
Dummy token like <eos> so the first token has something to condition on
|
| 257 |
+
:return: generator
|
| 258 |
+
Generator of tuples
|
| 259 |
+
(input_tokens, pred_tokens)
|
| 260 |
+
Note: Score only the last len(pred_tokens) logits of the LM
|
| 261 |
+
"""
|
| 262 |
+
assert 1 <= context_len <= max_seq_len
|
| 263 |
+
if not token_list:
|
| 264 |
+
return
|
| 265 |
+
# +1 offset, going from input->preds
|
| 266 |
+
pred_len = max_seq_len - context_len + 1
|
| 267 |
+
predicted = 0
|
| 268 |
+
|
| 269 |
+
# Special handling for first window: predict all tokens
|
| 270 |
+
first_seq_len = min(max_seq_len, len(token_list))
|
| 271 |
+
yield [prefix_token] + token_list[: first_seq_len - 1], token_list[:first_seq_len]
|
| 272 |
+
predicted += first_seq_len
|
| 273 |
+
|
| 274 |
+
while predicted < len(token_list):
|
| 275 |
+
window_pred_len = min(len(token_list) - predicted, pred_len)
|
| 276 |
+
window_end = predicted + window_pred_len
|
| 277 |
+
|
| 278 |
+
yield (
|
| 279 |
+
token_list[window_end - max_seq_len - 1 : window_end - 1],
|
| 280 |
+
token_list[window_end - window_pred_len : window_end],
|
| 281 |
+
)
|
| 282 |
+
predicted += window_pred_len
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def make_disjoint_window(
|
| 286 |
+
pair: Tuple[List[int], List[int]],
|
| 287 |
+
) -> Tuple[List[int], List[int]]:
|
| 288 |
+
"""Takes output from get_rolling_token_windows and makes the context not overlap with the continuation"""
|
| 289 |
+
a, b = pair
|
| 290 |
+
return a[: len(a) - (len(b) - 1)], b
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
class EnhancedJSONEncoder(json.JSONEncoder):
|
| 294 |
+
"""
|
| 295 |
+
Provides a proper json encoding for the loggers and trackers json dumps.
|
| 296 |
+
Notably manages the json encoding of dataclasses.
|
| 297 |
+
"""
|
| 298 |
+
|
| 299 |
+
def default(self, o):
|
| 300 |
+
if is_dataclass(o):
|
| 301 |
+
return asdict(o)
|
| 302 |
+
return super().default(o)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
class Reorderer:
|
| 306 |
+
def __init__(self, arr: List[Any], fn: Callable) -> None:
|
| 307 |
+
"""Reorder an array according to some function
|
| 308 |
+
|
| 309 |
+
Args:
|
| 310 |
+
arr (List[Any]): The initial array
|
| 311 |
+
fn (Callable[[Any], Any]): A function to determine the priority of elements
|
| 312 |
+
"""
|
| 313 |
+
self.size = len(arr)
|
| 314 |
+
arr = list(enumerate(arr))
|
| 315 |
+
arr = group(arr, lambda x: fn(x[1]))
|
| 316 |
+
# arr = [([y[0] for y in x], x[0][1]) for x in arr]
|
| 317 |
+
# TODO: overhaul reorderer. It currently grouped requests by content but we don't want this
|
| 318 |
+
arr = [([y[0]], x[0][1]) for x in arr for y in x]
|
| 319 |
+
arr.sort(key=lambda x: fn(x[1]))
|
| 320 |
+
|
| 321 |
+
self.arr = arr
|
| 322 |
+
|
| 323 |
+
def get_reordered(self):
|
| 324 |
+
"""Gets the reordered array
|
| 325 |
+
|
| 326 |
+
Returns:
|
| 327 |
+
List[Any]: The reordered array
|
| 328 |
+
"""
|
| 329 |
+
return [x[1] for x in self.arr]
|
| 330 |
+
|
| 331 |
+
def get_original(self, newarr):
|
| 332 |
+
"""Restores the original order of a new array based on the old array's order
|
| 333 |
+
|
| 334 |
+
Args:
|
| 335 |
+
newarr (List[Any]): The array to be restored
|
| 336 |
+
|
| 337 |
+
Returns:
|
| 338 |
+
List[Any]: The array restored to the original order
|
| 339 |
+
"""
|
| 340 |
+
res = [None] * self.size
|
| 341 |
+
cov = [False] * self.size
|
| 342 |
+
|
| 343 |
+
for (inds, _), v in zip(self.arr, newarr):
|
| 344 |
+
for ind in inds:
|
| 345 |
+
res[ind] = v
|
| 346 |
+
cov[ind] = True
|
| 347 |
+
|
| 348 |
+
assert all(cov)
|
| 349 |
+
|
| 350 |
+
return res
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def make_table(result_dict, column: str = "results", sort_results: bool = False):
|
| 354 |
+
"""Generate table of results."""
|
| 355 |
+
from pytablewriter import LatexTableWriter, MarkdownTableWriter
|
| 356 |
+
|
| 357 |
+
if column == "results":
|
| 358 |
+
column_name = "Tasks"
|
| 359 |
+
elif column == "groups":
|
| 360 |
+
column_name = "Groups"
|
| 361 |
+
|
| 362 |
+
all_headers = [
|
| 363 |
+
column_name,
|
| 364 |
+
"Version",
|
| 365 |
+
"Filter",
|
| 366 |
+
"n-shot",
|
| 367 |
+
"Metric",
|
| 368 |
+
"",
|
| 369 |
+
"Value",
|
| 370 |
+
"",
|
| 371 |
+
"Stderr",
|
| 372 |
+
]
|
| 373 |
+
|
| 374 |
+
md_writer = MarkdownTableWriter()
|
| 375 |
+
latex_writer = LatexTableWriter()
|
| 376 |
+
md_writer.headers = all_headers
|
| 377 |
+
latex_writer.headers = all_headers
|
| 378 |
+
|
| 379 |
+
values = []
|
| 380 |
+
|
| 381 |
+
keys = result_dict[column].keys()
|
| 382 |
+
if sort_results:
|
| 383 |
+
# sort entries alphabetically by task or group name.
|
| 384 |
+
# NOTE: we default here to false, because order matters for multi-level table printing a la mmlu.
|
| 385 |
+
# sorting here would mess that up
|
| 386 |
+
keys = sorted(keys)
|
| 387 |
+
for k in keys:
|
| 388 |
+
dic = result_dict[column][k]
|
| 389 |
+
version = result_dict["versions"].get(k, " N/A")
|
| 390 |
+
n = str(result_dict.get("n-shot", " ").get(k, " "))
|
| 391 |
+
higher_is_better = result_dict.get("higher_is_better", {}).get(k, {})
|
| 392 |
+
|
| 393 |
+
if "alias" in dic:
|
| 394 |
+
k = dic.pop("alias")
|
| 395 |
+
|
| 396 |
+
metric_items = dic.items()
|
| 397 |
+
metric_items = sorted(metric_items)
|
| 398 |
+
|
| 399 |
+
for (mf), v in metric_items:
|
| 400 |
+
m, _, f = mf.partition(",")
|
| 401 |
+
if m.endswith("_stderr"):
|
| 402 |
+
continue
|
| 403 |
+
|
| 404 |
+
hib = HIGHER_IS_BETTER_SYMBOLS.get(higher_is_better.get(m), "")
|
| 405 |
+
|
| 406 |
+
v = "%.4f" % v if isinstance(v, float) else v
|
| 407 |
+
|
| 408 |
+
if m + "_stderr" + "," + f in dic:
|
| 409 |
+
se = dic[m + "_stderr" + "," + f]
|
| 410 |
+
se = " N/A" if se == "N/A" else "%.4f" % se
|
| 411 |
+
values.append([k, version, f, n, m, hib, v, "±", se])
|
| 412 |
+
else:
|
| 413 |
+
values.append([k, version, f, n, m, hib, v, "", ""])
|
| 414 |
+
k = ""
|
| 415 |
+
version = ""
|
| 416 |
+
md_writer.value_matrix = values
|
| 417 |
+
latex_writer.value_matrix = values
|
| 418 |
+
|
| 419 |
+
# todo: make latex table look good
|
| 420 |
+
# print(latex_writer.dumps())
|
| 421 |
+
|
| 422 |
+
return md_writer.dumps()
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def positional_deprecated(fn):
|
| 426 |
+
"""
|
| 427 |
+
A decorator to nudge users into passing only keyword args (`kwargs`) to the
|
| 428 |
+
wrapped function, `fn`.
|
| 429 |
+
"""
|
| 430 |
+
|
| 431 |
+
@functools.wraps(fn)
|
| 432 |
+
def _wrapper(*args, **kwargs):
|
| 433 |
+
if len(args) != 1 if inspect.ismethod(fn) else 0:
|
| 434 |
+
print(
|
| 435 |
+
f"WARNING: using {fn.__name__} with positional arguments is "
|
| 436 |
+
"deprecated and will be disallowed in a future version of "
|
| 437 |
+
"lm-evaluation-harness!"
|
| 438 |
+
)
|
| 439 |
+
return fn(*args, **kwargs)
|
| 440 |
+
|
| 441 |
+
return _wrapper
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def ignore_constructor(loader, node):
|
| 445 |
+
return node
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def import_function(loader: yaml.Loader, node, yaml_path: Path):
|
| 449 |
+
function_name = loader.construct_scalar(node)
|
| 450 |
+
|
| 451 |
+
*module_name, function_name = function_name.split(".")
|
| 452 |
+
if isinstance(module_name, list):
|
| 453 |
+
module_name = ".".join(module_name)
|
| 454 |
+
module_path = yaml_path.parent / f"{module_name}.py"
|
| 455 |
+
|
| 456 |
+
spec = importlib.util.spec_from_file_location(module_name, module_path.as_posix())
|
| 457 |
+
|
| 458 |
+
if spec is None:
|
| 459 |
+
raise ImportError(f"Could not import module {module_name} from {module_path}.")
|
| 460 |
+
module = importlib.util.module_from_spec(spec)
|
| 461 |
+
|
| 462 |
+
if spec.loader is None:
|
| 463 |
+
raise ImportError(f"Module loader is None, {module_name} from {module_path}.")
|
| 464 |
+
spec.loader.exec_module(module)
|
| 465 |
+
|
| 466 |
+
function = getattr(module, function_name)
|
| 467 |
+
return function
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full"):
|
| 471 |
+
if mode == "simple":
|
| 472 |
+
constructor_fn = ignore_constructor
|
| 473 |
+
elif mode == "full":
|
| 474 |
+
if yaml_path is None:
|
| 475 |
+
raise ValueError("yaml_path must be provided if mode is 'full'.")
|
| 476 |
+
# Attach yaml_path to the import function so that it can be used later
|
| 477 |
+
constructor_fn = functools.partial(import_function, yaml_path=Path(yaml_path))
|
| 478 |
+
|
| 479 |
+
loader = yaml.CLoader if yaml.__with_libyaml__ else yaml.FullLoader
|
| 480 |
+
# Add the import_function constructor to the YAML loader
|
| 481 |
+
yaml.add_constructor("!function", constructor_fn, Loader=loader)
|
| 482 |
+
if yaml_config is None:
|
| 483 |
+
with open(yaml_path, "rb") as file:
|
| 484 |
+
yaml_config = yaml.load(file, Loader=loader)
|
| 485 |
+
|
| 486 |
+
if yaml_dir is None:
|
| 487 |
+
yaml_dir = os.path.dirname(yaml_path)
|
| 488 |
+
|
| 489 |
+
assert yaml_dir is not None
|
| 490 |
+
|
| 491 |
+
if "include" in yaml_config:
|
| 492 |
+
include_path = yaml_config["include"]
|
| 493 |
+
del yaml_config["include"]
|
| 494 |
+
|
| 495 |
+
if isinstance(include_path, str):
|
| 496 |
+
include_path = [include_path]
|
| 497 |
+
|
| 498 |
+
# Load from the last one first
|
| 499 |
+
include_path.reverse()
|
| 500 |
+
final_yaml_config = {}
|
| 501 |
+
for path in include_path:
|
| 502 |
+
# Assumes that path is a full path.
|
| 503 |
+
# If not found, assume the included yaml
|
| 504 |
+
# is in the same dir as the original yaml
|
| 505 |
+
if not os.path.isfile(path):
|
| 506 |
+
path = os.path.join(yaml_dir, path)
|
| 507 |
+
|
| 508 |
+
try:
|
| 509 |
+
included_yaml_config = load_yaml_config(yaml_path=path, mode=mode)
|
| 510 |
+
final_yaml_config.update(included_yaml_config)
|
| 511 |
+
except Exception as ex:
|
| 512 |
+
# If failed to load, ignore
|
| 513 |
+
raise ex
|
| 514 |
+
|
| 515 |
+
final_yaml_config.update(yaml_config)
|
| 516 |
+
return final_yaml_config
|
| 517 |
+
return yaml_config
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
def regex_replace(string, pattern, repl, count: int = 0):
|
| 521 |
+
"""Implements the `re.sub` function as a custom Jinja filter."""
|
| 522 |
+
return re.sub(pattern, repl, string, count=count)
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
env = Environment(
|
| 526 |
+
loader=BaseLoader, undefined=StrictUndefined, keep_trailing_newline=True
|
| 527 |
+
)
|
| 528 |
+
env.filters["regex_replace"] = regex_replace
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
def apply_template(template: str, doc: dict) -> str:
|
| 532 |
+
rtemplate = env.from_string(template)
|
| 533 |
+
return rtemplate.render(**doc)
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
def create_iterator(raw_iterator, *, rank=0, world_size=1, limit=None):
|
| 537 |
+
"""
|
| 538 |
+
Method for creating a (potentially) sliced and limited
|
| 539 |
+
iterator from a raw document iterator. Used for splitting data
|
| 540 |
+
among ranks in multigpu setting or only pulling a sample of documents
|
| 541 |
+
"""
|
| 542 |
+
return islice(raw_iterator, rank, limit, world_size)
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
def weighted_f1_score(items):
|
| 546 |
+
from sklearn.metrics import f1_score
|
| 547 |
+
|
| 548 |
+
unzipped_list = list(zip(*items))
|
| 549 |
+
golds = unzipped_list[0]
|
| 550 |
+
preds = unzipped_list[1]
|
| 551 |
+
fscore = f1_score(golds, preds, average="weighted")
|
| 552 |
+
return fscore
|
Prism/Dream/Dream_Prism/eval_instruct/pyproject.toml
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=40.8.0", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "lm_eval"
|
| 7 |
+
version = "0.4.8"
|
| 8 |
+
authors = [
|
| 9 |
+
{name="EleutherAI", email="contact@eleuther.ai"}
|
| 10 |
+
]
|
| 11 |
+
description = "A framework for evaluating language models"
|
| 12 |
+
readme = "README.md"
|
| 13 |
+
classifiers = [
|
| 14 |
+
"Development Status :: 3 - Alpha",
|
| 15 |
+
"Programming Language :: Python :: 3",
|
| 16 |
+
"License :: OSI Approved :: MIT License",
|
| 17 |
+
"Operating System :: OS Independent",
|
| 18 |
+
]
|
| 19 |
+
requires-python = ">=3.9"
|
| 20 |
+
license = { "text" = "MIT" }
|
| 21 |
+
dependencies = [
|
| 22 |
+
"accelerate>=0.26.0",
|
| 23 |
+
"evaluate",
|
| 24 |
+
"datasets>=2.16.0",
|
| 25 |
+
"evaluate>=0.4.0",
|
| 26 |
+
"jsonlines",
|
| 27 |
+
"numexpr",
|
| 28 |
+
"peft>=0.2.0",
|
| 29 |
+
"pybind11>=2.6.2",
|
| 30 |
+
"pytablewriter",
|
| 31 |
+
"rouge-score>=0.0.4",
|
| 32 |
+
"sacrebleu>=1.5.0",
|
| 33 |
+
"scikit-learn>=0.24.1",
|
| 34 |
+
"sqlitedict",
|
| 35 |
+
"torch>=1.8",
|
| 36 |
+
"tqdm-multiprocess",
|
| 37 |
+
"transformers>=4.1",
|
| 38 |
+
"zstandard",
|
| 39 |
+
"dill",
|
| 40 |
+
"word2number",
|
| 41 |
+
"more_itertools",
|
| 42 |
+
]
|
| 43 |
+
|
| 44 |
+
[tool.setuptools.packages.find]
|
| 45 |
+
include = ["lm_eval*"]
|
| 46 |
+
|
| 47 |
+
# required to include yaml files in pip installation
|
| 48 |
+
[tool.setuptools.package-data]
|
| 49 |
+
lm_eval = ["**/*.yaml", "tasks/**/*"]
|
| 50 |
+
|
| 51 |
+
[project.scripts]
|
| 52 |
+
lm-eval = "lm_eval.__main__:cli_evaluate"
|
| 53 |
+
lm_eval = "lm_eval.__main__:cli_evaluate"
|
| 54 |
+
|
| 55 |
+
[project.urls]
|
| 56 |
+
Homepage = "https://github.com/EleutherAI/lm-evaluation-harness"
|
| 57 |
+
Repository = "https://github.com/EleutherAI/lm-evaluation-harness"
|
| 58 |
+
|
| 59 |
+
[project.optional-dependencies]
|
| 60 |
+
api = ["requests", "aiohttp", "tenacity", "tqdm", "tiktoken"]
|
| 61 |
+
audiolm_qwen = ["librosa", "soundfile"]
|
| 62 |
+
deepsparse = ["deepsparse-nightly[llm]>=1.8.0.20240404"]
|
| 63 |
+
dev = ["pytest", "pytest-cov", "pytest-xdist", "pre-commit", "mypy", "unitxt"]
|
| 64 |
+
gptq = ["auto-gptq[triton]>=0.6.0"]
|
| 65 |
+
gptqmodel = ["gptqmodel>=1.0.9"]
|
| 66 |
+
hf_transfer = ["hf_transfer"]
|
| 67 |
+
ibm_watsonx_ai = ["ibm_watsonx_ai>=1.1.22", "python-dotenv"]
|
| 68 |
+
ifeval = ["langdetect", "immutabledict", "nltk>=3.9.1"]
|
| 69 |
+
ipex = ["optimum"]
|
| 70 |
+
japanese_leaderboard = ["emoji==2.14.0", "neologdn==0.5.3", "fugashi[unidic-lite]", "rouge_score>=0.1.2"]
|
| 71 |
+
longbench=["jeiba", "fuzzywuzzy", "rouge"]
|
| 72 |
+
mamba = ["mamba_ssm", "causal-conv1d==1.0.2"]
|
| 73 |
+
math = ["sympy>=1.12", "antlr4-python3-runtime==4.11", "math_verify[antlr4_11_0]"]
|
| 74 |
+
multilingual = ["nagisa>=0.2.7", "jieba>=0.42.1", "pycountry"]
|
| 75 |
+
neuronx = ["optimum[neuronx]"]
|
| 76 |
+
optimum = ["optimum[openvino]"]
|
| 77 |
+
promptsource = ["promptsource>=0.2.3"]
|
| 78 |
+
ruler = ["nltk", "wonderwords", "scipy"]
|
| 79 |
+
sae_lens = ["sae_lens"]
|
| 80 |
+
sentencepiece = ["sentencepiece>=0.1.98"]
|
| 81 |
+
sparseml = ["sparseml-nightly[llm]>=1.8.0.20240404"]
|
| 82 |
+
sparsify = ["sparsify"]
|
| 83 |
+
testing = ["pytest", "pytest-cov", "pytest-xdist"]
|
| 84 |
+
vllm = ["vllm>=0.4.2"]
|
| 85 |
+
wandb = ["wandb>=0.16.3", "pandas", "numpy"]
|
| 86 |
+
zeno = ["pandas", "zeno-client"]
|
| 87 |
+
all = [
|
| 88 |
+
"lm_eval[api]",
|
| 89 |
+
"lm_eval[audiolm_qwen]",
|
| 90 |
+
"lm_eval[deepsparse]",
|
| 91 |
+
"lm_eval[dev]",
|
| 92 |
+
"lm_eval[gptq]",
|
| 93 |
+
"lm_eval[gptqmodel]",
|
| 94 |
+
"lm_eval[hf_transfer]",
|
| 95 |
+
"lm_eval[ibm_watsonx_ai]",
|
| 96 |
+
"lm_eval[ifeval]",
|
| 97 |
+
"lm_eval[ipex]",
|
| 98 |
+
"lm_eval[japanese_leaderboard]",
|
| 99 |
+
"lm_eval[longbench]",
|
| 100 |
+
"lm_eval[mamba]",
|
| 101 |
+
"lm_eval[math]",
|
| 102 |
+
"lm_eval[multilingual]",
|
| 103 |
+
"lm_eval[neuronx]",
|
| 104 |
+
"lm_eval[optimum]",
|
| 105 |
+
"lm_eval[promptsource]",
|
| 106 |
+
"lm_eval[ruler]",
|
| 107 |
+
"lm_eval[sae_lens]",
|
| 108 |
+
"lm_eval[sentencepiece]",
|
| 109 |
+
"lm_eval[sparseml]",
|
| 110 |
+
"lm_eval[sparsify]",
|
| 111 |
+
"lm_eval[testing]",
|
| 112 |
+
"lm_eval[vllm]",
|
| 113 |
+
"lm_eval[wandb]",
|
| 114 |
+
"lm_eval[zeno]",
|
| 115 |
+
]
|
| 116 |
+
|
| 117 |
+
[tool.pymarkdown]
|
| 118 |
+
plugins.md013.enabled = false # line-length
|
| 119 |
+
plugins.md024.allow_different_nesting = true # no-duplicate-headers
|
| 120 |
+
plugins.md025.enabled = false # single-header
|
| 121 |
+
plugins.md028.enabled = false # no-blanks-blockquote
|
| 122 |
+
plugins.md029.allow_extended_start_values = true # ol-prefix
|
| 123 |
+
plugins.md034.enabled = false # no-bare-urls
|
| 124 |
+
|
| 125 |
+
[tool.ruff.lint]
|
| 126 |
+
extend-select = ["I"]
|
| 127 |
+
|
| 128 |
+
[tool.ruff.lint.isort]
|
| 129 |
+
lines-after-imports = 2
|
| 130 |
+
known-first-party = ["lm_eval"]
|
| 131 |
+
|
| 132 |
+
[tool.ruff.lint.extend-per-file-ignores]
|
| 133 |
+
"__init__.py" = ["F401","F402","F403"]
|
| 134 |
+
"utils.py" = ["F401"]
|
Prism/Dream/Dream_Prism/eval_instruct/requirements.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
-e .
|
Prism/Dream/Dream_Prism/eval_instruct/setup.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import setuptools
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
# This is to make sure that the package supports editable installs
|
| 5 |
+
setuptools.setup()
|
Prism/Dream/Dream_Prism/metrics/gsmk8_eval.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import re
|
| 3 |
+
import os
|
| 4 |
+
import glob
|
| 5 |
+
import math
|
| 6 |
+
import argparse
|
| 7 |
+
from collections import Counter
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
RES_PATH = "<PATH_TO_RESULTS_JSONL>"
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def last_boxed_only_string(string):
|
| 14 |
+
if not string: return None
|
| 15 |
+
idx = string.rfind("\\boxed")
|
| 16 |
+
if "\\boxed " in string:
|
| 17 |
+
return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
|
| 18 |
+
if idx < 0:
|
| 19 |
+
idx = string.rfind("\\fbox")
|
| 20 |
+
if idx < 0: return None
|
| 21 |
+
i = idx
|
| 22 |
+
right_brace_idx = None
|
| 23 |
+
num_left_braces_open = 0
|
| 24 |
+
while i < len(string):
|
| 25 |
+
if string[i] == "{":
|
| 26 |
+
num_left_braces_open += 1
|
| 27 |
+
if string[i] == "}":
|
| 28 |
+
num_left_braces_open -= 1
|
| 29 |
+
if num_left_braces_open == 0:
|
| 30 |
+
right_brace_idx = i
|
| 31 |
+
break
|
| 32 |
+
i += 1
|
| 33 |
+
return string[idx : right_brace_idx + 1] if right_brace_idx else None
|
| 34 |
+
|
| 35 |
+
def remove_boxed(s):
|
| 36 |
+
if not s: return None
|
| 37 |
+
if "\\boxed " in s: return s[len("\\boxed ") :]
|
| 38 |
+
if "\\boxed{" in s and s.endswith("}"): return s[len("\\boxed{") : -1]
|
| 39 |
+
return s
|
| 40 |
+
|
| 41 |
+
def strip_string(string):
|
| 42 |
+
if string is None: return ""
|
| 43 |
+
string = str(string).strip()
|
| 44 |
+
while re.search(r"(\d),(\d{3})", string):
|
| 45 |
+
string = re.sub(r"(\d),(\d{3})", r"\1\2", string)
|
| 46 |
+
string = string.replace("\n", "").replace("\\!", "")
|
| 47 |
+
string = string.replace("tfrac", "frac").replace("dfrac", "frac")
|
| 48 |
+
string = string.replace("\\left", "").replace("\\right", "")
|
| 49 |
+
string = string.replace("^{\\circ}", "").replace("^\\circ", "")
|
| 50 |
+
string = string.replace("\\$", "").replace("\\%", "").replace("\%", "")
|
| 51 |
+
if "=" in string and len(string.split("=")[0]) <= 3:
|
| 52 |
+
string = string.split("=")[1].strip()
|
| 53 |
+
string = string.replace(" ", "")
|
| 54 |
+
return string
|
| 55 |
+
|
| 56 |
+
def extract_answer_gsm8k(text):
|
| 57 |
+
if not text: return ""
|
| 58 |
+
boxed = last_boxed_only_string(text)
|
| 59 |
+
if boxed:
|
| 60 |
+
ans = remove_boxed(boxed)
|
| 61 |
+
if ans: return strip_string(ans)
|
| 62 |
+
|
| 63 |
+
tag_match = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL)
|
| 64 |
+
if tag_match:
|
| 65 |
+
return strip_string(tag_match.group(1))
|
| 66 |
+
|
| 67 |
+
nums = re.findall(r"-?\d+\.?\d*", text[-50:])
|
| 68 |
+
if nums:
|
| 69 |
+
return strip_string(nums[-1])
|
| 70 |
+
|
| 71 |
+
return ""
|
| 72 |
+
|
| 73 |
+
def extract_gold_gsm8k(target_str):
|
| 74 |
+
if "####" in target_str:
|
| 75 |
+
return strip_string(target_str.split("####")[-1])
|
| 76 |
+
return strip_string(target_str)
|
| 77 |
+
|
| 78 |
+
def is_equiv(pred, gold):
|
| 79 |
+
p = strip_string(pred)
|
| 80 |
+
g = strip_string(gold)
|
| 81 |
+
try:
|
| 82 |
+
return math.isclose(float(p), float(g), rel_tol=1e-4)
|
| 83 |
+
except:
|
| 84 |
+
return p == g
|
| 85 |
+
|
| 86 |
+
def run_evaluation(target_path):
|
| 87 |
+
if os.path.isdir(target_path):
|
| 88 |
+
jsonl_files = glob.glob(os.path.join(target_path, "*.jsonl"))
|
| 89 |
+
else:
|
| 90 |
+
jsonl_files = [target_path]
|
| 91 |
+
|
| 92 |
+
for file_path in jsonl_files:
|
| 93 |
+
print(f">>> 正在评测: {file_path}")
|
| 94 |
+
detailed_results = []
|
| 95 |
+
correct_count = 0
|
| 96 |
+
total_count = 0
|
| 97 |
+
nfe_list = []
|
| 98 |
+
svf_list = []
|
| 99 |
+
|
| 100 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 101 |
+
for line in f:
|
| 102 |
+
if not line.strip(): continue
|
| 103 |
+
item = json.loads(line)
|
| 104 |
+
doc = item.get("doc", {})
|
| 105 |
+
|
| 106 |
+
ground_truth = extract_gold_gsm8k(str(item.get("target", "")))
|
| 107 |
+
nfe_list.append(item.get("nfe", 0))
|
| 108 |
+
svf_list.append(item.get("svf_calls", 0))
|
| 109 |
+
|
| 110 |
+
ans_stats = {}
|
| 111 |
+
|
| 112 |
+
trajectories = item.get("all_trajectories", [])
|
| 113 |
+
if not trajectories:
|
| 114 |
+
resps = item.get("resps", [])
|
| 115 |
+
for r in resps:
|
| 116 |
+
text = r[0] if isinstance(r, list) else r
|
| 117 |
+
trajectories.append({"resp": text, "score": 0.0})
|
| 118 |
+
|
| 119 |
+
for traj in trajectories:
|
| 120 |
+
raw_text = traj.get("resp", "")
|
| 121 |
+
score = traj.get("score", -float('inf'))
|
| 122 |
+
extracted = extract_answer_gsm8k(raw_text)
|
| 123 |
+
|
| 124 |
+
if not extracted: continue
|
| 125 |
+
|
| 126 |
+
norm = strip_string(extracted)
|
| 127 |
+
if norm not in ans_stats:
|
| 128 |
+
ans_stats[norm] = {"count": 0, "max_score": -float('inf'), "original": extracted}
|
| 129 |
+
|
| 130 |
+
ans_stats[norm]["count"] += 1
|
| 131 |
+
if score > ans_stats[norm]["max_score"]:
|
| 132 |
+
ans_stats[norm]["max_score"] = score
|
| 133 |
+
ans_stats[norm]["original"] = extracted
|
| 134 |
+
|
| 135 |
+
if not ans_stats:
|
| 136 |
+
best_pred = ""
|
| 137 |
+
else:
|
| 138 |
+
sorted_norms = sorted(
|
| 139 |
+
ans_stats.keys(),
|
| 140 |
+
key=lambda x: (ans_stats[x]["count"], ans_stats[x]["max_score"]),
|
| 141 |
+
reverse=True
|
| 142 |
+
)
|
| 143 |
+
best_norm = sorted_norms[0]
|
| 144 |
+
best_pred = ans_stats[best_norm]["original"]
|
| 145 |
+
|
| 146 |
+
ans_correct = is_equiv(best_pred, ground_truth)
|
| 147 |
+
if ans_correct:
|
| 148 |
+
correct_count += 1
|
| 149 |
+
total_count += 1
|
| 150 |
+
|
| 151 |
+
detailed_results.append({
|
| 152 |
+
"question": doc.get("question", "N/A"),
|
| 153 |
+
"final_voted_answer": best_pred,
|
| 154 |
+
"ground_truth": ground_truth,
|
| 155 |
+
"is_correct": ans_correct,
|
| 156 |
+
"nfe": item.get("nfe", 0),
|
| 157 |
+
"svf_calls": item.get("svf_calls", 0)
|
| 158 |
+
})
|
| 159 |
+
|
| 160 |
+
accuracy = (correct_count / total_count * 100) if total_count > 0 else 0
|
| 161 |
+
|
| 162 |
+
avg_nfe = int(round(sum(nfe_list) / len(nfe_list))) if nfe_list else 0
|
| 163 |
+
avg_svf = int(round(sum(svf_list) / len(svf_list))) if svf_list else 0
|
| 164 |
+
|
| 165 |
+
print(f"Accuracy: {accuracy:.2f}% | NFE: {avg_nfe} | SVF: {avg_svf} ---")
|
| 166 |
+
|
| 167 |
+
output_name = f"eval_voted_{os.path.basename(file_path).replace('.jsonl', '.json')}"
|
| 168 |
+
output_path = os.path.join(os.path.dirname(file_path), output_name)
|
| 169 |
+
|
| 170 |
+
final_report = {
|
| 171 |
+
"summary": {
|
| 172 |
+
"accuracy": f"{accuracy:.2f}%",
|
| 173 |
+
"correct": correct_count,
|
| 174 |
+
"total": total_count,
|
| 175 |
+
"nfe": avg_nfe,
|
| 176 |
+
"svf_calls": avg_svf
|
| 177 |
+
},
|
| 178 |
+
"details": detailed_results
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
with open(output_path, 'w', encoding='utf-8') as out_f:
|
| 182 |
+
json.dump(final_report, out_f, ensure_ascii=False, indent=4)
|
| 183 |
+
|
| 184 |
+
if __name__ == "__main__":
|
| 185 |
+
parser = argparse.ArgumentParser()
|
| 186 |
+
parser.add_argument("-r", "--res_path", type=str, default=RES_PATH)
|
| 187 |
+
args = parser.parse_args()
|
| 188 |
+
run_evaluation(args.res_path)
|
Prism/Dream/Dream_Prism/metrics/humaneval_eval.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import json
|
| 4 |
+
import ast
|
| 5 |
+
import re
|
| 6 |
+
import glob
|
| 7 |
+
import argparse
|
| 8 |
+
import textwrap
|
| 9 |
+
import evaluate as hf_evaluate
|
| 10 |
+
from collections import Counter
|
| 11 |
+
|
| 12 |
+
os.environ["HF_ALLOW_CODE_EVAL"] = "1"
|
| 13 |
+
|
| 14 |
+
RES_PATH = "<PATH_TO_RESULTS_JSONL>"
|
| 15 |
+
|
| 16 |
+
def strict_dedent(text: str) -> str:
|
| 17 |
+
lines = text.split('\n')
|
| 18 |
+
while lines and not lines[0].strip(): lines.pop(0)
|
| 19 |
+
while lines and not lines[-1].strip(): lines.pop()
|
| 20 |
+
|
| 21 |
+
if not lines:
|
| 22 |
+
return ""
|
| 23 |
+
|
| 24 |
+
min_indent = None
|
| 25 |
+
for line in lines:
|
| 26 |
+
if line.strip():
|
| 27 |
+
indent = len(line) - len(line.lstrip())
|
| 28 |
+
if min_indent is None or indent < min_indent:
|
| 29 |
+
min_indent = indent
|
| 30 |
+
|
| 31 |
+
if min_indent is None:
|
| 32 |
+
min_indent = 0
|
| 33 |
+
|
| 34 |
+
dedented_lines = []
|
| 35 |
+
for line in lines:
|
| 36 |
+
if line.strip():
|
| 37 |
+
if len(line) >= min_indent:
|
| 38 |
+
dedented_lines.append(line[min_indent:])
|
| 39 |
+
else:
|
| 40 |
+
dedented_lines.append(line.lstrip())
|
| 41 |
+
else:
|
| 42 |
+
dedented_lines.append("")
|
| 43 |
+
|
| 44 |
+
return "\n".join(dedented_lines)
|
| 45 |
+
|
| 46 |
+
def extract_python_code(text: str) -> str:
|
| 47 |
+
if not text:
|
| 48 |
+
return ""
|
| 49 |
+
|
| 50 |
+
text = text.replace("<|role_end|>", "").replace("<|endoftext|>", "").replace("<|notification_end|>", "")
|
| 51 |
+
|
| 52 |
+
tag_match = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL)
|
| 53 |
+
if tag_match:
|
| 54 |
+
text = tag_match.group(1)
|
| 55 |
+
|
| 56 |
+
code_block_pattern = re.compile(r"```(?:python)?\n?(.*?)```", re.DOTALL)
|
| 57 |
+
match = code_block_pattern.search(text)
|
| 58 |
+
|
| 59 |
+
if match:
|
| 60 |
+
content = match.group(1)
|
| 61 |
+
else:
|
| 62 |
+
if "```" in text:
|
| 63 |
+
content = text.split("```")[0]
|
| 64 |
+
else:
|
| 65 |
+
lines = text.split('\n')
|
| 66 |
+
cleaned_lines = []
|
| 67 |
+
stop_words = ["Explanation:", "Example:", "Test Case:", "Output:", "Here are the tests:"]
|
| 68 |
+
for line in lines:
|
| 69 |
+
if any(sw in line for sw in stop_words):
|
| 70 |
+
break
|
| 71 |
+
cleaned_lines.append(line)
|
| 72 |
+
content = "\n".join(cleaned_lines)
|
| 73 |
+
|
| 74 |
+
return strict_dedent(content)
|
| 75 |
+
|
| 76 |
+
def normalize_code(code: str) -> str:
|
| 77 |
+
try:
|
| 78 |
+
tree = ast.parse(code)
|
| 79 |
+
for node in ast.walk(tree):
|
| 80 |
+
if isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Module)):
|
| 81 |
+
if (node.body and isinstance(node.body[0], ast.Expr) and
|
| 82 |
+
isinstance(node.body[0].value, ast.Constant) and isinstance(node.body[0].value.value, str)):
|
| 83 |
+
node.body.pop(0)
|
| 84 |
+
return ast.unparse(tree).strip()
|
| 85 |
+
except:
|
| 86 |
+
return re.sub(r"\s+", "", code)
|
| 87 |
+
|
| 88 |
+
def sanitize(prompt: str, completion: str, entry_point: str) -> str:
|
| 89 |
+
if f"def {entry_point}" in completion:
|
| 90 |
+
imports = [line for line in prompt.split("\n") if line.startswith("import ") or line.startswith("from ")]
|
| 91 |
+
return "\n".join(imports) + "\n" + completion
|
| 92 |
+
|
| 93 |
+
clean_body = strict_dedent(completion)
|
| 94 |
+
if not clean_body:
|
| 95 |
+
return prompt
|
| 96 |
+
|
| 97 |
+
indented_body = "\n".join([" " + line if line.strip() else "" for line in clean_body.split('\n')])
|
| 98 |
+
return prompt.strip() + "\n" + indented_body
|
| 99 |
+
|
| 100 |
+
def perform_majority_voting(trajectories, prompt, entry_point):
|
| 101 |
+
candidate_stats = {}
|
| 102 |
+
|
| 103 |
+
for item in trajectories:
|
| 104 |
+
if isinstance(item, dict):
|
| 105 |
+
raw_text = item.get("resp", "")
|
| 106 |
+
score = item.get("score", 0.0)
|
| 107 |
+
else:
|
| 108 |
+
raw_text = str(item[0] if isinstance(item, list) else item)
|
| 109 |
+
score = 0.0
|
| 110 |
+
|
| 111 |
+
extracted_code = extract_python_code(raw_text)
|
| 112 |
+
full_code = sanitize(prompt, extracted_code, entry_point)
|
| 113 |
+
|
| 114 |
+
is_valid = False
|
| 115 |
+
try:
|
| 116 |
+
ast.parse(full_code)
|
| 117 |
+
is_valid = True
|
| 118 |
+
except:
|
| 119 |
+
is_valid = False
|
| 120 |
+
|
| 121 |
+
norm_key = normalize_code(full_code)
|
| 122 |
+
if not norm_key: continue
|
| 123 |
+
|
| 124 |
+
if norm_key not in candidate_stats:
|
| 125 |
+
candidate_stats[norm_key] = {
|
| 126 |
+
"count": 0,
|
| 127 |
+
"max_score": -float("inf"),
|
| 128 |
+
"code": full_code,
|
| 129 |
+
"is_valid": is_valid
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
candidate_stats[norm_key]["count"] += 1
|
| 133 |
+
candidate_stats[norm_key]["max_score"] = max(candidate_stats[norm_key]["max_score"], score)
|
| 134 |
+
|
| 135 |
+
if not candidate_stats:
|
| 136 |
+
return prompt
|
| 137 |
+
|
| 138 |
+
sorted_candidates = sorted(
|
| 139 |
+
candidate_stats.values(),
|
| 140 |
+
key=lambda x: (x["is_valid"], x["count"], x["max_score"]),
|
| 141 |
+
reverse=True
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
return sorted_candidates[0]["code"]
|
| 145 |
+
|
| 146 |
+
def run_evaluation(target_path):
|
| 147 |
+
if os.path.isdir(target_path):
|
| 148 |
+
jsonl_files = glob.glob(os.path.join(target_path, "*.jsonl"))
|
| 149 |
+
else:
|
| 150 |
+
jsonl_files = [target_path]
|
| 151 |
+
|
| 152 |
+
try:
|
| 153 |
+
code_eval = hf_evaluate.load("code_eval")
|
| 154 |
+
except Exception as e:
|
| 155 |
+
print(f"Error loading code_eval: {e}")
|
| 156 |
+
return
|
| 157 |
+
|
| 158 |
+
for file_path in jsonl_files:
|
| 159 |
+
print(f">>> 正在评测文件: {file_path}")
|
| 160 |
+
|
| 161 |
+
all_voted_predictions = []
|
| 162 |
+
all_references = []
|
| 163 |
+
detailed_logs = []
|
| 164 |
+
|
| 165 |
+
nfe_sum = 0
|
| 166 |
+
svf_sum = 0
|
| 167 |
+
valid_samples = 0
|
| 168 |
+
|
| 169 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 170 |
+
for line in f:
|
| 171 |
+
if not line.strip(): continue
|
| 172 |
+
try:
|
| 173 |
+
data = json.loads(line)
|
| 174 |
+
except:
|
| 175 |
+
continue
|
| 176 |
+
|
| 177 |
+
doc = data.get("doc", {})
|
| 178 |
+
task_id = doc.get("task_id", f"Task_{valid_samples}")
|
| 179 |
+
prompt = doc.get("prompt", "")
|
| 180 |
+
entry_point = doc.get("entry_point", "solution")
|
| 181 |
+
test_code = doc.get("test", "") + f"\ncheck({entry_point})"
|
| 182 |
+
|
| 183 |
+
nfe_sum += data.get("nfe", 0)
|
| 184 |
+
svf_sum += data.get("svf_calls", 0)
|
| 185 |
+
valid_samples += 1
|
| 186 |
+
|
| 187 |
+
trajectories = data.get("all_trajectories", data.get("resps", []))
|
| 188 |
+
voted_code = perform_majority_voting(trajectories, prompt, entry_point)
|
| 189 |
+
|
| 190 |
+
all_voted_predictions.append([voted_code])
|
| 191 |
+
all_references.append(test_code)
|
| 192 |
+
|
| 193 |
+
detailed_logs.append({
|
| 194 |
+
"task_id": task_id,
|
| 195 |
+
"entry_point": entry_point,
|
| 196 |
+
"final_code": voted_code,
|
| 197 |
+
"nfe": data.get("nfe", 0),
|
| 198 |
+
"svf": data.get("svf_calls", 0),
|
| 199 |
+
})
|
| 200 |
+
|
| 201 |
+
if not all_voted_predictions: continue
|
| 202 |
+
|
| 203 |
+
print(f"执行测试中...")
|
| 204 |
+
pass_at_k, exec_results = code_eval.compute(
|
| 205 |
+
references=all_references,
|
| 206 |
+
predictions=all_voted_predictions,
|
| 207 |
+
k=[1]
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
accuracy = pass_at_k.get("pass@1", 0.0) * 100
|
| 211 |
+
avg_nfe = nfe_sum / valid_samples if valid_samples > 0 else 0
|
| 212 |
+
avg_svf = svf_sum / valid_samples if valid_samples > 0 else 0
|
| 213 |
+
|
| 214 |
+
for i, log in enumerate(detailed_logs):
|
| 215 |
+
res = exec_results.get(i, [])
|
| 216 |
+
log["passed"] = res[0][1].get("passed", False) if res else False
|
| 217 |
+
log["exec_msg"] = res[0][1].get("result", "failed") if res else "failed"
|
| 218 |
+
|
| 219 |
+
output_path = file_path.replace(".jsonl", "_voted_result.json")
|
| 220 |
+
final_report = {
|
| 221 |
+
"meta": {"file": file_path, "total_samples": valid_samples},
|
| 222 |
+
"metrics": {"accuracy": f"{accuracy:.2f}%", "avg_nfe": avg_nfe, "avg_svf": avg_svf},
|
| 223 |
+
"details": detailed_logs
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
with open(output_path, 'w', encoding='utf-8') as out_f:
|
| 227 |
+
json.dump(final_report, out_f, ensure_ascii=False, indent=4)
|
| 228 |
+
print(f"Accuracy: {accuracy:.2f}% | SVF: {avg_svf:.1f}\n")
|
| 229 |
+
|
| 230 |
+
if __name__ == "__main__":
|
| 231 |
+
parser = argparse.ArgumentParser()
|
| 232 |
+
parser.add_argument("-r", "--res_path", type=str, default=RES_PATH)
|
| 233 |
+
args = parser.parse_args()
|
| 234 |
+
run_evaluation(args.res_path)
|
Prism/Dream/Dream_Prism/metrics/math500_eval.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import re
|
| 3 |
+
import os
|
| 4 |
+
import math
|
| 5 |
+
import argparse
|
| 6 |
+
from collections import Counter
|
| 7 |
+
|
| 8 |
+
RES_PATH = "<PATH_TO_RESULTS_JSONL>"
|
| 9 |
+
|
| 10 |
+
def extract_answer(text):
|
| 11 |
+
if not text:
|
| 12 |
+
return "", False
|
| 13 |
+
text = text.replace("<|role_end|>", "").replace("<|endoftext|>", "").strip()
|
| 14 |
+
|
| 15 |
+
boxed_pattern = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}"
|
| 16 |
+
all_boxes = re.findall(boxed_pattern, text)
|
| 17 |
+
if all_boxes:
|
| 18 |
+
return all_boxes[-1], True
|
| 19 |
+
|
| 20 |
+
tag_match = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL)
|
| 21 |
+
if tag_match:
|
| 22 |
+
return tag_match.group(1).strip(), True
|
| 23 |
+
|
| 24 |
+
marker = "the answer is"
|
| 25 |
+
if marker in text.lower():
|
| 26 |
+
pos = text.lower().rfind(marker)
|
| 27 |
+
after_text = text[pos + len(marker):].strip()
|
| 28 |
+
after_text = re.sub(r"^[:\s]+", "", after_text)
|
| 29 |
+
return after_text.split('\n')[0].split('$')[0].strip(), True
|
| 30 |
+
|
| 31 |
+
tail = text[-50:].strip()
|
| 32 |
+
nums = re.findall(r"(-?\d+[\./\d]*|\\sqrt\{\d+\}|\(-?\d+.*?\))", tail)
|
| 33 |
+
if nums:
|
| 34 |
+
return nums[-1], False
|
| 35 |
+
return "", False
|
| 36 |
+
|
| 37 |
+
def normalize_math(string):
|
| 38 |
+
if not string: return ""
|
| 39 |
+
string = str(string).lower().strip()
|
| 40 |
+
|
| 41 |
+
string = string.replace("</reasoning>", "").replace("</answer>", "").replace("<answer>", "")
|
| 42 |
+
string = string.replace("...", "").replace("cannot be determined", "")
|
| 43 |
+
|
| 44 |
+
string = re.sub(r"([a-z]+|\\theta|\\alpha|\\pi)\s*=\s*", "", string)
|
| 45 |
+
string = re.sub(r"\\text\{([^}]*)\}", r"\1", string)
|
| 46 |
+
string = re.sub(r"\\(mathbf|mathrm|bold|unit|mbox|operatorname|mathrm)\{([^}]*)\}", r"\2", string)
|
| 47 |
+
string = re.sub(r"\\(d|t)?frac\{([^{}]*)\}\{([^{}]*)\}", r"\2/\3", string)
|
| 48 |
+
string = string.replace("\\!", "").replace("\\ ", "").replace("{", "").replace("}", "")
|
| 49 |
+
string = string.replace("\\left", "").replace("\\right", "")
|
| 50 |
+
string = string.replace("\\$", "").replace("$", "").replace("\\%", "").replace("%", "")
|
| 51 |
+
|
| 52 |
+
units_pattern = r"(units?|cm\^2|cm|inches|inch|square|degrees?|radians?|miles?|per|hour|cents?)"
|
| 53 |
+
string = re.sub(units_pattern, "", string)
|
| 54 |
+
string = string.replace("^{\\circ}", "").replace("^\\circ", "").replace("°", "").replace("\\degree", "")
|
| 55 |
+
string = string.replace("\\pi", "pi")
|
| 56 |
+
string = re.sub(r"(\d),(\d{3})", r"\1\2", string)
|
| 57 |
+
string = string.rstrip(".:,; ").replace(" ", "")
|
| 58 |
+
|
| 59 |
+
if "=" in string:
|
| 60 |
+
string = string.split("=")[-1]
|
| 61 |
+
|
| 62 |
+
return string
|
| 63 |
+
|
| 64 |
+
def is_equiv(pred, gold):
|
| 65 |
+
if not pred: return False
|
| 66 |
+
p, g = normalize_math(pred), normalize_math(gold)
|
| 67 |
+
if p == g: return True
|
| 68 |
+
|
| 69 |
+
if "=" in pred:
|
| 70 |
+
if normalize_math(pred.split("=")[-1]) == g:
|
| 71 |
+
return True
|
| 72 |
+
|
| 73 |
+
try:
|
| 74 |
+
def to_float(s):
|
| 75 |
+
if '/' in s and s.count('/') == 1:
|
| 76 |
+
parts = s.split('/')
|
| 77 |
+
return float(parts[0]) / float(parts[1])
|
| 78 |
+
if '_' in s: s = s.split('_')[0]
|
| 79 |
+
return float(s)
|
| 80 |
+
return math.isclose(to_float(p), to_float(g), rel_tol=1e-4)
|
| 81 |
+
except:
|
| 82 |
+
p_fuzzy = re.sub(r"[^a-z0-9/,\-]", "", p)
|
| 83 |
+
g_fuzzy = re.sub(r"[^a-z0-9/,\-]", "", g)
|
| 84 |
+
return p_fuzzy == g_fuzzy if p_fuzzy else False
|
| 85 |
+
|
| 86 |
+
def run_evaluation(target_path):
|
| 87 |
+
jsonl_files = []
|
| 88 |
+
if os.path.isdir(target_path):
|
| 89 |
+
for root, dirs, files in os.walk(target_path):
|
| 90 |
+
for file in files:
|
| 91 |
+
if file.endswith(".jsonl") and not file.startswith("eval_voted_"):
|
| 92 |
+
jsonl_files.append(os.path.join(root, file))
|
| 93 |
+
else:
|
| 94 |
+
jsonl_files = [target_path]
|
| 95 |
+
|
| 96 |
+
for file_path in jsonl_files:
|
| 97 |
+
print(f">>> 正在评测: {file_path}")
|
| 98 |
+
detailed_results = []
|
| 99 |
+
|
| 100 |
+
voted_correct_count = 0
|
| 101 |
+
total_count = 0
|
| 102 |
+
|
| 103 |
+
nfe_list = []
|
| 104 |
+
svf_list = []
|
| 105 |
+
|
| 106 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 107 |
+
for line in f:
|
| 108 |
+
if not line.strip(): continue
|
| 109 |
+
try:
|
| 110 |
+
item = json.loads(line)
|
| 111 |
+
except:
|
| 112 |
+
continue
|
| 113 |
+
|
| 114 |
+
doc = item.get("doc", {})
|
| 115 |
+
ground_truth = str(item.get("target", doc.get("answer", "")))
|
| 116 |
+
|
| 117 |
+
current_nfe = item.get("nfe", 0)
|
| 118 |
+
nfe_list.append(current_nfe)
|
| 119 |
+
current_svf = item.get("svf_calls", 0)
|
| 120 |
+
svf_list.append(current_svf)
|
| 121 |
+
|
| 122 |
+
ans_stats = {}
|
| 123 |
+
trajectories = item.get("all_trajectories", [])
|
| 124 |
+
|
| 125 |
+
for traj in trajectories:
|
| 126 |
+
raw_text = traj.get("resp", "")
|
| 127 |
+
score = traj.get("score", 0)
|
| 128 |
+
|
| 129 |
+
extracted, _ = extract_answer(raw_text)
|
| 130 |
+
if not extracted: continue
|
| 131 |
+
|
| 132 |
+
norm = normalize_math(extracted)
|
| 133 |
+
if norm not in ans_stats:
|
| 134 |
+
ans_stats[norm] = {
|
| 135 |
+
"count": 0,
|
| 136 |
+
"max_score": -float('inf'),
|
| 137 |
+
"total_weight": 0.0,
|
| 138 |
+
"original": extracted
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
ans_stats[norm]["count"] += 1
|
| 142 |
+
if score > ans_stats[norm]["max_score"]:
|
| 143 |
+
ans_stats[norm]["max_score"] = score
|
| 144 |
+
|
| 145 |
+
try:
|
| 146 |
+
weight = math.exp(score)
|
| 147 |
+
except OverflowError:
|
| 148 |
+
weight = float('inf')
|
| 149 |
+
ans_stats[norm]["total_weight"] += weight
|
| 150 |
+
|
| 151 |
+
if not ans_stats:
|
| 152 |
+
best_pred = ""
|
| 153 |
+
else:
|
| 154 |
+
sorted_norms = sorted(
|
| 155 |
+
ans_stats.keys(),
|
| 156 |
+
key=lambda x: (ans_stats[x]["total_weight"], ans_stats[x]["max_score"], ans_stats[x]["count"]),
|
| 157 |
+
reverse=True
|
| 158 |
+
)
|
| 159 |
+
best_norm = sorted_norms[0]
|
| 160 |
+
best_pred = ans_stats[best_norm]["original"]
|
| 161 |
+
|
| 162 |
+
is_voted_correct = False
|
| 163 |
+
if best_pred and is_equiv(best_pred, ground_truth):
|
| 164 |
+
voted_correct_count += 1
|
| 165 |
+
is_voted_correct = True
|
| 166 |
+
|
| 167 |
+
total_count += 1
|
| 168 |
+
|
| 169 |
+
detailed_results.append({
|
| 170 |
+
"question": doc.get("problem", "N/A"),
|
| 171 |
+
"final_voted_answer": best_pred,
|
| 172 |
+
"ground_truth": ground_truth,
|
| 173 |
+
"is_voted_correct": is_voted_correct,
|
| 174 |
+
"nfe": current_nfe,
|
| 175 |
+
"svf_calls": current_svf
|
| 176 |
+
})
|
| 177 |
+
|
| 178 |
+
accuracy = (voted_correct_count / total_count * 100) if total_count > 0 else 0
|
| 179 |
+
|
| 180 |
+
avg_nfe = sum(nfe_list) / len(nfe_list) if nfe_list else 0
|
| 181 |
+
avg_svf = sum(svf_list) / len(svf_list) if svf_list else 0
|
| 182 |
+
|
| 183 |
+
print(f"--- Accuracy : {accuracy:.2f}% | NFE: {avg_nfe:.1f} | SVF: {avg_svf:.1f} ---")
|
| 184 |
+
|
| 185 |
+
output_name = f"eval_voted_{os.path.basename(file_path).replace('.jsonl', '.json')}"
|
| 186 |
+
output_path = os.path.join(os.path.dirname(file_path), output_name)
|
| 187 |
+
|
| 188 |
+
final_report = {
|
| 189 |
+
"summary": {
|
| 190 |
+
"Accuracy": f"{accuracy:.2f}%",
|
| 191 |
+
"correct_voted_count": voted_correct_count,
|
| 192 |
+
"total": total_count,
|
| 193 |
+
"avg_nfe": avg_nfe,
|
| 194 |
+
"avg_svf": avg_svf
|
| 195 |
+
},
|
| 196 |
+
"details": detailed_results
|
| 197 |
+
}
|
| 198 |
+
with open(output_path, 'w', encoding='utf-8') as out_f:
|
| 199 |
+
json.dump(final_report, out_f, ensure_ascii=False, indent=4)
|
| 200 |
+
|
| 201 |
+
if __name__ == "__main__":
|
| 202 |
+
parser = argparse.ArgumentParser()
|
| 203 |
+
parser.add_argument("-r", "--res_path", type=str, default=RES_PATH)
|
| 204 |
+
args = parser.parse_args()
|
| 205 |
+
run_evaluation(args.res_path)
|
Prism/Dream/Dream_Prism/metrics/mbpp_eval.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import json
|
| 4 |
+
import ast
|
| 5 |
+
import re
|
| 6 |
+
import glob
|
| 7 |
+
import argparse
|
| 8 |
+
import textwrap
|
| 9 |
+
import evaluate as hf_evaluate
|
| 10 |
+
from collections import Counter
|
| 11 |
+
|
| 12 |
+
os.environ["HF_ALLOW_CODE_EVAL"] = "1"
|
| 13 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 14 |
+
|
| 15 |
+
RES_PATH = "<PATH_TO_RESULTS_JSONL>"
|
| 16 |
+
|
| 17 |
+
def strict_dedent(text: str) -> str:
|
| 18 |
+
lines = text.split('\n')
|
| 19 |
+
while lines and not lines[0].strip(): lines.pop(0)
|
| 20 |
+
while lines and not lines[-1].strip(): lines.pop()
|
| 21 |
+
|
| 22 |
+
if not lines:
|
| 23 |
+
return ""
|
| 24 |
+
|
| 25 |
+
min_indent = None
|
| 26 |
+
for line in lines:
|
| 27 |
+
if line.strip():
|
| 28 |
+
indent = len(line) - len(line.lstrip())
|
| 29 |
+
if min_indent is None or indent < min_indent:
|
| 30 |
+
min_indent = indent
|
| 31 |
+
|
| 32 |
+
if min_indent is None:
|
| 33 |
+
min_indent = 0
|
| 34 |
+
|
| 35 |
+
dedented_lines = []
|
| 36 |
+
for line in lines:
|
| 37 |
+
if line.strip():
|
| 38 |
+
if len(line) >= min_indent:
|
| 39 |
+
dedented_lines.append(line[min_indent:])
|
| 40 |
+
else:
|
| 41 |
+
dedented_lines.append(line.lstrip())
|
| 42 |
+
else:
|
| 43 |
+
dedented_lines.append("")
|
| 44 |
+
|
| 45 |
+
return "\n".join(dedented_lines)
|
| 46 |
+
|
| 47 |
+
def extract_python_code(text: str) -> str:
|
| 48 |
+
if not text:
|
| 49 |
+
return ""
|
| 50 |
+
|
| 51 |
+
text = text.replace("<|role_end|>", "").replace("<|endoftext|>", "").replace("<|notification_end|>", "")
|
| 52 |
+
text = text.replace("[DONE]", "")
|
| 53 |
+
|
| 54 |
+
tag_match = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL)
|
| 55 |
+
if tag_match:
|
| 56 |
+
text = tag_match.group(1)
|
| 57 |
+
|
| 58 |
+
code_block_pattern = re.compile(r"```(?:python)?\n?(.*?)```", re.DOTALL)
|
| 59 |
+
match = code_block_pattern.search(text)
|
| 60 |
+
|
| 61 |
+
if match:
|
| 62 |
+
content = match.group(1)
|
| 63 |
+
else:
|
| 64 |
+
if "```" in text:
|
| 65 |
+
content = text.split("```")[0]
|
| 66 |
+
else:
|
| 67 |
+
lines = text.split('\n')
|
| 68 |
+
start_idx = 0
|
| 69 |
+
stop_words = ["Here is", "Explanation", "Example", "Note", "python", "The code"]
|
| 70 |
+
|
| 71 |
+
for i, line in enumerate(lines):
|
| 72 |
+
stripped = line.strip()
|
| 73 |
+
if stripped.startswith(("def ", "import ", "from ", "class ")):
|
| 74 |
+
start_idx = i
|
| 75 |
+
break
|
| 76 |
+
if any(sw in line for sw in stop_words) and not stripped.endswith(":"):
|
| 77 |
+
continue
|
| 78 |
+
|
| 79 |
+
content = "\n".join(lines[start_idx:])
|
| 80 |
+
|
| 81 |
+
return strict_dedent(content)
|
| 82 |
+
|
| 83 |
+
def normalize_code(code: str) -> str:
|
| 84 |
+
try:
|
| 85 |
+
tree = ast.parse(code)
|
| 86 |
+
for node in ast.walk(tree):
|
| 87 |
+
if isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Module)):
|
| 88 |
+
if (node.body and isinstance(node.body[0], ast.Expr) and
|
| 89 |
+
isinstance(node.body[0].value, ast.Constant) and isinstance(node.body[0].value.value, str)):
|
| 90 |
+
node.body.pop(0)
|
| 91 |
+
return ast.unparse(tree).strip()
|
| 92 |
+
except:
|
| 93 |
+
return re.sub(r"\s+", "", code)
|
| 94 |
+
|
| 95 |
+
def perform_majority_voting(trajectories):
|
| 96 |
+
candidate_stats = {}
|
| 97 |
+
|
| 98 |
+
for item in trajectories:
|
| 99 |
+
if isinstance(item, dict):
|
| 100 |
+
raw_text = item.get("resp", "")
|
| 101 |
+
score = item.get("score", 0.0)
|
| 102 |
+
elif isinstance(item, (list, tuple)):
|
| 103 |
+
raw_text = item[0]
|
| 104 |
+
score = 0.0
|
| 105 |
+
else:
|
| 106 |
+
raw_text = str(item)
|
| 107 |
+
score = 0.0
|
| 108 |
+
|
| 109 |
+
extracted_code = extract_python_code(raw_text)
|
| 110 |
+
|
| 111 |
+
if not extracted_code.strip():
|
| 112 |
+
continue
|
| 113 |
+
|
| 114 |
+
is_valid = False
|
| 115 |
+
try:
|
| 116 |
+
ast.parse(extracted_code)
|
| 117 |
+
is_valid = True
|
| 118 |
+
except:
|
| 119 |
+
is_valid = False
|
| 120 |
+
|
| 121 |
+
norm_key = normalize_code(extracted_code)
|
| 122 |
+
if not norm_key: continue
|
| 123 |
+
|
| 124 |
+
if norm_key not in candidate_stats:
|
| 125 |
+
candidate_stats[norm_key] = {
|
| 126 |
+
"count": 0,
|
| 127 |
+
"max_score": -float("inf"),
|
| 128 |
+
"code": extracted_code,
|
| 129 |
+
"is_valid": is_valid
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
candidate_stats[norm_key]["count"] += 1
|
| 133 |
+
candidate_stats[norm_key]["max_score"] = max(candidate_stats[norm_key]["max_score"], score)
|
| 134 |
+
|
| 135 |
+
if not candidate_stats:
|
| 136 |
+
return ""
|
| 137 |
+
|
| 138 |
+
sorted_candidates = sorted(
|
| 139 |
+
candidate_stats.values(),
|
| 140 |
+
key=lambda x: (x["is_valid"], x["count"], x["max_score"]),
|
| 141 |
+
reverse=True
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
return sorted_candidates[0]["code"]
|
| 145 |
+
|
| 146 |
+
def run_evaluation(target_path):
|
| 147 |
+
if os.path.isdir(target_path):
|
| 148 |
+
jsonl_files = glob.glob(os.path.join(target_path, "*.jsonl"))
|
| 149 |
+
else:
|
| 150 |
+
jsonl_files = [target_path]
|
| 151 |
+
|
| 152 |
+
try:
|
| 153 |
+
code_eval = hf_evaluate.load("code_eval")
|
| 154 |
+
except Exception as e:
|
| 155 |
+
print(f"Error loading code_eval: {e}")
|
| 156 |
+
return
|
| 157 |
+
|
| 158 |
+
for file_path in jsonl_files:
|
| 159 |
+
print(f"\n>>> 正在评测 MBPP 文件: {file_path}")
|
| 160 |
+
|
| 161 |
+
all_voted_predictions = []
|
| 162 |
+
all_references = []
|
| 163 |
+
detailed_logs = []
|
| 164 |
+
|
| 165 |
+
nfe_total = 0
|
| 166 |
+
svf_total = 0
|
| 167 |
+
count_valid_samples = 0
|
| 168 |
+
|
| 169 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 170 |
+
lines = f.readlines()
|
| 171 |
+
|
| 172 |
+
for idx, line in enumerate(lines):
|
| 173 |
+
if not line.strip(): continue
|
| 174 |
+
try:
|
| 175 |
+
data = json.loads(line)
|
| 176 |
+
except json.JSONDecodeError:
|
| 177 |
+
continue
|
| 178 |
+
|
| 179 |
+
doc = data.get("doc", {})
|
| 180 |
+
task_id = doc.get("task_id", f"MBPP_{idx}")
|
| 181 |
+
|
| 182 |
+
test_list = doc.get("test_list", [])
|
| 183 |
+
test_setup = doc.get("test_setup_code", "")
|
| 184 |
+
challenge_tests = doc.get("challenge_test_list", [])
|
| 185 |
+
|
| 186 |
+
full_test_code = ""
|
| 187 |
+
if test_setup:
|
| 188 |
+
full_test_code += test_setup + "\n"
|
| 189 |
+
if test_list:
|
| 190 |
+
full_test_code += "\n".join(test_list) + "\n"
|
| 191 |
+
if challenge_tests:
|
| 192 |
+
full_test_code += "\n".join(challenge_tests)
|
| 193 |
+
|
| 194 |
+
current_nfe = data.get("nfe", 0)
|
| 195 |
+
current_svf = data.get("svf_calls", 0)
|
| 196 |
+
|
| 197 |
+
nfe_total += current_nfe
|
| 198 |
+
svf_total += current_svf
|
| 199 |
+
count_valid_samples += 1
|
| 200 |
+
|
| 201 |
+
trajectories = data.get("all_trajectories", [])
|
| 202 |
+
if not trajectories:
|
| 203 |
+
resps = data.get("resps", [])
|
| 204 |
+
trajectories = [{"resp": r} for r in resps]
|
| 205 |
+
|
| 206 |
+
voted_code = perform_majority_voting(trajectories)
|
| 207 |
+
|
| 208 |
+
if not voted_code:
|
| 209 |
+
voted_code = "def placeholder(): pass"
|
| 210 |
+
|
| 211 |
+
all_voted_predictions.append([voted_code])
|
| 212 |
+
all_references.append(full_test_code)
|
| 213 |
+
|
| 214 |
+
detailed_logs.append({
|
| 215 |
+
"task_id": task_id,
|
| 216 |
+
"final_code": voted_code,
|
| 217 |
+
"reference": full_test_code,
|
| 218 |
+
"nfe": current_nfe,
|
| 219 |
+
"svf": current_svf,
|
| 220 |
+
"traj_count": len(trajectories)
|
| 221 |
+
})
|
| 222 |
+
|
| 223 |
+
if not all_voted_predictions:
|
| 224 |
+
print("未找到有效数据。")
|
| 225 |
+
continue
|
| 226 |
+
|
| 227 |
+
print(f"正在执行代码测试 (共 {len(all_voted_predictions)} 题)...")
|
| 228 |
+
|
| 229 |
+
pass_at_k, exec_results = code_eval.compute(
|
| 230 |
+
references=all_references,
|
| 231 |
+
predictions=all_voted_predictions,
|
| 232 |
+
k=[1],
|
| 233 |
+
num_workers=4
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
accuracy = pass_at_k.get("pass@1", 0.0) * 100
|
| 237 |
+
avg_nfe = nfe_total / count_valid_samples if count_valid_samples > 0 else 0
|
| 238 |
+
avg_svf = svf_total / count_valid_samples if count_valid_samples > 0 else 0
|
| 239 |
+
print(f"Accuracy: {accuracy:.2f}% | NFE: {avg_nfe:.1f} | SVF: {avg_svf:.1f}")
|
| 240 |
+
|
| 241 |
+
for i, log in enumerate(detailed_logs):
|
| 242 |
+
res = exec_results.get(i, [])
|
| 243 |
+
if res and len(res) > 0:
|
| 244 |
+
is_passed = res[0][1].get("passed", False)
|
| 245 |
+
eval_result_str = res[0][1].get("result", "passed") if not is_passed else "passed"
|
| 246 |
+
else:
|
| 247 |
+
is_passed = False
|
| 248 |
+
eval_result_str = "Execution Failed"
|
| 249 |
+
|
| 250 |
+
log["passed"] = is_passed
|
| 251 |
+
log["exec_msg"] = eval_result_str
|
| 252 |
+
|
| 253 |
+
output_name = f"eval_mbpp_{os.path.basename(file_path).replace('.jsonl', '.json')}"
|
| 254 |
+
output_path = os.path.join(os.path.dirname(file_path), output_name)
|
| 255 |
+
|
| 256 |
+
final_report = {
|
| 257 |
+
"meta": {
|
| 258 |
+
"file": file_path,
|
| 259 |
+
"total_samples": count_valid_samples
|
| 260 |
+
},
|
| 261 |
+
"metrics": {
|
| 262 |
+
"accuracy": f"{accuracy:.2f}%",
|
| 263 |
+
"avg_nfe": avg_nfe,
|
| 264 |
+
"avg_svf": avg_svf
|
| 265 |
+
},
|
| 266 |
+
"details": detailed_logs
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
with open(output_path, 'w', encoding='utf-8') as out_f:
|
| 270 |
+
json.dump(final_report, out_f, ensure_ascii=False, indent=4)
|
| 271 |
+
print(f"结果已保存至: {output_path}\n")
|
| 272 |
+
|
| 273 |
+
if __name__ == "__main__":
|
| 274 |
+
parser = argparse.ArgumentParser(description="MBPP Metrics Evaluation Script")
|
| 275 |
+
parser.add_argument("-r", "--res_path", type=str, default=RES_PATH, help="Path to jsonl result file or directory")
|
| 276 |
+
args = parser.parse_args()
|
| 277 |
+
|
| 278 |
+
if os.path.exists(args.res_path):
|
| 279 |
+
run_evaluation(args.res_path)
|
| 280 |
+
else:
|
| 281 |
+
print(f"Path not found: {args.res_path}")
|
Prism/Dream/Dream_Prism/scripts/run_gsm8k.sh
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -e
|
| 3 |
+
set -x
|
| 4 |
+
|
| 5 |
+
PROJECT_ROOT="<PATH_TO_YOUR_DREAM_ROOT>"
|
| 6 |
+
MODEL_PATH="<PATH_TO_YOUR_DREAM_V0_INSTRUCT_7B>"
|
| 7 |
+
BASE_OUTPUT_PATH="${PROJECT_ROOT}/outputs/dream_gsm8k"
|
| 8 |
+
|
| 9 |
+
cd ${PROJECT_ROOT}
|
| 10 |
+
export CUDA_VISIBLE_DEVICES=0
|
| 11 |
+
export HF_ENDPOINT=https://hf-mirror.com
|
| 12 |
+
export HF_ALLOW_CODE_EVAL=1
|
| 13 |
+
export PYTHONPATH=.
|
| 14 |
+
|
| 15 |
+
TASK="gsm8k"
|
| 16 |
+
LENGTH=256
|
| 17 |
+
STEPS=256
|
| 18 |
+
PORT=12334
|
| 19 |
+
NAME="win_0.1-0.6_s2_k4"
|
| 20 |
+
|
| 21 |
+
mkdir -p "${BASE_OUTPUT_PATH}/${NAME}"
|
| 22 |
+
|
| 23 |
+
accelerate launch --main_process_port ${PORT} -m lm_eval\
|
| 24 |
+
--model diffllm \
|
| 25 |
+
--tasks ${TASK} \
|
| 26 |
+
--batch_size 1 \
|
| 27 |
+
--model_args "pretrained=${MODEL_PATH},trust_remote_code=True,dtype=bfloat16,max_new_tokens=${LENGTH},diffusion_steps=${STEPS}" \
|
| 28 |
+
--gen_kwargs "use_hts=True,initial_N=16,final_K=4,hts_survivor_k=2,hts_mode=True,hts_start_pct=0.1,hts_end_pct=0.6,pruning_interval=3,decay_factor=1.8,reward_mode=svf,task_type=math,temperature=0.7,realtime_output=${BASE_OUTPUT_PATH}/${NAME}/res.jsonl" \
|
| 29 |
+
--num_fewshot 0 \
|
| 30 |
+
--confirm_run_unsafe_code \
|
| 31 |
+
--output_path "${BASE_OUTPUT_PATH}/${NAME}"
|
Prism/Dream/Dream_Prism/scripts/run_humaneval.sh
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -e
|
| 3 |
+
set -x
|
| 4 |
+
|
| 5 |
+
PROJECT_ROOT="<PATH_TO_YOUR_DREAM_ROOT>"
|
| 6 |
+
MODEL_PATH="<PATH_TO_YOUR_DREAM_V0_INSTRUCT_7B>"
|
| 7 |
+
BASE_OUTPUT_PATH="${PROJECT_ROOT}/outputs/dream_humaneval"
|
| 8 |
+
|
| 9 |
+
cd ${PROJECT_ROOT}
|
| 10 |
+
export CUDA_VISIBLE_DEVICES=0
|
| 11 |
+
export HF_ENDPOINT=https://hf-mirror.com
|
| 12 |
+
export HF_ALLOW_CODE_EVAL=1
|
| 13 |
+
export PYTHONPATH=.
|
| 14 |
+
|
| 15 |
+
TASK="humaneval_instruct"
|
| 16 |
+
LENGTH=512
|
| 17 |
+
STEPS=512
|
| 18 |
+
PORT=12334
|
| 19 |
+
NAME="win_0.1-0.6_s2_k4"
|
| 20 |
+
|
| 21 |
+
mkdir -p "${BASE_OUTPUT_PATH}/${NAME}"
|
| 22 |
+
|
| 23 |
+
accelerate launch --main_process_port ${PORT} -m lm_eval\
|
| 24 |
+
--model diffllm \
|
| 25 |
+
--tasks ${TASK} \
|
| 26 |
+
--batch_size 1 \
|
| 27 |
+
--model_args "pretrained=${MODEL_PATH},trust_remote_code=True,dtype=bfloat16,max_new_tokens=${LENGTH},diffusion_steps=${STEPS}" \
|
| 28 |
+
--gen_kwargs "use_hts=True,initial_N=16,final_K=4,hts_survivor_k=2,hts_mode=True,hts_start_pct=0.1,hts_end_pct=0.6,pruning_interval=20,decay_factor=1.8,reward_mode=svf,task_type=code,temperature=0.7,realtime_output=${BASE_OUTPUT_PATH}/${NAME}/res.jsonl" \
|
| 29 |
+
--num_fewshot 0 \
|
| 30 |
+
--confirm_run_unsafe_code \
|
| 31 |
+
--output_path "${BASE_OUTPUT_PATH}/${NAME}"
|
Prism/Dream/Dream_Prism/scripts/run_math500.sh
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -e
|
| 3 |
+
set -x
|
| 4 |
+
|
| 5 |
+
PROJECT_ROOT="<PATH_TO_YOUR_DREAM_ROOT>"
|
| 6 |
+
MODEL_PATH="<PATH_TO_YOUR_DREAM_V0_INSTRUCT_7B>"
|
| 7 |
+
BASE_OUTPUT_PATH="${PROJECT_ROOT}/outputs/dream_math500"
|
| 8 |
+
|
| 9 |
+
cd ${PROJECT_ROOT}
|
| 10 |
+
export CUDA_VISIBLE_DEVICES=0
|
| 11 |
+
export HF_ENDPOINT=https://hf-mirror.com
|
| 12 |
+
export HF_ALLOW_CODE_EVAL=1
|
| 13 |
+
export PYTHONPATH=.
|
| 14 |
+
|
| 15 |
+
TASK="math500"
|
| 16 |
+
LENGTH=256
|
| 17 |
+
STEPS=256
|
| 18 |
+
PORT=12334
|
| 19 |
+
NAME="win_0.1-0.6_s2_k4"
|
| 20 |
+
|
| 21 |
+
mkdir -p "${BASE_OUTPUT_PATH}/${NAME}"
|
| 22 |
+
|
| 23 |
+
accelerate launch --main_process_port ${PORT} -m lm_eval\
|
| 24 |
+
--model diffllm \
|
| 25 |
+
--tasks ${TASK} \
|
| 26 |
+
--batch_size 1 \
|
| 27 |
+
--model_args "pretrained=${MODEL_PATH},trust_remote_code=True,dtype=bfloat16,max_new_tokens=${LENGTH},diffusion_steps=${STEPS}" \
|
| 28 |
+
--gen_kwargs "use_hts=True,initial_N=16,final_K=4,hts_survivor_k=2,hts_mode=True,hts_start_pct=0.1,hts_end_pct=0.6,pruning_interval=10,decay_factor=1.8,reward_mode=svf,task_type=math,temperature=0.7,realtime_output=${BASE_OUTPUT_PATH}/${NAME}/res.jsonl" \
|
| 29 |
+
--num_fewshot 0 \
|
| 30 |
+
--output_path "${BASE_OUTPUT_PATH}/${NAME}"
|
Prism/Dream/Dream_Prism/scripts/run_mbpp.sh
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -e
|
| 3 |
+
set -x
|
| 4 |
+
|
| 5 |
+
PROJECT_ROOT="<PATH_TO_YOUR_DREAM_ROOT>"
|
| 6 |
+
MODEL_PATH="<PATH_TO_YOUR_DREAM_V0_INSTRUCT_7B>"
|
| 7 |
+
BASE_OUTPUT_PATH="${PROJECT_ROOT}/outputs/dream_mbpp"
|
| 8 |
+
|
| 9 |
+
cd ${PROJECT_ROOT}
|
| 10 |
+
export CUDA_VISIBLE_DEVICES=0
|
| 11 |
+
export HF_ENDPOINT=https://hf-mirror.com
|
| 12 |
+
export HF_ALLOW_CODE_EVAL=1
|
| 13 |
+
export PYTHONPATH=.
|
| 14 |
+
|
| 15 |
+
TASK="mbpp"
|
| 16 |
+
LENGTH=512
|
| 17 |
+
STEPS=512
|
| 18 |
+
PORT=12334
|
| 19 |
+
NAME="win_0.1-0.6_s2_k4"
|
| 20 |
+
|
| 21 |
+
mkdir -p "${BASE_OUTPUT_PATH}/${NAME}"
|
| 22 |
+
|
| 23 |
+
accelerate launch --main_process_port ${PORT} -m lm_eval\
|
| 24 |
+
--model diffllm \
|
| 25 |
+
--tasks ${TASK} \
|
| 26 |
+
--batch_size 1 \
|
| 27 |
+
--model_args "pretrained=${MODEL_PATH},trust_remote_code=True,dtype=bfloat16,max_new_tokens=${LENGTH},diffusion_steps=${STEPS}" \
|
| 28 |
+
--gen_kwargs "use_hts=True,initial_N=16,final_K=4,hts_survivor_k=2,hts_mode=True,hts_start_pct=0.1,hts_end_pct=0.6,pruning_interval=3,decay_factor=1.8,reward_mode=svf,task_type=code,temperature=0.7,realtime_output=${BASE_OUTPUT_PATH}/${NAME}/res.jsonl" \
|
| 29 |
+
--num_fewshot 0 \
|
| 30 |
+
--confirm_run_unsafe_code \
|
| 31 |
+
--output_path "${BASE_OUTPUT_PATH}/${NAME}"
|
Prism/Dream/Dream_Prism/src/__init__.py
ADDED
|
File without changes
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/aiohappyeyeballs/__pycache__/types.cpython-312.pyc
ADDED
|
Binary file (659 Bytes). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/components/semiconnected.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Semiconnectedness."""
|
| 2 |
+
|
| 3 |
+
import networkx as nx
|
| 4 |
+
from networkx.utils import not_implemented_for, pairwise
|
| 5 |
+
|
| 6 |
+
__all__ = ["is_semiconnected"]
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@not_implemented_for("undirected")
|
| 10 |
+
@nx._dispatchable
|
| 11 |
+
def is_semiconnected(G):
|
| 12 |
+
r"""Returns True if the graph is semiconnected, False otherwise.
|
| 13 |
+
|
| 14 |
+
A graph is semiconnected if and only if for any pair of nodes, either one
|
| 15 |
+
is reachable from the other, or they are mutually reachable.
|
| 16 |
+
|
| 17 |
+
This function uses a theorem that states that a DAG is semiconnected
|
| 18 |
+
if for any topological sort, for node $v_n$ in that sort, there is an
|
| 19 |
+
edge $(v_i, v_{i+1})$. That allows us to check if a non-DAG `G` is
|
| 20 |
+
semiconnected by condensing the graph: i.e. constructing a new graph `H`
|
| 21 |
+
with nodes being the strongly connected components of `G`, and edges
|
| 22 |
+
(scc_1, scc_2) if there is a edge $(v_1, v_2)$ in `G` for some
|
| 23 |
+
$v_1 \in scc_1$ and $v_2 \in scc_2$. That results in a DAG, so we compute
|
| 24 |
+
the topological sort of `H` and check if for every $n$ there is an edge
|
| 25 |
+
$(scc_n, scc_{n+1})$.
|
| 26 |
+
|
| 27 |
+
Parameters
|
| 28 |
+
----------
|
| 29 |
+
G : NetworkX graph
|
| 30 |
+
A directed graph.
|
| 31 |
+
|
| 32 |
+
Returns
|
| 33 |
+
-------
|
| 34 |
+
semiconnected : bool
|
| 35 |
+
True if the graph is semiconnected, False otherwise.
|
| 36 |
+
|
| 37 |
+
Raises
|
| 38 |
+
------
|
| 39 |
+
NetworkXNotImplemented
|
| 40 |
+
If the input graph is undirected.
|
| 41 |
+
|
| 42 |
+
NetworkXPointlessConcept
|
| 43 |
+
If the graph is empty.
|
| 44 |
+
|
| 45 |
+
Examples
|
| 46 |
+
--------
|
| 47 |
+
>>> G = nx.path_graph(4, create_using=nx.DiGraph())
|
| 48 |
+
>>> print(nx.is_semiconnected(G))
|
| 49 |
+
True
|
| 50 |
+
>>> G = nx.DiGraph([(1, 2), (3, 2)])
|
| 51 |
+
>>> print(nx.is_semiconnected(G))
|
| 52 |
+
False
|
| 53 |
+
|
| 54 |
+
See Also
|
| 55 |
+
--------
|
| 56 |
+
is_strongly_connected
|
| 57 |
+
is_weakly_connected
|
| 58 |
+
is_connected
|
| 59 |
+
is_biconnected
|
| 60 |
+
"""
|
| 61 |
+
if len(G) == 0:
|
| 62 |
+
raise nx.NetworkXPointlessConcept(
|
| 63 |
+
"Connectivity is undefined for the null graph."
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
if not nx.is_weakly_connected(G):
|
| 67 |
+
return False
|
| 68 |
+
|
| 69 |
+
H = nx.condensation(G)
|
| 70 |
+
|
| 71 |
+
return all(H.has_edge(u, v) for u, v in pairwise(nx.topological_sort(H)))
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/operators/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from networkx.algorithms.operators.all import *
|
| 2 |
+
from networkx.algorithms.operators.binary import *
|
| 3 |
+
from networkx.algorithms.operators.product import *
|
| 4 |
+
from networkx.algorithms.operators.unary import *
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/operators/all.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Operations on many graphs."""
|
| 2 |
+
|
| 3 |
+
from itertools import chain, repeat
|
| 4 |
+
|
| 5 |
+
import networkx as nx
|
| 6 |
+
|
| 7 |
+
__all__ = ["union_all", "compose_all", "disjoint_union_all", "intersection_all"]
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@nx._dispatchable(graphs="[graphs]", preserve_all_attrs=True, returns_graph=True)
|
| 11 |
+
def union_all(graphs, rename=()):
|
| 12 |
+
"""Returns the union of all graphs.
|
| 13 |
+
|
| 14 |
+
The graphs must be disjoint, otherwise an exception is raised.
|
| 15 |
+
|
| 16 |
+
Parameters
|
| 17 |
+
----------
|
| 18 |
+
graphs : iterable
|
| 19 |
+
Iterable of NetworkX graphs
|
| 20 |
+
|
| 21 |
+
rename : iterable , optional
|
| 22 |
+
Node names of graphs can be changed by specifying the tuple
|
| 23 |
+
rename=('G-','H-') (for example). Node "u" in G is then renamed
|
| 24 |
+
"G-u" and "v" in H is renamed "H-v". Infinite generators (like itertools.count)
|
| 25 |
+
are also supported.
|
| 26 |
+
|
| 27 |
+
Returns
|
| 28 |
+
-------
|
| 29 |
+
U : a graph with the same type as the first graph in list
|
| 30 |
+
|
| 31 |
+
Raises
|
| 32 |
+
------
|
| 33 |
+
ValueError
|
| 34 |
+
If `graphs` is an empty list.
|
| 35 |
+
|
| 36 |
+
NetworkXError
|
| 37 |
+
In case of mixed type graphs, like MultiGraph and Graph, or directed and undirected graphs.
|
| 38 |
+
|
| 39 |
+
Notes
|
| 40 |
+
-----
|
| 41 |
+
For operating on mixed type graphs, they should be converted to the same type.
|
| 42 |
+
>>> G = nx.Graph()
|
| 43 |
+
>>> H = nx.DiGraph()
|
| 44 |
+
>>> GH = union_all([nx.DiGraph(G), H])
|
| 45 |
+
|
| 46 |
+
To force a disjoint union with node relabeling, use
|
| 47 |
+
disjoint_union_all(G,H) or convert_node_labels_to integers().
|
| 48 |
+
|
| 49 |
+
Graph, edge, and node attributes are propagated to the union graph.
|
| 50 |
+
If a graph attribute is present in multiple graphs, then the value
|
| 51 |
+
from the last graph in the list with that attribute is used.
|
| 52 |
+
|
| 53 |
+
Examples
|
| 54 |
+
--------
|
| 55 |
+
>>> G1 = nx.Graph([(1, 2), (2, 3)])
|
| 56 |
+
>>> G2 = nx.Graph([(4, 5), (5, 6)])
|
| 57 |
+
>>> result_graph = nx.union_all([G1, G2])
|
| 58 |
+
>>> result_graph.nodes()
|
| 59 |
+
NodeView((1, 2, 3, 4, 5, 6))
|
| 60 |
+
>>> result_graph.edges()
|
| 61 |
+
EdgeView([(1, 2), (2, 3), (4, 5), (5, 6)])
|
| 62 |
+
|
| 63 |
+
See Also
|
| 64 |
+
--------
|
| 65 |
+
union
|
| 66 |
+
disjoint_union_all
|
| 67 |
+
"""
|
| 68 |
+
R = None
|
| 69 |
+
seen_nodes = set()
|
| 70 |
+
|
| 71 |
+
# rename graph to obtain disjoint node labels
|
| 72 |
+
def add_prefix(graph, prefix):
|
| 73 |
+
if prefix is None:
|
| 74 |
+
return graph
|
| 75 |
+
|
| 76 |
+
def label(x):
|
| 77 |
+
return f"{prefix}{x}"
|
| 78 |
+
|
| 79 |
+
return nx.relabel_nodes(graph, label)
|
| 80 |
+
|
| 81 |
+
rename = chain(rename, repeat(None))
|
| 82 |
+
graphs = (add_prefix(G, name) for G, name in zip(graphs, rename))
|
| 83 |
+
|
| 84 |
+
for i, G in enumerate(graphs):
|
| 85 |
+
G_nodes_set = set(G.nodes)
|
| 86 |
+
if i == 0:
|
| 87 |
+
# Union is the same type as first graph
|
| 88 |
+
R = G.__class__()
|
| 89 |
+
elif G.is_directed() != R.is_directed():
|
| 90 |
+
raise nx.NetworkXError("All graphs must be directed or undirected.")
|
| 91 |
+
elif G.is_multigraph() != R.is_multigraph():
|
| 92 |
+
raise nx.NetworkXError("All graphs must be graphs or multigraphs.")
|
| 93 |
+
elif not seen_nodes.isdisjoint(G_nodes_set):
|
| 94 |
+
raise nx.NetworkXError(
|
| 95 |
+
"The node sets of the graphs are not disjoint.\n"
|
| 96 |
+
"Use `rename` to specify prefixes for the graphs or use\n"
|
| 97 |
+
"disjoint_union(G1, G2, ..., GN)."
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
seen_nodes |= G_nodes_set
|
| 101 |
+
R.graph.update(G.graph)
|
| 102 |
+
R.add_nodes_from(G.nodes(data=True))
|
| 103 |
+
R.add_edges_from(
|
| 104 |
+
G.edges(keys=True, data=True) if G.is_multigraph() else G.edges(data=True)
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
if R is None:
|
| 108 |
+
raise ValueError("cannot apply union_all to an empty list")
|
| 109 |
+
|
| 110 |
+
return R
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@nx._dispatchable(graphs="[graphs]", preserve_all_attrs=True, returns_graph=True)
|
| 114 |
+
def disjoint_union_all(graphs):
|
| 115 |
+
"""Returns the disjoint union of all graphs.
|
| 116 |
+
|
| 117 |
+
This operation forces distinct integer node labels starting with 0
|
| 118 |
+
for the first graph in the list and numbering consecutively.
|
| 119 |
+
|
| 120 |
+
Parameters
|
| 121 |
+
----------
|
| 122 |
+
graphs : iterable
|
| 123 |
+
Iterable of NetworkX graphs
|
| 124 |
+
|
| 125 |
+
Returns
|
| 126 |
+
-------
|
| 127 |
+
U : A graph with the same type as the first graph in list
|
| 128 |
+
|
| 129 |
+
Raises
|
| 130 |
+
------
|
| 131 |
+
ValueError
|
| 132 |
+
If `graphs` is an empty list.
|
| 133 |
+
|
| 134 |
+
NetworkXError
|
| 135 |
+
In case of mixed type graphs, like MultiGraph and Graph, or directed and undirected graphs.
|
| 136 |
+
|
| 137 |
+
Examples
|
| 138 |
+
--------
|
| 139 |
+
>>> G1 = nx.Graph([(1, 2), (2, 3)])
|
| 140 |
+
>>> G2 = nx.Graph([(4, 5), (5, 6)])
|
| 141 |
+
>>> U = nx.disjoint_union_all([G1, G2])
|
| 142 |
+
>>> list(U.nodes())
|
| 143 |
+
[0, 1, 2, 3, 4, 5]
|
| 144 |
+
>>> list(U.edges())
|
| 145 |
+
[(0, 1), (1, 2), (3, 4), (4, 5)]
|
| 146 |
+
|
| 147 |
+
Notes
|
| 148 |
+
-----
|
| 149 |
+
For operating on mixed type graphs, they should be converted to the same type.
|
| 150 |
+
|
| 151 |
+
Graph, edge, and node attributes are propagated to the union graph.
|
| 152 |
+
If a graph attribute is present in multiple graphs, then the value
|
| 153 |
+
from the last graph in the list with that attribute is used.
|
| 154 |
+
"""
|
| 155 |
+
|
| 156 |
+
def yield_relabeled(graphs):
|
| 157 |
+
first_label = 0
|
| 158 |
+
for G in graphs:
|
| 159 |
+
yield nx.convert_node_labels_to_integers(G, first_label=first_label)
|
| 160 |
+
first_label += len(G)
|
| 161 |
+
|
| 162 |
+
R = union_all(yield_relabeled(graphs))
|
| 163 |
+
|
| 164 |
+
return R
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
@nx._dispatchable(graphs="[graphs]", preserve_all_attrs=True, returns_graph=True)
|
| 168 |
+
def compose_all(graphs):
|
| 169 |
+
"""Returns the composition of all graphs.
|
| 170 |
+
|
| 171 |
+
Composition is the simple union of the node sets and edge sets.
|
| 172 |
+
The node sets of the supplied graphs need not be disjoint.
|
| 173 |
+
|
| 174 |
+
Parameters
|
| 175 |
+
----------
|
| 176 |
+
graphs : iterable
|
| 177 |
+
Iterable of NetworkX graphs
|
| 178 |
+
|
| 179 |
+
Returns
|
| 180 |
+
-------
|
| 181 |
+
C : A graph with the same type as the first graph in list
|
| 182 |
+
|
| 183 |
+
Raises
|
| 184 |
+
------
|
| 185 |
+
ValueError
|
| 186 |
+
If `graphs` is an empty list.
|
| 187 |
+
|
| 188 |
+
NetworkXError
|
| 189 |
+
In case of mixed type graphs, like MultiGraph and Graph, or directed and undirected graphs.
|
| 190 |
+
|
| 191 |
+
Examples
|
| 192 |
+
--------
|
| 193 |
+
>>> G1 = nx.Graph([(1, 2), (2, 3)])
|
| 194 |
+
>>> G2 = nx.Graph([(3, 4), (5, 6)])
|
| 195 |
+
>>> C = nx.compose_all([G1, G2])
|
| 196 |
+
>>> list(C.nodes())
|
| 197 |
+
[1, 2, 3, 4, 5, 6]
|
| 198 |
+
>>> list(C.edges())
|
| 199 |
+
[(1, 2), (2, 3), (3, 4), (5, 6)]
|
| 200 |
+
|
| 201 |
+
Notes
|
| 202 |
+
-----
|
| 203 |
+
For operating on mixed type graphs, they should be converted to the same type.
|
| 204 |
+
|
| 205 |
+
Graph, edge, and node attributes are propagated to the union graph.
|
| 206 |
+
If a graph attribute is present in multiple graphs, then the value
|
| 207 |
+
from the last graph in the list with that attribute is used.
|
| 208 |
+
"""
|
| 209 |
+
R = None
|
| 210 |
+
|
| 211 |
+
# add graph attributes, H attributes take precedent over G attributes
|
| 212 |
+
for i, G in enumerate(graphs):
|
| 213 |
+
if i == 0:
|
| 214 |
+
# create new graph
|
| 215 |
+
R = G.__class__()
|
| 216 |
+
elif G.is_directed() != R.is_directed():
|
| 217 |
+
raise nx.NetworkXError("All graphs must be directed or undirected.")
|
| 218 |
+
elif G.is_multigraph() != R.is_multigraph():
|
| 219 |
+
raise nx.NetworkXError("All graphs must be graphs or multigraphs.")
|
| 220 |
+
|
| 221 |
+
R.graph.update(G.graph)
|
| 222 |
+
R.add_nodes_from(G.nodes(data=True))
|
| 223 |
+
R.add_edges_from(
|
| 224 |
+
G.edges(keys=True, data=True) if G.is_multigraph() else G.edges(data=True)
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
if R is None:
|
| 228 |
+
raise ValueError("cannot apply compose_all to an empty list")
|
| 229 |
+
|
| 230 |
+
return R
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
@nx._dispatchable(graphs="[graphs]", returns_graph=True)
|
| 234 |
+
def intersection_all(graphs):
|
| 235 |
+
"""Returns a new graph that contains only the nodes and the edges that exist in
|
| 236 |
+
all graphs.
|
| 237 |
+
|
| 238 |
+
Parameters
|
| 239 |
+
----------
|
| 240 |
+
graphs : iterable
|
| 241 |
+
Iterable of NetworkX graphs
|
| 242 |
+
|
| 243 |
+
Returns
|
| 244 |
+
-------
|
| 245 |
+
R : A new graph with the same type as the first graph in list
|
| 246 |
+
|
| 247 |
+
Raises
|
| 248 |
+
------
|
| 249 |
+
ValueError
|
| 250 |
+
If `graphs` is an empty list.
|
| 251 |
+
|
| 252 |
+
NetworkXError
|
| 253 |
+
In case of mixed type graphs, like MultiGraph and Graph, or directed and undirected graphs.
|
| 254 |
+
|
| 255 |
+
Notes
|
| 256 |
+
-----
|
| 257 |
+
For operating on mixed type graphs, they should be converted to the same type.
|
| 258 |
+
|
| 259 |
+
Attributes from the graph, nodes, and edges are not copied to the new
|
| 260 |
+
graph.
|
| 261 |
+
|
| 262 |
+
The resulting graph can be updated with attributes if desired.
|
| 263 |
+
For example, code which adds the minimum attribute for each node across all
|
| 264 |
+
graphs could work::
|
| 265 |
+
|
| 266 |
+
>>> g = nx.Graph()
|
| 267 |
+
>>> g.add_node(0, capacity=4)
|
| 268 |
+
>>> g.add_node(1, capacity=3)
|
| 269 |
+
>>> g.add_edge(0, 1)
|
| 270 |
+
|
| 271 |
+
>>> h = g.copy()
|
| 272 |
+
>>> h.nodes[0]["capacity"] = 2
|
| 273 |
+
|
| 274 |
+
>>> gh = nx.intersection_all([g, h])
|
| 275 |
+
|
| 276 |
+
>>> new_node_attr = {
|
| 277 |
+
... n: min(*(anyG.nodes[n].get("capacity", float("inf")) for anyG in [g, h]))
|
| 278 |
+
... for n in gh
|
| 279 |
+
... }
|
| 280 |
+
>>> nx.set_node_attributes(gh, new_node_attr, "new_capacity")
|
| 281 |
+
>>> gh.nodes(data=True)
|
| 282 |
+
NodeDataView({0: {'new_capacity': 2}, 1: {'new_capacity': 3}})
|
| 283 |
+
|
| 284 |
+
Examples
|
| 285 |
+
--------
|
| 286 |
+
>>> G1 = nx.Graph([(1, 2), (2, 3)])
|
| 287 |
+
>>> G2 = nx.Graph([(2, 3), (3, 4)])
|
| 288 |
+
>>> R = nx.intersection_all([G1, G2])
|
| 289 |
+
>>> list(R.nodes())
|
| 290 |
+
[2, 3]
|
| 291 |
+
>>> list(R.edges())
|
| 292 |
+
[(2, 3)]
|
| 293 |
+
|
| 294 |
+
"""
|
| 295 |
+
R = None
|
| 296 |
+
|
| 297 |
+
for i, G in enumerate(graphs):
|
| 298 |
+
G_nodes_set = set(G.nodes)
|
| 299 |
+
G_edges_set = set(G.edges)
|
| 300 |
+
if not G.is_directed():
|
| 301 |
+
if G.is_multigraph():
|
| 302 |
+
G_edges_set.update((v, u, k) for u, v, k in list(G_edges_set))
|
| 303 |
+
else:
|
| 304 |
+
G_edges_set.update((v, u) for u, v in list(G_edges_set))
|
| 305 |
+
if i == 0:
|
| 306 |
+
# create new graph
|
| 307 |
+
R = G.__class__()
|
| 308 |
+
node_intersection = G_nodes_set
|
| 309 |
+
edge_intersection = G_edges_set
|
| 310 |
+
elif G.is_directed() != R.is_directed():
|
| 311 |
+
raise nx.NetworkXError("All graphs must be directed or undirected.")
|
| 312 |
+
elif G.is_multigraph() != R.is_multigraph():
|
| 313 |
+
raise nx.NetworkXError("All graphs must be graphs or multigraphs.")
|
| 314 |
+
else:
|
| 315 |
+
node_intersection &= G_nodes_set
|
| 316 |
+
edge_intersection &= G_edges_set
|
| 317 |
+
|
| 318 |
+
if R is None:
|
| 319 |
+
raise ValueError("cannot apply intersection_all to an empty list")
|
| 320 |
+
|
| 321 |
+
R.add_nodes_from(node_intersection)
|
| 322 |
+
R.add_edges_from(edge_intersection)
|
| 323 |
+
|
| 324 |
+
return R
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/operators/binary.py
ADDED
|
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Operations on graphs including union, intersection, difference.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import networkx as nx
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"union",
|
| 9 |
+
"compose",
|
| 10 |
+
"disjoint_union",
|
| 11 |
+
"intersection",
|
| 12 |
+
"difference",
|
| 13 |
+
"symmetric_difference",
|
| 14 |
+
"full_join",
|
| 15 |
+
]
|
| 16 |
+
_G_H = {"G": 0, "H": 1}
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@nx._dispatchable(graphs=_G_H, preserve_all_attrs=True, returns_graph=True)
|
| 20 |
+
def union(G, H, rename=()):
|
| 21 |
+
"""Combine graphs G and H. The names of nodes must be unique.
|
| 22 |
+
|
| 23 |
+
A name collision between the graphs will raise an exception.
|
| 24 |
+
|
| 25 |
+
A renaming facility is provided to avoid name collisions.
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
Parameters
|
| 29 |
+
----------
|
| 30 |
+
G, H : graph
|
| 31 |
+
A NetworkX graph
|
| 32 |
+
|
| 33 |
+
rename : iterable , optional
|
| 34 |
+
Node names of G and H can be changed by specifying the tuple
|
| 35 |
+
rename=('G-','H-') (for example). Node "u" in G is then renamed
|
| 36 |
+
"G-u" and "v" in H is renamed "H-v".
|
| 37 |
+
|
| 38 |
+
Returns
|
| 39 |
+
-------
|
| 40 |
+
U : A union graph with the same type as G.
|
| 41 |
+
|
| 42 |
+
See Also
|
| 43 |
+
--------
|
| 44 |
+
compose
|
| 45 |
+
:func:`~networkx.Graph.update`
|
| 46 |
+
disjoint_union
|
| 47 |
+
|
| 48 |
+
Notes
|
| 49 |
+
-----
|
| 50 |
+
To combine graphs that have common nodes, consider compose(G, H)
|
| 51 |
+
or the method, Graph.update().
|
| 52 |
+
|
| 53 |
+
disjoint_union() is similar to union() except that it avoids name clashes
|
| 54 |
+
by relabeling the nodes with sequential integers.
|
| 55 |
+
|
| 56 |
+
Edge and node attributes are propagated from G and H to the union graph.
|
| 57 |
+
Graph attributes are also propagated, but if they are present in both G and H,
|
| 58 |
+
then the value from H is used.
|
| 59 |
+
|
| 60 |
+
Examples
|
| 61 |
+
--------
|
| 62 |
+
>>> from pprint import pprint
|
| 63 |
+
>>> G = nx.Graph([(0, 1), (0, 2), (1, 2)])
|
| 64 |
+
>>> H = nx.Graph([(0, 1), (0, 3), (1, 3), (1, 2)])
|
| 65 |
+
>>> U = nx.union(G, H, rename=("G", "H"))
|
| 66 |
+
>>> U.nodes
|
| 67 |
+
NodeView(('G0', 'G1', 'G2', 'H0', 'H1', 'H3', 'H2'))
|
| 68 |
+
>>> edgelist = list(U.edges)
|
| 69 |
+
>>> pprint(edgelist)
|
| 70 |
+
[('G0', 'G1'),
|
| 71 |
+
('G0', 'G2'),
|
| 72 |
+
('G1', 'G2'),
|
| 73 |
+
('H0', 'H1'),
|
| 74 |
+
('H0', 'H3'),
|
| 75 |
+
('H1', 'H3'),
|
| 76 |
+
('H1', 'H2')]
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
"""
|
| 80 |
+
return nx.union_all([G, H], rename)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@nx._dispatchable(graphs=_G_H, preserve_all_attrs=True, returns_graph=True)
|
| 84 |
+
def disjoint_union(G, H):
|
| 85 |
+
"""Combine graphs G and H. The nodes are assumed to be unique (disjoint).
|
| 86 |
+
|
| 87 |
+
This algorithm automatically relabels nodes to avoid name collisions.
|
| 88 |
+
|
| 89 |
+
Parameters
|
| 90 |
+
----------
|
| 91 |
+
G,H : graph
|
| 92 |
+
A NetworkX graph
|
| 93 |
+
|
| 94 |
+
Returns
|
| 95 |
+
-------
|
| 96 |
+
U : A union graph with the same type as G.
|
| 97 |
+
|
| 98 |
+
See Also
|
| 99 |
+
--------
|
| 100 |
+
union
|
| 101 |
+
compose
|
| 102 |
+
:func:`~networkx.Graph.update`
|
| 103 |
+
|
| 104 |
+
Notes
|
| 105 |
+
-----
|
| 106 |
+
A new graph is created, of the same class as G. It is recommended
|
| 107 |
+
that G and H be either both directed or both undirected.
|
| 108 |
+
|
| 109 |
+
The nodes of G are relabeled 0 to len(G)-1, and the nodes of H are
|
| 110 |
+
relabeled len(G) to len(G)+len(H)-1.
|
| 111 |
+
|
| 112 |
+
Renumbering forces G and H to be disjoint, so no exception is ever raised for a name collision.
|
| 113 |
+
To preserve the check for common nodes, use union().
|
| 114 |
+
|
| 115 |
+
Edge and node attributes are propagated from G and H to the union graph.
|
| 116 |
+
Graph attributes are also propagated, but if they are present in both G and H,
|
| 117 |
+
then the value from H is used.
|
| 118 |
+
|
| 119 |
+
To combine graphs that have common nodes, consider compose(G, H)
|
| 120 |
+
or the method, Graph.update().
|
| 121 |
+
|
| 122 |
+
Examples
|
| 123 |
+
--------
|
| 124 |
+
>>> G = nx.Graph([(0, 1), (0, 2), (1, 2)])
|
| 125 |
+
>>> H = nx.Graph([(0, 3), (1, 2), (2, 3)])
|
| 126 |
+
>>> G.nodes[0]["key1"] = 5
|
| 127 |
+
>>> H.nodes[0]["key2"] = 10
|
| 128 |
+
>>> U = nx.disjoint_union(G, H)
|
| 129 |
+
>>> U.nodes(data=True)
|
| 130 |
+
NodeDataView({0: {'key1': 5}, 1: {}, 2: {}, 3: {'key2': 10}, 4: {}, 5: {}, 6: {}})
|
| 131 |
+
>>> U.edges
|
| 132 |
+
EdgeView([(0, 1), (0, 2), (1, 2), (3, 4), (4, 6), (5, 6)])
|
| 133 |
+
"""
|
| 134 |
+
return nx.disjoint_union_all([G, H])
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@nx._dispatchable(graphs=_G_H, returns_graph=True)
|
| 138 |
+
def intersection(G, H):
|
| 139 |
+
"""Returns a new graph that contains only the nodes and the edges that exist in
|
| 140 |
+
both G and H.
|
| 141 |
+
|
| 142 |
+
Parameters
|
| 143 |
+
----------
|
| 144 |
+
G,H : graph
|
| 145 |
+
A NetworkX graph. G and H can have different node sets but must be both graphs or both multigraphs.
|
| 146 |
+
|
| 147 |
+
Raises
|
| 148 |
+
------
|
| 149 |
+
NetworkXError
|
| 150 |
+
If one is a MultiGraph and the other one is a graph.
|
| 151 |
+
|
| 152 |
+
Returns
|
| 153 |
+
-------
|
| 154 |
+
GH : A new graph with the same type as G.
|
| 155 |
+
|
| 156 |
+
Notes
|
| 157 |
+
-----
|
| 158 |
+
Attributes from the graph, nodes, and edges are not copied to the new
|
| 159 |
+
graph. If you want a new graph of the intersection of G and H
|
| 160 |
+
with the attributes (including edge data) from G use remove_nodes_from()
|
| 161 |
+
as follows
|
| 162 |
+
|
| 163 |
+
>>> G = nx.path_graph(3)
|
| 164 |
+
>>> H = nx.path_graph(5)
|
| 165 |
+
>>> R = G.copy()
|
| 166 |
+
>>> R.remove_nodes_from(n for n in G if n not in H)
|
| 167 |
+
>>> R.remove_edges_from(e for e in G.edges if e not in H.edges)
|
| 168 |
+
|
| 169 |
+
Examples
|
| 170 |
+
--------
|
| 171 |
+
>>> G = nx.Graph([(0, 1), (0, 2), (1, 2)])
|
| 172 |
+
>>> H = nx.Graph([(0, 3), (1, 2), (2, 3)])
|
| 173 |
+
>>> R = nx.intersection(G, H)
|
| 174 |
+
>>> R.nodes
|
| 175 |
+
NodeView((0, 1, 2))
|
| 176 |
+
>>> R.edges
|
| 177 |
+
EdgeView([(1, 2)])
|
| 178 |
+
"""
|
| 179 |
+
return nx.intersection_all([G, H])
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
@nx._dispatchable(graphs=_G_H, returns_graph=True)
|
| 183 |
+
def difference(G, H):
|
| 184 |
+
"""Returns a new graph that contains the edges that exist in G but not in H.
|
| 185 |
+
|
| 186 |
+
The node sets of H and G must be the same.
|
| 187 |
+
|
| 188 |
+
Parameters
|
| 189 |
+
----------
|
| 190 |
+
G,H : graph
|
| 191 |
+
A NetworkX graph. G and H must have the same node sets.
|
| 192 |
+
|
| 193 |
+
Returns
|
| 194 |
+
-------
|
| 195 |
+
D : A new graph with the same type as G.
|
| 196 |
+
|
| 197 |
+
Notes
|
| 198 |
+
-----
|
| 199 |
+
Attributes from the graph, nodes, and edges are not copied to the new
|
| 200 |
+
graph. If you want a new graph of the difference of G and H with
|
| 201 |
+
the attributes (including edge data) from G use remove_nodes_from()
|
| 202 |
+
as follows:
|
| 203 |
+
|
| 204 |
+
>>> G = nx.path_graph(3)
|
| 205 |
+
>>> H = nx.path_graph(5)
|
| 206 |
+
>>> R = G.copy()
|
| 207 |
+
>>> R.remove_nodes_from(n for n in G if n in H)
|
| 208 |
+
|
| 209 |
+
Examples
|
| 210 |
+
--------
|
| 211 |
+
>>> G = nx.Graph([(0, 1), (0, 2), (1, 2), (1, 3)])
|
| 212 |
+
>>> H = nx.Graph([(0, 1), (1, 2), (0, 3)])
|
| 213 |
+
>>> R = nx.difference(G, H)
|
| 214 |
+
>>> R.nodes
|
| 215 |
+
NodeView((0, 1, 2, 3))
|
| 216 |
+
>>> R.edges
|
| 217 |
+
EdgeView([(0, 2), (1, 3)])
|
| 218 |
+
"""
|
| 219 |
+
# create new graph
|
| 220 |
+
if not G.is_multigraph() == H.is_multigraph():
|
| 221 |
+
raise nx.NetworkXError("G and H must both be graphs or multigraphs.")
|
| 222 |
+
R = nx.create_empty_copy(G, with_data=False)
|
| 223 |
+
|
| 224 |
+
if set(G) != set(H):
|
| 225 |
+
raise nx.NetworkXError("Node sets of graphs not equal")
|
| 226 |
+
|
| 227 |
+
if G.is_multigraph():
|
| 228 |
+
edges = G.edges(keys=True)
|
| 229 |
+
else:
|
| 230 |
+
edges = G.edges()
|
| 231 |
+
for e in edges:
|
| 232 |
+
if not H.has_edge(*e):
|
| 233 |
+
R.add_edge(*e)
|
| 234 |
+
return R
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
@nx._dispatchable(graphs=_G_H, returns_graph=True)
|
| 238 |
+
def symmetric_difference(G, H):
|
| 239 |
+
"""Returns new graph with edges that exist in either G or H but not both.
|
| 240 |
+
|
| 241 |
+
The node sets of H and G must be the same.
|
| 242 |
+
|
| 243 |
+
Parameters
|
| 244 |
+
----------
|
| 245 |
+
G,H : graph
|
| 246 |
+
A NetworkX graph. G and H must have the same node sets.
|
| 247 |
+
|
| 248 |
+
Returns
|
| 249 |
+
-------
|
| 250 |
+
D : A new graph with the same type as G.
|
| 251 |
+
|
| 252 |
+
Notes
|
| 253 |
+
-----
|
| 254 |
+
Attributes from the graph, nodes, and edges are not copied to the new
|
| 255 |
+
graph.
|
| 256 |
+
|
| 257 |
+
Examples
|
| 258 |
+
--------
|
| 259 |
+
>>> G = nx.Graph([(0, 1), (0, 2), (1, 2), (1, 3)])
|
| 260 |
+
>>> H = nx.Graph([(0, 1), (1, 2), (0, 3)])
|
| 261 |
+
>>> R = nx.symmetric_difference(G, H)
|
| 262 |
+
>>> R.nodes
|
| 263 |
+
NodeView((0, 1, 2, 3))
|
| 264 |
+
>>> R.edges
|
| 265 |
+
EdgeView([(0, 2), (0, 3), (1, 3)])
|
| 266 |
+
"""
|
| 267 |
+
# create new graph
|
| 268 |
+
if not G.is_multigraph() == H.is_multigraph():
|
| 269 |
+
raise nx.NetworkXError("G and H must both be graphs or multigraphs.")
|
| 270 |
+
R = nx.create_empty_copy(G, with_data=False)
|
| 271 |
+
|
| 272 |
+
if set(G) != set(H):
|
| 273 |
+
raise nx.NetworkXError("Node sets of graphs not equal")
|
| 274 |
+
|
| 275 |
+
gnodes = set(G) # set of nodes in G
|
| 276 |
+
hnodes = set(H) # set of nodes in H
|
| 277 |
+
nodes = gnodes.symmetric_difference(hnodes)
|
| 278 |
+
R.add_nodes_from(nodes)
|
| 279 |
+
|
| 280 |
+
if G.is_multigraph():
|
| 281 |
+
edges = G.edges(keys=True)
|
| 282 |
+
else:
|
| 283 |
+
edges = G.edges()
|
| 284 |
+
# we could copy the data here but then this function doesn't
|
| 285 |
+
# match intersection and difference
|
| 286 |
+
for e in edges:
|
| 287 |
+
if not H.has_edge(*e):
|
| 288 |
+
R.add_edge(*e)
|
| 289 |
+
|
| 290 |
+
if H.is_multigraph():
|
| 291 |
+
edges = H.edges(keys=True)
|
| 292 |
+
else:
|
| 293 |
+
edges = H.edges()
|
| 294 |
+
for e in edges:
|
| 295 |
+
if not G.has_edge(*e):
|
| 296 |
+
R.add_edge(*e)
|
| 297 |
+
return R
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
@nx._dispatchable(graphs=_G_H, preserve_all_attrs=True, returns_graph=True)
|
| 301 |
+
def compose(G, H):
|
| 302 |
+
"""Compose graph G with H by combining nodes and edges into a single graph.
|
| 303 |
+
|
| 304 |
+
The node sets and edges sets do not need to be disjoint.
|
| 305 |
+
|
| 306 |
+
Composing preserves the attributes of nodes and edges.
|
| 307 |
+
Attribute values from H take precedent over attribute values from G.
|
| 308 |
+
|
| 309 |
+
Parameters
|
| 310 |
+
----------
|
| 311 |
+
G, H : graph
|
| 312 |
+
A NetworkX graph
|
| 313 |
+
|
| 314 |
+
Returns
|
| 315 |
+
-------
|
| 316 |
+
C: A new graph with the same type as G
|
| 317 |
+
|
| 318 |
+
See Also
|
| 319 |
+
--------
|
| 320 |
+
:func:`~networkx.Graph.update`
|
| 321 |
+
union
|
| 322 |
+
disjoint_union
|
| 323 |
+
|
| 324 |
+
Notes
|
| 325 |
+
-----
|
| 326 |
+
It is recommended that G and H be either both directed or both undirected.
|
| 327 |
+
|
| 328 |
+
For MultiGraphs, the edges are identified by incident nodes AND edge-key.
|
| 329 |
+
This can cause surprises (i.e., edge `(1, 2)` may or may not be the same
|
| 330 |
+
in two graphs) if you use MultiGraph without keeping track of edge keys.
|
| 331 |
+
|
| 332 |
+
If combining the attributes of common nodes is not desired, consider union(),
|
| 333 |
+
which raises an exception for name collisions.
|
| 334 |
+
|
| 335 |
+
Examples
|
| 336 |
+
--------
|
| 337 |
+
>>> G = nx.Graph([(0, 1), (0, 2)])
|
| 338 |
+
>>> H = nx.Graph([(0, 1), (1, 2)])
|
| 339 |
+
>>> R = nx.compose(G, H)
|
| 340 |
+
>>> R.nodes
|
| 341 |
+
NodeView((0, 1, 2))
|
| 342 |
+
>>> R.edges
|
| 343 |
+
EdgeView([(0, 1), (0, 2), (1, 2)])
|
| 344 |
+
|
| 345 |
+
By default, the attributes from `H` take precedent over attributes from `G`.
|
| 346 |
+
If you prefer another way of combining attributes, you can update them after the compose operation:
|
| 347 |
+
|
| 348 |
+
>>> G = nx.Graph([(0, 1, {"weight": 2.0}), (3, 0, {"weight": 100.0})])
|
| 349 |
+
>>> H = nx.Graph([(0, 1, {"weight": 10.0}), (1, 2, {"weight": -1.0})])
|
| 350 |
+
>>> nx.set_node_attributes(G, {0: "dark", 1: "light", 3: "black"}, name="color")
|
| 351 |
+
>>> nx.set_node_attributes(H, {0: "green", 1: "orange", 2: "yellow"}, name="color")
|
| 352 |
+
>>> GcomposeH = nx.compose(G, H)
|
| 353 |
+
|
| 354 |
+
Normally, color attribute values of nodes of GcomposeH come from H. We can workaround this as follows:
|
| 355 |
+
|
| 356 |
+
>>> node_data = {
|
| 357 |
+
... n: G.nodes[n]["color"] + " " + H.nodes[n]["color"]
|
| 358 |
+
... for n in G.nodes & H.nodes
|
| 359 |
+
... }
|
| 360 |
+
>>> nx.set_node_attributes(GcomposeH, node_data, "color")
|
| 361 |
+
>>> print(GcomposeH.nodes[0]["color"])
|
| 362 |
+
dark green
|
| 363 |
+
|
| 364 |
+
>>> print(GcomposeH.nodes[3]["color"])
|
| 365 |
+
black
|
| 366 |
+
|
| 367 |
+
Similarly, we can update edge attributes after the compose operation in a way we prefer:
|
| 368 |
+
|
| 369 |
+
>>> edge_data = {
|
| 370 |
+
... e: G.edges[e]["weight"] * H.edges[e]["weight"] for e in G.edges & H.edges
|
| 371 |
+
... }
|
| 372 |
+
>>> nx.set_edge_attributes(GcomposeH, edge_data, "weight")
|
| 373 |
+
>>> print(GcomposeH.edges[(0, 1)]["weight"])
|
| 374 |
+
20.0
|
| 375 |
+
|
| 376 |
+
>>> print(GcomposeH.edges[(3, 0)]["weight"])
|
| 377 |
+
100.0
|
| 378 |
+
"""
|
| 379 |
+
return nx.compose_all([G, H])
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
@nx._dispatchable(graphs=_G_H, preserve_all_attrs=True, returns_graph=True)
|
| 383 |
+
def full_join(G, H, rename=(None, None)):
|
| 384 |
+
"""Returns the full join of graphs G and H.
|
| 385 |
+
|
| 386 |
+
Full join is the union of G and H in which all edges between
|
| 387 |
+
G and H are added.
|
| 388 |
+
The node sets of G and H must be disjoint,
|
| 389 |
+
otherwise an exception is raised.
|
| 390 |
+
|
| 391 |
+
Parameters
|
| 392 |
+
----------
|
| 393 |
+
G, H : graph
|
| 394 |
+
A NetworkX graph
|
| 395 |
+
|
| 396 |
+
rename : tuple , default=(None, None)
|
| 397 |
+
Node names of G and H can be changed by specifying the tuple
|
| 398 |
+
rename=('G-','H-') (for example). Node "u" in G is then renamed
|
| 399 |
+
"G-u" and "v" in H is renamed "H-v".
|
| 400 |
+
|
| 401 |
+
Returns
|
| 402 |
+
-------
|
| 403 |
+
U : The full join graph with the same type as G.
|
| 404 |
+
|
| 405 |
+
Notes
|
| 406 |
+
-----
|
| 407 |
+
It is recommended that G and H be either both directed or both undirected.
|
| 408 |
+
|
| 409 |
+
If G is directed, then edges from G to H are added as well as from H to G.
|
| 410 |
+
|
| 411 |
+
Note that full_join() does not produce parallel edges for MultiGraphs.
|
| 412 |
+
|
| 413 |
+
The full join operation of graphs G and H is the same as getting
|
| 414 |
+
their complement, performing a disjoint union, and finally getting
|
| 415 |
+
the complement of the resulting graph.
|
| 416 |
+
|
| 417 |
+
Graph, edge, and node attributes are propagated from G and H
|
| 418 |
+
to the union graph. If a graph attribute is present in both
|
| 419 |
+
G and H the value from H is used.
|
| 420 |
+
|
| 421 |
+
Examples
|
| 422 |
+
--------
|
| 423 |
+
>>> from pprint import pprint
|
| 424 |
+
>>> G = nx.Graph([(0, 1), (0, 2)])
|
| 425 |
+
>>> H = nx.Graph([(3, 4)])
|
| 426 |
+
>>> R = nx.full_join(G, H, rename=("G", "H"))
|
| 427 |
+
>>> R.nodes
|
| 428 |
+
NodeView(('G0', 'G1', 'G2', 'H3', 'H4'))
|
| 429 |
+
>>> edgelist = list(R.edges)
|
| 430 |
+
>>> pprint(edgelist)
|
| 431 |
+
[('G0', 'G1'),
|
| 432 |
+
('G0', 'G2'),
|
| 433 |
+
('G0', 'H3'),
|
| 434 |
+
('G0', 'H4'),
|
| 435 |
+
('G1', 'H3'),
|
| 436 |
+
('G1', 'H4'),
|
| 437 |
+
('G2', 'H3'),
|
| 438 |
+
('G2', 'H4'),
|
| 439 |
+
('H3', 'H4')]
|
| 440 |
+
|
| 441 |
+
See Also
|
| 442 |
+
--------
|
| 443 |
+
union
|
| 444 |
+
disjoint_union
|
| 445 |
+
"""
|
| 446 |
+
R = union(G, H, rename)
|
| 447 |
+
|
| 448 |
+
def add_prefix(graph, prefix):
|
| 449 |
+
if prefix is None:
|
| 450 |
+
return graph
|
| 451 |
+
|
| 452 |
+
def label(x):
|
| 453 |
+
return f"{prefix}{x}"
|
| 454 |
+
|
| 455 |
+
return nx.relabel_nodes(graph, label)
|
| 456 |
+
|
| 457 |
+
G = add_prefix(G, rename[0])
|
| 458 |
+
H = add_prefix(H, rename[1])
|
| 459 |
+
|
| 460 |
+
for i in G:
|
| 461 |
+
for j in H:
|
| 462 |
+
R.add_edge(i, j)
|
| 463 |
+
if R.is_directed():
|
| 464 |
+
for i in H:
|
| 465 |
+
for j in G:
|
| 466 |
+
R.add_edge(i, j)
|
| 467 |
+
|
| 468 |
+
return R
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/operators/product.py
ADDED
|
@@ -0,0 +1,633 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Graph products.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from itertools import product
|
| 6 |
+
|
| 7 |
+
import networkx as nx
|
| 8 |
+
from networkx.utils import not_implemented_for
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"tensor_product",
|
| 12 |
+
"cartesian_product",
|
| 13 |
+
"lexicographic_product",
|
| 14 |
+
"strong_product",
|
| 15 |
+
"power",
|
| 16 |
+
"rooted_product",
|
| 17 |
+
"corona_product",
|
| 18 |
+
"modular_product",
|
| 19 |
+
]
|
| 20 |
+
_G_H = {"G": 0, "H": 1}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _dict_product(d1, d2):
|
| 24 |
+
return {k: (d1.get(k), d2.get(k)) for k in set(d1) | set(d2)}
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# Generators for producing graph products
|
| 28 |
+
def _node_product(G, H):
|
| 29 |
+
for u, v in product(G, H):
|
| 30 |
+
yield ((u, v), _dict_product(G.nodes[u], H.nodes[v]))
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _directed_edges_cross_edges(G, H):
|
| 34 |
+
if not G.is_multigraph() and not H.is_multigraph():
|
| 35 |
+
for u, v, c in G.edges(data=True):
|
| 36 |
+
for x, y, d in H.edges(data=True):
|
| 37 |
+
yield (u, x), (v, y), _dict_product(c, d)
|
| 38 |
+
if not G.is_multigraph() and H.is_multigraph():
|
| 39 |
+
for u, v, c in G.edges(data=True):
|
| 40 |
+
for x, y, k, d in H.edges(data=True, keys=True):
|
| 41 |
+
yield (u, x), (v, y), k, _dict_product(c, d)
|
| 42 |
+
if G.is_multigraph() and not H.is_multigraph():
|
| 43 |
+
for u, v, k, c in G.edges(data=True, keys=True):
|
| 44 |
+
for x, y, d in H.edges(data=True):
|
| 45 |
+
yield (u, x), (v, y), k, _dict_product(c, d)
|
| 46 |
+
if G.is_multigraph() and H.is_multigraph():
|
| 47 |
+
for u, v, j, c in G.edges(data=True, keys=True):
|
| 48 |
+
for x, y, k, d in H.edges(data=True, keys=True):
|
| 49 |
+
yield (u, x), (v, y), (j, k), _dict_product(c, d)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _undirected_edges_cross_edges(G, H):
|
| 53 |
+
if not G.is_multigraph() and not H.is_multigraph():
|
| 54 |
+
for u, v, c in G.edges(data=True):
|
| 55 |
+
for x, y, d in H.edges(data=True):
|
| 56 |
+
yield (v, x), (u, y), _dict_product(c, d)
|
| 57 |
+
if not G.is_multigraph() and H.is_multigraph():
|
| 58 |
+
for u, v, c in G.edges(data=True):
|
| 59 |
+
for x, y, k, d in H.edges(data=True, keys=True):
|
| 60 |
+
yield (v, x), (u, y), k, _dict_product(c, d)
|
| 61 |
+
if G.is_multigraph() and not H.is_multigraph():
|
| 62 |
+
for u, v, k, c in G.edges(data=True, keys=True):
|
| 63 |
+
for x, y, d in H.edges(data=True):
|
| 64 |
+
yield (v, x), (u, y), k, _dict_product(c, d)
|
| 65 |
+
if G.is_multigraph() and H.is_multigraph():
|
| 66 |
+
for u, v, j, c in G.edges(data=True, keys=True):
|
| 67 |
+
for x, y, k, d in H.edges(data=True, keys=True):
|
| 68 |
+
yield (v, x), (u, y), (j, k), _dict_product(c, d)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _edges_cross_nodes(G, H):
|
| 72 |
+
if G.is_multigraph():
|
| 73 |
+
for u, v, k, d in G.edges(data=True, keys=True):
|
| 74 |
+
for x in H:
|
| 75 |
+
yield (u, x), (v, x), k, d
|
| 76 |
+
else:
|
| 77 |
+
for u, v, d in G.edges(data=True):
|
| 78 |
+
for x in H:
|
| 79 |
+
if H.is_multigraph():
|
| 80 |
+
yield (u, x), (v, x), None, d
|
| 81 |
+
else:
|
| 82 |
+
yield (u, x), (v, x), d
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _nodes_cross_edges(G, H):
|
| 86 |
+
if H.is_multigraph():
|
| 87 |
+
for x in G:
|
| 88 |
+
for u, v, k, d in H.edges(data=True, keys=True):
|
| 89 |
+
yield (x, u), (x, v), k, d
|
| 90 |
+
else:
|
| 91 |
+
for x in G:
|
| 92 |
+
for u, v, d in H.edges(data=True):
|
| 93 |
+
if G.is_multigraph():
|
| 94 |
+
yield (x, u), (x, v), None, d
|
| 95 |
+
else:
|
| 96 |
+
yield (x, u), (x, v), d
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _edges_cross_nodes_and_nodes(G, H):
|
| 100 |
+
if G.is_multigraph():
|
| 101 |
+
for u, v, k, d in G.edges(data=True, keys=True):
|
| 102 |
+
for x in H:
|
| 103 |
+
for y in H:
|
| 104 |
+
yield (u, x), (v, y), k, d
|
| 105 |
+
else:
|
| 106 |
+
for u, v, d in G.edges(data=True):
|
| 107 |
+
for x in H:
|
| 108 |
+
for y in H:
|
| 109 |
+
if H.is_multigraph():
|
| 110 |
+
yield (u, x), (v, y), None, d
|
| 111 |
+
else:
|
| 112 |
+
yield (u, x), (v, y), d
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def _init_product_graph(G, H):
|
| 116 |
+
if G.is_directed() != H.is_directed():
|
| 117 |
+
msg = "G and H must be both directed or both undirected"
|
| 118 |
+
raise nx.NetworkXError(msg)
|
| 119 |
+
if G.is_multigraph() or H.is_multigraph():
|
| 120 |
+
GH = nx.MultiGraph()
|
| 121 |
+
else:
|
| 122 |
+
GH = nx.Graph()
|
| 123 |
+
if G.is_directed():
|
| 124 |
+
GH = GH.to_directed()
|
| 125 |
+
return GH
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@nx._dispatchable(graphs=_G_H, preserve_node_attrs=True, returns_graph=True)
|
| 129 |
+
def tensor_product(G, H):
|
| 130 |
+
r"""Returns the tensor product of G and H.
|
| 131 |
+
|
| 132 |
+
The tensor product $P$ of the graphs $G$ and $H$ has a node set that
|
| 133 |
+
is the Cartesian product of the node sets, $V(P)=V(G) \times V(H)$.
|
| 134 |
+
$P$ has an edge $((u,v), (x,y))$ if and only if $(u,x)$ is an edge in $G$
|
| 135 |
+
and $(v,y)$ is an edge in $H$.
|
| 136 |
+
|
| 137 |
+
Tensor product is sometimes also referred to as the categorical product,
|
| 138 |
+
direct product, cardinal product or conjunction.
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
Parameters
|
| 142 |
+
----------
|
| 143 |
+
G, H: graphs
|
| 144 |
+
Networkx graphs.
|
| 145 |
+
|
| 146 |
+
Returns
|
| 147 |
+
-------
|
| 148 |
+
P: NetworkX graph
|
| 149 |
+
The tensor product of G and H. P will be a multi-graph if either G
|
| 150 |
+
or H is a multi-graph, will be a directed if G and H are directed,
|
| 151 |
+
and undirected if G and H are undirected.
|
| 152 |
+
|
| 153 |
+
Raises
|
| 154 |
+
------
|
| 155 |
+
NetworkXError
|
| 156 |
+
If G and H are not both directed or both undirected.
|
| 157 |
+
|
| 158 |
+
Notes
|
| 159 |
+
-----
|
| 160 |
+
Node attributes in P are two-tuple of the G and H node attributes.
|
| 161 |
+
Missing attributes are assigned None.
|
| 162 |
+
|
| 163 |
+
Examples
|
| 164 |
+
--------
|
| 165 |
+
>>> G = nx.Graph()
|
| 166 |
+
>>> H = nx.Graph()
|
| 167 |
+
>>> G.add_node(0, a1=True)
|
| 168 |
+
>>> H.add_node("a", a2="Spam")
|
| 169 |
+
>>> P = nx.tensor_product(G, H)
|
| 170 |
+
>>> list(P)
|
| 171 |
+
[(0, 'a')]
|
| 172 |
+
|
| 173 |
+
Edge attributes and edge keys (for multigraphs) are also copied to the
|
| 174 |
+
new product graph
|
| 175 |
+
"""
|
| 176 |
+
GH = _init_product_graph(G, H)
|
| 177 |
+
GH.add_nodes_from(_node_product(G, H))
|
| 178 |
+
GH.add_edges_from(_directed_edges_cross_edges(G, H))
|
| 179 |
+
if not GH.is_directed():
|
| 180 |
+
GH.add_edges_from(_undirected_edges_cross_edges(G, H))
|
| 181 |
+
return GH
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
@nx._dispatchable(graphs=_G_H, preserve_node_attrs=True, returns_graph=True)
|
| 185 |
+
def cartesian_product(G, H):
|
| 186 |
+
r"""Returns the Cartesian product of G and H.
|
| 187 |
+
|
| 188 |
+
The Cartesian product $P$ of the graphs $G$ and $H$ has a node set that
|
| 189 |
+
is the Cartesian product of the node sets, $V(P)=V(G) \times V(H)$.
|
| 190 |
+
$P$ has an edge $((u,v),(x,y))$ if and only if either $u$ is equal to $x$
|
| 191 |
+
and both $v$ and $y$ are adjacent in $H$ or if $v$ is equal to $y$ and
|
| 192 |
+
both $u$ and $x$ are adjacent in $G$.
|
| 193 |
+
|
| 194 |
+
Parameters
|
| 195 |
+
----------
|
| 196 |
+
G, H: graphs
|
| 197 |
+
Networkx graphs.
|
| 198 |
+
|
| 199 |
+
Returns
|
| 200 |
+
-------
|
| 201 |
+
P: NetworkX graph
|
| 202 |
+
The Cartesian product of G and H. P will be a multi-graph if either G
|
| 203 |
+
or H is a multi-graph. Will be a directed if G and H are directed,
|
| 204 |
+
and undirected if G and H are undirected.
|
| 205 |
+
|
| 206 |
+
Raises
|
| 207 |
+
------
|
| 208 |
+
NetworkXError
|
| 209 |
+
If G and H are not both directed or both undirected.
|
| 210 |
+
|
| 211 |
+
Notes
|
| 212 |
+
-----
|
| 213 |
+
Node attributes in P are two-tuple of the G and H node attributes.
|
| 214 |
+
Missing attributes are assigned None.
|
| 215 |
+
|
| 216 |
+
Examples
|
| 217 |
+
--------
|
| 218 |
+
>>> G = nx.Graph()
|
| 219 |
+
>>> H = nx.Graph()
|
| 220 |
+
>>> G.add_node(0, a1=True)
|
| 221 |
+
>>> H.add_node("a", a2="Spam")
|
| 222 |
+
>>> P = nx.cartesian_product(G, H)
|
| 223 |
+
>>> list(P)
|
| 224 |
+
[(0, 'a')]
|
| 225 |
+
|
| 226 |
+
Edge attributes and edge keys (for multigraphs) are also copied to the
|
| 227 |
+
new product graph
|
| 228 |
+
"""
|
| 229 |
+
GH = _init_product_graph(G, H)
|
| 230 |
+
GH.add_nodes_from(_node_product(G, H))
|
| 231 |
+
GH.add_edges_from(_edges_cross_nodes(G, H))
|
| 232 |
+
GH.add_edges_from(_nodes_cross_edges(G, H))
|
| 233 |
+
return GH
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
@nx._dispatchable(graphs=_G_H, preserve_node_attrs=True, returns_graph=True)
|
| 237 |
+
def lexicographic_product(G, H):
|
| 238 |
+
r"""Returns the lexicographic product of G and H.
|
| 239 |
+
|
| 240 |
+
The lexicographical product $P$ of the graphs $G$ and $H$ has a node set
|
| 241 |
+
that is the Cartesian product of the node sets, $V(P)=V(G) \times V(H)$.
|
| 242 |
+
$P$ has an edge $((u,v), (x,y))$ if and only if $(u,v)$ is an edge in $G$
|
| 243 |
+
or $u==v$ and $(x,y)$ is an edge in $H$.
|
| 244 |
+
|
| 245 |
+
Parameters
|
| 246 |
+
----------
|
| 247 |
+
G, H: graphs
|
| 248 |
+
Networkx graphs.
|
| 249 |
+
|
| 250 |
+
Returns
|
| 251 |
+
-------
|
| 252 |
+
P: NetworkX graph
|
| 253 |
+
The Cartesian product of G and H. P will be a multi-graph if either G
|
| 254 |
+
or H is a multi-graph. Will be a directed if G and H are directed,
|
| 255 |
+
and undirected if G and H are undirected.
|
| 256 |
+
|
| 257 |
+
Raises
|
| 258 |
+
------
|
| 259 |
+
NetworkXError
|
| 260 |
+
If G and H are not both directed or both undirected.
|
| 261 |
+
|
| 262 |
+
Notes
|
| 263 |
+
-----
|
| 264 |
+
Node attributes in P are two-tuple of the G and H node attributes.
|
| 265 |
+
Missing attributes are assigned None.
|
| 266 |
+
|
| 267 |
+
Examples
|
| 268 |
+
--------
|
| 269 |
+
>>> G = nx.Graph()
|
| 270 |
+
>>> H = nx.Graph()
|
| 271 |
+
>>> G.add_node(0, a1=True)
|
| 272 |
+
>>> H.add_node("a", a2="Spam")
|
| 273 |
+
>>> P = nx.lexicographic_product(G, H)
|
| 274 |
+
>>> list(P)
|
| 275 |
+
[(0, 'a')]
|
| 276 |
+
|
| 277 |
+
Edge attributes and edge keys (for multigraphs) are also copied to the
|
| 278 |
+
new product graph
|
| 279 |
+
"""
|
| 280 |
+
GH = _init_product_graph(G, H)
|
| 281 |
+
GH.add_nodes_from(_node_product(G, H))
|
| 282 |
+
# Edges in G regardless of H designation
|
| 283 |
+
GH.add_edges_from(_edges_cross_nodes_and_nodes(G, H))
|
| 284 |
+
# For each x in G, only if there is an edge in H
|
| 285 |
+
GH.add_edges_from(_nodes_cross_edges(G, H))
|
| 286 |
+
return GH
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
@nx._dispatchable(graphs=_G_H, preserve_node_attrs=True, returns_graph=True)
|
| 290 |
+
def strong_product(G, H):
|
| 291 |
+
r"""Returns the strong product of G and H.
|
| 292 |
+
|
| 293 |
+
The strong product $P$ of the graphs $G$ and $H$ has a node set that
|
| 294 |
+
is the Cartesian product of the node sets, $V(P)=V(G) \times V(H)$.
|
| 295 |
+
$P$ has an edge $((u,x), (v,y))$ if any of the following conditions
|
| 296 |
+
are met:
|
| 297 |
+
|
| 298 |
+
- $u=v$ and $(x,y)$ is an edge in $H$
|
| 299 |
+
- $x=y$ and $(u,v)$ is an edge in $G$
|
| 300 |
+
- $(u,v)$ is an edge in $G$ and $(x,y)$ is an edge in $H$
|
| 301 |
+
|
| 302 |
+
Parameters
|
| 303 |
+
----------
|
| 304 |
+
G, H: graphs
|
| 305 |
+
Networkx graphs.
|
| 306 |
+
|
| 307 |
+
Returns
|
| 308 |
+
-------
|
| 309 |
+
P: NetworkX graph
|
| 310 |
+
The Cartesian product of G and H. P will be a multi-graph if either G
|
| 311 |
+
or H is a multi-graph. Will be a directed if G and H are directed,
|
| 312 |
+
and undirected if G and H are undirected.
|
| 313 |
+
|
| 314 |
+
Raises
|
| 315 |
+
------
|
| 316 |
+
NetworkXError
|
| 317 |
+
If G and H are not both directed or both undirected.
|
| 318 |
+
|
| 319 |
+
Notes
|
| 320 |
+
-----
|
| 321 |
+
Node attributes in P are two-tuple of the G and H node attributes.
|
| 322 |
+
Missing attributes are assigned None.
|
| 323 |
+
|
| 324 |
+
Examples
|
| 325 |
+
--------
|
| 326 |
+
>>> G = nx.Graph()
|
| 327 |
+
>>> H = nx.Graph()
|
| 328 |
+
>>> G.add_node(0, a1=True)
|
| 329 |
+
>>> H.add_node("a", a2="Spam")
|
| 330 |
+
>>> P = nx.strong_product(G, H)
|
| 331 |
+
>>> list(P)
|
| 332 |
+
[(0, 'a')]
|
| 333 |
+
|
| 334 |
+
Edge attributes and edge keys (for multigraphs) are also copied to the
|
| 335 |
+
new product graph
|
| 336 |
+
"""
|
| 337 |
+
GH = _init_product_graph(G, H)
|
| 338 |
+
GH.add_nodes_from(_node_product(G, H))
|
| 339 |
+
GH.add_edges_from(_nodes_cross_edges(G, H))
|
| 340 |
+
GH.add_edges_from(_edges_cross_nodes(G, H))
|
| 341 |
+
GH.add_edges_from(_directed_edges_cross_edges(G, H))
|
| 342 |
+
if not GH.is_directed():
|
| 343 |
+
GH.add_edges_from(_undirected_edges_cross_edges(G, H))
|
| 344 |
+
return GH
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
@not_implemented_for("directed")
|
| 348 |
+
@not_implemented_for("multigraph")
|
| 349 |
+
@nx._dispatchable(returns_graph=True)
|
| 350 |
+
def power(G, k):
|
| 351 |
+
"""Returns the specified power of a graph.
|
| 352 |
+
|
| 353 |
+
The $k$th power of a simple graph $G$, denoted $G^k$, is a
|
| 354 |
+
graph on the same set of nodes in which two distinct nodes $u$ and
|
| 355 |
+
$v$ are adjacent in $G^k$ if and only if the shortest path
|
| 356 |
+
distance between $u$ and $v$ in $G$ is at most $k$.
|
| 357 |
+
|
| 358 |
+
Parameters
|
| 359 |
+
----------
|
| 360 |
+
G : graph
|
| 361 |
+
A NetworkX simple graph object.
|
| 362 |
+
|
| 363 |
+
k : positive integer
|
| 364 |
+
The power to which to raise the graph `G`.
|
| 365 |
+
|
| 366 |
+
Returns
|
| 367 |
+
-------
|
| 368 |
+
NetworkX simple graph
|
| 369 |
+
`G` to the power `k`.
|
| 370 |
+
|
| 371 |
+
Raises
|
| 372 |
+
------
|
| 373 |
+
ValueError
|
| 374 |
+
If the exponent `k` is not positive.
|
| 375 |
+
|
| 376 |
+
NetworkXNotImplemented
|
| 377 |
+
If `G` is not a simple graph.
|
| 378 |
+
|
| 379 |
+
Examples
|
| 380 |
+
--------
|
| 381 |
+
The number of edges will never decrease when taking successive
|
| 382 |
+
powers:
|
| 383 |
+
|
| 384 |
+
>>> G = nx.path_graph(4)
|
| 385 |
+
>>> list(nx.power(G, 2).edges)
|
| 386 |
+
[(0, 1), (0, 2), (1, 2), (1, 3), (2, 3)]
|
| 387 |
+
>>> list(nx.power(G, 3).edges)
|
| 388 |
+
[(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]
|
| 389 |
+
|
| 390 |
+
The `k` th power of a cycle graph on *n* nodes is the complete graph
|
| 391 |
+
on *n* nodes, if `k` is at least ``n // 2``:
|
| 392 |
+
|
| 393 |
+
>>> G = nx.cycle_graph(5)
|
| 394 |
+
>>> H = nx.complete_graph(5)
|
| 395 |
+
>>> nx.is_isomorphic(nx.power(G, 2), H)
|
| 396 |
+
True
|
| 397 |
+
>>> G = nx.cycle_graph(8)
|
| 398 |
+
>>> H = nx.complete_graph(8)
|
| 399 |
+
>>> nx.is_isomorphic(nx.power(G, 4), H)
|
| 400 |
+
True
|
| 401 |
+
|
| 402 |
+
References
|
| 403 |
+
----------
|
| 404 |
+
.. [1] J. A. Bondy, U. S. R. Murty, *Graph Theory*. Springer, 2008.
|
| 405 |
+
|
| 406 |
+
Notes
|
| 407 |
+
-----
|
| 408 |
+
This definition of "power graph" comes from Exercise 3.1.6 of
|
| 409 |
+
*Graph Theory* by Bondy and Murty [1]_.
|
| 410 |
+
|
| 411 |
+
"""
|
| 412 |
+
if k <= 0:
|
| 413 |
+
raise ValueError("k must be a positive integer")
|
| 414 |
+
H = nx.Graph()
|
| 415 |
+
H.add_nodes_from(G)
|
| 416 |
+
# update BFS code to ignore self loops.
|
| 417 |
+
for n in G:
|
| 418 |
+
seen = {} # level (number of hops) when seen in BFS
|
| 419 |
+
level = 1 # the current level
|
| 420 |
+
nextlevel = G[n]
|
| 421 |
+
while nextlevel:
|
| 422 |
+
thislevel = nextlevel # advance to next level
|
| 423 |
+
nextlevel = {} # and start a new list (fringe)
|
| 424 |
+
for v in thislevel:
|
| 425 |
+
if v == n: # avoid self loop
|
| 426 |
+
continue
|
| 427 |
+
if v not in seen:
|
| 428 |
+
seen[v] = level # set the level of vertex v
|
| 429 |
+
nextlevel.update(G[v]) # add neighbors of v
|
| 430 |
+
if k <= level:
|
| 431 |
+
break
|
| 432 |
+
level += 1
|
| 433 |
+
H.add_edges_from((n, nbr) for nbr in seen)
|
| 434 |
+
return H
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
@not_implemented_for("multigraph")
|
| 438 |
+
@nx._dispatchable(graphs=_G_H, returns_graph=True)
|
| 439 |
+
def rooted_product(G, H, root):
|
| 440 |
+
"""Return the rooted product of graphs G and H rooted at root in H.
|
| 441 |
+
|
| 442 |
+
A new graph is constructed representing the rooted product of
|
| 443 |
+
the inputted graphs, G and H, with a root in H.
|
| 444 |
+
A rooted product duplicates H for each nodes in G with the root
|
| 445 |
+
of H corresponding to the node in G. Nodes are renamed as the direct
|
| 446 |
+
product of G and H. The result is a subgraph of the cartesian product.
|
| 447 |
+
|
| 448 |
+
Parameters
|
| 449 |
+
----------
|
| 450 |
+
G,H : graph
|
| 451 |
+
A NetworkX graph
|
| 452 |
+
root : node
|
| 453 |
+
A node in H
|
| 454 |
+
|
| 455 |
+
Returns
|
| 456 |
+
-------
|
| 457 |
+
R : The rooted product of G and H with a specified root in H
|
| 458 |
+
|
| 459 |
+
Notes
|
| 460 |
+
-----
|
| 461 |
+
The nodes of R are the Cartesian Product of the nodes of G and H.
|
| 462 |
+
The nodes of G and H are not relabeled.
|
| 463 |
+
"""
|
| 464 |
+
if root not in H:
|
| 465 |
+
raise nx.NodeNotFound("root must be a vertex in H")
|
| 466 |
+
|
| 467 |
+
R = nx.Graph()
|
| 468 |
+
R.add_nodes_from(product(G, H))
|
| 469 |
+
|
| 470 |
+
R.add_edges_from(((e[0], root), (e[1], root)) for e in G.edges())
|
| 471 |
+
R.add_edges_from(((g, e[0]), (g, e[1])) for g in G for e in H.edges())
|
| 472 |
+
|
| 473 |
+
return R
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
@not_implemented_for("directed")
|
| 477 |
+
@not_implemented_for("multigraph")
|
| 478 |
+
@nx._dispatchable(graphs=_G_H, returns_graph=True)
|
| 479 |
+
def corona_product(G, H):
|
| 480 |
+
r"""Returns the Corona product of G and H.
|
| 481 |
+
|
| 482 |
+
The corona product of $G$ and $H$ is the graph $C = G \circ H$ obtained by
|
| 483 |
+
taking one copy of $G$, called the center graph, $|V(G)|$ copies of $H$,
|
| 484 |
+
called the outer graph, and making the $i$-th vertex of $G$ adjacent to
|
| 485 |
+
every vertex of the $i$-th copy of $H$, where $1 ≤ i ≤ |V(G)|$.
|
| 486 |
+
|
| 487 |
+
Parameters
|
| 488 |
+
----------
|
| 489 |
+
G, H: NetworkX graphs
|
| 490 |
+
The graphs to take the carona product of.
|
| 491 |
+
`G` is the center graph and `H` is the outer graph
|
| 492 |
+
|
| 493 |
+
Returns
|
| 494 |
+
-------
|
| 495 |
+
C: NetworkX graph
|
| 496 |
+
The Corona product of G and H.
|
| 497 |
+
|
| 498 |
+
Raises
|
| 499 |
+
------
|
| 500 |
+
NetworkXError
|
| 501 |
+
If G and H are not both directed or both undirected.
|
| 502 |
+
|
| 503 |
+
Examples
|
| 504 |
+
--------
|
| 505 |
+
>>> G = nx.cycle_graph(4)
|
| 506 |
+
>>> H = nx.path_graph(2)
|
| 507 |
+
>>> C = nx.corona_product(G, H)
|
| 508 |
+
>>> list(C)
|
| 509 |
+
[0, 1, 2, 3, (0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1), (3, 0), (3, 1)]
|
| 510 |
+
>>> print(C)
|
| 511 |
+
Graph with 12 nodes and 16 edges
|
| 512 |
+
|
| 513 |
+
References
|
| 514 |
+
----------
|
| 515 |
+
[1] M. Tavakoli, F. Rahbarnia, and A. R. Ashrafi,
|
| 516 |
+
"Studying the corona product of graphs under some graph invariants,"
|
| 517 |
+
Transactions on Combinatorics, vol. 3, no. 3, pp. 43–49, Sep. 2014,
|
| 518 |
+
doi: 10.22108/toc.2014.5542.
|
| 519 |
+
[2] A. Faraji, "Corona Product in Graph Theory," Ali Faraji, May 11, 2021.
|
| 520 |
+
https://blog.alifaraji.ir/math/graph-theory/corona-product.html (accessed Dec. 07, 2021).
|
| 521 |
+
"""
|
| 522 |
+
GH = _init_product_graph(G, H)
|
| 523 |
+
GH.add_nodes_from(G)
|
| 524 |
+
GH.add_edges_from(G.edges)
|
| 525 |
+
|
| 526 |
+
for G_node in G:
|
| 527 |
+
# copy nodes of H in GH, call it H_i
|
| 528 |
+
GH.add_nodes_from((G_node, v) for v in H)
|
| 529 |
+
|
| 530 |
+
# copy edges of H_i based on H
|
| 531 |
+
GH.add_edges_from(
|
| 532 |
+
((G_node, e0), (G_node, e1), d) for e0, e1, d in H.edges.data()
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
# creating new edges between H_i and a G's node
|
| 536 |
+
GH.add_edges_from((G_node, (G_node, H_node)) for H_node in H)
|
| 537 |
+
|
| 538 |
+
return GH
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
@nx._dispatchable(
|
| 542 |
+
graphs=_G_H, preserve_edge_attrs=True, preserve_node_attrs=True, returns_graph=True
|
| 543 |
+
)
|
| 544 |
+
def modular_product(G, H):
|
| 545 |
+
r"""Returns the Modular product of G and H.
|
| 546 |
+
|
| 547 |
+
The modular product of `G` and `H` is the graph $M = G \nabla H$,
|
| 548 |
+
consisting of the node set $V(M) = V(G) \times V(H)$ that is the Cartesian
|
| 549 |
+
product of the node sets of `G` and `H`. Further, M contains an edge ((u, v), (x, y)):
|
| 550 |
+
|
| 551 |
+
- if u is adjacent to x in `G` and v is adjacent to y in `H`, or
|
| 552 |
+
- if u is not adjacent to x in `G` and v is not adjacent to y in `H`.
|
| 553 |
+
|
| 554 |
+
More formally::
|
| 555 |
+
|
| 556 |
+
E(M) = {((u, v), (x, y)) | ((u, x) in E(G) and (v, y) in E(H)) or
|
| 557 |
+
((u, x) not in E(G) and (v, y) not in E(H))}
|
| 558 |
+
|
| 559 |
+
Parameters
|
| 560 |
+
----------
|
| 561 |
+
G, H: NetworkX graphs
|
| 562 |
+
The graphs to take the modular product of.
|
| 563 |
+
|
| 564 |
+
Returns
|
| 565 |
+
-------
|
| 566 |
+
M: NetworkX graph
|
| 567 |
+
The Modular product of `G` and `H`.
|
| 568 |
+
|
| 569 |
+
Raises
|
| 570 |
+
------
|
| 571 |
+
NetworkXNotImplemented
|
| 572 |
+
If `G` is not a simple graph.
|
| 573 |
+
|
| 574 |
+
Examples
|
| 575 |
+
--------
|
| 576 |
+
>>> G = nx.cycle_graph(4)
|
| 577 |
+
>>> H = nx.path_graph(2)
|
| 578 |
+
>>> M = nx.modular_product(G, H)
|
| 579 |
+
>>> list(M)
|
| 580 |
+
[(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1), (3, 0), (3, 1)]
|
| 581 |
+
>>> print(M)
|
| 582 |
+
Graph with 8 nodes and 8 edges
|
| 583 |
+
|
| 584 |
+
Notes
|
| 585 |
+
-----
|
| 586 |
+
The *modular product* is defined in [1]_ and was first
|
| 587 |
+
introduced as the *weak modular product*.
|
| 588 |
+
|
| 589 |
+
The modular product reduces the problem of counting isomorphic subgraphs
|
| 590 |
+
in `G` and `H` to the problem of counting cliques in M. The subgraphs of
|
| 591 |
+
`G` and `H` that are induced by the nodes of a clique in M are
|
| 592 |
+
isomorphic [2]_ [3]_.
|
| 593 |
+
|
| 594 |
+
References
|
| 595 |
+
----------
|
| 596 |
+
.. [1] R. Hammack, W. Imrich, and S. Klavžar,
|
| 597 |
+
"Handbook of Product Graphs", CRC Press, 2011.
|
| 598 |
+
|
| 599 |
+
.. [2] H. G. Barrow and R. M. Burstall,
|
| 600 |
+
"Subgraph isomorphism, matching relational structures and maximal
|
| 601 |
+
cliques", Information Processing Letters, vol. 4, issue 4, pp. 83-84,
|
| 602 |
+
1976, https://doi.org/10.1016/0020-0190(76)90049-1.
|
| 603 |
+
|
| 604 |
+
.. [3] V. G. Vizing, "Reduction of the problem of isomorphism and isomorphic
|
| 605 |
+
entrance to the task of finding the nondensity of a graph." Proc. Third
|
| 606 |
+
All-Union Conference on Problems of Theoretical Cybernetics. 1974.
|
| 607 |
+
"""
|
| 608 |
+
if G.is_directed() or H.is_directed():
|
| 609 |
+
raise nx.NetworkXNotImplemented(
|
| 610 |
+
"Modular product not implemented for directed graphs"
|
| 611 |
+
)
|
| 612 |
+
if G.is_multigraph() or H.is_multigraph():
|
| 613 |
+
raise nx.NetworkXNotImplemented(
|
| 614 |
+
"Modular product not implemented for multigraphs"
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
GH = _init_product_graph(G, H)
|
| 618 |
+
GH.add_nodes_from(_node_product(G, H))
|
| 619 |
+
|
| 620 |
+
for u, v, c in G.edges(data=True):
|
| 621 |
+
for x, y, d in H.edges(data=True):
|
| 622 |
+
GH.add_edge((u, x), (v, y), **_dict_product(c, d))
|
| 623 |
+
GH.add_edge((v, x), (u, y), **_dict_product(c, d))
|
| 624 |
+
|
| 625 |
+
G = nx.complement(G)
|
| 626 |
+
H = nx.complement(H)
|
| 627 |
+
|
| 628 |
+
for u, v, c in G.edges(data=True):
|
| 629 |
+
for x, y, d in H.edges(data=True):
|
| 630 |
+
GH.add_edge((u, x), (v, y), **_dict_product(c, d))
|
| 631 |
+
GH.add_edge((v, x), (u, y), **_dict_product(c, d))
|
| 632 |
+
|
| 633 |
+
return GH
|