BryanW commited on
Commit
c9c6027
·
verified ·
1 Parent(s): 2ee4cd6

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Prism/Dream/Dream_Baseline/LICENSE +201 -0
  2. Prism/Dream/Dream_Prism/LICENSE +201 -0
  3. Prism/Dream/Dream_Prism/eval_instruct/.gitignore +26 -0
  4. Prism/Dream/Dream_Prism/eval_instruct/README.md +16 -0
  5. Prism/Dream/Dream_Prism/eval_instruct/lm_eval/__init__.py +7 -0
  6. Prism/Dream/Dream_Prism/eval_instruct/lm_eval/__main__.py +512 -0
  7. Prism/Dream/Dream_Prism/eval_instruct/lm_eval/api/filter.py +56 -0
  8. Prism/Dream/Dream_Prism/eval_instruct/lm_eval/api/model.py +493 -0
  9. Prism/Dream/Dream_Prism/eval_instruct/lm_eval/api/samplers.py +232 -0
  10. Prism/Dream/Dream_Prism/eval_instruct/lm_eval/api/task.py +1839 -0
  11. Prism/Dream/Dream_Prism/eval_instruct/lm_eval/caching/__init__.py +0 -0
  12. Prism/Dream/Dream_Prism/eval_instruct/lm_eval/caching/cache.py +59 -0
  13. Prism/Dream/Dream_Prism/eval_instruct/lm_eval/decontamination/__init__.py +0 -0
  14. Prism/Dream/Dream_Prism/eval_instruct/lm_eval/decontamination/archiver.py +174 -0
  15. Prism/Dream/Dream_Prism/eval_instruct/lm_eval/decontamination/decontaminate.py +166 -0
  16. Prism/Dream/Dream_Prism/eval_instruct/lm_eval/decontamination/janitor.py +328 -0
  17. Prism/Dream/Dream_Prism/eval_instruct/lm_eval/evaluator.py +736 -0
  18. Prism/Dream/Dream_Prism/eval_instruct/lm_eval/evaluator_utils.py +554 -0
  19. Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/__init__.py +25 -0
  20. Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/custom.py +17 -0
  21. Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/decontamination.py +25 -0
  22. Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/extraction.py +188 -0
  23. Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/selection.py +61 -0
  24. Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/transformation.py +56 -0
  25. Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/__init__.py +17 -0
  26. Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/diffllm.py +563 -0
  27. Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/dummy.py +41 -0
  28. Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/hts_sampler.py +257 -0
  29. Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/huggingface.py +1459 -0
  30. Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/utils.py +731 -0
  31. Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/verifier.py +155 -0
  32. Prism/Dream/Dream_Prism/eval_instruct/lm_eval/utils.py +552 -0
  33. Prism/Dream/Dream_Prism/eval_instruct/pyproject.toml +134 -0
  34. Prism/Dream/Dream_Prism/eval_instruct/requirements.txt +1 -0
  35. Prism/Dream/Dream_Prism/eval_instruct/setup.py +5 -0
  36. Prism/Dream/Dream_Prism/metrics/gsmk8_eval.py +188 -0
  37. Prism/Dream/Dream_Prism/metrics/humaneval_eval.py +234 -0
  38. Prism/Dream/Dream_Prism/metrics/math500_eval.py +205 -0
  39. Prism/Dream/Dream_Prism/metrics/mbpp_eval.py +281 -0
  40. Prism/Dream/Dream_Prism/scripts/run_gsm8k.sh +31 -0
  41. Prism/Dream/Dream_Prism/scripts/run_humaneval.sh +31 -0
  42. Prism/Dream/Dream_Prism/scripts/run_math500.sh +30 -0
  43. Prism/Dream/Dream_Prism/scripts/run_mbpp.sh +31 -0
  44. Prism/Dream/Dream_Prism/src/__init__.py +0 -0
  45. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/aiohappyeyeballs/__pycache__/types.cpython-312.pyc +0 -0
  46. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/components/semiconnected.py +71 -0
  47. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/operators/__init__.py +4 -0
  48. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/operators/all.py +324 -0
  49. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/operators/binary.py +468 -0
  50. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/operators/product.py +633 -0
Prism/Dream/Dream_Baseline/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
Prism/Dream/Dream_Prism/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
Prism/Dream/Dream_Prism/eval_instruct/.gitignore ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ env
2
+ *.pyc
3
+ output/
4
+ output5/
5
+ data/
6
+ lm_cache
7
+ .idea
8
+ build
9
+ dist
10
+ *.egg-info
11
+ venv
12
+ .venv/
13
+ .vscode/
14
+ temp
15
+ __pycache__
16
+ .ipynb_checkpoints
17
+ temp
18
+ test_logs/
19
+ # IPython
20
+ profile_default/
21
+ ipython_config.py
22
+ # don't track (the default location of) the cached requests
23
+ lm_eval/caching/.cache
24
+ # don't track files created by wandb
25
+ wandb
26
+ examples/wandb
Prism/Dream/Dream_Prism/eval_instruct/README.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dream-Instruct Evaluation Toolkit
2
+ This toolkit contains the code Dream-Instruct models make use of for evaluation.
3
+
4
+ ## Quickstart
5
+ To install the toolkit, run:
6
+ ```
7
+ pip install -e ".[ifeval,math]"
8
+ ```
9
+
10
+ We provide a script to evaluate [Dream-org/Dream-v0-Instruct-7B](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B):
11
+ ```
12
+ bash eval.sh
13
+ ```
14
+
15
+ ## Acknowledgement
16
+ This is a fork of [EleutherAI/lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/main).
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ from .evaluator import evaluate, simple_evaluate
5
+
6
+
7
+ __version__ = "0.4.8"
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/__main__.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import logging
4
+ import os
5
+ import sys
6
+ from functools import partial
7
+ from typing import Union
8
+
9
+ from lm_eval import evaluator, utils
10
+ from lm_eval.evaluator import request_caching_arg_to_dict
11
+ from lm_eval.loggers import EvaluationTracker, WandbLogger
12
+ from lm_eval.tasks import TaskManager
13
+ from lm_eval.utils import (
14
+ handle_non_serializable,
15
+ make_table,
16
+ simple_parse_args_string,
17
+ )
18
+
19
+
20
+ def try_parse_json(value: str) -> Union[str, dict, None]:
21
+ if value is None:
22
+ return None
23
+ try:
24
+ return json.loads(value)
25
+ except json.JSONDecodeError:
26
+ if "{" in value:
27
+ raise argparse.ArgumentTypeError(
28
+ f"Invalid JSON: {value}. Hint: Use double quotes for JSON strings."
29
+ )
30
+ return value
31
+
32
+
33
+ def _int_or_none_list_arg_type(
34
+ min_len: int, max_len: int, defaults: str, value: str, split_char: str = ","
35
+ ):
36
+ def parse_value(item):
37
+ item = item.strip().lower()
38
+ if item == "none":
39
+ return None
40
+ try:
41
+ return int(item)
42
+ except ValueError:
43
+ raise argparse.ArgumentTypeError(f"{item} is not an integer or None")
44
+
45
+ items = [parse_value(v) for v in value.split(split_char)]
46
+ num_items = len(items)
47
+
48
+ if num_items == 1:
49
+ # Makes downstream handling the same for single and multiple values
50
+ items = items * max_len
51
+ elif num_items < min_len or num_items > max_len:
52
+ raise argparse.ArgumentTypeError(
53
+ f"Argument requires {max_len} integers or None, separated by '{split_char}'"
54
+ )
55
+ elif num_items != max_len:
56
+ logging.warning(
57
+ f"Argument requires {max_len} integers or None, separated by '{split_char}'. "
58
+ "Missing values will be filled with defaults."
59
+ )
60
+ default_items = [parse_value(v) for v in defaults.split(split_char)]
61
+ items.extend(
62
+ default_items[num_items:]
63
+ ) # extend items list with missing defaults
64
+
65
+ return items
66
+
67
+
68
+ def check_argument_types(parser: argparse.ArgumentParser):
69
+ """
70
+ Check to make sure all CLI args are typed, raises error if not
71
+ """
72
+ for action in parser._actions:
73
+ if action.dest != "help" and not action.const:
74
+ if action.type is None:
75
+ raise ValueError(
76
+ f"Argument '{action.dest}' doesn't have a type specified."
77
+ )
78
+ else:
79
+ continue
80
+
81
+
82
+ def setup_parser() -> argparse.ArgumentParser:
83
+ parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
84
+ parser.add_argument(
85
+ "--model", "-m", type=str, default="hf", help="Name of model e.g. `hf`"
86
+ )
87
+ parser.add_argument(
88
+ "--tasks",
89
+ "-t",
90
+ default=None,
91
+ type=str,
92
+ metavar="task1,task2",
93
+ help="Comma-separated list of task names or task groupings to evaluate on.\nTo get full list of tasks, use one of the commands `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above",
94
+ )
95
+ parser.add_argument(
96
+ "--model_args",
97
+ "-a",
98
+ default="",
99
+ type=try_parse_json,
100
+ help="""Comma separated string or JSON formatted arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32` or '{"pretrained":"EleutherAI/pythia-160m","dtype":"float32"}'""",
101
+ )
102
+ parser.add_argument(
103
+ "--num_fewshot",
104
+ "-f",
105
+ type=int,
106
+ default=None,
107
+ metavar="N",
108
+ help="Number of examples in few-shot context",
109
+ )
110
+ parser.add_argument(
111
+ "--batch_size",
112
+ "-b",
113
+ type=str,
114
+ default=1,
115
+ metavar="auto|auto:N|N",
116
+ help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.",
117
+ )
118
+ parser.add_argument(
119
+ "--max_batch_size",
120
+ type=int,
121
+ default=None,
122
+ metavar="N",
123
+ help="Maximal batch size to try with --batch_size auto.",
124
+ )
125
+ parser.add_argument(
126
+ "--device",
127
+ type=str,
128
+ default=None,
129
+ help="Device to use (e.g. cuda, cuda:0, cpu).",
130
+ )
131
+ parser.add_argument(
132
+ "--output_path",
133
+ "-o",
134
+ default=None,
135
+ type=str,
136
+ metavar="DIR|DIR/file.json",
137
+ help="The path to the output file where the result metrics will be saved. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.",
138
+ )
139
+ parser.add_argument(
140
+ "--limit",
141
+ "-L",
142
+ type=float,
143
+ default=None,
144
+ metavar="N|0<N<1",
145
+ help="Limit the number of examples per task. "
146
+ "If <1, limit is a percentage of the total number of examples.",
147
+ )
148
+ parser.add_argument(
149
+ "--use_cache",
150
+ "-c",
151
+ type=str,
152
+ default=None,
153
+ metavar="DIR",
154
+ help="A path to a sqlite db file for caching model responses. `None` if not caching.",
155
+ )
156
+ parser.add_argument(
157
+ "--cache_requests",
158
+ type=str,
159
+ default=None,
160
+ choices=["true", "refresh", "delete"],
161
+ help="Speed up evaluation by caching the building of dataset requests. `None` if not caching.",
162
+ )
163
+ parser.add_argument(
164
+ "--check_integrity",
165
+ action="store_true",
166
+ help="Whether to run the relevant part of the test suite for the tasks.",
167
+ )
168
+ parser.add_argument(
169
+ "--write_out",
170
+ "-w",
171
+ action="store_true",
172
+ default=False,
173
+ help="Prints the prompt for the first few documents.",
174
+ )
175
+ parser.add_argument(
176
+ "--log_samples",
177
+ "-s",
178
+ action="store_true",
179
+ default=False,
180
+ help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis. Use with --output_path.",
181
+ )
182
+ parser.add_argument(
183
+ "--system_instruction",
184
+ type=str,
185
+ default=None,
186
+ help="System instruction to be used in the prompt",
187
+ )
188
+ parser.add_argument(
189
+ "--apply_chat_template",
190
+ type=str,
191
+ nargs="?",
192
+ const=True,
193
+ default=False,
194
+ help=(
195
+ "If True, apply chat template to the prompt. "
196
+ "Providing `--apply_chat_template` without an argument will apply the default chat template to the prompt. "
197
+ "To apply a specific template from the available list of templates, provide the template name as an argument. "
198
+ "E.g. `--apply_chat_template template_name`"
199
+ ),
200
+ )
201
+ parser.add_argument(
202
+ "--fewshot_as_multiturn",
203
+ action="store_true",
204
+ default=False,
205
+ help="If True, uses the fewshot as a multi-turn conversation",
206
+ )
207
+ parser.add_argument(
208
+ "--show_config",
209
+ action="store_true",
210
+ default=False,
211
+ help="If True, shows the the full config of all tasks at the end of the evaluation.",
212
+ )
213
+ parser.add_argument(
214
+ "--include_path",
215
+ type=str,
216
+ default=None,
217
+ metavar="DIR",
218
+ help="Additional path to include if there are external tasks to include.",
219
+ )
220
+ parser.add_argument(
221
+ "--gen_kwargs",
222
+ type=try_parse_json,
223
+ default=None,
224
+ help=(
225
+ "Either comma delimited string or JSON formatted arguments for model generation on greedy_until tasks,"
226
+ """ e.g. '{"temperature":0.7,"until":["hello"]}' or temperature=0,top_p=0.1."""
227
+ ),
228
+ )
229
+ parser.add_argument(
230
+ "--verbosity",
231
+ "-v",
232
+ type=str.upper,
233
+ default=None,
234
+ metavar="CRITICAL|ERROR|WARNING|INFO|DEBUG",
235
+ help="(Deprecated) Controls logging verbosity level. Use the `LOGLEVEL` environment variable instead. Set to DEBUG for detailed output when testing or adding new task configurations.",
236
+ )
237
+ parser.add_argument(
238
+ "--wandb_args",
239
+ type=str,
240
+ default="",
241
+ help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval",
242
+ )
243
+ parser.add_argument(
244
+ "--wandb_config_args",
245
+ type=str,
246
+ default="",
247
+ help="Comma separated string arguments passed to wandb.config.update. Use this to trace parameters that aren't already traced by default. eg. `lr=0.01,repeats=3",
248
+ )
249
+ parser.add_argument(
250
+ "--hf_hub_log_args",
251
+ type=str,
252
+ default="",
253
+ help="Comma separated string arguments passed to Hugging Face Hub's log function, e.g. `hub_results_org=EleutherAI,hub_repo_name=lm-eval-results`",
254
+ )
255
+ parser.add_argument(
256
+ "--predict_only",
257
+ "-x",
258
+ action="store_true",
259
+ default=False,
260
+ help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.",
261
+ )
262
+ default_seed_string = "0,1234,1234,1234"
263
+ parser.add_argument(
264
+ "--seed",
265
+ type=partial(_int_or_none_list_arg_type, 3, 4, default_seed_string),
266
+ default=default_seed_string, # for backward compatibility
267
+ help=(
268
+ "Set seed for python's random, numpy, torch, and fewshot sampling.\n"
269
+ "Accepts a comma-separated list of 4 values for python's random, numpy, torch, and fewshot sampling seeds, "
270
+ "respectively, or a single integer to set the same seed for all four.\n"
271
+ f"The values are either an integer or 'None' to not set the seed. Default is `{default_seed_string}` "
272
+ "(for backward compatibility).\n"
273
+ "E.g. `--seed 0,None,8,52` sets `random.seed(0)`, `torch.manual_seed(8)`, and fewshot sampling seed to 52. "
274
+ "Here numpy's seed is not set since the second value is `None`.\n"
275
+ "E.g, `--seed 42` sets all four seeds to 42."
276
+ ),
277
+ )
278
+ parser.add_argument(
279
+ "--trust_remote_code",
280
+ action="store_true",
281
+ help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub",
282
+ )
283
+ parser.add_argument(
284
+ "--confirm_run_unsafe_code",
285
+ action="store_true",
286
+ help="Confirm that you understand the risks of running unsafe code for tasks that require it",
287
+ )
288
+ parser.add_argument(
289
+ "--metadata",
290
+ type=json.loads,
291
+ default=None,
292
+ help="""JSON string metadata to pass to task configs, for example '{"max_seq_lengths":[4096,8192]}'. Will be merged with model_args. Can also be set in task config.""",
293
+ )
294
+ return parser
295
+
296
+
297
+ def parse_eval_args(parser: argparse.ArgumentParser) -> argparse.Namespace:
298
+ check_argument_types(parser)
299
+ return parser.parse_args()
300
+
301
+
302
+ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
303
+ if not args:
304
+ # we allow for args to be passed externally, else we parse them ourselves
305
+ parser = setup_parser()
306
+ args = parse_eval_args(parser)
307
+
308
+ if args.wandb_args:
309
+ wandb_args_dict = simple_parse_args_string(args.wandb_args)
310
+ wandb_config_args_dict = simple_parse_args_string(args.wandb_config_args)
311
+ wandb_logger = WandbLogger(wandb_args_dict, wandb_config_args_dict)
312
+
313
+ utils.setup_logging(args.verbosity)
314
+ eval_logger = logging.getLogger(__name__)
315
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
316
+
317
+ # update the evaluation tracker args with the output path and the HF token
318
+ if args.output_path:
319
+ args.hf_hub_log_args += f",output_path={args.output_path}"
320
+ if os.environ.get("HF_TOKEN", None):
321
+ args.hf_hub_log_args += f",token={os.environ.get('HF_TOKEN')}"
322
+ evaluation_tracker_args = simple_parse_args_string(args.hf_hub_log_args)
323
+ evaluation_tracker = EvaluationTracker(**evaluation_tracker_args)
324
+
325
+ if args.predict_only:
326
+ args.log_samples = True
327
+ if (args.log_samples or args.predict_only) and not args.output_path:
328
+ raise ValueError(
329
+ "Specify --output_path if providing --log_samples or --predict_only"
330
+ )
331
+
332
+ if args.fewshot_as_multiturn and args.apply_chat_template is False:
333
+ raise ValueError(
334
+ "When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set (either to `True` or to the chosen template name)."
335
+ )
336
+
337
+ if args.include_path is not None:
338
+ eval_logger.info(f"Including path: {args.include_path}")
339
+ metadata = (
340
+ simple_parse_args_string(args.model_args)
341
+ if isinstance(args.model_args, str)
342
+ else args.model_args
343
+ if isinstance(args.model_args, dict)
344
+ else {}
345
+ ) | (
346
+ args.metadata
347
+ if isinstance(args.metadata, dict)
348
+ else simple_parse_args_string(args.metadata)
349
+ )
350
+
351
+ task_manager = TaskManager(include_path=args.include_path, metadata=metadata)
352
+
353
+ if "push_samples_to_hub" in evaluation_tracker_args and not args.log_samples:
354
+ eval_logger.warning(
355
+ "Pushing samples to the Hub requires --log_samples to be set. Samples will not be pushed to the Hub."
356
+ )
357
+
358
+ if args.limit:
359
+ eval_logger.warning(
360
+ " --limit SHOULD ONLY BE USED FOR TESTING."
361
+ "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
362
+ )
363
+
364
+ if args.tasks is None:
365
+ eval_logger.error("Need to specify task to evaluate.")
366
+ sys.exit()
367
+ elif args.tasks == "list":
368
+ print(task_manager.list_all_tasks())
369
+ sys.exit()
370
+ elif args.tasks == "list_groups":
371
+ print(task_manager.list_all_tasks(list_subtasks=False, list_tags=False))
372
+ sys.exit()
373
+ elif args.tasks == "list_tags":
374
+ print(task_manager.list_all_tasks(list_groups=False, list_subtasks=False))
375
+ sys.exit()
376
+ elif args.tasks == "list_subtasks":
377
+ print(task_manager.list_all_tasks(list_groups=False, list_tags=False))
378
+ sys.exit()
379
+ else:
380
+ if os.path.isdir(args.tasks):
381
+ import glob
382
+
383
+ task_names = []
384
+ yaml_path = os.path.join(args.tasks, "*.yaml")
385
+ for yaml_file in glob.glob(yaml_path):
386
+ config = utils.load_yaml_config(yaml_file)
387
+ task_names.append(config)
388
+ else:
389
+ task_list = args.tasks.split(",")
390
+ task_names = task_manager.match_tasks(task_list)
391
+ for task in [task for task in task_list if task not in task_names]:
392
+ if os.path.isfile(task):
393
+ config = utils.load_yaml_config(task)
394
+ task_names.append(config)
395
+ task_missing = [
396
+ task for task in task_list if task not in task_names and "*" not in task
397
+ ] # we don't want errors if a wildcard ("*") task name was used
398
+
399
+ if task_missing:
400
+ missing = ", ".join(task_missing)
401
+ eval_logger.error(
402
+ f"Tasks were not found: {missing}\n"
403
+ f"{utils.SPACING}Try `lm-eval --tasks list` for list of available tasks",
404
+ )
405
+ raise ValueError(
406
+ f"Tasks not found: {missing}. Try `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above, or pass '--verbosity DEBUG' to troubleshoot task registration issues."
407
+ )
408
+
409
+ # Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args
410
+ if args.trust_remote_code:
411
+ eval_logger.info(
412
+ "Passed `--trust_remote_code`, setting environment variable `HF_DATASETS_TRUST_REMOTE_CODE=true`"
413
+ )
414
+ # HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally,
415
+ # because it's already been determined based on the prior env var before launching our
416
+ # script--`datasets` gets imported by lm_eval internally before these lines can update the env.
417
+ import datasets
418
+
419
+ datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
420
+
421
+ args.model_args = args.model_args + ",trust_remote_code=True"
422
+ eval_logger.info(
423
+ f"Selected Tasks: {task_names}"
424
+ ) if eval_logger.getEffectiveLevel() >= logging.INFO else print(
425
+ f"Selected Tasks: {task_names}"
426
+ )
427
+
428
+ request_caching_args = request_caching_arg_to_dict(
429
+ cache_requests=args.cache_requests
430
+ )
431
+
432
+ results = evaluator.simple_evaluate(
433
+ model=args.model,
434
+ model_args=args.model_args,
435
+ tasks=task_names,
436
+ num_fewshot=args.num_fewshot,
437
+ batch_size=args.batch_size,
438
+ max_batch_size=args.max_batch_size,
439
+ device=args.device,
440
+ use_cache=args.use_cache,
441
+ limit=args.limit,
442
+ check_integrity=args.check_integrity,
443
+ write_out=args.write_out,
444
+ log_samples=args.log_samples,
445
+ evaluation_tracker=evaluation_tracker,
446
+ system_instruction=args.system_instruction,
447
+ apply_chat_template=args.apply_chat_template,
448
+ fewshot_as_multiturn=args.fewshot_as_multiturn,
449
+ gen_kwargs=args.gen_kwargs,
450
+ task_manager=task_manager,
451
+ predict_only=args.predict_only,
452
+ random_seed=args.seed[0],
453
+ numpy_random_seed=args.seed[1],
454
+ torch_random_seed=args.seed[2],
455
+ fewshot_random_seed=args.seed[3],
456
+ confirm_run_unsafe_code=args.confirm_run_unsafe_code,
457
+ metadata=metadata,
458
+ **request_caching_args,
459
+ )
460
+
461
+ if results is not None:
462
+ if args.log_samples:
463
+ samples = results.pop("samples")
464
+ dumped = json.dumps(
465
+ results, indent=2, default=handle_non_serializable, ensure_ascii=False
466
+ )
467
+ if args.show_config:
468
+ print(dumped)
469
+
470
+ batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))
471
+
472
+ # Add W&B logging
473
+ if args.wandb_args:
474
+ try:
475
+ wandb_logger.post_init(results)
476
+ wandb_logger.log_eval_result()
477
+ if args.log_samples:
478
+ wandb_logger.log_eval_samples(samples)
479
+ except Exception as e:
480
+ eval_logger.info(f"Logging to Weights and Biases failed due to {e}")
481
+
482
+ evaluation_tracker.save_results_aggregated(
483
+ results=results, samples=samples if args.log_samples else None
484
+ )
485
+
486
+ if args.log_samples:
487
+ for task_name, config in results["configs"].items():
488
+ evaluation_tracker.save_results_samples(
489
+ task_name=task_name, samples=samples[task_name]
490
+ )
491
+
492
+ if (
493
+ evaluation_tracker.push_results_to_hub
494
+ or evaluation_tracker.push_samples_to_hub
495
+ ):
496
+ evaluation_tracker.recreate_metadata_card()
497
+
498
+ print(
499
+ f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
500
+ f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
501
+ )
502
+ print(make_table(results))
503
+ if "groups" in results:
504
+ print(make_table(results, "groups"))
505
+
506
+ if args.wandb_args:
507
+ # Tear down wandb run once all the logging is done.
508
+ wandb_logger.run.finish()
509
+
510
+
511
+ if __name__ == "__main__":
512
+ cli_evaluate()
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/api/filter.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass
3
+ from typing import Callable, Iterable, List, Union
4
+
5
+ from lm_eval.api.instance import Instance
6
+
7
+
8
+ class Filter(ABC):
9
+ """
10
+ Filter classes operate on a per-task level.
11
+ They take all model outputs (`instance.resps` for all `task.instances`)
12
+ across all instances of a task, and perform operations.
13
+ In a single run, one can configure any number of separate filters or lists of filters.
14
+
15
+ """
16
+
17
+ def __init__(self, **kwargs) -> None:
18
+ """
19
+ Can define custom behavior here, if an individual instantiation of a Filter class should have state.
20
+ """
21
+
22
+ @abstractmethod
23
+ def apply(self, resps: Union[List, Iterable], docs: List[dict]) -> Iterable:
24
+ """
25
+ Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
26
+ Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
27
+ if pass in [<inst.resps for instance 0>, <inst.resps for instance 1>] should return
28
+ [<filtered resps for instance 0>, <filtered resps for instance 1>]
29
+ """
30
+ return resps
31
+
32
+
33
+ @dataclass
34
+ class FilterEnsemble:
35
+ """
36
+ FilterEnsemble creates a pipeline applying multiple filters.
37
+ Its intended usage is to stack multiple post-processing steps in order.
38
+ `task.apply_filters` should use a list of FilterEnsemble classes that it stores, to apply each
39
+ pipeline separately.
40
+ """
41
+
42
+ name: str
43
+ filters: List[Callable[[], Filter]]
44
+
45
+ def apply(self, instances: List[Instance]) -> None:
46
+ resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
47
+ resps, docs = list(resps), list(docs)
48
+
49
+ for f in self.filters:
50
+ # apply filters in sequence
51
+ resps = f().apply(resps, docs)
52
+
53
+ # add the end results after filtering to filtered_requests of their respective source instances.
54
+ # has key `self.name`: each FilterEnsemble applied in a given run should use a different name.
55
+ for inst, resp in zip(instances, resps):
56
+ inst.filtered_resps[self.name] = resp
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/api/model.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import hashlib
3
+ import json
4
+ import logging
5
+ import os
6
+ from typing import Dict, List, Optional, Tuple, Type, TypeVar, Union
7
+
8
+ import transformers
9
+ from sqlitedict import SqliteDict
10
+ from tqdm import tqdm
11
+
12
+ from lm_eval import utils
13
+
14
+
15
+ eval_logger = logging.getLogger(__name__)
16
+
17
+ T = TypeVar("T", bound="LM")
18
+
19
+
20
+ class LM(abc.ABC):
21
+ def __init__(self) -> None:
22
+ """Defines the interface that should be implemented by all LM subclasses.
23
+ LMs are assumed to take text (strings) as input and yield strings as output
24
+ (inputs/outputs should be tokenization-agnostic.)
25
+
26
+ """
27
+ # set rank and world size to a single process, by default.
28
+ self._rank = 0
29
+ self._world_size = 1
30
+ self.cache_hook = CacheHook(None)
31
+
32
+ @abc.abstractmethod
33
+ def loglikelihood(self, requests) -> List[Tuple[float, bool]]:
34
+ """Compute log-likelihood of generating a continuation from a context.
35
+ Downstream tasks should attempt to use loglikelihood instead of other
36
+ LM calls whenever possible.
37
+
38
+ :param requests: list[Instance]
39
+ A list of Instance objects, with property `args` which returns a tuple (context, continuation).
40
+ `context: str`
41
+ Context string. Implementations of LM must be able to handle an
42
+ empty context string.
43
+ `continuation: str`
44
+ The continuation over which log likelihood will be calculated. If
45
+ there is a word boundary, the space should be in the continuation.
46
+ For example, context="hello" continuation=" world" is correct.
47
+
48
+ :return: list[tuple[float, bool]]
49
+ A list of pairs (logprob, isgreedy)
50
+ `logprob: float`
51
+ The log probability of `continuation`.
52
+ `isgreedy`:
53
+ Whether `continuation` would be generated by greedy sampling from `context`.
54
+ """
55
+ pass
56
+
57
+ @abc.abstractmethod
58
+ def loglikelihood_rolling(self, requests) -> List[float]:
59
+ """Compute full log-likelihood of a string, with no truncation, for perplexity computation
60
+ - We will use the full max context length of the model.
61
+ - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
62
+ the max context length.
63
+ - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
64
+ which may simply concatenate multiple documents together.
65
+ - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
66
+ multiple chunks, the last input will still a full-sized context.
67
+ Example:
68
+ Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
69
+ Prefix: BOS/EOS
70
+ Max context length: 4
71
+ Resulting input/prediction pairs:
72
+
73
+ INPUT: BOS 0 1 2
74
+ PRED: 0 1 2 3
75
+
76
+ INPUT: 3 4 5 6
77
+ PRED: 4 5 6 7
78
+
79
+ INPUT: 5 6 7 8
80
+ PRED: 8 9
81
+
82
+ Observe that:
83
+ 1. Each token is predicted exactly once
84
+ 2. For the last pair, we provide the full context, but only score the last two tokens
85
+
86
+ :param requests: list[Instance]
87
+ A list of Instance objects with property `args` which returns a tuple (context,).
88
+ string: str
89
+ String for which we are computing overall loglikelihood
90
+ :return: list[tuple[float]]
91
+ A list of tuples (logprob,)
92
+ logprob: float
93
+ The log probability of `context` conditioned on the BOS/EOS token.
94
+ Can also be overridden for custom cases by `prefix_token_id`.
95
+ """
96
+ pass
97
+
98
+ # TODO: Add an optional max length
99
+ @abc.abstractmethod
100
+ def generate_until(self, requests) -> List[str]:
101
+ """Generate greedily until a stopping sequence
102
+
103
+ :param requests: list[Instance]
104
+ A list of Instance objects with property `args` which returns a tuple (context, gen_kwargs).
105
+ context: str
106
+ Context string
107
+ gen_kwargs: dict
108
+ A dictionary of keyword arguments to pass to the generation function e.g. top_k, until, etc.
109
+ :return: list[str]
110
+ A list of model generated continuations.
111
+ continuation: str
112
+ The generated continuation.
113
+ """
114
+ pass
115
+
116
+ def apply_chat_template(
117
+ self, chat_history: List[Dict[str, str]], add_generation_prompt=True
118
+ ) -> str:
119
+ """
120
+ Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM.
121
+
122
+ :param chat_history: list[dict[str, str]]
123
+ A list of dictionaries with keys 'role' and 'content'.
124
+ Values are strings representing the role name and the content of the message, respectively.
125
+ :param add_generation_prompt: bool
126
+ Whether to append an assistant gen prefix (for e.g. <|assistant|>) to the assistant messages in the chat history. False if prefilling an assistant message.
127
+ :return: str
128
+ A string representing the chat history in a format that can be used as input to the LM.
129
+ """
130
+ raise NotImplementedError(
131
+ "To use this model with chat templates, please implement the 'apply_chat_template' method for your model type."
132
+ )
133
+
134
+ @classmethod
135
+ def create_from_arg_string(
136
+ cls: Type[T], arg_string: str, additional_config: Optional[dict] = None
137
+ ) -> T:
138
+ """
139
+ Creates an instance of the LM class using the given argument string and additional config.
140
+
141
+ Parameters:
142
+ - arg_string: A string containing arguments in the format key1=value1,key2=value2.
143
+ - additional_config: Optional dictionary containing additional configuration parameters.
144
+
145
+ Returns:
146
+ - Instance of the LM class.
147
+ """
148
+ additional_config = {} if additional_config is None else additional_config
149
+ args = utils.simple_parse_args_string(arg_string)
150
+ args2 = {k: v for k, v in additional_config.items() if v is not None}
151
+ return cls(**args, **args2)
152
+
153
+ @classmethod
154
+ def create_from_arg_obj(
155
+ cls: Type[T], arg_dict: dict, additional_config: Optional[dict] = None
156
+ ) -> T:
157
+ """
158
+ Creates an instance of the LM class using the given arg_obj
159
+
160
+ Parameters:
161
+ - arg_obj: A dict containing arguments in the format key1=value1,key2=value2.
162
+ - additional_config: Optional dictionary containing additional configuration parameters.
163
+
164
+ Returns:
165
+ - Instance of the LM class.
166
+ """
167
+
168
+ additional_config = {} if additional_config is None else additional_config
169
+ additional_config = {
170
+ k: v for k, v in additional_config.items() if v is not None
171
+ }
172
+
173
+ return cls(**arg_dict, **additional_config)
174
+
175
+ @property
176
+ def rank(self):
177
+ # used in the case of parallelism. Hardcoded to
178
+ # ensure no errors arise using API models which do
179
+ # not support multi-device parallelism nor expect it.
180
+ return self._rank
181
+
182
+ @property
183
+ def world_size(self):
184
+ # used in the case of parallelism. Hardcoded to
185
+ # ensure no errors arise using API models which do
186
+ # not support multi-device parallelism nor expect it.
187
+ return self._world_size
188
+
189
+ @property
190
+ def tokenizer_name(self) -> str:
191
+ """Must be defined for LM subclasses which implement Chat Templating.
192
+ Should return the name of the tokenizer or chat template used.
193
+ Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used.
194
+ """
195
+ raise NotImplementedError(
196
+ "To use this model with chat templates, please implement the 'tokenizer_name' property."
197
+ )
198
+
199
+ def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
200
+ """Returns the chat template structure for user/assistant messages if a template is provided.
201
+ This method is intended to be overridden in a subclass to define a specific chat template format.
202
+ For models that do not support chat templates, this method returns None by default.
203
+ """
204
+
205
+ return ""
206
+
207
+ def set_cache_hook(self, cache_hook) -> None:
208
+ self.cache_hook = cache_hook
209
+
210
+
211
+ ### SQLite-based caching of LM responses
212
+ def hash_args(attr, args):
213
+ dat = json.dumps([attr] + list(args))
214
+ return hashlib.sha256(dat.encode("utf-8")).hexdigest()
215
+
216
+
217
+ class CacheHook:
218
+ def __init__(self, cachinglm) -> None:
219
+ if cachinglm is None:
220
+ self.dbdict = None
221
+ return
222
+
223
+ self.dbdict = cachinglm.dbdict
224
+
225
+ def add_partial(self, attr, req, res) -> None:
226
+ if self.dbdict is None:
227
+ return
228
+ hsh = hash_args(attr, req)
229
+ self.dbdict[hsh] = res
230
+
231
+
232
+ class CachingLM:
233
+ def __init__(self, lm, cache_db) -> None:
234
+ """LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
235
+
236
+ :param lm: LM
237
+ Underlying LM
238
+ :param cache_db: str
239
+ Path to cache db
240
+ """
241
+ self.lm = lm
242
+ self.cache_db = cache_db
243
+ if os.path.dirname(cache_db):
244
+ os.makedirs(os.path.dirname(cache_db), exist_ok=True)
245
+ self.dbdict = SqliteDict(cache_db, autocommit=True)
246
+
247
+ # add hook to lm
248
+ lm.set_cache_hook(self.get_cache_hook())
249
+
250
+ def __getattr__(self, attr: str):
251
+ lm_attr = getattr(self.lm, attr)
252
+ if attr not in ["loglikelihood", "loglikelihood_rolling", "generate_until"]:
253
+ eval_logger.debug(f"Passing through attribute '{attr}' to underlying LM")
254
+ return lm_attr
255
+
256
+ def fn(requests):
257
+ res = []
258
+ remaining_reqs = []
259
+ warned = False
260
+ # figure out which ones are cached and which ones are new
261
+ eval_logger.info(
262
+ f"Loading '{attr}' responses from cache '{self.cache_db}' where possible..."
263
+ )
264
+ for req in tqdm(requests, desc="Checking cached requests"):
265
+ hsh = hash_args(attr, req.args)
266
+ if attr == "generate_until" and req.args[1].get("do_sample", False):
267
+ # when we are doing non-greedy generation, don't use the cache
268
+ # (else every "randomly sampled" generation would be identical for repeats > 1).
269
+ if not warned:
270
+ eval_logger.warning(
271
+ f"Arguments to lm.generate_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed for such requests."
272
+ )
273
+ warned = True
274
+ res.append(None)
275
+ remaining_reqs.append(req)
276
+ elif hsh in self.dbdict:
277
+ ob = self.dbdict[hsh]
278
+
279
+ assert ob is not None
280
+
281
+ res.append(ob)
282
+ else:
283
+ res.append(None)
284
+ remaining_reqs.append(req)
285
+ eval_logger.info(
286
+ f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}"
287
+ )
288
+ if remaining_reqs:
289
+ # actually run the LM on the requests that do not have cached results
290
+ rem_res = getattr(self.lm, attr)(remaining_reqs)
291
+ else:
292
+ rem_res = []
293
+
294
+ # stick the new ones back into the list and also cache any of the new ones
295
+ resptr = 0
296
+ for req, r in zip(remaining_reqs, rem_res):
297
+ while res[resptr] is not None:
298
+ resptr += 1
299
+
300
+ res[resptr] = r
301
+
302
+ # caching
303
+ hsh = hash_args(attr, req.args)
304
+ self.dbdict[hsh] = r
305
+ self.dbdict.commit()
306
+
307
+ return res
308
+
309
+ return fn
310
+
311
+ def get_cache_hook(self):
312
+ return CacheHook(self)
313
+
314
+
315
+ class TemplateLM(LM):
316
+ """
317
+ A class acting as intermediary between the LM base class
318
+ and boilerplate often included in other LM subclasses.
319
+ """
320
+
321
+ tokenizer = None
322
+
323
+ @property
324
+ @abc.abstractmethod
325
+ def eot_token_id(self):
326
+ pass
327
+
328
+ @property
329
+ def prefix_token_id(self):
330
+ # it is used as prefix for loglikelihood
331
+ return self.eot_token_id
332
+
333
+ @abc.abstractmethod
334
+ def tok_encode(self, string: str, **kwargs) -> List[int]:
335
+ """
336
+ Tokenize a string using the model's tokenizer and return a list of token IDs.
337
+ """
338
+ pass
339
+
340
+ @abc.abstractmethod
341
+ def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]:
342
+ pass
343
+
344
+ def _encode_pair(
345
+ self, context: str, continuation: str
346
+ ) -> Tuple[List[int], List[int]]:
347
+ n_spaces = len(context) - len(context.rstrip())
348
+ if n_spaces > 0:
349
+ continuation = context[-n_spaces:] + continuation
350
+ context = context[:-n_spaces]
351
+
352
+ model_class = getattr(self, "AUTO_MODEL_CLASS", None)
353
+
354
+ if model_class == transformers.AutoModelForSeq2SeqLM:
355
+ context_enc = self.tok_encode(context)
356
+ continuation_enc = self.tok_encode(continuation, add_special_tokens=False)
357
+ else:
358
+ whole_enc = self.tok_encode(context + continuation)
359
+ context_enc = self.tok_encode(context)
360
+
361
+ context_enc_len = len(context_enc)
362
+ continuation_enc = whole_enc[context_enc_len:]
363
+
364
+ return context_enc, continuation_enc
365
+
366
+ def loglikelihood(
367
+ self, requests, disable_tqdm: bool = False
368
+ ) -> List[Tuple[float, bool]]:
369
+ new_reqs = []
370
+ for context, continuation in [req.args for req in requests]:
371
+ if context == "":
372
+ # BOS or EOS as context
373
+ context_enc, continuation_enc = (
374
+ [self.prefix_token_id],
375
+ self.tok_encode(continuation),
376
+ )
377
+ else:
378
+ context_enc, continuation_enc = self._encode_pair(context, continuation)
379
+
380
+ new_reqs.append(((context, continuation), context_enc, continuation_enc))
381
+
382
+ return self._loglikelihood_tokens(new_reqs, disable_tqdm=disable_tqdm)
383
+
384
+ @abc.abstractmethod
385
+ def loglikelihood_rolling(
386
+ self, requests, disable_tqdm: bool = False
387
+ ) -> List[float]:
388
+ pass
389
+
390
+ @abc.abstractmethod
391
+ def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
392
+ pass
393
+
394
+ def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
395
+ """
396
+ Set and get the appropriate chat template for the model.
397
+ This method sets the tokenizer's chat_template and returns the template string for reproducibility.
398
+
399
+ The template selection logic is adapted from the Transformers library's `apply_chat_template`
400
+ method in the Tokenizer class. The original implementation can be found at:
401
+ https://github.com/huggingface/transformers/blob/fc35907f95459d7a6c5281dfadd680b6f7b620e3/src/transformers/tokenization_utils_base.py#L1687
402
+
403
+ This method ensures that the right template is chosen based on the following:
404
+ 0. If the model has no 'tokenizer' attribute: assumes that there is only a single possible chat template, handled on the model provider side internally. Returns the empty string.
405
+ 1. If the model's tokenizer has multiple templates:
406
+ a. Use the specified template if it exists in the dictionary.
407
+ b. Use the default template from the list if no specific template is provided.
408
+ c. Raise an error if no default template exists and no specific template is provided.
409
+ 2. If the model's tokenizer has a single template or no template:
410
+ a. Use the tokenizer's chat template if available.
411
+ b. Fall back to the default chat template if no tokenizer chat template exists.
412
+
413
+ Args:
414
+ chat_template (Union[bool, str]): Specifies the chat template to use.
415
+ - If False or None, no template is applied.
416
+ - If True, the default or only available template is used.
417
+ - If a string, the template with the matching name is used.
418
+
419
+ Returns:
420
+ Optional[str]: The selected chat template, or None if no template is applied.
421
+ """
422
+ if self.tokenizer is None:
423
+ return ""
424
+
425
+ if chat_template is False or chat_template is None:
426
+ eval_logger.warning(
427
+ "model.chat_template was called with the chat_template set to False or None. "
428
+ "Therefore no chat template will be applied. Make sure this is an intended behavior."
429
+ )
430
+ return None
431
+
432
+ # Convert boolean chat_template to None to ensure compatibility with the adapted logic
433
+ if isinstance(chat_template, bool):
434
+ chat_template = None
435
+ using_default_template = False
436
+
437
+ # First, handle the cases when the model has a dict of multiple templates
438
+ try:
439
+ template = (
440
+ self.tokenizer.chat_template or self.tokenizer.default_chat_template
441
+ )
442
+ except AttributeError:
443
+ return None
444
+
445
+ if isinstance(template, dict):
446
+ using_default_dict = self.tokenizer.chat_template is None
447
+
448
+ if chat_template is not None:
449
+ if chat_template in template:
450
+ selected_template = template[chat_template]
451
+ if using_default_dict:
452
+ using_default_template = True
453
+ else:
454
+ raise ValueError(
455
+ f"The specified chat template '{chat_template}' is not available. "
456
+ f"Available template names are {sorted(template.keys())}."
457
+ )
458
+ else:
459
+ # If user didn't pass a chat template, use the default template from the dict
460
+ if "default" in template:
461
+ selected_template = template["default"]
462
+ using_default_template = True
463
+ else:
464
+ raise ValueError(
465
+ "This model has multiple chat templates with no default specified! Please either pass a chat "
466
+ "template or the name of the template you wish to use to the `chat_template` argument. Available "
467
+ f"template names are {sorted(template.keys())}."
468
+ )
469
+
470
+ # Cases when the model has a single template or no template
471
+ else:
472
+ # priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template
473
+ if isinstance(chat_template, str):
474
+ eval_logger.warning(
475
+ "Chat template name provided, but the tokenizer's chat template is not a dictionary. "
476
+ "Using the tokenizer's chat template or the default template instead."
477
+ )
478
+ if self.tokenizer.chat_template is not None:
479
+ selected_template = self.tokenizer.chat_template
480
+ else:
481
+ selected_template = self.tokenizer.default_chat_template
482
+ using_default_template = True
483
+
484
+ if using_default_template:
485
+ eval_logger.warning(
486
+ "No chat template is set for this tokenizer, falling back to a default class-level template. This is "
487
+ "very error-prone, because models are often trained with templates different from the class default! "
488
+ "Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which "
489
+ "point any code depending on them will stop working. We recommend setting a valid chat template before "
490
+ "then to ensure that this model continues working without issues."
491
+ )
492
+
493
+ return selected_template
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/api/samplers.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import warnings
3
+ from functools import partial
4
+ from typing import TYPE_CHECKING, Iterable, Optional, Union
5
+
6
+ import datasets
7
+
8
+
9
+ if TYPE_CHECKING:
10
+ from random import Random
11
+
12
+ from lm_eval.api.task import ConfigurableTask, Task
13
+
14
+ eval_logger = logging.getLogger("lm-eval")
15
+
16
+
17
+ class ContextSampler:
18
+ def __init__(
19
+ self,
20
+ docs: list[dict],
21
+ task: Union["Task", "ConfigurableTask"],
22
+ fewshot_indices: Optional[Iterable] = None,
23
+ rnd: Optional["Random"] = None,
24
+ ) -> None:
25
+ self.rnd = rnd
26
+ if not self.rnd:
27
+ raise ValueError(
28
+ "A `random.Random` generator argument must be provided to `rnd` of FewShotSampler!"
29
+ )
30
+
31
+ self.task = task
32
+ self.config = task._config
33
+
34
+ self.target_delimiter = self.config.target_delimiter
35
+ self.fewshot_delimiter = self.config.fewshot_delimiter
36
+
37
+ if (
38
+ self.config.fewshot_config is not None
39
+ and self.config.fewshot_config.get("doc_to_text", None) is not None
40
+ ):
41
+ self.doc_to_text = partial(
42
+ self.task.doc_to_text,
43
+ doc_to_text=self.config.fewshot_config.get("doc_to_text", None),
44
+ )
45
+ else:
46
+ self.doc_to_text = self.task.doc_to_text
47
+
48
+ if (
49
+ self.config.fewshot_config is not None
50
+ and self.config.fewshot_config.get("doc_to_target", None) is not None
51
+ ):
52
+ self.doc_to_target = partial(
53
+ self.task.doc_to_target,
54
+ doc_to_target=self.config.fewshot_config.get("doc_to_target", None),
55
+ )
56
+ else:
57
+ self.doc_to_target = self.task.doc_to_target
58
+
59
+ if (
60
+ self.config.fewshot_config is not None
61
+ and self.config.fewshot_config.get("doc_to_choice", None) is not None
62
+ ):
63
+ self.doc_to_choice = partial(
64
+ self.task.doc_to_choice,
65
+ doc_to_choice=self.config.fewshot_config.get("doc_to_choice", None),
66
+ )
67
+ else:
68
+ self.doc_to_choice = self.task.doc_to_choice
69
+
70
+ self.docs = docs # HF dataset split, provided by task._fewshot_docs()
71
+ if fewshot_indices: # subset few-shot docs from
72
+ if not isinstance(self.docs, datasets.Dataset):
73
+ raise ValueError(
74
+ "Got `fewshot_indices` but fewshot_docs are not a HF dataset. Don't use both `fewshot_indices` and a user-defined few-shot sample list simultaneously"
75
+ )
76
+ self.docs = self.docs.select(fewshot_indices)
77
+
78
+ def get_context(self, doc: dict, num_fewshot: int, gen_prefix: str = None):
79
+ # draw an extra fewshot sample if using same split as evaluating on
80
+ prefix = gen_prefix + " " if gen_prefix else ""
81
+ n_samples = (
82
+ num_fewshot + 1
83
+ if self.config.fewshot_split == self.config.test_split
84
+ else num_fewshot
85
+ )
86
+
87
+ # draw `n_samples` docs from fewshot_docs
88
+ fewshotex = self.sample(n_samples)
89
+
90
+ # get rid of the doc that's the one we're evaluating, if it's in the fewshot
91
+ # TODO: should we just stop people from using fewshot from same split as evaluating?
92
+ selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
93
+
94
+ labeled_examples = ""
95
+ for doc in selected_docs:
96
+ doc_content = self.doc_to_text(doc)
97
+ doc_target = self.doc_to_target(doc)
98
+ if self.config.doc_to_choice is None or isinstance(doc_content, str):
99
+ labeled_examples += doc_content
100
+ else:
101
+ labeled_examples += self.doc_to_choice(doc)[doc_content]
102
+
103
+ if doc_target != "":
104
+ if self.target_delimiter.isspace() and str(doc_target)[0].isspace():
105
+ # TODO: add logger warn once here.
106
+ warnings.warn(
107
+ "Both target_delimiter and target start with a space. This may cause issues.",
108
+ Warning,
109
+ stacklevel=2,
110
+ )
111
+ labeled_examples += self.target_delimiter
112
+ labeled_examples += prefix
113
+ labeled_examples += (
114
+ str(doc_target[0])
115
+ if isinstance(doc_target, list)
116
+ else doc_target
117
+ if self.config.doc_to_choice is None or isinstance(doc_target, str)
118
+ else str(self.doc_to_choice(doc)[doc_target])
119
+ )
120
+ labeled_examples += self.fewshot_delimiter
121
+
122
+ return labeled_examples
123
+
124
+ def get_chat_context(
125
+ self,
126
+ doc: dict,
127
+ num_fewshot: int,
128
+ fewshot_as_multiturn: bool = False,
129
+ gen_prefix: Optional[str] = None,
130
+ ):
131
+ # TODO: Do we need any other delimiter
132
+ prefix = gen_prefix + " " if gen_prefix else ""
133
+ chat_history = []
134
+ # draw an extra fewshot sample if using same split as evaluating on
135
+ n_samples = (
136
+ num_fewshot + 1
137
+ if self.config.fewshot_split == self.config.test_split
138
+ else num_fewshot
139
+ )
140
+ # draw `n_samples` docs from fewshot_docs
141
+ fewshotex = self.sample(n_samples)
142
+
143
+ # get rid of the doc that's the one we're evaluating, if it's in the fewshot
144
+ # TODO: should we just stop people from using fewshot from same split as evaluating?
145
+ selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
146
+
147
+ if fewshot_as_multiturn:
148
+ for doc in selected_docs:
149
+ doc_content = self.doc_to_text(doc)
150
+ doc_target = self.doc_to_target(doc)
151
+ chat_history.append(
152
+ {
153
+ "role": "user",
154
+ "content": doc_content
155
+ if self.config.doc_to_choice is None
156
+ or isinstance(doc_content, str)
157
+ else self.doc_to_choice(doc)[doc_content],
158
+ }
159
+ )
160
+ chat_history.append(
161
+ {
162
+ "role": "assistant",
163
+ "content": prefix + str(doc_target[0])
164
+ if isinstance(doc_target, list)
165
+ else prefix + doc_target
166
+ if self.config.doc_to_choice is None
167
+ or isinstance(doc_target, str)
168
+ else prefix + str(self.doc_to_choice(doc)[doc_target]),
169
+ }
170
+ )
171
+ else:
172
+ # get fewshot context as one user turn
173
+ chat_history.append(
174
+ {
175
+ "role": "user",
176
+ "content": self.get_context(
177
+ doc, num_fewshot, gen_prefix=gen_prefix
178
+ ),
179
+ }
180
+ )
181
+
182
+ return chat_history
183
+
184
+ def sample(self, n: int):
185
+ """
186
+ Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
187
+ """
188
+
189
+ return self.rnd.sample(self.docs, n)
190
+
191
+
192
+ class FirstNSampler(ContextSampler):
193
+ def sample(self, n: int) -> None:
194
+ """
195
+ Draw the first `n` samples in order from the specified split.
196
+ Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
197
+ """
198
+ assert n <= len(self.docs), (
199
+ f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available."
200
+ )
201
+ return self.docs[:n]
202
+
203
+
204
+ class BalancedSampler(ContextSampler):
205
+ def sample(self, n: int) -> None:
206
+ """
207
+ TODO: this should return approximately class-balanced samples from our fewshot examples.
208
+ TODO: what order should they be in? maybe random?
209
+ """
210
+
211
+ pass
212
+
213
+
214
+ class ManualSampler(ContextSampler):
215
+ def sample(self, n: int) -> None:
216
+ """ """
217
+ pass
218
+
219
+
220
+ SAMPLER_REGISTRY = {
221
+ "default": ContextSampler,
222
+ "first_n": FirstNSampler,
223
+ }
224
+
225
+
226
+ def get_sampler(name: str):
227
+ try:
228
+ return SAMPLER_REGISTRY[name]
229
+ except KeyError:
230
+ raise ValueError(
231
+ f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}"
232
+ )
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/api/task.py ADDED
@@ -0,0 +1,1839 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import ast
3
+ import logging
4
+ import random
5
+ import re
6
+ from collections.abc import Callable
7
+ from copy import deepcopy
8
+ from dataclasses import asdict, dataclass
9
+ from inspect import getsource
10
+ from typing import (
11
+ Any,
12
+ Dict,
13
+ Iterable,
14
+ Iterator,
15
+ List,
16
+ Literal,
17
+ Mapping,
18
+ Optional,
19
+ Tuple,
20
+ Union,
21
+ )
22
+
23
+ import datasets
24
+ import numpy as np
25
+ from tqdm import tqdm
26
+
27
+ from lm_eval import utils
28
+ from lm_eval.api import samplers
29
+ from lm_eval.api.instance import Instance, OutputType
30
+ from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
31
+ from lm_eval.api.registry import (
32
+ AGGREGATION_REGISTRY,
33
+ DEFAULT_METRIC_REGISTRY,
34
+ get_aggregation,
35
+ get_metric,
36
+ get_metric_aggregation,
37
+ is_higher_better,
38
+ )
39
+ from lm_eval.caching.cache import load_from_cache, save_to_cache
40
+ from lm_eval.filters import build_filter_ensemble
41
+ from lm_eval.prompts import get_prompt
42
+
43
+
44
+ ALL_OUTPUT_TYPES = [
45
+ "loglikelihood",
46
+ "multiple_choice",
47
+ "loglikelihood_rolling",
48
+ "generate_until",
49
+ ]
50
+
51
+ eval_logger = logging.getLogger(__name__)
52
+
53
+
54
+ @dataclass
55
+ class TaskConfig(dict):
56
+ # task naming/registry
57
+ task: Optional[str] = None
58
+ task_alias: Optional[str] = None
59
+ tag: Optional[Union[str, list]] = None
60
+ # HF dataset options.
61
+ # which dataset to use,
62
+ # and what splits for what purpose
63
+ custom_dataset: Optional[Callable] = None
64
+ dataset_path: Optional[str] = None
65
+ dataset_name: Optional[str] = None
66
+ dataset_kwargs: Optional[dict] = None
67
+ training_split: Optional[str] = None
68
+ validation_split: Optional[str] = None
69
+ test_split: Optional[str] = None
70
+ fewshot_split: Optional[str] = (
71
+ None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?)
72
+ )
73
+ # formatting / prompting options.
74
+ # see docs/advanced_task_guide.md for more info
75
+ process_docs: Optional[Callable] = None
76
+ doc_to_text: Optional[Union[Callable, str]] = None
77
+ doc_to_target: Optional[Union[Callable, str]] = None
78
+ doc_to_image: Union[Callable, str] = None
79
+ doc_to_audio: Union[Callable, str] = None
80
+ unsafe_code: bool = False
81
+ doc_to_choice: Optional[Union[Callable, str, dict, list]] = None
82
+ process_results: Optional[Union[Callable, str]] = None
83
+ use_prompt: Optional[str] = None
84
+ description: str = ""
85
+ target_delimiter: str = " "
86
+ fewshot_delimiter: str = "\n\n"
87
+ fewshot_config: Optional[dict] = None
88
+ # runtime configuration options
89
+ num_fewshot: Optional[int] = None
90
+ # scoring options
91
+ metric_list: Optional[list] = None
92
+ output_type: OutputType = "generate_until"
93
+ generation_kwargs: Optional[dict] = None
94
+ repeats: int = 1
95
+ filter_list: Optional[Union[str, list]] = None
96
+ should_decontaminate: bool = False
97
+ doc_to_decontamination_query: Optional[str] = None
98
+ gen_prefix: Optional[str] = None
99
+ metadata: Optional[dict] = (
100
+ None # by default, not used in the code. allows for users to pass arbitrary info to tasks
101
+ )
102
+
103
+ def __post_init__(self) -> None:
104
+ if self.generation_kwargs is not None:
105
+ if self.output_type != "generate_until":
106
+ eval_logger.warning(
107
+ f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!"
108
+ )
109
+
110
+ if "temperature" in self.generation_kwargs:
111
+ self.generation_kwargs["temperature"] = float(
112
+ self.generation_kwargs["temperature"]
113
+ )
114
+
115
+ if "until" not in self.generation_kwargs:
116
+ self.generation_kwargs["until"] = [self.fewshot_delimiter]
117
+ else:
118
+ if self.output_type == "generate_until":
119
+ # ensure that we greedily generate in absence of explicit arguments otherwise
120
+ self.generation_kwargs = {
121
+ "until": (
122
+ None
123
+ if self.fewshot_delimiter is None
124
+ else [self.fewshot_delimiter]
125
+ ),
126
+ "do_sample": False,
127
+ }
128
+
129
+ def __getitem__(self, item):
130
+ return getattr(self, item)
131
+
132
+ def __setitem__(self, item, value):
133
+ return setattr(self, item, value)
134
+
135
+ def to_dict(self, keep_callable: bool = False) -> dict:
136
+ """dumps the current config as a dictionary object, as a printable format.
137
+ null fields will not be printed.
138
+ Used for dumping results alongside full task configuration
139
+
140
+ :return: dict
141
+ A printable dictionary version of the TaskConfig object.
142
+
143
+ # TODO: should any default value in the TaskConfig not be printed?
144
+ """
145
+ cfg_dict = asdict(self)
146
+ # remove values that are `None`
147
+ for k, v in list(cfg_dict.items()):
148
+ if v is None:
149
+ cfg_dict.pop(k)
150
+ elif k == "metric_list":
151
+ for metric_dict in v:
152
+ for metric_key, metric_value in metric_dict.items():
153
+ if callable(metric_value):
154
+ metric_dict[metric_key] = self.serialize_function(
155
+ metric_value, keep_callable=keep_callable
156
+ )
157
+ cfg_dict[k] = v
158
+ elif callable(v):
159
+ cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable)
160
+ return cfg_dict
161
+
162
+ def serialize_function(
163
+ self, value: Union[Callable, str], keep_callable=False
164
+ ) -> Union[Callable, str]:
165
+ """Serializes a given function or string.
166
+
167
+ If 'keep_callable' is True, the original callable is returned.
168
+ Otherwise, attempts to return the source code of the callable using 'getsource'.
169
+ """
170
+ if keep_callable:
171
+ return value
172
+ else:
173
+ try:
174
+ return getsource(value)
175
+ except (TypeError, OSError):
176
+ return str(value)
177
+
178
+
179
+ class Task(abc.ABC):
180
+ """A task represents an entire benchmark including its dataset, problems,
181
+ answers, and evaluation methods. See BoolQ for a simple example implementation
182
+
183
+ A `doc` can be any python object which represents one instance of evaluation.
184
+ This is usually a dictionary e.g.
185
+ {"question": ..., "answer": ...} or
186
+ {"question": ..., question, answer)
187
+ """
188
+
189
+ VERSION: Optional[Union[int, str]] = None
190
+
191
+ # The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
192
+ # or a path to a custom `datasets` loading script.
193
+ DATASET_PATH: Optional[str] = None
194
+
195
+ # The name of a subset within `DATASET_PATH`.
196
+ DATASET_NAME: Optional[str] = None
197
+
198
+ OUTPUT_TYPE: Optional[OutputType] = None
199
+
200
+ def __init__(
201
+ self,
202
+ data_dir: Optional[str] = None,
203
+ cache_dir: Optional[str] = None,
204
+ download_mode: Optional[datasets.DownloadMode] = None,
205
+ config: Optional[Mapping] = None, # Union[dict, TaskConfig]
206
+ ) -> None:
207
+ """
208
+ :param data_dir: str
209
+ Stores the path to a local folder containing the `Task`'s data files.
210
+ Use this to specify the path to manually downloaded data (usually when
211
+ the dataset is not publicly accessible).
212
+ :param cache_dir: str
213
+ The directory to read/write the `Task` dataset. This follows the
214
+ HuggingFace `datasets` API with the default cache directory located at:
215
+ `~/.cache/huggingface/datasets`
216
+ NOTE: You can change the cache location globally for a given process
217
+ to another directory:
218
+ `export HF_DATASETS_CACHE="/path/to/another/directory"`
219
+ :param download_mode: datasets.DownloadMode
220
+ How to treat pre-existing `Task` downloads and data.
221
+ - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
222
+ Reuse download and reuse dataset.
223
+ - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
224
+ Reuse download with fresh dataset.
225
+ - `datasets.DownloadMode.FORCE_REDOWNLOAD`
226
+ Fresh download and fresh dataset.
227
+ """
228
+ self.download(data_dir, cache_dir, download_mode)
229
+ self._training_docs: Optional[list] = None
230
+ self._fewshot_docs: Optional[list] = None
231
+ self._instances: Optional[List[Instance]] = None
232
+
233
+ self._config: TaskConfig = TaskConfig({**config}) if config else TaskConfig()
234
+
235
+ self._filters = [build_filter_ensemble("none", [["take_first", None]])]
236
+ self.fewshot_rnd: Optional[random.Random] = (
237
+ None # purposely induce errors in case of improper usage
238
+ )
239
+
240
+ def download(
241
+ self,
242
+ data_dir: Optional[str] = None,
243
+ cache_dir: Optional[str] = None,
244
+ download_mode=None,
245
+ ) -> None:
246
+ """Downloads and returns the task dataset.
247
+ Override this method to download the dataset from a custom API.
248
+
249
+ :param data_dir: str
250
+ Stores the path to a local folder containing the `Task`'s data files.
251
+ Use this to specify the path to manually downloaded data (usually when
252
+ the dataset is not publicly accessible).
253
+ :param cache_dir: str
254
+ The directory to read/write the `Task` dataset. This follows the
255
+ HuggingFace `datasets` API with the default cache directory located at:
256
+ `~/.cache/huggingface/datasets`
257
+ NOTE: You can change the cache location globally for a given process
258
+ by setting the shell environment variable, `HF_DATASETS_CACHE`,
259
+ to another directory:
260
+ `export HF_DATASETS_CACHE="/path/to/another/directory"`
261
+ :param download_mode: datasets.DownloadMode
262
+ How to treat pre-existing `Task` downloads and data.
263
+ - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
264
+ Reuse download and reuse dataset.
265
+ - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
266
+ Reuse download with fresh dataset.
267
+ - `datasets.DownloadMode.FORCE_REDOWNLOAD`
268
+ Fresh download and fresh dataset.
269
+ """
270
+ self.dataset = datasets.load_dataset(
271
+ path=self.DATASET_PATH,
272
+ name=self.DATASET_NAME,
273
+ data_dir=data_dir,
274
+ cache_dir=cache_dir,
275
+ download_mode=download_mode,
276
+ )
277
+
278
+ @property
279
+ def config(self) -> TaskConfig:
280
+ """Returns the TaskConfig associated with this class."""
281
+ return self._config
282
+
283
+ @abc.abstractmethod
284
+ def has_training_docs(self):
285
+ """Whether the task has a training set"""
286
+ pass
287
+
288
+ @abc.abstractmethod
289
+ def has_validation_docs(self):
290
+ """Whether the task has a validation set"""
291
+ pass
292
+
293
+ @abc.abstractmethod
294
+ def has_test_docs(self):
295
+ """Whether the task has a test set"""
296
+ pass
297
+
298
+ def training_docs(self) -> Iterable:
299
+ """
300
+ :return: Iterable[obj]
301
+ A iterable of any object, that doc_to_text can handle
302
+ """
303
+ return []
304
+
305
+ def validation_docs(self) -> Iterable:
306
+ """
307
+ :return: Iterable[obj]
308
+ A iterable of any object, that doc_to_text can handle
309
+ """
310
+ return []
311
+
312
+ def test_docs(self) -> Iterable:
313
+ """
314
+ :return: Iterable[obj]
315
+ A iterable of any object, that doc_to_text can handle
316
+ """
317
+ return []
318
+
319
+ def fewshot_docs(self) -> Iterable:
320
+ """
321
+ :return: Iterable[obj]
322
+ A iterable of any object, that doc_to_text can handle
323
+ """
324
+ if self.has_training_docs():
325
+ return self.training_docs()
326
+ elif self.has_validation_docs():
327
+ return self.validation_docs()
328
+ else:
329
+ if self.config.get("num_fewshot", 0) > 0:
330
+ eval_logger.warning(
331
+ f"[Task: {self.config.task}] has_training_docs and has_validation_docs are False"
332
+ ", using test_docs as fewshot_docs but this is not recommended."
333
+ )
334
+ return self.test_docs()
335
+
336
+ def _process_doc(self, doc: dict) -> dict:
337
+ """
338
+ Override this to process (detokenize, strip, replace, etc.) individual
339
+ documents. This can be used in a map over documents of a data split.
340
+ E.g. `map(self._process_doc, self.dataset["validation"])`
341
+
342
+ :return: dict
343
+ The processed version of the specified `doc`.
344
+ """
345
+ return doc
346
+
347
+ @property
348
+ def instances(self) -> List[Instance]:
349
+ """After calling `task.build_all_requests()`, tasks
350
+ maintain a list of the dataset instances which will be evaluated.
351
+ """
352
+ return self._instances
353
+
354
+ def fewshot_examples(self, k, rnd):
355
+ if self._training_docs is None:
356
+ self._training_docs = list(self.training_docs())
357
+
358
+ return rnd.sample(self._training_docs, k)
359
+
360
+ def doc_to_decontamination_query(self, doc):
361
+ raise NotImplementedError(
362
+ "Override doc_to_decontamination_query with document specific decontamination query."
363
+ )
364
+
365
+ @abc.abstractmethod
366
+ def doc_to_text(self, doc):
367
+ pass
368
+
369
+ @abc.abstractmethod
370
+ def doc_to_target(self, doc):
371
+ pass
372
+
373
+ # not an abstractmethod because not every language-only task has to implement this
374
+ def doc_to_image(self, doc):
375
+ raise NotImplementedError
376
+
377
+ def doc_to_audio(self, doc):
378
+ raise NotImplementedError
379
+
380
+ def doc_to_prefix(self, doc):
381
+ return ""
382
+
383
+ def build_all_requests(
384
+ self,
385
+ *,
386
+ limit: Union[int, None] = None,
387
+ rank: int = 0,
388
+ world_size: int = 1,
389
+ cache_requests: bool = False,
390
+ rewrite_requests_cache: bool = False,
391
+ system_instruction: Optional[str] = None,
392
+ apply_chat_template: bool = False,
393
+ fewshot_as_multiturn: bool = False,
394
+ chat_template: Optional[Callable] = None,
395
+ tokenizer_name: str = "",
396
+ ) -> None:
397
+ """Build a set of Instances for a task, and store them in task.instances"""
398
+
399
+ # used with caching
400
+ og_limit = limit
401
+
402
+ cache_key = f"requests-{self._config.task}-{self.config.num_fewshot}shot-rank{rank}-world_size{world_size}"
403
+ cache_key += "-chat_template" if apply_chat_template else ""
404
+ cache_key += "-fewshot_as_multiturn" if fewshot_as_multiturn else ""
405
+ cache_key += (
406
+ f"-system_prompt_hash{utils.hash_string(system_instruction)}"
407
+ if system_instruction is not None
408
+ else ""
409
+ )
410
+ cache_key += f"-tokenizer{tokenizer_name}"
411
+
412
+ cached_instances = load_from_cache(file_name=cache_key, cache=cache_requests)
413
+
414
+ if cache_requests and cached_instances and not rewrite_requests_cache:
415
+ cached_instances = cached_instances[:limit]
416
+
417
+ flattened_instances = [
418
+ instance
419
+ for instance_group in cached_instances
420
+ for instance in instance_group
421
+ ]
422
+
423
+ self._instances = flattened_instances
424
+ return
425
+
426
+ eval_logger.info(f"Building contexts for {self.config.task} on rank {rank}...")
427
+
428
+ instances = []
429
+
430
+ # process all documents when caching is specified for simplicity
431
+ if (
432
+ cache_requests
433
+ and (not cached_instances or rewrite_requests_cache)
434
+ and limit is not None
435
+ ):
436
+ limit = None
437
+
438
+ doc_id_docs = list(
439
+ self.doc_iterator(rank=rank, limit=limit, world_size=world_size)
440
+ )
441
+
442
+ num_docs = len(doc_id_docs)
443
+
444
+ for doc_id, doc in tqdm(
445
+ doc_id_docs,
446
+ total=num_docs,
447
+ ):
448
+ # sample fewshot context #TODO: need to offset doc_id by rank now!
449
+ fewshot_ctx = self.fewshot_context(
450
+ doc,
451
+ 0 if self.config.num_fewshot is None else self.config.num_fewshot,
452
+ system_instruction,
453
+ apply_chat_template,
454
+ fewshot_as_multiturn,
455
+ chat_template,
456
+ gen_prefix=self.doc_to_prefix(doc),
457
+ )
458
+
459
+ # TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
460
+ inst = self.construct_requests(
461
+ doc=doc,
462
+ ctx=fewshot_ctx,
463
+ metadata=(self.config["task"], doc_id, self.config.repeats),
464
+ apply_chat_template=apply_chat_template,
465
+ chat_template=chat_template,
466
+ )
467
+
468
+ if not isinstance(inst, list):
469
+ inst = [inst]
470
+
471
+ instances.append(inst)
472
+
473
+ # now flatten, this is to allow slicing to work with pickles
474
+
475
+ sliced_instances = instances[:og_limit]
476
+
477
+ flattened_instances = [
478
+ instance
479
+ for instance_group in sliced_instances
480
+ for instance in instance_group
481
+ ]
482
+
483
+ self._instances = flattened_instances
484
+
485
+ if len(self._instances) == 0:
486
+ raise ValueError("task.build_requests() did not find any docs!")
487
+
488
+ if cache_requests and (not cached_instances or rewrite_requests_cache):
489
+ save_to_cache(file_name=cache_key, obj=instances)
490
+
491
+ @abc.abstractmethod
492
+ def construct_requests(self, doc, ctx, **kwargs):
493
+ """Uses RequestFactory to construct Requests and returns an iterable of
494
+ Requests which will be sent to the LM.
495
+
496
+ :param doc:
497
+ The document as returned from training_docs, validation_docs, or test_docs.
498
+ :param ctx: str
499
+ The context string, generated by fewshot_context. This includes the natural
500
+ language description, as well as the few shot examples, and the question
501
+ part of the document for `doc`.
502
+ :param doc_idx: int
503
+ The index of a document within `self.test_docs()` or `self.validation_docs()`,
504
+ whichever is the main split used.
505
+ :param repeats: int
506
+ TODO: update this docstring
507
+ The number of times each instance in a dataset is inferred on. Defaults to 1,
508
+ can be increased for techniques like majority voting.
509
+ """
510
+ pass
511
+
512
+ @abc.abstractmethod
513
+ def process_results(self, doc, results):
514
+ """Take a single document and the LM results and evaluates, returning a
515
+ dict where keys are the names of submetrics and values are the values of
516
+ the metric for that one document
517
+
518
+ :param doc:
519
+ The document as returned from training_docs, validation_docs, or test_docs.
520
+ :param results:
521
+ The results of the requests created in construct_requests.
522
+ """
523
+ pass
524
+
525
+ @abc.abstractmethod
526
+ def aggregation(self):
527
+ """
528
+ :returns: {str: [metric_score] -> float}
529
+ A dictionary where keys are the names of submetrics and values are
530
+ functions that aggregate a list of metric scores
531
+ """
532
+ pass
533
+
534
+ @abc.abstractmethod
535
+ def higher_is_better(self):
536
+ """
537
+ :returns: {str: bool}
538
+ A dictionary where keys are the names of submetrics and values are
539
+ whether a higher value of the submetric is better
540
+ """
541
+ pass
542
+
543
+ def get_config(self, key: str) -> Any:
544
+ return getattr(self._config, key, None)
545
+
546
+ @classmethod
547
+ def count_bytes(cls, doc):
548
+ """Used for byte-level perplexity metrics in rolling loglikelihood"""
549
+ return len(doc.encode("utf-8"))
550
+
551
+ @classmethod
552
+ def count_words(cls, doc):
553
+ """Downstream loglikelihood_rolling perplexity tasks with custom word boundaries should override this!"""
554
+ return len(re.split(r"\s+", doc))
555
+
556
+ @utils.positional_deprecated
557
+ def fewshot_context(self, doc, num_fewshot, rnd=None, description=None, **kwargs):
558
+ """Returns a fewshot context string that is made up of a prepended description
559
+ (if provided), the `num_fewshot` number of examples, and an appended prompt example.
560
+
561
+ :param doc: str
562
+ The document as returned from training_docs, validation_docs, or test_docs.
563
+ :param num_fewshot: int
564
+ The number of fewshot examples to provide in the returned context string.
565
+ :param rnd: random.Random
566
+ The pseudo-random number generator used to randomly sample examples.
567
+ WARNING: This is currently a required arg although it's optionalized with a default `None`.
568
+ :param description: str
569
+ The task's description that will be prepended to the fewshot examples.
570
+ :returns: str
571
+ The fewshot context.
572
+ """
573
+ if rnd is None:
574
+ if self.fewshot_rnd is not None:
575
+ rnd = self.fewshot_rnd
576
+ else:
577
+ raise ValueError(
578
+ "A `random.Random` generator argument must be provided to `rnd`"
579
+ )
580
+
581
+ description = description if description else ""
582
+
583
+ if num_fewshot == 0:
584
+ labeled_examples = ""
585
+ else:
586
+ # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
587
+ if self.has_training_docs():
588
+ fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
589
+ else:
590
+ if self._fewshot_docs is None:
591
+ self._fewshot_docs = list(
592
+ self.validation_docs()
593
+ if self.has_validation_docs()
594
+ else self.test_docs()
595
+ )
596
+
597
+ fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
598
+
599
+ # get rid of the doc that's the one we're evaluating, if it's in the fewshot
600
+ fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
601
+
602
+ labeled_examples = (
603
+ "\n\n".join(
604
+ [
605
+ self.doc_to_text(doc) + self.doc_to_target(doc)
606
+ for doc in fewshotex
607
+ ]
608
+ )
609
+ + "\n\n"
610
+ )
611
+
612
+ example = self.doc_to_text(doc)
613
+ return description + labeled_examples + example
614
+
615
+ def apply_filters(self) -> Optional[List[Instance]]:
616
+ """Iterates over FilterEnsembles and applies them to instances"""
617
+ if hasattr(self, "_filters"):
618
+ for f in self._filters:
619
+ f.apply(self._instances)
620
+ else:
621
+ eval_logger.warning("No filter defined, passing through instances")
622
+ return self._instances
623
+
624
+ def dump_config(self) -> dict:
625
+ """Returns the config as a dictionary."""
626
+ # TODO: this should only return the overrides applied to a non-YAML task's configuration.
627
+ # (num_fewshot)
628
+ return self.config.to_dict()
629
+
630
+ def set_config(self, key: str, value: Any, update: bool = False) -> None:
631
+ """Set or update the configuration for a given key."""
632
+ if key is None:
633
+ raise ValueError("Key must be provided.")
634
+
635
+ if update:
636
+ current_value = getattr(self._config, key, {})
637
+ if not isinstance(current_value, dict):
638
+ raise TypeError(
639
+ f"Expected a dict for key '{key}', got {type(current_value).__name__} instead."
640
+ )
641
+ current_value.update(value)
642
+ else:
643
+ setattr(self._config, key, value)
644
+
645
+ def override_metric(self, metric_name: str) -> None:
646
+ """
647
+ Override the default metrics used for evaluation with custom metrics.
648
+
649
+ Parameters:
650
+ - metric_name (str): The name of the custom metric to override. Should be registered in api.metrics.
651
+ """
652
+ (
653
+ self._metric_fn_list,
654
+ self._aggregation_list,
655
+ self._metric_fn_kwargs,
656
+ self._higher_is_better,
657
+ ) = ({}, {}, {}, {})
658
+ self._metric_fn_list[metric_name] = get_metric(metric_name)
659
+ self._aggregation_list[metric_name] = get_metric_aggregation(metric_name)
660
+ self._higher_is_better[metric_name] = is_higher_better(metric_name)
661
+ self._metric_fn_kwargs[metric_name] = {}
662
+ if not isinstance(self, ConfigurableTask):
663
+ self.process_results = lambda x, y: {metric_name: get_metric(metric_name)}
664
+ self.aggregation = lambda: {
665
+ metric_name: get_metric_aggregation(metric_name)
666
+ }
667
+ setattr(self._config, "metric_list", [{"metric": metric_name}])
668
+ setattr(self._config, "process_results", None)
669
+
670
+ def set_fewshot_seed(self, seed: Optional[int] = None) -> None:
671
+ self.fewshot_rnd = random.Random(seed)
672
+ if hasattr(self, "sampler"):
673
+ self.sampler.rnd = self.fewshot_rnd
674
+
675
+ @property
676
+ def eval_docs(self) -> Union[datasets.Dataset, List[dict]]:
677
+ if self.has_test_docs():
678
+ return self.test_docs()
679
+ elif self.has_validation_docs():
680
+ return self.validation_docs()
681
+ else:
682
+ raise ValueError(
683
+ f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
684
+ )
685
+
686
+ def doc_iterator(
687
+ self, *, rank: int = 0, limit: Union[int, None] = None, world_size: int = 1
688
+ ) -> Iterator[Tuple[int, Any]]:
689
+ limit = int(limit) if limit else None
690
+ doc_iterator = utils.create_iterator(
691
+ enumerate(self.eval_docs),
692
+ rank=int(rank),
693
+ limit=limit,
694
+ world_size=int(world_size),
695
+ )
696
+ return doc_iterator
697
+
698
+
699
+ class ConfigurableTask(Task):
700
+ VERSION = "Yaml"
701
+ OUTPUT_TYPE = None
702
+ CONFIG = None
703
+
704
+ def __init__(
705
+ self,
706
+ data_dir=None,
707
+ cache_dir=None,
708
+ download_mode=None,
709
+ config: Optional[dict] = None,
710
+ ) -> None: # TODO no super() call here
711
+ # Get pre-configured attributes
712
+ self._config = self.CONFIG
713
+
714
+ # Use new configurations if there was no preconfiguration
715
+ if self.config is None:
716
+ self._config = TaskConfig(**config)
717
+ # Overwrite configs
718
+ else:
719
+ if config is not None:
720
+ self._config.__dict__.update(config)
721
+
722
+ if self.config is None:
723
+ raise ValueError(
724
+ "Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg"
725
+ )
726
+
727
+ if isinstance(self.config.metadata, dict):
728
+ if "version" in self.config.metadata:
729
+ self.VERSION = self.config.metadata["version"]
730
+
731
+ if self.config.output_type is not None:
732
+ if self.config.output_type not in ALL_OUTPUT_TYPES:
733
+ raise ValueError(
734
+ f"Got invalid output_type '{self.config.output_type}', must be in '{','.join(ALL_OUTPUT_TYPES)}'"
735
+ )
736
+ self.OUTPUT_TYPE = self.config.output_type
737
+
738
+ if self.config.doc_to_image is not None:
739
+ # mark the task as requiring multimodality.
740
+ self.MULTIMODAL = True
741
+
742
+ if self.config.doc_to_audio:
743
+ # mark the task as requiring multimodality.
744
+ self.MULTIMODAL = True
745
+
746
+ if self.config.unsafe_code is not False:
747
+ self.UNSAFE_CODE = True
748
+
749
+ if self.config.dataset_path is not None:
750
+ self.DATASET_PATH = self.config.dataset_path
751
+
752
+ if self.config.dataset_name is not None:
753
+ self.DATASET_NAME = self.config.dataset_name
754
+
755
+ self._metric_fn_list = {}
756
+ self._metric_fn_kwargs = {}
757
+ self._aggregation_list = {}
758
+ self._higher_is_better = {}
759
+
760
+ if self.config.metric_list is None:
761
+ # TODO: handle this in TaskConfig.__post_init__ ?
762
+ _metric_list = DEFAULT_METRIC_REGISTRY[self.config.output_type]
763
+
764
+ for metric_name in _metric_list:
765
+ self._metric_fn_list[metric_name] = get_metric(metric_name)
766
+ self._metric_fn_kwargs[metric_name] = {}
767
+ self._aggregation_list[metric_name] = get_metric_aggregation(
768
+ metric_name
769
+ )
770
+ self._higher_is_better[metric_name] = is_higher_better(metric_name)
771
+ else:
772
+ for metric_config in self.config.metric_list:
773
+ if "metric" not in metric_config:
774
+ raise ValueError(
775
+ "'metric' key not provided for an entry in 'metric_list', must be specified!"
776
+ )
777
+ metric_name = metric_config["metric"]
778
+ kwargs = {
779
+ key: metric_config[key]
780
+ for key in metric_config
781
+ if key
782
+ not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"]
783
+ }
784
+ hf_evaluate_metric = (
785
+ "hf_evaluate" in metric_config
786
+ and metric_config["hf_evaluate"] is True
787
+ )
788
+
789
+ if self.config.process_results is not None:
790
+ self._metric_fn_list[metric_name] = None
791
+ self._metric_fn_kwargs[metric_name] = {}
792
+ elif callable(metric_name):
793
+ metric_fn = metric_name.__call__
794
+ metric_name = metric_name.__name__
795
+ self._metric_fn_list[metric_name] = metric_fn
796
+ self._metric_fn_kwargs[metric_name] = kwargs
797
+ else:
798
+ self._metric_fn_list[metric_name] = get_metric(
799
+ metric_name, hf_evaluate_metric
800
+ )
801
+ self._metric_fn_kwargs[metric_name] = kwargs
802
+
803
+ if "aggregation" in metric_config:
804
+ agg_name = metric_config["aggregation"]
805
+ if isinstance(agg_name, str):
806
+ self._aggregation_list[metric_name] = get_aggregation(agg_name)
807
+ elif callable(agg_name): # noqa: E721
808
+ self._aggregation_list[metric_name] = metric_config[
809
+ "aggregation"
810
+ ]
811
+ else:
812
+ INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
813
+ metric_agg = get_metric_aggregation(metric_name)
814
+ eval_logger.warning(
815
+ f"[Task: {self.config.task}] metric {metric_name} is defined, but aggregation is not. "
816
+ f"using default "
817
+ f"aggregation={INV_AGG_REGISTRY[metric_agg]}"
818
+ )
819
+ self._aggregation_list[metric_name] = metric_agg
820
+
821
+ if "higher_is_better" in metric_config:
822
+ self._higher_is_better[metric_name] = metric_config[
823
+ "higher_is_better"
824
+ ]
825
+ else:
826
+ eval_logger.warning(
827
+ f"[Task: {self.config.task}] metric {metric_name} is defined, but higher_is_better is not. "
828
+ f"using default "
829
+ f"higher_is_better={is_higher_better(metric_name)}"
830
+ )
831
+ self._higher_is_better[metric_name] = is_higher_better(metric_name)
832
+
833
+ self.download(self.config.dataset_kwargs)
834
+ self._training_docs = None
835
+ self._fewshot_docs = None
836
+
837
+ if self.config.filter_list is not None:
838
+ self._filters = []
839
+ for filter_config in self.config.filter_list:
840
+ filter_name = filter_config["name"]
841
+ filter_functions = filter_config["filter"]
842
+ components = []
843
+ for function in filter_functions:
844
+ kwargs = {
845
+ key: function[key] for key in function if key != "function"
846
+ }
847
+ components.append([function["function"], kwargs])
848
+ filter_pipeline = build_filter_ensemble(filter_name, components)
849
+ self._filters.append(filter_pipeline)
850
+ else:
851
+ # TODO: handle repeats in a more general way rather than just discarding
852
+ eval_logger.debug(
853
+ "No custom filters defined. Using default 'take_first' filter for handling repeats."
854
+ )
855
+ self._filters = [build_filter_ensemble("none", [["take_first", None]])]
856
+
857
+ if self.config.use_prompt is not None:
858
+ eval_logger.info(f"loading prompt {self.config.use_prompt}")
859
+ self.prompt = get_prompt(
860
+ self.config.use_prompt, self.DATASET_PATH, self.DATASET_NAME
861
+ )
862
+ else:
863
+ self.prompt = None
864
+
865
+ if self.fewshot_docs() is not None:
866
+ self.fewshot_rnd = (
867
+ random.Random()
868
+ ) # setting with no seed, to be overridden at a later time
869
+ config_sampler: Union[str, Callable] = (
870
+ self.config.fewshot_config.get("sampler", "default")
871
+ if self.config.fewshot_config
872
+ else "default"
873
+ )
874
+ if isinstance(config_sampler, str):
875
+ self.sampler = samplers.get_sampler(config_sampler)(
876
+ list(self.fewshot_docs()), self, rnd=self.fewshot_rnd
877
+ )
878
+ elif callable(config_sampler) and issubclass(
879
+ config_sampler, samplers.ContextSampler
880
+ ):
881
+ self.sampler = config_sampler(
882
+ docs=list(self.fewshot_docs()), task=self, rnd=self.fewshot_rnd
883
+ )
884
+ else:
885
+ raise TypeError(
886
+ f"fewshot_config.sampler should be a string or callable of ContextSampler type, "
887
+ f"not {type(config_sampler)}"
888
+ )
889
+
890
+ self.task_docs = self.eval_docs
891
+
892
+ # Test One Doc
893
+ self.features = list(self.task_docs.features.keys())
894
+ self.multiple_input = 0
895
+ self.multiple_target = 0
896
+ test_doc = self.task_docs[0]
897
+ test_text = self.doc_to_text(test_doc)
898
+ test_target = self.doc_to_target(test_doc)
899
+
900
+ if self.config.doc_to_choice is not None:
901
+ test_choice = self.doc_to_choice(test_doc)
902
+ if not isinstance(test_choice, list):
903
+ eval_logger.error("doc_to_choice must return list")
904
+ else:
905
+ num_choice = len(test_choice)
906
+
907
+ if isinstance(test_text, int):
908
+ self.multiple_input = num_choice
909
+ else:
910
+ test_choice = None
911
+
912
+ if isinstance(test_target, list):
913
+ self.multiple_target = len(test_target)
914
+ else:
915
+ if (isinstance(test_target, int)) and (test_choice is not None):
916
+ test_target = test_choice[test_target]
917
+ else:
918
+ test_target = str(test_target)
919
+
920
+ if test_choice is not None:
921
+ check_choices = test_choice
922
+ else:
923
+ check_choices = [test_target]
924
+ if self.config.doc_to_choice is not None:
925
+ for choice in check_choices:
926
+ choice_has_whitespace = True if choice[0].isspace() else False
927
+ delimiter_has_whitespace = (
928
+ True
929
+ if self.config.target_delimiter.rstrip()
930
+ != self.config.target_delimiter
931
+ else False
932
+ )
933
+
934
+ if delimiter_has_whitespace and choice_has_whitespace:
935
+ eval_logger.debug(
936
+ f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" have whitespace'
937
+ )
938
+ elif (not delimiter_has_whitespace) and (not choice_has_whitespace):
939
+ eval_logger.debug(
940
+ f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
941
+ )
942
+
943
+ def download(
944
+ self, dataset_kwargs: Optional[Dict[str, Any]] = None, **kwargs
945
+ ) -> None:
946
+ if isinstance(self.config.custom_dataset, Callable):
947
+ eval_logger.warning(
948
+ f"{self.config.task}: Custom kwargs can be passed to `--metadata` in console (as json string) or to the TaskManager."
949
+ + "\nFor example --metadata='{\"max_seq_lengths\":[4096, 8192]}'. For details see task Readme."
950
+ )
951
+ self.dataset = self.config.custom_dataset(
952
+ **(self.config.metadata or {}), **(self.config.dataset_kwargs or {})
953
+ )
954
+ else:
955
+ self.dataset = datasets.load_dataset(
956
+ path=self.DATASET_PATH,
957
+ name=self.DATASET_NAME,
958
+ **dataset_kwargs if dataset_kwargs is not None else {},
959
+ )
960
+
961
+ def has_training_docs(self) -> bool:
962
+ if self.config.training_split is not None:
963
+ return True
964
+ else:
965
+ return False
966
+
967
+ def has_validation_docs(self) -> bool:
968
+ if self.config.validation_split is not None:
969
+ return True
970
+ else:
971
+ return False
972
+
973
+ def has_test_docs(self) -> bool:
974
+ if self.config.test_split is not None:
975
+ return True
976
+ else:
977
+ return False
978
+
979
+ def training_docs(self) -> datasets.Dataset:
980
+ if self.has_training_docs():
981
+ if self.config.process_docs is not None:
982
+ return self.config.process_docs(
983
+ self.dataset[self.config.training_split]
984
+ )
985
+ return self.dataset[self.config.training_split]
986
+
987
+ def validation_docs(self) -> datasets.Dataset:
988
+ if self.has_validation_docs():
989
+ if self.config.process_docs is not None:
990
+ return self.config.process_docs(
991
+ self.dataset[self.config.validation_split]
992
+ )
993
+ return self.dataset[self.config.validation_split]
994
+
995
+ def test_docs(self) -> datasets.Dataset:
996
+ if self.has_test_docs():
997
+ if self.config.process_docs is not None:
998
+ return self.config.process_docs(self.dataset[self.config.test_split])
999
+ return self.dataset[self.config.test_split]
1000
+
1001
+ def fewshot_docs(self):
1002
+ if self.config.fewshot_split is not None:
1003
+ if self.config.process_docs is not None:
1004
+ return self.config.process_docs(self.dataset[self.config.fewshot_split])
1005
+ return self.dataset[self.config.fewshot_split]
1006
+ elif (
1007
+ self.config.fewshot_config is not None
1008
+ and self.config.fewshot_config.get("samples", None) is not None
1009
+ ):
1010
+ if isinstance(self.config.fewshot_config["samples"], list):
1011
+ return self.config.fewshot_config["samples"]
1012
+ elif callable(self.config.fewshot_config["samples"]):
1013
+ return self.config.fewshot_config["samples"]()
1014
+ else:
1015
+ raise Exception(
1016
+ "`fewshot_config['samples']` was incorrectly defined in the configuration. It should be either a list of samples as a dict, or function returning this list."
1017
+ )
1018
+ else:
1019
+ if (self.config.num_fewshot is not None) and (self.config.num_fewshot > 0):
1020
+ eval_logger.warning(
1021
+ f"[Task: {self.config.task}] "
1022
+ "num_fewshot > 0 but fewshot_split is None. "
1023
+ "using preconfigured rule."
1024
+ )
1025
+ return super().fewshot_docs()
1026
+
1027
+ @staticmethod
1028
+ def append_target_question(
1029
+ labeled_examples: List[Dict[str, str]],
1030
+ question: str,
1031
+ fewshot_as_multiturn: bool = False,
1032
+ gen_prefix: Optional[str] = None,
1033
+ ) -> None:
1034
+ """Adds a target question to the labeled examples list.
1035
+ If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry.
1036
+ Otherwise, it is appended to the last user entry, ensuring that the conversation alternates between the user and the assistant.
1037
+ """
1038
+ if not fewshot_as_multiturn:
1039
+ # if no messages or last message is system, append as new user entry
1040
+ if len(labeled_examples) == 0 or labeled_examples[-1]["role"] == "system":
1041
+ labeled_examples.append({"role": "user", "content": question})
1042
+ # if last message is user, append to it to avoid two user messages in a row
1043
+ else:
1044
+ labeled_examples[-1]["content"] += question
1045
+ else:
1046
+ # if fewshot_as_multiturn is True, append as next user entry (last is always assistant)
1047
+ labeled_examples.append({"role": "user", "content": question})
1048
+ if gen_prefix:
1049
+ labeled_examples.append({"role": "assistant", "content": gen_prefix})
1050
+
1051
+ @utils.positional_deprecated
1052
+ def fewshot_context(
1053
+ self,
1054
+ doc: dict,
1055
+ num_fewshot: int,
1056
+ system_instruction: Optional[str] = None,
1057
+ apply_chat_template: bool = False,
1058
+ fewshot_as_multiturn: bool = False,
1059
+ chat_template: Optional[Callable] = None,
1060
+ gen_prefix: Optional[str] = None,
1061
+ ) -> Union[str, List[str]]:
1062
+ """Returns a fewshot context string that is made up of a prepended description
1063
+ (if provided), the `num_fewshot` number of examples, and an appended prompt example.
1064
+
1065
+ :param doc: str
1066
+ The document as returned from training_docs, validation_docs, or test_docs.
1067
+ :param num_fewshot: int
1068
+ The number of fewshot examples to provide in the returned context string.
1069
+ :param system_instruction: str
1070
+ System instruction to be applied to the prompt.
1071
+ :param apply_chat_template: bool
1072
+ Whether to apply the chat template to the fewshot context.
1073
+ :param fewshot_as_multiturn: bool
1074
+ Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
1075
+ :param chat_template:
1076
+ callable (from lm.apply_chat_template) that takes in a list[Dict] chat transcript and renders it into a string.
1077
+ :param gen_prefix:
1078
+ String to append after the <|assistant|> token.
1079
+ :returns: str
1080
+ The fewshot context.
1081
+ """
1082
+ if apply_chat_template:
1083
+ labeled_examples = []
1084
+ else:
1085
+ labeled_examples = ""
1086
+
1087
+ # get task description
1088
+ if description := self.config.description:
1089
+ description = utils.apply_template(self.config.description, doc)
1090
+
1091
+ # create system prompt based on the provided system instruction and description
1092
+ if system_instruction is not None and description:
1093
+ system_prompt = (
1094
+ f"{system_instruction}{self.sampler.fewshot_delimiter}{description}"
1095
+ )
1096
+ elif system_instruction is not None:
1097
+ system_prompt = system_instruction
1098
+ elif description:
1099
+ system_prompt = description
1100
+ else:
1101
+ system_prompt = ""
1102
+
1103
+ # add system prompt if specified
1104
+ if system_prompt:
1105
+ if apply_chat_template:
1106
+ labeled_examples.append({"role": "system", "content": system_prompt})
1107
+ else:
1108
+ labeled_examples = system_prompt
1109
+ # if few-shot - append examples after the system prompt
1110
+ if num_fewshot > 0:
1111
+ if apply_chat_template:
1112
+ labeled_examples.extend(
1113
+ self.sampler.get_chat_context(
1114
+ doc,
1115
+ num_fewshot,
1116
+ fewshot_as_multiturn,
1117
+ gen_prefix=gen_prefix,
1118
+ )
1119
+ )
1120
+ else:
1121
+ labeled_examples += self.sampler.get_context(
1122
+ doc, num_fewshot, gen_prefix=gen_prefix
1123
+ )
1124
+
1125
+ example = self.doc_to_text(doc)
1126
+ if apply_chat_template:
1127
+ if self.multiple_input:
1128
+ # TODO: append prefill?
1129
+ if not labeled_examples:
1130
+ return ""
1131
+ return chat_template(labeled_examples)
1132
+ if isinstance(example, str):
1133
+ self.append_target_question(
1134
+ labeled_examples,
1135
+ example,
1136
+ fewshot_as_multiturn,
1137
+ gen_prefix=gen_prefix,
1138
+ )
1139
+ # for loglikelihood create a list of questions with appended choices
1140
+ elif isinstance(example, list):
1141
+ labeled_examples_list = []
1142
+ # copy chat history for each example and append the answer
1143
+ for ex in example:
1144
+ chat = deepcopy(labeled_examples)
1145
+ self.append_target_question(
1146
+ chat,
1147
+ ex,
1148
+ fewshot_as_multiturn,
1149
+ gen_prefix=gen_prefix,
1150
+ )
1151
+ # TODO: append prefill?
1152
+ labeled_examples_list.append(
1153
+ chat_template(
1154
+ chat,
1155
+ add_generation_prompt=False if gen_prefix else True,
1156
+ )
1157
+ )
1158
+ return labeled_examples_list
1159
+ # if example is an integer, append the choice or convert to string
1160
+ elif isinstance(example, int):
1161
+ if self.config.doc_to_choice is not None:
1162
+ choices = self.doc_to_choice(doc)
1163
+ self.append_target_question(
1164
+ labeled_examples,
1165
+ choices[example],
1166
+ fewshot_as_multiturn,
1167
+ gen_prefix=gen_prefix,
1168
+ )
1169
+ else:
1170
+ self.append_target_question(
1171
+ labeled_examples,
1172
+ str(example),
1173
+ fewshot_as_multiturn,
1174
+ gen_prefix=gen_prefix,
1175
+ )
1176
+ # return lm.apply_chat_template(labeled_examples)
1177
+ return chat_template(
1178
+ labeled_examples,
1179
+ add_generation_prompt=False if gen_prefix else True,
1180
+ )
1181
+ else:
1182
+ prefix = (
1183
+ self.config.target_delimiter + gen_prefix
1184
+ if gen_prefix is not None
1185
+ else ""
1186
+ )
1187
+ if self.multiple_input:
1188
+ return labeled_examples
1189
+ if isinstance(example, str):
1190
+ return labeled_examples + example + prefix
1191
+ elif isinstance(example, list):
1192
+ return [labeled_examples + ex + prefix for ex in example]
1193
+ elif isinstance(example, int):
1194
+ if self.config.doc_to_choice is not None:
1195
+ choices = self.doc_to_choice(doc)
1196
+ return labeled_examples + choices[example] + prefix
1197
+ else:
1198
+ return labeled_examples + str(example) + prefix
1199
+
1200
+ def apply_filters(self) -> Optional[List[Instance]]:
1201
+ """Iterates over FilterEnsembles and applies them to instances"""
1202
+ if hasattr(self, "_filters"):
1203
+ for f in self._filters:
1204
+ f.apply(self._instances)
1205
+ else:
1206
+ eval_logger.warning("No filter defined, passing through instances")
1207
+ return self._instances
1208
+
1209
+ def should_decontaminate(self):
1210
+ return self.config.should_decontaminate
1211
+
1212
+ def doc_to_decontamination_query(self, doc: dict):
1213
+ if self.config.should_decontaminate:
1214
+ if self.config.doc_to_decontamination_query is None:
1215
+ return self.doc_to_text(doc)
1216
+ else:
1217
+ doc_to_decontamination_query = self.config.doc_to_decontamination_query
1218
+ if doc_to_decontamination_query in self.features:
1219
+ return doc[doc_to_decontamination_query]
1220
+ elif callable(doc_to_decontamination_query):
1221
+ return doc_to_decontamination_query(doc)
1222
+ else:
1223
+ return ast.literal_eval(
1224
+ utils.apply_template(
1225
+ self.config.doc_to_decontamination_query, doc
1226
+ )
1227
+ )
1228
+
1229
+ def _process_doc(self, doc: dict) -> dict:
1230
+ """
1231
+ Override this to process (detokenize, strip, replace, etc.) individual
1232
+ documents. This can be used in a map over documents of a data split.
1233
+ E.g. `map(self._process_doc, self.dataset["validation"])`
1234
+
1235
+ :return: dict
1236
+ The processed version of the specified `doc`.
1237
+ """
1238
+ return doc
1239
+
1240
+ def doc_to_text(self, doc, doc_to_text=None):
1241
+ if self.prompt is not None:
1242
+ doc_to_text = self.prompt
1243
+ elif doc_to_text is not None:
1244
+ doc_to_text = doc_to_text
1245
+ else:
1246
+ doc_to_text = self.config.doc_to_text
1247
+
1248
+ if isinstance(doc_to_text, int):
1249
+ return doc_to_text
1250
+ elif isinstance(doc_to_text, str):
1251
+ if doc_to_text in self.features:
1252
+ # if self.config.doc_to_choice is not None:
1253
+ # return self.doc_to_choice(doc)[doc[doc_to_text]]
1254
+ # else:
1255
+ return doc[doc_to_text]
1256
+ else:
1257
+ text_string = utils.apply_template(doc_to_text, doc)
1258
+ if text_string.isdigit() and self._config.doc_to_choice is not None:
1259
+ return ast.literal_eval(text_string)
1260
+ else:
1261
+ return text_string
1262
+ elif callable(doc_to_text):
1263
+ return doc_to_text(doc)
1264
+ # Used when applying a Promptsource template
1265
+ elif hasattr(doc_to_text, "apply"):
1266
+ applied_prompt = doc_to_text.apply(doc)
1267
+ if len(applied_prompt) == 2:
1268
+ return applied_prompt[0]
1269
+ else:
1270
+ eval_logger.warning("Applied prompt returns empty string")
1271
+ return self.config.fewshot_delimiter
1272
+ else:
1273
+ print(type(doc_to_text))
1274
+ raise TypeError
1275
+
1276
+ def doc_to_target(self, doc: Mapping, doc_to_target=None) -> Union[int, str, list]:
1277
+ if self.prompt is not None:
1278
+ doc_to_target = self.prompt
1279
+ elif doc_to_target is not None:
1280
+ doc_to_target = doc_to_target
1281
+ else:
1282
+ doc_to_target = self.config.doc_to_target
1283
+
1284
+ if isinstance(doc_to_target, int):
1285
+ return doc_to_target
1286
+ elif isinstance(doc_to_target, str):
1287
+ if doc_to_target in self.features:
1288
+ # if self.config.doc_to_choice is not None:
1289
+ # return self.doc_to_choice(doc)[doc[doc_to_target]]
1290
+ # else:
1291
+ return doc[doc_to_target]
1292
+ else:
1293
+ target_string = utils.apply_template(doc_to_target, doc)
1294
+ if target_string.isdigit() and self._config.doc_to_choice is not None:
1295
+ return ast.literal_eval(target_string)
1296
+ elif (
1297
+ len(target_string) >= 2
1298
+ and (target_string[0] == "[")
1299
+ and (target_string[-1] == "]")
1300
+ ):
1301
+ try:
1302
+ return ast.literal_eval(target_string)
1303
+ except (SyntaxError, ValueError):
1304
+ return target_string
1305
+ else:
1306
+ return target_string
1307
+ elif isinstance(doc_to_target, list):
1308
+ return doc_to_target
1309
+ elif callable(doc_to_target):
1310
+ return doc_to_target(doc)
1311
+ # Used when applying a Promptsource template
1312
+ elif hasattr(doc_to_target, "apply"):
1313
+ applied_prompt = doc_to_target.apply(doc)
1314
+ if len(applied_prompt) == 2:
1315
+ return applied_prompt[1]
1316
+ else:
1317
+ eval_logger.warning("Applied prompt returns empty string")
1318
+ return self.config.fewshot_delimiter
1319
+ else:
1320
+ raise TypeError
1321
+
1322
+ def doc_to_choice(self, doc: Any, doc_to_choice=None) -> List[str]:
1323
+ if self.prompt is not None:
1324
+ doc_to_choice = self.prompt
1325
+ elif doc_to_choice is not None:
1326
+ doc_to_choice = doc_to_choice
1327
+ elif self.config.doc_to_choice is None:
1328
+ eval_logger.error("doc_to_choice was called but not set in config")
1329
+ else:
1330
+ doc_to_choice = self.config.doc_to_choice
1331
+
1332
+ if isinstance(doc_to_choice, str):
1333
+ if doc_to_choice in self.features:
1334
+ return doc[doc_to_choice]
1335
+ else:
1336
+ return ast.literal_eval(utils.apply_template(doc_to_choice, doc))
1337
+ elif isinstance(doc_to_choice, list):
1338
+ return doc_to_choice
1339
+ elif isinstance(doc_to_choice, dict):
1340
+ return list(doc_to_choice.values())
1341
+ elif callable(doc_to_choice):
1342
+ return doc_to_choice(doc)
1343
+ elif hasattr(doc_to_choice, "get_answer_choices_list"):
1344
+ return doc_to_choice.get_answer_choices_list(doc)
1345
+ else:
1346
+ raise TypeError
1347
+
1348
+ def doc_to_image(self, doc: Any, doc_to_image=None) -> Union[int, str, list]:
1349
+ if doc_to_image is not None:
1350
+ doc_to_image = doc_to_image
1351
+ elif self.config.doc_to_image is not None:
1352
+ doc_to_image = self.config.doc_to_image
1353
+ else:
1354
+ return None
1355
+
1356
+ if isinstance(doc_to_image, list):
1357
+ image_feature = [
1358
+ self.doc_to_image(doc, feature) for feature in doc_to_image
1359
+ ]
1360
+ return [feature for feature in image_feature if feature is not None]
1361
+ elif isinstance(doc_to_image, str):
1362
+ if doc_to_image in self.features:
1363
+ return doc[doc_to_image]
1364
+ else:
1365
+ return ast.literal_eval(utils.apply_template(doc_to_image, doc))
1366
+ elif callable(doc_to_image):
1367
+ return doc_to_image(doc)
1368
+ else:
1369
+ return None
1370
+
1371
+ def doc_to_audio(self, doc: Any, doc_to_audio=None) -> Union[int, str, list]:
1372
+ if doc_to_audio is not None:
1373
+ doc_to_audio = doc_to_audio
1374
+ elif self.config.doc_to_audio is not None:
1375
+ doc_to_audio = self.config.doc_to_audio
1376
+ else:
1377
+ return None
1378
+
1379
+ if isinstance(doc_to_audio, list):
1380
+ audio_feature = [
1381
+ self.doc_to_audio(doc, feature) for feature in doc_to_audio
1382
+ ]
1383
+ return [feature for feature in audio_feature if feature is not None]
1384
+ elif isinstance(doc_to_audio, str):
1385
+ if doc_to_audio in self.features:
1386
+ return doc[doc_to_audio]
1387
+ else:
1388
+ return ast.literal_eval(utils.apply_template(doc_to_audio, doc))
1389
+ elif callable(doc_to_audio):
1390
+ return doc_to_audio(doc)
1391
+ else:
1392
+ return None
1393
+
1394
+ def doc_to_prefix(self, doc):
1395
+ if (gen_prefix := self.config.gen_prefix) is not None:
1396
+ if gen_prefix in self.features:
1397
+ return doc[gen_prefix]
1398
+ else:
1399
+ return utils.apply_template(gen_prefix, doc)
1400
+ return None
1401
+
1402
+ def construct_requests(
1403
+ self, doc: dict, ctx: str, **kwargs
1404
+ ) -> Union[List[Instance], Instance]:
1405
+ apply_chat_template = kwargs.pop("apply_chat_template", False)
1406
+ chat_template: Callable | None = kwargs.pop("chat_template", None)
1407
+
1408
+ aux_arguments = None
1409
+
1410
+ if self.OUTPUT_TYPE == "loglikelihood":
1411
+ arguments = (ctx, self.doc_to_target(doc))
1412
+ elif self.OUTPUT_TYPE == "loglikelihood_rolling":
1413
+ arguments = (self.doc_to_target(doc),)
1414
+ elif self.OUTPUT_TYPE == "multiple_choice":
1415
+ choices = self.doc_to_choice(doc)
1416
+ target_delimiter = self.config.target_delimiter
1417
+ if apply_chat_template:
1418
+ target_delimiter = ""
1419
+ if self.multiple_input:
1420
+ # If there are multiple inputs, choices are placed in the ctx
1421
+ # apply chat_template to choices if apply_chat_template
1422
+ cont = self.doc_to_target(doc)
1423
+
1424
+ arguments = [
1425
+ (
1426
+ ctx
1427
+ + (
1428
+ chat_template([{"role": "user", "content": choice}])
1429
+ if apply_chat_template
1430
+ else choice
1431
+ ),
1432
+ f"{target_delimiter}{cont}",
1433
+ )
1434
+ for choice in choices
1435
+ ]
1436
+ else:
1437
+ # Otherwise they are placed in the continuation
1438
+ arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]
1439
+
1440
+ # TODO: we should raise a warning telling users this will at most ~2x runtime.
1441
+ if "acc_mutual_info" in self._metric_fn_list.keys():
1442
+ # if we are calculating multiple choice accuracy
1443
+ # using mutual information instead of raw loglikelihood as metric, need unconditional lls.
1444
+
1445
+ # here mutual info refers to calculating
1446
+ # log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
1447
+ # in other words normalizing by subtracting the unconditional logprob of each choice.
1448
+ aux_arguments = [("", f"{choice}") for choice in choices]
1449
+
1450
+ arguments.extend(aux_arguments)
1451
+
1452
+ elif self.OUTPUT_TYPE == "generate_until":
1453
+ arguments = (ctx, deepcopy(self.config.generation_kwargs))
1454
+
1455
+ multimodal_arg = {}
1456
+ if (
1457
+ self.config.doc_to_image
1458
+ ): # TODO: ensure that non-multimodal tasks aren't getting visual args
1459
+ multimodal_arg = {
1460
+ **multimodal_arg,
1461
+ **{"visual": self.doc_to_image(doc)},
1462
+ }
1463
+
1464
+ if (
1465
+ self.config.doc_to_audio
1466
+ ): # TODO: ensure that non-multimodal tasks aren't getting audio args
1467
+ multimodal_arg = {
1468
+ **multimodal_arg,
1469
+ **{"audio": self.doc_to_audio(doc)},
1470
+ }
1471
+
1472
+ if bool(multimodal_arg):
1473
+ if isinstance(arguments, list):
1474
+ arguments = [arg + (multimodal_arg,) for arg in arguments]
1475
+ else:
1476
+ arguments = arguments + (multimodal_arg,)
1477
+
1478
+ if self.OUTPUT_TYPE == "multiple_choice":
1479
+ request_list = [
1480
+ Instance(
1481
+ request_type="loglikelihood",
1482
+ doc=doc,
1483
+ arguments=arg,
1484
+ idx=i,
1485
+ **kwargs,
1486
+ )
1487
+ for i, arg in enumerate(arguments)
1488
+ ]
1489
+
1490
+ return request_list
1491
+
1492
+ return Instance(
1493
+ request_type=self.OUTPUT_TYPE,
1494
+ doc=doc,
1495
+ arguments=arguments,
1496
+ idx=0,
1497
+ **kwargs,
1498
+ )
1499
+
1500
+ def process_results(self, doc, results):
1501
+ if callable(self.config.process_results):
1502
+ return self.config.process_results(doc, results)
1503
+
1504
+ result_dict = {}
1505
+ use_metric = list(self._metric_fn_list.keys())
1506
+ if self.OUTPUT_TYPE == "loglikelihood":
1507
+ results = results[0]
1508
+ ll, is_greedy = results
1509
+ return {
1510
+ **({"perplexity": ll} if "perplexity" in use_metric else {}),
1511
+ **({"acc": int(is_greedy)} if "acc" in use_metric else {}),
1512
+ }
1513
+ elif self.OUTPUT_TYPE == "loglikelihood_rolling":
1514
+ (loglikelihood,) = results
1515
+ _words = self.count_words(self.doc_to_target(doc))
1516
+ _bytes = self.count_bytes(self.doc_to_target(doc))
1517
+ return {
1518
+ **(
1519
+ {"word_perplexity": (loglikelihood, _words)}
1520
+ if "word_perplexity" in use_metric
1521
+ else {}
1522
+ ),
1523
+ **(
1524
+ {"byte_perplexity": (loglikelihood, _bytes)}
1525
+ if "byte_perplexity" in use_metric
1526
+ else {}
1527
+ ),
1528
+ **(
1529
+ {"bits_per_byte": (loglikelihood, _bytes)}
1530
+ if "bits_per_byte" in use_metric
1531
+ else {}
1532
+ ),
1533
+ }
1534
+ elif self.OUTPUT_TYPE == "multiple_choice":
1535
+ lls, is_greedy = zip(*results)
1536
+
1537
+ # retrieve choices in List[str] form, to compute choice lengths, etc.
1538
+ choices = self.doc_to_choice(doc)
1539
+ completion_len = np.array([float(len(i)) for i in choices])
1540
+
1541
+ if (
1542
+ 2 * len(choices) == len(lls)
1543
+ and "acc_mutual_info" in self._metric_fn_list.keys()
1544
+ ):
1545
+ # then we are doing mutual info.
1546
+ # this stores the "dryrun" / unconditional answer loglikelihoods
1547
+ lls_unconditional = lls[1::2]
1548
+ if len(lls_unconditional) != len(choices):
1549
+ raise ValueError
1550
+ # and this stores our "regular" conditional loglikelihoods
1551
+ lls = lls[::2]
1552
+
1553
+ pred = np.argmax(lls)
1554
+ pred_norm = np.argmax(lls / completion_len)
1555
+
1556
+ if self.multiple_input:
1557
+ gold = self.doc_to_text(doc)
1558
+ else:
1559
+ gold = self.doc_to_target(doc)
1560
+
1561
+ gold_index_error = False
1562
+ if isinstance(gold, list):
1563
+ gold = [i if i < len(choices) else -100 for i in gold]
1564
+ if -100 in gold:
1565
+ gold_index_error = True
1566
+ else:
1567
+ if isinstance(gold, int):
1568
+ gold = gold if gold < len(choices) else -100
1569
+ elif isinstance(gold, str):
1570
+ gold = choices.index(gold) if gold in choices else -100
1571
+
1572
+ if gold == -100:
1573
+ gold_index_error = True
1574
+
1575
+ if gold_index_error:
1576
+ eval_logger.warning(
1577
+ f"Label index was not in within range of available choices,"
1578
+ f"Sample:\n\n{doc}\n\n"
1579
+ )
1580
+
1581
+ if self.multiple_target:
1582
+ acc = 1.0 if pred in gold else 0.0
1583
+ acc_norm = 1.0 if pred_norm in gold else 0.0
1584
+ exact_match = int(any([is_greedy[i] if i != -100 else 0 for i in gold]))
1585
+ else:
1586
+ acc = 1.0 if pred == gold else 0.0
1587
+ acc_norm = 1.0 if pred_norm == gold else 0.0
1588
+ # TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
1589
+ exact_match = int(is_greedy[gold]) if gold != -100 else 0
1590
+
1591
+ prob_norm = utils.softmax(lls)
1592
+
1593
+ # TODO use keyword arguments to the metric?
1594
+ # gold, pred, norm stuff, the original lls,
1595
+ result_dict = {
1596
+ **({"acc": acc} if "acc" in use_metric else {}),
1597
+ **({"f1": (gold, pred)} if "f1" in use_metric else {}),
1598
+ **({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
1599
+ **({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
1600
+ **({"exact_match": exact_match} if "exact_match" in use_metric else {}),
1601
+ **(
1602
+ {"brier_score": (gold, prob_norm)}
1603
+ if "brier_score" in use_metric
1604
+ else {}
1605
+ ),
1606
+ }
1607
+
1608
+ if "acc_mutual_info" in use_metric:
1609
+ lls_mutual_info = [
1610
+ ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional)
1611
+ ]
1612
+ acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0
1613
+ result_dict["acc_mutual_info"] = acc_mutual_info
1614
+
1615
+ elif self.OUTPUT_TYPE == "generate_until":
1616
+ gold = self.doc_to_target(doc)
1617
+ result = results[0]
1618
+ if self.config.doc_to_choice is not None:
1619
+ # If you set doc_to_choice,
1620
+ # it assumes that doc_to_target returns a number.
1621
+ choices = self.doc_to_choice(doc)
1622
+ gold = choices[gold]
1623
+ # we expect multiple_targets to be a list.
1624
+ elif self.multiple_target:
1625
+ gold = list(gold)
1626
+ # TODO: handle this better
1627
+ elif type(gold) is not type(result) and not (
1628
+ "bypass" in self._metric_fn_list.keys() or isinstance(result, list)
1629
+ ):
1630
+ # cast gold to the same type as result
1631
+ gold = type(result)(gold)
1632
+
1633
+ for metric in self._metric_fn_list.keys():
1634
+ if self.multiple_target:
1635
+ # in the case where we have multiple targets,
1636
+ # return true if any are true
1637
+ # TODO: this may break for multipLe_target, non zero-or-1 metrics
1638
+ scores = []
1639
+ if not isinstance(gold, list):
1640
+ # sometimes, a multiple_target dataset has exceptions where one doc has only one string answer
1641
+ # print(gold)
1642
+ gold = [gold]
1643
+ if metric == "exact_match":
1644
+ result = [result for _ in range(len(gold))]
1645
+ scores = self._metric_fn_list[metric](
1646
+ references=gold,
1647
+ predictions=result,
1648
+ **self._metric_fn_kwargs[metric],
1649
+ )[metric]
1650
+ result_score = 1.0 if scores > 0.0 else 0.0
1651
+ else:
1652
+ for gold_option in gold:
1653
+ try:
1654
+ result_score = self._metric_fn_list[metric](
1655
+ references=[gold_option],
1656
+ predictions=[result],
1657
+ **self._metric_fn_kwargs[metric],
1658
+ )
1659
+ except (
1660
+ TypeError
1661
+ ): # TODO: this is hacky and I don't want to do it
1662
+ result_score = self._metric_fn_list[metric](
1663
+ [gold_option, result]
1664
+ )
1665
+ if isinstance(result_score, dict):
1666
+ # TODO: this handles the case where HF evaluate returns a dict.
1667
+ result_score = result_score[metric]
1668
+ scores.append(result_score)
1669
+ if any(scores):
1670
+ result_score = 1.0
1671
+ else:
1672
+ result_score = 0.0
1673
+ else:
1674
+ try:
1675
+ result_score = self._metric_fn_list[metric](
1676
+ references=[gold],
1677
+ predictions=[result],
1678
+ **self._metric_fn_kwargs[metric],
1679
+ )
1680
+ except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
1681
+ result_score = self._metric_fn_list[metric]([gold, result])
1682
+ if isinstance(result_score, dict):
1683
+ # TODO: this handles the case where HF evaluate returns a dict.
1684
+ # This allows for multiple metrics to be returned from the same function
1685
+ for k, v in result_score.items():
1686
+ result_dict[k] = v
1687
+ else:
1688
+ result_dict[metric] = result_score
1689
+ else:
1690
+ raise ValueError(
1691
+ f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
1692
+ "'loglikelihood', 'loglikelihood_rolling', 'generate_until' or 'multiple_choice'",
1693
+ )
1694
+
1695
+ return result_dict
1696
+
1697
+ def aggregation(self) -> dict:
1698
+ return self._aggregation_list
1699
+
1700
+ def higher_is_better(self) -> dict:
1701
+ return self._higher_is_better
1702
+
1703
+ def get_config(self, key: str) -> Any:
1704
+ return getattr(self._config, key, None)
1705
+
1706
+ @property
1707
+ def task_name(self) -> Any:
1708
+ return getattr(self.config, "task", None)
1709
+
1710
+ def __repr__(self):
1711
+ return (
1712
+ f"ConfigurableTask(task_name={getattr(self.config, 'task', None)},"
1713
+ f"output_type={self.OUTPUT_TYPE},"
1714
+ f"num_fewshot={getattr(self.config, 'num_fewshot', None)},"
1715
+ f"num_samples={len(self.eval_docs)})"
1716
+ )
1717
+
1718
+
1719
+ class MultipleChoiceTask(Task):
1720
+ OUTPUT_TYPE = "loglikelihood"
1721
+
1722
+ def doc_to_target(self, doc: dict) -> str:
1723
+ return " " + doc["choices"][doc["gold"]]
1724
+
1725
+ def construct_requests(self, doc: dict, ctx: str, **kwargs) -> List[Instance]:
1726
+ # TODO: add mutual info here?
1727
+ return [
1728
+ Instance(
1729
+ request_type="loglikelihood",
1730
+ doc=doc,
1731
+ arguments=(ctx, " {}".format(choice)),
1732
+ idx=i,
1733
+ **kwargs,
1734
+ )
1735
+ for i, choice in enumerate(doc["choices"])
1736
+ ]
1737
+
1738
+ def process_results(self, doc: dict, results: Iterable[Tuple[float, bool]]) -> dict:
1739
+ results = [
1740
+ res[0] for res in results
1741
+ ] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
1742
+ gold = doc["gold"]
1743
+
1744
+ acc = 1.0 if np.argmax(results) == gold else 0.0
1745
+ completion_len = np.array([float(len(i)) for i in doc["choices"]])
1746
+ acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0
1747
+
1748
+ return {
1749
+ "acc": acc,
1750
+ "acc_norm": acc_norm,
1751
+ }
1752
+
1753
+ def higher_is_better(self) -> dict:
1754
+ return {
1755
+ "acc": True,
1756
+ "acc_norm": True,
1757
+ }
1758
+
1759
+ def aggregation(self) -> dict:
1760
+ return {
1761
+ "acc": mean,
1762
+ "acc_norm": mean,
1763
+ }
1764
+
1765
+
1766
+ class PerplexityTask(Task):
1767
+ OUTPUT_TYPE = "loglikelihood_rolling"
1768
+
1769
+ def has_training_docs(self) -> bool:
1770
+ return False
1771
+
1772
+ def fewshot_examples(self, k: int, rnd) -> List:
1773
+ if k != 0:
1774
+ raise ValueError(
1775
+ "The number of fewshot examples must be 0 for perplexity tasks."
1776
+ )
1777
+ return []
1778
+
1779
+ def fewshot_context(self, doc: dict, num_fewshot: int) -> Literal[""]:
1780
+ if num_fewshot != 0:
1781
+ raise ValueError(
1782
+ "The number of fewshot examples must be 0 for perplexity tasks."
1783
+ )
1784
+
1785
+ return ""
1786
+
1787
+ def higher_is_better(self) -> dict:
1788
+ return {
1789
+ "word_perplexity": False,
1790
+ "byte_perplexity": False,
1791
+ "bits_per_byte": False,
1792
+ }
1793
+
1794
+ def doc_to_decontamination_query(self, doc):
1795
+ return doc
1796
+
1797
+ def doc_to_text(self, doc) -> str:
1798
+ return ""
1799
+
1800
+ def doc_to_target(self, doc):
1801
+ return doc
1802
+
1803
+ def construct_requests(self, doc: dict, ctx: Optional[str], **kwargs):
1804
+ if bool(ctx):
1805
+ raise ValueError
1806
+
1807
+ return Instance(
1808
+ request_type=self.OUTPUT_TYPE,
1809
+ doc=doc,
1810
+ arguments=(self.doc_to_target(doc),),
1811
+ idx=0,
1812
+ **kwargs,
1813
+ )
1814
+
1815
+ def process_results(self, doc: dict, results: Tuple[float]) -> dict:
1816
+ (loglikelihood,) = results
1817
+ words = self.count_words(self.doc_to_target(doc))
1818
+ bytes_ = self.count_bytes(self.doc_to_target(doc))
1819
+ return {
1820
+ "word_perplexity": (loglikelihood, words),
1821
+ "byte_perplexity": (loglikelihood, bytes_),
1822
+ "bits_per_byte": (loglikelihood, bytes_),
1823
+ }
1824
+
1825
+ def aggregation(self) -> dict:
1826
+ return {
1827
+ "word_perplexity": weighted_perplexity,
1828
+ "byte_perplexity": weighted_perplexity,
1829
+ "bits_per_byte": bits_per_byte,
1830
+ }
1831
+
1832
+ @classmethod
1833
+ def count_bytes(cls, doc) -> int:
1834
+ return len(doc.encode("utf-8"))
1835
+
1836
+ @classmethod
1837
+ def count_words(cls, doc) -> int:
1838
+ """Downstream tasks with custom word boundaries should override this!"""
1839
+ return len(re.split(r"\s+", doc))
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/caching/__init__.py ADDED
File without changes
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/caching/cache.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import logging
3
+ import os
4
+
5
+ import dill
6
+
7
+
8
+ eval_logger = logging.getLogger(__name__)
9
+
10
+
11
+ MODULE_DIR = os.path.dirname(os.path.realpath(__file__))
12
+
13
+ OVERRIDE_PATH = os.getenv("LM_HARNESS_CACHE_PATH")
14
+
15
+
16
+ PATH = OVERRIDE_PATH if OVERRIDE_PATH else f"{MODULE_DIR}/.cache"
17
+
18
+ # This should be sufficient for uniqueness
19
+ HASH_INPUT = "EleutherAI-lm-evaluation-harness"
20
+
21
+ HASH_PREFIX = hashlib.sha256(HASH_INPUT.encode("utf-8")).hexdigest()
22
+
23
+ FILE_SUFFIX = f".{HASH_PREFIX}.pickle"
24
+
25
+
26
+ def load_from_cache(file_name: str, cache: bool = False):
27
+ if not cache:
28
+ return
29
+ try:
30
+ path = f"{PATH}/{file_name}{FILE_SUFFIX}"
31
+
32
+ with open(path, "rb") as file:
33
+ cached_task_dict = dill.loads(file.read())
34
+ return cached_task_dict
35
+
36
+ except Exception:
37
+ eval_logger.debug(f"{file_name} is not cached, generating...")
38
+ pass
39
+
40
+
41
+ def save_to_cache(file_name, obj):
42
+ if not os.path.exists(PATH):
43
+ os.mkdir(PATH)
44
+
45
+ file_path = f"{PATH}/{file_name}{FILE_SUFFIX}"
46
+
47
+ eval_logger.debug(f"Saving {file_path} to cache...")
48
+ with open(file_path, "wb") as file:
49
+ file.write(dill.dumps(obj))
50
+
51
+
52
+ # NOTE the "key" param is to allow for flexibility
53
+ def delete_cache(key: str = ""):
54
+ files = os.listdir(PATH)
55
+
56
+ for file in files:
57
+ if file.startswith(key) and file.endswith(FILE_SUFFIX):
58
+ file_path = f"{PATH}/{file}"
59
+ os.unlink(file_path)
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/decontamination/__init__.py ADDED
File without changes
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/decontamination/archiver.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import io
3
+ import json
4
+ import mmap
5
+ import os
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ import jsonlines
10
+ import tqdm
11
+ import zstandard
12
+
13
+
14
+ def json_serial(obj: Any) -> str:
15
+ """JSON serializer for objects not serializable by default json code"""
16
+
17
+ if isinstance(obj, (datetime.datetime,)):
18
+ return obj.isoformat()
19
+ raise TypeError("Type %s not serializable" % type(obj))
20
+
21
+
22
+ # Modified version of lm_dataformat Archive for single file.
23
+ class Archive:
24
+ def __init__(self, file_path: str, compression_level: int = 3) -> None:
25
+ self.file_path = file_path
26
+ dir_name = os.path.dirname(file_path)
27
+ if dir_name:
28
+ os.makedirs(dir_name, exist_ok=True)
29
+ self.fh = open(self.file_path, "wb")
30
+ self.cctx = zstandard.ZstdCompressor(level=compression_level)
31
+ self.compressor = self.cctx.stream_writer(self.fh)
32
+
33
+ def add_data(self, data, meta=None) -> None:
34
+ if meta is None:
35
+ meta = {}
36
+ self.compressor.write(
37
+ json.dumps({"text": data, "meta": meta}, default=json_serial).encode(
38
+ "UTF-8"
39
+ )
40
+ + b"\n"
41
+ )
42
+
43
+ def commit(self) -> None:
44
+ self.compressor.flush(zstandard.FLUSH_FRAME)
45
+ self.fh.flush()
46
+ self.fh.close()
47
+
48
+
49
+ # Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm.
50
+ class Reader:
51
+ def __init__(self) -> None:
52
+ pass
53
+
54
+ def read(
55
+ self,
56
+ file,
57
+ get_meta: bool = False,
58
+ autojoin_paragraphs: bool = True,
59
+ para_joiner: str = "\n\n",
60
+ ):
61
+ with open(file, "rb") as fh:
62
+ self.fh = fh
63
+ cctx = zstandard.ZstdDecompressor()
64
+ reader = io.BufferedReader(cctx.stream_reader(fh))
65
+ rdr = jsonlines.Reader(reader)
66
+ for ob in rdr:
67
+ # naive jsonl where each object is just the string itself, with no meta. For legacy compatibility.
68
+ if isinstance(ob, str):
69
+ assert not get_meta
70
+ yield ob
71
+ continue
72
+
73
+ text = ob["text"]
74
+
75
+ if autojoin_paragraphs and isinstance(text, list):
76
+ text = para_joiner.join(text)
77
+
78
+ if get_meta:
79
+ yield text, (ob["meta"] if "meta" in ob else {})
80
+ else:
81
+ yield text
82
+
83
+
84
+ class TextArchive:
85
+ def __init__(self, file_path, mode: str = "rb+") -> None:
86
+ self.file_path = file_path
87
+ dir_name = os.path.dirname(file_path)
88
+ if dir_name:
89
+ os.makedirs(dir_name, exist_ok=True)
90
+
91
+ if not os.path.exists(file_path):
92
+ Path(file_path).touch()
93
+
94
+ self.fh = open(self.file_path, mode)
95
+
96
+ def add_data(self, data) -> None:
97
+ self.fh.write(data.encode("UTF-8") + b"\n")
98
+
99
+ def commit(self) -> None:
100
+ self.fh.flush()
101
+ self.fh.close()
102
+
103
+
104
+ class TextReader:
105
+ def __init__(self, file_path) -> None:
106
+ self.file_path = file_path
107
+
108
+ # Optimized mmap read with infrequent tqdm updates to maintain speed
109
+ # Tested up to 250MB/s.
110
+ def read_tqdm(self, update_frequency: int = 10000):
111
+ current_file_position = 0
112
+ line_counter = 0
113
+ with (
114
+ open(self.file_path, "r", encoding="utf-8") as fh,
115
+ tqdm.tqdm(
116
+ total=os.path.getsize(self.file_path),
117
+ dynamic_ncols=True,
118
+ unit="byte",
119
+ unit_scale=1,
120
+ ) as progress,
121
+ ):
122
+ with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
123
+ for line in iter(mmap_obj.readline, b""):
124
+ line = line.decode("utf-8")
125
+ line_counter += 1
126
+ if line_counter == update_frequency:
127
+ new_file_pos = mmap_obj.tell()
128
+ bytes_read = new_file_pos - current_file_position
129
+ current_file_position = new_file_pos
130
+ progress.update(bytes_read)
131
+ line_counter = 0
132
+ yield line[:-1]
133
+
134
+ def read_and_tell(self):
135
+ current_file_position = 0
136
+ with open(self.file_path, "r", encoding="utf8") as fh:
137
+ with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
138
+ for line in iter(mmap_obj.readline, b""):
139
+ line = line.decode("utf-8")
140
+ new_file_pos = mmap_obj.tell()
141
+ raw_bytes_read = new_file_pos - current_file_position
142
+ current_file_position = new_file_pos
143
+ yield line[:-1], raw_bytes_read
144
+
145
+ def read(self):
146
+ with open(self.file_path, "r", encoding="utf8") as fh:
147
+ with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
148
+ for line in iter(mmap_obj.readline, b""):
149
+ line = line.decode("utf-8")
150
+ yield line[:-1]
151
+
152
+ def read_slow(self):
153
+ with open(self.file_path, "r", encoding="utf8") as fh:
154
+ while True:
155
+ line = fh.readline()
156
+ if line == -1 or line == "":
157
+ break
158
+ else:
159
+ yield line[:-1]
160
+
161
+
162
+ # Optimized for speed. Decompresses the archive in shell before
163
+ # using the mmap'd TextReader.
164
+ class ZStdTextReader:
165
+ def __init__(self, file) -> None:
166
+ self.file = file
167
+
168
+ def read_tqdm(self):
169
+ decompressed_file = self.file[:-4]
170
+ print("Decompressing file, please wait...")
171
+ os.system(f"zstd -d {self.file}") # linux decompress is faster
172
+ reader = TextReader(decompressed_file)
173
+ yield from reader.read_tqdm()
174
+ os.remove(decompressed_file)
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/decontamination/decontaminate.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import glob
3
+ import json
4
+ import os
5
+ import pickle
6
+ import random
7
+ import time
8
+
9
+ from .archiver import ZStdTextReader
10
+ from .janitor import Janitor, word_ngrams
11
+
12
+
13
+ # Was used for testing the evaluator decoupled from the full logic below
14
+ def get_train_overlap_stub(docs: dict, ngrams_path: str, ngrams_n_size: str):
15
+ simulated_overlap = 0.1
16
+ contaminated = int(len(docs) * simulated_overlap)
17
+ return random.sample(range(len(docs)), contaminated)
18
+
19
+
20
+ # Returns a dictionary containing all overlapping documents in each
21
+ # task. In the standard use case, an overlap occurs when any of the 13-grams
22
+ # found in the task document exist in the training set documents.
23
+ #
24
+ # To generate 13-grams for the pile see scripts/clean_training_data. The final output of these
25
+ # scripts are an info.json file containing the n_gram_size (13) and a bunch of "ngrams_{x}.bkt.txt.sorted.zst"
26
+ # files. These should exist in the "ngrams_path" provided to this function.
27
+
28
+
29
+ # Algorithm:
30
+ # 1. Build lookups for each dataset {ngram: list(document_ids)}
31
+ # 2. Merge into an overall lookup {ngram: [(task_name, task_set, doc_ids),]}
32
+ # 3. Full scan the 13-grams from the training set against the merged lookup,
33
+ # saving matches in the "duplicates" dictionary {(task_name, task_set): set(doc_ids)}
34
+ # 4. Strip the task_set from the dictionary keys and return
35
+ #
36
+ # We cache the task+set lookups as well as the overlaps.
37
+ def get_train_overlap(docs_by_task_set: dict, ngrams_path: str, limit: int) -> dict:
38
+ # return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size)
39
+
40
+ info_dict_path = os.path.join(ngrams_path, "info.json")
41
+ info_dict = json.load(open(info_dict_path, "r", encoding="utf-8"))
42
+ ngrams_n_size = info_dict["ngram_size"]
43
+
44
+ janitor = Janitor()
45
+
46
+ # Build lookup for each dataset first in case we use different task combinations later
47
+ print("Building Lookups...")
48
+ start = time.perf_counter()
49
+
50
+ def get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit) -> str:
51
+ return f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.overlaps"
52
+
53
+ lookups = {}
54
+ duplicates = {} # (task_name, task_set): set(doc_ids)}
55
+ sets_to_decontaminate = len(docs_by_task_set.keys())
56
+
57
+ for (task_name, task_set), docs in docs_by_task_set.items():
58
+ if not os.path.exists(f"data/{task_name}"):
59
+ os.mkdir(f"data/{task_name}")
60
+
61
+ # Check if we've decontaminated this combination before
62
+ overlaps_dump_path = get_overlaps_dump_path(
63
+ task_name, task_set, ngrams_n_size, limit
64
+ )
65
+ if os.path.exists(overlaps_dump_path):
66
+ duplicates[(task_name, task_set)] = pickle.load(
67
+ open(overlaps_dump_path, "rb")
68
+ )
69
+ sets_to_decontaminate -= 1
70
+ continue
71
+ else:
72
+ duplicates[(task_name, task_set)] = set()
73
+
74
+ # Build/load the task lookup {ngram: set(documents)}.
75
+ task_set_lookup_path = (
76
+ f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.lookup"
77
+ )
78
+ if os.path.exists(task_set_lookup_path):
79
+ print(f"{task_set_lookup_path} available, loading...")
80
+ lookups[(task_name, task_set)] = pickle.load(
81
+ open(task_set_lookup_path, "rb")
82
+ )
83
+ else:
84
+ print(f"{task_set_lookup_path} not available, building...")
85
+ lookup = collections.defaultdict(set)
86
+
87
+ for doc_id, document in enumerate(docs):
88
+ ngrams = word_ngrams(janitor.normalize_string(document), ngrams_n_size)
89
+ for ngram in ngrams:
90
+ lookup[ngram].add(doc_id)
91
+
92
+ pickle.dump(lookup, open(task_set_lookup_path, "wb"))
93
+ lookups[(task_name, task_set)] = lookup
94
+
95
+ elapsed = time.perf_counter() - start
96
+ print(f"Building lookups took {elapsed:0.5f} seconds.")
97
+
98
+ matched_ngrams = []
99
+
100
+ if sets_to_decontaminate > 0:
101
+ print("Merging lookups...")
102
+ start = time.perf_counter()
103
+ merged_lookup = collections.defaultdict(list)
104
+ for (task_name, task_set), lookup in lookups.items():
105
+ for ngram, doc_ids in lookup.items():
106
+ merged_lookup[ngram].append((task_name, task_set, doc_ids))
107
+
108
+ elapsed = time.perf_counter() - start
109
+ print(f"Merging lookups took {elapsed:0.5f} seconds.")
110
+
111
+ print(f"{ngrams_n_size} grams files found in {ngrams_path}:")
112
+ files = glob.glob(os.path.join(ngrams_path, "*.sorted.zst"))
113
+ print(files)
114
+
115
+ for file in files:
116
+ start = time.perf_counter()
117
+ print(f"Scanning {file}")
118
+ reader = ZStdTextReader(file)
119
+ total_ngrams = 0
120
+ unique_ngrams = 0
121
+ matching_unique = 0
122
+ non_matching_unique = 0
123
+
124
+ current_ngram = ""
125
+ for line in reader.read_tqdm(): # Scan training set ngrams file
126
+ total_ngrams += 1
127
+ [ngram, document_id] = line.rsplit(" ", 1)
128
+ if (
129
+ ngram != current_ngram
130
+ ): # Only need to match the ngram once in training set
131
+ unique_ngrams += 1
132
+ current_ngram = ngram
133
+ if ngram in merged_lookup:
134
+ matched_ngrams.append(ngram) # For logging
135
+ matching_unique += 1
136
+ for task_name, task_set, doc_ids in merged_lookup[ngram]:
137
+ task_doc_set = duplicates[(task_name, task_set)]
138
+ for doc_id in doc_ids: # Record contamination across all relevant task/set combos
139
+ task_doc_set.add(doc_id)
140
+ del merged_lookup[ngram] # No point matching again
141
+ else:
142
+ non_matching_unique += 1
143
+
144
+ print(f"Total Ngrams: {total_ngrams}")
145
+ print(f"Unique Ngrams: {unique_ngrams}")
146
+ print(f"Unique Matching: {matching_unique}")
147
+ print(f"Unique Non Matching: {non_matching_unique}")
148
+ print("Matched ngrams:")
149
+ for ngram in matched_ngrams:
150
+ print(ngram)
151
+
152
+ elapsed = time.perf_counter() - start
153
+ print(f"Read took {elapsed:0.5f} seconds.")
154
+ print(f"Speed: {(os.path.getsize(file) / 1000000.0) / elapsed}MB/second")
155
+
156
+ print(duplicates)
157
+
158
+ # Dump overlaps separately
159
+ for (task_name, task_set), doc_ids in duplicates.items():
160
+ overlaps_dump_path = get_overlaps_dump_path(
161
+ task_name, task_set, ngrams_n_size, limit
162
+ )
163
+ pickle.dump(doc_ids, open(overlaps_dump_path, "wb"))
164
+
165
+ # Strip task set and return
166
+ return {task_name: doc_ids for (task_name, task_set), doc_ids in duplicates.items()}
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/decontamination/janitor.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import re
3
+ import string
4
+ import traceback
5
+ from typing import Iterator, List, Sequence, Tuple, TypeVar
6
+
7
+
8
+ # This is a cpp module. Compile janitor_util.cpp with:
9
+ # c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
10
+ try:
11
+ import janitor_util
12
+
13
+ JANITOR_CPP = True
14
+ except Exception:
15
+ print("WARNING: C++ module could not be loaded. Janitor running in python mode")
16
+ traceback.print_exc()
17
+ JANITOR_CPP = False
18
+
19
+ T = TypeVar("T")
20
+
21
+
22
+ # Implementation from nltk source
23
+ # https://www.nltk.org/_modules/nltk/util.html
24
+ def form_ngrams(sequence: Iterator[T], n: int) -> Iterator[Tuple[T, ...]]:
25
+ history = []
26
+ while n > 1:
27
+ # PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
28
+ try:
29
+ next_item = next(sequence)
30
+ except StopIteration:
31
+ # no more data, terminate the generator
32
+ return
33
+ history.append(next_item)
34
+ n -= 1
35
+ for item in sequence:
36
+ history.append(item)
37
+ yield tuple(history)
38
+ del history[0]
39
+
40
+
41
+ def word_ngrams(s: str, n: int) -> Iterator[str]:
42
+ """Splits a string into ngram words"""
43
+ tokens = s.split() # not a generator :(
44
+ ngram_seqs = form_ngrams(iter(tokens), n)
45
+ return (" ".join(ngram) for ngram in ngram_seqs)
46
+
47
+
48
+ # Does character sequences only - combined faster function to play around with later
49
+ # def word_ngrams_indices_combined(sequence, n):
50
+ # current_word = ""
51
+ # history = []
52
+ # gap = False;
53
+ # start = 0
54
+ # end = 0
55
+ # for character in sequence:
56
+ # if character == " ":
57
+ # if not gap:
58
+ # gap = True
59
+ # history.append(current_word)
60
+ # end += len(current_word) - 1
61
+ # current_word = ""
62
+ # if len(history) == n:
63
+ # yield (tuple(history), start, end)
64
+ # del history[0]
65
+ # start = end + 1
66
+ # end = start
67
+ # else:
68
+ # gap = False
69
+ # current_word += character
70
+
71
+
72
+ # https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python
73
+ def split_indices(s: str) -> Iterator[Tuple[str, Tuple[int, int]]]:
74
+ """Splits a string on whitespaces and records the indices of each in the original string.
75
+ @:return generator((word, (start_idx, end_idx)), ...)
76
+ """
77
+ return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s))
78
+
79
+
80
+ def word_ngrams_indices(s: str, n: int) -> Iterator[Tuple[str, Tuple[int, int]]]:
81
+ """Splits a string into pairs of (ngram words, their start/end indices)"""
82
+ tokens_with_indices = split_indices(s)
83
+
84
+ # Generator of ngrams of (word, idx_pairs)
85
+ # (
86
+ # [(word, (start,end)), (word, (start, end))...],
87
+ # [(word, (start, end)), ...],
88
+ # ...
89
+ # )
90
+ ngram_seqs_with_indices = form_ngrams(tokens_with_indices, n)
91
+
92
+ # Generator of pairs of word and index ngrams
93
+ # (
94
+ # ([word, word, ...], [(start,end), (start,end), ...]),
95
+ # ...
96
+ # )
97
+ ngram_indices_pairs = (
98
+ zip(*ngram_with_indices) for ngram_with_indices in ngram_seqs_with_indices
99
+ )
100
+
101
+ # Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...)
102
+ return (
103
+ (" ".join(ngram_seq), (indices[0][0], indices[-1][1]))
104
+ for ngram_seq, indices in ngram_indices_pairs
105
+ )
106
+
107
+
108
+ class Janitor:
109
+ # FIXME delete_chars: Should anything else go here? Special chars?
110
+ def __init__(
111
+ self,
112
+ ngram_n: int = 13,
113
+ window_to_remove: int = 200,
114
+ too_dirty_cutoff: int = 10,
115
+ minimum_slice_length: int = 200,
116
+ delete_chars: str = string.punctuation,
117
+ ) -> None:
118
+ self.ngram_n = ngram_n
119
+ self.window_to_remove = window_to_remove
120
+ self.too_dirty_cutoff = too_dirty_cutoff
121
+ self.minimum_slice_length = minimum_slice_length
122
+ self.delete_chars = delete_chars
123
+
124
+ self.dirt_ngrams = set()
125
+
126
+ # If in python, we'll translate uppercase to lowercase and delete naughty characters.
127
+ # This is fast by python standards
128
+ # https://stackoverflow.com/questions/638893/what-is-the-most-efficient-way-in-python-to-convert-a-string-to-all-lowercase-st
129
+ self.translation_table = str.maketrans(
130
+ string.ascii_lowercase + string.ascii_uppercase, # These characters
131
+ string.ascii_lowercase * 2, # Become these characters
132
+ self.delete_chars, # These are deleted
133
+ )
134
+
135
+ ##############
136
+ # I/O for saving contamination ngrams
137
+ ##############
138
+
139
+ def save_contamination_ngrams(self, filename: str) -> None:
140
+ with open(filename, "wb") as fp:
141
+ pickle.dump(filename, fp)
142
+
143
+ def load_contamination_ngrams(self, filename: str) -> None:
144
+ with open(filename, "rb") as fp:
145
+ self.dirt_ngrams = pickle.load(fp)
146
+
147
+ ##############
148
+ # Call these :)
149
+ ##############
150
+
151
+ def register_contaminant(self, dirt_string: str) -> None:
152
+ """Register a string as contamination to be removed, e.g. a test set
153
+ This breaks the dirt_string into ngrams to store for future cleaning"""
154
+ if JANITOR_CPP:
155
+ return self.register_contaminant_cpp(dirt_string)
156
+ else:
157
+ print("WARNING: Janitor running in python mode")
158
+ return self.register_contaminant_python(dirt_string)
159
+
160
+ def clean(self, dirty_string: str) -> List[str]:
161
+ """Clean a string (e.g. a training set) by removing all ngrams previously
162
+ registered as contaminants. Returns a list of clean chunks, or empty if
163
+ the string was too dirty"""
164
+ if JANITOR_CPP:
165
+ return self.clean_cpp(dirty_string)
166
+ else:
167
+ print("WARNING: Janitor running in python mode")
168
+ return self.clean_python(dirty_string)
169
+
170
+ def _split_chunks(
171
+ self, dirty_string: str, dirty_parts: Sequence[Tuple]
172
+ ) -> List[str]:
173
+ clean_chunks = []
174
+ splice_idx = 0
175
+ end = -1
176
+ for i, (ngram, start, end) in enumerate(dirty_parts):
177
+ if i >= self.too_dirty_cutoff:
178
+ return []
179
+ start = max(0, start - self.window_to_remove)
180
+ end = min(len(dirty_string), end + self.window_to_remove)
181
+
182
+ if start - splice_idx > self.minimum_slice_length:
183
+ clean_chunks.append(dirty_string[splice_idx:start])
184
+ splice_idx = end
185
+
186
+ if end < len(dirty_string) - self.minimum_slice_length:
187
+ clean_chunks.append(dirty_string[end + 1 :])
188
+
189
+ return clean_chunks
190
+
191
+ ##############
192
+ # Fast C++
193
+ ##############
194
+
195
+ def register_contaminant_cpp(self, dirt_string) -> None:
196
+ self.dirt_ngrams.update(
197
+ janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n)
198
+ )
199
+
200
+ def clean_cpp(self, dirty_string: str) -> List[str]:
201
+ contamination_indices = janitor_util.clean_ngram_with_indices(
202
+ dirty_string, self.delete_chars, self.ngram_n
203
+ )
204
+ return self._split_chunks(dirty_string, contamination_indices)
205
+
206
+ ##############
207
+ # Slow python
208
+ ##############
209
+
210
+ def normalize_string(self, s: str) -> str:
211
+ return s.translate(self.translation_table)
212
+
213
+ def register_contaminant_python(self, dirt_string: str) -> None:
214
+ self.dirt_ngrams.update(
215
+ word_ngrams(self.normalize_string(dirt_string), self.ngram_n)
216
+ )
217
+
218
+ def clean_python(self, dirty_string: str) -> List[str]:
219
+ contamination_indices = (
220
+ (None, *idx_pair)
221
+ for dirty_ngram, idx_pair in word_ngrams_indices(dirty_string, self.ngram_n)
222
+ if self.normalize_string(dirty_ngram) in self.dirt_ngrams
223
+ )
224
+ return self._split_chunks(dirty_string, contamination_indices)
225
+
226
+
227
+ ##################################################################
228
+ # Tests
229
+ #################################################################
230
+
231
+ # def print_cpp():
232
+ # source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
233
+
234
+ # for i in range(1, 10, 2):
235
+ # pprint(janitor_util.clean_ngram(source, string.punctuation, i))
236
+ # for ngram, start, end in \
237
+ # janitor_util.clean_ngram_with_indices(source, string.punctuation, i):
238
+ # print(ngram, "\t", start, end, source[start:end].replace("\n", "\\n"))
239
+
240
+
241
+ # def test_cpp():
242
+ # source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
243
+ # contaminant = "dirty boy. Clean he he"
244
+
245
+ # jan_python = Janitor()
246
+ # jan_cpp = Janitor()
247
+
248
+ # jan_python.register_contaminant_python(contaminant)
249
+ # jan_cpp.register_contaminant(contaminant)
250
+
251
+ # assert jan_python.dirt_ngrams == jan_cpp.dirt_ngrams, (jan_python.dirt_ngrams, jan_cpp.dirt_ngrams)
252
+
253
+ # assert jan_python.clean_python(source) == jan_cpp.clean(source), \
254
+ # (jan_python.clean_python(source), jan_cpp.clean(source))
255
+
256
+ # print("Passed test, python==cpp")
257
+
258
+
259
+ # def benchmark():
260
+ # # Download and put in data folder: enwik8 (100 MB) from https://cs.fit.edu/~mmahoney/compression/textdata.html
261
+ # setup = \
262
+ # """
263
+ # with open("data/enwik8", "r") as f:
264
+ # data = f.read()
265
+ # jan = Janitor(too_dirty_cutoff=1000)
266
+ # jan.register_contaminant('''
267
+ # theories is that there is a connection between &quot;geekdom&quot; and autism.
268
+ # This is hinted, for instance, by a ''Wired Magazine'' article in 2001 entitled &quot;
269
+ # The [[Geek]] Syndrome&quot;, which is a point argued by many in the autism rights
270
+ # movement{{ref|Wired}}. This article, many professionals assert, is just one example of
271
+ # the media's application of mental disease labels to what is actually variant normal behavior
272
+ # &amp;mdash;they argue that shyness, lack of athletic ability or social skills, and intellectual
273
+ # interests, even when they seem unusual to others, are not in themselves signs of autism or
274
+ # Asperger's syndrome. Others assert that it is actually the medical profession which is applying
275
+ # mental disease labels to children who in the past would have simply been accepted as a little
276
+ # different or even labeled 'gifted'. See [[clinomorphism]] for further discussion of this issue.
277
+ # Due to the recent publicity surrounding autism and autis
278
+ # ultan Al Nahyan]] granted [[Petroleum]] concessions, and oil was first found in 1958. At first,
279
+ # oil money had a marginal impact. A few lowrise concete buildings were erected, and the first
280
+ # paved road was completed in 1961, but Sheikh Shakbut, uncertain whether the new oil royalties
281
+ # would last, took a cautious approach, preferring to save the revenue rather than investing it in
282
+ # development. His brother, [[Zayed bin Sultan Al Nahayan]], saw that oil wealth had the potential
283
+ # to transform Abu Dhabi. The ruling Al Nahayan family decided that Sheikh Zayed should replace his
284
+ # brother as Ruler and carry out his vision of developing the country. On [[August 6]], [[1966]],
285
+ # with the assistance of the British, Sheikh Zayed became the new ruler. See generally, Al-Fahim, M,
286
+ # ''From Rags to Riches: A Story of Abu Dhabi'', Chapter Six (London Centre of Arab Studies, 1995),
287
+ # ISBN 1 900404 00 1. With the announcement by Britain in 1968 that it would withdraw from the
288
+ # Gulf area by 1971, Sheikh Zayed became the main driving force behind the formation of the
289
+ # [[United Arab Emirates]]. After the Emirates gained independence in 1971,
290
+ # ''')
291
+ # """
292
+
293
+ # n = 1
294
+ # print(f"Timing {n} run on 100 MB")
295
+ # print("Register contaminant")
296
+ # # print("\tPython", timeit.timeit("jan.register_contaminant_python(data)", setup=setup, globals=globals(), number=n))
297
+ # print("\tCpp", timeit.timeit("jan.register_contaminant(data)", setup=setup, globals=globals(), number=n))
298
+
299
+ # print("Clean")
300
+ # # print("\tPython", timeit.timeit("jan.clean_python(data)", setup=setup, globals=globals(), number=n))
301
+ # print("\tCpp", timeit.timeit("jan.clean(data)", setup=setup, globals=globals(), number=n))
302
+
303
+
304
+ # def test_janitor_general():
305
+ # source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
306
+ # contaminant = "dirty boy. Clean he he"
307
+
308
+ # jan = Janitor(ngram_n=3)
309
+ # jan.register_contaminant(contaminant)
310
+ # cleaned = " ".join(jan.clean(source))
311
+ # for contam in jan.dirt_ngrams:
312
+ # assert contam not in cleaned, contam
313
+
314
+ # filename = "data/saved_contam"
315
+ # jan.save_contamination_ngrams(filename)
316
+
317
+ # jan = Janitor(ngram_n=3)
318
+ # jan.load_contamination_ngrams(filename)
319
+ # cleaned = " ".join(jan.clean(source))
320
+ # for contam in jan.dirt_ngrams:
321
+ # assert contam not in cleaned, contam
322
+
323
+
324
+ # if __name__ == "__main__":
325
+ # test()
326
+ # # print_cpp()
327
+ # # test_cpp()
328
+ # # benchmark()
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/evaluator.py ADDED
@@ -0,0 +1,736 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import json
3
+ import logging
4
+ import random
5
+ import time
6
+ from collections import defaultdict
7
+ from typing import TYPE_CHECKING, List, Optional, Union
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+ import lm_eval.api.metrics
13
+ import lm_eval.api.registry
14
+ import lm_eval.api.task
15
+ import lm_eval.models
16
+ from lm_eval.caching.cache import delete_cache
17
+ from lm_eval.evaluator_utils import (
18
+ consolidate_group_results,
19
+ consolidate_results,
20
+ get_sample_size,
21
+ get_subtask_list,
22
+ get_task_list,
23
+ prepare_print_tasks,
24
+ print_writeout,
25
+ run_task_tests,
26
+ )
27
+ from lm_eval.loggers import EvaluationTracker
28
+ from lm_eval.loggers.utils import add_env_info, add_tokenizer_info, get_git_commit_hash
29
+ from lm_eval.tasks import (
30
+ TaskManager,
31
+ get_task_dict,
32
+ )
33
+ from lm_eval.utils import (
34
+ handle_non_serializable,
35
+ hash_string,
36
+ positional_deprecated,
37
+ setup_logging,
38
+ simple_parse_args_string,
39
+ )
40
+
41
+
42
+ if TYPE_CHECKING:
43
+ from lm_eval.api.model import LM
44
+ from lm_eval.api.task import Task
45
+
46
+ eval_logger = logging.getLogger(__name__)
47
+
48
+
49
+ @positional_deprecated
50
+ def simple_evaluate(
51
+ model,
52
+ model_args: Optional[Union[str, dict]] = None,
53
+ tasks: Optional[List[Union[str, dict, object]]] = None,
54
+ num_fewshot: Optional[int] = None,
55
+ batch_size: Optional[Union[int, str]] = None,
56
+ max_batch_size: Optional[int] = None,
57
+ device: Optional[str] = None,
58
+ use_cache: Optional[str] = None,
59
+ cache_requests: bool = False,
60
+ rewrite_requests_cache: bool = False,
61
+ delete_requests_cache: bool = False,
62
+ limit: Optional[Union[int, float]] = None,
63
+ bootstrap_iters: int = 100000,
64
+ check_integrity: bool = False,
65
+ write_out: bool = False,
66
+ log_samples: bool = True,
67
+ evaluation_tracker: Optional[EvaluationTracker] = None,
68
+ system_instruction: Optional[str] = None,
69
+ apply_chat_template: Union[bool, str] = False,
70
+ fewshot_as_multiturn: bool = False,
71
+ gen_kwargs: Union[str, dict, None] = None,
72
+ task_manager: Optional[TaskManager] = None,
73
+ verbosity=None,
74
+ predict_only: bool = False,
75
+ random_seed: int = 0,
76
+ numpy_random_seed: int = 1234,
77
+ torch_random_seed: int = 1234,
78
+ fewshot_random_seed: int = 1234,
79
+ confirm_run_unsafe_code: bool = False,
80
+ metadata: Optional[dict] = None,
81
+ ):
82
+ """Instantiate and evaluate a model on a list of tasks.
83
+
84
+ :param model: Union[str, LM]
85
+ Name of model or LM object, see lm_eval.models.get_model
86
+ :param model_args: Optional[str, dict]
87
+ String or dict arguments for each model class, see LM.create_from_arg_string and LM.create_from_arg_object.
88
+ Ignored if `model` argument is a LM object.
89
+ :param tasks: list[Union[str, dict, Task]]
90
+ List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
91
+ :param num_fewshot: int
92
+ Number of examples in few-shot context
93
+ :param batch_size: int or str, optional
94
+ Batch size for model
95
+ :param max_batch_size: int, optional
96
+ Maximal batch size to try with automatic batch size detection
97
+ :param device: str, optional
98
+ PyTorch device (e.g. "cpu" or "cuda:0") for running models
99
+ :param use_cache: str, optional
100
+ A path to a sqlite db file for caching model responses. `None` if not caching.
101
+ :param cache_requests: bool, optional
102
+ Speed up evaluation by caching the building of dataset requests. `None` if not caching.
103
+ :param rewrite_requests_cache: bool, optional
104
+ Rewrites all the request cache if set to `True`. `None` if not desired.
105
+ :param delete_requests_cache: bool, optional
106
+ Deletes all the request cache if set to `True`. `None` if not desired.
107
+ :param limit: int or float, optional
108
+ Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
109
+ :param bootstrap_iters:
110
+ Number of iterations for bootstrap statistics, used when calculating stderrs. set to 0 for no stderr calculations to be performed.
111
+ :param check_integrity: bool
112
+ Whether to run the relevant part of the test suite for the tasks
113
+ :param write_out: bool
114
+ If True, write out an example document and model input for checking task integrity
115
+ :param log_samples: bool
116
+ If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
117
+ :param system_instruction: str
118
+ System instruction to be applied to the prompt
119
+ :param apply_chat_template: Union[bool, str]
120
+ Specifies whether to apply a chat template to the prompt.
121
+ - If set to True, the default chat template is applied.
122
+ - If set to a string, applies the specified chat template by name.
123
+ Defaults to False (no chat template applied).
124
+ :param fewshot_as_multiturn: bool
125
+ Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
126
+ :param gen_kwargs: dict or comma-separated string
127
+ Arguments for model generation
128
+ Ignored for all tasks with loglikelihood output_type
129
+ :param verbosity: str
130
+ Verbosity level for logging
131
+ :param predict_only: bool
132
+ If true only model outputs will be generated and returned. Metrics will not be evaluated
133
+ :param random_seed: int
134
+ Random seed for python's random module. If set to None, the seed will not be set.
135
+ :param numpy_random_seed: int
136
+ Random seed for numpy. If set to None, the seed will not be set.
137
+ :param torch_random_seed: int
138
+ Random seed for torch. If set to None, the seed will not be set.
139
+ :param fewshot_random_seed: int
140
+ Random seed for fewshot sampler random generator. If set to None, the seed of generator will be set to None.
141
+ :param metadata: dict
142
+ Additional metadata to be added to the task manager. Will get passed to the download function of the task.
143
+
144
+ return
145
+ Dictionary of results
146
+ """
147
+ if verbosity is not None:
148
+ setup_logging(verbosity=verbosity)
149
+ start_date = time.time()
150
+
151
+ if isinstance(model_args, str) and (
152
+ "instruct" in model_args and not apply_chat_template
153
+ ):
154
+ eval_logger.warning(
155
+ "Instruct model detected, but chat template not applied. Recommend setting `apply_chat_template` (optionally `fewshot_as_multiturn`)."
156
+ )
157
+
158
+ if delete_requests_cache:
159
+ eval_logger.info("Deleting requests cache...")
160
+ delete_cache()
161
+
162
+ seed_message = []
163
+ if random_seed is not None:
164
+ # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1412
165
+ seed_message.append(f"Setting random seed to {random_seed}")
166
+ random.seed(random_seed)
167
+
168
+ if numpy_random_seed is not None:
169
+ seed_message.append(f"Setting numpy seed to {numpy_random_seed}")
170
+ np.random.seed(numpy_random_seed)
171
+
172
+ if torch_random_seed is not None:
173
+ seed_message.append(f"Setting torch manual seed to {torch_random_seed}")
174
+ torch.manual_seed(torch_random_seed)
175
+
176
+ if fewshot_random_seed is not None:
177
+ seed_message.append(f"Setting fewshot manual seed to {fewshot_random_seed}")
178
+
179
+ if seed_message:
180
+ eval_logger.info(" | ".join(seed_message))
181
+
182
+ if tasks is None:
183
+ tasks = []
184
+ if len(tasks) == 0:
185
+ raise ValueError(
186
+ "No tasks specified, or no tasks found. Please verify the task names."
187
+ )
188
+
189
+ if gen_kwargs is not None:
190
+ if isinstance(gen_kwargs, str):
191
+ gen_kwargs = simple_parse_args_string(gen_kwargs)
192
+ eval_logger.warning(
193
+ f"generation_kwargs: {gen_kwargs} specified through cli, these settings will update set parameters in yaml tasks. "
194
+ "Ensure 'do_sample=True' for non-greedy decoding!"
195
+ )
196
+ if not gen_kwargs:
197
+ gen_kwargs = None
198
+
199
+ if isinstance(model, str):
200
+ if model_args is None:
201
+ eval_logger.warning("model_args not specified. Using defaults.")
202
+ model_args = ""
203
+
204
+ if isinstance(model_args, dict):
205
+ eval_logger.info(
206
+ f"Initializing {model} model, with arguments: {model_args}"
207
+ )
208
+ lm = lm_eval.api.registry.get_model(model).create_from_arg_obj(
209
+ model_args,
210
+ {
211
+ "batch_size": batch_size,
212
+ "max_batch_size": max_batch_size,
213
+ "device": device,
214
+ },
215
+ )
216
+
217
+ else:
218
+ eval_logger.info(
219
+ f"Initializing {model} model, with arguments: {simple_parse_args_string(model_args)}"
220
+ )
221
+ lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
222
+ model_args,
223
+ {
224
+ "batch_size": batch_size,
225
+ "max_batch_size": max_batch_size,
226
+ "device": device,
227
+ },
228
+ )
229
+ else:
230
+ if not isinstance(model, lm_eval.api.model.LM):
231
+ raise TypeError(
232
+ f"The value of `model` passed to simple_evaluate() was of type {type(model)}, but is required to be a subclass of lm_eval.api.model.LM . This may be because you are passing an initialized Hugging Face PreTrainedModel without having wrapped it in `lm_eval.models.huggingface.HFLM(pretrained=my_model)` first."
233
+ )
234
+ eval_logger.info("Using pre-initialized model")
235
+ lm = model
236
+
237
+ if use_cache is not None:
238
+ eval_logger.info(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
239
+ lm = lm_eval.api.model.CachingLM(
240
+ lm,
241
+ use_cache
242
+ # each rank receives a different cache db.
243
+ # necessary to avoid multiple writes to cache at once
244
+ + "_rank"
245
+ + str(lm.rank)
246
+ + ".db",
247
+ )
248
+
249
+ if task_manager is None:
250
+ metadata = (
251
+ simple_parse_args_string(model_args)
252
+ if isinstance(model_args, str)
253
+ else model_args
254
+ if isinstance(model_args, dict)
255
+ else {}
256
+ ) | (metadata or {})
257
+ task_manager = TaskManager(metadata=metadata)
258
+
259
+ task_dict = get_task_dict(
260
+ tasks,
261
+ task_manager,
262
+ )
263
+
264
+ # helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
265
+ # (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed)
266
+ def _adjust_config(task_dict):
267
+ adjusted_task_dict = {}
268
+ for task_name, task_obj in task_dict.items():
269
+ if isinstance(task_obj, dict):
270
+ adjusted_task_dict = {
271
+ **adjusted_task_dict,
272
+ **{task_name: _adjust_config(task_obj)},
273
+ }
274
+
275
+ else:
276
+ if task_obj.get_config("output_type") == "generate_until":
277
+ if gen_kwargs is not None:
278
+ task_obj.set_config(
279
+ key="generation_kwargs", value=gen_kwargs, update=True
280
+ )
281
+ eval_logger.info(
282
+ f"{task_obj.config.task}: Using gen_kwargs: {task_obj.config.generation_kwargs}"
283
+ )
284
+
285
+ if predict_only:
286
+ eval_logger.info(
287
+ f"Processing {task_name} in output-only mode. Metrics will not be calculated!"
288
+ )
289
+ # we have to change the class properties post-hoc. This is pretty hacky.
290
+ task_obj.override_metric(metric_name="bypass")
291
+
292
+ # override tasks' fewshot values to the provided num_fewshot arg value
293
+ # except if tasks have it set to 0 manually in their configs--then we should never overwrite that
294
+ if num_fewshot is not None:
295
+ if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0:
296
+ eval_logger.info(
297
+ f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
298
+ )
299
+ else:
300
+ eval_logger.warning(
301
+ f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
302
+ )
303
+ task_obj.set_config(key="num_fewshot", value=num_fewshot)
304
+ else:
305
+ # if num_fewshot not provided, and the task does not define a default one, default to 0
306
+ if (
307
+ default_num_fewshot := task_obj.get_config("num_fewshot")
308
+ ) is None:
309
+ task_obj.set_config(key="num_fewshot", value=0)
310
+ # fewshot_random_seed set for tasks, even with a default num_fewshot (e.g. in the YAML file)
311
+ task_obj.set_fewshot_seed(seed=fewshot_random_seed)
312
+
313
+ adjusted_task_dict[task_name] = task_obj
314
+
315
+ return adjusted_task_dict
316
+
317
+ task_dict = _adjust_config(task_dict)
318
+
319
+ if check_integrity:
320
+ run_task_tests(task_list=tasks)
321
+
322
+ if evaluation_tracker is not None:
323
+ evaluation_tracker.general_config_tracker.log_experiment_args(
324
+ model_source=model,
325
+ model_args=model_args,
326
+ system_instruction=system_instruction,
327
+ chat_template=lm.chat_template(apply_chat_template)
328
+ if apply_chat_template
329
+ else None,
330
+ fewshot_as_multiturn=fewshot_as_multiturn,
331
+ )
332
+
333
+ results = evaluate(
334
+ lm=lm,
335
+ task_dict=task_dict,
336
+ limit=limit,
337
+ cache_requests=cache_requests,
338
+ rewrite_requests_cache=rewrite_requests_cache,
339
+ bootstrap_iters=bootstrap_iters,
340
+ write_out=write_out,
341
+ log_samples=True if predict_only else log_samples,
342
+ system_instruction=system_instruction,
343
+ apply_chat_template=apply_chat_template,
344
+ fewshot_as_multiturn=fewshot_as_multiturn,
345
+ verbosity=verbosity,
346
+ confirm_run_unsafe_code=confirm_run_unsafe_code,
347
+ )
348
+ if verbosity is not None:
349
+ setup_logging(verbosity=verbosity)
350
+
351
+ if lm.rank == 0:
352
+ if isinstance(model, str):
353
+ model_name = model
354
+ elif hasattr(model, "config") and hasattr(model.config, "_name_or_path"):
355
+ model_name = model.config._name_or_path
356
+ else:
357
+ model_name = type(model).__name__
358
+
359
+ # add info about the model and few shot config
360
+ results["config"] = {
361
+ "model": model_name,
362
+ "model_args": model_args,
363
+ }
364
+ # add more detailed model info if available
365
+ if isinstance(lm, lm_eval.models.huggingface.HFLM):
366
+ results["config"].update(lm.get_model_info())
367
+ # add info about execution
368
+ results["config"].update(
369
+ {
370
+ "batch_size": batch_size,
371
+ "batch_sizes": (
372
+ list(lm.batch_sizes.values()) if hasattr(lm, "batch_sizes") else []
373
+ ),
374
+ "device": device,
375
+ "use_cache": use_cache,
376
+ "limit": limit,
377
+ "bootstrap_iters": bootstrap_iters,
378
+ "gen_kwargs": gen_kwargs,
379
+ "random_seed": random_seed,
380
+ "numpy_seed": numpy_random_seed,
381
+ "torch_seed": torch_random_seed,
382
+ "fewshot_seed": fewshot_random_seed,
383
+ }
384
+ )
385
+ results["git_hash"] = get_git_commit_hash()
386
+ results["date"] = start_date
387
+ add_env_info(results) # additional environment info to results
388
+ add_tokenizer_info(results, lm) # additional info about tokenizer
389
+ return results
390
+ else:
391
+ return None
392
+
393
+
394
+ @positional_deprecated
395
+ def evaluate(
396
+ lm: "LM",
397
+ task_dict,
398
+ limit: Optional[int] = None,
399
+ cache_requests: bool = False,
400
+ rewrite_requests_cache: bool = False,
401
+ bootstrap_iters: Optional[int] = 100000,
402
+ write_out: bool = False,
403
+ log_samples: bool = True,
404
+ system_instruction: Optional[str] = None,
405
+ apply_chat_template: Union[bool, str] = False,
406
+ fewshot_as_multiturn: bool = False,
407
+ verbosity: str = "INFO",
408
+ confirm_run_unsafe_code: bool = False,
409
+ ):
410
+ """Instantiate and evaluate a model on a list of tasks.
411
+
412
+ :param lm: obj
413
+ Language Model
414
+ :param task_dict: dict[str, Task]
415
+ Dictionary of tasks. Tasks will be taken to have name type(task).config.task .
416
+ :param limit: int, optional
417
+ Limit the number of examples per task (only use this for testing)
418
+ :param cache_requests: bool, optional
419
+ Speed up evaluation by caching the building of dataset requests.
420
+ :param rewrite_requests_cache: bool, optional
421
+ Rewrites all the request cache if set to `True`.
422
+ :param bootstrap_iters:
423
+ Number of iterations for bootstrap statistics, used when calculating stderr. Set to 0 for skipping all stderr calculations.
424
+ :param write_out: bool
425
+ If True, write out an example document and model input for checking task integrity
426
+ :param log_samples: bool
427
+ If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
428
+ :param system_instruction: str
429
+ System instruction to be applied to the prompt
430
+ :param apply_chat_template: Union[bool, str]
431
+ Specifies whether to apply a chat template to the prompt.
432
+ - If set to True, the default chat template is applied.
433
+ - If set to a string, applies the specified chat template by name.
434
+ Defaults to False (no chat template applied).
435
+ :param fewshot_as_multiturn: bool
436
+ Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
437
+ :param verbosity: str
438
+ Verbosity level for logging
439
+ :param confirm_run_unsafe_code: bool
440
+ Whether to confirm running tasks marked as unsafe.
441
+ :return
442
+ Dictionary of results
443
+ """
444
+
445
+ if apply_chat_template:
446
+ eval_logger.warning(
447
+ "Chat template formatting change affects loglikelihood and multiple-choice tasks. See docs/chat-template-readme.md for details."
448
+ )
449
+
450
+ # tracks all Instances/requests a model must generate output on.
451
+ requests = defaultdict(list)
452
+ # stores the amount to pad out reqs per req. type so that
453
+ # number of fwd passes per distributed rank is equal
454
+ padding_requests = defaultdict(int)
455
+
456
+ # get lists of group hierarchy and each type of request
457
+ eval_tasks = get_task_list(task_dict)
458
+ if not log_samples:
459
+ if not all(
460
+ "bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys()
461
+ for task_output in eval_tasks
462
+ ):
463
+ raise ValueError("log_samples must be True for 'bypass' metric-only tasks")
464
+
465
+ # validation checks:
466
+ # 1.are we running multimodal task <-> non-multimodal model class, or vice-versa.
467
+ # 2.are we running code that is marked as unsafe.
468
+ incompatible_tasks = []
469
+ for task_output in eval_tasks:
470
+ task: Task = task_output.task
471
+
472
+ if getattr(lm, "MULTIMODAL", False) != getattr(task, "MULTIMODAL", False):
473
+ incompatible_tasks.append(task_output.task_name)
474
+ elif getattr(task, "UNSAFE_CODE", False) and not confirm_run_unsafe_code:
475
+ raise ValueError(
476
+ f"Attempted to run task: {task_output.task_name} which is marked as unsafe. Set confirm_run_unsafe_code=True to run this task."
477
+ )
478
+ if len(incompatible_tasks) > 0:
479
+ if not getattr(lm, "MULTIMODAL", False):
480
+ raise ValueError(
481
+ f"Attempted to run tasks: {incompatible_tasks} which require multimodal input, but the selected model type does not currently implement this. Multimodal support is currently restricted to the ['hf-multimodal', 'vllm-vlm'] model type."
482
+ )
483
+ else:
484
+ raise ValueError(
485
+ f"Attempted to run tasks: {incompatible_tasks} which are text-only, but used a model type which only currently supports multimodal tasks."
486
+ )
487
+ # end validation check
488
+
489
+ # Cache the limit arg.
490
+ limit_arg = limit
491
+ limits = []
492
+ for task_output in eval_tasks:
493
+ task: Task = task_output.task
494
+
495
+ limit = get_sample_size(task, limit_arg)
496
+ limits.append(limit)
497
+ task.build_all_requests(
498
+ limit=limit,
499
+ rank=lm.rank,
500
+ world_size=lm.world_size,
501
+ cache_requests=cache_requests,
502
+ rewrite_requests_cache=rewrite_requests_cache,
503
+ system_instruction=system_instruction,
504
+ apply_chat_template=bool(apply_chat_template),
505
+ fewshot_as_multiturn=fewshot_as_multiturn,
506
+ chat_template=getattr(lm, "apply_chat_template")
507
+ if apply_chat_template
508
+ else None,
509
+ tokenizer_name=getattr(lm, "tokenizer_name", "")
510
+ if apply_chat_template
511
+ else "",
512
+ )
513
+ eval_logger.debug(
514
+ f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}"
515
+ )
516
+ if write_out:
517
+ print_writeout(task)
518
+ # aggregate Instances by LM method requested to get output.
519
+ for instance in task.instances:
520
+ reqtype = instance.request_type
521
+ requests[reqtype].append(instance)
522
+
523
+ if lm.world_size > 1:
524
+ instances_rnk = torch.tensor(len(task._instances), device=lm.device)
525
+ gathered_item = (
526
+ lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
527
+ )
528
+ # "multiple_choice" task types dispatch (several) "loglikelihood" request types
529
+ reqtype = (
530
+ "loglikelihood"
531
+ if task.OUTPUT_TYPE == "multiple_choice"
532
+ else task.OUTPUT_TYPE
533
+ )
534
+ # compute number of pseudo-batches to pad with (FSDP/DDP require even batches among ranks)
535
+ numpad = max(gathered_item) - gathered_item[lm.rank]
536
+ # todo: may not account for padding in cases like SquadV2 which has multiple req types
537
+ padding_requests[reqtype] += numpad
538
+
539
+ ### Run LM on inputs, get all outputs ###
540
+ # execute each type of request
541
+ for reqtype, reqs in requests.items():
542
+ eval_logger.info(f"Running {reqtype} requests")
543
+ # create `K` copies of each request `req` based off `K = req.repeats`
544
+ cloned_reqs = []
545
+ for req in reqs:
546
+ cloned_reqs.extend([req] * req.repeats)
547
+
548
+ if (lm.world_size > 1) and (padding_requests[reqtype] > 0):
549
+ for _ in range(padding_requests[reqtype]):
550
+ cloned_reqs.extend([req] * req.repeats)
551
+
552
+ # run requests through model
553
+ resps = getattr(lm, reqtype)(cloned_reqs)
554
+
555
+ # put responses from model into a list of length K for each request.
556
+ for x, req in zip(resps, cloned_reqs):
557
+ req.resps.append(x)
558
+
559
+ if lm.world_size > 1:
560
+ lm.accelerator.wait_for_everyone()
561
+
562
+ RANK = lm.rank
563
+ WORLD_SIZE = lm.world_size
564
+ ### Postprocess outputs ###
565
+ # TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
566
+ for task_output, limit in zip(eval_tasks, limits):
567
+ task = task_output.task
568
+ task.apply_filters()
569
+
570
+ ### Collect values of metrics on all datapoints ###
571
+ # # unpack results and sort back in order and return control to Task
572
+ # TODO: make it possible to use a different metric per filter
573
+ # Pre-process task.instances to group by doc_id
574
+ instances_by_doc_id = defaultdict(list)
575
+ for instance in task.instances:
576
+ instances_by_doc_id[instance.doc_id].append(instance)
577
+ # Sort instances within each group
578
+ for instances in instances_by_doc_id.values():
579
+ instances.sort(key=lambda x: x.idx)
580
+ # iterate over different filters used
581
+ for filter_key in task.instances[0].filtered_resps.keys():
582
+ doc_iterator = task.doc_iterator(
583
+ rank=RANK, limit=limit, world_size=WORLD_SIZE
584
+ )
585
+ for doc_id, doc in doc_iterator:
586
+ requests = instances_by_doc_id[doc_id]
587
+ metrics = task.process_results(
588
+ doc, [req.filtered_resps[filter_key] for req in requests]
589
+ )
590
+ if log_samples:
591
+ target = task.doc_to_target(doc)
592
+ example = {
593
+ "doc_id": doc_id,
594
+ "doc": doc,
595
+ "target": target,
596
+ "arguments": [req.args for req in requests],
597
+ "resps": [req.resps for req in requests],
598
+ "filtered_resps": [
599
+ req.filtered_resps[filter_key] for req in requests
600
+ ],
601
+ "filter": filter_key,
602
+ "metrics": list(metrics.keys()),
603
+ "doc_hash": hash_string(
604
+ json.dumps(
605
+ requests[0].doc,
606
+ indent=2,
607
+ default=handle_non_serializable,
608
+ ensure_ascii=False,
609
+ )
610
+ ),
611
+ "prompt_hash": hash_string(requests[0].arguments[0]),
612
+ "target_hash": hash_string(str(target)),
613
+ }
614
+ example.update(metrics)
615
+ task_output.logged_samples.append(example)
616
+ for metric, value in metrics.items():
617
+ task_output.sample_metrics[(metric, filter_key)].append(value)
618
+
619
+ if WORLD_SIZE > 1:
620
+ # if multigpu, then gather data across all ranks to rank 0
621
+ # first gather logged samples across all ranks
622
+ for task_output in eval_tasks:
623
+ if log_samples:
624
+ # for task_name, task_samples in list(samples.items()):
625
+ full_samples = [None] * WORLD_SIZE if RANK == 0 else None
626
+ torch.distributed.gather_object(
627
+ obj=task_output.logged_samples,
628
+ object_gather_list=full_samples,
629
+ dst=0,
630
+ )
631
+
632
+ if RANK == 0:
633
+ task_output.logged_samples = list(
634
+ itertools.chain.from_iterable(full_samples)
635
+ )
636
+
637
+ # then collect metrics across all ranks
638
+ for metrics in task_output.sample_metrics:
639
+ metric_list = [None] * WORLD_SIZE if RANK == 0 else None
640
+ torch.distributed.gather_object(
641
+ obj=task_output.sample_metrics[metrics],
642
+ object_gather_list=metric_list,
643
+ dst=0,
644
+ )
645
+ if RANK == 0:
646
+ task_output.sample_metrics[metrics] = list(
647
+ itertools.chain.from_iterable(metric_list)
648
+ )
649
+
650
+ if RANK == 0:
651
+ ### Aggregate results over all datapoints ###
652
+ # aggregate results ; run bootstrap CIs
653
+ for task_output in eval_tasks:
654
+ task_output.calculate_aggregate_metric(bootstrap_iters=bootstrap_iters)
655
+ (
656
+ results,
657
+ samples,
658
+ configs,
659
+ versions,
660
+ num_fewshot,
661
+ higher_is_better,
662
+ ) = consolidate_results(eval_tasks)
663
+
664
+ ### Calculate group metrics ###
665
+ if bool(results):
666
+ results, versions, show_group_table, *_ = consolidate_group_results(
667
+ results, versions, task_dict
668
+ )
669
+
670
+ results_agg, group_agg = prepare_print_tasks(task_dict, results)
671
+ subtask_list = get_subtask_list(task_dict)
672
+
673
+ # collect all higher_is_better values for metrics
674
+ # in the group's subtasks.
675
+ # TODO: clean this up ; unify with the below metric_list loop?
676
+ _higher_is_better = {}
677
+ for group, task_list in subtask_list.items():
678
+ if (
679
+ len(task_list) != 0
680
+ ): # subtask list will list "task_name": [] for solo tasks
681
+ for task in task_list:
682
+ for m, h in higher_is_better[task].items():
683
+ if m not in _higher_is_better.keys():
684
+ _higher_is_better[m] = h
685
+
686
+ if (
687
+ m in _higher_is_better
688
+ and _higher_is_better[m] is not None
689
+ and _higher_is_better[m] != h
690
+ ):
691
+ eval_logger.warning(
692
+ f"Higher_is_better values for metric {m} in group {group} are not consistent. Defaulting to None."
693
+ )
694
+ _higher_is_better[m] = None
695
+ higher_is_better[group] = _higher_is_better
696
+
697
+ results_dict = {
698
+ "results": dict(results_agg.items()),
699
+ **(
700
+ {"groups": dict(group_agg.items())}
701
+ if (bool(group_agg) & show_group_table)
702
+ else {}
703
+ ),
704
+ "group_subtasks": dict(reversed(subtask_list.items())),
705
+ "configs": dict(sorted(configs.items())),
706
+ "versions": dict(sorted(versions.items())),
707
+ "n-shot": dict(sorted(num_fewshot.items())),
708
+ "higher_is_better": dict(sorted(higher_is_better.items())),
709
+ "n-samples": {
710
+ task_output.task_name: {
711
+ "original": len(task_output.task.eval_docs),
712
+ "effective": min(
713
+ limit if limit else len(task_output.task.eval_docs),
714
+ len(task_output.task.eval_docs),
715
+ ),
716
+ }
717
+ for task_output, limit in zip(eval_tasks, limits)
718
+ },
719
+ }
720
+ if log_samples:
721
+ results_dict["samples"] = dict(samples)
722
+
723
+ return results_dict
724
+
725
+ else:
726
+ return None
727
+
728
+
729
+ def request_caching_arg_to_dict(cache_requests: str) -> dict:
730
+ request_caching_args = {
731
+ "cache_requests": cache_requests in {"true", "refresh"},
732
+ "rewrite_requests_cache": cache_requests == "refresh",
733
+ "delete_requests_cache": cache_requests == "delete",
734
+ }
735
+
736
+ return request_caching_args
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/evaluator_utils.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import logging
3
+ import math
4
+ import pathlib
5
+ import sys
6
+ from typing import List, Optional, Tuple, Union
7
+
8
+ from lm_eval.api.group import ConfigurableGroup
9
+ from lm_eval.api.metrics import (
10
+ aggregate_subtask_metrics,
11
+ mean,
12
+ pooled_sample_stderr,
13
+ stderr_for_metric,
14
+ )
15
+ from lm_eval.api.task import Task
16
+ from lm_eval.utils import positional_deprecated
17
+
18
+
19
+ eval_logger = logging.getLogger(__name__)
20
+
21
+
22
+ class TaskOutput:
23
+ """
24
+ Wrapper class for Task outputs.It contains various attributes and methods to manage and calculate metrics for the task.
25
+
26
+ Attributes:
27
+ task (object): The task object.
28
+ task_name (str): The name of the task.
29
+ task_config (dict): The configuration of the task.
30
+ version (str): The version of the task.
31
+ group_name (str): The name of the task group.
32
+ n_shot (int): The number of shots for the task.
33
+ task_alias (str): The alias of the task.
34
+ group_alias (str): The alias of the task group.
35
+ is_group (bool): Indicates if the task is a group.
36
+ logged_samples (list): The list of logged samples.
37
+ sample_len (int): The length of the samples.
38
+ sample_metrics (defaultdict): The dictionary of samples' metrics.
39
+ agg_metrics (defaultdict): The dictionary of aggregate metrics.
40
+
41
+ Methods:
42
+ from_taskdict(cls, task_name: str, task):
43
+ Creates a TaskOutput instance from a task dictionary.
44
+
45
+ calculate_aggregate_metric(bootstrap_iters=100000) -> None:
46
+ Calculates the aggregate metrics for the task.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ task=None,
52
+ task_name=None,
53
+ task_config=None,
54
+ version=None,
55
+ group_name=None,
56
+ n_shot=None,
57
+ task_alias=None,
58
+ group_alias=None,
59
+ is_group=None,
60
+ ):
61
+ self.task = task
62
+ self.task_config = task_config
63
+ self.task_name = task_name
64
+ self.group_name = group_name
65
+ self.version = version
66
+ self.n_shot = n_shot
67
+ self.task_alias = task_alias
68
+ self.group_alias = group_alias
69
+ self.is_group = is_group
70
+ self.logged_samples = []
71
+ self.sample_len = None
72
+ self.sample_metrics = collections.defaultdict(list)
73
+ self.agg_metrics = collections.defaultdict(list)
74
+
75
+ @classmethod
76
+ def from_taskdict(cls, task_name: str, task):
77
+ if isinstance(task, tuple):
78
+ group_name, task = task
79
+ else:
80
+ group_name = None
81
+ if not task:
82
+ # these gets filtered out in get_task_list
83
+ # once they are added to group hierarchy
84
+ is_group = True
85
+ return cls(
86
+ task=task, task_name=task_name, is_group=is_group, group_name=group_name
87
+ )
88
+ version = task.VERSION
89
+ task_config = dict(task.dump_config())
90
+ if (n_shot := task_config.get("num_fewshot")) == 0:
91
+ n_shot = task_config.get("metadata", {}).get("num_fewshot", 0)
92
+ task_alias = task_config.get("alias")
93
+ group_alias = task_config.get("group_alias")
94
+ return cls(
95
+ task=task,
96
+ task_name=task_name,
97
+ task_config=task_config,
98
+ group_name=group_name,
99
+ version=version,
100
+ n_shot=n_shot,
101
+ task_alias=task_alias,
102
+ group_alias=group_alias,
103
+ )
104
+
105
+ def calculate_aggregate_metric(self, bootstrap_iters=100000) -> None:
106
+ for (metric, filter_key), items in self.sample_metrics.items():
107
+ try:
108
+ agg_fn = self.task.aggregation()[metric]
109
+ except KeyError:
110
+ # This is when process results output an arbitrary metric
111
+ # TODO: Handle this better and allow other aggregate functions other than mean.
112
+ agg_fn = mean
113
+ metric_key = f"{metric},{filter_key}"
114
+ self.agg_metrics[metric_key] = agg_fn(items)
115
+ self.sample_len = len(items) # TODO: same sample size for each metric?
116
+ if isinstance(bootstrap_iters, int):
117
+ stderr_fn = stderr_for_metric(
118
+ metric=agg_fn,
119
+ bootstrap_iters=min(bootstrap_iters, 100)
120
+ if metric in ["bleu", "chrf", "ter"]
121
+ else bootstrap_iters,
122
+ )
123
+ self.agg_metrics[f"{metric}_stderr,{filter_key}"] = (
124
+ stderr_fn(items) if (stderr_fn and len(items) > 1) else "N/A"
125
+ )
126
+ else:
127
+ raise ValueError(
128
+ f"Received bootstrap_iters '{bootstrap_iters}' but expected an integer. Set to 0 to turn off stderr calculations."
129
+ )
130
+
131
+ def __repr__(self):
132
+ return (
133
+ f"TaskOutput(task_name={self.task_name}, "
134
+ f"group_name={self.group_name}, "
135
+ f"version={self.version}, "
136
+ f"n_shot={self.n_shot}, "
137
+ f"task_alias={self.task_alias}, "
138
+ f"group_alias={self.group_alias})"
139
+ )
140
+
141
+
142
+ def get_task_list(task_dict: dict) -> List[TaskOutput]:
143
+ outputs = []
144
+ for task_name, task_obj in task_dict.items():
145
+ if isinstance(task_obj, dict):
146
+ _outputs = get_task_list(task_obj)
147
+ outputs.extend(_outputs)
148
+ else:
149
+ task_output = TaskOutput.from_taskdict(task_name, task_obj)
150
+ outputs.append(task_output)
151
+
152
+ return outputs
153
+
154
+
155
+ def get_subtask_list(task_dict, task_root=None, depth=0):
156
+ subtask_list = {}
157
+ for group_obj, task_obj in task_dict.items():
158
+ if isinstance(group_obj, ConfigurableGroup):
159
+ # group_name = group_obj.group_name
160
+ group_name = group_obj.group_name
161
+ else:
162
+ group_name = group_obj
163
+ if isinstance(task_obj, dict):
164
+ _subtask_list = get_subtask_list(
165
+ task_obj, task_root=group_name, depth=depth + 1
166
+ )
167
+ if task_root:
168
+ subtask_list.setdefault((task_root, depth), []).extend(
169
+ [
170
+ _task
171
+ for (_task, _depth) in _subtask_list.keys()
172
+ if (_depth - 1) == depth
173
+ ]
174
+ )
175
+
176
+ subtask_list = {**subtask_list, **_subtask_list}
177
+ else:
178
+ if isinstance(task_obj, ConfigurableGroup):
179
+ # group_or_task_name = task_obj.group_name
180
+ group_or_task_name = task_obj.group_name
181
+ elif isinstance(task_obj, Task):
182
+ # group_or_task_name = task_obj.task_name
183
+ group_or_task_name = task_obj.task_name
184
+
185
+ if task_root is None:
186
+ subtask_list.setdefault((group_or_task_name, depth), [])
187
+ else:
188
+ subtask_list.setdefault((task_root, depth), []).append(
189
+ group_or_task_name
190
+ )
191
+
192
+ if depth == 0:
193
+ _subtask_list = {}
194
+ for group_key, task_list in subtask_list.items():
195
+ group_name, depth = group_key
196
+ _subtask_list[group_name] = task_list
197
+ subtask_list = _subtask_list
198
+
199
+ return subtask_list
200
+
201
+
202
+ def print_writeout(task) -> None:
203
+ for inst in task.instances:
204
+ # print the prompt for the first few documents
205
+ if inst.doc_id < 1:
206
+ eval_logger.info(
207
+ f"Task: {task}; document {inst.doc_id}; context prompt (starting on next line):\
208
+ \n{inst.args[0]}\n(end of prompt on previous line)\ntarget string or answer choice index (starting on next line):\n{task.doc_to_target(inst.doc)}\n(end of target on previous line)"
209
+ )
210
+ eval_logger.info(f"Request: {str(inst)}")
211
+
212
+
213
+ def get_sample_size(task, limit: Optional[int]) -> Union[int, None]:
214
+ if limit is not None:
215
+ limit = (
216
+ int(math.ceil(len(task.eval_docs) * limit)) if limit < 1.0 else int(limit)
217
+ )
218
+ return limit
219
+
220
+
221
+ def prepare_print_tasks(
222
+ task_dict: dict,
223
+ results: dict,
224
+ task_depth=0,
225
+ group_depth=0,
226
+ ) -> Tuple[dict, dict]:
227
+ """
228
+ @param task_dict: Dictionary representing the group hierarchy of tasks. Each key is a group name and its
229
+ value is a list of task names.
230
+ @param results: Dictionary containing the results of each task. Each key is a
231
+ group name and its value is a dictionary of task results.
232
+ @param task_depth: The indentation level for printing the task
233
+ hierarchy. Default is 0.
234
+ @param group_depth: The indentation level for printing the group
235
+ hierarchy. Default is 0.
236
+ @return: A tuple of two dictionaries: results_agg and groups_agg. results_agg contains
237
+ aggregated results for each task, and groups_agg contains aggregated results for each group.
238
+
239
+ Prepares the task hierarchy and aggregates the results for each task and group recursively for printing.
240
+ """
241
+
242
+ def _sort_task_dict(task_dict):
243
+ """
244
+ Helper utility. Sorts the task dict at the current level of the hierarchy based on alphabetized task name.
245
+ Required so that we end up sorting within each sub-header correctly.
246
+ """
247
+
248
+ return dict(
249
+ sorted(
250
+ task_dict.items(),
251
+ key=lambda item: item[0].group_name
252
+ if isinstance(item[0], ConfigurableGroup)
253
+ else item[0],
254
+ )
255
+ )
256
+
257
+ task_agg = collections.defaultdict(dict)
258
+ group_agg = collections.defaultdict(dict)
259
+ task_dict = _sort_task_dict(task_dict)
260
+ for task_or_group_name, task_or_group_obj in task_dict.items():
261
+ tab_string = " " * task_depth + "- " if task_depth > 0 else ""
262
+ if isinstance(task_or_group_name, ConfigurableGroup):
263
+ # string_name = task_or_group_name.group_name
264
+ name = task_or_group_name.group_name
265
+ from_configurable_group = True
266
+ task_or_group_obj = _sort_task_dict(task_or_group_obj)
267
+ elif isinstance(task_or_group_name, str):
268
+ name = task_or_group_name
269
+ if isinstance(task_or_group_obj, Task):
270
+ # string_name = task_or_group_obj.task_name
271
+ name = task_or_group_obj.task_name
272
+ from_configurable_group = False
273
+
274
+ task_agg[name] = results[name].copy()
275
+ if from_configurable_group:
276
+ if task_or_group_name.group_alias is not None:
277
+ alias = task_or_group_name.group_alias
278
+ else:
279
+ alias = task_or_group_name.group
280
+ else:
281
+ if "alias" in task_agg[name]:
282
+ alias = task_agg[name]["alias"]
283
+ else:
284
+ alias = name
285
+
286
+ task_agg[name]["alias"] = tab_string + alias
287
+ if "samples" in task_agg[name]:
288
+ task_agg[name].pop("samples")
289
+
290
+ if from_configurable_group and (" " not in results[name]):
291
+ group_tab_string = " " * group_depth + "- " if group_depth > 0 else ""
292
+ group_agg[name] = results[name].copy()
293
+ group_agg[name]["alias"] = group_tab_string + alias
294
+ if "samples" in group_agg[name]:
295
+ group_agg[name].pop("samples")
296
+
297
+ if isinstance(task_or_group_obj, dict):
298
+ task_depth += 1
299
+ group_depth += 1
300
+ _task_agg, _group_agg = prepare_print_tasks(
301
+ task_or_group_obj, results, task_depth, group_depth
302
+ )
303
+ task_agg = {
304
+ **task_agg,
305
+ **_task_agg,
306
+ }
307
+ group_agg = {**group_agg, **_group_agg}
308
+ task_depth -= 1
309
+ group_depth -= 1
310
+ return task_agg, group_agg
311
+
312
+
313
+ def consolidate_results(
314
+ eval_tasks: List[TaskOutput],
315
+ ) -> Tuple[dict, dict, dict, dict, dict, dict]:
316
+ """
317
+ @param eval_tasks: list(TaskOutput).
318
+ @return: A tuple containing the consolidated results, samples, configs, versions, and num_fewshot.
319
+
320
+ Consolidates the results of multiple evaluation tasks into a single structure.
321
+
322
+ The method iterates over each evaluation instance and extracts relevant information to create the consolidated
323
+ results structure. The consolidated results structure has the following properties:
324
+
325
+ - results: A defaultdict with task names as keys and dictionaries as values. Each dictionary contains
326
+ metric/filter pairs as keys and corresponding metric values as values. The "alias" key is used to store task
327
+ aliases specified in the task configuration.
328
+ - samples: A defaultdict with task names as keys and lists of log samples as values.
329
+ - configs: A defaultdict with task names as keys and task configurations as values.
330
+ - versions: A defaultdict with task names as keys and task versions as values.
331
+ - num_fewshot: A defaultdict with task names as keys and number of few-shot samples as values.
332
+ - higher_is_better: A defaultdict with task names as keys and indicators of whether higher values are better
333
+ for each metric as values.
334
+
335
+ The method then returns the consolidated results, samples, configs, versions, and num_fewshot as a tuple.
336
+ """
337
+ # stores the final result for each task, for each metric/filter pair.
338
+ results = collections.defaultdict(dict)
339
+ # logs info about each document evaluated.
340
+ samples = collections.defaultdict(list)
341
+ # store num-fewshot value per task
342
+ num_fewshot = collections.defaultdict(int)
343
+ # Tracks the YAML configs of all chosen task
344
+ configs = collections.defaultdict(dict)
345
+ # Tracks each task's version.
346
+ versions = collections.defaultdict(dict)
347
+ # Track `higher_is_better` for each metric
348
+ higher_is_better = collections.defaultdict(dict)
349
+
350
+ for task_output in eval_tasks:
351
+ if "task_alias" in (task_config := task_output.task_config):
352
+ results[task_output.task_name]["alias"] = task_config["task_alias"]
353
+ else:
354
+ results[task_output.task_name]["alias"] = task_output.task_name
355
+ if group_alias := task_output.group_alias:
356
+ if group_alias not in results and (group_name := task_output.group_name):
357
+ results[group_name]["alias"] = group_alias
358
+ num_fewshot[task_output.task_name] = task_output.n_shot
359
+ configs[task_output.task_name] = task_output.task_config
360
+ versions[task_output.task_name] = task_output.version
361
+ samples[task_output.task_name] = task_output.logged_samples
362
+ higher_is_better[task_output.task_name] = task_output.task.higher_is_better()
363
+ for (metric, filter_key), items in task_output.sample_metrics.items():
364
+ metric_key = f"{metric},{filter_key}"
365
+ results[task_output.task_name][metric_key] = task_output.agg_metrics[
366
+ metric_key
367
+ ]
368
+ results[task_output.task_name]["samples"] = task_output.sample_len
369
+ results[task_output.task_name][f"{metric}_stderr,{filter_key}"] = (
370
+ task_output.agg_metrics[f"{metric}_stderr,{filter_key}"]
371
+ )
372
+ return results, samples, configs, versions, num_fewshot, higher_is_better
373
+
374
+
375
+ def consolidate_group_results(
376
+ results,
377
+ versions,
378
+ task_dict,
379
+ task_root=None,
380
+ show_group_table=False,
381
+ task_aggregation_list=None,
382
+ ) -> Tuple[dict, dict, bool, Union[None,]]:
383
+ """
384
+ (Recursively) calculates groups' aggregated metrics and updates the results and versions dictionaries with this info.
385
+
386
+ @return: a tuple [results, versions, show_group_table, task_aggregation_list] with formats described below:
387
+
388
+ - results: A defaultdict with task names (and, after this function is called, group names of
389
+ groups that perform aggregation) as keys, and dictionaries with "alias" and metric,filter_name pairs as keys.
390
+ - versions: A defaultdict with task names (and, after this function is called, group names of
391
+ groups that perform aggregation) as keys, and float values representing the task or group's version if a version is specified. (defaulting to None).
392
+ - show_group_table: a boolean which is true if there exists a group that requires printing of its aggregated scores in a group table.
393
+ - task_aggregation_list: a defaultdict listing the subtasks to average over to produce a given group's end metric.
394
+
395
+ The method then returns the updated results, versions, show_group_table, and task_aggregation_list as a tuple.
396
+ In the top-level invocation of this function, task_aggregation_list is ignored.
397
+ """
398
+ if task_root is None:
399
+ task_root = {}
400
+
401
+ if task_aggregation_list is None:
402
+ task_aggregation_list = {}
403
+
404
+ for group_or_task, group_or_task_info in task_dict.items():
405
+ # Convert to string
406
+ if isinstance(group_or_task, ConfigurableGroup):
407
+ group_config = group_or_task.config
408
+ group_or_task = group_or_task.group_name
409
+ else:
410
+ group_config = None
411
+
412
+ if isinstance(group_or_task_info, Task):
413
+ if task_root:
414
+ task_aggregation_list.setdefault(task_root, []).append(
415
+ group_or_task_info.task_name
416
+ )
417
+ else:
418
+ (
419
+ results,
420
+ versions,
421
+ show_group_table,
422
+ _task_aggregation_list,
423
+ ) = consolidate_group_results(
424
+ results,
425
+ versions,
426
+ group_or_task_info,
427
+ group_or_task,
428
+ show_group_table,
429
+ task_aggregation_list,
430
+ )
431
+ if task_root:
432
+ task_aggregation_list.setdefault(task_root, []).extend(
433
+ task_aggregation_list.get(group_or_task, [])
434
+ )
435
+
436
+ if (group_config is None) or (
437
+ group_config["aggregate_metric_list"] is None
438
+ ):
439
+ results[group_or_task][" "] = " "
440
+ continue
441
+
442
+ if "aggregate_metric_list" in group_config:
443
+ agg_metric_list = group_config["aggregate_metric_list"]
444
+
445
+ show_group_table = show_group_table | bool(
446
+ group_config["aggregate_metric_list"]
447
+ )
448
+
449
+ task_list = _task_aggregation_list[group_or_task]
450
+
451
+ metric_list = list(
452
+ {
453
+ key
454
+ for task in task_list
455
+ for key in results[task].keys()
456
+ if "_stderr" not in key and key not in ["task", "alias", "samples"]
457
+ }
458
+ )
459
+ for metric in metric_list:
460
+ stderr = "_stderr,".join(metric.split(","))
461
+
462
+ # gather metrics, sizes, and stderrs from subtasks
463
+ metrics = [
464
+ results[task][metric]
465
+ for task in task_list
466
+ if metric in results[task]
467
+ ] # TODO: copy?
468
+ stderrs = [
469
+ results[task][stderr]
470
+ for task in task_list
471
+ if stderr in results[task]
472
+ ]
473
+ sizes = [
474
+ results[task]["samples"]
475
+ for task in task_list
476
+ if metric in results[task]
477
+ ]
478
+
479
+ for metric_config in agg_metric_list:
480
+ for filter_name in metric_config["filter_list"]:
481
+ if metric != ",".join([metric_config["metric"], filter_name]):
482
+ continue
483
+
484
+ # compute group's pooled metric and stderr
485
+ if metric_config["aggregation"] == "mean":
486
+ aggregate_fn = aggregate_subtask_metrics
487
+ elif callable(metric_config["aggregation"]):
488
+ aggregate_fn = metric_config["aggregation"]
489
+ else:
490
+ raise ValueError(
491
+ f"Currently, only 'mean' is supported for automatically aggregating scores across groups' subtasks. Got '{metric_config['aggregation']}' for group '{group_or_task}'"
492
+ )
493
+
494
+ results[group_or_task][metric] = aggregate_fn(
495
+ metrics,
496
+ sizes,
497
+ metric_config["weight_by_size"],
498
+ )
499
+ # TODO: calculate groups' metrics using arbitrary agg fns
500
+ if "N/A" in stderrs:
501
+ results[group_or_task][stderr] = "N/A"
502
+ else:
503
+ # NOTE: this assumes we are using the mean to aggregate. There are warnings about this elsewhere
504
+ results[group_or_task][stderr] = pooled_sample_stderr(
505
+ stderrs, sizes
506
+ )
507
+
508
+ results[group_or_task]["samples"] = sum(sizes)
509
+ group_metadata = group_config.get("metadata", None)
510
+ if group_metadata is not None:
511
+ versions[group_or_task] = group_metadata.get("version", None)
512
+ # print(results)
513
+ return results, versions, show_group_table, task_aggregation_list
514
+
515
+
516
+ @positional_deprecated
517
+ def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
518
+ """
519
+ Search upward in the directory tree to a maximum of three layers
520
+ to find and return the package root (containing the 'tests' folder)
521
+ """
522
+ cur_path = start_path.resolve()
523
+ max_layers = 3
524
+ for _ in range(max_layers):
525
+ if (cur_path / "tests" / "test_version_stable.py").exists():
526
+ return cur_path
527
+ else:
528
+ cur_path = cur_path.parent.resolve()
529
+ raise FileNotFoundError(
530
+ f"Unable to find package root within {max_layers} upwards" + f"of {start_path}"
531
+ )
532
+
533
+
534
+ @positional_deprecated
535
+ def run_task_tests(task_list: List[str]):
536
+ """
537
+ Find the package root and run the tests for the given tasks
538
+ """
539
+ import pytest
540
+
541
+ package_root = find_test_root(start_path=pathlib.Path(__file__))
542
+ task_string = " or ".join(task_list)
543
+ args = [
544
+ f"{package_root}/tests/test_version_stable.py",
545
+ f"--rootdir={package_root}",
546
+ "-k",
547
+ f"{task_string}",
548
+ ]
549
+ sys.path.append(str(package_root))
550
+ pytest_return_val = pytest.main(args)
551
+ if pytest_return_val:
552
+ raise ValueError(
553
+ f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}"
554
+ )
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import List
3
+
4
+ from lm_eval.api.filter import FilterEnsemble
5
+ from lm_eval.api.registry import get_filter
6
+
7
+ from . import custom, extraction, selection, transformation
8
+
9
+
10
+ def build_filter_ensemble(
11
+ filter_name: str, components: List[List[str]]
12
+ ) -> FilterEnsemble:
13
+ """
14
+ Create a filtering pipeline.
15
+ """
16
+ filters = []
17
+ for function, kwargs in components:
18
+ if kwargs is None:
19
+ kwargs = {}
20
+ # create a filter given its name in the registry
21
+ f = partial(get_filter(function), **kwargs)
22
+ # add the filter as a pipeline step
23
+ filters.append(f)
24
+
25
+ return FilterEnsemble(name=filter_name, filters=filters)
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/custom.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lm_eval.api.filter import Filter
2
+ from lm_eval.api.registry import register_filter
3
+
4
+
5
+ @register_filter("custom")
6
+ class CustomFilter(Filter):
7
+ """
8
+ Custom filter that applies a custom, user-defined function to the model responses.
9
+ """
10
+
11
+ def __init__(self, **kwargs) -> None:
12
+ self.filter_fn = kwargs.pop("filter_fn")
13
+
14
+ super().__init__(**kwargs)
15
+
16
+ def apply(self, resps, docs):
17
+ return self.filter_fn(resps, docs)
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/decontamination.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lm_eval.api.filter import Filter
2
+ from lm_eval.api.registry import register_filter
3
+
4
+
5
+ @register_filter("decontaminate")
6
+ class DecontaminationFilter(Filter):
7
+ """
8
+ A filter which evaluates
9
+ """
10
+
11
+ name = "track_decontamination"
12
+
13
+ def __init__(self, path) -> None:
14
+ """
15
+
16
+ TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path").
17
+ should further cache result on a given (task_name, doc_id)
18
+ """
19
+ self._decontam_results = None
20
+
21
+ def apply(self, resps, docs) -> None:
22
+ """
23
+ Return {"no_contamination", "only_contamination"} keys for the 2 different subsets
24
+ """
25
+ pass
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/extraction.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import sys
3
+ import unicodedata
4
+
5
+ from lm_eval.api.filter import Filter
6
+ from lm_eval.api.registry import register_filter
7
+
8
+
9
+ @register_filter("regex")
10
+ class RegexFilter(Filter):
11
+ """A filter that extracts values from text using regex pattern matching.
12
+
13
+ This filter applies a regex pattern to each model response and extracts matched values.
14
+ If no match is found, returns a fallback value. Useful for extracting structured data
15
+ (like numbers) from unstructured model outputs.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
21
+ group_select: int = 0,
22
+ fallback: str = "[invalid]",
23
+ ) -> None:
24
+ """
25
+ pass a string `regex` to run `re.compile(r"regex")` on.
26
+ `fallback` defines the output returned if no matches for the regex are located.
27
+ """
28
+ self.regex_pattern = regex_pattern
29
+ self.regex = re.compile(regex_pattern)
30
+ self.group_select = group_select
31
+ self.fallback = fallback
32
+
33
+ def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
34
+ # here, we assume we have a list, in which each element is
35
+ # a list of model responses for some particular input/target pair.
36
+ # so we process each of these (same input/target response sets)
37
+ # independently (and keep them a list.)
38
+ def filter_set(inst):
39
+ filtered = []
40
+ for resp in inst:
41
+ match = self.regex.findall(resp)
42
+ if match:
43
+ match = match[self.group_select]
44
+ if isinstance(match, tuple):
45
+ match = [m for m in match if m]
46
+ if match:
47
+ match = match[0]
48
+ else:
49
+ match = self.fallback
50
+ match = match.strip()
51
+ else:
52
+ match = self.fallback
53
+ filtered.append(match)
54
+ return filtered
55
+
56
+ filtered_resps = list(map(lambda x: filter_set(x), resps))
57
+
58
+ return filtered_resps
59
+
60
+
61
+ @register_filter("remove_whitespace")
62
+ class WhitespaceFilter(Filter):
63
+ """Filters out leading whitespace from responses."""
64
+
65
+ def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
66
+ def filter_set(inst):
67
+ filtered_resp = []
68
+ for resp in inst:
69
+ resp = resp.lstrip()
70
+ filtered_resp.append(resp)
71
+ return filtered_resp
72
+
73
+ filtered_resps = [filter_set(resp) for resp in resps]
74
+
75
+ return filtered_resps
76
+
77
+
78
+ @register_filter("multi_choice_regex")
79
+ class MultiChoiceRegexFilter(RegexFilter):
80
+ """
81
+ A filter used to extract a model's answer on multiple choice questions with
82
+ letter answers. assumes each document has a "choices" field
83
+ containing the list of answer choices and that the answer label symbols
84
+ are of the form (A), (B), (C), ... or A, B, C.
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
90
+ group_select=0,
91
+ fallback: str = "[invalid]",
92
+ ignore_case=False,
93
+ ignore_punctuation=False,
94
+ regexes_to_ignore=None,
95
+ ) -> None:
96
+ """
97
+ regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
98
+ - step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response.
99
+ - step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices.
100
+ group_select: Selects the (group_select)th match from the findall result.
101
+ ignore_case: Ignores the case during step 1 matching
102
+ ignore_punctuation: Remove the punctuation during step 1 matching
103
+ regexes_to_ignore: Remove these regexes during step 1 matching
104
+ """
105
+ super().__init__(regex_pattern, group_select, fallback)
106
+ self.ignore_case = ignore_case
107
+ self.ignore_punctuation = ignore_punctuation
108
+ self.regexes_to_ignore = regexes_to_ignore
109
+
110
+ def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
111
+ # here, we assume we have a list, in which each element is
112
+ # a list of model responses for some particular input/target pair.
113
+ # so we process each of these (same input/target response sets)
114
+ # independently (and keep them a list.)
115
+
116
+ def find_match(regex, resp, convert_dict={}):
117
+ match = regex.findall(resp)
118
+ if match:
119
+ match = match[self.group_select]
120
+ if isinstance(match, tuple):
121
+ match = [m for m in match if m][0]
122
+ match = match.strip()
123
+ if match and match in convert_dict:
124
+ match = convert_dict[match]
125
+ return match
126
+
127
+ punct_tbl = dict.fromkeys(
128
+ i
129
+ for i in range(sys.maxunicode)
130
+ if unicodedata.category(chr(i)).startswith("P")
131
+ )
132
+
133
+ def filter_ignores(st):
134
+ if self.regexes_to_ignore is not None:
135
+ for s in self.regexes_to_ignore:
136
+ st = re.sub(s, "", st)
137
+
138
+ if self.ignore_case:
139
+ st = st.lower()
140
+
141
+ if self.ignore_punctuation:
142
+ # https://stackoverflow.com/a/266162
143
+ st = st.translate(punct_tbl)
144
+ return st
145
+
146
+ filtered_resps = []
147
+
148
+ for r, doc in zip(resps, docs):
149
+ fallback_regexes = []
150
+ choice_to_alpha = {}
151
+ next_alpha = "A"
152
+
153
+ without_paren_fallback_regexes = []
154
+ without_paren_to_target = {}
155
+
156
+ choices = doc["choices"]
157
+ for c in choices:
158
+ m = filter_ignores(c.strip())
159
+ fallback_regexes.append(f"{re.escape(m)}")
160
+ choice_to_alpha[m] = f"({next_alpha})"
161
+
162
+ without_paren_fallback_regexes.append(next_alpha)
163
+ without_paren_to_target[next_alpha] = f"({next_alpha})"
164
+
165
+ next_alpha = chr(ord(next_alpha) + 1)
166
+ fallback_regex = re.compile("|".join(fallback_regexes))
167
+ without_paren_fallback_regex = "|".join(without_paren_fallback_regexes)
168
+ without_paren_fallback_regex = re.compile(
169
+ rf":[\s]*({without_paren_fallback_regex})"
170
+ )
171
+
172
+ filtered = []
173
+ for resp in r:
174
+ match = find_match(self.regex, resp)
175
+ if not match:
176
+ match = find_match(
177
+ fallback_regex, filter_ignores(resp), choice_to_alpha
178
+ )
179
+ if not match:
180
+ match = find_match(
181
+ without_paren_fallback_regex, resp, without_paren_to_target
182
+ )
183
+ if not match:
184
+ match = self.fallback
185
+ filtered.append(match)
186
+ filtered_resps.append(filtered)
187
+
188
+ return filtered_resps
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/selection.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter
2
+
3
+ from lm_eval.api.filter import Filter
4
+ from lm_eval.api.registry import register_filter
5
+
6
+
7
+ # TODO: implement "arg_max" filter. either it should take in an arbitrary "scoring"/reward function
8
+ # that takes an input and returns a scalar and then should select the max reward,
9
+ # or should implement different filters for different ways of handling a reward model's inference.
10
+
11
+
12
+ @register_filter("take_first")
13
+ class TakeFirstFilter(Filter):
14
+ def __init__(self) -> None:
15
+ """
16
+ Can define custom behavior here, if an individual instantiation of a Filter class should have state.
17
+ """
18
+
19
+ def apply(self, resps, docs):
20
+ """
21
+ Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
22
+ """
23
+ return map(lambda r: r[0], resps)
24
+
25
+
26
+ @register_filter("take_first_k")
27
+ class TakeKFilter(Filter):
28
+ def __init__(self, **kwargs) -> None:
29
+ self.k = kwargs.pop("k")
30
+
31
+ super().__init__(**kwargs)
32
+
33
+ def apply(self, resps, docs):
34
+ # need resp to be subscriptable to check below
35
+ resps = list(resps)
36
+ # check we have at least k responses per doc, else we can't take the first k
37
+ assert len(resps[0]) >= self.k, (
38
+ f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ."
39
+ )
40
+ return map(lambda r: r[: self.k], resps)
41
+
42
+
43
+ @register_filter("majority_vote")
44
+ class MajorityVoteFilter(Filter):
45
+ def __init__(self) -> None:
46
+ """
47
+ Can define custom behavior here, if an individual instantiation of a Filter class should have state.
48
+ """
49
+
50
+ def apply(self, resps, docs):
51
+ """
52
+ Each entry of `resps` is a list of model responses.
53
+ We select the response that occurs most frequently in each entry of `resps`.
54
+ """
55
+
56
+ def select_majority(resp):
57
+ counts = Counter(resp)
58
+ vote = counts.most_common(1)[0][0]
59
+ return vote
60
+
61
+ return map(lambda r: [select_majority(r)], resps)
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/transformation.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lm_eval.api.filter import Filter
2
+ from lm_eval.api.registry import register_filter
3
+
4
+
5
+ @register_filter("lowercase")
6
+ class LowercaseFilter(Filter):
7
+ def __init__(self) -> None:
8
+ pass
9
+
10
+ def apply(self, resps, docs):
11
+ def filter_set(inst):
12
+ return [resp.lower() for resp in inst]
13
+
14
+ return [filter_set(resp) for resp in resps]
15
+
16
+
17
+ @register_filter("uppercase")
18
+ class UppercaseFilter(Filter):
19
+ def __init__(self) -> None:
20
+ pass
21
+
22
+ def apply(self, resps, docs):
23
+ def filter_set(inst):
24
+ return [resp.upper() for resp in inst]
25
+
26
+ return [filter_set(resp) for resp in resps]
27
+
28
+
29
+ @register_filter("map")
30
+ class MapFilter(Filter):
31
+ def __init__(self, mapping_dict: dict = None, default_value=None) -> None:
32
+ """
33
+ Initializes the MapFilter with a given mapping dictionary and default value.
34
+
35
+ Args:
36
+ - mapping_dict (dict): A dictionary containing the key-value mappings.
37
+ Default is an empty dictionary.
38
+ - default_value (Any): The value to be returned when a key is not found in the mapping_dict.
39
+ Default is None.
40
+
41
+ Example:
42
+ mapper = MapFilter({'A': 1, 'B': 2}, default_value=0)
43
+ """
44
+ if mapping_dict is None:
45
+ mapping_dict = {}
46
+ assert isinstance(mapping_dict, dict), (
47
+ "Provided mapping_dict is not a dictionary"
48
+ )
49
+ self.mapping_dict = mapping_dict
50
+ self.default_value = default_value
51
+
52
+ def apply(self, resps, docs):
53
+ def filter_set(inst):
54
+ return [self.mapping_dict.get(resp, self.default_value) for resp in inst]
55
+
56
+ return [filter_set(resp) for resp in resps]
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import (
2
+ diffllm,
3
+ huggingface,
4
+ )
5
+
6
+
7
+ # TODO: implement __all__
8
+
9
+
10
+ try:
11
+ # enable hf hub transfer if available
12
+ import hf_transfer # type: ignore # noqa
13
+ import huggingface_hub.constants # type: ignore
14
+
15
+ huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
16
+ except ImportError:
17
+ pass
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/diffllm.py ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import gc
3
+ import random
4
+ import json
5
+ import os
6
+ import time
7
+ from datetime import timedelta
8
+ from typing import List, Optional, Tuple, Type, TypeVar, Union
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import transformers
13
+ from accelerate import (
14
+ Accelerator,
15
+ InitProcessGroupKwargs,
16
+ find_executable_batch_size,
17
+ )
18
+ from datasets import Dataset
19
+ from packaging import version
20
+ from tqdm import tqdm
21
+
22
+ from lm_eval import utils
23
+ from lm_eval.api.instance import Instance
24
+ from lm_eval.api.model import LM
25
+ from lm_eval.api.registry import register_model
26
+ from lm_eval.models.utils import Collator, get_dtype
27
+
28
+ eval_logger = logging.getLogger(__name__)
29
+ T = TypeVar("T", bound="LM")
30
+
31
+
32
+ def empty_cache_by_memory(threshold_gb=70):
33
+ """
34
+ Empty CUDA cache if allocated memory exceeds threshold
35
+ Args:
36
+ threshold_gb: Memory threshold in GB
37
+ """
38
+ if torch.cuda.is_available():
39
+ # Get current memory allocated
40
+ allocated = torch.cuda.memory_allocated() / 1024**3 # Convert to GB
41
+
42
+ if allocated > threshold_gb:
43
+ # Clear cache
44
+ gc.collect()
45
+ torch.cuda.empty_cache()
46
+ print(f"Cache cleared. Memory freed: {allocated:.2f} GB")
47
+
48
+ @register_model("diffllm")
49
+ class DiffLLM(LM):
50
+ def __init__(
51
+ self,
52
+ pretrained: Union[str, transformers.PreTrainedModel],
53
+ batch_size: Optional[Union[int, str]] = 1,
54
+ device: Optional[str] = "cuda",
55
+ dtype: Optional[Union[str, torch.dtype]] = "auto",
56
+ max_prompt_len: Optional[int] = 1024,
57
+ max_new_tokens: Optional[int] = 128,
58
+ nll_type: Optional[str] = "mc",
59
+ log_type: Optional[str] = "ftb",
60
+ classifier_free_guidance: Optional[float] = 1.0,
61
+ pad_to_max_len: Optional[bool] = False,
62
+ sampling_eps: Optional[float] = 1e-3,
63
+ diffusion_steps: Optional[int] = 32,
64
+ trust_remote_code: Optional[bool] = True,
65
+ parallelize: Optional[bool] = False,
66
+ autogptq: Optional[Union[bool, str]] = False,
67
+ **kwargs,
68
+ ) -> None:
69
+ super().__init__()
70
+
71
+ # prepare for parallelism
72
+ assert isinstance(device, str)
73
+ assert isinstance(pretrained, str)
74
+ assert isinstance(batch_size, (int, str))
75
+
76
+ gpus = torch.cuda.device_count()
77
+ accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
78
+ accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
79
+
80
+ self.accelerator = accelerator
81
+
82
+ if "npu" in accelerator.device.type:
83
+ gpus = torch.npu.device_count()
84
+
85
+ # using one process with no model parallelism
86
+ if not (parallelize or accelerator.num_processes > 1):
87
+ # use user-passed device
88
+ device_list = set(
89
+ ["cuda", "cpu"]
90
+ + [f"cuda:{i}" for i in range(gpus)]
91
+ + ["mps", "mps:0"]
92
+ + [f"npu:{i}" for i in range(gpus)]
93
+ )
94
+ if device and device in device_list:
95
+ self._device = torch.device(device)
96
+ eval_logger.info(f"Using device '{device}'")
97
+ if device in ("mps", "mps:0") and version.parse(
98
+ torch.__version__
99
+ ) < version.parse("2.1"):
100
+ raise RuntimeError(
101
+ f"mps requires torch >= 2.1. You have {torch.__version__}"
102
+ )
103
+ else:
104
+ eval_logger.info("Device not specified")
105
+ eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
106
+ self._device = (
107
+ torch.device("cuda")
108
+ if torch.cuda.is_available()
109
+ else torch.device("cpu")
110
+ )
111
+ else:
112
+ if device != "cuda":
113
+ eval_logger.info(
114
+ f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
115
+ )
116
+ self._device = self.accelerator.device
117
+
118
+ self.batch_size_per_gpu = batch_size
119
+ if isinstance(batch_size, str):
120
+ self.batch_size_per_gpu = int(batch_size)
121
+ self._create_model_and_tokenizer(pretrained, dtype, trust_remote_code)
122
+
123
+ if isinstance(pretrained, str):
124
+ if gpus >= 1 or str(self.device) == "mps":
125
+ if not (parallelize or autogptq or (hasattr(self, "accelerator") and self.accelerator.num_processes > 1)):
126
+ try:
127
+ self.model.to(self.device)
128
+ except ValueError:
129
+ eval_logger.debug(
130
+ "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore."
131
+ )
132
+ if gpus > 1:
133
+ if self.accelerator.num_processes > 1:
134
+ self._device = torch.device(f"{accelerator.device}")
135
+ self._rank = self.accelerator.local_process_index
136
+ self._world_size = self.accelerator.num_processes
137
+ else:
138
+ self._rank = 0
139
+ self._world_size = 1
140
+ else:
141
+ self._rank = 0
142
+ self._world_size = 1
143
+ else:
144
+ eval_logger.warning(
145
+ "Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration"
146
+ )
147
+ self._rank = 0
148
+ self._world_size = 1
149
+
150
+ self.max_prompt_len = max_prompt_len
151
+ self.max_new_tokens = max_new_tokens
152
+ self.diffusion_steps = diffusion_steps
153
+ self.temperature = kwargs.get("temperature", 0.7)
154
+ self.top_p = kwargs.get("top_p", 0.95)
155
+ self.alg = kwargs.get("alg", "entropy")
156
+ self.alg_temp = kwargs.get("alg_temp", 0.0)
157
+ self.top_k = kwargs.get("top_k", None)
158
+
159
+ self.nll_type = nll_type
160
+ self.log_type = log_type
161
+ self.classifier_free_guidance = classifier_free_guidance
162
+ self.pad_to_max_len = pad_to_max_len
163
+ self.sampling_eps = sampling_eps
164
+
165
+ self.mask_id = 151666
166
+ self.eos_id = 151643
167
+
168
+ raw_use_hts = kwargs.get("use_hts", False)
169
+ if isinstance(raw_use_hts, str):
170
+ self.use_hts = raw_use_hts.lower() == "true"
171
+ else:
172
+ self.use_hts = bool(raw_use_hts)
173
+
174
+ self.realtime_output = kwargs.get("realtime_output", "eval_results.jsonl")
175
+
176
+ if self.use_hts:
177
+ from .hts_sampler import HTSSampler
178
+ self.hts_sampler = HTSSampler(self.model, self.tokenizer, device=self.device)
179
+ eval_logger.info(f"Rank {self.rank}: HTS Sampler initialized for Dream.")
180
+
181
+ @property
182
+ def batch_size(self):
183
+ return self.batch_size_per_gpu
184
+
185
+ @property
186
+ def device(self):
187
+ return self._device
188
+
189
+ @property
190
+ def rank(self):
191
+ return self._rank
192
+
193
+ @property
194
+ def world_size(self):
195
+ return self._world_size
196
+
197
+ def _create_model_and_tokenizer(self, pretrained, dtype, trust_remote_code):
198
+ self.model = (
199
+ transformers.AutoModel.from_pretrained(
200
+ pretrained,
201
+ torch_dtype=get_dtype(dtype),
202
+ trust_remote_code=trust_remote_code,
203
+ )
204
+ .eval()
205
+ ).to(self.device)
206
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(
207
+ pretrained, trust_remote_code=trust_remote_code
208
+ )
209
+
210
+ def tok_decode(self, tokens, skip_special_tokens=True):
211
+ return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
212
+
213
+ def tok_encode(self, text, add_special_tokens=True):
214
+ return self.tokenizer(
215
+ text, return_tensors="pt", add_special_tokens=add_special_tokens
216
+ ).input_ids
217
+
218
+ @classmethod
219
+ def create_from_arg_string(
220
+ cls: Type[T], arg_string: str, additional_config: Optional[dict] = None
221
+ ) -> T:
222
+ additional_config = {} if additional_config is None else additional_config
223
+ args = utils.simple_parse_args_string(arg_string)
224
+ args2 = {k: v for k, v in additional_config.items() if v is not None}
225
+ return cls(**args, **args2)
226
+
227
+ def apply_chat_template(
228
+ self, chat_history, add_generation_prompt: bool = True
229
+ ) -> str:
230
+ chat_templated = self.tokenizer.apply_chat_template(
231
+ chat_history,
232
+ tokenize=False,
233
+ add_generation_prompt=add_generation_prompt,
234
+ continue_final_message=not add_generation_prompt,
235
+ )
236
+ return chat_templated
237
+
238
+ @property
239
+ def tokenizer_name(self) -> str:
240
+ return self.tokenizer.name_or_path.replace("/", "__")
241
+
242
+ def _generate_batch(self, prompts: List[str], gen_kwargs: dict = None) -> Tuple[List[str], List[dict]]:
243
+ raw_val = gen_kwargs.get("use_hts", self.use_hts)
244
+ use_hts_now = str(raw_val).lower() == "true" if not isinstance(raw_val, bool) else raw_val
245
+
246
+ all_stats = []
247
+ if not use_hts_now:
248
+ prompt_ids = self.tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left").input_ids
249
+ prompt_ids = prompt_ids[:, -self.max_prompt_len:]
250
+ attn_mask = prompt_ids.ne(self.tokenizer.pad_token_id).to(self.device)
251
+ prompt_ids = prompt_ids.to(device=self.device)
252
+
253
+ generation_ids = self.model.diffusion_generate(
254
+ prompt_ids,
255
+ attention_mask=attn_mask,
256
+ max_new_tokens=self.max_new_tokens,
257
+ output_history=False,
258
+ return_dict_in_generate=True,
259
+ steps=self.diffusion_steps,
260
+ temperature=self.temperature,
261
+ top_p=self.top_p,
262
+ top_k=self.top_k,
263
+ alg=self.alg,
264
+ alg_temp=self.alg_temp,
265
+ )
266
+ responses = [
267
+ self.tokenizer.decode(g[len(p) :].tolist()).split(self.tokenizer.eos_token)[0]
268
+ for p, g in zip(prompt_ids, generation_ids.sequences)
269
+ ]
270
+ all_stats = [{} for _ in responses]
271
+ return responses, all_stats
272
+ else:
273
+ if not hasattr(self, "hts_sampler"):
274
+ from .hts_sampler import HTSSampler
275
+ self.hts_sampler = HTSSampler(self.model, self.tokenizer, device=self.device)
276
+
277
+ results = []
278
+ for prompt in prompts:
279
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
280
+
281
+ final_codes, stats = self.hts_sampler.generate_hts(
282
+ prompt_text=prompt,
283
+ input_ids=input_ids,
284
+ initial_N=int(gen_kwargs.get("initial_N", 4)),
285
+ final_K=int(gen_kwargs.get("final_K", 1)),
286
+ hts_survivor_k=int(gen_kwargs.get("hts_survivor_k", 4)),
287
+ reward_mode=gen_kwargs.get("reward_mode", "svf"),
288
+ task_type=gen_kwargs.get("task_type", "code"),
289
+ steps=self.diffusion_steps,
290
+ gen_length=self.max_new_tokens,
291
+ temperature=float(gen_kwargs.get("temperature", self.temperature)),
292
+ top_p=float(gen_kwargs.get("top_p", self.top_p)),
293
+ top_k=gen_kwargs.get("top_k", self.top_k),
294
+ until=gen_kwargs.get("until", []),
295
+ hts_mode=True,
296
+ mask_id=self.mask_id,
297
+ eos_id=self.eos_id
298
+ )
299
+
300
+ results.append(final_codes[0])
301
+ all_stats.append(stats)
302
+ return results, all_stats
303
+
304
+ def generate_until(self, requests: List[Instance], disable_tqdm: bool = False):
305
+ res = []
306
+
307
+ gen_kwargs_first = requests[0].args[1]
308
+ actual_output_path = gen_kwargs_first.get("realtime_output", self.realtime_output)
309
+
310
+ raw_val = gen_kwargs_first.get("use_hts", self.use_hts)
311
+ self.use_hts = str(raw_val).lower() == "true" if not isinstance(raw_val, bool) else raw_val
312
+
313
+ rank_tmp_file = actual_output_path.replace(".jsonl", f"_rank{self.rank}.tmp")
314
+
315
+ output_dir = os.path.dirname(rank_tmp_file)
316
+ if output_dir and not os.path.exists(output_dir):
317
+ os.makedirs(output_dir, exist_ok=True)
318
+
319
+ pbar = tqdm(
320
+ total=len(requests),
321
+ disable=(disable_tqdm or (self.rank != 0)),
322
+ desc="Running generate_until",
323
+ )
324
+
325
+ for batch_idx in range(0, len(requests), self.batch_size_per_gpu):
326
+ batch_requests = requests[batch_idx : batch_idx + self.batch_size_per_gpu]
327
+ contexts, task_gen_args = zip(*[req.arguments for req in batch_requests])
328
+
329
+ responses, stats_list = self._generate_batch(contexts, gen_kwargs=task_gen_args[0])
330
+
331
+ for i, r in enumerate(responses):
332
+ r = r.replace("```python", "").replace("```", "")
333
+
334
+ for s in task_gen_args[0].get('until', []):
335
+ r = r.split(s)[0]
336
+
337
+ target_val = getattr(batch_requests[i], "target", None)
338
+ if target_val is None or target_val == "N/A":
339
+ target_val = batch_requests[i].doc.get("answer", batch_requests[i].doc.get("solution", "N/A"))
340
+
341
+ save_data = {
342
+ "doc": batch_requests[i].doc,
343
+ "target": target_val,
344
+ "prompt": contexts[i],
345
+ "response": r,
346
+ }
347
+
348
+ if self.use_hts:
349
+ save_data.update(stats_list[i])
350
+
351
+ with open(rank_tmp_file, "a", encoding="utf-8") as f:
352
+ f.write(json.dumps(save_data, ensure_ascii=False) + "\n")
353
+ f.flush()
354
+
355
+ responses[i] = r
356
+
357
+ if self.rank == 0 and batch_idx == 0:
358
+ print(f"Sample Response:\n{responses[0]}\n")
359
+
360
+ res.extend(responses)
361
+ pbar.update(len(batch_requests))
362
+
363
+ pbar.close()
364
+
365
+ self.accelerator.wait_for_everyone()
366
+
367
+ if self.rank == 0:
368
+ eval_logger.info(f"Merging rank files into {actual_output_path}...")
369
+ with open(actual_output_path, "w", encoding="utf-8") as final_f:
370
+ for r in range(self.world_size):
371
+ temp_f = actual_output_path.replace(".jsonl", f"_rank{r}.tmp")
372
+ if os.path.exists(temp_f):
373
+ with open(temp_f, "r", encoding="utf-8") as tf:
374
+ for line in tf:
375
+ final_f.write(line)
376
+ os.remove(temp_f)
377
+ eval_logger.info("Merge completed.")
378
+
379
+ return res
380
+
381
+ def _forward_process(self, batch):
382
+ b, l = batch.shape
383
+ u0 = torch.rand(1, device=batch.device, dtype=torch.float32)
384
+ indices = torch.arange(b, device=batch.device).float()
385
+ t = (u0 + indices / b) % 1
386
+ p_mask = (1 - self.sampling_eps) * t + self.sampling_eps
387
+ p_mask = p_mask[:, None].repeat(1, l)
388
+ mask_indices = torch.rand((b, l), device=batch.device) < p_mask
389
+ mask_indices[:, 0] = False
390
+ mask_indices[:, -1] = False
391
+ noisy_batch = torch.where(mask_indices, self.mask_id, batch)
392
+ return noisy_batch, p_mask
393
+
394
+ @torch.no_grad()
395
+ def get_logits(self, batch, prompt_index):
396
+ if self.classifier_free_guidance > 1.:
397
+ assert len(prompt_index) == batch.shape[1]
398
+ prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1)
399
+ un_batch = batch.clone()
400
+ un_batch[prompt_index] = self.mask_id
401
+ batch = torch.cat([batch, un_batch])
402
+
403
+ if self.pad_to_max_len:
404
+ raise NotImplementedError
405
+ else:
406
+ input = batch
407
+
408
+ with torch.amp.autocast('cuda', dtype=torch.bfloat16):
409
+ logits = self.model(input, 'full').logits
410
+
411
+ if self.classifier_free_guidance > 1.:
412
+ logits, un_logits = torch.chunk(logits, 2, dim=0)
413
+ logits = un_logits + self.classifier_free_guidance * (logits - un_logits)
414
+ return logits[:, :batch.shape[1]]
415
+
416
+ @torch.no_grad()
417
+ def _eval_target_nll_mc(self, prefix, target):
418
+ if prefix is None:
419
+ seq = target[None, :]
420
+ else:
421
+ seq = torch.concatenate([prefix, target])[None, :]
422
+ seq = seq.repeat((self.batch_size, 1)).to(self.device)
423
+
424
+ if self.log_type == 'ftb':
425
+ prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
426
+ else:
427
+ prompt_index = torch.arange(seq.shape[1], device=self.device) >= len(prefix)
428
+
429
+ loss_acc = []
430
+ mc_num = self.diffusion_steps
431
+ for _ in range(max(mc_num // self.batch_size, 1)):
432
+ perturbed_seq = seq.clone()
433
+ perturbed_seq_, p_mask = self._forward_process(seq)
434
+ if self.log_type == 'ftb':
435
+ perturbed_seq[:, -len(target):] = perturbed_seq_[:, -len(target):]
436
+ elif self.log_type == 'btf':
437
+ perturbed_seq[:, :len(prefix)] = perturbed_seq_[:, :len(prefix)]
438
+ elif self.log_type == 'union':
439
+ perturbed_seq = perturbed_seq_
440
+ else:
441
+ raise NotImplementedError(self.log_type)
442
+
443
+ mask_indices = perturbed_seq == self.mask_id
444
+
445
+ logits = self.get_logits(perturbed_seq, prompt_index)
446
+
447
+ loss = F.cross_entropy(logits[mask_indices], seq[mask_indices], reduction='none') / p_mask[mask_indices]
448
+ loss = loss.sum() / self.batch_size
449
+ loss_acc.append(loss.item())
450
+ del logits, loss, perturbed_seq, perturbed_seq_, p_mask, mask_indices
451
+ empty_cache_by_memory(threshold_gb=70)
452
+
453
+ return sum(loss_acc) / len(loss_acc)
454
+
455
+ @torch.no_grad()
456
+ def _eval_target_nll_ar(self, prefix, target):
457
+ prefix, target = prefix.unsqueeze(0), target.unsqueeze(0) # 1*l1, 1*l2
458
+ assert self.log_type in ['ftb', 'btf']
459
+ assert self.nll_type in ['ar_ftb', 'ar_btf']
460
+
461
+ if self.log_type == 'ftb':
462
+ prompt_index = torch.arange(prefix.shape[1] + target.shape[1], device=self.device) < prefix.shape[1]
463
+ else:
464
+ prompt_index = torch.arange(prefix.shape[1] + target.shape[1], device=self.device) >= prefix.shape[1]
465
+
466
+ if self.log_type == 'ftb':
467
+ perturbed_ = target.repeat(target.shape[1], 1).clone().contiguous() # l2*l2
468
+ else:
469
+ perturbed_ = prefix.repeat(prefix.shape[1], 1).clone().contiguous() # l1*l1
470
+
471
+ mask_index = torch.ones((perturbed_.shape[1], perturbed_.shape[1]), dtype=torch.bool)
472
+ if self.nll_type == 'ar_ftb':
473
+ mask_index = torch.triu(mask_index)
474
+ else:
475
+ mask_index = torch.tril(mask_index)
476
+ perturbed_[mask_index] = self.mask_id
477
+ if self.log_type == 'ftb':
478
+ perturbed_seq = torch.cat([prefix.repeat(perturbed_.shape[0], 1), perturbed_], dim=-1)
479
+ else:
480
+ perturbed_seq = torch.cat([perturbed_, target.repeat(perturbed_.shape[0], 1)], dim=-1)
481
+
482
+ logits_ = []
483
+ num = len(perturbed_seq) // self.batch_size if len(perturbed_seq) % self.batch_size == 0 else len(perturbed_seq) // self.batch_size + 1
484
+ for i in range(num):
485
+ end = (i + 1) * self.batch_size if (i + 1) * self.batch_size < len(perturbed_seq) else len(perturbed_seq)
486
+ perturbed_seq_ = perturbed_seq[i * self.batch_size: end]
487
+ perturbed_seq_ = perturbed_seq_.to(self.device)
488
+ if len(perturbed_seq_.shape) == 1:
489
+ perturbed_seq_ = perturbed_seq_.unsqueeze(0)
490
+ logits = self.get_logits(perturbed_seq_, prompt_index)
491
+ logits_.append(logits.cpu())
492
+ logits = torch.cat(logits_, dim=0)
493
+
494
+ temp_index = torch.ones((perturbed_.shape[1], perturbed_.shape[1]), dtype=torch.bool)
495
+ if self.nll_type == 'ar_ftb':
496
+ temp_index = torch.triu(temp_index, diagonal=1)
497
+ else:
498
+ temp_index = torch.tril(temp_index, diagonal=-1)
499
+ mask_index[temp_index] = False
500
+ if self.log_type == 'ftb':
501
+ logits_index = torch.cat([torch.zeros((perturbed_.shape[1], prefix.shape[1]), dtype=torch.bool), mask_index], dim=-1)
502
+ else:
503
+ logits_index = torch.cat([mask_index, torch.zeros((perturbed_.shape[1], target.shape[1]), dtype=torch.bool)], dim=-1)
504
+
505
+ if self.log_type == 'ftb':
506
+ loss = F.cross_entropy(logits[logits_index], target[0], reduction='sum').cpu().item()
507
+ else:
508
+ loss = F.cross_entropy(logits[logits_index], prefix[0], reduction='sum').cpu().item()
509
+ return loss
510
+
511
+ def _encode_pair(self, context, continuation):
512
+ n_spaces = len(context) - len(context.rstrip())
513
+ if n_spaces > 0:
514
+ continuation = context[-n_spaces:] + continuation
515
+ context = context[:-n_spaces]
516
+
517
+ whole_enc = self.tokenizer.encode(context + continuation) + [
518
+ self.tokenizer.eos_token_id
519
+ ]
520
+ context_enc = self.tokenizer.encode(context)
521
+
522
+ context_enc_len = len(context_enc)
523
+ continuation_enc = whole_enc[context_enc_len:]
524
+
525
+ return context_enc, continuation_enc
526
+
527
+ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
528
+ def _tokenize(e):
529
+ prefix, target = self._encode_pair(e["prefix"], e["target"])
530
+ return {
531
+ "prefix_text": e["prefix"],
532
+ "target_text": e["target"],
533
+ "prefix": prefix,
534
+ "target": target,
535
+ }
536
+
537
+ ds = []
538
+ ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests]
539
+ ds = Dataset.from_list(ds)
540
+ ds = ds.map(_tokenize)
541
+ ds = ds.with_format("torch")
542
+
543
+ out = []
544
+ with torch.no_grad():
545
+ for elem in tqdm(ds, desc="Computing likelihood..."):
546
+ prefix = elem["prefix"]
547
+ target = elem["target"]
548
+
549
+ if self.nll_type == 'mc':
550
+ ll = -self._eval_target_nll_mc(prefix, target)
551
+ if self.log_type == 'union':
552
+ ll = ll / (len(target) + len(prefix))
553
+ elif self.nll_type == 'ar_ftb' or self.nll_type == 'ar_btf':
554
+ ll = -self._eval_target_nll_ar(prefix, target)
555
+ else:
556
+ raise NotImplementedError(self.nll_type)
557
+
558
+ is_target_greedy_dec = False
559
+ out.append((ll, 1.0 if is_target_greedy_dec else 0.0))
560
+ return out
561
+
562
+ def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
563
+ raise NotImplementedError
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/dummy.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ from tqdm import tqdm
4
+
5
+ from lm_eval.api.model import LM
6
+ from lm_eval.api.registry import register_model
7
+
8
+
9
+ @register_model("dummy")
10
+ class DummyLM(LM):
11
+ def __init__(self) -> None:
12
+ super().__init__()
13
+
14
+ @classmethod
15
+ def create_from_arg_string(cls, arg_string, additional_config=None):
16
+ return cls()
17
+
18
+ def loglikelihood(self, requests, disable_tqdm: bool = False):
19
+ res = []
20
+
21
+ for _ in tqdm(requests, disable=disable_tqdm):
22
+ res.append((-random.random(), False))
23
+
24
+ return res
25
+
26
+ def generate_until(self, requests, disable_tqdm: bool = False):
27
+ res = []
28
+
29
+ for request in tqdm(requests, disable=disable_tqdm):
30
+ res.append("lol")
31
+ assert request.arguments[0].strip() != ""
32
+
33
+ return res
34
+
35
+ def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
36
+ res = []
37
+
38
+ for _ in tqdm(requests, disable=disable_tqdm):
39
+ res.append(-random.random())
40
+
41
+ return res
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/hts_sampler.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from .verifier import CodeVerifier
5
+ import logging
6
+ import re
7
+ import math
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class HTSSampler:
12
+ def __init__(self, model, tokenizer, device="cuda"):
13
+ self.model = model
14
+ self.tokenizer = tokenizer
15
+ self.device = device
16
+ self.verifier = CodeVerifier(model, tokenizer, device)
17
+
18
+ def _get_num_transfer_tokens(self, block_length, steps):
19
+ if steps == 0: return torch.tensor([], dtype=torch.int64)
20
+ base = block_length // steps
21
+ remainder = block_length % steps
22
+ num_transfer_tokens = torch.full((steps,), base, dtype=torch.int64)
23
+ num_transfer_tokens[:remainder] += 1
24
+ return num_transfer_tokens
25
+
26
+ def _sample_with_temperature(self, logits, temperature, top_k, top_p):
27
+ logits = logits.to(torch.float32)
28
+ orig_probs = torch.softmax(logits, dim=-1)
29
+ x0_p, _ = torch.max(orig_probs, dim=-1)
30
+
31
+ if temperature > 0.0:
32
+ noise = torch.rand_like(logits, dtype=torch.float32)
33
+ gumbel_noise = -torch.log(-torch.log(noise + 1e-10) + 1e-10)
34
+ logits = logits / temperature + gumbel_noise
35
+
36
+ if top_k is not None and top_k > 0:
37
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
38
+ logits[indices_to_remove] = -float('Inf')
39
+
40
+ x0 = torch.argmax(logits, dim=-1)
41
+ return x0, x0_p
42
+
43
+ def _safe_scalar(self, val):
44
+ if isinstance(val, torch.Tensor):
45
+ if val.numel() > 1: return val.mean().item()
46
+ return val.item()
47
+ return float(val)
48
+
49
+ def _analyze_structure(self, text, task_type="code"):
50
+ score = 0.0
51
+ stripped = text.strip()
52
+ if task_type == "code":
53
+ if len(stripped) < 5: return -0.1
54
+ keywords = ["return", "print", "yield", "lambda", "class ", "def "]
55
+ if any(k in stripped for k in keywords): score += 0.05
56
+ if ":" in stripped: score += 0.02
57
+ if " " in text: score += 0.03
58
+ elif task_type == "math":
59
+ if "\\boxed{" in stripped: score += 0.1
60
+ if "The answer is" in stripped: score += 0.05
61
+ return score
62
+
63
+ def _chunked_forward(self, x, chunk_size=32, slice_indices=None):
64
+ total_batch = x.shape[0]
65
+ logits_list = []
66
+ for i in range(0, total_batch, chunk_size):
67
+ end_idx = min(i + chunk_size, total_batch)
68
+ sub_x = x[i:end_idx]
69
+ with torch.no_grad():
70
+ with torch.amp.autocast('cuda', dtype=torch.bfloat16):
71
+ outputs = self.model(sub_x, 'full')
72
+ sub_logits = outputs.logits
73
+ sub_logits = torch.cat([sub_logits[:, :1, :], sub_logits[:, :-1, :]], dim=1)
74
+ if slice_indices is not None:
75
+ s_start, s_end = slice_indices
76
+ sub_logits = sub_logits[:, s_start:s_end, :]
77
+ logits_list.append(sub_logits.detach().clone())
78
+ return torch.cat(logits_list, dim=0)
79
+
80
+ def _branch_and_resample(self, x, conf_scores, survivor_indices, target_width, mask_id,
81
+ prompt_length, resample_window=6, task_type="code"):
82
+ num_survivors = len(survivor_indices)
83
+ if num_survivors == 0: return x[:target_width].clone(), conf_scores[:target_width].clone()
84
+
85
+ base_repeat = target_width // num_survivors
86
+ remainder = target_width % num_survivors
87
+ new_x_list, new_conf_list = [], []
88
+
89
+ for i in range(num_survivors):
90
+ count = base_repeat + (1 if i < remainder else 0)
91
+ if count == 0: continue
92
+ survivor_x = x[survivor_indices[i]]
93
+ survivor_conf = conf_scores[survivor_indices[i]]
94
+
95
+ new_x_list.append(survivor_x.unsqueeze(0))
96
+ new_conf_list.append(survivor_conf.unsqueeze(0))
97
+
98
+ if count > 1:
99
+ gen_part = survivor_x[prompt_length:]
100
+ gen_conf = survivor_conf[prompt_length:]
101
+ non_mask_indices = (gen_part != mask_id).nonzero(as_tuple=True)[0]
102
+ for _ in range(count - 1):
103
+ perturbed_x = survivor_x.clone()
104
+ perturbed_conf = survivor_conf.clone()
105
+ if len(non_mask_indices) > 0:
106
+ pool_size = min(resample_window * 2, len(non_mask_indices))
107
+ current_token_confs = gen_conf[non_mask_indices]
108
+ _, candidate_pool = torch.topk(current_token_confs, k=pool_size, largest=False)
109
+
110
+ num_to_perturb = min(resample_window, pool_size)
111
+ rand_indices = torch.randperm(pool_size, device=self.device)[:num_to_perturb]
112
+ selected_sub_indices = candidate_pool[rand_indices]
113
+
114
+ target_idx_in_x = prompt_length + non_mask_indices[selected_sub_indices]
115
+ perturbed_x[target_idx_in_x] = mask_id
116
+ perturbed_conf[target_idx_in_x] = 0.0
117
+ new_x_list.append(perturbed_x.unsqueeze(0))
118
+ new_conf_list.append(perturbed_conf.unsqueeze(0))
119
+
120
+ return torch.cat(new_x_list, dim=0), torch.cat(new_conf_list, dim=0)
121
+
122
+ @torch.no_grad()
123
+ def generate_hts(self, prompt_text, input_ids, problem_data=None,
124
+ initial_N=1, final_K=1, survivor_K=None,
125
+ prune_step_pct=0.0, reward_mode="confidence",
126
+ temperature=0.7, block_length=32, steps=64, gen_length=1024,
127
+ top_p=0.95, top_k=None, minimal_topk=1, threshold=0.9,
128
+ eos_id=151643, mask_id=151666,
129
+ hts_mode=False, hts_start_pct=0.1, hts_end_pct=0.6, decay_factor=1.2,
130
+ hts_survivor_k=4, task_type="code", until=None, pruning_interval=20):
131
+
132
+ input_ids = input_ids.to(self.device)
133
+ prompt_length = input_ids.shape[1]
134
+ total_length = prompt_length + gen_length
135
+
136
+ x = torch.full((initial_N, total_length), mask_id, dtype=torch.long, device=self.device)
137
+ x[:, :prompt_length] = input_ids.repeat(initial_N, 1)
138
+ conf_scores = torch.zeros((initial_N, total_length), dtype=torch.float32, device=self.device)
139
+ conf_scores[:, :prompt_length] = 1.0
140
+
141
+ schedule = self._get_num_transfer_tokens(gen_length, steps)
142
+ current_bsz = initial_N
143
+ schedule_map = {}
144
+ ts_start, tr_end = 0, 0
145
+
146
+ if hts_mode:
147
+ ts_start, tr_end = int(steps * hts_start_pct), int(steps * hts_end_pct)
148
+ else:
149
+ final_K_list = [final_K] if not isinstance(final_K, list) else final_K
150
+ prune_pct_list = [prune_step_pct] if not isinstance(prune_step_pct, list) else prune_step_pct
151
+ for pct, width in zip(prune_pct_list, final_K_list):
152
+ if pct > 0: schedule_map[int(steps * pct)] = width
153
+
154
+ stats = {
155
+ "initial_n": initial_N,
156
+ "final_k": final_K if not isinstance(final_K, list) else final_K[-1],
157
+ "nfe": 0,
158
+ "svf_calls": 0,
159
+ "pruning_history": [],
160
+ "entropy_history": [],
161
+ "final_scores": []
162
+ }
163
+
164
+ next_allowed_pruning_step = ts_start
165
+
166
+ for step in range(steps):
167
+ perform_pruning = False
168
+ num_parents_to_select = hts_survivor_k
169
+
170
+ if hts_mode and ts_start <= step < tr_end and step >= next_allowed_pruning_step:
171
+ target_width = max(stats["final_k"], math.ceil(initial_N * (decay_factor ** -(step - ts_start))))
172
+ if current_bsz > target_width:
173
+ perform_pruning = True
174
+ elif not hts_mode and step in schedule_map:
175
+ target_width = schedule_map[step]
176
+ num_parents_to_select = target_width
177
+ if current_bsz > target_width:
178
+ perform_pruning = True
179
+
180
+ if perform_pruning:
181
+ stats["svf_calls"] += current_bsz
182
+ full_logits = self._chunked_forward(x[:current_bsz, :], slice_indices=(prompt_length, total_length))
183
+ rough_ids = torch.argmax(full_logits, dim=-1)
184
+ rough_codes = self.tokenizer.batch_decode(rough_ids, skip_special_tokens=True)
185
+
186
+ candidates = []
187
+ for i in range(current_bsz):
188
+ s = self._safe_scalar(self.verifier.get_reward(prompt_text, rough_codes[i], mode=reward_mode, current_logits=full_logits[i], task_type=task_type))
189
+ s += self._analyze_structure(rough_codes[i], task_type=task_type)
190
+ clean_text = rough_codes[i].strip().replace(" ", "").replace("\n", "")
191
+ content_key = hash(clean_text[:150] + clean_text[-150:]) if clean_text else i
192
+ candidates.append({'score': s, 'idx': i, 'key': content_key})
193
+
194
+ stats["pruning_history"].append({"step": step, "scores": [c['score'] for c in candidates]})
195
+ candidates.sort(key=lambda c: c['score'], reverse=True)
196
+
197
+ selected_indices, seen_keys = [], set()
198
+ for cand in candidates:
199
+ if len(selected_indices) >= num_parents_to_select: break
200
+ if cand['key'] not in seen_keys:
201
+ selected_indices.append(cand['idx']); seen_keys.add(cand['key'])
202
+ for cand in candidates:
203
+ if len(selected_indices) >= num_parents_to_select: break
204
+ if cand['idx'] not in selected_indices: selected_indices.append(cand['idx'])
205
+
206
+ top_indices = torch.tensor(selected_indices, device=self.device)
207
+ x, conf_scores = self._branch_and_resample(x, conf_scores, top_indices, target_width, mask_id, prompt_length, task_type=task_type)
208
+
209
+ current_bsz = target_width
210
+ next_allowed_pruning_step = step + pruning_interval
211
+
212
+ active_mask = (x[:current_bsz, prompt_length:] == mask_id)
213
+ if active_mask.sum() == 0: break
214
+
215
+ stats["nfe"] += current_bsz
216
+ logits = self._chunked_forward(x[:current_bsz, :], slice_indices=(prompt_length, total_length))
217
+
218
+ with torch.no_grad():
219
+ probs = torch.softmax(logits.float(), dim=-1)
220
+ token_entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=-1)
221
+ sample_entropy = token_entropy.mean(dim=-1)
222
+ stats["entropy_history"].append(sample_entropy.tolist())
223
+
224
+ x0, x0_p = self._sample_with_temperature(logits, temperature, top_k, top_p)
225
+ num_transfer = schedule[step].item()
226
+ confidence = torch.where(active_mask, x0_p, -torch.inf)
227
+ transfer_idx = torch.zeros_like(x0, dtype=torch.bool)
228
+
229
+ for b in range(current_bsz):
230
+ k = min(num_transfer, active_mask[b].sum().item())
231
+ if k <= 0: continue
232
+ high_conf_mask = (confidence[b] > threshold)
233
+ if high_conf_mask.sum() >= k:
234
+ transfer_idx[b] = high_conf_mask
235
+ else:
236
+ _, topk_ids = torch.topk(confidence[b], k=k)
237
+ transfer_idx[b, topk_ids] = True
238
+
239
+ if transfer_idx.any():
240
+ x[:current_bsz, prompt_length:][transfer_idx] = x0[transfer_idx]
241
+ conf_scores[:current_bsz, prompt_length:][transfer_idx] = x0_p[transfer_idx]
242
+
243
+ final_codes = self.tokenizer.batch_decode(x[:current_bsz, prompt_length:], skip_special_tokens=True)
244
+ final_candidates = []
245
+ for i, code in enumerate(final_codes):
246
+ txt = code.split(self.tokenizer.eos_token)[0]
247
+ if until:
248
+ for term in until:
249
+ if term in txt: txt = txt.split(term)[0]
250
+ s = self._safe_scalar(self.verifier.get_reward(prompt_text, txt, mode=reward_mode, task_type=task_type))
251
+ final_candidates.append({'resp': txt, 'score': s})
252
+
253
+ final_candidates.sort(key=lambda c: c['score'], reverse=True)
254
+ stats["final_scores"] = [c['score'] for c in final_candidates]
255
+ stats["all_trajectories"] = [{"rank": i+1, "resp": c['resp'], "score": c['score']} for i, c in enumerate(final_candidates)]
256
+
257
+ return [c['resp'] for c in final_candidates], stats
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/huggingface.py ADDED
@@ -0,0 +1,1459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ import os
4
+ from datetime import timedelta
5
+ from pathlib import Path
6
+ from typing import Dict, List, Literal, Optional, Tuple, Union
7
+
8
+ import jinja2
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import transformers
12
+ from accelerate import (
13
+ Accelerator,
14
+ InitProcessGroupKwargs,
15
+ find_executable_batch_size,
16
+ )
17
+ from accelerate.utils import get_max_memory
18
+ from huggingface_hub import HfApi
19
+ from packaging import version
20
+ from peft import PeftModel
21
+ from peft import __version__ as PEFT_VERSION
22
+ from tqdm import tqdm
23
+ from transformers.models.auto.modeling_auto import (
24
+ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
25
+ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
26
+ )
27
+
28
+ from lm_eval import utils
29
+ from lm_eval.api.instance import Instance
30
+ from lm_eval.api.model import TemplateLM
31
+ from lm_eval.api.registry import register_model
32
+ from lm_eval.models.utils import (
33
+ Collator,
34
+ clear_torch_cache,
35
+ configure_pad_token,
36
+ get_dtype,
37
+ handle_stop_sequences,
38
+ pad_and_concat,
39
+ stop_sequences_criteria,
40
+ )
41
+
42
+
43
+ eval_logger = logging.getLogger(__name__)
44
+
45
+
46
+ @register_model("hf-auto", "hf", "huggingface")
47
+ class HFLM(TemplateLM):
48
+ """
49
+ An abstracted Huggingface model class. Enables usage with both models of
50
+ `transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes.
51
+
52
+ Supports data-parallel multi-GPU with HF Accelerate.
53
+ """
54
+
55
+ AUTO_MODEL_CLASS = None
56
+ _DEFAULT_MAX_LENGTH = 2048
57
+
58
+ def __init__(
59
+ self,
60
+ pretrained: Union[str, transformers.PreTrainedModel],
61
+ backend: Literal["default", "causal", "seq2seq"] = "default",
62
+ # override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
63
+ revision: Optional[str] = "main",
64
+ subfolder: Optional[str] = None,
65
+ tokenizer: Optional[
66
+ Union[
67
+ str,
68
+ transformers.PreTrainedTokenizer,
69
+ transformers.PreTrainedTokenizerFast,
70
+ ]
71
+ ] = None,
72
+ truncation: Optional[bool] = False,
73
+ logits_cache: bool = True,
74
+ max_length: Optional[int] = None,
75
+ device: Optional[str] = "cuda",
76
+ dtype: Optional[Union[str, torch.dtype]] = "auto",
77
+ batch_size: Optional[Union[int, str]] = 1,
78
+ max_batch_size: Optional[int] = 64,
79
+ trust_remote_code: Optional[bool] = False,
80
+ use_fast_tokenizer: Optional[bool] = True,
81
+ add_bos_token: Optional[bool] = False,
82
+ prefix_token_id: Optional[int] = None,
83
+ # arguments used for splitting a model across GPUs naively.
84
+ # only used if `parallelize=True`.
85
+ parallelize: Optional[bool] = False,
86
+ max_memory_per_gpu: Optional[Union[int, str]] = None,
87
+ max_cpu_memory: Optional[Union[int, str]] = None,
88
+ offload_folder: Optional[Union[str, os.PathLike]] = "./offload",
89
+ # PEFT, delta weights and quantization options
90
+ peft: Optional[str] = None,
91
+ delta: Optional[str] = None,
92
+ autogptq: Optional[Union[bool, str]] = False,
93
+ gptqmodel: Optional[bool] = False,
94
+ gguf_file: Optional[str] = None,
95
+ **kwargs,
96
+ ) -> None:
97
+ super().__init__()
98
+ # optionally: take in an already-initialized transformers.PreTrainedModel
99
+ if not isinstance(pretrained, str):
100
+ eval_logger.warning(
101
+ "`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way."
102
+ )
103
+ assert not parallelize, (
104
+ "`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`"
105
+ )
106
+ self._model = pretrained
107
+ self._device = self._model.device
108
+ self._config = self._model.config
109
+ gpus = 0
110
+
111
+ else:
112
+ assert isinstance(device, str)
113
+ assert isinstance(pretrained, str)
114
+ assert isinstance(batch_size, (int, str))
115
+
116
+ gpus = torch.cuda.device_count()
117
+ accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
118
+ accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
119
+ if accelerator.num_processes > 1:
120
+ self.accelerator = accelerator
121
+
122
+ if "npu" in accelerator.device.type:
123
+ gpus = torch.npu.device_count()
124
+
125
+ # using one process with no model parallelism
126
+ if not (parallelize or accelerator.num_processes > 1):
127
+ # use user-passed device
128
+ device_list = set(
129
+ ["cuda", "cpu"]
130
+ + [f"cuda:{i}" for i in range(gpus)]
131
+ + ["mps", "mps:0"]
132
+ + [f"npu:{i}" for i in range(gpus)]
133
+ )
134
+ if device and device in device_list:
135
+ self._device = torch.device(device)
136
+ eval_logger.info(f"Using device '{device}'")
137
+ if device in ("mps", "mps:0") and version.parse(
138
+ torch.__version__
139
+ ) < version.parse("2.1"):
140
+ raise RuntimeError(
141
+ f"mps requires torch >= 2.1. You have {torch.__version__}"
142
+ )
143
+ else:
144
+ eval_logger.info("Device not specified")
145
+ eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
146
+ self._device = (
147
+ torch.device("cuda")
148
+ if torch.cuda.is_available()
149
+ else torch.device("cpu")
150
+ )
151
+ else: # Parallelism managed by accelerate
152
+ if device != "cuda":
153
+ eval_logger.info(
154
+ f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
155
+ )
156
+ # TODO: include in warning that `load_in_8bit` etc. affect this too
157
+ self._device = (
158
+ self.accelerator.device
159
+ if hasattr(self, "accelerator")
160
+ else torch.device(device)
161
+ )
162
+
163
+ revision = str(revision) # cast to string if not already one
164
+ # TODO: update this to be less of a hack once subfolder is fixed in HF
165
+ revision = revision + ("/" + subfolder if subfolder is not None else "")
166
+
167
+ self._get_config(
168
+ pretrained,
169
+ revision=revision,
170
+ trust_remote_code=trust_remote_code,
171
+ gguf_file=gguf_file,
172
+ )
173
+
174
+ # determine which of 'causal' and 'seq2seq' backends to use for HF models
175
+ self._get_backend(
176
+ config=self.config, backend=backend, trust_remote_code=trust_remote_code
177
+ )
178
+
179
+ # load tokenizer so we know tokenizer vocabulary size before loading model and PEFT
180
+ self._create_tokenizer(
181
+ pretrained,
182
+ tokenizer,
183
+ revision=revision,
184
+ trust_remote_code=trust_remote_code,
185
+ use_fast_tokenizer=use_fast_tokenizer,
186
+ gguf_file=gguf_file,
187
+ add_bos_token=add_bos_token,
188
+ )
189
+
190
+ # if we passed `pretrained` as a string, initialize our model now
191
+ if isinstance(pretrained, str):
192
+ self._create_model(
193
+ pretrained=pretrained,
194
+ revision=revision,
195
+ dtype=dtype,
196
+ trust_remote_code=trust_remote_code,
197
+ parallelize=parallelize,
198
+ gpus=gpus,
199
+ max_memory_per_gpu=max_memory_per_gpu,
200
+ max_cpu_memory=max_cpu_memory,
201
+ offload_folder=offload_folder,
202
+ peft=peft,
203
+ delta=delta,
204
+ autogptq=autogptq,
205
+ gptqmodel=gptqmodel,
206
+ gguf_file=gguf_file,
207
+ **kwargs,
208
+ )
209
+
210
+ # access self._model through self.model property outside this method
211
+ if isinstance(self.model, torch.nn.Module):
212
+ self.model.eval()
213
+ self.model.tie_weights()
214
+
215
+ self.truncation = truncation
216
+ self.logits_cache = logits_cache
217
+ self.vocab_size = self.tokenizer.vocab_size
218
+ # select (or create) a pad token to use
219
+ self.tokenizer = configure_pad_token(self.tokenizer, model_config=self.config)
220
+
221
+ self.add_bos_token = add_bos_token
222
+ if "gemma" in getattr(self.config, "model_type", ""):
223
+ self.add_bos_token = True
224
+ eval_logger.info(
225
+ f"Model type is '{self.config.model_type}', part of the Gemma family--a BOS token will be used as Gemma underperforms without it."
226
+ )
227
+
228
+ self._max_length = max_length
229
+ self.pretrained = pretrained
230
+ self.delta = delta
231
+ self.peft = peft
232
+ self.revision = revision
233
+ self.batch_schedule = 1
234
+ self.batch_sizes = {}
235
+ self.max_batch_size = max_batch_size
236
+
237
+ if str(batch_size).startswith("auto"):
238
+ batch_size = batch_size.split(":")
239
+ self.batch_size_per_gpu = batch_size[0]
240
+ self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1
241
+ else:
242
+ self.batch_size_per_gpu = int(batch_size)
243
+
244
+ if isinstance(pretrained, str):
245
+ if gpus >= 1 or str(self.device) == "mps":
246
+ # TODO: can remove this whole snippet except in the mps case, perhaps?
247
+ if not (parallelize or autogptq or hasattr(self, "accelerator")):
248
+ # place model onto device requested manually,
249
+ # if not using HF Accelerate or device_map
250
+ # or any other option that preloads model onto device
251
+ try:
252
+ self.model.to(self.device)
253
+ except ValueError:
254
+ eval_logger.debug(
255
+ "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore."
256
+ )
257
+ # multigpu data-parallel support when launched with accelerate
258
+ if gpus > 1:
259
+ if accelerator.num_processes > 1:
260
+ if parallelize:
261
+ eval_logger.warning(
262
+ "You are both using a HF Accelerate `device_map` (`--model_args parallelize=True`) and launching via `accelerate launch`. This will attempt to do model and data parallelism depending on the resources available."
263
+ )
264
+ elif gpus > accelerator.num_processes:
265
+ eval_logger.warning(
266
+ "WARNING: The number of total system GPUs does not match the number of spawned processes. "
267
+ "If you would like to use data parallelism, please launch the script "
268
+ "with 'accelerate launch *script*'. "
269
+ f"Current run will proceed with {accelerator.num_processes} devices."
270
+ )
271
+ if self.accelerator.is_local_main_process:
272
+ eval_logger.info(
273
+ f"Using {gpus} devices with data parallelism"
274
+ )
275
+
276
+ self._device = torch.device(f"{accelerator.device}")
277
+ self.accelerator = accelerator
278
+
279
+ self._rank = self.accelerator.local_process_index
280
+ self._world_size = self.accelerator.num_processes
281
+ else:
282
+ # if we aren't launching via accelerate, ditch
283
+ self._rank = 0
284
+ self._world_size = 1
285
+ else:
286
+ # if a PreTrainedModel was passed into HFLM, we forgo distributed setup.
287
+ eval_logger.warning(
288
+ "Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration"
289
+ )
290
+ self._rank = 0
291
+ self._world_size = 1
292
+
293
+ self.custom_prefix_token_id = prefix_token_id
294
+ if prefix_token_id is not None:
295
+ eval_logger.info(
296
+ f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}"
297
+ )
298
+
299
+ def _get_accelerate_args(
300
+ self,
301
+ parallelize: Optional[bool] = None,
302
+ device_map: Optional[str] = "auto",
303
+ max_memory_per_gpu: Optional[Union[int, str]] = None,
304
+ max_cpu_memory: Optional[Union[int, str]] = None,
305
+ offload_folder: Optional[str] = "./offload",
306
+ gpus: Optional[int] = None,
307
+ ) -> dict:
308
+ """Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`."""
309
+ num_local_processes = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
310
+ num_machines = int(os.environ.get("WORLD_SIZE", 0)) // num_local_processes
311
+ if (
312
+ num_machines == 0
313
+ and hasattr(self, "accelerator")
314
+ and self.accelerator is not None
315
+ ):
316
+ eval_logger.info(
317
+ "We are not in a distributed setting for accelerate. Setting model_parallel to False."
318
+ )
319
+ parallelize = False
320
+
321
+ if parallelize is None:
322
+ # If parallelism is unset by the user, we automatically assign model parallelism
323
+ # if enough extra GPUs are available
324
+ max_memory_all_gpus = get_max_memory()
325
+ # We just want gpu, not cpu, max memory
326
+ if "cpu" in max_memory_all_gpus:
327
+ del max_memory_all_gpus["cpu"]
328
+ parallelize = bool(num_local_processes < len(max_memory_all_gpus))
329
+ eval_logger.info(
330
+ f"Setting model parallel to {parallelize} since "
331
+ f"the number of local processes is {num_local_processes} "
332
+ f"and the number of GPUs is {len(max_memory_all_gpus)}"
333
+ )
334
+
335
+ args = {}
336
+ if parallelize: # Model parallelism will be used
337
+ max_memory = {}
338
+ if max_memory_per_gpu is not None: # Using the provided memory requirements
339
+ max_memory_per_gpu_map = {
340
+ device_idx: max_memory_per_gpu for device_idx in range(gpus)
341
+ }
342
+ else: # Estimating the possible memory requirements
343
+ max_memory_all_gpus = get_max_memory()
344
+ if "cpu" in max_memory_all_gpus:
345
+ del max_memory_all_gpus["cpu"]
346
+ if not hasattr(self, "accelerator"):
347
+ max_memory_per_gpu_map = {
348
+ k: v for k, v in max_memory_all_gpus.items()
349
+ }
350
+ else:
351
+ # use only 1 / num_processes of the GPUs if we are running under accelerate launch
352
+ max_memory_per_gpu_map = {
353
+ k: v
354
+ for k, v in max_memory_all_gpus.items()
355
+ if k % num_local_processes
356
+ == (self.accelerator.process_index % num_local_processes)
357
+ }
358
+ args["max_memory"] = max_memory_per_gpu_map
359
+ args["device_map"] = "auto" if device_map is None else device_map
360
+ eval_logger.info(
361
+ f"Model parallel was set to True, setting max memory per GPU to {max_memory_per_gpu_map} and device map to {args.get('device_map')}"
362
+ )
363
+
364
+ if max_cpu_memory is not None:
365
+ max_memory["cpu"] = max_cpu_memory
366
+
367
+ args["offload_folder"] = offload_folder
368
+ elif (
369
+ device_map is None
370
+ ): # No model parallelism, we use the default provided device for our model
371
+ if hasattr(self, "accelerator"):
372
+ device_map = {"": f"{self.accelerator.device}"}
373
+ else:
374
+ device_map = {"": str(self.device)}
375
+ args["max_memory"] = None
376
+ args["device_map"] = device_map
377
+ eval_logger.info(
378
+ f"Model parallel was set to False, max memory was not set, and device map was set to {device_map}"
379
+ )
380
+ else:
381
+ args["max_memory"] = None
382
+ args["device_map"] = None
383
+ eval_logger.info("Model parallel was set to False.")
384
+
385
+ return args
386
+
387
+ @property
388
+ def config(self):
389
+ # return the associated transformers.AutoConfig for the given pretrained model.
390
+ return self._config
391
+
392
+ @property
393
+ def model(self):
394
+ # returns the model, unwrapping it if using Accelerate
395
+ if hasattr(self, "accelerator"):
396
+ return self.accelerator.unwrap_model(self._model)
397
+ else:
398
+ return self._model
399
+
400
+ @property
401
+ def eot_token_id(self):
402
+ # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
403
+ return self.tokenizer.eos_token_id
404
+
405
+ @property
406
+ def prefix_token_id(self):
407
+ # it is used as prefix for loglikelihood
408
+ if self.custom_prefix_token_id is not None:
409
+ return self.custom_prefix_token_id
410
+ if self.tokenizer.bos_token_id is not None:
411
+ return self.tokenizer.bos_token_id
412
+ return self.tokenizer.eos_token_id
413
+
414
+ @property
415
+ def max_length(self):
416
+ if self._max_length: # if max length manually set, return it
417
+ return self._max_length
418
+ seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
419
+ for attr in seqlen_config_attrs:
420
+ if hasattr(self.model.config, attr):
421
+ return getattr(self.model.config, attr)
422
+ if hasattr(self.tokenizer, "model_max_length"):
423
+ if self.tokenizer.model_max_length == 1000000000000000019884624838656:
424
+ return self._DEFAULT_MAX_LENGTH
425
+ return self.tokenizer.model_max_length
426
+ return self._DEFAULT_MAX_LENGTH
427
+
428
+ @property
429
+ def max_gen_toks(self) -> int:
430
+ return 256
431
+
432
+ @property
433
+ def batch_size(self):
434
+ return self.batch_size_per_gpu
435
+
436
+ @property
437
+ def device(self):
438
+ return self._device
439
+
440
+ @property
441
+ def rank(self):
442
+ return self._rank
443
+
444
+ @property
445
+ def world_size(self):
446
+ return self._world_size
447
+
448
+ @property
449
+ def tokenizer_name(self) -> str:
450
+ return self.tokenizer.name_or_path.replace("/", "__")
451
+
452
+ def _get_backend(
453
+ self,
454
+ config: Union[transformers.PretrainedConfig, transformers.AutoConfig],
455
+ backend: Literal["default", "causal", "seq2seq"] = "default",
456
+ trust_remote_code: Optional[bool] = False,
457
+ ) -> None:
458
+ """
459
+ Helper method during initialization.
460
+ Determines the backend ("causal" (decoder-only) or "seq2seq" (encoder-decoder)) model type to be used.
461
+ sets `self.AUTO_MODEL_CLASS` appropriately if not already set.
462
+
463
+ **If not calling HFLM.__init__() or HFLM._get_backend() within a subclass of HFLM,
464
+ user must set `self.backend` to be either "causal" or "seq2seq" manually!**
465
+ """
466
+
467
+ assert backend in ["default", "causal", "seq2seq"]
468
+
469
+ if backend != "default":
470
+ # if we've settled on non-default backend, use that manually
471
+ if backend == "causal":
472
+ self.backend = backend
473
+ elif backend == "seq2seq":
474
+ self.backend = backend
475
+ eval_logger.info(
476
+ f"Overrode HF model backend type, and using type '{self.backend}'"
477
+ )
478
+ else:
479
+ # determine and use the default HF backend for this model, based on its config + metadata.
480
+ if (
481
+ getattr(config, "model_type")
482
+ in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
483
+ ):
484
+ # first check if model type is listed under seq2seq models, since some
485
+ # models like MBart are listed in both seq2seq and causal mistakenly in HF transformers.
486
+ # these special cases should be treated as seq2seq models.
487
+ self.backend = "seq2seq"
488
+ eval_logger.debug(f"Using model type '{self.backend}'")
489
+ elif (
490
+ getattr(self.config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
491
+ ):
492
+ self.backend = "causal"
493
+ eval_logger.debug(f"Using model type '{self.backend}'")
494
+ else:
495
+ if not trust_remote_code:
496
+ eval_logger.warning(
497
+ "HF model type is neither marked as CausalLM or Seq2SeqLM. \
498
+ This is expected if your model requires `trust_remote_code=True` but may be an error otherwise."
499
+ "Setting backend to causal"
500
+ )
501
+ # if model type is neither in HF transformers causal or seq2seq model registries
502
+ # then we default to assuming AutoModelForCausalLM
503
+ self.backend = "causal"
504
+ eval_logger.info(
505
+ f"Model type cannot be determined. Using default model type '{self.backend}'"
506
+ )
507
+
508
+ if self.AUTO_MODEL_CLASS is None:
509
+ if self.backend == "causal":
510
+ self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
511
+ elif self.backend == "seq2seq":
512
+ self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
513
+
514
+ def _get_config(
515
+ self,
516
+ pretrained: str,
517
+ revision: str = "main",
518
+ trust_remote_code: bool = False,
519
+ gguf_file: Optional[str] = None,
520
+ ) -> None:
521
+ """Return the model config for HuggingFace models"""
522
+ self._config = transformers.AutoConfig.from_pretrained(
523
+ pretrained,
524
+ revision=revision,
525
+ trust_remote_code=trust_remote_code,
526
+ gguf_file=gguf_file,
527
+ )
528
+
529
+ def _create_model(
530
+ self,
531
+ pretrained: str,
532
+ revision: Optional[str] = "main",
533
+ dtype: Optional[Union[str, torch.dtype]] = "auto",
534
+ trust_remote_code: Optional[bool] = False,
535
+ # arguments used for splitting a model across GPUs naively.
536
+ # only used if `parallelize=True`.
537
+ # (accelerate naive PP (device_map) options)
538
+ parallelize: Optional[bool] = False,
539
+ gpus: Optional[int] = None,
540
+ max_memory_per_gpu: Optional[Union[int, str]] = None,
541
+ max_cpu_memory: Optional[Union[int, str]] = None,
542
+ offload_folder: Optional[str] = "./offload",
543
+ # PEFT, delta weights and quantization options
544
+ peft: Optional[str] = None,
545
+ delta: Optional[str] = None,
546
+ autogptq: Optional[Union[bool, str]] = False,
547
+ gptqmodel: Optional[bool] = False,
548
+ gguf_file: Optional[str] = None,
549
+ **kwargs,
550
+ ) -> None:
551
+ """
552
+ Initializes an HF or HF-compatible PreTrainedModel from scratch
553
+ inside HFLM, using the kwargs passed into self.__init__().
554
+
555
+ Also handles functionality such as AutoGPTQ usage and PEFT wrapping.
556
+
557
+ For future similar extensions to AutoGPTQ that are not core to HF's ecosystem,
558
+ (such as PyTorch models that are nearly, but not quite, fully mirroring
559
+ HF's public interface relied on in this HFLM class)
560
+ please consider subclassing HFLM and overriding this and other methods as needed.
561
+ """
562
+
563
+ model_kwargs = kwargs if kwargs else {}
564
+
565
+ model_kwargs.update(
566
+ self._get_accelerate_args(
567
+ parallelize=parallelize,
568
+ device_map=kwargs.get("device_map", None),
569
+ max_memory_per_gpu=max_memory_per_gpu,
570
+ max_cpu_memory=max_cpu_memory,
571
+ offload_folder=offload_folder,
572
+ gpus=gpus,
573
+ )
574
+ )
575
+
576
+ if not autogptq and not gptqmodel:
577
+ if model_kwargs.get("load_in_4bit", None):
578
+ assert transformers.__version__ >= "4.30.0", (
579
+ "load_in_4bit requires transformers >= 4.30.0"
580
+ )
581
+ if transformers.__version__ >= "4.30.0":
582
+ if model_kwargs.get("load_in_4bit", None):
583
+ if model_kwargs.get("bnb_4bit_compute_dtype", None):
584
+ model_kwargs["bnb_4bit_compute_dtype"] = get_dtype(
585
+ model_kwargs["bnb_4bit_compute_dtype"]
586
+ )
587
+
588
+ self._model = self.AUTO_MODEL_CLASS.from_pretrained(
589
+ pretrained,
590
+ revision=revision,
591
+ torch_dtype=get_dtype(dtype),
592
+ trust_remote_code=trust_remote_code,
593
+ gguf_file=gguf_file,
594
+ **model_kwargs,
595
+ )
596
+ else:
597
+ if autogptq and gptqmodel:
598
+ raise ValueError(
599
+ "Cannot use both 'autogptq' and 'gptqmodel' options at the same time."
600
+ )
601
+
602
+ if autogptq:
603
+ try:
604
+ from auto_gptq import AutoGPTQForCausalLM
605
+ except ModuleNotFoundError as exception:
606
+ raise type(exception)(
607
+ "Tried to load auto_gptq, but auto-gptq is not installed ",
608
+ "please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]",
609
+ )
610
+
611
+ self._model = AutoGPTQForCausalLM.from_quantized(
612
+ pretrained,
613
+ trust_remote_code=trust_remote_code,
614
+ model_basename=None if autogptq is True else Path(autogptq).stem,
615
+ use_safetensors=True
616
+ if autogptq is True
617
+ else autogptq.endswith(".safetensors"),
618
+ **model_kwargs,
619
+ )
620
+
621
+ if gptqmodel:
622
+ try:
623
+ from gptqmodel import GPTQModel
624
+ except ModuleNotFoundError as exception:
625
+ raise type(exception)(
626
+ "Tried to load gptqmodel, but gptqmodel is not installed ",
627
+ "please install gptqmodel via `pip install gptqmodel --no-build-isolation` or `pip install lm-eval[gptqmodel] --no-build-isolation`",
628
+ )
629
+
630
+ self._model = GPTQModel.from_quantized(
631
+ pretrained, trust_remote_code=trust_remote_code, **model_kwargs
632
+ )
633
+
634
+ if peft and delta:
635
+ raise ValueError(
636
+ "Cannot use both 'peft' and 'delta' options at the same time."
637
+ )
638
+
639
+ if peft:
640
+ if model_kwargs.get("load_in_4bit", None):
641
+ if version.parse(PEFT_VERSION) < version.parse("0.4.0"):
642
+ raise AssertionError("load_in_4bit requires peft >= 0.4.0")
643
+ if self._model.config.vocab_size != len(self.tokenizer):
644
+ # resize model for LoRAs with added tokens
645
+ eval_logger.info(
646
+ f"Model config indicates vocab_size='{self._model.config.vocab_size}', but found tokenizer with vocab size '{len(self.tokenizer)}'. Resizing model embedding layer..."
647
+ )
648
+ self._model.resize_token_embeddings(len(self.tokenizer))
649
+ self._model = PeftModel.from_pretrained(
650
+ self._model, peft, revision=revision
651
+ )
652
+ elif delta:
653
+ if autogptq:
654
+ eval_logger.warning(
655
+ "Delta weights might trigger unexpected behavior when used with AutoGPTQ."
656
+ )
657
+ _model_delta = self.AUTO_MODEL_CLASS.from_pretrained(
658
+ delta,
659
+ revision=revision,
660
+ torch_dtype=get_dtype(dtype),
661
+ trust_remote_code=trust_remote_code,
662
+ **model_kwargs,
663
+ )
664
+ for name, param in self._model.state_dict().items():
665
+ try:
666
+ param.data += _model_delta.state_dict()[name]
667
+ except KeyError:
668
+ raise KeyError(f"Delta model is missing weights for layer: {name}")
669
+ except Exception as e:
670
+ raise RuntimeError(
671
+ f"Failed to add delta weights to layer {name}. Error: {e}"
672
+ )
673
+
674
+ del _model_delta
675
+
676
+ return None
677
+
678
+ def _create_tokenizer(
679
+ self,
680
+ pretrained: Union[str, transformers.PreTrainedModel],
681
+ tokenizer: Optional[
682
+ Union[
683
+ str,
684
+ transformers.PreTrainedTokenizer,
685
+ transformers.PreTrainedTokenizerFast,
686
+ ]
687
+ ],
688
+ revision: Optional[str] = "main",
689
+ trust_remote_code: Optional[bool] = False,
690
+ use_fast_tokenizer: Optional[bool] = True,
691
+ gguf_file: Optional[str] = None,
692
+ add_bos_token: Optional[bool] = False,
693
+ ) -> None:
694
+ """
695
+ Helper method during initialization.
696
+
697
+ Create a tokenizer object corresponding to the correct
698
+ tokenizer for value of `pretrained`, or use the pre-initialized tokenizer passed.
699
+ """
700
+ kwargs = {
701
+ "revision": revision,
702
+ "trust_remote_code": trust_remote_code,
703
+ }
704
+
705
+ # gguf format embeds tokenizer and is not compatible with hf tokenizer `use_fast` param
706
+ if gguf_file is not None:
707
+ kwargs["gguf_file"] = gguf_file
708
+ else:
709
+ kwargs["use_fast"] = use_fast_tokenizer
710
+
711
+ if add_bos_token:
712
+ kwargs["add_bos_token"] = True
713
+
714
+ if tokenizer:
715
+ if isinstance(tokenizer, str):
716
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(
717
+ tokenizer, **kwargs
718
+ )
719
+ else:
720
+ assert isinstance(
721
+ tokenizer, transformers.PreTrainedTokenizer
722
+ ) or isinstance(tokenizer, transformers.PreTrainedTokenizerFast)
723
+ self.tokenizer = tokenizer
724
+ else:
725
+ # Get tokenizer based on 'pretrained'
726
+ if isinstance(pretrained, str):
727
+ model_name = pretrained
728
+ else:
729
+ # get the HF hub name via accessor on model
730
+ model_name = self.model.name_or_path
731
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(
732
+ model_name, **kwargs
733
+ )
734
+ return None
735
+
736
+ def _detect_batch_size(self, requests=None, pos: int = 0):
737
+ if requests:
738
+ _, context_enc, continuation_enc = requests[pos]
739
+ max_length = len(
740
+ (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]
741
+ )
742
+ max_context_enc = len(context_enc[-(self.max_length + 1) :])
743
+ max_cont_enc = len(continuation_enc[-(self.max_length + 1) :])
744
+ else:
745
+ max_length = self.max_length
746
+ max_context_enc = max_length
747
+ max_cont_enc = max_length
748
+
749
+ # if OOM, then halves batch_size and tries again
750
+ @find_executable_batch_size(starting_batch_size=self.max_batch_size)
751
+ def forward_batch(batch_size):
752
+ if self.backend == "seq2seq":
753
+ length = max(max_context_enc, max_cont_enc)
754
+ batched_conts = torch.ones(
755
+ (batch_size, length), device=self.device
756
+ ).long()
757
+ test_batch = torch.ones((batch_size, length), device=self.device).long()
758
+ call_kwargs = {
759
+ "attn_mask": test_batch,
760
+ "labels": batched_conts,
761
+ }
762
+ else:
763
+ call_kwargs = {}
764
+ test_batch = torch.ones(
765
+ (batch_size, max_length), device=self.device
766
+ ).long()
767
+ for _ in range(5):
768
+ out = F.log_softmax(self._model_call(test_batch, **call_kwargs), dim=-1) # noqa: F841
769
+
770
+ return batch_size
771
+
772
+ try:
773
+ batch_size = forward_batch()
774
+ except RuntimeError as e:
775
+ if "No executable batch size found" in str(e):
776
+ batch_size = 1
777
+ else:
778
+ raise
779
+
780
+ if self.world_size > 1:
781
+ # if multi-GPU, always take minimum over all selected batch sizes
782
+ max_rnk_bs = torch.tensor([batch_size], device=self.device)
783
+ gathered = (
784
+ self.accelerator.gather(max_rnk_bs).cpu().detach().numpy().tolist()
785
+ )
786
+ batch_size = min(gathered)
787
+ clear_torch_cache()
788
+ return batch_size
789
+
790
+ clear_torch_cache()
791
+ return batch_size
792
+
793
+ def tok_encode(
794
+ self, string: str, left_truncate_len=None, add_special_tokens=None
795
+ ) -> List[int]:
796
+ """ """
797
+ # default for None - empty dict, use predefined tokenizer param
798
+ # used for all models except for CausalLM or predefined value
799
+ special_tokens_kwargs = {}
800
+
801
+ # by default for CausalLM - false or self.add_bos_token is set
802
+ if add_special_tokens is None:
803
+ if self.backend == "causal":
804
+ special_tokens_kwargs = {
805
+ "add_special_tokens": False or self.add_bos_token
806
+ }
807
+ # otherwise the method explicitly defines the value
808
+ else:
809
+ special_tokens_kwargs = {"add_special_tokens": add_special_tokens}
810
+
811
+ encoding = self.tokenizer.encode(string, **special_tokens_kwargs)
812
+
813
+ # left-truncate the encoded context to be at most `left_truncate_len` tokens long
814
+ if left_truncate_len:
815
+ encoding = encoding[-left_truncate_len:]
816
+
817
+ return encoding
818
+
819
+ def tok_batch_encode(
820
+ self,
821
+ strings: List[str],
822
+ padding_side: str = "left",
823
+ left_truncate_len: int = None,
824
+ truncation: bool = False,
825
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
826
+ # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
827
+ old_padding_side = self.tokenizer.padding_side
828
+ self.tokenizer.padding_side = padding_side
829
+
830
+ add_special_tokens = {}
831
+ if self.backend == "causal":
832
+ add_special_tokens = {"add_special_tokens": False or self.add_bos_token}
833
+
834
+ encoding = self.tokenizer(
835
+ strings,
836
+ truncation=truncation,
837
+ padding="longest",
838
+ return_tensors="pt",
839
+ **add_special_tokens,
840
+ )
841
+ if left_truncate_len:
842
+ original_lengths = encoding["input_ids"].size(1)
843
+ if original_lengths > left_truncate_len:
844
+ eval_logger.warn(
845
+ f"Left truncation applied. Original sequence length was {original_lengths}, "
846
+ f"truncating to last {left_truncate_len} tokens. Some content will be lost.",
847
+ )
848
+ encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
849
+ encoding["attention_mask"] = encoding["attention_mask"][
850
+ :, -left_truncate_len:
851
+ ]
852
+ self.tokenizer.padding_side = old_padding_side
853
+
854
+ return encoding["input_ids"], encoding["attention_mask"]
855
+
856
+ def tok_decode(self, tokens, skip_special_tokens=True):
857
+ return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
858
+
859
+ def _model_call(self, inps, attn_mask=None, labels=None):
860
+ """
861
+ :param inps: torch.Tensor
862
+ A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape
863
+ [batch, sequence_ctx]. the size of sequence may vary from call to call
864
+ :param attn_mask: torch.Tensor, optional
865
+ A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
866
+ (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
867
+ :param labels: torch.Tensor, optional
868
+ A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
869
+ (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
870
+ :return
871
+ A torch tensor of shape [batch, sequence, vocab] with the
872
+ logits returned from the model's decoder
873
+ """
874
+ with torch.no_grad():
875
+ if attn_mask is not None or labels is not None:
876
+ assert attn_mask is not None and labels is not None
877
+ assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM
878
+ return self.model(
879
+ input_ids=inps, attention_mask=attn_mask, labels=labels
880
+ ).logits
881
+ else:
882
+ assert self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
883
+ return self.model(inps).logits
884
+
885
+ def _model_generate(self, context, max_length, stop, **generation_kwargs):
886
+ # temperature = 0.0 if not set
887
+ # if do_sample is false and temp==0.0:
888
+ # remove temperature, as do_sample=False takes care of this
889
+ # and we don't want a warning from HF
890
+ generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
891
+ do_sample = generation_kwargs.get("do_sample", None)
892
+
893
+ # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
894
+ if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
895
+ generation_kwargs["do_sample"] = do_sample = False
896
+
897
+ if do_sample is False and generation_kwargs.get("temperature") == 0.0:
898
+ generation_kwargs.pop("temperature")
899
+ # build stopping criteria
900
+ stopping_criteria = stop_sequences_criteria(
901
+ self.tokenizer, stop, context.shape[1], context.shape[0]
902
+ )
903
+ return self.model.generate(
904
+ input_ids=context,
905
+ max_length=max_length,
906
+ stopping_criteria=stopping_criteria,
907
+ pad_token_id=self.tokenizer.pad_token_id,
908
+ use_cache=True,
909
+ **generation_kwargs,
910
+ )
911
+
912
+ def _select_cont_toks(
913
+ self, logits: torch.Tensor, contlen: int = None, inplen: int = None
914
+ ) -> torch.Tensor:
915
+ if self.backend == "causal":
916
+ assert contlen and inplen, (
917
+ "Must pass input len and cont. len to select scored logits for causal LM"
918
+ )
919
+ # discard right-padding.
920
+ # also discard the input/context tokens. we'll only score continuations.
921
+ logits = logits[inplen - contlen : inplen]
922
+ elif self.backend == "seq2seq":
923
+ assert contlen and not inplen, (
924
+ "Selecting scored logits for Seq2SeqLM requires only cont. len"
925
+ )
926
+ # only discard right-padding.
927
+ # the logits input to this fn only contain decoder-side tokens.
928
+ logits = logits[:contlen]
929
+
930
+ return logits
931
+
932
+ def loglikelihood_rolling(
933
+ self, requests: List[Instance], disable_tqdm: bool = False
934
+ ) -> List[float]:
935
+ adaptive_batch_size = None
936
+ if self.batch_size == "auto":
937
+ # using rolling window with maximum context
938
+ print("Passed argument batch_size = auto. Detecting largest batch size")
939
+ batch_size = self._detect_batch_size()
940
+ print(f"Determined Largest batch size: {batch_size}")
941
+ adaptive_batch_size = batch_size
942
+
943
+ # First, collect all windows from all requests
944
+ all_windows = [] # List of (request_idx, window) tuples
945
+ request_window_counts = [] # Track number of windows per request
946
+
947
+ for req_idx, (string,) in enumerate(
948
+ tqdm(
949
+ [req.args for req in requests],
950
+ disable=(disable_tqdm or (self.rank != 0)),
951
+ )
952
+ ):
953
+ rolling_token_windows: List[Tuple[List[int], List[int]]] = list(
954
+ map(
955
+ utils.make_disjoint_window,
956
+ utils.get_rolling_token_windows(
957
+ token_list=self.tok_encode(string),
958
+ prefix_token=self.prefix_token_id,
959
+ max_seq_len=self.max_length,
960
+ context_len=1,
961
+ ),
962
+ )
963
+ )
964
+
965
+ # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
966
+ windows = [(None,) + x for x in rolling_token_windows]
967
+
968
+ # Store windows with their request index
969
+ all_windows.extend((req_idx, window) for window in windows)
970
+ request_window_counts.append(len(windows))
971
+
972
+ # Handle distributed case padding
973
+ pad_amnt = 0
974
+ if self.world_size > 1:
975
+ mytensor = torch.tensor(len(all_windows), device=self.device)
976
+ gathered = self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
977
+ pad_amnt = max(gathered) - gathered[self.rank]
978
+ if pad_amnt > 0:
979
+ all_windows += pad_amnt * [all_windows[0]]
980
+
981
+ all_nlls = []
982
+ batch_size = adaptive_batch_size or self.batch_size
983
+ for i in range(0, len(all_windows), batch_size):
984
+ batch = all_windows[i : i + batch_size]
985
+ # Extract just the windows for processing, keeping track of request indices
986
+ batch_indices, batch_windows = zip(*batch)
987
+
988
+ batch_nlls = self._loglikelihood_tokens(
989
+ requests=batch_windows,
990
+ disable_tqdm=False,
991
+ override_bs=len(batch_windows),
992
+ )
993
+ # Store results with their request indices
994
+ all_nlls.extend(zip(batch_indices, batch_nlls))
995
+
996
+ # Remove padding if necessary
997
+ if (self.world_size > 1) and (pad_amnt > 0):
998
+ all_nlls = all_nlls[:-pad_amnt]
999
+
1000
+ # Reconstruct per-request loglikelihoods
1001
+ loglikelihoods = []
1002
+ current_idx = 0
1003
+ for window_count in request_window_counts:
1004
+ # Get all nlls for this request
1005
+ request_nlls = all_nlls[current_idx : current_idx + window_count]
1006
+ # Sum up the nlls for this request (discarding is_greedy)
1007
+ request_total = sum(nll[0] for _, nll in request_nlls)
1008
+ loglikelihoods.append(request_total)
1009
+ current_idx += window_count
1010
+
1011
+ string = requests[len(loglikelihoods) - 1].args[0]
1012
+ self.cache_hook.add_partial(
1013
+ "loglikelihood_rolling", (string,), request_total
1014
+ )
1015
+
1016
+ return loglikelihoods
1017
+
1018
+ def _batch_scheduler(self, pos, n_reordered_requests):
1019
+ sched = pos // int(len(n_reordered_requests) / self.batch_schedule)
1020
+ if sched in self.batch_sizes:
1021
+ return self.batch_sizes[sched]
1022
+ if (len(self.batch_sizes) > 1) and (
1023
+ self.batch_sizes[sched - 1] == self.max_batch_size
1024
+ ):
1025
+ # if previous batch size is already maximal, skip recomputation
1026
+ self.batch_sizes[sched] = self.max_batch_size
1027
+ return self.batch_sizes[sched]
1028
+ print(
1029
+ f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size"
1030
+ )
1031
+ self.batch_sizes[sched] = self._detect_batch_size(n_reordered_requests, pos)
1032
+ print(f"Determined largest batch size: {self.batch_sizes[sched]}")
1033
+ return self.batch_sizes[sched]
1034
+
1035
+ def _loglikelihood_tokens(
1036
+ self,
1037
+ requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
1038
+ disable_tqdm: bool = False,
1039
+ override_bs: int = None,
1040
+ ) -> List[Tuple[float, bool]]:
1041
+ # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
1042
+ res = []
1043
+
1044
+ def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]):
1045
+ """Defines the key for the sorted method"""
1046
+ # the negative sign on len(toks) sorts descending - this has a few advantages:
1047
+ # - time estimates will always be over not underestimates, which is more useful for planning
1048
+ # - to know the size of a batch when going through the list, you know the first one is always the batch
1049
+ # padded context length. this is useful to simplify the batching logic and more importantly to make
1050
+ # automatic adaptive batches much much easier to implement
1051
+ # - any OOMs will happen right away rather than near the end
1052
+
1053
+ toks = req[1] + req[2]
1054
+ return -len(toks), tuple(toks)
1055
+
1056
+ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
1057
+ """Defines the key to group and lookup one-token continuations"""
1058
+ # Use with group_by="contexts" (optional)"
1059
+ # allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
1060
+ # speeds up some multiple-choice tasks proportionally to the number of choices.
1061
+ # groups requests by context+continuation[:-1] and infer on one request/group.
1062
+ return req[-2] + req[-1][:-1]
1063
+
1064
+ re_ord = Collator(
1065
+ requests,
1066
+ sort_fn=_collate,
1067
+ group_by="contexts"
1068
+ if self.backend == "causal" and self.logits_cache
1069
+ else None,
1070
+ group_fn=_lookup_one_token_cont,
1071
+ )
1072
+
1073
+ # automatic (variable) batch size detection for vectorization
1074
+ # pull longest context sample from request
1075
+ n_reordered_requests = len(re_ord)
1076
+ batch_size = (
1077
+ self.batch_size
1078
+ if self.batch_size != "auto"
1079
+ else override_bs
1080
+ if override_bs is not None
1081
+ else 0
1082
+ )
1083
+ batch_fn = (
1084
+ self._batch_scheduler
1085
+ if self.batch_size == "auto"
1086
+ and n_reordered_requests > 0
1087
+ and not override_bs
1088
+ else None
1089
+ )
1090
+
1091
+ chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn)
1092
+ pbar = tqdm(
1093
+ total=len(requests),
1094
+ disable=(disable_tqdm or (self.rank != 0)),
1095
+ desc="Running loglikelihood requests",
1096
+ )
1097
+ for chunk in chunks:
1098
+ inps = []
1099
+ cont_toks_list = []
1100
+ inplens = []
1101
+
1102
+ conts = []
1103
+ encoder_attns = []
1104
+
1105
+ padding_len_inp = None
1106
+ padding_len_cont = None
1107
+ # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
1108
+ # tensors, then we pack them together into a batch, call the model, and then pick it all apart
1109
+ # again because vectorizing is annoying
1110
+
1111
+ for _, context_enc, continuation_enc in chunk:
1112
+ # sanity check
1113
+ assert len(context_enc) > 0
1114
+ assert len(continuation_enc) > 0
1115
+ assert len(continuation_enc) <= self.max_length
1116
+
1117
+ # how this all works (illustrated on a causal decoder-only setup):
1118
+ # CTX CONT
1119
+ # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
1120
+ # model \ \
1121
+ # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
1122
+ # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice
1123
+
1124
+ # when too long to fit in context, truncate from the left
1125
+ if self.backend == "causal":
1126
+ total_length = len(context_enc) + len(continuation_enc)
1127
+ if total_length > self.max_length + 1:
1128
+ eval_logger.warn(
1129
+ f"Combined length of context ({len(context_enc)}) and continuation ({len(continuation_enc)}) "
1130
+ f"exceeds model's maximum length ({self.max_length}). "
1131
+ f"Truncating {total_length - self.max_length + 1} tokens from the left."
1132
+ )
1133
+ inp = torch.tensor(
1134
+ (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
1135
+ dtype=torch.long,
1136
+ device=self.device,
1137
+ )
1138
+ (inplen,) = inp.shape
1139
+ elif self.backend == "seq2seq":
1140
+ inp = torch.tensor(
1141
+ (context_enc)[-self.max_length :],
1142
+ dtype=torch.long,
1143
+ device=self.device,
1144
+ )
1145
+ (inplen,) = inp.shape
1146
+
1147
+ # build encoder attn masks
1148
+ encoder_attns.append(torch.ones_like(inp))
1149
+
1150
+ cont = torch.tensor(
1151
+ (continuation_enc)[-self.max_length :],
1152
+ # TODO: left-shift these?
1153
+ # TODO: our code assumes we never end up truncating conts for either model type
1154
+ dtype=torch.long,
1155
+ device=self.device,
1156
+ )
1157
+ (contlen,) = cont.shape
1158
+
1159
+ conts.append(cont)
1160
+
1161
+ padding_len_cont = (
1162
+ max(padding_len_cont, contlen)
1163
+ if padding_len_cont is not None
1164
+ else contlen
1165
+ )
1166
+
1167
+ padding_len_inp = (
1168
+ max(padding_len_inp, inplen)
1169
+ if padding_len_inp is not None
1170
+ else inplen
1171
+ )
1172
+
1173
+ inps.append(inp) # [1, inp_length]
1174
+ cont_toks_list.append(continuation_enc)
1175
+ inplens.append(inplen)
1176
+
1177
+ # create encoder attn mask and batched conts, if seq2seq
1178
+ call_kwargs = {}
1179
+ if self.backend == "causal":
1180
+ batched_inps = pad_and_concat(
1181
+ padding_len_inp, inps, padding_side="right"
1182
+ ) # [batch, padding_len_inp]
1183
+ elif self.backend == "seq2seq":
1184
+ # TODO: left-pad encoder inps and mask?
1185
+ batched_inps = pad_and_concat(
1186
+ padding_len_inp, inps
1187
+ ) # [batch, padding_len_inp]
1188
+ batched_conts = pad_and_concat(
1189
+ padding_len_cont, conts
1190
+ ) # [batch, padding_len_cont]
1191
+ batched_encoder_mask = pad_and_concat(
1192
+ padding_len_inp, encoder_attns
1193
+ ) # [batch, padding_len_inp]
1194
+ call_kwargs = {
1195
+ "attn_mask": batched_encoder_mask,
1196
+ "labels": batched_conts,
1197
+ }
1198
+
1199
+ multi_logits = F.log_softmax(
1200
+ self._model_call(batched_inps, **call_kwargs), dim=-1
1201
+ ) # [batch, padding_length (inp or cont), vocab]
1202
+
1203
+ for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip(
1204
+ chunk, multi_logits, inplens, cont_toks_list
1205
+ ):
1206
+ # Slice to original seq length
1207
+ contlen = len(cont_toks)
1208
+ # take only logits in the continuation
1209
+ # (discard context toks if decoder-only ; discard right-padding)
1210
+ # also discards + checks for "virtual tokens" in the causal LM's input window
1211
+ # from prompt/prefix tuning tokens, if applicable
1212
+ ctx_len = (
1213
+ inplen + (logits.shape[0] - padding_len_inp)
1214
+ if self.backend == "causal"
1215
+ else None
1216
+ )
1217
+ logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
1218
+ logits = logits.unsqueeze(0) # [1, seq, vocab]
1219
+
1220
+ # Check if per-token argmax is exactly equal to continuation
1221
+ greedy_tokens = logits.argmax(dim=-1)
1222
+
1223
+ # check for one-token continuation cache hits.
1224
+ # noop in case group_by != "contexts" or no cache hit and returns the
1225
+ # original args. Otherwise, expands the logits batch dimension and yields each
1226
+ # batch along with matching continuation tokens and prompt strings.
1227
+ # logits -> [1, seq, vocab]
1228
+ for request_str, cont_toks, logits in re_ord.get_cache(
1229
+ req_str=request_str,
1230
+ cxt_toks=ctx_tokens,
1231
+ cont_toks=cont_toks,
1232
+ logits=logits,
1233
+ ):
1234
+ cont_toks = torch.tensor(
1235
+ cont_toks, dtype=torch.long, device=self.device
1236
+ ).unsqueeze(0) # [1, seq]
1237
+ max_equal = (greedy_tokens == cont_toks).all()
1238
+
1239
+ # Obtain log-probs at the corresponding continuation token indices
1240
+ # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
1241
+ logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
1242
+ -1
1243
+ ) # [1, seq]
1244
+
1245
+ # Answer: (log prob, is-exact-match)
1246
+ answer = (float(logits.sum()), bool(max_equal))
1247
+
1248
+ res.append(answer)
1249
+
1250
+ if request_str is not None:
1251
+ # special case: loglikelihood_rolling produces a number of loglikelihood requests
1252
+ # all with cache key None. instead do add_partial on the per-example level
1253
+ # in the loglikelihood_rolling() function for those.
1254
+ self.cache_hook.add_partial(
1255
+ "loglikelihood", request_str, answer
1256
+ )
1257
+ pbar.update(1)
1258
+
1259
+ pbar.close()
1260
+
1261
+ return re_ord.get_original(res)
1262
+
1263
+ def generate_until(
1264
+ self, requests: List[Instance], disable_tqdm: bool = False
1265
+ ) -> List[str]:
1266
+ res = []
1267
+
1268
+ def _collate(req: Tuple[str, dict]):
1269
+ """Defines the key for the sorted method"""
1270
+ # the negative sign on len(toks) sorts descending - this has a few advantages:
1271
+ # - time estimates will always be over not underestimates, which is more useful for planning
1272
+ # - to know the size of a batch when going through the list, you know the first one is always the batch
1273
+ # padded context length. this is useful to simplify the batching logic and more importantly to make
1274
+ # automatic adaptive batches much much easier to implement
1275
+ # - any OOMs will happen right away rather than near the end
1276
+ toks = self.tok_encode(req[0])
1277
+ return -len(toks), req[0]
1278
+
1279
+ pbar = tqdm(
1280
+ total=len(requests),
1281
+ disable=(disable_tqdm or (self.rank != 0)),
1282
+ desc="Running generate_until requests",
1283
+ )
1284
+ adaptive_batch_size = None
1285
+ if self.batch_size == "auto":
1286
+ # using rolling window with maximum context
1287
+ print("Passed argument batch_size = auto. Detecting largest batch size")
1288
+ batch_size = self._detect_batch_size()
1289
+ print(f"Determined Largest batch size: {batch_size}")
1290
+ adaptive_batch_size = batch_size
1291
+ # for each different set of kwargs, we execute all requests, by batch.
1292
+ batch_size = (
1293
+ self.batch_size
1294
+ if self.batch_size != "auto"
1295
+ else adaptive_batch_size
1296
+ if adaptive_batch_size is not None
1297
+ else 0
1298
+ )
1299
+ batch_fn = (
1300
+ self._batch_scheduler
1301
+ if self.batch_size == "auto" and not adaptive_batch_size
1302
+ else None
1303
+ )
1304
+
1305
+ # we group requests by their generation_kwargs,
1306
+ # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
1307
+ # in the same batch.
1308
+ # group_fn=lambda x: x[1] -> x=(context, gen_kwargs)
1309
+ re_ords = Collator(
1310
+ [reg.args for reg in requests],
1311
+ sort_fn=_collate,
1312
+ group_by="gen_kwargs",
1313
+ group_fn=lambda x: x[1],
1314
+ )
1315
+ chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn)
1316
+ eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
1317
+ for chunk in chunks:
1318
+ contexts, all_gen_kwargs = zip(*chunk)
1319
+ # we assume all gen kwargs in the batch are the same
1320
+ # this is safe to assume because the `grouper` object ensures it.
1321
+ gen_kwargs = all_gen_kwargs[0]
1322
+ # unpack our keyword arguments.
1323
+ if isinstance(gen_kwargs, dict):
1324
+ kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
1325
+ # add EOS token to stop sequences
1326
+ until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
1327
+ else:
1328
+ raise ValueError(
1329
+ f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
1330
+ )
1331
+ if "max_gen_toks" in kwargs.keys():
1332
+ max_gen_toks = kwargs.pop("max_gen_toks")
1333
+ else:
1334
+ max_gen_toks = self.max_gen_toks
1335
+
1336
+ # set the max length in tokens of inputs ("context_enc")
1337
+ if self.backend == "causal":
1338
+ # max len for inputs = max length, minus room to generate the max new tokens
1339
+ max_ctx_len = self.max_length - max_gen_toks
1340
+ assert max_ctx_len > 0, (
1341
+ f"Invalid configuration: requested max tokens to generate ({max_gen_toks}) must be less than model's maximum sequence length ({self.max_length})."
1342
+ )
1343
+ elif self.backend == "seq2seq":
1344
+ # max len for inputs = encoder's whole max_length
1345
+ max_ctx_len = self.max_length
1346
+
1347
+ # encode, pad, and truncate contexts for this batch
1348
+ context_enc, attn_masks = self.tok_batch_encode(
1349
+ contexts,
1350
+ left_truncate_len=max_ctx_len,
1351
+ truncation=self.truncation,
1352
+ )
1353
+ context_enc = context_enc.to(self.device)
1354
+ attn_masks = attn_masks.to(self.device)
1355
+
1356
+ if "max_length" not in kwargs:
1357
+ kwargs["max_length"] = context_enc.shape[1] + max_gen_toks
1358
+
1359
+ # perform batched generation
1360
+ cont = self._model_generate(
1361
+ context=context_enc,
1362
+ attention_mask=attn_masks,
1363
+ stop=until,
1364
+ **kwargs,
1365
+ )
1366
+
1367
+ cont_toks_list = cont.tolist()
1368
+ for cont_toks, context in zip(cont_toks_list, contexts):
1369
+ # discard context + left-padding toks if using causal decoder-only LM
1370
+ if self.backend == "causal":
1371
+ cont_toks = cont_toks[context_enc.shape[1] :]
1372
+
1373
+ s = self.tok_decode(cont_toks)
1374
+
1375
+ # use secondary stop seqs to cut off should-have-been-stopped content post-hoc
1376
+ for term in until:
1377
+ if len(term) > 0:
1378
+ # ignore '' separator,
1379
+ # for seq2seq case where self.tok_decode(self.eot_token_id) = ''
1380
+ s = s.split(term)[0]
1381
+
1382
+ res.append(s)
1383
+
1384
+ self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s)
1385
+ pbar.update(1)
1386
+ # reorder this group of results back to original unsorted form
1387
+ res = re_ords.get_original(res)
1388
+
1389
+ pbar.close()
1390
+
1391
+ return res
1392
+
1393
+ def apply_chat_template(
1394
+ self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
1395
+ ) -> str:
1396
+ """
1397
+ Method to apply a chat template to a list of chat history between user and model.
1398
+ """
1399
+ try:
1400
+ chat_templated = self.tokenizer.apply_chat_template(
1401
+ chat_history,
1402
+ tokenize=False,
1403
+ add_generation_prompt=add_generation_prompt,
1404
+ continue_final_message=not add_generation_prompt,
1405
+ )
1406
+ except jinja2.exceptions.TemplateError:
1407
+ eval_logger.warning(
1408
+ "Failed to apply chat template. removing the system role in chat history."
1409
+ )
1410
+ chat_history = [msg for msg in chat_history if msg["role"] != "system"]
1411
+ chat_templated = self.tokenizer.apply_chat_template(
1412
+ chat_history,
1413
+ tokenize=False,
1414
+ add_generation_prompt=add_generation_prompt,
1415
+ continue_final_message=not add_generation_prompt,
1416
+ )
1417
+
1418
+ return chat_templated
1419
+
1420
+ def get_model_info(self) -> dict:
1421
+ """
1422
+ Method to get Hugging Face model information for experiment reproducibility.
1423
+ """
1424
+
1425
+ def get_model_num_params(model) -> int:
1426
+ if hasattr(model, "num_parameters"):
1427
+ return model.num_parameters()
1428
+ if hasattr(model, "parameters"):
1429
+ return sum(p.numel() for p in model.parameters())
1430
+ else:
1431
+ return -1
1432
+
1433
+ def get_model_dtype(model) -> str:
1434
+ if hasattr(model, "dtype"):
1435
+ return model.dtype
1436
+ else:
1437
+ return ""
1438
+
1439
+ def get_model_sha(pretrained: str, revision: str) -> str:
1440
+ try:
1441
+ model_info = HfApi().model_info(repo_id=pretrained, revision=revision)
1442
+ return model_info.sha
1443
+ except Exception as e:
1444
+ eval_logger.debug(
1445
+ f"Failed to get model SHA for {pretrained} at revision {revision}. Error: {e}"
1446
+ )
1447
+ return ""
1448
+
1449
+ model_info = {
1450
+ "model_num_parameters": get_model_num_params(self._model),
1451
+ "model_dtype": get_model_dtype(self._model),
1452
+ "model_revision": self.revision,
1453
+ "model_sha": get_model_sha(self.pretrained, self.revision),
1454
+ }
1455
+ if self.peft:
1456
+ model_info["peft_sha"] = get_model_sha(self.peft, self.revision)
1457
+ if self.delta:
1458
+ model_info["delta_sha"] = get_model_sha(self.delta, self.revision)
1459
+ return model_info
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/utils.py ADDED
@@ -0,0 +1,731 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import fnmatch
3
+ import gc
4
+ import itertools
5
+ import logging
6
+ import time
7
+ from functools import wraps
8
+ from typing import (
9
+ TYPE_CHECKING,
10
+ Any,
11
+ Callable,
12
+ Dict,
13
+ Iterable,
14
+ Iterator,
15
+ List,
16
+ Literal,
17
+ Optional,
18
+ Tuple,
19
+ Type,
20
+ Union,
21
+ )
22
+
23
+ import torch
24
+ import transformers
25
+
26
+
27
+ eval_logger = logging.getLogger(__name__)
28
+
29
+
30
+ if TYPE_CHECKING:
31
+ from transformers import PreTrainedTokenizerBase
32
+ from transformers.configuration_utils import PretrainedConfig
33
+
34
+
35
+ def chunks(iter, n: int = 0, fn=None):
36
+ """
37
+ Divides an iterable into chunks of specified size or based on a given function.
38
+ Useful for batching
39
+
40
+ Parameters:
41
+ - iter: The input iterable to be divided into chunks.
42
+ - n: An integer representing the size of each chunk. Default is 0.
43
+ - fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None.
44
+
45
+ Returns:
46
+ An iterator that yields chunks of the input iterable.
47
+
48
+ Example usage:
49
+ ```
50
+ data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
51
+ for chunk in chunks(data, 3):
52
+ print(chunk)
53
+ ```
54
+ Output:
55
+ ```
56
+ [1, 2, 3]
57
+ [4, 5, 6]
58
+ [7, 8, 9]
59
+ [10]
60
+ ```
61
+ """
62
+ arr = []
63
+ for i, x in enumerate(iter):
64
+ arr.append(x)
65
+ if len(arr) == (fn(i, iter) if fn else n):
66
+ yield arr
67
+ arr = []
68
+
69
+ if arr:
70
+ yield arr
71
+
72
+
73
+ class MultiChoice:
74
+ def __init__(self, choices) -> None:
75
+ self.choices = choices
76
+
77
+ # Simple wildcard support (linux filename patterns)
78
+ def __contains__(self, values) -> bool:
79
+ for value in values.split(","):
80
+ if len(fnmatch.filter(self.choices, value)) == 0:
81
+ eval_logger.info("Available tasks to choose:")
82
+ for choice in self.choices:
83
+ eval_logger.info(f" - {choice}")
84
+ raise ValueError("'{}' is not in task list".format(value))
85
+ return True
86
+
87
+ def __iter__(self) -> Iterator:
88
+ for choice in self.choices:
89
+ yield choice
90
+
91
+
92
+ class Grouper:
93
+ """
94
+ takes an array `arr` and function `fn` and returns a dictionary
95
+ with keys fn(ob) for each ob in `arr` and with values `self.arr[key]` a list of all
96
+ objects in `arr` satisfying `key == fn(ob)`.
97
+ """
98
+
99
+ def __init__(self, arr, fn) -> None:
100
+ # self.orig_arr = arr
101
+ self.size = len(arr)
102
+ arr = list(enumerate(arr))
103
+
104
+ def group_return_dict(arr, fn):
105
+ res = collections.defaultdict(list)
106
+
107
+ for ob in arr:
108
+ res[fn(ob)].append(ob)
109
+ return res
110
+
111
+ arr = group_return_dict(arr, lambda x: fn(x[1]))
112
+
113
+ # self.arr has format Dict[Tuple[int, <entry from orig. arr>]]
114
+ self.arr = arr
115
+ self._grouped = None
116
+
117
+ def get_grouped(self):
118
+ # return the contents but not indices for our grouped dict.
119
+ if self._grouped:
120
+ return self._grouped
121
+ grouped = {}
122
+ for key in self.arr.keys():
123
+ # drop the index from each element of self.arr
124
+ grouped[key] = [y[1] for y in self.arr[key]]
125
+ self._grouped = grouped
126
+ return grouped
127
+
128
+ def get_original(self, grouped_dict):
129
+ # take in a grouped dictionary with e.g. results for each key listed
130
+ # in the same order as the instances in `self.arr`, and
131
+ # return the results in the same (single list) order as `self.orig_arr`.
132
+ res = [None] * self.size
133
+ cov = [False] * self.size
134
+ # orig = [None] * self.size
135
+
136
+ assert grouped_dict.keys() == self.arr.keys()
137
+
138
+ for key in grouped_dict.keys():
139
+ for (ind, _), v in zip(self.arr[key], grouped_dict[key]):
140
+ res[ind] = v
141
+ cov[ind] = True
142
+ # orig[ind] = _
143
+
144
+ assert all(cov)
145
+ # assert orig == self.orig_arr
146
+
147
+ return res
148
+
149
+
150
+ def pad_and_concat(
151
+ max_length: int,
152
+ tensors: List[torch.Tensor],
153
+ padding_side: Literal["right", "left"] = "right",
154
+ ):
155
+ """
156
+ Method for padding a list of tensors given the maximum tensor
157
+ length in the batch. Used for batching inputs and continuations in
158
+ seq2seq models.
159
+ """
160
+ assert padding_side == "left" or padding_side == "right", (
161
+ f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
162
+ )
163
+
164
+ for i, tensor in enumerate(tensors):
165
+ if len(tensor.shape) == 2:
166
+ tensor = tensor.squeeze(0) # squeeze, in case passed [1, seq] size
167
+ tensor_len = tensor.shape[0]
168
+ if tensor_len < max_length:
169
+ if padding_side == "right":
170
+ # right-pad
171
+ tensors[i] = torch.cat(
172
+ [
173
+ tensor, # [seq]
174
+ torch.zeros(
175
+ max_length - tensor_len,
176
+ dtype=torch.long,
177
+ device=tensor.device,
178
+ ), # [padding_length - seq]
179
+ ],
180
+ dim=0,
181
+ ).unsqueeze(0)
182
+ else:
183
+ # left-pad
184
+ tensors[i] = torch.cat(
185
+ [
186
+ torch.zeros(
187
+ max_length - tensor_len,
188
+ dtype=torch.long,
189
+ device=tensor.device,
190
+ ), # [padding_length - seq]
191
+ tensor, # [seq]
192
+ ],
193
+ dim=0,
194
+ ).unsqueeze(0)
195
+ else:
196
+ tensors[i] = tensor.unsqueeze(0)
197
+
198
+ return torch.cat(tensors, dim=0)
199
+
200
+
201
+ def clear_torch_cache() -> None:
202
+ gc.collect()
203
+ torch.cuda.empty_cache()
204
+
205
+
206
+ def get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype:
207
+ """Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig"""
208
+ if isinstance(dtype, str) and dtype != "auto":
209
+ # Convert `str` args torch dtype: `float16` -> `torch.float16`
210
+ _torch_dtype = getattr(torch, dtype)
211
+ else:
212
+ _torch_dtype = dtype
213
+ return _torch_dtype
214
+
215
+
216
+ class MultiTokenEOSCriteria(transformers.StoppingCriteria):
217
+ """Criteria to stop on the specified multi-token sequence."""
218
+
219
+ def __init__(
220
+ self,
221
+ sequence: str,
222
+ tokenizer: transformers.PreTrainedTokenizer,
223
+ initial_decoder_input_length: int,
224
+ batch_size: int,
225
+ ) -> None:
226
+ self.initial_decoder_input_length = initial_decoder_input_length
227
+ self.done_tracker = [False] * batch_size
228
+ self.sequence = sequence
229
+ self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
230
+ # print(sequence, self.sequence_ids)
231
+ # we look back for 2 more tokens than it takes to encode our stop sequence
232
+ # because tokenizers suck, and a model might generate `['\n', '\n']` but our `sequence` is `['\n\n']`
233
+ # and we don't want to mistakenly not stop a generation because our
234
+ # (string) stop sequence was output in a different tokenization
235
+
236
+ # NOTE: there is a minor danger that this will end up looking back 2 tokens into the past, into the inputs to the model,
237
+ # and stopping generation immediately as a result. With only 2 extra tokens of lookback, this risk is minimized
238
+ # Additionally, in lookback_ids_batch we should prevent ever looking back into the inputs as described.
239
+ self.sequence_id_len = len(self.sequence_ids) + 2
240
+ self.tokenizer = tokenizer
241
+
242
+ def __call__(self, input_ids, scores, **kwargs) -> bool:
243
+ # For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
244
+ lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :]
245
+
246
+ lookback_ids_batch = lookback_ids_batch[:, -self.sequence_id_len :]
247
+
248
+ lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
249
+
250
+ for i, done in enumerate(self.done_tracker):
251
+ if not done:
252
+ self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
253
+ return False not in self.done_tracker
254
+
255
+
256
+ def stop_sequences_criteria(
257
+ tokenizer: transformers.PreTrainedTokenizer,
258
+ stop_sequences: List[str],
259
+ initial_decoder_input_length: int,
260
+ batch_size: int,
261
+ ) -> transformers.StoppingCriteriaList:
262
+ return transformers.StoppingCriteriaList(
263
+ [
264
+ *[
265
+ MultiTokenEOSCriteria(
266
+ sequence, tokenizer, initial_decoder_input_length, batch_size
267
+ )
268
+ for sequence in stop_sequences
269
+ ],
270
+ ]
271
+ )
272
+
273
+
274
+ def undistribute(iterable):
275
+ """
276
+ Undoes https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.distribute .
277
+
278
+ Re-interleaves results that have been split using more_itertools.distribute:
279
+ >>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6])
280
+ >>> list(group_1)
281
+ [1, 3, 5]
282
+ >>> list(group_2)
283
+ [2, 4, 6]
284
+ >>> undistribute([group_1, group_2])
285
+ [1, 2, 3, 4, 5, 6]
286
+
287
+ Handles non-uniform component lengths:
288
+
289
+ >>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7])
290
+ >>> [list(c) for c in children]
291
+ [[1, 4, 7], [2, 5], [3, 6]]
292
+ >>> undistribute(children)
293
+ [1, 2, 3, 4, 5, 6, 7]
294
+
295
+ Also handles when some iterables are empty:
296
+
297
+ >>> children = distribute(5, [1, 2, 3])
298
+ >>> [list(c) for c in children]
299
+ [[1], [2], [3], [], []]
300
+ >>> undistribute(children)
301
+ [1, 2, 3]
302
+
303
+ """
304
+
305
+ return [
306
+ x
307
+ for x in itertools.chain.from_iterable(
308
+ itertools.zip_longest(*[list(x) for x in iterable])
309
+ )
310
+ if x is not None
311
+ ]
312
+
313
+
314
+ def retry_on_specific_exceptions(
315
+ on_exceptions: List[Type[Exception]],
316
+ max_retries: Optional[int] = None,
317
+ backoff_time: float = 3.0,
318
+ backoff_multiplier: float = 1.5,
319
+ on_exception_callback: Optional[Callable[[Exception, float], Any]] = None,
320
+ ):
321
+ """Retry on an LLM Provider's rate limit error with exponential backoff
322
+ For example, to use for OpenAI, do the following:
323
+ ```
324
+ from openai import RateLimitError
325
+
326
+ # Recommend specifying max_retries to avoid infinite loops!
327
+ @retry_on_specific_exceptions([RateLimitError], max_retries=3)
328
+ def completion(...):
329
+ # Wrap OpenAI completion function here
330
+ ...
331
+ ```
332
+ """
333
+
334
+ def decorator(func: Callable):
335
+ @wraps(func)
336
+ def wrapper(*args, **kwargs):
337
+ sleep_time = backoff_time
338
+ attempt = 0
339
+ while max_retries is None or attempt < max_retries:
340
+ try:
341
+ return func(*args, **kwargs)
342
+ except tuple(on_exceptions) as e:
343
+ if on_exception_callback is not None:
344
+ on_exception_callback(e, sleep_time)
345
+ time.sleep(sleep_time)
346
+ sleep_time *= backoff_multiplier
347
+ attempt += 1
348
+
349
+ return wrapper
350
+
351
+ return decorator
352
+
353
+
354
+ class Collator:
355
+ """
356
+ A class for reordering and batching elements of an array.
357
+
358
+ This class allows for sorting an array based on a provided sorting function, grouping elements based on a grouping function, and generating batches from the sorted and grouped data.
359
+
360
+ Objects of this class have the group_by attribute which determines the method for grouping
361
+ the data while batching it. Three options include "gen_kwargs", "contexts", or None:
362
+ If group_by == "gen_kwargs" then requests will be grouped by gen_kwargs
363
+ If group_by == "contexts" then requests will be grouped by context + cont[:-1]
364
+ If None then requests will just be reordered by length descending.
365
+ """
366
+
367
+ def __init__(
368
+ self,
369
+ arr: List,
370
+ sort_fn: Callable = lambda x: x,
371
+ group_fn: Callable = lambda x: x[1],
372
+ group_by: Union[Literal["gen_kwargs", "contexts"], None] = None,
373
+ ) -> None:
374
+ self._group_by = group_by
375
+ # 0 indices are enumerated indices. Apply functions to original arr.
376
+ self._sort_fn = lambda x: sort_fn(x[1])
377
+ self._group_fn = lambda x: group_fn(x[1])
378
+ self._reorder_indices: List = []
379
+ self._size = len(arr)
380
+ self._arr_with_indices: Union[Dict, Tuple[Tuple[int, Any], ...]] = tuple(
381
+ enumerate(arr)
382
+ ) # [indices, (arr)]
383
+ if self._group_by == "contexts":
384
+ self._group_by_context()
385
+ elif self._group_by == "gen_kwargs":
386
+ self._group_by_index()
387
+
388
+ def _group_by_index(self) -> None:
389
+ """Group the elements of a list based on their indices."""
390
+ self._arr_with_indices = self.group(
391
+ self._arr_with_indices, fn=self._group_fn, group_by="gen_kwargs"
392
+ )
393
+
394
+ def _group_by_context(self) -> None:
395
+ """Group the array with indices by context."""
396
+ self._arr_with_indices = self.group(
397
+ self._arr_with_indices, fn=self._group_fn, group_by="contexts"
398
+ )
399
+
400
+ def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None) -> Iterator:
401
+ """
402
+ Generates and yields batches from the reordered array. The method of grouping and batching
403
+ depends on the parameter `group_by`.
404
+ If `group_by` is set to "gen_kwargs", it will batch the
405
+ re-ordered values with same gen_kwargs for each batch.
406
+ If `group_by` is "contexts", it caches the requests by context before batching.
407
+ If `group_by` is neither "gen_kwargs" nor "contexts", it yields the reordered array
408
+
409
+ Parameters:
410
+ - n (int): The size of each batch. Defaults to 1.
411
+ - batch_fn ([Callable[[int, Iterable], int]] | None): A function to determine the size of
412
+ each batch. Optional, defaults to None.
413
+
414
+ Returns:
415
+ Iterator: An iterator over batches of reordered elements grouped as per the `group_by`
416
+ attribute.
417
+
418
+ Yields:
419
+ List of batched elements according to the `group_by` attribute.
420
+ """
421
+ if self._group_by == "gen_kwargs":
422
+ for (
423
+ key,
424
+ values,
425
+ ) in self._arr_with_indices.items(): # type: ignore
426
+ values = self._reorder(values)
427
+ batch = self.get_chunks(values, n=n, fn=batch_fn)
428
+ yield from batch
429
+ elif self._group_by == "contexts":
430
+ # Get one sample from each key
431
+ values = self._reorder(
432
+ [value[0] for value in self._arr_with_indices.values()]
433
+ )
434
+ batch = self.get_chunks(values, n=n, fn=batch_fn)
435
+ yield from batch
436
+ else:
437
+ values = self._reorder(self._arr_with_indices) # type: ignore
438
+ batch = self.get_chunks(values, n=n, fn=batch_fn)
439
+ yield from batch
440
+
441
+ def get_cache(
442
+ self,
443
+ req_str: Tuple[str, str] = None,
444
+ cxt_toks: List[int] = None,
445
+ cont_toks: List[int] = None,
446
+ logits: torch.Tensor = None,
447
+ ) -> Iterator[Tuple[Tuple[str, str], List[int], torch.Tensor]]:
448
+ """
449
+ Retrieves cached single-token continuations and their associated arguments, updating indices as necessary.
450
+
451
+ The behavior of this function varies depending on how the `group_by` attribute is set:
452
+
453
+ - When `group_by` is "contexts":
454
+ The function identifies single-token continuations by checking for keys that equate to
455
+ [context+continuation][-1] and logs the indices for re-ordering.
456
+ In this mode, this function can work in two scenarios:
457
+
458
+ 1. Cache Hit - Single Match:
459
+ If a single matching context-continuation pair is found in the cache,
460
+ the function yields the original arguments.
461
+
462
+ 2. Cache Hit - Multiple Matches:
463
+ If multiple matching context-continuation pairs are found in the cache,
464
+ the function expands the logits batch dimension to match the number of cache hits.
465
+ It updates the original requests and continuation tokens.
466
+
467
+ - When `group_by` is not set to "contexts":
468
+ This method yields the original arguments, logits and continuation tokens,
469
+ without checking for one-token continuations.
470
+
471
+ Parameters:
472
+ - req_str (tuple[str, str]): Original strings used for CachingLM.
473
+ - cxt_toks (list[int]): Full context tokens used for lookup.
474
+ - cont_toks (list[int]): Continuation tokens for which logits were generated.
475
+ - logits (torch.Tensor [1, seq_length, vocab_size]): Logits generated by the model given context and continuation keys.
476
+
477
+ Yields:
478
+ - Iterator:
479
+ - req_str (tuple[str, str]): strings used for CachingLM.
480
+ - cont_toks (list[int]) : continuation tokens.
481
+ - logits (torch.Tensor [1, seq_length, vocab_size]): The original logits (repeated cache hit times)
482
+ """
483
+ if self._group_by == "contexts":
484
+ cache_hit: List[
485
+ Tuple[int, Tuple[Tuple[str, str], List[int], List[int]]]
486
+ ] = self._arr_with_indices.pop(tuple(cxt_toks + cont_toks[:-1]))
487
+ if (cache_size := len(cache_hit)) == 1:
488
+ self._reorder_indices.extend(x[0] for x in cache_hit)
489
+ yield req_str, cont_toks, logits
490
+ else:
491
+ # If we have matching requests then expand the batch dimension (no-op) and
492
+ # yield each along with its corresponding args.
493
+ multilogits = logits.expand(cache_size, -1, -1).chunk(cache_size)
494
+ indices, req_str, cont_toks = zip(
495
+ *[(x[0], x[1][0], x[-1][-1]) for x in cache_hit]
496
+ )
497
+ self._reorder_indices.extend(indices)
498
+ for c_key, cont_tok, logit in zip(req_str, cont_toks, multilogits):
499
+ yield c_key, cont_tok, logit
500
+ else:
501
+ yield req_str, cont_toks, logits
502
+
503
+ def _reorder(self, arr: Union[List, Tuple[Tuple[int, Any], ...]]) -> Iterator:
504
+ """
505
+ Reorders the elements in the array based on the sorting function.
506
+
507
+ Parameters:
508
+ - arr (list | tuple[tuple[int, Any], ...]]): The array or iterable to be reordered.
509
+
510
+ Yields:
511
+ Iterator
512
+ """
513
+ arr = sorted(arr, key=self._sort_fn)
514
+ if not self._group_by == "contexts":
515
+ # If grouped by contexts then indices will be set in get_cache()
516
+ self._reorder_indices.extend([x[0] for x in arr])
517
+ yield from [x[1] for x in arr]
518
+
519
+ def get_original(self, newarr: List) -> List:
520
+ """
521
+ Restores the original order of elements from the reordered list.
522
+
523
+ Parameters:
524
+ - newarr (list): The reordered array.
525
+
526
+ Returns:
527
+ list: The array with elements restored to their original order.
528
+ """
529
+ res = [None] * self._size
530
+ cov = [False] * self._size
531
+
532
+ for ind, v in zip(self._reorder_indices, newarr):
533
+ res[ind] = v
534
+ cov[ind] = True
535
+
536
+ assert all(cov)
537
+
538
+ return res
539
+
540
+ def __len__(self):
541
+ return self._size
542
+
543
+ @staticmethod
544
+ def group(
545
+ arr: Iterable,
546
+ fn: Callable,
547
+ group_by: Literal["gen_kwargs", "contexts"] = "gen_kwargs",
548
+ ) -> dict:
549
+ """
550
+ Groups elements of an iterable based on a provided function.
551
+
552
+
553
+ The `group_by` parameter determines the method of grouping.
554
+ If `group_by` is "contexts", the elements are grouped by [context + cont][:-1].
555
+ If `group_by` is "gen_kwargs", the elements are grouped based on the gen_kwargs dict.
556
+
557
+ Parameters:
558
+ - arr (Iterable): The iterable to be grouped.
559
+ - fn (Callable): The function to determine the grouping.
560
+ - values (bool): If True, returns the values of the group. Defaults to False.
561
+
562
+ Returns:
563
+ Iterator: An iterable of grouped elements.
564
+ """
565
+ res = collections.defaultdict(list)
566
+ for ob in arr:
567
+ # where ob == [context + cont]
568
+ if group_by == "contexts":
569
+ res[tuple(fn(ob))].append(ob)
570
+ else:
571
+ try:
572
+ hashable_dict = tuple(
573
+ (
574
+ key,
575
+ tuple(value)
576
+ if isinstance(value, collections.abc.Iterable)
577
+ else value,
578
+ )
579
+ for key, value in sorted(fn(ob).items())
580
+ )
581
+ res[hashable_dict].append(ob)
582
+ except (TypeError, AttributeError):
583
+ res[tuple(fn(ob))].append(ob)
584
+ return res
585
+
586
+ @staticmethod
587
+ def get_chunks(_iter, n: int = 0, fn=None):
588
+ """
589
+ Divides an iterable into chunks of specified size or based on a given function.
590
+ Useful for batching
591
+
592
+ Parameters:
593
+ - iter: The input iterable to be divided into chunks.
594
+ - n: An integer representing the size of each chunk. Default is 0.
595
+ - fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None.
596
+
597
+ Returns:
598
+ An iterator that yields chunks of the input iterable.
599
+
600
+ Example usage:
601
+ ```
602
+ data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
603
+ for chunk in chunks(data, 3):
604
+ print(chunk)
605
+ ```
606
+ Output:
607
+ ```
608
+ [1, 2, 3]
609
+ [4, 5, 6]
610
+ [7, 8, 9]
611
+ [10]
612
+ ```
613
+ """
614
+ arr = []
615
+ _iter = tuple(_iter)
616
+ for i, x in enumerate(_iter):
617
+ arr.append(x)
618
+ if len(arr) == (fn(i, _iter) if fn else n):
619
+ yield arr
620
+ arr = []
621
+
622
+ if arr:
623
+ yield arr
624
+
625
+
626
+ def configure_pad_token(
627
+ tokenizer: "PreTrainedTokenizerBase",
628
+ model_config: Optional["PretrainedConfig"] = None,
629
+ ) -> "PreTrainedTokenizerBase":
630
+ """
631
+ This function checks if the (Hugging Face) tokenizer has a padding token and sets it if not present.
632
+ Some tokenizers require special handling.
633
+
634
+ Args:
635
+ tokenizer: The tokenizer for which the padding token is to be handled.
636
+ model_config: The configuration of the model. Default is None.
637
+
638
+ Returns:
639
+ The tokenizer after the padding token has been handled.
640
+
641
+ Raises:
642
+ AssertionError: If the tokenizer is of type RWKVWorldTokenizer or Rwkv5Tokenizer and the padding token id is not 0.
643
+ """
644
+ if tokenizer.pad_token:
645
+ pass
646
+ elif tokenizer.unk_token:
647
+ tokenizer.pad_token_id = tokenizer.unk_token_id
648
+ elif tokenizer.eos_token:
649
+ tokenizer.pad_token_id = tokenizer.eos_token_id
650
+ else:
651
+ # handle special cases
652
+ if model_config and getattr(model_config, "model_type", None) == "qwen":
653
+ # Qwen's trust_remote_code tokenizer does not allow for adding special tokens
654
+ tokenizer.pad_token = "<|endoftext|>"
655
+ elif (
656
+ tokenizer.__class__.__name__ == "RWKVWorldTokenizer"
657
+ or tokenizer.__class__.__name__ == "Rwkv5Tokenizer"
658
+ ):
659
+ # The RWKV world tokenizer, does not allow for adding special tokens / setting the pad token (which is set as 0)
660
+ # The additional tokenizer name check is needed, as there exists rwkv4 models with neox tokenizer
661
+ # ---
662
+ # Note that the world tokenizer class name, might change in the future for the final huggingface merge
663
+ # https://github.com/huggingface/transformers/pull/26963
664
+ assert tokenizer.pad_token_id == 0
665
+ else:
666
+ tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
667
+
668
+ return tokenizer
669
+
670
+
671
+ def replace_placeholders(
672
+ string: str, default_placeholder: str, image_token: str, max_images: int
673
+ ):
674
+ """
675
+ A utility function used for local multimodal models. It locates all `placeholder` string
676
+ occurrences in the given input `string_` and replaces the first `max_count` instances with
677
+ `replacement`, and all subsequent occurrences with the empty string.
678
+
679
+ This is used to replace <image> placeholder tags by model-specific image tokens like <|image_pad|>
680
+ and to allow for only the first `max_count` images to be passed to a model if desired.
681
+
682
+ :param string: The original string containing placeholders.
683
+ :param default_placeholder: The placeholder text to be replaced.
684
+ :param image_token: The token to replace the placeholder with.
685
+ :param max_images: The maximum number of replacements to make.
686
+ :return: The string with placeholders replaced.
687
+ """
688
+ count = 0
689
+ result = []
690
+
691
+ parts = string.split(default_placeholder)
692
+ for part in parts[:-1]: # Iterate through all but the last part
693
+ result.append(part)
694
+ if count < max_images:
695
+ result.append(image_token)
696
+ count += 1
697
+ elif default_placeholder != image_token:
698
+ result.append(default_placeholder)
699
+
700
+ # Add the last part of the string
701
+ result.append(parts[-1])
702
+ return "".join(result)
703
+
704
+
705
+ def flatten_image_list(images: List[List]):
706
+ """
707
+ Takes in a list of lists of images, and returns a single list of all images in order.
708
+ Used for some multimodal models like Llava-1.5 which expects this flattened-list format for its image processor.
709
+
710
+ :param images: A list of lists of PIL images.
711
+ :return: a list of PIL images, via concatenating all the sub-lists in order.
712
+ """
713
+ return [image for image_list in images for image in image_list]
714
+
715
+
716
+ def handle_stop_sequences(
717
+ until: Union[str, List[str], None], eos: Optional[str]
718
+ ) -> List[str]:
719
+ """Ensures that the `until` parameter is a list of stop sequences and includes the EOS token."""
720
+ if isinstance(until, str):
721
+ until = [until]
722
+ elif until is None:
723
+ until = []
724
+ elif not isinstance(until, list):
725
+ raise ValueError(
726
+ f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
727
+ )
728
+
729
+ if eos is not None and eos not in until:
730
+ until.append(eos)
731
+ return until
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/verifier.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import logging
3
+ import ast
4
+ import re
5
+ import numpy as np
6
+ import textwrap
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class CodeVerifier:
11
+ def __init__(self, model, tokenizer, device="cuda"):
12
+ self.model = model
13
+ self.tokenizer = tokenizer
14
+ self.device = device
15
+
16
+ self.yes_ids, self.no_ids = [], []
17
+ for t in ["Yes", " Yes", "YES"]:
18
+ ids = self.tokenizer.encode(t, add_special_tokens=False)
19
+ if len(ids) > 0: self.yes_ids.append(ids[-1])
20
+ for t in ["No", " No", "NO"]:
21
+ ids = self.tokenizer.encode(t, add_special_tokens=False)
22
+ if len(ids) > 0: self.no_ids.append(ids[-1])
23
+
24
+ self.yes_ids = list(set(self.yes_ids))
25
+ self.no_ids = list(set(self.no_ids))
26
+
27
+ def _extract_python_code(self, text):
28
+ text = text.strip()
29
+ match = re.search(r"```python\s*(.*?)```", text, re.DOTALL)
30
+ if match: return match.group(1)
31
+ match_generic = re.search(r"```\s*(.*?)```", text, re.DOTALL)
32
+ if match_generic: return match_generic.group(1)
33
+ return text
34
+
35
+ def check_syntax(self, code_str):
36
+ clean_code = self._extract_python_code(code_str)
37
+ try:
38
+ if len(clean_code.strip()) < 5: return False
39
+ ast.parse(clean_code)
40
+ return True
41
+ except:
42
+ return False
43
+
44
+ def compute_confidence(self, logits):
45
+ if logits is None: return 0.0
46
+ probs = torch.softmax(logits, dim=-1)
47
+ max_probs, _ = torch.max(probs, dim=-1)
48
+ log_probs = torch.log(max_probs + 1e-10)
49
+ return torch.exp(torch.mean(log_probs)).item()
50
+
51
+ def svf_score(self, prompt, code_str, task_type="code"):
52
+
53
+ max_len = 2000
54
+ if len(code_str) > max_len:
55
+ if task_type == "reasoning":
56
+ truncated_code = code_str[:500] + "\n...[truncated]...\n" + code_str[-(max_len-500):]
57
+ else:
58
+ truncated_code = code_str[-max_len:]
59
+ else:
60
+ truncated_code = code_str
61
+
62
+ if task_type == "code":
63
+ prompt_template = f"""
64
+ You are an expert programming contest judge. Your task is to evaluate a generated solution for a given problem based on correctness, efficiency, and adherence to constraints.
65
+
66
+ [Problem Statement]
67
+ {prompt}
68
+ [/Problem Statement]
69
+
70
+ [Proposed Python Solution]
71
+ ```python
72
+ {truncated_code}
73
+ ```
74
+ [/Proposed Python Solution]
75
+
76
+ **Analysis Steps:**
77
+ 1. Correctness: Does the core algorithm correctly solve the problem?
78
+ 2. Efficiency: Is the time complexity acceptable for the given constraints?
79
+ 3. Edge Cases & Constraints: Does the code handle all rules and edge cases?
80
+
81
+ **Conclusion**: Based on your analysis, is the solution likely to be fully correct? Answer with a single word: Yes or No.
82
+ **Answer:** """
83
+
84
+ elif task_type == "math":
85
+ prompt_template = f"""
86
+ You are an expert mathematician and competition judge. Your task is to evaluate a proposed mathematical solution for a given problem based on its logical rigor and accuracy.
87
+
88
+ [Math Problem]
89
+ {prompt}
90
+ [/Math Problem]
91
+
92
+ [Proposed Mathematical Solution]
93
+ {truncated_code}
94
+ [/Proposed Mathematical Solution]
95
+
96
+ **Analysis Steps:**
97
+ 1. Reasoning Validity: Are the logical steps and mathematical properties applied correctly?
98
+ 2. Calculation Accuracy: Are the intermediate calculations or algebraic manipulations accurate?
99
+ 3. Goal Alignment: Does the current reasoning path directly lead toward the final answer required by the problem?
100
+
101
+ **Conclusion**: Based on your analysis, is this solution path sound and likely to result in the correct final answer? Answer with a single word: Yes or No.
102
+ **Answer:** """
103
+
104
+ elif task_type == "reasoning":
105
+ prompt_template = f"""
106
+ You are an expert reading comprehension and faithfulness judge. Your task is to evaluate a generated answer based on the provided context and question.
107
+
108
+ [Context and Question]
109
+ {prompt}
110
+ [/Context and Question]
111
+
112
+ [Proposed Answer]
113
+ {truncated_code}
114
+ [/Proposed Answer]
115
+
116
+ **Analysis Steps :**
117
+ 1. Faithfulness: Is the answer an exact, literal span from the context?
118
+ 2. Relevance: Does the answer directly address the specific question asked without hallucinating external information?
119
+ 3. Accuracy: Does the provided context strictly support this answer?
120
+
121
+ **Conclusion**: Based on your analysis, is the answer fully faithful to the context and correct? Answer with a single word: Yes or No.
122
+ **Answer:** """
123
+
124
+ else:
125
+ prompt_template = f"Is the following answer correct?\nQuestion: {prompt}\nAnswer: {truncated_code}\nAnswer Yes or No.\nAnswer:"
126
+
127
+ verify_text = textwrap.dedent(prompt_template).strip()
128
+ input_ids = self.tokenizer(verify_text, return_tensors="pt").input_ids.to(self.device)
129
+
130
+ max_pos = getattr(self.model.config, "max_position_embeddings",
131
+ getattr(self.model.config, "n_positions",
132
+ getattr(self.model.config, "max_sequence_length", 20480)))
133
+
134
+ if input_ids.shape[1] > max_pos - 16:
135
+ logger.warning("Verifier input is too long, truncating from the left.")
136
+ input_ids = input_ids[:, -(max_pos - 16):]
137
+
138
+ with torch.no_grad():
139
+ with torch.amp.autocast('cuda', dtype=torch.bfloat16):
140
+ outputs = self.model(input_ids, 'full')
141
+ logits = outputs.logits[0, -1, :]
142
+
143
+ yes_score = max((logits[i].item() for i in self.yes_ids if i < logits.shape[-1]), default=-float('inf'))
144
+ no_score = max((logits[i].item() for i in self.no_ids if i < logits.shape[-1]), default=-float('inf'))
145
+
146
+ if yes_score == -float('inf') and no_score == -float('inf'): return 0.5
147
+
148
+ probs = torch.softmax(torch.tensor([yes_score, no_score]), dim=0)
149
+ return probs[0].item()
150
+
151
+ def get_reward(self, prompt, code_str, mode="confidence", problem_data=None, current_logits=None, task_type="code"):
152
+ if mode == "svf":
153
+ return self.svf_score(prompt, code_str, task_type=task_type)
154
+ else:
155
+ return self.compute_confidence(current_logits)
Prism/Dream/Dream_Prism/eval_instruct/lm_eval/utils.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import fnmatch
3
+ import functools
4
+ import hashlib
5
+ import importlib.util
6
+ import inspect
7
+ import json
8
+ import logging
9
+ import os
10
+ import re
11
+ from dataclasses import asdict, is_dataclass
12
+ from itertools import islice
13
+ from pathlib import Path
14
+ from typing import Any, Callable, Generator, List, Optional, Tuple
15
+
16
+ import numpy as np
17
+ import yaml
18
+ from jinja2 import BaseLoader, Environment, StrictUndefined
19
+
20
+
21
+ SPACING = " " * 47
22
+
23
+ HIGHER_IS_BETTER_SYMBOLS = {
24
+ True: "↑",
25
+ False: "↓",
26
+ }
27
+
28
+
29
+ def setup_logging(verbosity=logging.INFO):
30
+ # Configure the root logger
31
+ class CustomFormatter(logging.Formatter):
32
+ def format(self, record):
33
+ if record.name.startswith("lm_eval."):
34
+ record.name = record.name[len("lm_eval.") :]
35
+ return super().format(record)
36
+
37
+ formatter = CustomFormatter(
38
+ "%(asctime)s %(levelname)-8s [%(name)s:%(lineno)d] %(message)s",
39
+ datefmt="%Y-%m-%d:%H:%M:%S",
40
+ )
41
+
42
+ log_level = os.environ.get("LOGLEVEL", verbosity) or verbosity
43
+
44
+ level_map = {
45
+ "DEBUG": logging.DEBUG,
46
+ "INFO": logging.INFO,
47
+ "WARNING": logging.WARNING,
48
+ "ERROR": logging.ERROR,
49
+ "CRITICAL": logging.CRITICAL,
50
+ }
51
+
52
+ log_level = level_map.get(str(log_level).upper(), logging.INFO)
53
+
54
+ if not logging.root.handlers:
55
+ handler = logging.StreamHandler()
56
+ handler.setFormatter(formatter)
57
+
58
+ root_logger = logging.getLogger()
59
+ root_logger.addHandler(handler)
60
+ root_logger.setLevel(log_level)
61
+
62
+ if log_level == logging.DEBUG:
63
+ third_party_loggers = ["urllib3", "filelock", "fsspec"]
64
+ for logger_name in third_party_loggers:
65
+ logging.getLogger(logger_name).setLevel(logging.INFO)
66
+ else:
67
+ logging.getLogger().setLevel(log_level)
68
+
69
+
70
+ def hash_string(string: str) -> str:
71
+ return hashlib.sha256(string.encode("utf-8")).hexdigest()
72
+
73
+
74
+ def escaped_split(text, sep_char, maxsplit=-1):
75
+ """Split text into a list on occurrences of the given separation
76
+ character `sep_char`. The separation character may be escaped by a
77
+ backslash to avoid splitting at that location.
78
+
79
+ The separation character must be a string of size 1.
80
+
81
+ If `maxsplit` is given, at most `maxsplit` splits are done (thus,
82
+ the list will have at most `maxsplit + 1` elements). If `maxsplit`
83
+ is not specified or less than 0, then there is no limit on the
84
+ number of splits (all possible splits are made).
85
+ """
86
+ assert len(sep_char) == 1, (
87
+ "separation string must be a single character for escaped splitting"
88
+ )
89
+
90
+ if maxsplit == 0:
91
+ return text
92
+ maxsplit = max(0, maxsplit)
93
+
94
+ return re.split(r"(?<!\\)" + sep_char, text, maxsplit)
95
+
96
+
97
+ def handle_arg_string(arg):
98
+ if arg.lower() == "true":
99
+ return True
100
+ elif arg.lower() == "false":
101
+ return False
102
+ elif arg.isnumeric():
103
+ return int(arg)
104
+ try:
105
+ return float(arg)
106
+ except ValueError:
107
+ return arg
108
+
109
+
110
+ def handle_non_serializable(o):
111
+ if isinstance(o, np.int64) or isinstance(o, np.int32):
112
+ return int(o)
113
+ elif isinstance(o, set):
114
+ return list(o)
115
+ else:
116
+ return str(o)
117
+
118
+
119
+ def sanitize_list(sub):
120
+ """
121
+ Takes possible nested list and recursively converts all inner component to strings
122
+ """
123
+ if isinstance(sub, list):
124
+ return [sanitize_list(item) for item in sub]
125
+ if isinstance(sub, tuple):
126
+ return tuple(sanitize_list(item) for item in sub)
127
+ else:
128
+ return str(sub)
129
+
130
+
131
+ def simple_parse_args_string(args_string: Optional[str]) -> dict:
132
+ """
133
+ Parses something like
134
+ args1=val1,arg2=val2
135
+ Into a dictionary
136
+ """
137
+ if args_string is None:
138
+ return {}
139
+ args_string = args_string.strip()
140
+ if not args_string:
141
+ return {}
142
+ arg_list = [arg for arg in args_string.split(",") if arg]
143
+ args_dict = {
144
+ kv[0]: handle_arg_string("=".join(kv[1:]))
145
+ for kv in [arg.split("=") for arg in arg_list]
146
+ }
147
+ return args_dict
148
+
149
+
150
+ def join_iters(iters):
151
+ for iter in iters:
152
+ yield from iter
153
+
154
+
155
+ def group(arr, fn):
156
+ res = collections.defaultdict(list)
157
+
158
+ for ob in arr:
159
+ res[fn(ob)].append(ob)
160
+
161
+ return list(res.values())
162
+
163
+
164
+ # Returns a list containing all values of the source_list that
165
+ # match at least one of the patterns
166
+ def pattern_match(patterns, source_list):
167
+ if isinstance(patterns, str):
168
+ patterns = [patterns]
169
+
170
+ task_names = set()
171
+ for pattern in patterns:
172
+ for matching in fnmatch.filter(source_list, pattern):
173
+ task_names.add(matching)
174
+ return sorted(list(task_names))
175
+
176
+
177
+ def softmax(x) -> np.ndarray:
178
+ """Compute softmax values for each sets of scores in x."""
179
+ e_x = np.exp(x - np.max(x))
180
+ return e_x / e_x.sum()
181
+
182
+
183
+ def general_detokenize(string) -> str:
184
+ string = string.replace(" n't", "n't")
185
+ string = string.replace(" )", ")")
186
+ string = string.replace("( ", "(")
187
+ string = string.replace('" ', '"')
188
+ string = string.replace(' "', '"')
189
+ string = re.sub(r" (['.,])", r"\1", string)
190
+ return string
191
+
192
+
193
+ def get_file_task_name(filename: str) -> str:
194
+ """
195
+ Given the sample results filenames, extracts and returns the task name.
196
+ """
197
+ return filename[filename.find("_") + 1 : filename.rfind("_")]
198
+
199
+
200
+ def get_file_datetime(filename: str) -> str:
201
+ """
202
+ Given the results and sample results filenames, extracts and returns the datetime.
203
+ """
204
+ return filename[filename.rfind("_") + 1 :].replace(".jsonl", "")
205
+
206
+
207
+ def sanitize_model_name(model_name: str) -> str:
208
+ """
209
+ Given the model name, returns a sanitized version of it.
210
+ """
211
+ return re.sub(r"[\"<>:/\|\\?\*\[\]]+", "__", model_name)
212
+
213
+
214
+ def sanitize_task_name(task_name: str) -> str:
215
+ """
216
+ Given the task name, returns a sanitized version of it.
217
+ """
218
+ return re.sub(r"\W", "_", task_name)
219
+
220
+
221
+ def get_latest_filename(filenames: List[str]) -> str:
222
+ """
223
+ Given a list of filenames, returns the filename with the latest datetime.
224
+ """
225
+ return max(filenames, key=lambda f: get_file_datetime(f))
226
+
227
+
228
+ def get_results_filenames(filenames: List[str]) -> List[str]:
229
+ """
230
+ Extracts filenames that correspond to aggregated results.
231
+ """
232
+ return [f for f in filenames if "/results_" in f and ".json" in f]
233
+
234
+
235
+ def get_sample_results_filenames(filenames: List[str]) -> List[str]:
236
+ """
237
+ Extracts filenames that correspond to sample results.
238
+ """
239
+ return [f for f in filenames if "/samples_" in f and ".json" in f]
240
+
241
+
242
+ def get_rolling_token_windows(
243
+ token_list: List[int], prefix_token: int, max_seq_len: int, context_len: int
244
+ ) -> Generator[Tuple[List[int], List[int]], None, None]:
245
+ """
246
+ - context_len allows for a rolling window context, allowing each prediction window to potentially
247
+ condition on some context
248
+
249
+ :param token_list: list
250
+ List of tokens to be PREDICTED
251
+ :param max_seq_len: int
252
+ max_seq_len of model (or max_seq_len we want to use)
253
+ :param context_len: int
254
+ Amount of desired token context for prediction. Needs to be at least 1.
255
+ :param prefix_token: token
256
+ Dummy token like <eos> so the first token has something to condition on
257
+ :return: generator
258
+ Generator of tuples
259
+ (input_tokens, pred_tokens)
260
+ Note: Score only the last len(pred_tokens) logits of the LM
261
+ """
262
+ assert 1 <= context_len <= max_seq_len
263
+ if not token_list:
264
+ return
265
+ # +1 offset, going from input->preds
266
+ pred_len = max_seq_len - context_len + 1
267
+ predicted = 0
268
+
269
+ # Special handling for first window: predict all tokens
270
+ first_seq_len = min(max_seq_len, len(token_list))
271
+ yield [prefix_token] + token_list[: first_seq_len - 1], token_list[:first_seq_len]
272
+ predicted += first_seq_len
273
+
274
+ while predicted < len(token_list):
275
+ window_pred_len = min(len(token_list) - predicted, pred_len)
276
+ window_end = predicted + window_pred_len
277
+
278
+ yield (
279
+ token_list[window_end - max_seq_len - 1 : window_end - 1],
280
+ token_list[window_end - window_pred_len : window_end],
281
+ )
282
+ predicted += window_pred_len
283
+
284
+
285
+ def make_disjoint_window(
286
+ pair: Tuple[List[int], List[int]],
287
+ ) -> Tuple[List[int], List[int]]:
288
+ """Takes output from get_rolling_token_windows and makes the context not overlap with the continuation"""
289
+ a, b = pair
290
+ return a[: len(a) - (len(b) - 1)], b
291
+
292
+
293
+ class EnhancedJSONEncoder(json.JSONEncoder):
294
+ """
295
+ Provides a proper json encoding for the loggers and trackers json dumps.
296
+ Notably manages the json encoding of dataclasses.
297
+ """
298
+
299
+ def default(self, o):
300
+ if is_dataclass(o):
301
+ return asdict(o)
302
+ return super().default(o)
303
+
304
+
305
+ class Reorderer:
306
+ def __init__(self, arr: List[Any], fn: Callable) -> None:
307
+ """Reorder an array according to some function
308
+
309
+ Args:
310
+ arr (List[Any]): The initial array
311
+ fn (Callable[[Any], Any]): A function to determine the priority of elements
312
+ """
313
+ self.size = len(arr)
314
+ arr = list(enumerate(arr))
315
+ arr = group(arr, lambda x: fn(x[1]))
316
+ # arr = [([y[0] for y in x], x[0][1]) for x in arr]
317
+ # TODO: overhaul reorderer. It currently grouped requests by content but we don't want this
318
+ arr = [([y[0]], x[0][1]) for x in arr for y in x]
319
+ arr.sort(key=lambda x: fn(x[1]))
320
+
321
+ self.arr = arr
322
+
323
+ def get_reordered(self):
324
+ """Gets the reordered array
325
+
326
+ Returns:
327
+ List[Any]: The reordered array
328
+ """
329
+ return [x[1] for x in self.arr]
330
+
331
+ def get_original(self, newarr):
332
+ """Restores the original order of a new array based on the old array's order
333
+
334
+ Args:
335
+ newarr (List[Any]): The array to be restored
336
+
337
+ Returns:
338
+ List[Any]: The array restored to the original order
339
+ """
340
+ res = [None] * self.size
341
+ cov = [False] * self.size
342
+
343
+ for (inds, _), v in zip(self.arr, newarr):
344
+ for ind in inds:
345
+ res[ind] = v
346
+ cov[ind] = True
347
+
348
+ assert all(cov)
349
+
350
+ return res
351
+
352
+
353
+ def make_table(result_dict, column: str = "results", sort_results: bool = False):
354
+ """Generate table of results."""
355
+ from pytablewriter import LatexTableWriter, MarkdownTableWriter
356
+
357
+ if column == "results":
358
+ column_name = "Tasks"
359
+ elif column == "groups":
360
+ column_name = "Groups"
361
+
362
+ all_headers = [
363
+ column_name,
364
+ "Version",
365
+ "Filter",
366
+ "n-shot",
367
+ "Metric",
368
+ "",
369
+ "Value",
370
+ "",
371
+ "Stderr",
372
+ ]
373
+
374
+ md_writer = MarkdownTableWriter()
375
+ latex_writer = LatexTableWriter()
376
+ md_writer.headers = all_headers
377
+ latex_writer.headers = all_headers
378
+
379
+ values = []
380
+
381
+ keys = result_dict[column].keys()
382
+ if sort_results:
383
+ # sort entries alphabetically by task or group name.
384
+ # NOTE: we default here to false, because order matters for multi-level table printing a la mmlu.
385
+ # sorting here would mess that up
386
+ keys = sorted(keys)
387
+ for k in keys:
388
+ dic = result_dict[column][k]
389
+ version = result_dict["versions"].get(k, " N/A")
390
+ n = str(result_dict.get("n-shot", " ").get(k, " "))
391
+ higher_is_better = result_dict.get("higher_is_better", {}).get(k, {})
392
+
393
+ if "alias" in dic:
394
+ k = dic.pop("alias")
395
+
396
+ metric_items = dic.items()
397
+ metric_items = sorted(metric_items)
398
+
399
+ for (mf), v in metric_items:
400
+ m, _, f = mf.partition(",")
401
+ if m.endswith("_stderr"):
402
+ continue
403
+
404
+ hib = HIGHER_IS_BETTER_SYMBOLS.get(higher_is_better.get(m), "")
405
+
406
+ v = "%.4f" % v if isinstance(v, float) else v
407
+
408
+ if m + "_stderr" + "," + f in dic:
409
+ se = dic[m + "_stderr" + "," + f]
410
+ se = " N/A" if se == "N/A" else "%.4f" % se
411
+ values.append([k, version, f, n, m, hib, v, "±", se])
412
+ else:
413
+ values.append([k, version, f, n, m, hib, v, "", ""])
414
+ k = ""
415
+ version = ""
416
+ md_writer.value_matrix = values
417
+ latex_writer.value_matrix = values
418
+
419
+ # todo: make latex table look good
420
+ # print(latex_writer.dumps())
421
+
422
+ return md_writer.dumps()
423
+
424
+
425
+ def positional_deprecated(fn):
426
+ """
427
+ A decorator to nudge users into passing only keyword args (`kwargs`) to the
428
+ wrapped function, `fn`.
429
+ """
430
+
431
+ @functools.wraps(fn)
432
+ def _wrapper(*args, **kwargs):
433
+ if len(args) != 1 if inspect.ismethod(fn) else 0:
434
+ print(
435
+ f"WARNING: using {fn.__name__} with positional arguments is "
436
+ "deprecated and will be disallowed in a future version of "
437
+ "lm-evaluation-harness!"
438
+ )
439
+ return fn(*args, **kwargs)
440
+
441
+ return _wrapper
442
+
443
+
444
+ def ignore_constructor(loader, node):
445
+ return node
446
+
447
+
448
+ def import_function(loader: yaml.Loader, node, yaml_path: Path):
449
+ function_name = loader.construct_scalar(node)
450
+
451
+ *module_name, function_name = function_name.split(".")
452
+ if isinstance(module_name, list):
453
+ module_name = ".".join(module_name)
454
+ module_path = yaml_path.parent / f"{module_name}.py"
455
+
456
+ spec = importlib.util.spec_from_file_location(module_name, module_path.as_posix())
457
+
458
+ if spec is None:
459
+ raise ImportError(f"Could not import module {module_name} from {module_path}.")
460
+ module = importlib.util.module_from_spec(spec)
461
+
462
+ if spec.loader is None:
463
+ raise ImportError(f"Module loader is None, {module_name} from {module_path}.")
464
+ spec.loader.exec_module(module)
465
+
466
+ function = getattr(module, function_name)
467
+ return function
468
+
469
+
470
+ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full"):
471
+ if mode == "simple":
472
+ constructor_fn = ignore_constructor
473
+ elif mode == "full":
474
+ if yaml_path is None:
475
+ raise ValueError("yaml_path must be provided if mode is 'full'.")
476
+ # Attach yaml_path to the import function so that it can be used later
477
+ constructor_fn = functools.partial(import_function, yaml_path=Path(yaml_path))
478
+
479
+ loader = yaml.CLoader if yaml.__with_libyaml__ else yaml.FullLoader
480
+ # Add the import_function constructor to the YAML loader
481
+ yaml.add_constructor("!function", constructor_fn, Loader=loader)
482
+ if yaml_config is None:
483
+ with open(yaml_path, "rb") as file:
484
+ yaml_config = yaml.load(file, Loader=loader)
485
+
486
+ if yaml_dir is None:
487
+ yaml_dir = os.path.dirname(yaml_path)
488
+
489
+ assert yaml_dir is not None
490
+
491
+ if "include" in yaml_config:
492
+ include_path = yaml_config["include"]
493
+ del yaml_config["include"]
494
+
495
+ if isinstance(include_path, str):
496
+ include_path = [include_path]
497
+
498
+ # Load from the last one first
499
+ include_path.reverse()
500
+ final_yaml_config = {}
501
+ for path in include_path:
502
+ # Assumes that path is a full path.
503
+ # If not found, assume the included yaml
504
+ # is in the same dir as the original yaml
505
+ if not os.path.isfile(path):
506
+ path = os.path.join(yaml_dir, path)
507
+
508
+ try:
509
+ included_yaml_config = load_yaml_config(yaml_path=path, mode=mode)
510
+ final_yaml_config.update(included_yaml_config)
511
+ except Exception as ex:
512
+ # If failed to load, ignore
513
+ raise ex
514
+
515
+ final_yaml_config.update(yaml_config)
516
+ return final_yaml_config
517
+ return yaml_config
518
+
519
+
520
+ def regex_replace(string, pattern, repl, count: int = 0):
521
+ """Implements the `re.sub` function as a custom Jinja filter."""
522
+ return re.sub(pattern, repl, string, count=count)
523
+
524
+
525
+ env = Environment(
526
+ loader=BaseLoader, undefined=StrictUndefined, keep_trailing_newline=True
527
+ )
528
+ env.filters["regex_replace"] = regex_replace
529
+
530
+
531
+ def apply_template(template: str, doc: dict) -> str:
532
+ rtemplate = env.from_string(template)
533
+ return rtemplate.render(**doc)
534
+
535
+
536
+ def create_iterator(raw_iterator, *, rank=0, world_size=1, limit=None):
537
+ """
538
+ Method for creating a (potentially) sliced and limited
539
+ iterator from a raw document iterator. Used for splitting data
540
+ among ranks in multigpu setting or only pulling a sample of documents
541
+ """
542
+ return islice(raw_iterator, rank, limit, world_size)
543
+
544
+
545
+ def weighted_f1_score(items):
546
+ from sklearn.metrics import f1_score
547
+
548
+ unzipped_list = list(zip(*items))
549
+ golds = unzipped_list[0]
550
+ preds = unzipped_list[1]
551
+ fscore = f1_score(golds, preds, average="weighted")
552
+ return fscore
Prism/Dream/Dream_Prism/eval_instruct/pyproject.toml ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=40.8.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "lm_eval"
7
+ version = "0.4.8"
8
+ authors = [
9
+ {name="EleutherAI", email="contact@eleuther.ai"}
10
+ ]
11
+ description = "A framework for evaluating language models"
12
+ readme = "README.md"
13
+ classifiers = [
14
+ "Development Status :: 3 - Alpha",
15
+ "Programming Language :: Python :: 3",
16
+ "License :: OSI Approved :: MIT License",
17
+ "Operating System :: OS Independent",
18
+ ]
19
+ requires-python = ">=3.9"
20
+ license = { "text" = "MIT" }
21
+ dependencies = [
22
+ "accelerate>=0.26.0",
23
+ "evaluate",
24
+ "datasets>=2.16.0",
25
+ "evaluate>=0.4.0",
26
+ "jsonlines",
27
+ "numexpr",
28
+ "peft>=0.2.0",
29
+ "pybind11>=2.6.2",
30
+ "pytablewriter",
31
+ "rouge-score>=0.0.4",
32
+ "sacrebleu>=1.5.0",
33
+ "scikit-learn>=0.24.1",
34
+ "sqlitedict",
35
+ "torch>=1.8",
36
+ "tqdm-multiprocess",
37
+ "transformers>=4.1",
38
+ "zstandard",
39
+ "dill",
40
+ "word2number",
41
+ "more_itertools",
42
+ ]
43
+
44
+ [tool.setuptools.packages.find]
45
+ include = ["lm_eval*"]
46
+
47
+ # required to include yaml files in pip installation
48
+ [tool.setuptools.package-data]
49
+ lm_eval = ["**/*.yaml", "tasks/**/*"]
50
+
51
+ [project.scripts]
52
+ lm-eval = "lm_eval.__main__:cli_evaluate"
53
+ lm_eval = "lm_eval.__main__:cli_evaluate"
54
+
55
+ [project.urls]
56
+ Homepage = "https://github.com/EleutherAI/lm-evaluation-harness"
57
+ Repository = "https://github.com/EleutherAI/lm-evaluation-harness"
58
+
59
+ [project.optional-dependencies]
60
+ api = ["requests", "aiohttp", "tenacity", "tqdm", "tiktoken"]
61
+ audiolm_qwen = ["librosa", "soundfile"]
62
+ deepsparse = ["deepsparse-nightly[llm]>=1.8.0.20240404"]
63
+ dev = ["pytest", "pytest-cov", "pytest-xdist", "pre-commit", "mypy", "unitxt"]
64
+ gptq = ["auto-gptq[triton]>=0.6.0"]
65
+ gptqmodel = ["gptqmodel>=1.0.9"]
66
+ hf_transfer = ["hf_transfer"]
67
+ ibm_watsonx_ai = ["ibm_watsonx_ai>=1.1.22", "python-dotenv"]
68
+ ifeval = ["langdetect", "immutabledict", "nltk>=3.9.1"]
69
+ ipex = ["optimum"]
70
+ japanese_leaderboard = ["emoji==2.14.0", "neologdn==0.5.3", "fugashi[unidic-lite]", "rouge_score>=0.1.2"]
71
+ longbench=["jeiba", "fuzzywuzzy", "rouge"]
72
+ mamba = ["mamba_ssm", "causal-conv1d==1.0.2"]
73
+ math = ["sympy>=1.12", "antlr4-python3-runtime==4.11", "math_verify[antlr4_11_0]"]
74
+ multilingual = ["nagisa>=0.2.7", "jieba>=0.42.1", "pycountry"]
75
+ neuronx = ["optimum[neuronx]"]
76
+ optimum = ["optimum[openvino]"]
77
+ promptsource = ["promptsource>=0.2.3"]
78
+ ruler = ["nltk", "wonderwords", "scipy"]
79
+ sae_lens = ["sae_lens"]
80
+ sentencepiece = ["sentencepiece>=0.1.98"]
81
+ sparseml = ["sparseml-nightly[llm]>=1.8.0.20240404"]
82
+ sparsify = ["sparsify"]
83
+ testing = ["pytest", "pytest-cov", "pytest-xdist"]
84
+ vllm = ["vllm>=0.4.2"]
85
+ wandb = ["wandb>=0.16.3", "pandas", "numpy"]
86
+ zeno = ["pandas", "zeno-client"]
87
+ all = [
88
+ "lm_eval[api]",
89
+ "lm_eval[audiolm_qwen]",
90
+ "lm_eval[deepsparse]",
91
+ "lm_eval[dev]",
92
+ "lm_eval[gptq]",
93
+ "lm_eval[gptqmodel]",
94
+ "lm_eval[hf_transfer]",
95
+ "lm_eval[ibm_watsonx_ai]",
96
+ "lm_eval[ifeval]",
97
+ "lm_eval[ipex]",
98
+ "lm_eval[japanese_leaderboard]",
99
+ "lm_eval[longbench]",
100
+ "lm_eval[mamba]",
101
+ "lm_eval[math]",
102
+ "lm_eval[multilingual]",
103
+ "lm_eval[neuronx]",
104
+ "lm_eval[optimum]",
105
+ "lm_eval[promptsource]",
106
+ "lm_eval[ruler]",
107
+ "lm_eval[sae_lens]",
108
+ "lm_eval[sentencepiece]",
109
+ "lm_eval[sparseml]",
110
+ "lm_eval[sparsify]",
111
+ "lm_eval[testing]",
112
+ "lm_eval[vllm]",
113
+ "lm_eval[wandb]",
114
+ "lm_eval[zeno]",
115
+ ]
116
+
117
+ [tool.pymarkdown]
118
+ plugins.md013.enabled = false # line-length
119
+ plugins.md024.allow_different_nesting = true # no-duplicate-headers
120
+ plugins.md025.enabled = false # single-header
121
+ plugins.md028.enabled = false # no-blanks-blockquote
122
+ plugins.md029.allow_extended_start_values = true # ol-prefix
123
+ plugins.md034.enabled = false # no-bare-urls
124
+
125
+ [tool.ruff.lint]
126
+ extend-select = ["I"]
127
+
128
+ [tool.ruff.lint.isort]
129
+ lines-after-imports = 2
130
+ known-first-party = ["lm_eval"]
131
+
132
+ [tool.ruff.lint.extend-per-file-ignores]
133
+ "__init__.py" = ["F401","F402","F403"]
134
+ "utils.py" = ["F401"]
Prism/Dream/Dream_Prism/eval_instruct/requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ -e .
Prism/Dream/Dream_Prism/eval_instruct/setup.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import setuptools
2
+
3
+
4
+ # This is to make sure that the package supports editable installs
5
+ setuptools.setup()
Prism/Dream/Dream_Prism/metrics/gsmk8_eval.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import os
4
+ import glob
5
+ import math
6
+ import argparse
7
+ from collections import Counter
8
+
9
+
10
+ RES_PATH = "<PATH_TO_RESULTS_JSONL>"
11
+
12
+
13
+ def last_boxed_only_string(string):
14
+ if not string: return None
15
+ idx = string.rfind("\\boxed")
16
+ if "\\boxed " in string:
17
+ return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
18
+ if idx < 0:
19
+ idx = string.rfind("\\fbox")
20
+ if idx < 0: return None
21
+ i = idx
22
+ right_brace_idx = None
23
+ num_left_braces_open = 0
24
+ while i < len(string):
25
+ if string[i] == "{":
26
+ num_left_braces_open += 1
27
+ if string[i] == "}":
28
+ num_left_braces_open -= 1
29
+ if num_left_braces_open == 0:
30
+ right_brace_idx = i
31
+ break
32
+ i += 1
33
+ return string[idx : right_brace_idx + 1] if right_brace_idx else None
34
+
35
+ def remove_boxed(s):
36
+ if not s: return None
37
+ if "\\boxed " in s: return s[len("\\boxed ") :]
38
+ if "\\boxed{" in s and s.endswith("}"): return s[len("\\boxed{") : -1]
39
+ return s
40
+
41
+ def strip_string(string):
42
+ if string is None: return ""
43
+ string = str(string).strip()
44
+ while re.search(r"(\d),(\d{3})", string):
45
+ string = re.sub(r"(\d),(\d{3})", r"\1\2", string)
46
+ string = string.replace("\n", "").replace("\\!", "")
47
+ string = string.replace("tfrac", "frac").replace("dfrac", "frac")
48
+ string = string.replace("\\left", "").replace("\\right", "")
49
+ string = string.replace("^{\\circ}", "").replace("^\\circ", "")
50
+ string = string.replace("\\$", "").replace("\\%", "").replace("\%", "")
51
+ if "=" in string and len(string.split("=")[0]) <= 3:
52
+ string = string.split("=")[1].strip()
53
+ string = string.replace(" ", "")
54
+ return string
55
+
56
+ def extract_answer_gsm8k(text):
57
+ if not text: return ""
58
+ boxed = last_boxed_only_string(text)
59
+ if boxed:
60
+ ans = remove_boxed(boxed)
61
+ if ans: return strip_string(ans)
62
+
63
+ tag_match = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL)
64
+ if tag_match:
65
+ return strip_string(tag_match.group(1))
66
+
67
+ nums = re.findall(r"-?\d+\.?\d*", text[-50:])
68
+ if nums:
69
+ return strip_string(nums[-1])
70
+
71
+ return ""
72
+
73
+ def extract_gold_gsm8k(target_str):
74
+ if "####" in target_str:
75
+ return strip_string(target_str.split("####")[-1])
76
+ return strip_string(target_str)
77
+
78
+ def is_equiv(pred, gold):
79
+ p = strip_string(pred)
80
+ g = strip_string(gold)
81
+ try:
82
+ return math.isclose(float(p), float(g), rel_tol=1e-4)
83
+ except:
84
+ return p == g
85
+
86
+ def run_evaluation(target_path):
87
+ if os.path.isdir(target_path):
88
+ jsonl_files = glob.glob(os.path.join(target_path, "*.jsonl"))
89
+ else:
90
+ jsonl_files = [target_path]
91
+
92
+ for file_path in jsonl_files:
93
+ print(f">>> 正在评测: {file_path}")
94
+ detailed_results = []
95
+ correct_count = 0
96
+ total_count = 0
97
+ nfe_list = []
98
+ svf_list = []
99
+
100
+ with open(file_path, 'r', encoding='utf-8') as f:
101
+ for line in f:
102
+ if not line.strip(): continue
103
+ item = json.loads(line)
104
+ doc = item.get("doc", {})
105
+
106
+ ground_truth = extract_gold_gsm8k(str(item.get("target", "")))
107
+ nfe_list.append(item.get("nfe", 0))
108
+ svf_list.append(item.get("svf_calls", 0))
109
+
110
+ ans_stats = {}
111
+
112
+ trajectories = item.get("all_trajectories", [])
113
+ if not trajectories:
114
+ resps = item.get("resps", [])
115
+ for r in resps:
116
+ text = r[0] if isinstance(r, list) else r
117
+ trajectories.append({"resp": text, "score": 0.0})
118
+
119
+ for traj in trajectories:
120
+ raw_text = traj.get("resp", "")
121
+ score = traj.get("score", -float('inf'))
122
+ extracted = extract_answer_gsm8k(raw_text)
123
+
124
+ if not extracted: continue
125
+
126
+ norm = strip_string(extracted)
127
+ if norm not in ans_stats:
128
+ ans_stats[norm] = {"count": 0, "max_score": -float('inf'), "original": extracted}
129
+
130
+ ans_stats[norm]["count"] += 1
131
+ if score > ans_stats[norm]["max_score"]:
132
+ ans_stats[norm]["max_score"] = score
133
+ ans_stats[norm]["original"] = extracted
134
+
135
+ if not ans_stats:
136
+ best_pred = ""
137
+ else:
138
+ sorted_norms = sorted(
139
+ ans_stats.keys(),
140
+ key=lambda x: (ans_stats[x]["count"], ans_stats[x]["max_score"]),
141
+ reverse=True
142
+ )
143
+ best_norm = sorted_norms[0]
144
+ best_pred = ans_stats[best_norm]["original"]
145
+
146
+ ans_correct = is_equiv(best_pred, ground_truth)
147
+ if ans_correct:
148
+ correct_count += 1
149
+ total_count += 1
150
+
151
+ detailed_results.append({
152
+ "question": doc.get("question", "N/A"),
153
+ "final_voted_answer": best_pred,
154
+ "ground_truth": ground_truth,
155
+ "is_correct": ans_correct,
156
+ "nfe": item.get("nfe", 0),
157
+ "svf_calls": item.get("svf_calls", 0)
158
+ })
159
+
160
+ accuracy = (correct_count / total_count * 100) if total_count > 0 else 0
161
+
162
+ avg_nfe = int(round(sum(nfe_list) / len(nfe_list))) if nfe_list else 0
163
+ avg_svf = int(round(sum(svf_list) / len(svf_list))) if svf_list else 0
164
+
165
+ print(f"Accuracy: {accuracy:.2f}% | NFE: {avg_nfe} | SVF: {avg_svf} ---")
166
+
167
+ output_name = f"eval_voted_{os.path.basename(file_path).replace('.jsonl', '.json')}"
168
+ output_path = os.path.join(os.path.dirname(file_path), output_name)
169
+
170
+ final_report = {
171
+ "summary": {
172
+ "accuracy": f"{accuracy:.2f}%",
173
+ "correct": correct_count,
174
+ "total": total_count,
175
+ "nfe": avg_nfe,
176
+ "svf_calls": avg_svf
177
+ },
178
+ "details": detailed_results
179
+ }
180
+
181
+ with open(output_path, 'w', encoding='utf-8') as out_f:
182
+ json.dump(final_report, out_f, ensure_ascii=False, indent=4)
183
+
184
+ if __name__ == "__main__":
185
+ parser = argparse.ArgumentParser()
186
+ parser.add_argument("-r", "--res_path", type=str, default=RES_PATH)
187
+ args = parser.parse_args()
188
+ run_evaluation(args.res_path)
Prism/Dream/Dream_Prism/metrics/humaneval_eval.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import ast
5
+ import re
6
+ import glob
7
+ import argparse
8
+ import textwrap
9
+ import evaluate as hf_evaluate
10
+ from collections import Counter
11
+
12
+ os.environ["HF_ALLOW_CODE_EVAL"] = "1"
13
+
14
+ RES_PATH = "<PATH_TO_RESULTS_JSONL>"
15
+
16
+ def strict_dedent(text: str) -> str:
17
+ lines = text.split('\n')
18
+ while lines and not lines[0].strip(): lines.pop(0)
19
+ while lines and not lines[-1].strip(): lines.pop()
20
+
21
+ if not lines:
22
+ return ""
23
+
24
+ min_indent = None
25
+ for line in lines:
26
+ if line.strip():
27
+ indent = len(line) - len(line.lstrip())
28
+ if min_indent is None or indent < min_indent:
29
+ min_indent = indent
30
+
31
+ if min_indent is None:
32
+ min_indent = 0
33
+
34
+ dedented_lines = []
35
+ for line in lines:
36
+ if line.strip():
37
+ if len(line) >= min_indent:
38
+ dedented_lines.append(line[min_indent:])
39
+ else:
40
+ dedented_lines.append(line.lstrip())
41
+ else:
42
+ dedented_lines.append("")
43
+
44
+ return "\n".join(dedented_lines)
45
+
46
+ def extract_python_code(text: str) -> str:
47
+ if not text:
48
+ return ""
49
+
50
+ text = text.replace("<|role_end|>", "").replace("<|endoftext|>", "").replace("<|notification_end|>", "")
51
+
52
+ tag_match = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL)
53
+ if tag_match:
54
+ text = tag_match.group(1)
55
+
56
+ code_block_pattern = re.compile(r"```(?:python)?\n?(.*?)```", re.DOTALL)
57
+ match = code_block_pattern.search(text)
58
+
59
+ if match:
60
+ content = match.group(1)
61
+ else:
62
+ if "```" in text:
63
+ content = text.split("```")[0]
64
+ else:
65
+ lines = text.split('\n')
66
+ cleaned_lines = []
67
+ stop_words = ["Explanation:", "Example:", "Test Case:", "Output:", "Here are the tests:"]
68
+ for line in lines:
69
+ if any(sw in line for sw in stop_words):
70
+ break
71
+ cleaned_lines.append(line)
72
+ content = "\n".join(cleaned_lines)
73
+
74
+ return strict_dedent(content)
75
+
76
+ def normalize_code(code: str) -> str:
77
+ try:
78
+ tree = ast.parse(code)
79
+ for node in ast.walk(tree):
80
+ if isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Module)):
81
+ if (node.body and isinstance(node.body[0], ast.Expr) and
82
+ isinstance(node.body[0].value, ast.Constant) and isinstance(node.body[0].value.value, str)):
83
+ node.body.pop(0)
84
+ return ast.unparse(tree).strip()
85
+ except:
86
+ return re.sub(r"\s+", "", code)
87
+
88
+ def sanitize(prompt: str, completion: str, entry_point: str) -> str:
89
+ if f"def {entry_point}" in completion:
90
+ imports = [line for line in prompt.split("\n") if line.startswith("import ") or line.startswith("from ")]
91
+ return "\n".join(imports) + "\n" + completion
92
+
93
+ clean_body = strict_dedent(completion)
94
+ if not clean_body:
95
+ return prompt
96
+
97
+ indented_body = "\n".join([" " + line if line.strip() else "" for line in clean_body.split('\n')])
98
+ return prompt.strip() + "\n" + indented_body
99
+
100
+ def perform_majority_voting(trajectories, prompt, entry_point):
101
+ candidate_stats = {}
102
+
103
+ for item in trajectories:
104
+ if isinstance(item, dict):
105
+ raw_text = item.get("resp", "")
106
+ score = item.get("score", 0.0)
107
+ else:
108
+ raw_text = str(item[0] if isinstance(item, list) else item)
109
+ score = 0.0
110
+
111
+ extracted_code = extract_python_code(raw_text)
112
+ full_code = sanitize(prompt, extracted_code, entry_point)
113
+
114
+ is_valid = False
115
+ try:
116
+ ast.parse(full_code)
117
+ is_valid = True
118
+ except:
119
+ is_valid = False
120
+
121
+ norm_key = normalize_code(full_code)
122
+ if not norm_key: continue
123
+
124
+ if norm_key not in candidate_stats:
125
+ candidate_stats[norm_key] = {
126
+ "count": 0,
127
+ "max_score": -float("inf"),
128
+ "code": full_code,
129
+ "is_valid": is_valid
130
+ }
131
+
132
+ candidate_stats[norm_key]["count"] += 1
133
+ candidate_stats[norm_key]["max_score"] = max(candidate_stats[norm_key]["max_score"], score)
134
+
135
+ if not candidate_stats:
136
+ return prompt
137
+
138
+ sorted_candidates = sorted(
139
+ candidate_stats.values(),
140
+ key=lambda x: (x["is_valid"], x["count"], x["max_score"]),
141
+ reverse=True
142
+ )
143
+
144
+ return sorted_candidates[0]["code"]
145
+
146
+ def run_evaluation(target_path):
147
+ if os.path.isdir(target_path):
148
+ jsonl_files = glob.glob(os.path.join(target_path, "*.jsonl"))
149
+ else:
150
+ jsonl_files = [target_path]
151
+
152
+ try:
153
+ code_eval = hf_evaluate.load("code_eval")
154
+ except Exception as e:
155
+ print(f"Error loading code_eval: {e}")
156
+ return
157
+
158
+ for file_path in jsonl_files:
159
+ print(f">>> 正在评测文件: {file_path}")
160
+
161
+ all_voted_predictions = []
162
+ all_references = []
163
+ detailed_logs = []
164
+
165
+ nfe_sum = 0
166
+ svf_sum = 0
167
+ valid_samples = 0
168
+
169
+ with open(file_path, 'r', encoding='utf-8') as f:
170
+ for line in f:
171
+ if not line.strip(): continue
172
+ try:
173
+ data = json.loads(line)
174
+ except:
175
+ continue
176
+
177
+ doc = data.get("doc", {})
178
+ task_id = doc.get("task_id", f"Task_{valid_samples}")
179
+ prompt = doc.get("prompt", "")
180
+ entry_point = doc.get("entry_point", "solution")
181
+ test_code = doc.get("test", "") + f"\ncheck({entry_point})"
182
+
183
+ nfe_sum += data.get("nfe", 0)
184
+ svf_sum += data.get("svf_calls", 0)
185
+ valid_samples += 1
186
+
187
+ trajectories = data.get("all_trajectories", data.get("resps", []))
188
+ voted_code = perform_majority_voting(trajectories, prompt, entry_point)
189
+
190
+ all_voted_predictions.append([voted_code])
191
+ all_references.append(test_code)
192
+
193
+ detailed_logs.append({
194
+ "task_id": task_id,
195
+ "entry_point": entry_point,
196
+ "final_code": voted_code,
197
+ "nfe": data.get("nfe", 0),
198
+ "svf": data.get("svf_calls", 0),
199
+ })
200
+
201
+ if not all_voted_predictions: continue
202
+
203
+ print(f"执行测试中...")
204
+ pass_at_k, exec_results = code_eval.compute(
205
+ references=all_references,
206
+ predictions=all_voted_predictions,
207
+ k=[1]
208
+ )
209
+
210
+ accuracy = pass_at_k.get("pass@1", 0.0) * 100
211
+ avg_nfe = nfe_sum / valid_samples if valid_samples > 0 else 0
212
+ avg_svf = svf_sum / valid_samples if valid_samples > 0 else 0
213
+
214
+ for i, log in enumerate(detailed_logs):
215
+ res = exec_results.get(i, [])
216
+ log["passed"] = res[0][1].get("passed", False) if res else False
217
+ log["exec_msg"] = res[0][1].get("result", "failed") if res else "failed"
218
+
219
+ output_path = file_path.replace(".jsonl", "_voted_result.json")
220
+ final_report = {
221
+ "meta": {"file": file_path, "total_samples": valid_samples},
222
+ "metrics": {"accuracy": f"{accuracy:.2f}%", "avg_nfe": avg_nfe, "avg_svf": avg_svf},
223
+ "details": detailed_logs
224
+ }
225
+
226
+ with open(output_path, 'w', encoding='utf-8') as out_f:
227
+ json.dump(final_report, out_f, ensure_ascii=False, indent=4)
228
+ print(f"Accuracy: {accuracy:.2f}% | SVF: {avg_svf:.1f}\n")
229
+
230
+ if __name__ == "__main__":
231
+ parser = argparse.ArgumentParser()
232
+ parser.add_argument("-r", "--res_path", type=str, default=RES_PATH)
233
+ args = parser.parse_args()
234
+ run_evaluation(args.res_path)
Prism/Dream/Dream_Prism/metrics/math500_eval.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import os
4
+ import math
5
+ import argparse
6
+ from collections import Counter
7
+
8
+ RES_PATH = "<PATH_TO_RESULTS_JSONL>"
9
+
10
+ def extract_answer(text):
11
+ if not text:
12
+ return "", False
13
+ text = text.replace("<|role_end|>", "").replace("<|endoftext|>", "").strip()
14
+
15
+ boxed_pattern = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}"
16
+ all_boxes = re.findall(boxed_pattern, text)
17
+ if all_boxes:
18
+ return all_boxes[-1], True
19
+
20
+ tag_match = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL)
21
+ if tag_match:
22
+ return tag_match.group(1).strip(), True
23
+
24
+ marker = "the answer is"
25
+ if marker in text.lower():
26
+ pos = text.lower().rfind(marker)
27
+ after_text = text[pos + len(marker):].strip()
28
+ after_text = re.sub(r"^[:\s]+", "", after_text)
29
+ return after_text.split('\n')[0].split('$')[0].strip(), True
30
+
31
+ tail = text[-50:].strip()
32
+ nums = re.findall(r"(-?\d+[\./\d]*|\\sqrt\{\d+\}|\(-?\d+.*?\))", tail)
33
+ if nums:
34
+ return nums[-1], False
35
+ return "", False
36
+
37
+ def normalize_math(string):
38
+ if not string: return ""
39
+ string = str(string).lower().strip()
40
+
41
+ string = string.replace("</reasoning>", "").replace("</answer>", "").replace("<answer>", "")
42
+ string = string.replace("...", "").replace("cannot be determined", "")
43
+
44
+ string = re.sub(r"([a-z]+|\\theta|\\alpha|\\pi)\s*=\s*", "", string)
45
+ string = re.sub(r"\\text\{([^}]*)\}", r"\1", string)
46
+ string = re.sub(r"\\(mathbf|mathrm|bold|unit|mbox|operatorname|mathrm)\{([^}]*)\}", r"\2", string)
47
+ string = re.sub(r"\\(d|t)?frac\{([^{}]*)\}\{([^{}]*)\}", r"\2/\3", string)
48
+ string = string.replace("\\!", "").replace("\\ ", "").replace("{", "").replace("}", "")
49
+ string = string.replace("\\left", "").replace("\\right", "")
50
+ string = string.replace("\\$", "").replace("$", "").replace("\\%", "").replace("%", "")
51
+
52
+ units_pattern = r"(units?|cm\^2|cm|inches|inch|square|degrees?|radians?|miles?|per|hour|cents?)"
53
+ string = re.sub(units_pattern, "", string)
54
+ string = string.replace("^{\\circ}", "").replace("^\\circ", "").replace("°", "").replace("\\degree", "")
55
+ string = string.replace("\\pi", "pi")
56
+ string = re.sub(r"(\d),(\d{3})", r"\1\2", string)
57
+ string = string.rstrip(".:,; ").replace(" ", "")
58
+
59
+ if "=" in string:
60
+ string = string.split("=")[-1]
61
+
62
+ return string
63
+
64
+ def is_equiv(pred, gold):
65
+ if not pred: return False
66
+ p, g = normalize_math(pred), normalize_math(gold)
67
+ if p == g: return True
68
+
69
+ if "=" in pred:
70
+ if normalize_math(pred.split("=")[-1]) == g:
71
+ return True
72
+
73
+ try:
74
+ def to_float(s):
75
+ if '/' in s and s.count('/') == 1:
76
+ parts = s.split('/')
77
+ return float(parts[0]) / float(parts[1])
78
+ if '_' in s: s = s.split('_')[0]
79
+ return float(s)
80
+ return math.isclose(to_float(p), to_float(g), rel_tol=1e-4)
81
+ except:
82
+ p_fuzzy = re.sub(r"[^a-z0-9/,\-]", "", p)
83
+ g_fuzzy = re.sub(r"[^a-z0-9/,\-]", "", g)
84
+ return p_fuzzy == g_fuzzy if p_fuzzy else False
85
+
86
+ def run_evaluation(target_path):
87
+ jsonl_files = []
88
+ if os.path.isdir(target_path):
89
+ for root, dirs, files in os.walk(target_path):
90
+ for file in files:
91
+ if file.endswith(".jsonl") and not file.startswith("eval_voted_"):
92
+ jsonl_files.append(os.path.join(root, file))
93
+ else:
94
+ jsonl_files = [target_path]
95
+
96
+ for file_path in jsonl_files:
97
+ print(f">>> 正在评测: {file_path}")
98
+ detailed_results = []
99
+
100
+ voted_correct_count = 0
101
+ total_count = 0
102
+
103
+ nfe_list = []
104
+ svf_list = []
105
+
106
+ with open(file_path, 'r', encoding='utf-8') as f:
107
+ for line in f:
108
+ if not line.strip(): continue
109
+ try:
110
+ item = json.loads(line)
111
+ except:
112
+ continue
113
+
114
+ doc = item.get("doc", {})
115
+ ground_truth = str(item.get("target", doc.get("answer", "")))
116
+
117
+ current_nfe = item.get("nfe", 0)
118
+ nfe_list.append(current_nfe)
119
+ current_svf = item.get("svf_calls", 0)
120
+ svf_list.append(current_svf)
121
+
122
+ ans_stats = {}
123
+ trajectories = item.get("all_trajectories", [])
124
+
125
+ for traj in trajectories:
126
+ raw_text = traj.get("resp", "")
127
+ score = traj.get("score", 0)
128
+
129
+ extracted, _ = extract_answer(raw_text)
130
+ if not extracted: continue
131
+
132
+ norm = normalize_math(extracted)
133
+ if norm not in ans_stats:
134
+ ans_stats[norm] = {
135
+ "count": 0,
136
+ "max_score": -float('inf'),
137
+ "total_weight": 0.0,
138
+ "original": extracted
139
+ }
140
+
141
+ ans_stats[norm]["count"] += 1
142
+ if score > ans_stats[norm]["max_score"]:
143
+ ans_stats[norm]["max_score"] = score
144
+
145
+ try:
146
+ weight = math.exp(score)
147
+ except OverflowError:
148
+ weight = float('inf')
149
+ ans_stats[norm]["total_weight"] += weight
150
+
151
+ if not ans_stats:
152
+ best_pred = ""
153
+ else:
154
+ sorted_norms = sorted(
155
+ ans_stats.keys(),
156
+ key=lambda x: (ans_stats[x]["total_weight"], ans_stats[x]["max_score"], ans_stats[x]["count"]),
157
+ reverse=True
158
+ )
159
+ best_norm = sorted_norms[0]
160
+ best_pred = ans_stats[best_norm]["original"]
161
+
162
+ is_voted_correct = False
163
+ if best_pred and is_equiv(best_pred, ground_truth):
164
+ voted_correct_count += 1
165
+ is_voted_correct = True
166
+
167
+ total_count += 1
168
+
169
+ detailed_results.append({
170
+ "question": doc.get("problem", "N/A"),
171
+ "final_voted_answer": best_pred,
172
+ "ground_truth": ground_truth,
173
+ "is_voted_correct": is_voted_correct,
174
+ "nfe": current_nfe,
175
+ "svf_calls": current_svf
176
+ })
177
+
178
+ accuracy = (voted_correct_count / total_count * 100) if total_count > 0 else 0
179
+
180
+ avg_nfe = sum(nfe_list) / len(nfe_list) if nfe_list else 0
181
+ avg_svf = sum(svf_list) / len(svf_list) if svf_list else 0
182
+
183
+ print(f"--- Accuracy : {accuracy:.2f}% | NFE: {avg_nfe:.1f} | SVF: {avg_svf:.1f} ---")
184
+
185
+ output_name = f"eval_voted_{os.path.basename(file_path).replace('.jsonl', '.json')}"
186
+ output_path = os.path.join(os.path.dirname(file_path), output_name)
187
+
188
+ final_report = {
189
+ "summary": {
190
+ "Accuracy": f"{accuracy:.2f}%",
191
+ "correct_voted_count": voted_correct_count,
192
+ "total": total_count,
193
+ "avg_nfe": avg_nfe,
194
+ "avg_svf": avg_svf
195
+ },
196
+ "details": detailed_results
197
+ }
198
+ with open(output_path, 'w', encoding='utf-8') as out_f:
199
+ json.dump(final_report, out_f, ensure_ascii=False, indent=4)
200
+
201
+ if __name__ == "__main__":
202
+ parser = argparse.ArgumentParser()
203
+ parser.add_argument("-r", "--res_path", type=str, default=RES_PATH)
204
+ args = parser.parse_args()
205
+ run_evaluation(args.res_path)
Prism/Dream/Dream_Prism/metrics/mbpp_eval.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import ast
5
+ import re
6
+ import glob
7
+ import argparse
8
+ import textwrap
9
+ import evaluate as hf_evaluate
10
+ from collections import Counter
11
+
12
+ os.environ["HF_ALLOW_CODE_EVAL"] = "1"
13
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
14
+
15
+ RES_PATH = "<PATH_TO_RESULTS_JSONL>"
16
+
17
+ def strict_dedent(text: str) -> str:
18
+ lines = text.split('\n')
19
+ while lines and not lines[0].strip(): lines.pop(0)
20
+ while lines and not lines[-1].strip(): lines.pop()
21
+
22
+ if not lines:
23
+ return ""
24
+
25
+ min_indent = None
26
+ for line in lines:
27
+ if line.strip():
28
+ indent = len(line) - len(line.lstrip())
29
+ if min_indent is None or indent < min_indent:
30
+ min_indent = indent
31
+
32
+ if min_indent is None:
33
+ min_indent = 0
34
+
35
+ dedented_lines = []
36
+ for line in lines:
37
+ if line.strip():
38
+ if len(line) >= min_indent:
39
+ dedented_lines.append(line[min_indent:])
40
+ else:
41
+ dedented_lines.append(line.lstrip())
42
+ else:
43
+ dedented_lines.append("")
44
+
45
+ return "\n".join(dedented_lines)
46
+
47
+ def extract_python_code(text: str) -> str:
48
+ if not text:
49
+ return ""
50
+
51
+ text = text.replace("<|role_end|>", "").replace("<|endoftext|>", "").replace("<|notification_end|>", "")
52
+ text = text.replace("[DONE]", "")
53
+
54
+ tag_match = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL)
55
+ if tag_match:
56
+ text = tag_match.group(1)
57
+
58
+ code_block_pattern = re.compile(r"```(?:python)?\n?(.*?)```", re.DOTALL)
59
+ match = code_block_pattern.search(text)
60
+
61
+ if match:
62
+ content = match.group(1)
63
+ else:
64
+ if "```" in text:
65
+ content = text.split("```")[0]
66
+ else:
67
+ lines = text.split('\n')
68
+ start_idx = 0
69
+ stop_words = ["Here is", "Explanation", "Example", "Note", "python", "The code"]
70
+
71
+ for i, line in enumerate(lines):
72
+ stripped = line.strip()
73
+ if stripped.startswith(("def ", "import ", "from ", "class ")):
74
+ start_idx = i
75
+ break
76
+ if any(sw in line for sw in stop_words) and not stripped.endswith(":"):
77
+ continue
78
+
79
+ content = "\n".join(lines[start_idx:])
80
+
81
+ return strict_dedent(content)
82
+
83
+ def normalize_code(code: str) -> str:
84
+ try:
85
+ tree = ast.parse(code)
86
+ for node in ast.walk(tree):
87
+ if isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Module)):
88
+ if (node.body and isinstance(node.body[0], ast.Expr) and
89
+ isinstance(node.body[0].value, ast.Constant) and isinstance(node.body[0].value.value, str)):
90
+ node.body.pop(0)
91
+ return ast.unparse(tree).strip()
92
+ except:
93
+ return re.sub(r"\s+", "", code)
94
+
95
+ def perform_majority_voting(trajectories):
96
+ candidate_stats = {}
97
+
98
+ for item in trajectories:
99
+ if isinstance(item, dict):
100
+ raw_text = item.get("resp", "")
101
+ score = item.get("score", 0.0)
102
+ elif isinstance(item, (list, tuple)):
103
+ raw_text = item[0]
104
+ score = 0.0
105
+ else:
106
+ raw_text = str(item)
107
+ score = 0.0
108
+
109
+ extracted_code = extract_python_code(raw_text)
110
+
111
+ if not extracted_code.strip():
112
+ continue
113
+
114
+ is_valid = False
115
+ try:
116
+ ast.parse(extracted_code)
117
+ is_valid = True
118
+ except:
119
+ is_valid = False
120
+
121
+ norm_key = normalize_code(extracted_code)
122
+ if not norm_key: continue
123
+
124
+ if norm_key not in candidate_stats:
125
+ candidate_stats[norm_key] = {
126
+ "count": 0,
127
+ "max_score": -float("inf"),
128
+ "code": extracted_code,
129
+ "is_valid": is_valid
130
+ }
131
+
132
+ candidate_stats[norm_key]["count"] += 1
133
+ candidate_stats[norm_key]["max_score"] = max(candidate_stats[norm_key]["max_score"], score)
134
+
135
+ if not candidate_stats:
136
+ return ""
137
+
138
+ sorted_candidates = sorted(
139
+ candidate_stats.values(),
140
+ key=lambda x: (x["is_valid"], x["count"], x["max_score"]),
141
+ reverse=True
142
+ )
143
+
144
+ return sorted_candidates[0]["code"]
145
+
146
+ def run_evaluation(target_path):
147
+ if os.path.isdir(target_path):
148
+ jsonl_files = glob.glob(os.path.join(target_path, "*.jsonl"))
149
+ else:
150
+ jsonl_files = [target_path]
151
+
152
+ try:
153
+ code_eval = hf_evaluate.load("code_eval")
154
+ except Exception as e:
155
+ print(f"Error loading code_eval: {e}")
156
+ return
157
+
158
+ for file_path in jsonl_files:
159
+ print(f"\n>>> 正在评测 MBPP 文件: {file_path}")
160
+
161
+ all_voted_predictions = []
162
+ all_references = []
163
+ detailed_logs = []
164
+
165
+ nfe_total = 0
166
+ svf_total = 0
167
+ count_valid_samples = 0
168
+
169
+ with open(file_path, 'r', encoding='utf-8') as f:
170
+ lines = f.readlines()
171
+
172
+ for idx, line in enumerate(lines):
173
+ if not line.strip(): continue
174
+ try:
175
+ data = json.loads(line)
176
+ except json.JSONDecodeError:
177
+ continue
178
+
179
+ doc = data.get("doc", {})
180
+ task_id = doc.get("task_id", f"MBPP_{idx}")
181
+
182
+ test_list = doc.get("test_list", [])
183
+ test_setup = doc.get("test_setup_code", "")
184
+ challenge_tests = doc.get("challenge_test_list", [])
185
+
186
+ full_test_code = ""
187
+ if test_setup:
188
+ full_test_code += test_setup + "\n"
189
+ if test_list:
190
+ full_test_code += "\n".join(test_list) + "\n"
191
+ if challenge_tests:
192
+ full_test_code += "\n".join(challenge_tests)
193
+
194
+ current_nfe = data.get("nfe", 0)
195
+ current_svf = data.get("svf_calls", 0)
196
+
197
+ nfe_total += current_nfe
198
+ svf_total += current_svf
199
+ count_valid_samples += 1
200
+
201
+ trajectories = data.get("all_trajectories", [])
202
+ if not trajectories:
203
+ resps = data.get("resps", [])
204
+ trajectories = [{"resp": r} for r in resps]
205
+
206
+ voted_code = perform_majority_voting(trajectories)
207
+
208
+ if not voted_code:
209
+ voted_code = "def placeholder(): pass"
210
+
211
+ all_voted_predictions.append([voted_code])
212
+ all_references.append(full_test_code)
213
+
214
+ detailed_logs.append({
215
+ "task_id": task_id,
216
+ "final_code": voted_code,
217
+ "reference": full_test_code,
218
+ "nfe": current_nfe,
219
+ "svf": current_svf,
220
+ "traj_count": len(trajectories)
221
+ })
222
+
223
+ if not all_voted_predictions:
224
+ print("未找到有效数据。")
225
+ continue
226
+
227
+ print(f"正在执行代码测试 (共 {len(all_voted_predictions)} 题)...")
228
+
229
+ pass_at_k, exec_results = code_eval.compute(
230
+ references=all_references,
231
+ predictions=all_voted_predictions,
232
+ k=[1],
233
+ num_workers=4
234
+ )
235
+
236
+ accuracy = pass_at_k.get("pass@1", 0.0) * 100
237
+ avg_nfe = nfe_total / count_valid_samples if count_valid_samples > 0 else 0
238
+ avg_svf = svf_total / count_valid_samples if count_valid_samples > 0 else 0
239
+ print(f"Accuracy: {accuracy:.2f}% | NFE: {avg_nfe:.1f} | SVF: {avg_svf:.1f}")
240
+
241
+ for i, log in enumerate(detailed_logs):
242
+ res = exec_results.get(i, [])
243
+ if res and len(res) > 0:
244
+ is_passed = res[0][1].get("passed", False)
245
+ eval_result_str = res[0][1].get("result", "passed") if not is_passed else "passed"
246
+ else:
247
+ is_passed = False
248
+ eval_result_str = "Execution Failed"
249
+
250
+ log["passed"] = is_passed
251
+ log["exec_msg"] = eval_result_str
252
+
253
+ output_name = f"eval_mbpp_{os.path.basename(file_path).replace('.jsonl', '.json')}"
254
+ output_path = os.path.join(os.path.dirname(file_path), output_name)
255
+
256
+ final_report = {
257
+ "meta": {
258
+ "file": file_path,
259
+ "total_samples": count_valid_samples
260
+ },
261
+ "metrics": {
262
+ "accuracy": f"{accuracy:.2f}%",
263
+ "avg_nfe": avg_nfe,
264
+ "avg_svf": avg_svf
265
+ },
266
+ "details": detailed_logs
267
+ }
268
+
269
+ with open(output_path, 'w', encoding='utf-8') as out_f:
270
+ json.dump(final_report, out_f, ensure_ascii=False, indent=4)
271
+ print(f"结果已保存至: {output_path}\n")
272
+
273
+ if __name__ == "__main__":
274
+ parser = argparse.ArgumentParser(description="MBPP Metrics Evaluation Script")
275
+ parser.add_argument("-r", "--res_path", type=str, default=RES_PATH, help="Path to jsonl result file or directory")
276
+ args = parser.parse_args()
277
+
278
+ if os.path.exists(args.res_path):
279
+ run_evaluation(args.res_path)
280
+ else:
281
+ print(f"Path not found: {args.res_path}")
Prism/Dream/Dream_Prism/scripts/run_gsm8k.sh ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e
3
+ set -x
4
+
5
+ PROJECT_ROOT="<PATH_TO_YOUR_DREAM_ROOT>"
6
+ MODEL_PATH="<PATH_TO_YOUR_DREAM_V0_INSTRUCT_7B>"
7
+ BASE_OUTPUT_PATH="${PROJECT_ROOT}/outputs/dream_gsm8k"
8
+
9
+ cd ${PROJECT_ROOT}
10
+ export CUDA_VISIBLE_DEVICES=0
11
+ export HF_ENDPOINT=https://hf-mirror.com
12
+ export HF_ALLOW_CODE_EVAL=1
13
+ export PYTHONPATH=.
14
+
15
+ TASK="gsm8k"
16
+ LENGTH=256
17
+ STEPS=256
18
+ PORT=12334
19
+ NAME="win_0.1-0.6_s2_k4"
20
+
21
+ mkdir -p "${BASE_OUTPUT_PATH}/${NAME}"
22
+
23
+ accelerate launch --main_process_port ${PORT} -m lm_eval\
24
+ --model diffllm \
25
+ --tasks ${TASK} \
26
+ --batch_size 1 \
27
+ --model_args "pretrained=${MODEL_PATH},trust_remote_code=True,dtype=bfloat16,max_new_tokens=${LENGTH},diffusion_steps=${STEPS}" \
28
+ --gen_kwargs "use_hts=True,initial_N=16,final_K=4,hts_survivor_k=2,hts_mode=True,hts_start_pct=0.1,hts_end_pct=0.6,pruning_interval=3,decay_factor=1.8,reward_mode=svf,task_type=math,temperature=0.7,realtime_output=${BASE_OUTPUT_PATH}/${NAME}/res.jsonl" \
29
+ --num_fewshot 0 \
30
+ --confirm_run_unsafe_code \
31
+ --output_path "${BASE_OUTPUT_PATH}/${NAME}"
Prism/Dream/Dream_Prism/scripts/run_humaneval.sh ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e
3
+ set -x
4
+
5
+ PROJECT_ROOT="<PATH_TO_YOUR_DREAM_ROOT>"
6
+ MODEL_PATH="<PATH_TO_YOUR_DREAM_V0_INSTRUCT_7B>"
7
+ BASE_OUTPUT_PATH="${PROJECT_ROOT}/outputs/dream_humaneval"
8
+
9
+ cd ${PROJECT_ROOT}
10
+ export CUDA_VISIBLE_DEVICES=0
11
+ export HF_ENDPOINT=https://hf-mirror.com
12
+ export HF_ALLOW_CODE_EVAL=1
13
+ export PYTHONPATH=.
14
+
15
+ TASK="humaneval_instruct"
16
+ LENGTH=512
17
+ STEPS=512
18
+ PORT=12334
19
+ NAME="win_0.1-0.6_s2_k4"
20
+
21
+ mkdir -p "${BASE_OUTPUT_PATH}/${NAME}"
22
+
23
+ accelerate launch --main_process_port ${PORT} -m lm_eval\
24
+ --model diffllm \
25
+ --tasks ${TASK} \
26
+ --batch_size 1 \
27
+ --model_args "pretrained=${MODEL_PATH},trust_remote_code=True,dtype=bfloat16,max_new_tokens=${LENGTH},diffusion_steps=${STEPS}" \
28
+ --gen_kwargs "use_hts=True,initial_N=16,final_K=4,hts_survivor_k=2,hts_mode=True,hts_start_pct=0.1,hts_end_pct=0.6,pruning_interval=20,decay_factor=1.8,reward_mode=svf,task_type=code,temperature=0.7,realtime_output=${BASE_OUTPUT_PATH}/${NAME}/res.jsonl" \
29
+ --num_fewshot 0 \
30
+ --confirm_run_unsafe_code \
31
+ --output_path "${BASE_OUTPUT_PATH}/${NAME}"
Prism/Dream/Dream_Prism/scripts/run_math500.sh ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e
3
+ set -x
4
+
5
+ PROJECT_ROOT="<PATH_TO_YOUR_DREAM_ROOT>"
6
+ MODEL_PATH="<PATH_TO_YOUR_DREAM_V0_INSTRUCT_7B>"
7
+ BASE_OUTPUT_PATH="${PROJECT_ROOT}/outputs/dream_math500"
8
+
9
+ cd ${PROJECT_ROOT}
10
+ export CUDA_VISIBLE_DEVICES=0
11
+ export HF_ENDPOINT=https://hf-mirror.com
12
+ export HF_ALLOW_CODE_EVAL=1
13
+ export PYTHONPATH=.
14
+
15
+ TASK="math500"
16
+ LENGTH=256
17
+ STEPS=256
18
+ PORT=12334
19
+ NAME="win_0.1-0.6_s2_k4"
20
+
21
+ mkdir -p "${BASE_OUTPUT_PATH}/${NAME}"
22
+
23
+ accelerate launch --main_process_port ${PORT} -m lm_eval\
24
+ --model diffllm \
25
+ --tasks ${TASK} \
26
+ --batch_size 1 \
27
+ --model_args "pretrained=${MODEL_PATH},trust_remote_code=True,dtype=bfloat16,max_new_tokens=${LENGTH},diffusion_steps=${STEPS}" \
28
+ --gen_kwargs "use_hts=True,initial_N=16,final_K=4,hts_survivor_k=2,hts_mode=True,hts_start_pct=0.1,hts_end_pct=0.6,pruning_interval=10,decay_factor=1.8,reward_mode=svf,task_type=math,temperature=0.7,realtime_output=${BASE_OUTPUT_PATH}/${NAME}/res.jsonl" \
29
+ --num_fewshot 0 \
30
+ --output_path "${BASE_OUTPUT_PATH}/${NAME}"
Prism/Dream/Dream_Prism/scripts/run_mbpp.sh ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e
3
+ set -x
4
+
5
+ PROJECT_ROOT="<PATH_TO_YOUR_DREAM_ROOT>"
6
+ MODEL_PATH="<PATH_TO_YOUR_DREAM_V0_INSTRUCT_7B>"
7
+ BASE_OUTPUT_PATH="${PROJECT_ROOT}/outputs/dream_mbpp"
8
+
9
+ cd ${PROJECT_ROOT}
10
+ export CUDA_VISIBLE_DEVICES=0
11
+ export HF_ENDPOINT=https://hf-mirror.com
12
+ export HF_ALLOW_CODE_EVAL=1
13
+ export PYTHONPATH=.
14
+
15
+ TASK="mbpp"
16
+ LENGTH=512
17
+ STEPS=512
18
+ PORT=12334
19
+ NAME="win_0.1-0.6_s2_k4"
20
+
21
+ mkdir -p "${BASE_OUTPUT_PATH}/${NAME}"
22
+
23
+ accelerate launch --main_process_port ${PORT} -m lm_eval\
24
+ --model diffllm \
25
+ --tasks ${TASK} \
26
+ --batch_size 1 \
27
+ --model_args "pretrained=${MODEL_PATH},trust_remote_code=True,dtype=bfloat16,max_new_tokens=${LENGTH},diffusion_steps=${STEPS}" \
28
+ --gen_kwargs "use_hts=True,initial_N=16,final_K=4,hts_survivor_k=2,hts_mode=True,hts_start_pct=0.1,hts_end_pct=0.6,pruning_interval=3,decay_factor=1.8,reward_mode=svf,task_type=code,temperature=0.7,realtime_output=${BASE_OUTPUT_PATH}/${NAME}/res.jsonl" \
29
+ --num_fewshot 0 \
30
+ --confirm_run_unsafe_code \
31
+ --output_path "${BASE_OUTPUT_PATH}/${NAME}"
Prism/Dream/Dream_Prism/src/__init__.py ADDED
File without changes
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/aiohappyeyeballs/__pycache__/types.cpython-312.pyc ADDED
Binary file (659 Bytes). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/components/semiconnected.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Semiconnectedness."""
2
+
3
+ import networkx as nx
4
+ from networkx.utils import not_implemented_for, pairwise
5
+
6
+ __all__ = ["is_semiconnected"]
7
+
8
+
9
+ @not_implemented_for("undirected")
10
+ @nx._dispatchable
11
+ def is_semiconnected(G):
12
+ r"""Returns True if the graph is semiconnected, False otherwise.
13
+
14
+ A graph is semiconnected if and only if for any pair of nodes, either one
15
+ is reachable from the other, or they are mutually reachable.
16
+
17
+ This function uses a theorem that states that a DAG is semiconnected
18
+ if for any topological sort, for node $v_n$ in that sort, there is an
19
+ edge $(v_i, v_{i+1})$. That allows us to check if a non-DAG `G` is
20
+ semiconnected by condensing the graph: i.e. constructing a new graph `H`
21
+ with nodes being the strongly connected components of `G`, and edges
22
+ (scc_1, scc_2) if there is a edge $(v_1, v_2)$ in `G` for some
23
+ $v_1 \in scc_1$ and $v_2 \in scc_2$. That results in a DAG, so we compute
24
+ the topological sort of `H` and check if for every $n$ there is an edge
25
+ $(scc_n, scc_{n+1})$.
26
+
27
+ Parameters
28
+ ----------
29
+ G : NetworkX graph
30
+ A directed graph.
31
+
32
+ Returns
33
+ -------
34
+ semiconnected : bool
35
+ True if the graph is semiconnected, False otherwise.
36
+
37
+ Raises
38
+ ------
39
+ NetworkXNotImplemented
40
+ If the input graph is undirected.
41
+
42
+ NetworkXPointlessConcept
43
+ If the graph is empty.
44
+
45
+ Examples
46
+ --------
47
+ >>> G = nx.path_graph(4, create_using=nx.DiGraph())
48
+ >>> print(nx.is_semiconnected(G))
49
+ True
50
+ >>> G = nx.DiGraph([(1, 2), (3, 2)])
51
+ >>> print(nx.is_semiconnected(G))
52
+ False
53
+
54
+ See Also
55
+ --------
56
+ is_strongly_connected
57
+ is_weakly_connected
58
+ is_connected
59
+ is_biconnected
60
+ """
61
+ if len(G) == 0:
62
+ raise nx.NetworkXPointlessConcept(
63
+ "Connectivity is undefined for the null graph."
64
+ )
65
+
66
+ if not nx.is_weakly_connected(G):
67
+ return False
68
+
69
+ H = nx.condensation(G)
70
+
71
+ return all(H.has_edge(u, v) for u, v in pairwise(nx.topological_sort(H)))
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/operators/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from networkx.algorithms.operators.all import *
2
+ from networkx.algorithms.operators.binary import *
3
+ from networkx.algorithms.operators.product import *
4
+ from networkx.algorithms.operators.unary import *
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/operators/all.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Operations on many graphs."""
2
+
3
+ from itertools import chain, repeat
4
+
5
+ import networkx as nx
6
+
7
+ __all__ = ["union_all", "compose_all", "disjoint_union_all", "intersection_all"]
8
+
9
+
10
+ @nx._dispatchable(graphs="[graphs]", preserve_all_attrs=True, returns_graph=True)
11
+ def union_all(graphs, rename=()):
12
+ """Returns the union of all graphs.
13
+
14
+ The graphs must be disjoint, otherwise an exception is raised.
15
+
16
+ Parameters
17
+ ----------
18
+ graphs : iterable
19
+ Iterable of NetworkX graphs
20
+
21
+ rename : iterable , optional
22
+ Node names of graphs can be changed by specifying the tuple
23
+ rename=('G-','H-') (for example). Node "u" in G is then renamed
24
+ "G-u" and "v" in H is renamed "H-v". Infinite generators (like itertools.count)
25
+ are also supported.
26
+
27
+ Returns
28
+ -------
29
+ U : a graph with the same type as the first graph in list
30
+
31
+ Raises
32
+ ------
33
+ ValueError
34
+ If `graphs` is an empty list.
35
+
36
+ NetworkXError
37
+ In case of mixed type graphs, like MultiGraph and Graph, or directed and undirected graphs.
38
+
39
+ Notes
40
+ -----
41
+ For operating on mixed type graphs, they should be converted to the same type.
42
+ >>> G = nx.Graph()
43
+ >>> H = nx.DiGraph()
44
+ >>> GH = union_all([nx.DiGraph(G), H])
45
+
46
+ To force a disjoint union with node relabeling, use
47
+ disjoint_union_all(G,H) or convert_node_labels_to integers().
48
+
49
+ Graph, edge, and node attributes are propagated to the union graph.
50
+ If a graph attribute is present in multiple graphs, then the value
51
+ from the last graph in the list with that attribute is used.
52
+
53
+ Examples
54
+ --------
55
+ >>> G1 = nx.Graph([(1, 2), (2, 3)])
56
+ >>> G2 = nx.Graph([(4, 5), (5, 6)])
57
+ >>> result_graph = nx.union_all([G1, G2])
58
+ >>> result_graph.nodes()
59
+ NodeView((1, 2, 3, 4, 5, 6))
60
+ >>> result_graph.edges()
61
+ EdgeView([(1, 2), (2, 3), (4, 5), (5, 6)])
62
+
63
+ See Also
64
+ --------
65
+ union
66
+ disjoint_union_all
67
+ """
68
+ R = None
69
+ seen_nodes = set()
70
+
71
+ # rename graph to obtain disjoint node labels
72
+ def add_prefix(graph, prefix):
73
+ if prefix is None:
74
+ return graph
75
+
76
+ def label(x):
77
+ return f"{prefix}{x}"
78
+
79
+ return nx.relabel_nodes(graph, label)
80
+
81
+ rename = chain(rename, repeat(None))
82
+ graphs = (add_prefix(G, name) for G, name in zip(graphs, rename))
83
+
84
+ for i, G in enumerate(graphs):
85
+ G_nodes_set = set(G.nodes)
86
+ if i == 0:
87
+ # Union is the same type as first graph
88
+ R = G.__class__()
89
+ elif G.is_directed() != R.is_directed():
90
+ raise nx.NetworkXError("All graphs must be directed or undirected.")
91
+ elif G.is_multigraph() != R.is_multigraph():
92
+ raise nx.NetworkXError("All graphs must be graphs or multigraphs.")
93
+ elif not seen_nodes.isdisjoint(G_nodes_set):
94
+ raise nx.NetworkXError(
95
+ "The node sets of the graphs are not disjoint.\n"
96
+ "Use `rename` to specify prefixes for the graphs or use\n"
97
+ "disjoint_union(G1, G2, ..., GN)."
98
+ )
99
+
100
+ seen_nodes |= G_nodes_set
101
+ R.graph.update(G.graph)
102
+ R.add_nodes_from(G.nodes(data=True))
103
+ R.add_edges_from(
104
+ G.edges(keys=True, data=True) if G.is_multigraph() else G.edges(data=True)
105
+ )
106
+
107
+ if R is None:
108
+ raise ValueError("cannot apply union_all to an empty list")
109
+
110
+ return R
111
+
112
+
113
+ @nx._dispatchable(graphs="[graphs]", preserve_all_attrs=True, returns_graph=True)
114
+ def disjoint_union_all(graphs):
115
+ """Returns the disjoint union of all graphs.
116
+
117
+ This operation forces distinct integer node labels starting with 0
118
+ for the first graph in the list and numbering consecutively.
119
+
120
+ Parameters
121
+ ----------
122
+ graphs : iterable
123
+ Iterable of NetworkX graphs
124
+
125
+ Returns
126
+ -------
127
+ U : A graph with the same type as the first graph in list
128
+
129
+ Raises
130
+ ------
131
+ ValueError
132
+ If `graphs` is an empty list.
133
+
134
+ NetworkXError
135
+ In case of mixed type graphs, like MultiGraph and Graph, or directed and undirected graphs.
136
+
137
+ Examples
138
+ --------
139
+ >>> G1 = nx.Graph([(1, 2), (2, 3)])
140
+ >>> G2 = nx.Graph([(4, 5), (5, 6)])
141
+ >>> U = nx.disjoint_union_all([G1, G2])
142
+ >>> list(U.nodes())
143
+ [0, 1, 2, 3, 4, 5]
144
+ >>> list(U.edges())
145
+ [(0, 1), (1, 2), (3, 4), (4, 5)]
146
+
147
+ Notes
148
+ -----
149
+ For operating on mixed type graphs, they should be converted to the same type.
150
+
151
+ Graph, edge, and node attributes are propagated to the union graph.
152
+ If a graph attribute is present in multiple graphs, then the value
153
+ from the last graph in the list with that attribute is used.
154
+ """
155
+
156
+ def yield_relabeled(graphs):
157
+ first_label = 0
158
+ for G in graphs:
159
+ yield nx.convert_node_labels_to_integers(G, first_label=first_label)
160
+ first_label += len(G)
161
+
162
+ R = union_all(yield_relabeled(graphs))
163
+
164
+ return R
165
+
166
+
167
+ @nx._dispatchable(graphs="[graphs]", preserve_all_attrs=True, returns_graph=True)
168
+ def compose_all(graphs):
169
+ """Returns the composition of all graphs.
170
+
171
+ Composition is the simple union of the node sets and edge sets.
172
+ The node sets of the supplied graphs need not be disjoint.
173
+
174
+ Parameters
175
+ ----------
176
+ graphs : iterable
177
+ Iterable of NetworkX graphs
178
+
179
+ Returns
180
+ -------
181
+ C : A graph with the same type as the first graph in list
182
+
183
+ Raises
184
+ ------
185
+ ValueError
186
+ If `graphs` is an empty list.
187
+
188
+ NetworkXError
189
+ In case of mixed type graphs, like MultiGraph and Graph, or directed and undirected graphs.
190
+
191
+ Examples
192
+ --------
193
+ >>> G1 = nx.Graph([(1, 2), (2, 3)])
194
+ >>> G2 = nx.Graph([(3, 4), (5, 6)])
195
+ >>> C = nx.compose_all([G1, G2])
196
+ >>> list(C.nodes())
197
+ [1, 2, 3, 4, 5, 6]
198
+ >>> list(C.edges())
199
+ [(1, 2), (2, 3), (3, 4), (5, 6)]
200
+
201
+ Notes
202
+ -----
203
+ For operating on mixed type graphs, they should be converted to the same type.
204
+
205
+ Graph, edge, and node attributes are propagated to the union graph.
206
+ If a graph attribute is present in multiple graphs, then the value
207
+ from the last graph in the list with that attribute is used.
208
+ """
209
+ R = None
210
+
211
+ # add graph attributes, H attributes take precedent over G attributes
212
+ for i, G in enumerate(graphs):
213
+ if i == 0:
214
+ # create new graph
215
+ R = G.__class__()
216
+ elif G.is_directed() != R.is_directed():
217
+ raise nx.NetworkXError("All graphs must be directed or undirected.")
218
+ elif G.is_multigraph() != R.is_multigraph():
219
+ raise nx.NetworkXError("All graphs must be graphs or multigraphs.")
220
+
221
+ R.graph.update(G.graph)
222
+ R.add_nodes_from(G.nodes(data=True))
223
+ R.add_edges_from(
224
+ G.edges(keys=True, data=True) if G.is_multigraph() else G.edges(data=True)
225
+ )
226
+
227
+ if R is None:
228
+ raise ValueError("cannot apply compose_all to an empty list")
229
+
230
+ return R
231
+
232
+
233
+ @nx._dispatchable(graphs="[graphs]", returns_graph=True)
234
+ def intersection_all(graphs):
235
+ """Returns a new graph that contains only the nodes and the edges that exist in
236
+ all graphs.
237
+
238
+ Parameters
239
+ ----------
240
+ graphs : iterable
241
+ Iterable of NetworkX graphs
242
+
243
+ Returns
244
+ -------
245
+ R : A new graph with the same type as the first graph in list
246
+
247
+ Raises
248
+ ------
249
+ ValueError
250
+ If `graphs` is an empty list.
251
+
252
+ NetworkXError
253
+ In case of mixed type graphs, like MultiGraph and Graph, or directed and undirected graphs.
254
+
255
+ Notes
256
+ -----
257
+ For operating on mixed type graphs, they should be converted to the same type.
258
+
259
+ Attributes from the graph, nodes, and edges are not copied to the new
260
+ graph.
261
+
262
+ The resulting graph can be updated with attributes if desired.
263
+ For example, code which adds the minimum attribute for each node across all
264
+ graphs could work::
265
+
266
+ >>> g = nx.Graph()
267
+ >>> g.add_node(0, capacity=4)
268
+ >>> g.add_node(1, capacity=3)
269
+ >>> g.add_edge(0, 1)
270
+
271
+ >>> h = g.copy()
272
+ >>> h.nodes[0]["capacity"] = 2
273
+
274
+ >>> gh = nx.intersection_all([g, h])
275
+
276
+ >>> new_node_attr = {
277
+ ... n: min(*(anyG.nodes[n].get("capacity", float("inf")) for anyG in [g, h]))
278
+ ... for n in gh
279
+ ... }
280
+ >>> nx.set_node_attributes(gh, new_node_attr, "new_capacity")
281
+ >>> gh.nodes(data=True)
282
+ NodeDataView({0: {'new_capacity': 2}, 1: {'new_capacity': 3}})
283
+
284
+ Examples
285
+ --------
286
+ >>> G1 = nx.Graph([(1, 2), (2, 3)])
287
+ >>> G2 = nx.Graph([(2, 3), (3, 4)])
288
+ >>> R = nx.intersection_all([G1, G2])
289
+ >>> list(R.nodes())
290
+ [2, 3]
291
+ >>> list(R.edges())
292
+ [(2, 3)]
293
+
294
+ """
295
+ R = None
296
+
297
+ for i, G in enumerate(graphs):
298
+ G_nodes_set = set(G.nodes)
299
+ G_edges_set = set(G.edges)
300
+ if not G.is_directed():
301
+ if G.is_multigraph():
302
+ G_edges_set.update((v, u, k) for u, v, k in list(G_edges_set))
303
+ else:
304
+ G_edges_set.update((v, u) for u, v in list(G_edges_set))
305
+ if i == 0:
306
+ # create new graph
307
+ R = G.__class__()
308
+ node_intersection = G_nodes_set
309
+ edge_intersection = G_edges_set
310
+ elif G.is_directed() != R.is_directed():
311
+ raise nx.NetworkXError("All graphs must be directed or undirected.")
312
+ elif G.is_multigraph() != R.is_multigraph():
313
+ raise nx.NetworkXError("All graphs must be graphs or multigraphs.")
314
+ else:
315
+ node_intersection &= G_nodes_set
316
+ edge_intersection &= G_edges_set
317
+
318
+ if R is None:
319
+ raise ValueError("cannot apply intersection_all to an empty list")
320
+
321
+ R.add_nodes_from(node_intersection)
322
+ R.add_edges_from(edge_intersection)
323
+
324
+ return R
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/operators/binary.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Operations on graphs including union, intersection, difference.
3
+ """
4
+
5
+ import networkx as nx
6
+
7
+ __all__ = [
8
+ "union",
9
+ "compose",
10
+ "disjoint_union",
11
+ "intersection",
12
+ "difference",
13
+ "symmetric_difference",
14
+ "full_join",
15
+ ]
16
+ _G_H = {"G": 0, "H": 1}
17
+
18
+
19
+ @nx._dispatchable(graphs=_G_H, preserve_all_attrs=True, returns_graph=True)
20
+ def union(G, H, rename=()):
21
+ """Combine graphs G and H. The names of nodes must be unique.
22
+
23
+ A name collision between the graphs will raise an exception.
24
+
25
+ A renaming facility is provided to avoid name collisions.
26
+
27
+
28
+ Parameters
29
+ ----------
30
+ G, H : graph
31
+ A NetworkX graph
32
+
33
+ rename : iterable , optional
34
+ Node names of G and H can be changed by specifying the tuple
35
+ rename=('G-','H-') (for example). Node "u" in G is then renamed
36
+ "G-u" and "v" in H is renamed "H-v".
37
+
38
+ Returns
39
+ -------
40
+ U : A union graph with the same type as G.
41
+
42
+ See Also
43
+ --------
44
+ compose
45
+ :func:`~networkx.Graph.update`
46
+ disjoint_union
47
+
48
+ Notes
49
+ -----
50
+ To combine graphs that have common nodes, consider compose(G, H)
51
+ or the method, Graph.update().
52
+
53
+ disjoint_union() is similar to union() except that it avoids name clashes
54
+ by relabeling the nodes with sequential integers.
55
+
56
+ Edge and node attributes are propagated from G and H to the union graph.
57
+ Graph attributes are also propagated, but if they are present in both G and H,
58
+ then the value from H is used.
59
+
60
+ Examples
61
+ --------
62
+ >>> from pprint import pprint
63
+ >>> G = nx.Graph([(0, 1), (0, 2), (1, 2)])
64
+ >>> H = nx.Graph([(0, 1), (0, 3), (1, 3), (1, 2)])
65
+ >>> U = nx.union(G, H, rename=("G", "H"))
66
+ >>> U.nodes
67
+ NodeView(('G0', 'G1', 'G2', 'H0', 'H1', 'H3', 'H2'))
68
+ >>> edgelist = list(U.edges)
69
+ >>> pprint(edgelist)
70
+ [('G0', 'G1'),
71
+ ('G0', 'G2'),
72
+ ('G1', 'G2'),
73
+ ('H0', 'H1'),
74
+ ('H0', 'H3'),
75
+ ('H1', 'H3'),
76
+ ('H1', 'H2')]
77
+
78
+
79
+ """
80
+ return nx.union_all([G, H], rename)
81
+
82
+
83
+ @nx._dispatchable(graphs=_G_H, preserve_all_attrs=True, returns_graph=True)
84
+ def disjoint_union(G, H):
85
+ """Combine graphs G and H. The nodes are assumed to be unique (disjoint).
86
+
87
+ This algorithm automatically relabels nodes to avoid name collisions.
88
+
89
+ Parameters
90
+ ----------
91
+ G,H : graph
92
+ A NetworkX graph
93
+
94
+ Returns
95
+ -------
96
+ U : A union graph with the same type as G.
97
+
98
+ See Also
99
+ --------
100
+ union
101
+ compose
102
+ :func:`~networkx.Graph.update`
103
+
104
+ Notes
105
+ -----
106
+ A new graph is created, of the same class as G. It is recommended
107
+ that G and H be either both directed or both undirected.
108
+
109
+ The nodes of G are relabeled 0 to len(G)-1, and the nodes of H are
110
+ relabeled len(G) to len(G)+len(H)-1.
111
+
112
+ Renumbering forces G and H to be disjoint, so no exception is ever raised for a name collision.
113
+ To preserve the check for common nodes, use union().
114
+
115
+ Edge and node attributes are propagated from G and H to the union graph.
116
+ Graph attributes are also propagated, but if they are present in both G and H,
117
+ then the value from H is used.
118
+
119
+ To combine graphs that have common nodes, consider compose(G, H)
120
+ or the method, Graph.update().
121
+
122
+ Examples
123
+ --------
124
+ >>> G = nx.Graph([(0, 1), (0, 2), (1, 2)])
125
+ >>> H = nx.Graph([(0, 3), (1, 2), (2, 3)])
126
+ >>> G.nodes[0]["key1"] = 5
127
+ >>> H.nodes[0]["key2"] = 10
128
+ >>> U = nx.disjoint_union(G, H)
129
+ >>> U.nodes(data=True)
130
+ NodeDataView({0: {'key1': 5}, 1: {}, 2: {}, 3: {'key2': 10}, 4: {}, 5: {}, 6: {}})
131
+ >>> U.edges
132
+ EdgeView([(0, 1), (0, 2), (1, 2), (3, 4), (4, 6), (5, 6)])
133
+ """
134
+ return nx.disjoint_union_all([G, H])
135
+
136
+
137
+ @nx._dispatchable(graphs=_G_H, returns_graph=True)
138
+ def intersection(G, H):
139
+ """Returns a new graph that contains only the nodes and the edges that exist in
140
+ both G and H.
141
+
142
+ Parameters
143
+ ----------
144
+ G,H : graph
145
+ A NetworkX graph. G and H can have different node sets but must be both graphs or both multigraphs.
146
+
147
+ Raises
148
+ ------
149
+ NetworkXError
150
+ If one is a MultiGraph and the other one is a graph.
151
+
152
+ Returns
153
+ -------
154
+ GH : A new graph with the same type as G.
155
+
156
+ Notes
157
+ -----
158
+ Attributes from the graph, nodes, and edges are not copied to the new
159
+ graph. If you want a new graph of the intersection of G and H
160
+ with the attributes (including edge data) from G use remove_nodes_from()
161
+ as follows
162
+
163
+ >>> G = nx.path_graph(3)
164
+ >>> H = nx.path_graph(5)
165
+ >>> R = G.copy()
166
+ >>> R.remove_nodes_from(n for n in G if n not in H)
167
+ >>> R.remove_edges_from(e for e in G.edges if e not in H.edges)
168
+
169
+ Examples
170
+ --------
171
+ >>> G = nx.Graph([(0, 1), (0, 2), (1, 2)])
172
+ >>> H = nx.Graph([(0, 3), (1, 2), (2, 3)])
173
+ >>> R = nx.intersection(G, H)
174
+ >>> R.nodes
175
+ NodeView((0, 1, 2))
176
+ >>> R.edges
177
+ EdgeView([(1, 2)])
178
+ """
179
+ return nx.intersection_all([G, H])
180
+
181
+
182
+ @nx._dispatchable(graphs=_G_H, returns_graph=True)
183
+ def difference(G, H):
184
+ """Returns a new graph that contains the edges that exist in G but not in H.
185
+
186
+ The node sets of H and G must be the same.
187
+
188
+ Parameters
189
+ ----------
190
+ G,H : graph
191
+ A NetworkX graph. G and H must have the same node sets.
192
+
193
+ Returns
194
+ -------
195
+ D : A new graph with the same type as G.
196
+
197
+ Notes
198
+ -----
199
+ Attributes from the graph, nodes, and edges are not copied to the new
200
+ graph. If you want a new graph of the difference of G and H with
201
+ the attributes (including edge data) from G use remove_nodes_from()
202
+ as follows:
203
+
204
+ >>> G = nx.path_graph(3)
205
+ >>> H = nx.path_graph(5)
206
+ >>> R = G.copy()
207
+ >>> R.remove_nodes_from(n for n in G if n in H)
208
+
209
+ Examples
210
+ --------
211
+ >>> G = nx.Graph([(0, 1), (0, 2), (1, 2), (1, 3)])
212
+ >>> H = nx.Graph([(0, 1), (1, 2), (0, 3)])
213
+ >>> R = nx.difference(G, H)
214
+ >>> R.nodes
215
+ NodeView((0, 1, 2, 3))
216
+ >>> R.edges
217
+ EdgeView([(0, 2), (1, 3)])
218
+ """
219
+ # create new graph
220
+ if not G.is_multigraph() == H.is_multigraph():
221
+ raise nx.NetworkXError("G and H must both be graphs or multigraphs.")
222
+ R = nx.create_empty_copy(G, with_data=False)
223
+
224
+ if set(G) != set(H):
225
+ raise nx.NetworkXError("Node sets of graphs not equal")
226
+
227
+ if G.is_multigraph():
228
+ edges = G.edges(keys=True)
229
+ else:
230
+ edges = G.edges()
231
+ for e in edges:
232
+ if not H.has_edge(*e):
233
+ R.add_edge(*e)
234
+ return R
235
+
236
+
237
+ @nx._dispatchable(graphs=_G_H, returns_graph=True)
238
+ def symmetric_difference(G, H):
239
+ """Returns new graph with edges that exist in either G or H but not both.
240
+
241
+ The node sets of H and G must be the same.
242
+
243
+ Parameters
244
+ ----------
245
+ G,H : graph
246
+ A NetworkX graph. G and H must have the same node sets.
247
+
248
+ Returns
249
+ -------
250
+ D : A new graph with the same type as G.
251
+
252
+ Notes
253
+ -----
254
+ Attributes from the graph, nodes, and edges are not copied to the new
255
+ graph.
256
+
257
+ Examples
258
+ --------
259
+ >>> G = nx.Graph([(0, 1), (0, 2), (1, 2), (1, 3)])
260
+ >>> H = nx.Graph([(0, 1), (1, 2), (0, 3)])
261
+ >>> R = nx.symmetric_difference(G, H)
262
+ >>> R.nodes
263
+ NodeView((0, 1, 2, 3))
264
+ >>> R.edges
265
+ EdgeView([(0, 2), (0, 3), (1, 3)])
266
+ """
267
+ # create new graph
268
+ if not G.is_multigraph() == H.is_multigraph():
269
+ raise nx.NetworkXError("G and H must both be graphs or multigraphs.")
270
+ R = nx.create_empty_copy(G, with_data=False)
271
+
272
+ if set(G) != set(H):
273
+ raise nx.NetworkXError("Node sets of graphs not equal")
274
+
275
+ gnodes = set(G) # set of nodes in G
276
+ hnodes = set(H) # set of nodes in H
277
+ nodes = gnodes.symmetric_difference(hnodes)
278
+ R.add_nodes_from(nodes)
279
+
280
+ if G.is_multigraph():
281
+ edges = G.edges(keys=True)
282
+ else:
283
+ edges = G.edges()
284
+ # we could copy the data here but then this function doesn't
285
+ # match intersection and difference
286
+ for e in edges:
287
+ if not H.has_edge(*e):
288
+ R.add_edge(*e)
289
+
290
+ if H.is_multigraph():
291
+ edges = H.edges(keys=True)
292
+ else:
293
+ edges = H.edges()
294
+ for e in edges:
295
+ if not G.has_edge(*e):
296
+ R.add_edge(*e)
297
+ return R
298
+
299
+
300
+ @nx._dispatchable(graphs=_G_H, preserve_all_attrs=True, returns_graph=True)
301
+ def compose(G, H):
302
+ """Compose graph G with H by combining nodes and edges into a single graph.
303
+
304
+ The node sets and edges sets do not need to be disjoint.
305
+
306
+ Composing preserves the attributes of nodes and edges.
307
+ Attribute values from H take precedent over attribute values from G.
308
+
309
+ Parameters
310
+ ----------
311
+ G, H : graph
312
+ A NetworkX graph
313
+
314
+ Returns
315
+ -------
316
+ C: A new graph with the same type as G
317
+
318
+ See Also
319
+ --------
320
+ :func:`~networkx.Graph.update`
321
+ union
322
+ disjoint_union
323
+
324
+ Notes
325
+ -----
326
+ It is recommended that G and H be either both directed or both undirected.
327
+
328
+ For MultiGraphs, the edges are identified by incident nodes AND edge-key.
329
+ This can cause surprises (i.e., edge `(1, 2)` may or may not be the same
330
+ in two graphs) if you use MultiGraph without keeping track of edge keys.
331
+
332
+ If combining the attributes of common nodes is not desired, consider union(),
333
+ which raises an exception for name collisions.
334
+
335
+ Examples
336
+ --------
337
+ >>> G = nx.Graph([(0, 1), (0, 2)])
338
+ >>> H = nx.Graph([(0, 1), (1, 2)])
339
+ >>> R = nx.compose(G, H)
340
+ >>> R.nodes
341
+ NodeView((0, 1, 2))
342
+ >>> R.edges
343
+ EdgeView([(0, 1), (0, 2), (1, 2)])
344
+
345
+ By default, the attributes from `H` take precedent over attributes from `G`.
346
+ If you prefer another way of combining attributes, you can update them after the compose operation:
347
+
348
+ >>> G = nx.Graph([(0, 1, {"weight": 2.0}), (3, 0, {"weight": 100.0})])
349
+ >>> H = nx.Graph([(0, 1, {"weight": 10.0}), (1, 2, {"weight": -1.0})])
350
+ >>> nx.set_node_attributes(G, {0: "dark", 1: "light", 3: "black"}, name="color")
351
+ >>> nx.set_node_attributes(H, {0: "green", 1: "orange", 2: "yellow"}, name="color")
352
+ >>> GcomposeH = nx.compose(G, H)
353
+
354
+ Normally, color attribute values of nodes of GcomposeH come from H. We can workaround this as follows:
355
+
356
+ >>> node_data = {
357
+ ... n: G.nodes[n]["color"] + " " + H.nodes[n]["color"]
358
+ ... for n in G.nodes & H.nodes
359
+ ... }
360
+ >>> nx.set_node_attributes(GcomposeH, node_data, "color")
361
+ >>> print(GcomposeH.nodes[0]["color"])
362
+ dark green
363
+
364
+ >>> print(GcomposeH.nodes[3]["color"])
365
+ black
366
+
367
+ Similarly, we can update edge attributes after the compose operation in a way we prefer:
368
+
369
+ >>> edge_data = {
370
+ ... e: G.edges[e]["weight"] * H.edges[e]["weight"] for e in G.edges & H.edges
371
+ ... }
372
+ >>> nx.set_edge_attributes(GcomposeH, edge_data, "weight")
373
+ >>> print(GcomposeH.edges[(0, 1)]["weight"])
374
+ 20.0
375
+
376
+ >>> print(GcomposeH.edges[(3, 0)]["weight"])
377
+ 100.0
378
+ """
379
+ return nx.compose_all([G, H])
380
+
381
+
382
+ @nx._dispatchable(graphs=_G_H, preserve_all_attrs=True, returns_graph=True)
383
+ def full_join(G, H, rename=(None, None)):
384
+ """Returns the full join of graphs G and H.
385
+
386
+ Full join is the union of G and H in which all edges between
387
+ G and H are added.
388
+ The node sets of G and H must be disjoint,
389
+ otherwise an exception is raised.
390
+
391
+ Parameters
392
+ ----------
393
+ G, H : graph
394
+ A NetworkX graph
395
+
396
+ rename : tuple , default=(None, None)
397
+ Node names of G and H can be changed by specifying the tuple
398
+ rename=('G-','H-') (for example). Node "u" in G is then renamed
399
+ "G-u" and "v" in H is renamed "H-v".
400
+
401
+ Returns
402
+ -------
403
+ U : The full join graph with the same type as G.
404
+
405
+ Notes
406
+ -----
407
+ It is recommended that G and H be either both directed or both undirected.
408
+
409
+ If G is directed, then edges from G to H are added as well as from H to G.
410
+
411
+ Note that full_join() does not produce parallel edges for MultiGraphs.
412
+
413
+ The full join operation of graphs G and H is the same as getting
414
+ their complement, performing a disjoint union, and finally getting
415
+ the complement of the resulting graph.
416
+
417
+ Graph, edge, and node attributes are propagated from G and H
418
+ to the union graph. If a graph attribute is present in both
419
+ G and H the value from H is used.
420
+
421
+ Examples
422
+ --------
423
+ >>> from pprint import pprint
424
+ >>> G = nx.Graph([(0, 1), (0, 2)])
425
+ >>> H = nx.Graph([(3, 4)])
426
+ >>> R = nx.full_join(G, H, rename=("G", "H"))
427
+ >>> R.nodes
428
+ NodeView(('G0', 'G1', 'G2', 'H3', 'H4'))
429
+ >>> edgelist = list(R.edges)
430
+ >>> pprint(edgelist)
431
+ [('G0', 'G1'),
432
+ ('G0', 'G2'),
433
+ ('G0', 'H3'),
434
+ ('G0', 'H4'),
435
+ ('G1', 'H3'),
436
+ ('G1', 'H4'),
437
+ ('G2', 'H3'),
438
+ ('G2', 'H4'),
439
+ ('H3', 'H4')]
440
+
441
+ See Also
442
+ --------
443
+ union
444
+ disjoint_union
445
+ """
446
+ R = union(G, H, rename)
447
+
448
+ def add_prefix(graph, prefix):
449
+ if prefix is None:
450
+ return graph
451
+
452
+ def label(x):
453
+ return f"{prefix}{x}"
454
+
455
+ return nx.relabel_nodes(graph, label)
456
+
457
+ G = add_prefix(G, rename[0])
458
+ H = add_prefix(H, rename[1])
459
+
460
+ for i in G:
461
+ for j in H:
462
+ R.add_edge(i, j)
463
+ if R.is_directed():
464
+ for i in H:
465
+ for j in G:
466
+ R.add_edge(i, j)
467
+
468
+ return R
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/operators/product.py ADDED
@@ -0,0 +1,633 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Graph products.
3
+ """
4
+
5
+ from itertools import product
6
+
7
+ import networkx as nx
8
+ from networkx.utils import not_implemented_for
9
+
10
+ __all__ = [
11
+ "tensor_product",
12
+ "cartesian_product",
13
+ "lexicographic_product",
14
+ "strong_product",
15
+ "power",
16
+ "rooted_product",
17
+ "corona_product",
18
+ "modular_product",
19
+ ]
20
+ _G_H = {"G": 0, "H": 1}
21
+
22
+
23
+ def _dict_product(d1, d2):
24
+ return {k: (d1.get(k), d2.get(k)) for k in set(d1) | set(d2)}
25
+
26
+
27
+ # Generators for producing graph products
28
+ def _node_product(G, H):
29
+ for u, v in product(G, H):
30
+ yield ((u, v), _dict_product(G.nodes[u], H.nodes[v]))
31
+
32
+
33
+ def _directed_edges_cross_edges(G, H):
34
+ if not G.is_multigraph() and not H.is_multigraph():
35
+ for u, v, c in G.edges(data=True):
36
+ for x, y, d in H.edges(data=True):
37
+ yield (u, x), (v, y), _dict_product(c, d)
38
+ if not G.is_multigraph() and H.is_multigraph():
39
+ for u, v, c in G.edges(data=True):
40
+ for x, y, k, d in H.edges(data=True, keys=True):
41
+ yield (u, x), (v, y), k, _dict_product(c, d)
42
+ if G.is_multigraph() and not H.is_multigraph():
43
+ for u, v, k, c in G.edges(data=True, keys=True):
44
+ for x, y, d in H.edges(data=True):
45
+ yield (u, x), (v, y), k, _dict_product(c, d)
46
+ if G.is_multigraph() and H.is_multigraph():
47
+ for u, v, j, c in G.edges(data=True, keys=True):
48
+ for x, y, k, d in H.edges(data=True, keys=True):
49
+ yield (u, x), (v, y), (j, k), _dict_product(c, d)
50
+
51
+
52
+ def _undirected_edges_cross_edges(G, H):
53
+ if not G.is_multigraph() and not H.is_multigraph():
54
+ for u, v, c in G.edges(data=True):
55
+ for x, y, d in H.edges(data=True):
56
+ yield (v, x), (u, y), _dict_product(c, d)
57
+ if not G.is_multigraph() and H.is_multigraph():
58
+ for u, v, c in G.edges(data=True):
59
+ for x, y, k, d in H.edges(data=True, keys=True):
60
+ yield (v, x), (u, y), k, _dict_product(c, d)
61
+ if G.is_multigraph() and not H.is_multigraph():
62
+ for u, v, k, c in G.edges(data=True, keys=True):
63
+ for x, y, d in H.edges(data=True):
64
+ yield (v, x), (u, y), k, _dict_product(c, d)
65
+ if G.is_multigraph() and H.is_multigraph():
66
+ for u, v, j, c in G.edges(data=True, keys=True):
67
+ for x, y, k, d in H.edges(data=True, keys=True):
68
+ yield (v, x), (u, y), (j, k), _dict_product(c, d)
69
+
70
+
71
+ def _edges_cross_nodes(G, H):
72
+ if G.is_multigraph():
73
+ for u, v, k, d in G.edges(data=True, keys=True):
74
+ for x in H:
75
+ yield (u, x), (v, x), k, d
76
+ else:
77
+ for u, v, d in G.edges(data=True):
78
+ for x in H:
79
+ if H.is_multigraph():
80
+ yield (u, x), (v, x), None, d
81
+ else:
82
+ yield (u, x), (v, x), d
83
+
84
+
85
+ def _nodes_cross_edges(G, H):
86
+ if H.is_multigraph():
87
+ for x in G:
88
+ for u, v, k, d in H.edges(data=True, keys=True):
89
+ yield (x, u), (x, v), k, d
90
+ else:
91
+ for x in G:
92
+ for u, v, d in H.edges(data=True):
93
+ if G.is_multigraph():
94
+ yield (x, u), (x, v), None, d
95
+ else:
96
+ yield (x, u), (x, v), d
97
+
98
+
99
+ def _edges_cross_nodes_and_nodes(G, H):
100
+ if G.is_multigraph():
101
+ for u, v, k, d in G.edges(data=True, keys=True):
102
+ for x in H:
103
+ for y in H:
104
+ yield (u, x), (v, y), k, d
105
+ else:
106
+ for u, v, d in G.edges(data=True):
107
+ for x in H:
108
+ for y in H:
109
+ if H.is_multigraph():
110
+ yield (u, x), (v, y), None, d
111
+ else:
112
+ yield (u, x), (v, y), d
113
+
114
+
115
+ def _init_product_graph(G, H):
116
+ if G.is_directed() != H.is_directed():
117
+ msg = "G and H must be both directed or both undirected"
118
+ raise nx.NetworkXError(msg)
119
+ if G.is_multigraph() or H.is_multigraph():
120
+ GH = nx.MultiGraph()
121
+ else:
122
+ GH = nx.Graph()
123
+ if G.is_directed():
124
+ GH = GH.to_directed()
125
+ return GH
126
+
127
+
128
+ @nx._dispatchable(graphs=_G_H, preserve_node_attrs=True, returns_graph=True)
129
+ def tensor_product(G, H):
130
+ r"""Returns the tensor product of G and H.
131
+
132
+ The tensor product $P$ of the graphs $G$ and $H$ has a node set that
133
+ is the Cartesian product of the node sets, $V(P)=V(G) \times V(H)$.
134
+ $P$ has an edge $((u,v), (x,y))$ if and only if $(u,x)$ is an edge in $G$
135
+ and $(v,y)$ is an edge in $H$.
136
+
137
+ Tensor product is sometimes also referred to as the categorical product,
138
+ direct product, cardinal product or conjunction.
139
+
140
+
141
+ Parameters
142
+ ----------
143
+ G, H: graphs
144
+ Networkx graphs.
145
+
146
+ Returns
147
+ -------
148
+ P: NetworkX graph
149
+ The tensor product of G and H. P will be a multi-graph if either G
150
+ or H is a multi-graph, will be a directed if G and H are directed,
151
+ and undirected if G and H are undirected.
152
+
153
+ Raises
154
+ ------
155
+ NetworkXError
156
+ If G and H are not both directed or both undirected.
157
+
158
+ Notes
159
+ -----
160
+ Node attributes in P are two-tuple of the G and H node attributes.
161
+ Missing attributes are assigned None.
162
+
163
+ Examples
164
+ --------
165
+ >>> G = nx.Graph()
166
+ >>> H = nx.Graph()
167
+ >>> G.add_node(0, a1=True)
168
+ >>> H.add_node("a", a2="Spam")
169
+ >>> P = nx.tensor_product(G, H)
170
+ >>> list(P)
171
+ [(0, 'a')]
172
+
173
+ Edge attributes and edge keys (for multigraphs) are also copied to the
174
+ new product graph
175
+ """
176
+ GH = _init_product_graph(G, H)
177
+ GH.add_nodes_from(_node_product(G, H))
178
+ GH.add_edges_from(_directed_edges_cross_edges(G, H))
179
+ if not GH.is_directed():
180
+ GH.add_edges_from(_undirected_edges_cross_edges(G, H))
181
+ return GH
182
+
183
+
184
+ @nx._dispatchable(graphs=_G_H, preserve_node_attrs=True, returns_graph=True)
185
+ def cartesian_product(G, H):
186
+ r"""Returns the Cartesian product of G and H.
187
+
188
+ The Cartesian product $P$ of the graphs $G$ and $H$ has a node set that
189
+ is the Cartesian product of the node sets, $V(P)=V(G) \times V(H)$.
190
+ $P$ has an edge $((u,v),(x,y))$ if and only if either $u$ is equal to $x$
191
+ and both $v$ and $y$ are adjacent in $H$ or if $v$ is equal to $y$ and
192
+ both $u$ and $x$ are adjacent in $G$.
193
+
194
+ Parameters
195
+ ----------
196
+ G, H: graphs
197
+ Networkx graphs.
198
+
199
+ Returns
200
+ -------
201
+ P: NetworkX graph
202
+ The Cartesian product of G and H. P will be a multi-graph if either G
203
+ or H is a multi-graph. Will be a directed if G and H are directed,
204
+ and undirected if G and H are undirected.
205
+
206
+ Raises
207
+ ------
208
+ NetworkXError
209
+ If G and H are not both directed or both undirected.
210
+
211
+ Notes
212
+ -----
213
+ Node attributes in P are two-tuple of the G and H node attributes.
214
+ Missing attributes are assigned None.
215
+
216
+ Examples
217
+ --------
218
+ >>> G = nx.Graph()
219
+ >>> H = nx.Graph()
220
+ >>> G.add_node(0, a1=True)
221
+ >>> H.add_node("a", a2="Spam")
222
+ >>> P = nx.cartesian_product(G, H)
223
+ >>> list(P)
224
+ [(0, 'a')]
225
+
226
+ Edge attributes and edge keys (for multigraphs) are also copied to the
227
+ new product graph
228
+ """
229
+ GH = _init_product_graph(G, H)
230
+ GH.add_nodes_from(_node_product(G, H))
231
+ GH.add_edges_from(_edges_cross_nodes(G, H))
232
+ GH.add_edges_from(_nodes_cross_edges(G, H))
233
+ return GH
234
+
235
+
236
+ @nx._dispatchable(graphs=_G_H, preserve_node_attrs=True, returns_graph=True)
237
+ def lexicographic_product(G, H):
238
+ r"""Returns the lexicographic product of G and H.
239
+
240
+ The lexicographical product $P$ of the graphs $G$ and $H$ has a node set
241
+ that is the Cartesian product of the node sets, $V(P)=V(G) \times V(H)$.
242
+ $P$ has an edge $((u,v), (x,y))$ if and only if $(u,v)$ is an edge in $G$
243
+ or $u==v$ and $(x,y)$ is an edge in $H$.
244
+
245
+ Parameters
246
+ ----------
247
+ G, H: graphs
248
+ Networkx graphs.
249
+
250
+ Returns
251
+ -------
252
+ P: NetworkX graph
253
+ The Cartesian product of G and H. P will be a multi-graph if either G
254
+ or H is a multi-graph. Will be a directed if G and H are directed,
255
+ and undirected if G and H are undirected.
256
+
257
+ Raises
258
+ ------
259
+ NetworkXError
260
+ If G and H are not both directed or both undirected.
261
+
262
+ Notes
263
+ -----
264
+ Node attributes in P are two-tuple of the G and H node attributes.
265
+ Missing attributes are assigned None.
266
+
267
+ Examples
268
+ --------
269
+ >>> G = nx.Graph()
270
+ >>> H = nx.Graph()
271
+ >>> G.add_node(0, a1=True)
272
+ >>> H.add_node("a", a2="Spam")
273
+ >>> P = nx.lexicographic_product(G, H)
274
+ >>> list(P)
275
+ [(0, 'a')]
276
+
277
+ Edge attributes and edge keys (for multigraphs) are also copied to the
278
+ new product graph
279
+ """
280
+ GH = _init_product_graph(G, H)
281
+ GH.add_nodes_from(_node_product(G, H))
282
+ # Edges in G regardless of H designation
283
+ GH.add_edges_from(_edges_cross_nodes_and_nodes(G, H))
284
+ # For each x in G, only if there is an edge in H
285
+ GH.add_edges_from(_nodes_cross_edges(G, H))
286
+ return GH
287
+
288
+
289
+ @nx._dispatchable(graphs=_G_H, preserve_node_attrs=True, returns_graph=True)
290
+ def strong_product(G, H):
291
+ r"""Returns the strong product of G and H.
292
+
293
+ The strong product $P$ of the graphs $G$ and $H$ has a node set that
294
+ is the Cartesian product of the node sets, $V(P)=V(G) \times V(H)$.
295
+ $P$ has an edge $((u,x), (v,y))$ if any of the following conditions
296
+ are met:
297
+
298
+ - $u=v$ and $(x,y)$ is an edge in $H$
299
+ - $x=y$ and $(u,v)$ is an edge in $G$
300
+ - $(u,v)$ is an edge in $G$ and $(x,y)$ is an edge in $H$
301
+
302
+ Parameters
303
+ ----------
304
+ G, H: graphs
305
+ Networkx graphs.
306
+
307
+ Returns
308
+ -------
309
+ P: NetworkX graph
310
+ The Cartesian product of G and H. P will be a multi-graph if either G
311
+ or H is a multi-graph. Will be a directed if G and H are directed,
312
+ and undirected if G and H are undirected.
313
+
314
+ Raises
315
+ ------
316
+ NetworkXError
317
+ If G and H are not both directed or both undirected.
318
+
319
+ Notes
320
+ -----
321
+ Node attributes in P are two-tuple of the G and H node attributes.
322
+ Missing attributes are assigned None.
323
+
324
+ Examples
325
+ --------
326
+ >>> G = nx.Graph()
327
+ >>> H = nx.Graph()
328
+ >>> G.add_node(0, a1=True)
329
+ >>> H.add_node("a", a2="Spam")
330
+ >>> P = nx.strong_product(G, H)
331
+ >>> list(P)
332
+ [(0, 'a')]
333
+
334
+ Edge attributes and edge keys (for multigraphs) are also copied to the
335
+ new product graph
336
+ """
337
+ GH = _init_product_graph(G, H)
338
+ GH.add_nodes_from(_node_product(G, H))
339
+ GH.add_edges_from(_nodes_cross_edges(G, H))
340
+ GH.add_edges_from(_edges_cross_nodes(G, H))
341
+ GH.add_edges_from(_directed_edges_cross_edges(G, H))
342
+ if not GH.is_directed():
343
+ GH.add_edges_from(_undirected_edges_cross_edges(G, H))
344
+ return GH
345
+
346
+
347
+ @not_implemented_for("directed")
348
+ @not_implemented_for("multigraph")
349
+ @nx._dispatchable(returns_graph=True)
350
+ def power(G, k):
351
+ """Returns the specified power of a graph.
352
+
353
+ The $k$th power of a simple graph $G$, denoted $G^k$, is a
354
+ graph on the same set of nodes in which two distinct nodes $u$ and
355
+ $v$ are adjacent in $G^k$ if and only if the shortest path
356
+ distance between $u$ and $v$ in $G$ is at most $k$.
357
+
358
+ Parameters
359
+ ----------
360
+ G : graph
361
+ A NetworkX simple graph object.
362
+
363
+ k : positive integer
364
+ The power to which to raise the graph `G`.
365
+
366
+ Returns
367
+ -------
368
+ NetworkX simple graph
369
+ `G` to the power `k`.
370
+
371
+ Raises
372
+ ------
373
+ ValueError
374
+ If the exponent `k` is not positive.
375
+
376
+ NetworkXNotImplemented
377
+ If `G` is not a simple graph.
378
+
379
+ Examples
380
+ --------
381
+ The number of edges will never decrease when taking successive
382
+ powers:
383
+
384
+ >>> G = nx.path_graph(4)
385
+ >>> list(nx.power(G, 2).edges)
386
+ [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3)]
387
+ >>> list(nx.power(G, 3).edges)
388
+ [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]
389
+
390
+ The `k` th power of a cycle graph on *n* nodes is the complete graph
391
+ on *n* nodes, if `k` is at least ``n // 2``:
392
+
393
+ >>> G = nx.cycle_graph(5)
394
+ >>> H = nx.complete_graph(5)
395
+ >>> nx.is_isomorphic(nx.power(G, 2), H)
396
+ True
397
+ >>> G = nx.cycle_graph(8)
398
+ >>> H = nx.complete_graph(8)
399
+ >>> nx.is_isomorphic(nx.power(G, 4), H)
400
+ True
401
+
402
+ References
403
+ ----------
404
+ .. [1] J. A. Bondy, U. S. R. Murty, *Graph Theory*. Springer, 2008.
405
+
406
+ Notes
407
+ -----
408
+ This definition of "power graph" comes from Exercise 3.1.6 of
409
+ *Graph Theory* by Bondy and Murty [1]_.
410
+
411
+ """
412
+ if k <= 0:
413
+ raise ValueError("k must be a positive integer")
414
+ H = nx.Graph()
415
+ H.add_nodes_from(G)
416
+ # update BFS code to ignore self loops.
417
+ for n in G:
418
+ seen = {} # level (number of hops) when seen in BFS
419
+ level = 1 # the current level
420
+ nextlevel = G[n]
421
+ while nextlevel:
422
+ thislevel = nextlevel # advance to next level
423
+ nextlevel = {} # and start a new list (fringe)
424
+ for v in thislevel:
425
+ if v == n: # avoid self loop
426
+ continue
427
+ if v not in seen:
428
+ seen[v] = level # set the level of vertex v
429
+ nextlevel.update(G[v]) # add neighbors of v
430
+ if k <= level:
431
+ break
432
+ level += 1
433
+ H.add_edges_from((n, nbr) for nbr in seen)
434
+ return H
435
+
436
+
437
+ @not_implemented_for("multigraph")
438
+ @nx._dispatchable(graphs=_G_H, returns_graph=True)
439
+ def rooted_product(G, H, root):
440
+ """Return the rooted product of graphs G and H rooted at root in H.
441
+
442
+ A new graph is constructed representing the rooted product of
443
+ the inputted graphs, G and H, with a root in H.
444
+ A rooted product duplicates H for each nodes in G with the root
445
+ of H corresponding to the node in G. Nodes are renamed as the direct
446
+ product of G and H. The result is a subgraph of the cartesian product.
447
+
448
+ Parameters
449
+ ----------
450
+ G,H : graph
451
+ A NetworkX graph
452
+ root : node
453
+ A node in H
454
+
455
+ Returns
456
+ -------
457
+ R : The rooted product of G and H with a specified root in H
458
+
459
+ Notes
460
+ -----
461
+ The nodes of R are the Cartesian Product of the nodes of G and H.
462
+ The nodes of G and H are not relabeled.
463
+ """
464
+ if root not in H:
465
+ raise nx.NodeNotFound("root must be a vertex in H")
466
+
467
+ R = nx.Graph()
468
+ R.add_nodes_from(product(G, H))
469
+
470
+ R.add_edges_from(((e[0], root), (e[1], root)) for e in G.edges())
471
+ R.add_edges_from(((g, e[0]), (g, e[1])) for g in G for e in H.edges())
472
+
473
+ return R
474
+
475
+
476
+ @not_implemented_for("directed")
477
+ @not_implemented_for("multigraph")
478
+ @nx._dispatchable(graphs=_G_H, returns_graph=True)
479
+ def corona_product(G, H):
480
+ r"""Returns the Corona product of G and H.
481
+
482
+ The corona product of $G$ and $H$ is the graph $C = G \circ H$ obtained by
483
+ taking one copy of $G$, called the center graph, $|V(G)|$ copies of $H$,
484
+ called the outer graph, and making the $i$-th vertex of $G$ adjacent to
485
+ every vertex of the $i$-th copy of $H$, where $1 ≤ i ≤ |V(G)|$.
486
+
487
+ Parameters
488
+ ----------
489
+ G, H: NetworkX graphs
490
+ The graphs to take the carona product of.
491
+ `G` is the center graph and `H` is the outer graph
492
+
493
+ Returns
494
+ -------
495
+ C: NetworkX graph
496
+ The Corona product of G and H.
497
+
498
+ Raises
499
+ ------
500
+ NetworkXError
501
+ If G and H are not both directed or both undirected.
502
+
503
+ Examples
504
+ --------
505
+ >>> G = nx.cycle_graph(4)
506
+ >>> H = nx.path_graph(2)
507
+ >>> C = nx.corona_product(G, H)
508
+ >>> list(C)
509
+ [0, 1, 2, 3, (0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1), (3, 0), (3, 1)]
510
+ >>> print(C)
511
+ Graph with 12 nodes and 16 edges
512
+
513
+ References
514
+ ----------
515
+ [1] M. Tavakoli, F. Rahbarnia, and A. R. Ashrafi,
516
+ "Studying the corona product of graphs under some graph invariants,"
517
+ Transactions on Combinatorics, vol. 3, no. 3, pp. 43–49, Sep. 2014,
518
+ doi: 10.22108/toc.2014.5542.
519
+ [2] A. Faraji, "Corona Product in Graph Theory," Ali Faraji, May 11, 2021.
520
+ https://blog.alifaraji.ir/math/graph-theory/corona-product.html (accessed Dec. 07, 2021).
521
+ """
522
+ GH = _init_product_graph(G, H)
523
+ GH.add_nodes_from(G)
524
+ GH.add_edges_from(G.edges)
525
+
526
+ for G_node in G:
527
+ # copy nodes of H in GH, call it H_i
528
+ GH.add_nodes_from((G_node, v) for v in H)
529
+
530
+ # copy edges of H_i based on H
531
+ GH.add_edges_from(
532
+ ((G_node, e0), (G_node, e1), d) for e0, e1, d in H.edges.data()
533
+ )
534
+
535
+ # creating new edges between H_i and a G's node
536
+ GH.add_edges_from((G_node, (G_node, H_node)) for H_node in H)
537
+
538
+ return GH
539
+
540
+
541
+ @nx._dispatchable(
542
+ graphs=_G_H, preserve_edge_attrs=True, preserve_node_attrs=True, returns_graph=True
543
+ )
544
+ def modular_product(G, H):
545
+ r"""Returns the Modular product of G and H.
546
+
547
+ The modular product of `G` and `H` is the graph $M = G \nabla H$,
548
+ consisting of the node set $V(M) = V(G) \times V(H)$ that is the Cartesian
549
+ product of the node sets of `G` and `H`. Further, M contains an edge ((u, v), (x, y)):
550
+
551
+ - if u is adjacent to x in `G` and v is adjacent to y in `H`, or
552
+ - if u is not adjacent to x in `G` and v is not adjacent to y in `H`.
553
+
554
+ More formally::
555
+
556
+ E(M) = {((u, v), (x, y)) | ((u, x) in E(G) and (v, y) in E(H)) or
557
+ ((u, x) not in E(G) and (v, y) not in E(H))}
558
+
559
+ Parameters
560
+ ----------
561
+ G, H: NetworkX graphs
562
+ The graphs to take the modular product of.
563
+
564
+ Returns
565
+ -------
566
+ M: NetworkX graph
567
+ The Modular product of `G` and `H`.
568
+
569
+ Raises
570
+ ------
571
+ NetworkXNotImplemented
572
+ If `G` is not a simple graph.
573
+
574
+ Examples
575
+ --------
576
+ >>> G = nx.cycle_graph(4)
577
+ >>> H = nx.path_graph(2)
578
+ >>> M = nx.modular_product(G, H)
579
+ >>> list(M)
580
+ [(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1), (3, 0), (3, 1)]
581
+ >>> print(M)
582
+ Graph with 8 nodes and 8 edges
583
+
584
+ Notes
585
+ -----
586
+ The *modular product* is defined in [1]_ and was first
587
+ introduced as the *weak modular product*.
588
+
589
+ The modular product reduces the problem of counting isomorphic subgraphs
590
+ in `G` and `H` to the problem of counting cliques in M. The subgraphs of
591
+ `G` and `H` that are induced by the nodes of a clique in M are
592
+ isomorphic [2]_ [3]_.
593
+
594
+ References
595
+ ----------
596
+ .. [1] R. Hammack, W. Imrich, and S. Klavžar,
597
+ "Handbook of Product Graphs", CRC Press, 2011.
598
+
599
+ .. [2] H. G. Barrow and R. M. Burstall,
600
+ "Subgraph isomorphism, matching relational structures and maximal
601
+ cliques", Information Processing Letters, vol. 4, issue 4, pp. 83-84,
602
+ 1976, https://doi.org/10.1016/0020-0190(76)90049-1.
603
+
604
+ .. [3] V. G. Vizing, "Reduction of the problem of isomorphism and isomorphic
605
+ entrance to the task of finding the nondensity of a graph." Proc. Third
606
+ All-Union Conference on Problems of Theoretical Cybernetics. 1974.
607
+ """
608
+ if G.is_directed() or H.is_directed():
609
+ raise nx.NetworkXNotImplemented(
610
+ "Modular product not implemented for directed graphs"
611
+ )
612
+ if G.is_multigraph() or H.is_multigraph():
613
+ raise nx.NetworkXNotImplemented(
614
+ "Modular product not implemented for multigraphs"
615
+ )
616
+
617
+ GH = _init_product_graph(G, H)
618
+ GH.add_nodes_from(_node_product(G, H))
619
+
620
+ for u, v, c in G.edges(data=True):
621
+ for x, y, d in H.edges(data=True):
622
+ GH.add_edge((u, x), (v, y), **_dict_product(c, d))
623
+ GH.add_edge((v, x), (u, y), **_dict_product(c, d))
624
+
625
+ G = nx.complement(G)
626
+ H = nx.complement(H)
627
+
628
+ for u, v, c in G.edges(data=True):
629
+ for x, y, d in H.edges(data=True):
630
+ GH.add_edge((u, x), (v, y), **_dict_product(c, d))
631
+ GH.add_edge((v, x), (u, y), **_dict_product(c, d))
632
+
633
+ return GH