Add files using upload-large-folder tool
Browse files- venv/Lib/site-packages/accelerate-1.6.0.dist-info/INSTALLER +1 -0
- venv/Lib/site-packages/accelerate-1.6.0.dist-info/LICENSE +201 -0
- venv/Lib/site-packages/accelerate-1.6.0.dist-info/METADATA +380 -0
- venv/Lib/site-packages/accelerate-1.6.0.dist-info/RECORD +177 -0
- venv/Lib/site-packages/accelerate-1.6.0.dist-info/REQUESTED +0 -0
- venv/Lib/site-packages/accelerate-1.6.0.dist-info/WHEEL +5 -0
- venv/Lib/site-packages/accelerate-1.6.0.dist-info/entry_points.txt +6 -0
- venv/Lib/site-packages/accelerate-1.6.0.dist-info/top_level.txt +1 -0
- venv/Lib/site-packages/accelerate/__init__.py +50 -0
- venv/Lib/site-packages/accelerate/accelerator.py +0 -0
- venv/Lib/site-packages/accelerate/big_modeling.py +637 -0
- venv/Lib/site-packages/accelerate/checkpointing.py +319 -0
- venv/Lib/site-packages/accelerate/data_loader.py +1429 -0
- venv/Lib/site-packages/accelerate/hooks.py +739 -0
- venv/Lib/site-packages/accelerate/inference.py +184 -0
- venv/Lib/site-packages/accelerate/launchers.py +301 -0
- venv/Lib/site-packages/accelerate/local_sgd.py +106 -0
- venv/Lib/site-packages/accelerate/logging.py +125 -0
- venv/Lib/site-packages/accelerate/memory_utils.py +22 -0
- venv/Lib/site-packages/accelerate/optimizer.py +212 -0
- venv/Lib/site-packages/accelerate/scheduler.py +98 -0
- venv/Lib/site-packages/accelerate/state.py +1330 -0
- venv/Lib/site-packages/accelerate/tracking.py +1089 -0
- venv/Lib/site-packages/adodbapi/__init__.py +82 -0
- venv/Lib/site-packages/adodbapi/ado_consts.py +283 -0
- venv/Lib/site-packages/adodbapi/adodbapi.py +1153 -0
- venv/Lib/site-packages/adodbapi/apibase.py +723 -0
- venv/Lib/site-packages/adodbapi/is64bit.py +34 -0
- venv/Lib/site-packages/adodbapi/license.txt +505 -0
- venv/Lib/site-packages/adodbapi/process_connect_string.py +137 -0
- venv/Lib/site-packages/adodbapi/readme.txt +88 -0
- venv/Lib/site-packages/adodbapi/schema_table.py +16 -0
- venv/Lib/site-packages/adodbapi/setup.py +68 -0
- venv/Lib/site-packages/aiohappyeyeballs-2.6.1.dist-info/INSTALLER +1 -0
- venv/Lib/site-packages/aiohappyeyeballs-2.6.1.dist-info/LICENSE +279 -0
- venv/Lib/site-packages/aiohappyeyeballs-2.6.1.dist-info/METADATA +123 -0
- venv/Lib/site-packages/aiohappyeyeballs-2.6.1.dist-info/RECORD +16 -0
- venv/Lib/site-packages/aiohappyeyeballs-2.6.1.dist-info/WHEEL +4 -0
- venv/Lib/site-packages/aiohappyeyeballs/__init__.py +14 -0
- venv/Lib/site-packages/aiohappyeyeballs/_staggered.py +207 -0
- venv/Lib/site-packages/aiohappyeyeballs/impl.py +259 -0
- venv/Lib/site-packages/aiohappyeyeballs/py.typed +0 -0
- venv/Lib/site-packages/aiohappyeyeballs/types.py +17 -0
- venv/Lib/site-packages/aiohappyeyeballs/utils.py +97 -0
- venv/Lib/site-packages/aiohttp/abc.py +253 -0
- venv/Lib/site-packages/aiohttp/base_protocol.py +100 -0
- venv/Lib/site-packages/scipy-1.15.3-cp312-cp312-win_amd64.whl +0 -0
- venv/Lib/site-packages/six.py +1003 -0
- venv/Lib/site-packages/threadpoolctl.py +1292 -0
- venv/Lib/site-packages/typing_extensions.py +0 -0
venv/Lib/site-packages/accelerate-1.6.0.dist-info/INSTALLER
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
pip
|
venv/Lib/site-packages/accelerate-1.6.0.dist-info/LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
venv/Lib/site-packages/accelerate-1.6.0.dist-info/METADATA
ADDED
|
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.1
|
| 2 |
+
Name: accelerate
|
| 3 |
+
Version: 1.6.0
|
| 4 |
+
Summary: Accelerate
|
| 5 |
+
Home-page: https://github.com/huggingface/accelerate
|
| 6 |
+
Author: The HuggingFace team
|
| 7 |
+
Author-email: zach.mueller@huggingface.co
|
| 8 |
+
License: Apache
|
| 9 |
+
Keywords: deep learning
|
| 10 |
+
Classifier: Development Status :: 5 - Production/Stable
|
| 11 |
+
Classifier: Intended Audience :: Developers
|
| 12 |
+
Classifier: Intended Audience :: Education
|
| 13 |
+
Classifier: Intended Audience :: Science/Research
|
| 14 |
+
Classifier: License :: OSI Approved :: Apache Software License
|
| 15 |
+
Classifier: Operating System :: OS Independent
|
| 16 |
+
Classifier: Programming Language :: Python :: 3
|
| 17 |
+
Classifier: Programming Language :: Python :: 3.9
|
| 18 |
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
| 19 |
+
Requires-Python: >=3.9.0
|
| 20 |
+
Description-Content-Type: text/markdown
|
| 21 |
+
License-File: LICENSE
|
| 22 |
+
Requires-Dist: numpy<3.0.0,>=1.17
|
| 23 |
+
Requires-Dist: packaging>=20.0
|
| 24 |
+
Requires-Dist: psutil
|
| 25 |
+
Requires-Dist: pyyaml
|
| 26 |
+
Requires-Dist: torch>=2.0.0
|
| 27 |
+
Requires-Dist: huggingface-hub>=0.21.0
|
| 28 |
+
Requires-Dist: safetensors>=0.4.3
|
| 29 |
+
Provides-Extra: deepspeed
|
| 30 |
+
Requires-Dist: deepspeed; extra == "deepspeed"
|
| 31 |
+
Provides-Extra: dev
|
| 32 |
+
Requires-Dist: black~=23.1; extra == "dev"
|
| 33 |
+
Requires-Dist: hf-doc-builder>=0.3.0; extra == "dev"
|
| 34 |
+
Requires-Dist: ruff~=0.11.2; extra == "dev"
|
| 35 |
+
Requires-Dist: pytest<=8.0.0,>=7.2.0; extra == "dev"
|
| 36 |
+
Requires-Dist: pytest-xdist; extra == "dev"
|
| 37 |
+
Requires-Dist: pytest-subtests; extra == "dev"
|
| 38 |
+
Requires-Dist: parameterized; extra == "dev"
|
| 39 |
+
Requires-Dist: pytest-order; extra == "dev"
|
| 40 |
+
Requires-Dist: datasets; extra == "dev"
|
| 41 |
+
Requires-Dist: diffusers; extra == "dev"
|
| 42 |
+
Requires-Dist: evaluate; extra == "dev"
|
| 43 |
+
Requires-Dist: torchdata>=0.8.0; extra == "dev"
|
| 44 |
+
Requires-Dist: torchpippy>=0.2.0; extra == "dev"
|
| 45 |
+
Requires-Dist: transformers; extra == "dev"
|
| 46 |
+
Requires-Dist: scipy; extra == "dev"
|
| 47 |
+
Requires-Dist: scikit-learn; extra == "dev"
|
| 48 |
+
Requires-Dist: tqdm; extra == "dev"
|
| 49 |
+
Requires-Dist: bitsandbytes; extra == "dev"
|
| 50 |
+
Requires-Dist: timm; extra == "dev"
|
| 51 |
+
Requires-Dist: rich; extra == "dev"
|
| 52 |
+
Provides-Extra: docs
|
| 53 |
+
Provides-Extra: quality
|
| 54 |
+
Requires-Dist: black~=23.1; extra == "quality"
|
| 55 |
+
Requires-Dist: hf-doc-builder>=0.3.0; extra == "quality"
|
| 56 |
+
Requires-Dist: ruff~=0.11.2; extra == "quality"
|
| 57 |
+
Provides-Extra: rich
|
| 58 |
+
Requires-Dist: rich; extra == "rich"
|
| 59 |
+
Provides-Extra: sagemaker
|
| 60 |
+
Requires-Dist: sagemaker; extra == "sagemaker"
|
| 61 |
+
Provides-Extra: test_dev
|
| 62 |
+
Requires-Dist: datasets; extra == "test-dev"
|
| 63 |
+
Requires-Dist: diffusers; extra == "test-dev"
|
| 64 |
+
Requires-Dist: evaluate; extra == "test-dev"
|
| 65 |
+
Requires-Dist: torchdata>=0.8.0; extra == "test-dev"
|
| 66 |
+
Requires-Dist: torchpippy>=0.2.0; extra == "test-dev"
|
| 67 |
+
Requires-Dist: transformers; extra == "test-dev"
|
| 68 |
+
Requires-Dist: scipy; extra == "test-dev"
|
| 69 |
+
Requires-Dist: scikit-learn; extra == "test-dev"
|
| 70 |
+
Requires-Dist: tqdm; extra == "test-dev"
|
| 71 |
+
Requires-Dist: bitsandbytes; extra == "test-dev"
|
| 72 |
+
Requires-Dist: timm; extra == "test-dev"
|
| 73 |
+
Provides-Extra: test_prod
|
| 74 |
+
Requires-Dist: pytest<=8.0.0,>=7.2.0; extra == "test-prod"
|
| 75 |
+
Requires-Dist: pytest-xdist; extra == "test-prod"
|
| 76 |
+
Requires-Dist: pytest-subtests; extra == "test-prod"
|
| 77 |
+
Requires-Dist: parameterized; extra == "test-prod"
|
| 78 |
+
Requires-Dist: pytest-order; extra == "test-prod"
|
| 79 |
+
Provides-Extra: test_trackers
|
| 80 |
+
Requires-Dist: wandb; extra == "test-trackers"
|
| 81 |
+
Requires-Dist: comet-ml; extra == "test-trackers"
|
| 82 |
+
Requires-Dist: tensorboard; extra == "test-trackers"
|
| 83 |
+
Requires-Dist: dvclive; extra == "test-trackers"
|
| 84 |
+
Requires-Dist: mlflow; extra == "test-trackers"
|
| 85 |
+
Requires-Dist: matplotlib; extra == "test-trackers"
|
| 86 |
+
Provides-Extra: testing
|
| 87 |
+
Requires-Dist: pytest<=8.0.0,>=7.2.0; extra == "testing"
|
| 88 |
+
Requires-Dist: pytest-xdist; extra == "testing"
|
| 89 |
+
Requires-Dist: pytest-subtests; extra == "testing"
|
| 90 |
+
Requires-Dist: parameterized; extra == "testing"
|
| 91 |
+
Requires-Dist: pytest-order; extra == "testing"
|
| 92 |
+
Requires-Dist: datasets; extra == "testing"
|
| 93 |
+
Requires-Dist: diffusers; extra == "testing"
|
| 94 |
+
Requires-Dist: evaluate; extra == "testing"
|
| 95 |
+
Requires-Dist: torchdata>=0.8.0; extra == "testing"
|
| 96 |
+
Requires-Dist: torchpippy>=0.2.0; extra == "testing"
|
| 97 |
+
Requires-Dist: transformers; extra == "testing"
|
| 98 |
+
Requires-Dist: scipy; extra == "testing"
|
| 99 |
+
Requires-Dist: scikit-learn; extra == "testing"
|
| 100 |
+
Requires-Dist: tqdm; extra == "testing"
|
| 101 |
+
Requires-Dist: bitsandbytes; extra == "testing"
|
| 102 |
+
Requires-Dist: timm; extra == "testing"
|
| 103 |
+
|
| 104 |
+
<!---
|
| 105 |
+
Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 106 |
+
|
| 107 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 108 |
+
you may not use this file except in compliance with the License.
|
| 109 |
+
You may obtain a copy of the License at
|
| 110 |
+
|
| 111 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 112 |
+
|
| 113 |
+
Unless required by applicable law or agreed to in writing, software
|
| 114 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 115 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 116 |
+
See the License for the specific language governing permissions and
|
| 117 |
+
limitations under the License.
|
| 118 |
+
-->
|
| 119 |
+
|
| 120 |
+
<p align="center">
|
| 121 |
+
<br>
|
| 122 |
+
<img src="https://raw.githubusercontent.com/huggingface/accelerate/main/docs/source/imgs/accelerate_logo.png" width="400"/>
|
| 123 |
+
<br>
|
| 124 |
+
<p>
|
| 125 |
+
|
| 126 |
+
<p align="center">
|
| 127 |
+
<!-- Uncomment when CircleCI is set up
|
| 128 |
+
<a href="https://circleci.com/gh/huggingface/accelerate"><img alt="Build" src="https://img.shields.io/circleci/build/github/huggingface/transformers/master"></a>
|
| 129 |
+
-->
|
| 130 |
+
<a href="https://github.com/huggingface/accelerate/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/github/license/huggingface/accelerate.svg?color=blue"></a>
|
| 131 |
+
<a href="https://huggingface.co/docs/accelerate/index.html"><img alt="Documentation" src="https://img.shields.io/website/http/huggingface.co/docs/accelerate/index.html.svg?down_color=red&down_message=offline&up_message=online"></a>
|
| 132 |
+
<a href="https://github.com/huggingface/accelerate/releases"><img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/accelerate.svg"></a>
|
| 133 |
+
<a href="https://github.com/huggingface/accelerate/blob/main/CODE_OF_CONDUCT.md"><img alt="Contributor Covenant" src="https://img.shields.io/badge/Contributor%20Covenant-v2.0%20adopted-ff69b4.svg"></a>
|
| 134 |
+
</p>
|
| 135 |
+
|
| 136 |
+
<h3 align="center">
|
| 137 |
+
<p>Run your *raw* PyTorch training script on any kind of device
|
| 138 |
+
</h3>
|
| 139 |
+
|
| 140 |
+
<h3 align="center">
|
| 141 |
+
<a href="https://hf.co/course"><img src="https://raw.githubusercontent.com/huggingface/accelerate/main/docs/source/imgs/course_banner.png"></a>
|
| 142 |
+
</h3>
|
| 143 |
+
|
| 144 |
+
## Easy to integrate
|
| 145 |
+
|
| 146 |
+
🤗 Accelerate was created for PyTorch users who like to write the training loop of PyTorch models but are reluctant to write and maintain the boilerplate code needed to use multi-GPUs/TPU/fp16.
|
| 147 |
+
|
| 148 |
+
🤗 Accelerate abstracts exactly and only the boilerplate code related to multi-GPUs/TPU/fp16 and leaves the rest of your code unchanged.
|
| 149 |
+
|
| 150 |
+
Here is an example:
|
| 151 |
+
|
| 152 |
+
```diff
|
| 153 |
+
import torch
|
| 154 |
+
import torch.nn.functional as F
|
| 155 |
+
from datasets import load_dataset
|
| 156 |
+
+ from accelerate import Accelerator
|
| 157 |
+
|
| 158 |
+
+ accelerator = Accelerator()
|
| 159 |
+
- device = 'cpu'
|
| 160 |
+
+ device = accelerator.device
|
| 161 |
+
|
| 162 |
+
model = torch.nn.Transformer().to(device)
|
| 163 |
+
optimizer = torch.optim.Adam(model.parameters())
|
| 164 |
+
|
| 165 |
+
dataset = load_dataset('my_dataset')
|
| 166 |
+
data = torch.utils.data.DataLoader(dataset, shuffle=True)
|
| 167 |
+
|
| 168 |
+
+ model, optimizer, data = accelerator.prepare(model, optimizer, data)
|
| 169 |
+
|
| 170 |
+
model.train()
|
| 171 |
+
for epoch in range(10):
|
| 172 |
+
for source, targets in data:
|
| 173 |
+
source = source.to(device)
|
| 174 |
+
targets = targets.to(device)
|
| 175 |
+
|
| 176 |
+
optimizer.zero_grad()
|
| 177 |
+
|
| 178 |
+
output = model(source)
|
| 179 |
+
loss = F.cross_entropy(output, targets)
|
| 180 |
+
|
| 181 |
+
- loss.backward()
|
| 182 |
+
+ accelerator.backward(loss)
|
| 183 |
+
|
| 184 |
+
optimizer.step()
|
| 185 |
+
```
|
| 186 |
+
|
| 187 |
+
As you can see in this example, by adding 5-lines to any standard PyTorch training script you can now run on any kind of single or distributed node setting (single CPU, single GPU, multi-GPUs and TPUs) as well as with or without mixed precision (fp8, fp16, bf16).
|
| 188 |
+
|
| 189 |
+
In particular, the same code can then be run without modification on your local machine for debugging or your training environment.
|
| 190 |
+
|
| 191 |
+
🤗 Accelerate even handles the device placement for you (which requires a few more changes to your code, but is safer in general), so you can even simplify your training loop further:
|
| 192 |
+
|
| 193 |
+
```diff
|
| 194 |
+
import torch
|
| 195 |
+
import torch.nn.functional as F
|
| 196 |
+
from datasets import load_dataset
|
| 197 |
+
+ from accelerate import Accelerator
|
| 198 |
+
|
| 199 |
+
- device = 'cpu'
|
| 200 |
+
+ accelerator = Accelerator()
|
| 201 |
+
|
| 202 |
+
- model = torch.nn.Transformer().to(device)
|
| 203 |
+
+ model = torch.nn.Transformer()
|
| 204 |
+
optimizer = torch.optim.Adam(model.parameters())
|
| 205 |
+
|
| 206 |
+
dataset = load_dataset('my_dataset')
|
| 207 |
+
data = torch.utils.data.DataLoader(dataset, shuffle=True)
|
| 208 |
+
|
| 209 |
+
+ model, optimizer, data = accelerator.prepare(model, optimizer, data)
|
| 210 |
+
|
| 211 |
+
model.train()
|
| 212 |
+
for epoch in range(10):
|
| 213 |
+
for source, targets in data:
|
| 214 |
+
- source = source.to(device)
|
| 215 |
+
- targets = targets.to(device)
|
| 216 |
+
|
| 217 |
+
optimizer.zero_grad()
|
| 218 |
+
|
| 219 |
+
output = model(source)
|
| 220 |
+
loss = F.cross_entropy(output, targets)
|
| 221 |
+
|
| 222 |
+
- loss.backward()
|
| 223 |
+
+ accelerator.backward(loss)
|
| 224 |
+
|
| 225 |
+
optimizer.step()
|
| 226 |
+
```
|
| 227 |
+
|
| 228 |
+
Want to learn more? Check out the [documentation](https://huggingface.co/docs/accelerate) or have a look at our [examples](https://github.com/huggingface/accelerate/tree/main/examples).
|
| 229 |
+
|
| 230 |
+
## Launching script
|
| 231 |
+
|
| 232 |
+
🤗 Accelerate also provides an optional CLI tool that allows you to quickly configure and test your training environment before launching the scripts. No need to remember how to use `torch.distributed.run` or to write a specific launcher for TPU training!
|
| 233 |
+
On your machine(s) just run:
|
| 234 |
+
|
| 235 |
+
```bash
|
| 236 |
+
accelerate config
|
| 237 |
+
```
|
| 238 |
+
|
| 239 |
+
and answer the questions asked. This will generate a config file that will be used automatically to properly set the default options when doing
|
| 240 |
+
|
| 241 |
+
```bash
|
| 242 |
+
accelerate launch my_script.py --args_to_my_script
|
| 243 |
+
```
|
| 244 |
+
|
| 245 |
+
For instance, here is how you would run the GLUE example on the MRPC task (from the root of the repo):
|
| 246 |
+
|
| 247 |
+
```bash
|
| 248 |
+
accelerate launch examples/nlp_example.py
|
| 249 |
+
```
|
| 250 |
+
|
| 251 |
+
This CLI tool is **optional**, and you can still use `python my_script.py` or `python -m torchrun my_script.py` at your convenience.
|
| 252 |
+
|
| 253 |
+
You can also directly pass in the arguments you would to `torchrun` as arguments to `accelerate launch` if you wish to not run` accelerate config`.
|
| 254 |
+
|
| 255 |
+
For example, here is how to launch on two GPUs:
|
| 256 |
+
|
| 257 |
+
```bash
|
| 258 |
+
accelerate launch --multi_gpu --num_processes 2 examples/nlp_example.py
|
| 259 |
+
```
|
| 260 |
+
|
| 261 |
+
To learn more, check the CLI documentation available [here](https://huggingface.co/docs/accelerate/package_reference/cli).
|
| 262 |
+
|
| 263 |
+
Or view the configuration zoo [here](https://github.com/huggingface/accelerate/blob/main/examples/config_yaml_templates/)
|
| 264 |
+
|
| 265 |
+
## Launching multi-CPU run using MPI
|
| 266 |
+
|
| 267 |
+
🤗 Here is another way to launch multi-CPU run using MPI. You can learn how to install Open MPI on [this page](https://www.open-mpi.org/faq/?category=building#easy-build). You can use Intel MPI or MVAPICH as well.
|
| 268 |
+
Once you have MPI setup on your cluster, just run:
|
| 269 |
+
```bash
|
| 270 |
+
accelerate config
|
| 271 |
+
```
|
| 272 |
+
Answer the questions that are asked, selecting to run using multi-CPU, and answer "yes" when asked if you want accelerate to launch mpirun.
|
| 273 |
+
Then, use `accelerate launch` with your script like:
|
| 274 |
+
```bash
|
| 275 |
+
accelerate launch examples/nlp_example.py
|
| 276 |
+
```
|
| 277 |
+
Alternatively, you can use mpirun directly, without using the CLI like:
|
| 278 |
+
```bash
|
| 279 |
+
mpirun -np 2 python examples/nlp_example.py
|
| 280 |
+
```
|
| 281 |
+
|
| 282 |
+
## Launching training using DeepSpeed
|
| 283 |
+
|
| 284 |
+
🤗 Accelerate supports training on single/multiple GPUs using DeepSpeed. To use it, you don't need to change anything in your training code; you can set everything using just `accelerate config`. However, if you desire to tweak your DeepSpeed related args from your Python script, we provide you the `DeepSpeedPlugin`.
|
| 285 |
+
|
| 286 |
+
```python
|
| 287 |
+
from accelerate import Accelerator, DeepSpeedPlugin
|
| 288 |
+
|
| 289 |
+
# deepspeed needs to know your gradient accumulation steps beforehand, so don't forget to pass it
|
| 290 |
+
# Remember you still need to do gradient accumulation by yourself, just like you would have done without deepspeed
|
| 291 |
+
deepspeed_plugin = DeepSpeedPlugin(zero_stage=2, gradient_accumulation_steps=2)
|
| 292 |
+
accelerator = Accelerator(mixed_precision='fp16', deepspeed_plugin=deepspeed_plugin)
|
| 293 |
+
|
| 294 |
+
# How to save your 🤗 Transformer?
|
| 295 |
+
accelerator.wait_for_everyone()
|
| 296 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
| 297 |
+
unwrapped_model.save_pretrained(save_dir, save_function=accelerator.save, state_dict=accelerator.get_state_dict(model))
|
| 298 |
+
```
|
| 299 |
+
|
| 300 |
+
Note: DeepSpeed support is experimental for now. In case you get into some problem, please open an issue.
|
| 301 |
+
|
| 302 |
+
## Launching your training from a notebook
|
| 303 |
+
|
| 304 |
+
🤗 Accelerate also provides a `notebook_launcher` function you can use in a notebook to launch a distributed training. This is especially useful for Colab or Kaggle notebooks with a TPU backend. Just define your training loop in a `training_function` then in your last cell, add:
|
| 305 |
+
|
| 306 |
+
```python
|
| 307 |
+
from accelerate import notebook_launcher
|
| 308 |
+
|
| 309 |
+
notebook_launcher(training_function)
|
| 310 |
+
```
|
| 311 |
+
|
| 312 |
+
An example can be found in [this notebook](https://github.com/huggingface/notebooks/blob/main/examples/accelerate_examples/simple_nlp_example.ipynb). [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/accelerate_examples/simple_nlp_example.ipynb)
|
| 313 |
+
|
| 314 |
+
## Why should I use 🤗 Accelerate?
|
| 315 |
+
|
| 316 |
+
You should use 🤗 Accelerate when you want to easily run your training scripts in a distributed environment without having to renounce full control over your training loop. This is not a high-level framework above PyTorch, just a thin wrapper so you don't have to learn a new library. In fact, the whole API of 🤗 Accelerate is in one class, the `Accelerator` object.
|
| 317 |
+
|
| 318 |
+
## Why shouldn't I use 🤗 Accelerate?
|
| 319 |
+
|
| 320 |
+
You shouldn't use 🤗 Accelerate if you don't want to write a training loop yourself. There are plenty of high-level libraries above PyTorch that will offer you that, 🤗 Accelerate is not one of them.
|
| 321 |
+
|
| 322 |
+
## Frameworks using 🤗 Accelerate
|
| 323 |
+
|
| 324 |
+
If you like the simplicity of 🤗 Accelerate but would prefer a higher-level abstraction around its capabilities, some frameworks and libraries that are built on top of 🤗 Accelerate are listed below:
|
| 325 |
+
|
| 326 |
+
* [Amphion](https://github.com/open-mmlab/Amphion) is a toolkit for Audio, Music, and Speech Generation. Its purpose is to support reproducible research and help junior researchers and engineers get started in the field of audio, music, and speech generation research and development.
|
| 327 |
+
* [Animus](https://github.com/Scitator/animus) is a minimalistic framework to run machine learning experiments. Animus highlights common "breakpoints" in ML experiments and provides a unified interface for them within [IExperiment](https://github.com/Scitator/animus/blob/main/animus/core.py#L76).
|
| 328 |
+
* [Catalyst](https://github.com/catalyst-team/catalyst#getting-started) is a PyTorch framework for Deep Learning Research and Development. It focuses on reproducibility, rapid experimentation, and codebase reuse so you can create something new rather than write yet another train loop. Catalyst provides a [Runner](https://catalyst-team.github.io/catalyst/api/core.html#runner) to connect all parts of the experiment: hardware backend, data transformations, model training, and inference logic.
|
| 329 |
+
* [fastai](https://github.com/fastai/fastai#installing) is a PyTorch framework for Deep Learning that simplifies training fast and accurate neural nets using modern best practices. fastai provides a [Learner](https://docs.fast.ai/learner.html#Learner) to handle the training, fine-tuning, and inference of deep learning algorithms.
|
| 330 |
+
* [Finetuner](https://github.com/jina-ai/finetuner) is a service that enables models to create higher-quality embeddings for semantic search, visual similarity search, cross-modal text<->image search, recommendation systems, clustering, duplication detection, anomaly detection, or other uses.
|
| 331 |
+
* [InvokeAI](https://github.com/invoke-ai/InvokeAI) is a creative engine for Stable Diffusion models, offering industry-leading WebUI, terminal usage support, and serves as the foundation for many commercial products.
|
| 332 |
+
* [Kornia](https://kornia.readthedocs.io/en/latest/get-started/introduction.html) is a differentiable library that allows classical computer vision to be integrated into deep learning models. Kornia provides a [Trainer](https://kornia.readthedocs.io/en/latest/x.html#kornia.x.Trainer) with the specific purpose to train and fine-tune the supported deep learning algorithms within the library.
|
| 333 |
+
* [Open Assistant](https://projects.laion.ai/Open-Assistant/) is a chat-based assistant that understands tasks, can interact with their party systems, and retrieve information dynamically to do so.
|
| 334 |
+
* [pytorch-accelerated](https://github.com/Chris-hughes10/pytorch-accelerated) is a lightweight training library, with a streamlined feature set centered around a general-purpose [Trainer](https://pytorch-accelerated.readthedocs.io/en/latest/trainer.html), that places a huge emphasis on simplicity and transparency; enabling users to understand exactly what is going on under the hood, but without having to write and maintain the boilerplate themselves!
|
| 335 |
+
* [Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) is an open-source browser-based easy-to-use interface based on the Gradio library for Stable Diffusion.
|
| 336 |
+
* [torchkeras](https://github.com/lyhue1991/torchkeras) is a simple tool for training pytorch model just in a keras style, a dynamic and beautiful plot is provided in notebook to monitor your loss or metric.
|
| 337 |
+
* [transformers](https://github.com/huggingface/transformers) as a tool for helping train state-of-the-art machine learning models in PyTorch, Tensorflow, and JAX. (Accelerate is the backend for the PyTorch side).
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
## Installation
|
| 341 |
+
|
| 342 |
+
This repository is tested on Python 3.8+ and PyTorch 1.10.0+
|
| 343 |
+
|
| 344 |
+
You should install 🤗 Accelerate in a [virtual environment](https://docs.python.org/3/library/venv.html). If you're unfamiliar with Python virtual environments, check out the [user guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
|
| 345 |
+
|
| 346 |
+
First, create a virtual environment with the version of Python you're going to use and activate it.
|
| 347 |
+
|
| 348 |
+
Then, you will need to install PyTorch: refer to the [official installation page](https://pytorch.org/get-started/locally/#start-locally) regarding the specific install command for your platform. Then 🤗 Accelerate can be installed using pip as follows:
|
| 349 |
+
|
| 350 |
+
```bash
|
| 351 |
+
pip install accelerate
|
| 352 |
+
```
|
| 353 |
+
|
| 354 |
+
## Supported integrations
|
| 355 |
+
|
| 356 |
+
- CPU only
|
| 357 |
+
- multi-CPU on one node (machine)
|
| 358 |
+
- multi-CPU on several nodes (machines)
|
| 359 |
+
- single GPU
|
| 360 |
+
- multi-GPU on one node (machine)
|
| 361 |
+
- multi-GPU on several nodes (machines)
|
| 362 |
+
- TPU
|
| 363 |
+
- FP16/BFloat16 mixed precision
|
| 364 |
+
- FP8 mixed precision with [Transformer Engine](https://github.com/NVIDIA/TransformerEngine) or [MS-AMP](https://github.com/Azure/MS-AMP/)
|
| 365 |
+
- DeepSpeed support (Experimental)
|
| 366 |
+
- PyTorch Fully Sharded Data Parallel (FSDP) support (Experimental)
|
| 367 |
+
- Megatron-LM support (Experimental)
|
| 368 |
+
|
| 369 |
+
## Citing 🤗 Accelerate
|
| 370 |
+
|
| 371 |
+
If you use 🤗 Accelerate in your publication, please cite it by using the following BibTeX entry.
|
| 372 |
+
|
| 373 |
+
```bibtex
|
| 374 |
+
@Misc{accelerate,
|
| 375 |
+
title = {Accelerate: Training and inference at scale made simple, efficient and adaptable.},
|
| 376 |
+
author = {Sylvain Gugger and Lysandre Debut and Thomas Wolf and Philipp Schmid and Zachary Mueller and Sourab Mangrulkar and Marc Sun and Benjamin Bossan},
|
| 377 |
+
howpublished = {\url{https://github.com/huggingface/accelerate}},
|
| 378 |
+
year = {2022}
|
| 379 |
+
}
|
| 380 |
+
```
|
venv/Lib/site-packages/accelerate-1.6.0.dist-info/RECORD
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
../../Scripts/accelerate-config.exe,sha256=oMHvUIO20oc9e7mTWqdxwnwE2vt6jKFxV5U0295vjGQ,108433
|
| 2 |
+
../../Scripts/accelerate-estimate-memory.exe,sha256=ptuggVnh4A7ZaLZFPsmF8CMf3_QPEi7MFOshFkZtg_E,108435
|
| 3 |
+
../../Scripts/accelerate-launch.exe,sha256=n9Bd7LTGgPp2bVXE9A73FDCFI_brpmNFws404EoTC_g,108433
|
| 4 |
+
../../Scripts/accelerate-merge-weights.exe,sha256=PSkH501EplMRrCYdZqTr8qK-VUHOeIpuD9UzcDuJ6oQ,108432
|
| 5 |
+
../../Scripts/accelerate.exe,sha256=InYSaN6P9U5H6tX0Xjkm5K1FvyZwbIyQmXJY2WlrW6s,108441
|
| 6 |
+
accelerate-1.6.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
| 7 |
+
accelerate-1.6.0.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
| 8 |
+
accelerate-1.6.0.dist-info/METADATA,sha256=zT5ADQHZZeLT4qEiGMNSG4cT7hCnQplwyshDyeDyZNo,19421
|
| 9 |
+
accelerate-1.6.0.dist-info/RECORD,,
|
| 10 |
+
accelerate-1.6.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
| 11 |
+
accelerate-1.6.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
| 12 |
+
accelerate-1.6.0.dist-info/entry_points.txt,sha256=Vpy8gUGfZ-1VnM2229fb8CpJNLBdMH_wtJ9PQ7b_2tQ,296
|
| 13 |
+
accelerate-1.6.0.dist-info/top_level.txt,sha256=esVfdxTidsjQ90zsN_rPpjLFJ4ijRlx4mnLrG09hlt4,11
|
| 14 |
+
accelerate/__init__.py,sha256=r3I-pArsQK9ZrH3XgnjeCoXo4l-DEFOWQhjj3BguTZc,1504
|
| 15 |
+
accelerate/__pycache__/__init__.cpython-312.pyc,,
|
| 16 |
+
accelerate/__pycache__/accelerator.cpython-312.pyc,,
|
| 17 |
+
accelerate/__pycache__/big_modeling.cpython-312.pyc,,
|
| 18 |
+
accelerate/__pycache__/checkpointing.cpython-312.pyc,,
|
| 19 |
+
accelerate/__pycache__/data_loader.cpython-312.pyc,,
|
| 20 |
+
accelerate/__pycache__/hooks.cpython-312.pyc,,
|
| 21 |
+
accelerate/__pycache__/inference.cpython-312.pyc,,
|
| 22 |
+
accelerate/__pycache__/launchers.cpython-312.pyc,,
|
| 23 |
+
accelerate/__pycache__/local_sgd.cpython-312.pyc,,
|
| 24 |
+
accelerate/__pycache__/logging.cpython-312.pyc,,
|
| 25 |
+
accelerate/__pycache__/memory_utils.cpython-312.pyc,,
|
| 26 |
+
accelerate/__pycache__/optimizer.cpython-312.pyc,,
|
| 27 |
+
accelerate/__pycache__/scheduler.cpython-312.pyc,,
|
| 28 |
+
accelerate/__pycache__/state.cpython-312.pyc,,
|
| 29 |
+
accelerate/__pycache__/tracking.cpython-312.pyc,,
|
| 30 |
+
accelerate/accelerator.py,sha256=G952noNHGPrl-poK6qAj1OY32kGjmN5S13v8zy7H63E,173175
|
| 31 |
+
accelerate/big_modeling.py,sha256=IMiAtiuZQpwSyk2jQsoYC2uWzfRUSpCg7FiThSvjfKw,29702
|
| 32 |
+
accelerate/checkpointing.py,sha256=BaDOrpQzRI2U1BvN2vK4lepTRNNqbxGd4QPa1zOShoc,13612
|
| 33 |
+
accelerate/commands/__init__.py,sha256=m1PPTDT4ziIAvM0-FDSgIMIZ69Konn126s6LwuzH6v8,606
|
| 34 |
+
accelerate/commands/__pycache__/__init__.cpython-312.pyc,,
|
| 35 |
+
accelerate/commands/__pycache__/accelerate_cli.cpython-312.pyc,,
|
| 36 |
+
accelerate/commands/__pycache__/env.cpython-312.pyc,,
|
| 37 |
+
accelerate/commands/__pycache__/estimate.cpython-312.pyc,,
|
| 38 |
+
accelerate/commands/__pycache__/launch.cpython-312.pyc,,
|
| 39 |
+
accelerate/commands/__pycache__/merge.cpython-312.pyc,,
|
| 40 |
+
accelerate/commands/__pycache__/test.cpython-312.pyc,,
|
| 41 |
+
accelerate/commands/__pycache__/to_fsdp2.cpython-312.pyc,,
|
| 42 |
+
accelerate/commands/__pycache__/tpu.cpython-312.pyc,,
|
| 43 |
+
accelerate/commands/__pycache__/utils.cpython-312.pyc,,
|
| 44 |
+
accelerate/commands/accelerate_cli.py,sha256=SkwFad6Z1ZsGjtm7TiXFq8je-akshp_0WxX_6rGSBw8,1972
|
| 45 |
+
accelerate/commands/config/__init__.py,sha256=iJK8dgj3pc5Vdr1E7UuGoFu-BlybyXLxYDoTg9gXngE,1645
|
| 46 |
+
accelerate/commands/config/__pycache__/__init__.cpython-312.pyc,,
|
| 47 |
+
accelerate/commands/config/__pycache__/cluster.cpython-312.pyc,,
|
| 48 |
+
accelerate/commands/config/__pycache__/config.cpython-312.pyc,,
|
| 49 |
+
accelerate/commands/config/__pycache__/config_args.cpython-312.pyc,,
|
| 50 |
+
accelerate/commands/config/__pycache__/config_utils.cpython-312.pyc,,
|
| 51 |
+
accelerate/commands/config/__pycache__/default.cpython-312.pyc,,
|
| 52 |
+
accelerate/commands/config/__pycache__/sagemaker.cpython-312.pyc,,
|
| 53 |
+
accelerate/commands/config/__pycache__/update.cpython-312.pyc,,
|
| 54 |
+
accelerate/commands/config/cluster.py,sha256=w0L3zTyZp4sjDpCrM3NxOjxZ0kyPJZmzi06pFZmbM2c,37472
|
| 55 |
+
accelerate/commands/config/config.py,sha256=FuRlQvOjgATEtyqOSsGD-KEtOCvACOHjs2C-krrtldk,3035
|
| 56 |
+
accelerate/commands/config/config_args.py,sha256=xn6M8iJnlFycosDlbM0BE86r9RxfdDwHtIlk-UUq7UM,10082
|
| 57 |
+
accelerate/commands/config/config_utils.py,sha256=mdvZE9fpllfD8S4Blhqk3nLqQ5m14WJ0jQ1yh768H10,3177
|
| 58 |
+
accelerate/commands/config/default.py,sha256=sPgQVt_0zk68KlupQFqt8B6JUoPMFPxXmXr7xFM-EN8,6212
|
| 59 |
+
accelerate/commands/config/sagemaker.py,sha256=GjHE2-h4tRr1P_PFtMF3miiAtJlzkbHbMb6kFXqn8eo,10341
|
| 60 |
+
accelerate/commands/config/update.py,sha256=NXW1J7GkUHpg71QlIXsmMB_0z8S8IZo2FWax5POwrhc,2395
|
| 61 |
+
accelerate/commands/env.py,sha256=-B3FPX4S705A-P_tyLKm_JzGpz-TeKqFNPdNWDAdGIM,4156
|
| 62 |
+
accelerate/commands/estimate.py,sha256=Qduq4xudVyIede37BMEe1rNhXf-rfW-MHV2KtwxdfEA,12585
|
| 63 |
+
accelerate/commands/launch.py,sha256=7DI42Uw4kf_peOpY5TUA1V2yz7cuSO3cYnLgiI5G1Vs,47496
|
| 64 |
+
accelerate/commands/menu/__init__.py,sha256=uqSlBM0TFHBwzdv3p3SXfpAk1lZFp4h1a7mbBdscPHs,645
|
| 65 |
+
accelerate/commands/menu/__pycache__/__init__.cpython-312.pyc,,
|
| 66 |
+
accelerate/commands/menu/__pycache__/cursor.cpython-312.pyc,,
|
| 67 |
+
accelerate/commands/menu/__pycache__/helpers.cpython-312.pyc,,
|
| 68 |
+
accelerate/commands/menu/__pycache__/input.cpython-312.pyc,,
|
| 69 |
+
accelerate/commands/menu/__pycache__/keymap.cpython-312.pyc,,
|
| 70 |
+
accelerate/commands/menu/__pycache__/selection_menu.cpython-312.pyc,,
|
| 71 |
+
accelerate/commands/menu/cursor.py,sha256=-lmpJVAzvNc0c3EOtSuLoKB59zqylVCbYyWLPnrOmvQ,2028
|
| 72 |
+
accelerate/commands/menu/helpers.py,sha256=KrSB5fJjH4MUEUAQJ6bYaN16AYcnl9UalDrPD3DYeeg,1483
|
| 73 |
+
accelerate/commands/menu/input.py,sha256=T8Mdd-Y_OURgqfDV9qZh4Wf6hmT22AneNtJzj4JA1Rk,2512
|
| 74 |
+
accelerate/commands/menu/keymap.py,sha256=eXj-suyYs1m5dEHoUKN4mKAMLc8DWHnwhP6G6JSU0jQ,4086
|
| 75 |
+
accelerate/commands/menu/selection_menu.py,sha256=bxy-DHaKKC6SCToOlMBv5_z0MdUzylEg6Sio9OuV3GM,4921
|
| 76 |
+
accelerate/commands/merge.py,sha256=quDKckN3vKn9nsGjdwfoojnfTMFdKRRUkY1DYuuNNmc,2388
|
| 77 |
+
accelerate/commands/test.py,sha256=YrPYEaAACOGZ6btn2MV6NbMSEdBUcMWADLbQWaZSHtk,2149
|
| 78 |
+
accelerate/commands/to_fsdp2.py,sha256=gfbhoUT4qFB3LVDMNmckElgLG0yWm8aj_aofszeiJmM,5991
|
| 79 |
+
accelerate/commands/tpu.py,sha256=KyxDP7IuveidZrbW4rx2s8Ku3o_ptI6tzwr_R7ck0os,5548
|
| 80 |
+
accelerate/commands/utils.py,sha256=aT8xUCe2pCkFII7yZxcfaohEjgBAzMUM7WiD4UuWSOY,4150
|
| 81 |
+
accelerate/data_loader.py,sha256=yArisKhfuIJzDD7vuOgZAqEJNUC8tgl2L8ay92rgtfY,64551
|
| 82 |
+
accelerate/hooks.py,sha256=lYtYSIqEQnZOImgj2UMTngQPkcQDEHS2klwak1oHD6w,32248
|
| 83 |
+
accelerate/inference.py,sha256=NLANdzXm5PwmDWbPYkFmoRoQSLLvuhfvIG33xfpapT0,7668
|
| 84 |
+
accelerate/launchers.py,sha256=QIqUVkDc-oTmWf00L8kas7u2RBEwOYoRi8M2Our0DAs,13721
|
| 85 |
+
accelerate/local_sgd.py,sha256=aCj_yqXK_FhhZRWEpzXIkgXBERH6fC3HyrC3nsOj1uA,4160
|
| 86 |
+
accelerate/logging.py,sha256=4XcgY_BV7Qn_enh2tZ-8fNtuaE_3n-LsYJbgwhRx_PI,5042
|
| 87 |
+
accelerate/memory_utils.py,sha256=3R5LoeHl6GgTZ-IMPrDZMdaEehWarGdPqODushb-6pg,862
|
| 88 |
+
accelerate/optimizer.py,sha256=QfgCkQ5dA-fLSi_Z7CBPRCObXA1rL9zxHg4tyKCEg2A,8113
|
| 89 |
+
accelerate/scheduler.py,sha256=des_4M_Tt1W8gCYZZbLla0GHBEgJY3Wx2EGBQPTzeiY,4238
|
| 90 |
+
accelerate/state.py,sha256=YYpuPqXeNjz5_Y71h0zmCu13cBuDmQ8lw6fAmoSWUFk,55457
|
| 91 |
+
accelerate/test_utils/__init__.py,sha256=8xikmLMAM6_6CwVF6tsdsv4XzgWkHAk2tZBdV9DxIH8,1749
|
| 92 |
+
accelerate/test_utils/__pycache__/__init__.cpython-312.pyc,,
|
| 93 |
+
accelerate/test_utils/__pycache__/examples.cpython-312.pyc,,
|
| 94 |
+
accelerate/test_utils/__pycache__/testing.cpython-312.pyc,,
|
| 95 |
+
accelerate/test_utils/__pycache__/training.cpython-312.pyc,,
|
| 96 |
+
accelerate/test_utils/examples.py,sha256=IN4n2lxA95hexE2rojsyyjhpXLbXnbmjTzd8UTws5_4,7257
|
| 97 |
+
accelerate/test_utils/scripts/__init__.py,sha256=m1PPTDT4ziIAvM0-FDSgIMIZ69Konn126s6LwuzH6v8,606
|
| 98 |
+
accelerate/test_utils/scripts/__pycache__/__init__.cpython-312.pyc,,
|
| 99 |
+
accelerate/test_utils/scripts/__pycache__/test_cli.cpython-312.pyc,,
|
| 100 |
+
accelerate/test_utils/scripts/__pycache__/test_ddp_comm_hook.cpython-312.pyc,,
|
| 101 |
+
accelerate/test_utils/scripts/__pycache__/test_distributed_data_loop.cpython-312.pyc,,
|
| 102 |
+
accelerate/test_utils/scripts/__pycache__/test_merge_weights.cpython-312.pyc,,
|
| 103 |
+
accelerate/test_utils/scripts/__pycache__/test_notebook.cpython-312.pyc,,
|
| 104 |
+
accelerate/test_utils/scripts/__pycache__/test_ops.cpython-312.pyc,,
|
| 105 |
+
accelerate/test_utils/scripts/__pycache__/test_script.cpython-312.pyc,,
|
| 106 |
+
accelerate/test_utils/scripts/__pycache__/test_sync.cpython-312.pyc,,
|
| 107 |
+
accelerate/test_utils/scripts/external_deps/__init__.py,sha256=m1PPTDT4ziIAvM0-FDSgIMIZ69Konn126s6LwuzH6v8,606
|
| 108 |
+
accelerate/test_utils/scripts/external_deps/__pycache__/__init__.cpython-312.pyc,,
|
| 109 |
+
accelerate/test_utils/scripts/external_deps/__pycache__/test_checkpointing.cpython-312.pyc,,
|
| 110 |
+
accelerate/test_utils/scripts/external_deps/__pycache__/test_ds_multiple_model.cpython-312.pyc,,
|
| 111 |
+
accelerate/test_utils/scripts/external_deps/__pycache__/test_metrics.cpython-312.pyc,,
|
| 112 |
+
accelerate/test_utils/scripts/external_deps/__pycache__/test_peak_memory_usage.cpython-312.pyc,,
|
| 113 |
+
accelerate/test_utils/scripts/external_deps/__pycache__/test_performance.cpython-312.pyc,,
|
| 114 |
+
accelerate/test_utils/scripts/external_deps/__pycache__/test_pippy.cpython-312.pyc,,
|
| 115 |
+
accelerate/test_utils/scripts/external_deps/__pycache__/test_zero3_integration.cpython-312.pyc,,
|
| 116 |
+
accelerate/test_utils/scripts/external_deps/test_checkpointing.py,sha256=XHaNRmnrARd1izXFjWGi5UjYGas-4vqayW51jAHBPCA,10699
|
| 117 |
+
accelerate/test_utils/scripts/external_deps/test_ds_multiple_model.py,sha256=Cg4-h0B4UcOQ5CxXjIdrsPVR5fFsWCv24DqZGjXEwW8,13790
|
| 118 |
+
accelerate/test_utils/scripts/external_deps/test_metrics.py,sha256=Ev2XKaiwmznoxKujskAAuISGChW646MOiyf0CXEPb9Y,12168
|
| 119 |
+
accelerate/test_utils/scripts/external_deps/test_peak_memory_usage.py,sha256=9Yn9Rc7d-yWr1fU0RagASPG5l8vrKeHVYbuYABbA-fU,12498
|
| 120 |
+
accelerate/test_utils/scripts/external_deps/test_performance.py,sha256=Di6LT19bCBLlWmCBSu_jjdqR2EqngXpvUOGDBx8GfZE,10432
|
| 121 |
+
accelerate/test_utils/scripts/external_deps/test_pippy.py,sha256=ocZntbmAduln2ma4LeEA9o-S8hla3YXCJ_A8hEcWHgs,4762
|
| 122 |
+
accelerate/test_utils/scripts/external_deps/test_zero3_integration.py,sha256=P9alBOHZ9Lfqs5LoRP7bCbXl-tnsNrBkvJZGseibBeA,1665
|
| 123 |
+
accelerate/test_utils/scripts/test_cli.py,sha256=qfk1aYFtdvYFCYPkl05602SNGvk08QTv0xZVVcFVtzM,833
|
| 124 |
+
accelerate/test_utils/scripts/test_ddp_comm_hook.py,sha256=k_-2MBjLKNdMGIcneTbuGd84K05Wp1GEQX6DUVF9UBw,3566
|
| 125 |
+
accelerate/test_utils/scripts/test_distributed_data_loop.py,sha256=RUWTwd7DIpr2fl7JtKOsvTjMiJioTxO8FdSr2Lw_5uI,15137
|
| 126 |
+
accelerate/test_utils/scripts/test_merge_weights.py,sha256=dssMnAoZt291vNLbPhPOTQUooh0leg_0erQh0uZH6aU,6125
|
| 127 |
+
accelerate/test_utils/scripts/test_notebook.py,sha256=qfIy3IvH74-kGn8nadBn_k7qrviqvsxy5ijsnUhuY6o,3894
|
| 128 |
+
accelerate/test_utils/scripts/test_ops.py,sha256=Bcs-h8EMJwULTfbizlFN5qkv3JraWEpoSZWMn-HswiI,6265
|
| 129 |
+
accelerate/test_utils/scripts/test_script.py,sha256=8-53hIVQXD28HQT4h2Ijy6yGCHfTWDAf1-HOi4UtDng,34219
|
| 130 |
+
accelerate/test_utils/scripts/test_sync.py,sha256=PDe8sYZLCL2LKjj_L9b-Bh2BjAjeii9EZ8sZNfuYx5s,18817
|
| 131 |
+
accelerate/test_utils/testing.py,sha256=x9RK70VgAMyHlo5xut7P85j-9kdAnlfQe_4jwSPpMv4,27807
|
| 132 |
+
accelerate/test_utils/training.py,sha256=jO5YEIr34jAcnJ_9WNp_x3zuHzSam_I6IgMvmcGm7yI,6456
|
| 133 |
+
accelerate/tracking.py,sha256=ucpsoYAT3pVXgOfwDdXf6qTugY2-tk-EINvZtfmRitM,42756
|
| 134 |
+
accelerate/utils/__init__.py,sha256=wjpXyvFxS-ed3Stwm_IHIlBmsmP7KRyAljQ_Qss-OWw,7802
|
| 135 |
+
accelerate/utils/__pycache__/__init__.cpython-312.pyc,,
|
| 136 |
+
accelerate/utils/__pycache__/ao.cpython-312.pyc,,
|
| 137 |
+
accelerate/utils/__pycache__/bnb.cpython-312.pyc,,
|
| 138 |
+
accelerate/utils/__pycache__/constants.cpython-312.pyc,,
|
| 139 |
+
accelerate/utils/__pycache__/dataclasses.cpython-312.pyc,,
|
| 140 |
+
accelerate/utils/__pycache__/deepspeed.cpython-312.pyc,,
|
| 141 |
+
accelerate/utils/__pycache__/environment.cpython-312.pyc,,
|
| 142 |
+
accelerate/utils/__pycache__/fsdp_utils.cpython-312.pyc,,
|
| 143 |
+
accelerate/utils/__pycache__/imports.cpython-312.pyc,,
|
| 144 |
+
accelerate/utils/__pycache__/launch.cpython-312.pyc,,
|
| 145 |
+
accelerate/utils/__pycache__/megatron_lm.cpython-312.pyc,,
|
| 146 |
+
accelerate/utils/__pycache__/memory.cpython-312.pyc,,
|
| 147 |
+
accelerate/utils/__pycache__/modeling.cpython-312.pyc,,
|
| 148 |
+
accelerate/utils/__pycache__/offload.cpython-312.pyc,,
|
| 149 |
+
accelerate/utils/__pycache__/operations.cpython-312.pyc,,
|
| 150 |
+
accelerate/utils/__pycache__/other.cpython-312.pyc,,
|
| 151 |
+
accelerate/utils/__pycache__/random.cpython-312.pyc,,
|
| 152 |
+
accelerate/utils/__pycache__/rich.cpython-312.pyc,,
|
| 153 |
+
accelerate/utils/__pycache__/torch_xla.cpython-312.pyc,,
|
| 154 |
+
accelerate/utils/__pycache__/tqdm.cpython-312.pyc,,
|
| 155 |
+
accelerate/utils/__pycache__/transformer_engine.cpython-312.pyc,,
|
| 156 |
+
accelerate/utils/__pycache__/versions.cpython-312.pyc,,
|
| 157 |
+
accelerate/utils/ao.py,sha256=koMiji7AG1kJMRMkJnwSnpuycfx4lPY3CNnpNx2ZqzM,4736
|
| 158 |
+
accelerate/utils/bnb.py,sha256=KCbg6LUt4eXvPHVnKh7rSVcPwDnzxY_Ii7yYmK5bNGw,20737
|
| 159 |
+
accelerate/utils/constants.py,sha256=hc24V0pgxWdBQwS6SXxDKwuIni2pCnzdfvMOX1XI9Os,3264
|
| 160 |
+
accelerate/utils/dataclasses.py,sha256=E7CnCbfskpzxzSorst95Via_XE39t0NP_UGYgJUris0,131486
|
| 161 |
+
accelerate/utils/deepspeed.py,sha256=QYIXv5LwHXw7wBFFo-7a0t86MbwNAfieJkkBaLGA6wI,14064
|
| 162 |
+
accelerate/utils/environment.py,sha256=h0zacbBkAp9szltTf5-aTr5NcbVsQp7wl6DFWp8XNuI,15257
|
| 163 |
+
accelerate/utils/fsdp_utils.py,sha256=Q2tc9EakwBjuYlyXvQrBLV97r6cdReRft6KeS1P_Vb4,28938
|
| 164 |
+
accelerate/utils/imports.py,sha256=YI1ebPJAuxarclENTfzvDPPGf6jeEKnVQ42taFPuqh0,16759
|
| 165 |
+
accelerate/utils/launch.py,sha256=nN4ykAtnEL3oITLTejABltdpS3OivcE2COmX-BnWuY4,31195
|
| 166 |
+
accelerate/utils/megatron_lm.py,sha256=FnIF-niZjvdMk9ymafZWEPjDho_Q_P98C69qc9g5r_E,58059
|
| 167 |
+
accelerate/utils/memory.py,sha256=lDHqW7Ue_CPmw_DWgNxX_B3HY71_srAFdgR10XiVRSM,6960
|
| 168 |
+
accelerate/utils/modeling.py,sha256=_xSTiH7zSsffZULSTJuzcDK6IaWImEMOcbq1xqeI7GY,92319
|
| 169 |
+
accelerate/utils/offload.py,sha256=VFaL8oSJzqZ_47VuUQ69xZi9bF2heRSFoOSnnOxbGXc,7825
|
| 170 |
+
accelerate/utils/operations.py,sha256=VWPYvtrO4UGX5JmisanXzLLUbhAeL8kQk0yYc66bQ2M,31055
|
| 171 |
+
accelerate/utils/other.py,sha256=iiLZcKEAlK2Sj_wt03gAEGKrk7_NZFwbmy9cgEppRPw,13231
|
| 172 |
+
accelerate/utils/random.py,sha256=Xv_ZJm9eaC2Q7rgZy9OpOunKuTingMiDQCH00qhNVxE,6220
|
| 173 |
+
accelerate/utils/rich.py,sha256=8JZX_uGMQX-BufdXxJpdne7BWd1KyLHSgbiGxrDMYr8,847
|
| 174 |
+
accelerate/utils/torch_xla.py,sha256=Pq1tuqN0X_pWDVza6YgjfO45uoJdoRVRForLeLQzFus,1908
|
| 175 |
+
accelerate/utils/tqdm.py,sha256=k8e9JnieTEQHCCNBaiBys7hPxWlEbyRASdIma-qy_X8,1657
|
| 176 |
+
accelerate/utils/transformer_engine.py,sha256=498Y3z2BkbybYLtBiuF_TJgt8Iii943s4wgRAV8FDC4,6372
|
| 177 |
+
accelerate/utils/versions.py,sha256=UgmcbjBm--6CIx1ZamSAMjAK_B_2l48LbeaNygqej8M,2149
|
venv/Lib/site-packages/accelerate-1.6.0.dist-info/REQUESTED
ADDED
|
File without changes
|
venv/Lib/site-packages/accelerate-1.6.0.dist-info/WHEEL
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Wheel-Version: 1.0
|
| 2 |
+
Generator: setuptools (75.1.0)
|
| 3 |
+
Root-Is-Purelib: true
|
| 4 |
+
Tag: py3-none-any
|
| 5 |
+
|
venv/Lib/site-packages/accelerate-1.6.0.dist-info/entry_points.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[console_scripts]
|
| 2 |
+
accelerate = accelerate.commands.accelerate_cli:main
|
| 3 |
+
accelerate-config = accelerate.commands.config:main
|
| 4 |
+
accelerate-estimate-memory = accelerate.commands.estimate:main
|
| 5 |
+
accelerate-launch = accelerate.commands.launch:main
|
| 6 |
+
accelerate-merge-weights = accelerate.commands.merge:main
|
venv/Lib/site-packages/accelerate-1.6.0.dist-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
accelerate
|
venv/Lib/site-packages/accelerate/__init__.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
__version__ = "1.6.0"
|
| 15 |
+
|
| 16 |
+
from .accelerator import Accelerator
|
| 17 |
+
from .big_modeling import (
|
| 18 |
+
cpu_offload,
|
| 19 |
+
cpu_offload_with_hook,
|
| 20 |
+
disk_offload,
|
| 21 |
+
dispatch_model,
|
| 22 |
+
init_empty_weights,
|
| 23 |
+
init_on_device,
|
| 24 |
+
load_checkpoint_and_dispatch,
|
| 25 |
+
)
|
| 26 |
+
from .data_loader import skip_first_batches
|
| 27 |
+
from .inference import prepare_pippy
|
| 28 |
+
from .launchers import debug_launcher, notebook_launcher
|
| 29 |
+
from .state import PartialState
|
| 30 |
+
from .utils import (
|
| 31 |
+
AutocastKwargs,
|
| 32 |
+
DataLoaderConfiguration,
|
| 33 |
+
DDPCommunicationHookType,
|
| 34 |
+
DeepSpeedPlugin,
|
| 35 |
+
DistributedDataParallelKwargs,
|
| 36 |
+
DistributedType,
|
| 37 |
+
FullyShardedDataParallelPlugin,
|
| 38 |
+
GradScalerKwargs,
|
| 39 |
+
InitProcessGroupKwargs,
|
| 40 |
+
ProfileKwargs,
|
| 41 |
+
find_executable_batch_size,
|
| 42 |
+
infer_auto_device_map,
|
| 43 |
+
is_rich_available,
|
| 44 |
+
load_checkpoint_in_model,
|
| 45 |
+
synchronize_rng_states,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if is_rich_available():
|
| 50 |
+
from .utils import rich
|
venv/Lib/site-packages/accelerate/accelerator.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
venv/Lib/site-packages/accelerate/big_modeling.py
ADDED
|
@@ -0,0 +1,637 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
import os
|
| 17 |
+
from contextlib import contextmanager
|
| 18 |
+
from functools import wraps
|
| 19 |
+
from typing import Optional, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
|
| 24 |
+
from .hooks import (
|
| 25 |
+
AlignDevicesHook,
|
| 26 |
+
CpuOffload,
|
| 27 |
+
UserCpuOffloadHook,
|
| 28 |
+
add_hook_to_module,
|
| 29 |
+
attach_align_device_hook,
|
| 30 |
+
attach_align_device_hook_on_blocks,
|
| 31 |
+
)
|
| 32 |
+
from .utils import (
|
| 33 |
+
OffloadedWeightsLoader,
|
| 34 |
+
check_cuda_p2p_ib_support,
|
| 35 |
+
check_device_map,
|
| 36 |
+
extract_submodules_state_dict,
|
| 37 |
+
find_tied_parameters,
|
| 38 |
+
get_balanced_memory,
|
| 39 |
+
infer_auto_device_map,
|
| 40 |
+
is_bnb_available,
|
| 41 |
+
is_mlu_available,
|
| 42 |
+
is_musa_available,
|
| 43 |
+
is_npu_available,
|
| 44 |
+
is_sdaa_available,
|
| 45 |
+
is_xpu_available,
|
| 46 |
+
load_checkpoint_in_model,
|
| 47 |
+
offload_state_dict,
|
| 48 |
+
parse_flag_from_env,
|
| 49 |
+
retie_parameters,
|
| 50 |
+
)
|
| 51 |
+
from .utils.other import recursive_getattr
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
logger = logging.getLogger(__name__)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@contextmanager
|
| 58 |
+
def init_empty_weights(include_buffers: bool = None):
|
| 59 |
+
"""
|
| 60 |
+
A context manager under which models are initialized with all parameters on the meta device, therefore creating an
|
| 61 |
+
empty model. Useful when just initializing the model would blow the available RAM.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
include_buffers (`bool`, *optional*):
|
| 65 |
+
Whether or not to also put all buffers on the meta device while initializing.
|
| 66 |
+
|
| 67 |
+
Example:
|
| 68 |
+
|
| 69 |
+
```python
|
| 70 |
+
import torch.nn as nn
|
| 71 |
+
from accelerate import init_empty_weights
|
| 72 |
+
|
| 73 |
+
# Initialize a model with 100 billions parameters in no time and without using any RAM.
|
| 74 |
+
with init_empty_weights():
|
| 75 |
+
tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
<Tip warning={true}>
|
| 79 |
+
|
| 80 |
+
Any model created under this context manager has no weights. As such you can't do something like
|
| 81 |
+
`model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].
|
| 82 |
+
Make sure to overwrite the default device_map param for [`load_checkpoint_and_dispatch`], otherwise dispatch is not
|
| 83 |
+
called.
|
| 84 |
+
|
| 85 |
+
</Tip>
|
| 86 |
+
"""
|
| 87 |
+
if include_buffers is None:
|
| 88 |
+
include_buffers = parse_flag_from_env("ACCELERATE_INIT_INCLUDE_BUFFERS", False)
|
| 89 |
+
with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f:
|
| 90 |
+
yield f
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@contextmanager
|
| 94 |
+
def init_on_device(device: torch.device, include_buffers: bool = None):
|
| 95 |
+
"""
|
| 96 |
+
A context manager under which models are initialized with all parameters on the specified device.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
device (`torch.device`):
|
| 100 |
+
Device to initialize all parameters on.
|
| 101 |
+
include_buffers (`bool`, *optional*):
|
| 102 |
+
Whether or not to also put all buffers on the meta device while initializing.
|
| 103 |
+
|
| 104 |
+
Example:
|
| 105 |
+
|
| 106 |
+
```python
|
| 107 |
+
import torch.nn as nn
|
| 108 |
+
from accelerate import init_on_device
|
| 109 |
+
|
| 110 |
+
with init_on_device(device=torch.device("cuda")):
|
| 111 |
+
tst = nn.Linear(100, 100) # on `cuda` device
|
| 112 |
+
```
|
| 113 |
+
"""
|
| 114 |
+
if include_buffers is None:
|
| 115 |
+
include_buffers = parse_flag_from_env("ACCELERATE_INIT_INCLUDE_BUFFERS", False)
|
| 116 |
+
|
| 117 |
+
if include_buffers:
|
| 118 |
+
with device:
|
| 119 |
+
yield
|
| 120 |
+
return
|
| 121 |
+
|
| 122 |
+
old_register_parameter = nn.Module.register_parameter
|
| 123 |
+
if include_buffers:
|
| 124 |
+
old_register_buffer = nn.Module.register_buffer
|
| 125 |
+
|
| 126 |
+
def register_empty_parameter(module, name, param):
|
| 127 |
+
old_register_parameter(module, name, param)
|
| 128 |
+
if param is not None:
|
| 129 |
+
param_cls = type(module._parameters[name])
|
| 130 |
+
kwargs = module._parameters[name].__dict__
|
| 131 |
+
kwargs["requires_grad"] = param.requires_grad
|
| 132 |
+
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
| 133 |
+
|
| 134 |
+
def register_empty_buffer(module, name, buffer, persistent=True):
|
| 135 |
+
old_register_buffer(module, name, buffer, persistent=persistent)
|
| 136 |
+
if buffer is not None:
|
| 137 |
+
module._buffers[name] = module._buffers[name].to(device)
|
| 138 |
+
|
| 139 |
+
# Patch tensor creation
|
| 140 |
+
if include_buffers:
|
| 141 |
+
tensor_constructors_to_patch = {
|
| 142 |
+
torch_function_name: getattr(torch, torch_function_name)
|
| 143 |
+
for torch_function_name in ["empty", "zeros", "ones", "full"]
|
| 144 |
+
}
|
| 145 |
+
else:
|
| 146 |
+
tensor_constructors_to_patch = {}
|
| 147 |
+
|
| 148 |
+
def patch_tensor_constructor(fn):
|
| 149 |
+
def wrapper(*args, **kwargs):
|
| 150 |
+
kwargs["device"] = device
|
| 151 |
+
return fn(*args, **kwargs)
|
| 152 |
+
|
| 153 |
+
return wrapper
|
| 154 |
+
|
| 155 |
+
try:
|
| 156 |
+
nn.Module.register_parameter = register_empty_parameter
|
| 157 |
+
if include_buffers:
|
| 158 |
+
nn.Module.register_buffer = register_empty_buffer
|
| 159 |
+
for torch_function_name in tensor_constructors_to_patch.keys():
|
| 160 |
+
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
|
| 161 |
+
yield
|
| 162 |
+
finally:
|
| 163 |
+
nn.Module.register_parameter = old_register_parameter
|
| 164 |
+
if include_buffers:
|
| 165 |
+
nn.Module.register_buffer = old_register_buffer
|
| 166 |
+
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
|
| 167 |
+
setattr(torch, torch_function_name, old_torch_function)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def cpu_offload(
|
| 171 |
+
model: nn.Module,
|
| 172 |
+
execution_device: Optional[torch.device] = None,
|
| 173 |
+
offload_buffers: bool = False,
|
| 174 |
+
state_dict: Optional[dict[str, torch.Tensor]] = None,
|
| 175 |
+
preload_module_classes: Optional[list[str]] = None,
|
| 176 |
+
):
|
| 177 |
+
"""
|
| 178 |
+
Activates full CPU offload for a model. As a result, all parameters of the model will be offloaded and only one
|
| 179 |
+
copy of the state dict of the model will be kept. During the forward pass, parameters will be extracted from that
|
| 180 |
+
state dict and put on the execution device passed as they are needed, then offloaded again.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
model (`torch.nn.Module`):
|
| 184 |
+
The model to offload.
|
| 185 |
+
execution_device (`torch.device`, *optional*):
|
| 186 |
+
The device on which the forward pass of the model will be executed (should be a GPU). Will default to the
|
| 187 |
+
model first parameter device.
|
| 188 |
+
offload_buffers (`bool`, *optional*, defaults to `False`):
|
| 189 |
+
Whether or not to offload the buffers with the model parameters.
|
| 190 |
+
state_dict (`Dict[str, torch.Tensor]`, *optional*):
|
| 191 |
+
The state dict of the model that will be kept on CPU.
|
| 192 |
+
preload_module_classes (`List[str]`, *optional*):
|
| 193 |
+
A list of classes whose instances should load all their weights (even in the submodules) at the beginning
|
| 194 |
+
of the forward. This should only be used for classes that have submodules which are registered but not
|
| 195 |
+
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
|
| 196 |
+
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
|
| 197 |
+
"""
|
| 198 |
+
if execution_device is None:
|
| 199 |
+
execution_device = next(iter(model.parameters())).device
|
| 200 |
+
if state_dict is None:
|
| 201 |
+
state_dict = {n: p.to("cpu") for n, p in model.state_dict().items()}
|
| 202 |
+
|
| 203 |
+
add_hook_to_module(model, AlignDevicesHook(io_same_device=True), append=True)
|
| 204 |
+
attach_align_device_hook(
|
| 205 |
+
model,
|
| 206 |
+
execution_device=execution_device,
|
| 207 |
+
offload=True,
|
| 208 |
+
offload_buffers=offload_buffers,
|
| 209 |
+
weights_map=state_dict,
|
| 210 |
+
preload_module_classes=preload_module_classes,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
return model
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def cpu_offload_with_hook(
|
| 217 |
+
model: torch.nn.Module,
|
| 218 |
+
execution_device: Optional[Union[int, str, torch.device]] = None,
|
| 219 |
+
prev_module_hook: Optional[UserCpuOffloadHook] = None,
|
| 220 |
+
):
|
| 221 |
+
"""
|
| 222 |
+
Offloads a model on the CPU and puts it back to an execution device when executed. The difference with
|
| 223 |
+
[`cpu_offload`] is that the model stays on the execution device after the forward and is only offloaded again when
|
| 224 |
+
the `offload` method of the returned `hook` is called. Useful for pipelines running a model in a loop.
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
model (`torch.nn.Module`):
|
| 228 |
+
The model to offload.
|
| 229 |
+
execution_device(`str`, `int` or `torch.device`, *optional*):
|
| 230 |
+
The device on which the model should be executed. Will default to the MPS device if it's available, then
|
| 231 |
+
GPU 0 if there is a GPU, and finally to the CPU.
|
| 232 |
+
prev_module_hook (`UserCpuOffloadHook`, *optional*):
|
| 233 |
+
The hook sent back by this function for a previous model in the pipeline you are running. If passed, its
|
| 234 |
+
offload method will be called just before the forward of the model to which this hook is attached.
|
| 235 |
+
|
| 236 |
+
Example:
|
| 237 |
+
|
| 238 |
+
```py
|
| 239 |
+
model_1, hook_1 = cpu_offload_with_hook(model_1, cuda_device)
|
| 240 |
+
model_2, hook_2 = cpu_offload_with_hook(model_2, cuda_device, prev_module_hook=hook_1)
|
| 241 |
+
model_3, hook_3 = cpu_offload_with_hook(model_3, cuda_device, prev_module_hook=hook_2)
|
| 242 |
+
|
| 243 |
+
hid_1 = model_1(input)
|
| 244 |
+
for i in range(50):
|
| 245 |
+
# model1 is offloaded on the CPU at the first iteration, model 2 stays on the GPU for this whole loop.
|
| 246 |
+
hid_2 = model_2(hid_1)
|
| 247 |
+
# model2 is offloaded to the CPU just before this forward.
|
| 248 |
+
hid_3 = model_3(hid_3)
|
| 249 |
+
|
| 250 |
+
# For model3, you need to manually call the hook offload method.
|
| 251 |
+
hook_3.offload()
|
| 252 |
+
```
|
| 253 |
+
"""
|
| 254 |
+
hook = CpuOffload(execution_device=execution_device, prev_module_hook=prev_module_hook)
|
| 255 |
+
add_hook_to_module(model, hook, append=True)
|
| 256 |
+
user_hook = UserCpuOffloadHook(model, hook)
|
| 257 |
+
return model, user_hook
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def disk_offload(
|
| 261 |
+
model: nn.Module,
|
| 262 |
+
offload_dir: Union[str, os.PathLike],
|
| 263 |
+
execution_device: Optional[torch.device] = None,
|
| 264 |
+
offload_buffers: bool = False,
|
| 265 |
+
preload_module_classes: Optional[list[str]] = None,
|
| 266 |
+
):
|
| 267 |
+
"""
|
| 268 |
+
Activates full disk offload for a model. As a result, all parameters of the model will be offloaded as
|
| 269 |
+
memory-mapped array in a given folder. During the forward pass, parameters will be accessed from that folder and
|
| 270 |
+
put on the execution device passed as they are needed, then offloaded again.
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
model (`torch.nn.Module`): The model to offload.
|
| 274 |
+
offload_dir (`str` or `os.PathLike`):
|
| 275 |
+
The folder in which to offload the model weights (or where the model weights are already offloaded).
|
| 276 |
+
execution_device (`torch.device`, *optional*):
|
| 277 |
+
The device on which the forward pass of the model will be executed (should be a GPU). Will default to the
|
| 278 |
+
model's first parameter device.
|
| 279 |
+
offload_buffers (`bool`, *optional*, defaults to `False`):
|
| 280 |
+
Whether or not to offload the buffers with the model parameters.
|
| 281 |
+
preload_module_classes (`List[str]`, *optional*):
|
| 282 |
+
A list of classes whose instances should load all their weights (even in the submodules) at the beginning
|
| 283 |
+
of the forward. This should only be used for classes that have submodules which are registered but not
|
| 284 |
+
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
|
| 285 |
+
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
|
| 286 |
+
"""
|
| 287 |
+
if not os.path.isdir(offload_dir) or not os.path.isfile(os.path.join(offload_dir, "index.json")):
|
| 288 |
+
offload_state_dict(offload_dir, model.state_dict())
|
| 289 |
+
if execution_device is None:
|
| 290 |
+
execution_device = next(iter(model.parameters())).device
|
| 291 |
+
weights_map = OffloadedWeightsLoader(save_folder=offload_dir)
|
| 292 |
+
|
| 293 |
+
add_hook_to_module(model, AlignDevicesHook(io_same_device=True), append=True)
|
| 294 |
+
attach_align_device_hook(
|
| 295 |
+
model,
|
| 296 |
+
execution_device=execution_device,
|
| 297 |
+
offload=True,
|
| 298 |
+
offload_buffers=offload_buffers,
|
| 299 |
+
weights_map=weights_map,
|
| 300 |
+
preload_module_classes=preload_module_classes,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
return model
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def dispatch_model(
|
| 307 |
+
model: nn.Module,
|
| 308 |
+
device_map: dict[str, Union[str, int, torch.device]],
|
| 309 |
+
main_device: Optional[torch.device] = None,
|
| 310 |
+
state_dict: Optional[dict[str, torch.Tensor]] = None,
|
| 311 |
+
offload_dir: Optional[Union[str, os.PathLike]] = None,
|
| 312 |
+
offload_index: Optional[dict[str, str]] = None,
|
| 313 |
+
offload_buffers: bool = False,
|
| 314 |
+
skip_keys: Optional[Union[str, list[str]]] = None,
|
| 315 |
+
preload_module_classes: Optional[list[str]] = None,
|
| 316 |
+
force_hooks: bool = False,
|
| 317 |
+
):
|
| 318 |
+
"""
|
| 319 |
+
Dispatches a model according to a given device map. Layers of the model might be spread across GPUs, offloaded on
|
| 320 |
+
the CPU or even the disk.
|
| 321 |
+
|
| 322 |
+
Args:
|
| 323 |
+
model (`torch.nn.Module`):
|
| 324 |
+
The model to dispatch.
|
| 325 |
+
device_map (`Dict[str, Union[str, int, torch.device]]`):
|
| 326 |
+
A dictionary mapping module names in the models `state_dict` to the device they should go to. Note that
|
| 327 |
+
`"disk"` is accepted even if it's not a proper value for `torch.device`.
|
| 328 |
+
main_device (`str`, `int` or `torch.device`, *optional*):
|
| 329 |
+
The main execution device. Will default to the first device in the `device_map` different from `"cpu"` or
|
| 330 |
+
`"disk"`.
|
| 331 |
+
state_dict (`Dict[str, torch.Tensor]`, *optional*):
|
| 332 |
+
The state dict of the part of the model that will be kept on CPU.
|
| 333 |
+
offload_dir (`str` or `os.PathLike`):
|
| 334 |
+
The folder in which to offload the model weights (or where the model weights are already offloaded).
|
| 335 |
+
offload_index (`Dict`, *optional*):
|
| 336 |
+
A dictionary from weight name to their information (`dtype`/ `shape` or safetensors filename). Will default
|
| 337 |
+
to the index saved in `save_folder`.
|
| 338 |
+
offload_buffers (`bool`, *optional*, defaults to `False`):
|
| 339 |
+
Whether or not to offload the buffers with the model parameters.
|
| 340 |
+
skip_keys (`str` or `List[str]`, *optional*):
|
| 341 |
+
A list of keys to ignore when moving inputs or outputs between devices.
|
| 342 |
+
preload_module_classes (`List[str]`, *optional*):
|
| 343 |
+
A list of classes whose instances should load all their weights (even in the submodules) at the beginning
|
| 344 |
+
of the forward. This should only be used for classes that have submodules which are registered but not
|
| 345 |
+
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
|
| 346 |
+
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
|
| 347 |
+
force_hooks (`bool`, *optional*, defaults to `False`):
|
| 348 |
+
Whether or not to force device hooks to be attached to the model even if all layers are dispatched to a
|
| 349 |
+
single device.
|
| 350 |
+
"""
|
| 351 |
+
# Error early if the device map is incomplete.
|
| 352 |
+
check_device_map(model, device_map)
|
| 353 |
+
|
| 354 |
+
# We need to force hook for quantized model that can't be moved with to()
|
| 355 |
+
if getattr(model, "quantization_method", "bitsandbytes") == "bitsandbytes":
|
| 356 |
+
# since bnb 0.43.2, we can move 4-bit model
|
| 357 |
+
if getattr(model, "is_loaded_in_8bit", False) or (
|
| 358 |
+
getattr(model, "is_loaded_in_4bit", False) and not is_bnb_available(min_version="0.43.2")
|
| 359 |
+
):
|
| 360 |
+
force_hooks = True
|
| 361 |
+
|
| 362 |
+
# We attach hooks if the device_map has at least 2 different devices or if
|
| 363 |
+
# force_hooks is set to `True`. Otherwise, the model in already loaded
|
| 364 |
+
# in the unique device and the user can decide where to dispatch the model.
|
| 365 |
+
# If the model is quantized, we always force-dispatch the model
|
| 366 |
+
if (len(set(device_map.values())) > 1) or force_hooks:
|
| 367 |
+
if main_device is None:
|
| 368 |
+
if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {"cpu", "disk"}:
|
| 369 |
+
main_device = "cpu"
|
| 370 |
+
else:
|
| 371 |
+
main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0]
|
| 372 |
+
|
| 373 |
+
if main_device != "cpu":
|
| 374 |
+
cpu_modules = [name for name, device in device_map.items() if device == "cpu"]
|
| 375 |
+
if state_dict is None and len(cpu_modules) > 0:
|
| 376 |
+
state_dict = extract_submodules_state_dict(model.state_dict(), cpu_modules)
|
| 377 |
+
|
| 378 |
+
disk_modules = [name for name, device in device_map.items() if device == "disk"]
|
| 379 |
+
if offload_dir is None and offload_index is None and len(disk_modules) > 0:
|
| 380 |
+
raise ValueError(
|
| 381 |
+
"We need an `offload_dir` to dispatch this model according to this `device_map`, the following submodules "
|
| 382 |
+
f"need to be offloaded: {', '.join(disk_modules)}."
|
| 383 |
+
)
|
| 384 |
+
if (
|
| 385 |
+
len(disk_modules) > 0
|
| 386 |
+
and offload_index is None
|
| 387 |
+
and (not os.path.isdir(offload_dir) or not os.path.isfile(os.path.join(offload_dir, "index.json")))
|
| 388 |
+
):
|
| 389 |
+
disk_state_dict = extract_submodules_state_dict(model.state_dict(), disk_modules)
|
| 390 |
+
offload_state_dict(offload_dir, disk_state_dict)
|
| 391 |
+
|
| 392 |
+
execution_device = {
|
| 393 |
+
name: main_device if device in ["cpu", "disk"] else device for name, device in device_map.items()
|
| 394 |
+
}
|
| 395 |
+
execution_device[""] = main_device
|
| 396 |
+
offloaded_devices = ["disk"] if main_device == "cpu" or main_device == "mps" else ["cpu", "disk"]
|
| 397 |
+
offload = {name: device in offloaded_devices for name, device in device_map.items()}
|
| 398 |
+
save_folder = offload_dir if len(disk_modules) > 0 else None
|
| 399 |
+
if state_dict is not None or save_folder is not None or offload_index is not None:
|
| 400 |
+
device = main_device if offload_index is not None else None
|
| 401 |
+
weights_map = OffloadedWeightsLoader(
|
| 402 |
+
state_dict=state_dict, save_folder=save_folder, index=offload_index, device=device
|
| 403 |
+
)
|
| 404 |
+
else:
|
| 405 |
+
weights_map = None
|
| 406 |
+
|
| 407 |
+
# When dispatching the model's parameters to the devices specified in device_map, we want to avoid allocating memory several times for the
|
| 408 |
+
# tied parameters. The dictionary tied_params_map keeps track of the already allocated data for a given tied parameter (represented by its
|
| 409 |
+
# original pointer) on each devices.
|
| 410 |
+
tied_params = find_tied_parameters(model)
|
| 411 |
+
|
| 412 |
+
tied_params_map = {}
|
| 413 |
+
for group in tied_params:
|
| 414 |
+
for param_name in group:
|
| 415 |
+
# data_ptr() is enough here, as `find_tied_parameters` finds tied params simply by comparing `param1 is param2`, so we don't need
|
| 416 |
+
# to care about views of tensors through storage_offset.
|
| 417 |
+
data_ptr = recursive_getattr(model, param_name).data_ptr()
|
| 418 |
+
tied_params_map[data_ptr] = {}
|
| 419 |
+
|
| 420 |
+
# Note: To handle the disk offloading case, we can not simply use weights_map[param_name].data_ptr() as the reference pointer,
|
| 421 |
+
# as we have no guarantee that safetensors' `file.get_tensor()` will always give the same pointer.
|
| 422 |
+
|
| 423 |
+
attach_align_device_hook_on_blocks(
|
| 424 |
+
model,
|
| 425 |
+
execution_device=execution_device,
|
| 426 |
+
offload=offload,
|
| 427 |
+
offload_buffers=offload_buffers,
|
| 428 |
+
weights_map=weights_map,
|
| 429 |
+
skip_keys=skip_keys,
|
| 430 |
+
preload_module_classes=preload_module_classes,
|
| 431 |
+
tied_params_map=tied_params_map,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
# warn if there is any params on the meta device
|
| 435 |
+
offloaded_devices_str = " and ".join(
|
| 436 |
+
[device for device in set(device_map.values()) if device in ("cpu", "disk")]
|
| 437 |
+
)
|
| 438 |
+
if len(offloaded_devices_str) > 0:
|
| 439 |
+
logger.warning(
|
| 440 |
+
f"Some parameters are on the meta device because they were offloaded to the {offloaded_devices_str}."
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
# Attaching the hook may break tied weights, so we retie them
|
| 444 |
+
retie_parameters(model, tied_params)
|
| 445 |
+
|
| 446 |
+
# add warning to cuda and to method
|
| 447 |
+
def add_warning(fn, model):
|
| 448 |
+
@wraps(fn)
|
| 449 |
+
def wrapper(*args, **kwargs):
|
| 450 |
+
warning_msg = "You shouldn't move a model that is dispatched using accelerate hooks."
|
| 451 |
+
if str(fn.__name__) == "to":
|
| 452 |
+
to_device = torch._C._nn._parse_to(*args, **kwargs)[0]
|
| 453 |
+
if to_device is not None:
|
| 454 |
+
logger.warning(warning_msg)
|
| 455 |
+
else:
|
| 456 |
+
logger.warning(warning_msg)
|
| 457 |
+
for param in model.parameters():
|
| 458 |
+
if param.device == torch.device("meta"):
|
| 459 |
+
raise RuntimeError("You can't move a model that has some modules offloaded to cpu or disk.")
|
| 460 |
+
return fn(*args, **kwargs)
|
| 461 |
+
|
| 462 |
+
return wrapper
|
| 463 |
+
|
| 464 |
+
# Make sure to update _accelerate_added_attributes in hooks.py if you add any hook
|
| 465 |
+
model.to = add_warning(model.to, model)
|
| 466 |
+
if is_npu_available():
|
| 467 |
+
model.npu = add_warning(model.npu, model)
|
| 468 |
+
elif is_mlu_available():
|
| 469 |
+
model.mlu = add_warning(model.mlu, model)
|
| 470 |
+
elif is_sdaa_available():
|
| 471 |
+
model.sdaa = add_warning(model.sdaa, model)
|
| 472 |
+
elif is_musa_available():
|
| 473 |
+
model.musa = add_warning(model.musa, model)
|
| 474 |
+
elif is_xpu_available():
|
| 475 |
+
model.xpu = add_warning(model.xpu, model)
|
| 476 |
+
else:
|
| 477 |
+
model.cuda = add_warning(model.cuda, model)
|
| 478 |
+
|
| 479 |
+
# Check if we are using multi-gpus with RTX 4000 series
|
| 480 |
+
use_multi_gpu = len([device for device in set(device_map.values()) if device not in ("cpu", "disk")]) > 1
|
| 481 |
+
if use_multi_gpu and not check_cuda_p2p_ib_support():
|
| 482 |
+
logger.warning(
|
| 483 |
+
"We've detected an older driver with an RTX 4000 series GPU. These drivers have issues with P2P. "
|
| 484 |
+
"This can affect the multi-gpu inference when using accelerate device_map."
|
| 485 |
+
"Please make sure to update your driver to the latest version which resolves this."
|
| 486 |
+
)
|
| 487 |
+
else:
|
| 488 |
+
device = list(device_map.values())[0]
|
| 489 |
+
# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
|
| 490 |
+
if is_npu_available() and isinstance(device, int):
|
| 491 |
+
device = f"npu:{device}"
|
| 492 |
+
elif is_mlu_available() and isinstance(device, int):
|
| 493 |
+
device = f"mlu:{device}"
|
| 494 |
+
elif is_sdaa_available() and isinstance(device, int):
|
| 495 |
+
device = f"sdaa:{device}"
|
| 496 |
+
elif is_musa_available() and isinstance(device, int):
|
| 497 |
+
device = f"musa:{device}"
|
| 498 |
+
if device != "disk":
|
| 499 |
+
model.to(device)
|
| 500 |
+
else:
|
| 501 |
+
raise ValueError(
|
| 502 |
+
"You are trying to offload the whole model to the disk. Please use the `disk_offload` function instead."
|
| 503 |
+
)
|
| 504 |
+
# Convert OrderedDict back to dict for easier usage
|
| 505 |
+
model.hf_device_map = dict(device_map)
|
| 506 |
+
return model
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
def load_checkpoint_and_dispatch(
|
| 510 |
+
model: nn.Module,
|
| 511 |
+
checkpoint: Union[str, os.PathLike],
|
| 512 |
+
device_map: Optional[Union[str, dict[str, Union[int, str, torch.device]]]] = None,
|
| 513 |
+
max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None,
|
| 514 |
+
no_split_module_classes: Optional[list[str]] = None,
|
| 515 |
+
offload_folder: Optional[Union[str, os.PathLike]] = None,
|
| 516 |
+
offload_buffers: bool = False,
|
| 517 |
+
dtype: Optional[Union[str, torch.dtype]] = None,
|
| 518 |
+
offload_state_dict: Optional[bool] = None,
|
| 519 |
+
skip_keys: Optional[Union[str, list[str]]] = None,
|
| 520 |
+
preload_module_classes: Optional[list[str]] = None,
|
| 521 |
+
force_hooks: bool = False,
|
| 522 |
+
strict: bool = False,
|
| 523 |
+
):
|
| 524 |
+
"""
|
| 525 |
+
Loads a (potentially sharded) checkpoint inside a model, potentially sending weights to a given device as they are
|
| 526 |
+
loaded and adds the various hooks that will make this model run properly (even if split across devices).
|
| 527 |
+
|
| 528 |
+
Args:
|
| 529 |
+
model (`torch.nn.Module`): The model in which we want to load a checkpoint.
|
| 530 |
+
checkpoint (`str` or `os.PathLike`):
|
| 531 |
+
The folder checkpoint to load. It can be:
|
| 532 |
+
- a path to a file containing a whole model state dict
|
| 533 |
+
- a path to a `.json` file containing the index to a sharded checkpoint
|
| 534 |
+
- a path to a folder containing a unique `.index.json` file and the shards of a checkpoint.
|
| 535 |
+
device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*):
|
| 536 |
+
A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer
|
| 537 |
+
name, once a given module name is inside, every submodule of it will be sent to the same device.
|
| 538 |
+
|
| 539 |
+
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For more
|
| 540 |
+
information about each option see [here](../concept_guides/big_model_inference#designing-a-device-map).
|
| 541 |
+
Defaults to None, which means [`dispatch_model`] will not be called.
|
| 542 |
+
max_memory (`Dict`, *optional*):
|
| 543 |
+
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each GPU
|
| 544 |
+
and the available CPU RAM if unset.
|
| 545 |
+
no_split_module_classes (`List[str]`, *optional*):
|
| 546 |
+
A list of layer class names that should never be split across device (for instance any layer that has a
|
| 547 |
+
residual connection).
|
| 548 |
+
offload_folder (`str` or `os.PathLike`, *optional*):
|
| 549 |
+
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
|
| 550 |
+
offload_buffers (`bool`, *optional*, defaults to `False`):
|
| 551 |
+
In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as
|
| 552 |
+
well as the parameters.
|
| 553 |
+
dtype (`str` or `torch.dtype`, *optional*):
|
| 554 |
+
If provided, the weights will be converted to that type when loaded.
|
| 555 |
+
offload_state_dict (`bool`, *optional*):
|
| 556 |
+
If `True`, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if
|
| 557 |
+
the weight of the CPU state dict + the biggest shard does not fit. Will default to `True` if the device map
|
| 558 |
+
picked contains `"disk"` values.
|
| 559 |
+
skip_keys (`str` or `List[str]`, *optional*):
|
| 560 |
+
A list of keys to ignore when moving inputs or outputs between devices.
|
| 561 |
+
preload_module_classes (`List[str]`, *optional*):
|
| 562 |
+
A list of classes whose instances should load all their weights (even in the submodules) at the beginning
|
| 563 |
+
of the forward. This should only be used for classes that have submodules which are registered but not
|
| 564 |
+
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
|
| 565 |
+
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
|
| 566 |
+
force_hooks (`bool`, *optional*, defaults to `False`):
|
| 567 |
+
Whether or not to force device hooks to be attached to the model even if all layers are dispatched to a
|
| 568 |
+
single device.
|
| 569 |
+
strict (`bool`, *optional*, defaults to `False`):
|
| 570 |
+
Whether to strictly enforce that the keys in the checkpoint state_dict match the keys of the model's
|
| 571 |
+
state_dict.
|
| 572 |
+
|
| 573 |
+
Example:
|
| 574 |
+
|
| 575 |
+
```python
|
| 576 |
+
>>> from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
| 577 |
+
>>> from huggingface_hub import hf_hub_download
|
| 578 |
+
>>> from transformers import AutoConfig, AutoModelForCausalLM
|
| 579 |
+
|
| 580 |
+
>>> # Download the Weights
|
| 581 |
+
>>> checkpoint = "EleutherAI/gpt-j-6B"
|
| 582 |
+
>>> weights_location = hf_hub_download(checkpoint, "pytorch_model.bin")
|
| 583 |
+
|
| 584 |
+
>>> # Create a model and initialize it with empty weights
|
| 585 |
+
>>> config = AutoConfig.from_pretrained(checkpoint)
|
| 586 |
+
>>> with init_empty_weights():
|
| 587 |
+
... model = AutoModelForCausalLM.from_config(config)
|
| 588 |
+
|
| 589 |
+
>>> # Load the checkpoint and dispatch it to the right devices
|
| 590 |
+
>>> model = load_checkpoint_and_dispatch(
|
| 591 |
+
... model, weights_location, device_map="auto", no_split_module_classes=["GPTJBlock"]
|
| 592 |
+
... )
|
| 593 |
+
```
|
| 594 |
+
"""
|
| 595 |
+
if isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
|
| 596 |
+
raise ValueError(
|
| 597 |
+
"If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or 'sequential'."
|
| 598 |
+
)
|
| 599 |
+
if isinstance(device_map, str):
|
| 600 |
+
if device_map != "sequential":
|
| 601 |
+
max_memory = get_balanced_memory(
|
| 602 |
+
model,
|
| 603 |
+
max_memory=max_memory,
|
| 604 |
+
no_split_module_classes=no_split_module_classes,
|
| 605 |
+
dtype=dtype,
|
| 606 |
+
low_zero=(device_map == "balanced_low_0"),
|
| 607 |
+
)
|
| 608 |
+
device_map = infer_auto_device_map(
|
| 609 |
+
model,
|
| 610 |
+
max_memory=max_memory,
|
| 611 |
+
no_split_module_classes=no_split_module_classes,
|
| 612 |
+
dtype=dtype,
|
| 613 |
+
offload_buffers=offload_buffers,
|
| 614 |
+
)
|
| 615 |
+
if offload_state_dict is None and device_map is not None and "disk" in device_map.values():
|
| 616 |
+
offload_state_dict = True
|
| 617 |
+
load_checkpoint_in_model(
|
| 618 |
+
model,
|
| 619 |
+
checkpoint,
|
| 620 |
+
device_map=device_map,
|
| 621 |
+
offload_folder=offload_folder,
|
| 622 |
+
dtype=dtype,
|
| 623 |
+
offload_state_dict=offload_state_dict,
|
| 624 |
+
offload_buffers=offload_buffers,
|
| 625 |
+
strict=strict,
|
| 626 |
+
)
|
| 627 |
+
if device_map is None:
|
| 628 |
+
return model
|
| 629 |
+
return dispatch_model(
|
| 630 |
+
model,
|
| 631 |
+
device_map=device_map,
|
| 632 |
+
offload_dir=offload_folder,
|
| 633 |
+
offload_buffers=offload_buffers,
|
| 634 |
+
skip_keys=skip_keys,
|
| 635 |
+
preload_module_classes=preload_module_classes,
|
| 636 |
+
force_hooks=force_hooks,
|
| 637 |
+
)
|
venv/Lib/site-packages/accelerate/checkpointing.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import random
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
from safetensors.torch import load_model
|
| 21 |
+
from torch.cuda.amp import GradScaler
|
| 22 |
+
|
| 23 |
+
from .utils import (
|
| 24 |
+
MODEL_NAME,
|
| 25 |
+
OPTIMIZER_NAME,
|
| 26 |
+
RNG_STATE_NAME,
|
| 27 |
+
SAFE_MODEL_NAME,
|
| 28 |
+
SAFE_WEIGHTS_NAME,
|
| 29 |
+
SAMPLER_NAME,
|
| 30 |
+
SCALER_NAME,
|
| 31 |
+
SCHEDULER_NAME,
|
| 32 |
+
WEIGHTS_NAME,
|
| 33 |
+
get_pretty_name,
|
| 34 |
+
is_cuda_available,
|
| 35 |
+
is_hpu_available,
|
| 36 |
+
is_mlu_available,
|
| 37 |
+
is_musa_available,
|
| 38 |
+
is_sdaa_available,
|
| 39 |
+
is_torch_xla_available,
|
| 40 |
+
is_xpu_available,
|
| 41 |
+
load,
|
| 42 |
+
save,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
if is_torch_xla_available():
|
| 47 |
+
import torch_xla.core.xla_model as xm
|
| 48 |
+
|
| 49 |
+
from .logging import get_logger
|
| 50 |
+
from .state import PartialState
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
logger = get_logger(__name__)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def save_accelerator_state(
|
| 57 |
+
output_dir: str,
|
| 58 |
+
model_states: list[dict],
|
| 59 |
+
optimizers: list,
|
| 60 |
+
schedulers: list,
|
| 61 |
+
dataloaders: list,
|
| 62 |
+
process_index: int,
|
| 63 |
+
step: int,
|
| 64 |
+
scaler: GradScaler = None,
|
| 65 |
+
save_on_each_node: bool = False,
|
| 66 |
+
safe_serialization: bool = True,
|
| 67 |
+
):
|
| 68 |
+
"""
|
| 69 |
+
Saves the current states of the models, optimizers, scaler, and RNG generators to a given directory.
|
| 70 |
+
|
| 71 |
+
<Tip>
|
| 72 |
+
|
| 73 |
+
If `safe_serialization` is `True`, models will be saved with `safetensors` while the rest are saved using native
|
| 74 |
+
`pickle`.
|
| 75 |
+
|
| 76 |
+
</Tip>
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
output_dir (`str` or `os.PathLike`):
|
| 80 |
+
The name of the folder to save all relevant weights and states.
|
| 81 |
+
model_states (`List[torch.nn.Module]`):
|
| 82 |
+
A list of model states
|
| 83 |
+
optimizers (`List[torch.optim.Optimizer]`):
|
| 84 |
+
A list of optimizer instances
|
| 85 |
+
schedulers (`List[torch.optim.lr_scheduler._LRScheduler]`):
|
| 86 |
+
A list of learning rate schedulers
|
| 87 |
+
dataloaders (`List[torch.utils.data.DataLoader]`):
|
| 88 |
+
A list of dataloader instances to save their sampler states
|
| 89 |
+
process_index (`int`):
|
| 90 |
+
The current process index in the Accelerator state
|
| 91 |
+
step (`int`):
|
| 92 |
+
The current step in the internal step tracker
|
| 93 |
+
scaler (`torch.amp.GradScaler`, *optional*):
|
| 94 |
+
An optional gradient scaler instance to save;
|
| 95 |
+
save_on_each_node (`bool`, *optional*):
|
| 96 |
+
Whether to save on every node, or only the main node.
|
| 97 |
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
| 98 |
+
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
| 99 |
+
"""
|
| 100 |
+
output_dir = Path(output_dir)
|
| 101 |
+
# Model states
|
| 102 |
+
for i, state in enumerate(model_states):
|
| 103 |
+
weights_name = WEIGHTS_NAME if not safe_serialization else SAFE_WEIGHTS_NAME
|
| 104 |
+
if i > 0:
|
| 105 |
+
weights_name = weights_name.replace(".", f"_{i}.")
|
| 106 |
+
output_model_file = output_dir.joinpath(weights_name)
|
| 107 |
+
save(state, output_model_file, save_on_each_node=save_on_each_node, safe_serialization=safe_serialization)
|
| 108 |
+
logger.info(f"Model weights saved in {output_model_file}")
|
| 109 |
+
# Optimizer states
|
| 110 |
+
for i, opt in enumerate(optimizers):
|
| 111 |
+
state = opt.state_dict()
|
| 112 |
+
optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin"
|
| 113 |
+
output_optimizer_file = output_dir.joinpath(optimizer_name)
|
| 114 |
+
save(state, output_optimizer_file, save_on_each_node=save_on_each_node, safe_serialization=False)
|
| 115 |
+
logger.info(f"Optimizer state saved in {output_optimizer_file}")
|
| 116 |
+
# Scheduler states
|
| 117 |
+
for i, scheduler in enumerate(schedulers):
|
| 118 |
+
state = scheduler.state_dict()
|
| 119 |
+
scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin"
|
| 120 |
+
output_scheduler_file = output_dir.joinpath(scheduler_name)
|
| 121 |
+
save(state, output_scheduler_file, save_on_each_node=save_on_each_node, safe_serialization=False)
|
| 122 |
+
logger.info(f"Scheduler state saved in {output_scheduler_file}")
|
| 123 |
+
# DataLoader states
|
| 124 |
+
for i, dataloader in enumerate(dataloaders):
|
| 125 |
+
sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin"
|
| 126 |
+
output_sampler_file = output_dir.joinpath(sampler_name)
|
| 127 |
+
# Only save if we have our custom sampler
|
| 128 |
+
from .data_loader import IterableDatasetShard, SeedableRandomSampler
|
| 129 |
+
|
| 130 |
+
if isinstance(dataloader.dataset, IterableDatasetShard):
|
| 131 |
+
sampler = dataloader.get_sampler()
|
| 132 |
+
if isinstance(sampler, SeedableRandomSampler):
|
| 133 |
+
save(sampler, output_sampler_file, save_on_each_node=save_on_each_node, safe_serialization=False)
|
| 134 |
+
if getattr(dataloader, "use_stateful_dataloader", False):
|
| 135 |
+
dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
|
| 136 |
+
output_dataloader_state_dict_file = output_dir.joinpath(dataloader_state_dict_name)
|
| 137 |
+
state_dict = dataloader.state_dict()
|
| 138 |
+
torch.save(state_dict, output_dataloader_state_dict_file)
|
| 139 |
+
logger.info(f"Sampler state for dataloader {i} saved in {output_sampler_file}")
|
| 140 |
+
|
| 141 |
+
# GradScaler state
|
| 142 |
+
if scaler is not None:
|
| 143 |
+
state = scaler.state_dict()
|
| 144 |
+
output_scaler_file = output_dir.joinpath(SCALER_NAME)
|
| 145 |
+
torch.save(state, output_scaler_file)
|
| 146 |
+
logger.info(f"Gradient scaler state saved in {output_scaler_file}")
|
| 147 |
+
# Random number generator states
|
| 148 |
+
states = {}
|
| 149 |
+
states_name = f"{RNG_STATE_NAME}_{process_index}.pkl"
|
| 150 |
+
states["step"] = step
|
| 151 |
+
states["random_state"] = random.getstate()
|
| 152 |
+
states["numpy_random_seed"] = np.random.get_state()
|
| 153 |
+
states["torch_manual_seed"] = torch.get_rng_state()
|
| 154 |
+
if is_xpu_available():
|
| 155 |
+
states["torch_xpu_manual_seed"] = torch.xpu.get_rng_state_all()
|
| 156 |
+
if is_mlu_available():
|
| 157 |
+
states["torch_mlu_manual_seed"] = torch.mlu.get_rng_state_all()
|
| 158 |
+
elif is_sdaa_available():
|
| 159 |
+
states["torch_sdaa_manual_seed"] = torch.sdaa.get_rng_state_all()
|
| 160 |
+
elif is_musa_available():
|
| 161 |
+
states["torch_musa_manual_seed"] = torch.musa.get_rng_state_all()
|
| 162 |
+
if is_hpu_available():
|
| 163 |
+
states["torch_hpu_manual_seed"] = torch.hpu.get_rng_state_all()
|
| 164 |
+
if is_cuda_available():
|
| 165 |
+
states["torch_cuda_manual_seed"] = torch.cuda.get_rng_state_all()
|
| 166 |
+
if is_torch_xla_available():
|
| 167 |
+
states["xm_seed"] = xm.get_rng_state()
|
| 168 |
+
output_states_file = output_dir.joinpath(states_name)
|
| 169 |
+
torch.save(states, output_states_file)
|
| 170 |
+
logger.info(f"Random states saved in {output_states_file}")
|
| 171 |
+
return output_dir
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def load_accelerator_state(
|
| 175 |
+
input_dir,
|
| 176 |
+
models,
|
| 177 |
+
optimizers,
|
| 178 |
+
schedulers,
|
| 179 |
+
dataloaders,
|
| 180 |
+
process_index,
|
| 181 |
+
scaler=None,
|
| 182 |
+
map_location=None,
|
| 183 |
+
**load_model_func_kwargs,
|
| 184 |
+
):
|
| 185 |
+
"""
|
| 186 |
+
Loads states of the models, optimizers, scaler, and RNG generators from a given directory.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
input_dir (`str` or `os.PathLike`):
|
| 190 |
+
The name of the folder to load all relevant weights and states.
|
| 191 |
+
models (`List[torch.nn.Module]`):
|
| 192 |
+
A list of model instances
|
| 193 |
+
optimizers (`List[torch.optim.Optimizer]`):
|
| 194 |
+
A list of optimizer instances
|
| 195 |
+
schedulers (`List[torch.optim.lr_scheduler._LRScheduler]`):
|
| 196 |
+
A list of learning rate schedulers
|
| 197 |
+
process_index (`int`):
|
| 198 |
+
The current process index in the Accelerator state
|
| 199 |
+
scaler (`torch.amp.GradScaler`, *optional*):
|
| 200 |
+
An optional *GradScaler* instance to load
|
| 201 |
+
map_location (`str`, *optional*):
|
| 202 |
+
What device to load the optimizer state onto. Should be one of either "cpu" or "on_device".
|
| 203 |
+
load_model_func_kwargs (`dict`, *optional*):
|
| 204 |
+
Additional arguments that can be passed to the model's `load_state_dict` method.
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
`dict`: Contains the `Accelerator` attributes to override while loading the state.
|
| 208 |
+
"""
|
| 209 |
+
# stores the `Accelerator` attributes to override
|
| 210 |
+
override_attributes = dict()
|
| 211 |
+
if map_location not in [None, "cpu", "on_device"]:
|
| 212 |
+
raise TypeError(
|
| 213 |
+
"Unsupported optimizer map location passed, please choose one of `None`, `'cpu'`, or `'on_device'`"
|
| 214 |
+
)
|
| 215 |
+
if map_location is None:
|
| 216 |
+
map_location = "cpu"
|
| 217 |
+
elif map_location == "on_device":
|
| 218 |
+
map_location = PartialState().device
|
| 219 |
+
|
| 220 |
+
input_dir = Path(input_dir)
|
| 221 |
+
# Model states
|
| 222 |
+
for i, model in enumerate(models):
|
| 223 |
+
ending = f"_{i}" if i > 0 else ""
|
| 224 |
+
input_model_file = input_dir.joinpath(f"{SAFE_MODEL_NAME}{ending}.safetensors")
|
| 225 |
+
if input_model_file.exists():
|
| 226 |
+
load_model(model, input_model_file, device=str(map_location), **load_model_func_kwargs)
|
| 227 |
+
else:
|
| 228 |
+
# Load with torch
|
| 229 |
+
input_model_file = input_dir.joinpath(f"{MODEL_NAME}{ending}.bin")
|
| 230 |
+
state_dict = load(input_model_file, map_location=map_location)
|
| 231 |
+
model.load_state_dict(state_dict, **load_model_func_kwargs)
|
| 232 |
+
logger.info("All model weights loaded successfully")
|
| 233 |
+
|
| 234 |
+
# Optimizer states
|
| 235 |
+
for i, opt in enumerate(optimizers):
|
| 236 |
+
optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin"
|
| 237 |
+
input_optimizer_file = input_dir.joinpath(optimizer_name)
|
| 238 |
+
optimizer_state = load(input_optimizer_file, map_location=map_location)
|
| 239 |
+
optimizers[i].load_state_dict(optimizer_state)
|
| 240 |
+
logger.info("All optimizer states loaded successfully")
|
| 241 |
+
|
| 242 |
+
# Scheduler states
|
| 243 |
+
for i, scheduler in enumerate(schedulers):
|
| 244 |
+
scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin"
|
| 245 |
+
input_scheduler_file = input_dir.joinpath(scheduler_name)
|
| 246 |
+
scheduler_state = load(input_scheduler_file)
|
| 247 |
+
scheduler.load_state_dict(scheduler_state)
|
| 248 |
+
logger.info("All scheduler states loaded successfully")
|
| 249 |
+
|
| 250 |
+
for i, dataloader in enumerate(dataloaders):
|
| 251 |
+
sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin"
|
| 252 |
+
input_sampler_file = input_dir.joinpath(sampler_name)
|
| 253 |
+
# Only load if we have our custom sampler
|
| 254 |
+
from .data_loader import IterableDatasetShard, SeedableRandomSampler
|
| 255 |
+
|
| 256 |
+
if isinstance(dataloader.dataset, IterableDatasetShard):
|
| 257 |
+
sampler = dataloader.get_sampler()
|
| 258 |
+
if isinstance(sampler, SeedableRandomSampler):
|
| 259 |
+
sampler = dataloader.set_sampler(load(input_sampler_file))
|
| 260 |
+
if getattr(dataloader, "use_stateful_dataloader", False):
|
| 261 |
+
dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
|
| 262 |
+
input_dataloader_state_dict_file = input_dir.joinpath(dataloader_state_dict_name)
|
| 263 |
+
if input_dataloader_state_dict_file.exists():
|
| 264 |
+
state_dict = load(input_dataloader_state_dict_file)
|
| 265 |
+
dataloader.load_state_dict(state_dict)
|
| 266 |
+
logger.info("All dataloader sampler states loaded successfully")
|
| 267 |
+
|
| 268 |
+
# GradScaler state
|
| 269 |
+
if scaler is not None:
|
| 270 |
+
input_scaler_file = input_dir.joinpath(SCALER_NAME)
|
| 271 |
+
scaler_state = load(input_scaler_file)
|
| 272 |
+
scaler.load_state_dict(scaler_state)
|
| 273 |
+
logger.info("GradScaler state loaded successfully")
|
| 274 |
+
|
| 275 |
+
# Random states
|
| 276 |
+
try:
|
| 277 |
+
states = load(input_dir.joinpath(f"{RNG_STATE_NAME}_{process_index}.pkl"))
|
| 278 |
+
if "step" in states:
|
| 279 |
+
override_attributes["step"] = states["step"]
|
| 280 |
+
random.setstate(states["random_state"])
|
| 281 |
+
np.random.set_state(states["numpy_random_seed"])
|
| 282 |
+
torch.set_rng_state(states["torch_manual_seed"])
|
| 283 |
+
if is_xpu_available():
|
| 284 |
+
torch.xpu.set_rng_state_all(states["torch_xpu_manual_seed"])
|
| 285 |
+
if is_mlu_available():
|
| 286 |
+
torch.mlu.set_rng_state_all(states["torch_mlu_manual_seed"])
|
| 287 |
+
elif is_sdaa_available():
|
| 288 |
+
torch.sdaa.set_rng_state_all(states["torch_sdaa_manual_seed"])
|
| 289 |
+
elif is_musa_available():
|
| 290 |
+
torch.musa.set_rng_state_all(states["torch_musa_manual_seed"])
|
| 291 |
+
else:
|
| 292 |
+
torch.cuda.set_rng_state_all(states["torch_cuda_manual_seed"])
|
| 293 |
+
if is_torch_xla_available():
|
| 294 |
+
xm.set_rng_state(states["xm_seed"])
|
| 295 |
+
logger.info("All random states loaded successfully")
|
| 296 |
+
except Exception:
|
| 297 |
+
logger.info("Could not load random states")
|
| 298 |
+
|
| 299 |
+
return override_attributes
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def save_custom_state(obj, path, index: int = 0, save_on_each_node: bool = False):
|
| 303 |
+
"""
|
| 304 |
+
Saves the state of `obj` to `{path}/custom_checkpoint_{index}.pkl`
|
| 305 |
+
"""
|
| 306 |
+
# Should this be the right way to get a qual_name type value from `obj`?
|
| 307 |
+
save_location = Path(path) / f"custom_checkpoint_{index}.pkl"
|
| 308 |
+
logger.info(f"Saving the state of {get_pretty_name(obj)} to {save_location}")
|
| 309 |
+
save(obj.state_dict(), save_location, save_on_each_node=save_on_each_node)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def load_custom_state(obj, path, index: int = 0):
|
| 313 |
+
"""
|
| 314 |
+
Loads the state of `obj` at `{path}/custom_checkpoint_{index}.pkl`. Will always set `weights_only=False` when
|
| 315 |
+
loading the state.
|
| 316 |
+
"""
|
| 317 |
+
load_location = f"{path}/custom_checkpoint_{index}.pkl"
|
| 318 |
+
logger.info(f"Loading the state of {get_pretty_name(obj)} from {load_location}")
|
| 319 |
+
obj.load_state_dict(load(load_location, map_location="cpu", weights_only=False))
|
venv/Lib/site-packages/accelerate/data_loader.py
ADDED
|
@@ -0,0 +1,1429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import importlib
|
| 16 |
+
import math
|
| 17 |
+
from contextlib import suppress
|
| 18 |
+
from typing import Callable, Optional, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
from packaging import version
|
| 22 |
+
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler
|
| 23 |
+
|
| 24 |
+
from .logging import get_logger
|
| 25 |
+
from .state import DistributedType, GradientState, PartialState, is_torch_xla_available
|
| 26 |
+
from .utils import (
|
| 27 |
+
RNGType,
|
| 28 |
+
broadcast,
|
| 29 |
+
broadcast_object_list,
|
| 30 |
+
compare_versions,
|
| 31 |
+
concatenate,
|
| 32 |
+
find_batch_size,
|
| 33 |
+
get_data_structure,
|
| 34 |
+
initialize_tensors,
|
| 35 |
+
is_torch_version,
|
| 36 |
+
is_torchdata_stateful_dataloader_available,
|
| 37 |
+
send_to_device,
|
| 38 |
+
slice_tensors,
|
| 39 |
+
synchronize_rng_states,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
logger = get_logger(__name__)
|
| 44 |
+
|
| 45 |
+
# kwargs of the DataLoader in min version 2.0
|
| 46 |
+
_PYTORCH_DATALOADER_KWARGS = {
|
| 47 |
+
"batch_size": 1,
|
| 48 |
+
"shuffle": False,
|
| 49 |
+
"sampler": None,
|
| 50 |
+
"batch_sampler": None,
|
| 51 |
+
"num_workers": 0,
|
| 52 |
+
"collate_fn": None,
|
| 53 |
+
"pin_memory": False,
|
| 54 |
+
"drop_last": False,
|
| 55 |
+
"timeout": 0,
|
| 56 |
+
"worker_init_fn": None,
|
| 57 |
+
"multiprocessing_context": None,
|
| 58 |
+
"generator": None,
|
| 59 |
+
"prefetch_factor": 2,
|
| 60 |
+
"persistent_workers": False,
|
| 61 |
+
"pin_memory_device": "",
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
# kwargs added after by version
|
| 65 |
+
_PYTORCH_DATALOADER_ADDITIONAL_KWARGS = {"2.6.0": {"in_order": True}}
|
| 66 |
+
|
| 67 |
+
for v, additional_kwargs in _PYTORCH_DATALOADER_ADDITIONAL_KWARGS.items():
|
| 68 |
+
if is_torch_version(">=", v):
|
| 69 |
+
_PYTORCH_DATALOADER_KWARGS.update(additional_kwargs)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class SeedableRandomSampler(RandomSampler):
|
| 73 |
+
"""
|
| 74 |
+
Same as a random sampler, except that in `__iter__` a seed can be used.
|
| 75 |
+
|
| 76 |
+
Needed specifically in distributed cases, when the random generator for each GPU needs to start from the same seed
|
| 77 |
+
and be fully reproducable on multiple iterations.
|
| 78 |
+
|
| 79 |
+
If a custom `generator` is passed, it will rely on its initial seed as well as the current iteration it is on
|
| 80 |
+
(stored in `self.epoch`).
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
def __init__(self, *args, **kwargs):
|
| 84 |
+
data_seed = kwargs.pop("data_seed", None)
|
| 85 |
+
super().__init__(*args, **kwargs)
|
| 86 |
+
|
| 87 |
+
self.initial_seed = data_seed if data_seed is not None else torch.random.initial_seed()
|
| 88 |
+
self.epoch = 0
|
| 89 |
+
|
| 90 |
+
def __iter__(self):
|
| 91 |
+
if self.generator is None:
|
| 92 |
+
self.generator = torch.Generator(
|
| 93 |
+
device=torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu"
|
| 94 |
+
)
|
| 95 |
+
self.generator.manual_seed(self.initial_seed)
|
| 96 |
+
|
| 97 |
+
# Allow `self.epoch` to modify the seed of the generator
|
| 98 |
+
seed = self.epoch + self.initial_seed
|
| 99 |
+
# print("Setting seed at epoch", self.epoch, seed)
|
| 100 |
+
self.generator.manual_seed(seed)
|
| 101 |
+
yield from super().__iter__()
|
| 102 |
+
self.set_epoch(self.epoch + 1)
|
| 103 |
+
|
| 104 |
+
def set_epoch(self, epoch: int):
|
| 105 |
+
"Sets the current iteration of the sampler."
|
| 106 |
+
self.epoch = epoch
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class BatchSamplerShard(BatchSampler):
|
| 110 |
+
"""
|
| 111 |
+
Wraps a PyTorch `BatchSampler` to generate batches for one of the processes only. Instances of this class will
|
| 112 |
+
always yield a number of batches that is a round multiple of `num_processes` and that all have the same size.
|
| 113 |
+
Depending on the value of the `drop_last` attribute of the batch sampler passed, it will either stop the iteration
|
| 114 |
+
at the first batch that would be too small / not present on all processes or loop with indices from the beginning.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
batch_sampler (`torch.utils.data.sampler.BatchSampler`):
|
| 118 |
+
The batch sampler to split in several shards.
|
| 119 |
+
num_processes (`int`, *optional*, defaults to 1):
|
| 120 |
+
The number of processes running concurrently.
|
| 121 |
+
process_index (`int`, *optional*, defaults to 0):
|
| 122 |
+
The index of the current process.
|
| 123 |
+
split_batches (`bool`, *optional*, defaults to `False`):
|
| 124 |
+
Whether the shards should be created by splitting a batch to give a piece of it on each process, or by
|
| 125 |
+
yielding different full batches on each process.
|
| 126 |
+
|
| 127 |
+
On two processes with a sampler of `[[0, 1, 2, 3], [4, 5, 6, 7]]`, this will result in:
|
| 128 |
+
|
| 129 |
+
- the sampler on process 0 to yield `[0, 1, 2, 3]` and the sampler on process 1 to yield `[4, 5, 6, 7]` if
|
| 130 |
+
this argument is set to `False`.
|
| 131 |
+
- the sampler on process 0 to yield `[0, 1]` then `[4, 5]` and the sampler on process 1 to yield `[2, 3]`
|
| 132 |
+
then `[6, 7]` if this argument is set to `True`.
|
| 133 |
+
even_batches (`bool`, *optional*, defaults to `True`):
|
| 134 |
+
Whether or not to loop back at the beginning of the sampler when the number of samples is not a round
|
| 135 |
+
multiple of (original batch size / number of processes).
|
| 136 |
+
|
| 137 |
+
<Tip warning={true}>
|
| 138 |
+
|
| 139 |
+
`BatchSampler`s with varying batch sizes are not enabled by default. To enable this behaviour, set `even_batches`
|
| 140 |
+
equal to `False`
|
| 141 |
+
|
| 142 |
+
</Tip>"""
|
| 143 |
+
|
| 144 |
+
def __init__(
|
| 145 |
+
self,
|
| 146 |
+
batch_sampler: BatchSampler,
|
| 147 |
+
num_processes: int = 1,
|
| 148 |
+
process_index: int = 0,
|
| 149 |
+
split_batches: bool = False,
|
| 150 |
+
even_batches: bool = True,
|
| 151 |
+
):
|
| 152 |
+
if split_batches and batch_sampler.batch_size % num_processes != 0:
|
| 153 |
+
raise ValueError(
|
| 154 |
+
f"To use `BatchSamplerShard` in `split_batches` mode, the batch size ({batch_sampler.batch_size}) "
|
| 155 |
+
f"needs to be a round multiple of the number of processes ({num_processes})."
|
| 156 |
+
)
|
| 157 |
+
self.batch_sampler = batch_sampler
|
| 158 |
+
self.num_processes = num_processes
|
| 159 |
+
self.process_index = process_index
|
| 160 |
+
self.split_batches = split_batches
|
| 161 |
+
self.even_batches = even_batches
|
| 162 |
+
self.batch_size = getattr(batch_sampler, "batch_size", None)
|
| 163 |
+
self.drop_last = getattr(batch_sampler, "drop_last", False)
|
| 164 |
+
if self.batch_size is None and self.even_batches:
|
| 165 |
+
raise ValueError(
|
| 166 |
+
"You need to use `even_batches=False` when the batch sampler has no batch size. If you "
|
| 167 |
+
"are not calling this method directly, set `accelerator.even_batches=False` instead."
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
@property
|
| 171 |
+
def total_length(self):
|
| 172 |
+
return len(self.batch_sampler)
|
| 173 |
+
|
| 174 |
+
def __len__(self):
|
| 175 |
+
if self.split_batches:
|
| 176 |
+
# Split batches does not change the length of the batch sampler
|
| 177 |
+
return len(self.batch_sampler)
|
| 178 |
+
if len(self.batch_sampler) % self.num_processes == 0:
|
| 179 |
+
# If the length is a round multiple of the number of processes, it's easy.
|
| 180 |
+
return len(self.batch_sampler) // self.num_processes
|
| 181 |
+
length = len(self.batch_sampler) // self.num_processes
|
| 182 |
+
if self.drop_last:
|
| 183 |
+
# Same if we drop the remainder.
|
| 184 |
+
return length
|
| 185 |
+
elif self.even_batches:
|
| 186 |
+
# When we even batches we always get +1
|
| 187 |
+
return length + 1
|
| 188 |
+
else:
|
| 189 |
+
# Otherwise it depends on the process index.
|
| 190 |
+
return length + 1 if self.process_index < len(self.batch_sampler) % self.num_processes else length
|
| 191 |
+
|
| 192 |
+
def __iter__(self):
|
| 193 |
+
return self._iter_with_split() if self.split_batches else self._iter_with_no_split()
|
| 194 |
+
|
| 195 |
+
def _iter_with_split(self):
|
| 196 |
+
initial_data = []
|
| 197 |
+
batch_length = self.batch_sampler.batch_size // self.num_processes
|
| 198 |
+
for idx, batch in enumerate(self.batch_sampler):
|
| 199 |
+
if idx == 0:
|
| 200 |
+
initial_data = batch
|
| 201 |
+
if len(batch) == self.batch_size:
|
| 202 |
+
# If the batch is full, we yield the part of it this process is responsible of.
|
| 203 |
+
yield batch[batch_length * self.process_index : batch_length * (self.process_index + 1)]
|
| 204 |
+
|
| 205 |
+
# If drop_last is True of the last batch was full, iteration is over, otherwise...
|
| 206 |
+
if not self.drop_last and len(initial_data) > 0 and len(batch) < self.batch_size:
|
| 207 |
+
if not self.even_batches:
|
| 208 |
+
if len(batch) > batch_length * self.process_index:
|
| 209 |
+
yield batch[batch_length * self.process_index : batch_length * (self.process_index + 1)]
|
| 210 |
+
else:
|
| 211 |
+
# For degenerate cases where the dataset has less than num_process * batch_size samples
|
| 212 |
+
while len(initial_data) < self.batch_size:
|
| 213 |
+
initial_data += initial_data
|
| 214 |
+
batch = batch + initial_data
|
| 215 |
+
yield batch[batch_length * self.process_index : batch_length * (self.process_index + 1)]
|
| 216 |
+
|
| 217 |
+
def _iter_with_no_split(self):
|
| 218 |
+
initial_data = []
|
| 219 |
+
batch_to_yield = []
|
| 220 |
+
for idx, batch in enumerate(self.batch_sampler):
|
| 221 |
+
# We gather the initial indices in case we need to circle back at the end.
|
| 222 |
+
if not self.drop_last and idx < self.num_processes:
|
| 223 |
+
initial_data += batch
|
| 224 |
+
# We identify the batch to yield but wait until we ar sure every process gets a full batch before actually
|
| 225 |
+
# yielding it.
|
| 226 |
+
if idx % self.num_processes == self.process_index:
|
| 227 |
+
batch_to_yield = batch
|
| 228 |
+
if idx % self.num_processes == self.num_processes - 1 and (
|
| 229 |
+
self.batch_size is None or len(batch) == self.batch_size
|
| 230 |
+
):
|
| 231 |
+
yield batch_to_yield
|
| 232 |
+
batch_to_yield = []
|
| 233 |
+
|
| 234 |
+
# If drop_last is True, iteration is over, otherwise...
|
| 235 |
+
if not self.drop_last and len(initial_data) > 0:
|
| 236 |
+
if not self.even_batches:
|
| 237 |
+
if len(batch_to_yield) > 0:
|
| 238 |
+
yield batch_to_yield
|
| 239 |
+
else:
|
| 240 |
+
# ... we yield the complete batch we had saved before if it has the proper length
|
| 241 |
+
if len(batch_to_yield) == self.batch_size:
|
| 242 |
+
yield batch_to_yield
|
| 243 |
+
|
| 244 |
+
# For degenerate cases where the dataset has less than num_process * batch_size samples
|
| 245 |
+
while len(initial_data) < self.num_processes * self.batch_size:
|
| 246 |
+
initial_data += initial_data
|
| 247 |
+
|
| 248 |
+
# If the last batch seen was of the proper size, it has been yielded by its process so we move to the next
|
| 249 |
+
if len(batch) == self.batch_size:
|
| 250 |
+
batch = []
|
| 251 |
+
idx += 1
|
| 252 |
+
|
| 253 |
+
# Make sure we yield a multiple of self.num_processes batches
|
| 254 |
+
cycle_index = 0
|
| 255 |
+
while idx % self.num_processes != 0 or len(batch) > 0:
|
| 256 |
+
end_index = cycle_index + self.batch_size - len(batch)
|
| 257 |
+
batch += initial_data[cycle_index:end_index]
|
| 258 |
+
if idx % self.num_processes == self.process_index:
|
| 259 |
+
yield batch
|
| 260 |
+
cycle_index = end_index
|
| 261 |
+
batch = []
|
| 262 |
+
idx += 1
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
class IterableDatasetShard(IterableDataset):
|
| 266 |
+
"""
|
| 267 |
+
Wraps a PyTorch `IterableDataset` to generate samples for one of the processes only. Instances of this class will
|
| 268 |
+
always yield a number of samples that is a round multiple of the actual batch size (depending of the value of
|
| 269 |
+
`split_batches`, this is either `batch_size` or `batch_size x num_processes`). Depending on the value of the
|
| 270 |
+
`drop_last` attribute of the batch sampler passed, it will either stop the iteration at the first batch that would
|
| 271 |
+
be too small or loop with indices from the beginning.
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
dataset (`torch.utils.data.dataset.IterableDataset`):
|
| 275 |
+
The batch sampler to split in several shards.
|
| 276 |
+
batch_size (`int`, *optional*, defaults to 1):
|
| 277 |
+
The size of the batches per shard (if `split_batches=False`) or the size of the batches (if
|
| 278 |
+
`split_batches=True`).
|
| 279 |
+
drop_last (`bool`, *optional*, defaults to `False`):
|
| 280 |
+
Whether or not to drop the last incomplete batch or complete the last batches by using the samples from the
|
| 281 |
+
beginning.
|
| 282 |
+
num_processes (`int`, *optional*, defaults to 1):
|
| 283 |
+
The number of processes running concurrently.
|
| 284 |
+
process_index (`int`, *optional*, defaults to 0):
|
| 285 |
+
The index of the current process.
|
| 286 |
+
split_batches (`bool`, *optional*, defaults to `False`):
|
| 287 |
+
Whether the shards should be created by splitting a batch to give a piece of it on each process, or by
|
| 288 |
+
yielding different full batches on each process.
|
| 289 |
+
|
| 290 |
+
On two processes with an iterable dataset yielding of `[0, 1, 2, 3, 4, 5, 6, 7]`, this will result in:
|
| 291 |
+
|
| 292 |
+
- the shard on process 0 to yield `[0, 1, 2, 3]` and the shard on process 1 to yield `[4, 5, 6, 7]` if this
|
| 293 |
+
argument is set to `False`.
|
| 294 |
+
- the shard on process 0 to yield `[0, 1, 4, 5]` and the sampler on process 1 to yield `[2, 3, 6, 7]` if
|
| 295 |
+
this argument is set to `True`.
|
| 296 |
+
"""
|
| 297 |
+
|
| 298 |
+
def __init__(
|
| 299 |
+
self,
|
| 300 |
+
dataset: IterableDataset,
|
| 301 |
+
batch_size: int = 1,
|
| 302 |
+
drop_last: bool = False,
|
| 303 |
+
num_processes: int = 1,
|
| 304 |
+
process_index: int = 0,
|
| 305 |
+
split_batches: bool = False,
|
| 306 |
+
):
|
| 307 |
+
if split_batches and batch_size > 1 and batch_size % num_processes != 0:
|
| 308 |
+
raise ValueError(
|
| 309 |
+
f"To use `IterableDatasetShard` in `split_batches` mode, the batch size ({batch_size}) "
|
| 310 |
+
f"needs to be a round multiple of the number of processes ({num_processes})."
|
| 311 |
+
)
|
| 312 |
+
self.dataset = dataset
|
| 313 |
+
self.batch_size = batch_size
|
| 314 |
+
self.drop_last = drop_last
|
| 315 |
+
self.num_processes = num_processes
|
| 316 |
+
self.process_index = process_index
|
| 317 |
+
self.split_batches = split_batches
|
| 318 |
+
|
| 319 |
+
def set_epoch(self, epoch):
|
| 320 |
+
self.epoch = epoch
|
| 321 |
+
if hasattr(self.dataset, "set_epoch"):
|
| 322 |
+
self.dataset.set_epoch(epoch)
|
| 323 |
+
|
| 324 |
+
def __len__(self):
|
| 325 |
+
# We will just raise the downstream error if the underlying dataset is not sized
|
| 326 |
+
if self.drop_last:
|
| 327 |
+
return (len(self.dataset) // (self.batch_size * self.num_processes)) * self.batch_size
|
| 328 |
+
else:
|
| 329 |
+
return math.ceil(len(self.dataset) / (self.batch_size * self.num_processes)) * self.batch_size
|
| 330 |
+
|
| 331 |
+
def __iter__(self):
|
| 332 |
+
if (
|
| 333 |
+
not hasattr(self.dataset, "set_epoch")
|
| 334 |
+
and hasattr(self.dataset, "generator")
|
| 335 |
+
and isinstance(self.dataset.generator, torch.Generator)
|
| 336 |
+
):
|
| 337 |
+
self.dataset.generator.manual_seed(self.epoch)
|
| 338 |
+
real_batch_size = self.batch_size if self.split_batches else (self.batch_size * self.num_processes)
|
| 339 |
+
process_batch_size = (self.batch_size // self.num_processes) if self.split_batches else self.batch_size
|
| 340 |
+
process_slice = range(self.process_index * process_batch_size, (self.process_index + 1) * process_batch_size)
|
| 341 |
+
|
| 342 |
+
first_batch = None
|
| 343 |
+
current_batch = []
|
| 344 |
+
for element in self.dataset:
|
| 345 |
+
current_batch.append(element)
|
| 346 |
+
# Wait to have a full batch before yielding elements.
|
| 347 |
+
if len(current_batch) == real_batch_size:
|
| 348 |
+
for i in process_slice:
|
| 349 |
+
yield current_batch[i]
|
| 350 |
+
if first_batch is None:
|
| 351 |
+
first_batch = current_batch.copy()
|
| 352 |
+
current_batch = []
|
| 353 |
+
|
| 354 |
+
# Finished if drop_last is True, otherwise complete the last batch with elements from the beginning.
|
| 355 |
+
if not self.drop_last and len(current_batch) > 0:
|
| 356 |
+
if first_batch is None:
|
| 357 |
+
first_batch = current_batch.copy()
|
| 358 |
+
while len(current_batch) < real_batch_size:
|
| 359 |
+
current_batch += first_batch
|
| 360 |
+
for i in process_slice:
|
| 361 |
+
yield current_batch[i]
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
class DataLoaderStateMixin:
|
| 365 |
+
"""
|
| 366 |
+
Mixin class that adds a state to a `DataLoader` to keep track of the status inside the dataloader such as at the
|
| 367 |
+
end of the iteration, the number of items in the dataset in the last batch relative to the batch size, and other
|
| 368 |
+
useful information that might be needed.
|
| 369 |
+
|
| 370 |
+
**Available attributes:**
|
| 371 |
+
|
| 372 |
+
- **end_of_dataloader** (`bool`) -- Whether at the last iteration or batch
|
| 373 |
+
- **remainder** (`int`) -- The number of items that are remaining in the last batch, relative to the total
|
| 374 |
+
batch size
|
| 375 |
+
|
| 376 |
+
<Tip warning={true}>
|
| 377 |
+
|
| 378 |
+
Inheriters of this class should ensure that the class creates a `GradientState()` instance, stored in
|
| 379 |
+
`self.gradient_state`.
|
| 380 |
+
|
| 381 |
+
</Tip>
|
| 382 |
+
|
| 383 |
+
"""
|
| 384 |
+
|
| 385 |
+
def __init_subclass__(cls, **kwargs):
|
| 386 |
+
cls.end_of_dataloader = False
|
| 387 |
+
cls.remainder = -1
|
| 388 |
+
|
| 389 |
+
def reset(self):
|
| 390 |
+
self.end_of_dataloader = False
|
| 391 |
+
self.remainder = -1
|
| 392 |
+
|
| 393 |
+
def begin(self):
|
| 394 |
+
"Prepares the gradient state for the current dataloader"
|
| 395 |
+
self.reset()
|
| 396 |
+
with suppress(Exception):
|
| 397 |
+
if not self._drop_last:
|
| 398 |
+
length = getattr(self.dataset, "total_dataset_length", len(self.dataset))
|
| 399 |
+
self.remainder = length % self.total_batch_size
|
| 400 |
+
self.gradient_state._add_dataloader(self)
|
| 401 |
+
|
| 402 |
+
def end(self):
|
| 403 |
+
"Cleans up the gradient state after exiting the dataloader"
|
| 404 |
+
self.gradient_state._remove_dataloader(self)
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
class DataLoaderAdapter:
|
| 408 |
+
"""
|
| 409 |
+
A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. For
|
| 410 |
+
compatability reasons, this class inherits from the class it wraps around, so it can be used as a drop-in.
|
| 411 |
+
"""
|
| 412 |
+
|
| 413 |
+
def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, **kwargs):
|
| 414 |
+
self.use_stateful_dataloader = use_stateful_dataloader
|
| 415 |
+
if is_torchdata_stateful_dataloader_available():
|
| 416 |
+
from torchdata.stateful_dataloader import StatefulDataLoader
|
| 417 |
+
|
| 418 |
+
if use_stateful_dataloader and not is_torchdata_stateful_dataloader_available():
|
| 419 |
+
raise ImportError(
|
| 420 |
+
"StatefulDataLoader is not available. Please install torchdata version 0.8.0 or higher to use it."
|
| 421 |
+
)
|
| 422 |
+
if use_stateful_dataloader:
|
| 423 |
+
torchdata_version = version.parse(importlib.metadata.version("torchdata"))
|
| 424 |
+
if (
|
| 425 |
+
"in_order" in kwargs
|
| 426 |
+
and compare_versions(torchdata_version, "<", "0.11")
|
| 427 |
+
and is_torch_version(">=", "2.6.0")
|
| 428 |
+
):
|
| 429 |
+
kwargs.pop("in_order")
|
| 430 |
+
self.base_dataloader = StatefulDataLoader(dataset, batch_sampler=batch_sampler, **kwargs)
|
| 431 |
+
else:
|
| 432 |
+
self.base_dataloader = DataLoader(dataset, batch_sampler=batch_sampler, **kwargs)
|
| 433 |
+
|
| 434 |
+
if hasattr(self.base_dataloader, "state_dict"):
|
| 435 |
+
self.dl_state_dict = self.base_dataloader.state_dict()
|
| 436 |
+
|
| 437 |
+
def __getattr__(self, name):
|
| 438 |
+
# Avoid infinite recursion if we try to access a nonexistent base_dataloader attribute.
|
| 439 |
+
if name == "base_dataloader":
|
| 440 |
+
raise AttributeError()
|
| 441 |
+
# Delegate attribute access to the internal dataloader
|
| 442 |
+
return getattr(self.base_dataloader, name)
|
| 443 |
+
|
| 444 |
+
def state_dict(self):
|
| 445 |
+
return self.dl_state_dict
|
| 446 |
+
|
| 447 |
+
def load_state_dict(self, state_dict):
|
| 448 |
+
self.base_dataloader.load_state_dict(state_dict)
|
| 449 |
+
|
| 450 |
+
@property
|
| 451 |
+
def __class__(self):
|
| 452 |
+
"""
|
| 453 |
+
In order to maintain backwards compatability with other code, we need to ensure `isinstance(obj, DataLoader)`
|
| 454 |
+
returs true. This is because some downstream code assumes that the `DataLoader` is the base class of the
|
| 455 |
+
object.
|
| 456 |
+
"""
|
| 457 |
+
return self.base_dataloader.__class__
|
| 458 |
+
|
| 459 |
+
def __len__(self):
|
| 460 |
+
return len(self.base_dataloader)
|
| 461 |
+
|
| 462 |
+
def adjust_state_dict_for_prefetch(self):
|
| 463 |
+
"""
|
| 464 |
+
Adjusts the state dict for prefetching. Natively, this will adjust all of the iters yielded keys in
|
| 465 |
+
`self.dl_state_dict` by a factor of `num_processes - 1`, however if a custom correction is needed, this can be
|
| 466 |
+
overridden.
|
| 467 |
+
|
| 468 |
+
This should modify `self.dl_state_dict` directly
|
| 469 |
+
"""
|
| 470 |
+
# The state dict will be off by a factor of `n-1` batch too many during DDP,
|
| 471 |
+
# so we need to adjust it here
|
| 472 |
+
if PartialState().distributed_type != DistributedType.NO:
|
| 473 |
+
factor = PartialState().num_processes - 1
|
| 474 |
+
if self.dl_state_dict["_sampler_iter_yielded"] > 0:
|
| 475 |
+
self.dl_state_dict["_sampler_iter_yielded"] -= factor
|
| 476 |
+
if self.dl_state_dict["_num_yielded"] > 0:
|
| 477 |
+
self.dl_state_dict["_num_yielded"] -= factor
|
| 478 |
+
if self.dl_state_dict["_index_sampler_state"] is not None:
|
| 479 |
+
if (
|
| 480 |
+
"samples_yielded" in self.dl_state_dict["_index_sampler_state"]
|
| 481 |
+
and self.dl_state_dict["_index_sampler_state"]["samples_yielded"] > 0
|
| 482 |
+
):
|
| 483 |
+
self.dl_state_dict["_index_sampler_state"]["samples_yielded"] -= self.batch_size * factor
|
| 484 |
+
|
| 485 |
+
def _update_state_dict(self):
|
| 486 |
+
# The state_dict of the underlying base_dataloader may be ahead of what is currently being yielded.
|
| 487 |
+
# E.g. the implementation of DataLoaderShard involves having an underlying iterator 1 element ahead of
|
| 488 |
+
# what it wants to yield.
|
| 489 |
+
#
|
| 490 |
+
# _update_state_dict is called to snapshot the state_dict that would properly recover the DataLoaderAdapter.
|
| 491 |
+
if hasattr(self.base_dataloader, "state_dict"):
|
| 492 |
+
self.dl_state_dict = self.base_dataloader.state_dict()
|
| 493 |
+
# Potentially modify the state_dict to adjust for prefetching
|
| 494 |
+
self.adjust_state_dict_for_prefetch()
|
| 495 |
+
# Then tag if we are at the end of the dataloader
|
| 496 |
+
self.dl_state_dict["_iterator_finished"] = self.end_of_dataloader
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
|
| 500 |
+
"""
|
| 501 |
+
Subclass of `DataLoaderAdapter` that will deal with device placement and current distributed setup.
|
| 502 |
+
|
| 503 |
+
Args:
|
| 504 |
+
dataset (`torch.utils.data.dataset.Dataset`):
|
| 505 |
+
The dataset to use to build this dataloader.
|
| 506 |
+
device (`torch.device`, *optional*):
|
| 507 |
+
If passed, the device to put all batches on.
|
| 508 |
+
rng_types (list of `str` or [`~utils.RNGType`]):
|
| 509 |
+
The list of random number generators to synchronize at the beginning of each iteration. Should be one or
|
| 510 |
+
several of:
|
| 511 |
+
|
| 512 |
+
- `"torch"`: the base torch random number generator
|
| 513 |
+
- `"cuda"`: the CUDA random number generator (GPU only)
|
| 514 |
+
- `"xla"`: the XLA random number generator (TPU only)
|
| 515 |
+
- `"generator"`: an optional `torch.Generator`
|
| 516 |
+
synchronized_generator (`torch.Generator`, *optional*):
|
| 517 |
+
A random number generator to keep synchronized across processes.
|
| 518 |
+
skip_batches (`int`, *optional*, defaults to 0):
|
| 519 |
+
The number of batches to skip at the beginning.
|
| 520 |
+
use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
|
| 521 |
+
Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
|
| 522 |
+
**kwargs (additional keyword arguments, *optional*):
|
| 523 |
+
All other keyword arguments to pass to the regular `DataLoader` initialization.
|
| 524 |
+
|
| 525 |
+
**Available attributes:**
|
| 526 |
+
|
| 527 |
+
- **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes.
|
| 528 |
+
Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total
|
| 529 |
+
number of processes
|
| 530 |
+
|
| 531 |
+
- **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes.
|
| 532 |
+
"""
|
| 533 |
+
|
| 534 |
+
def __init__(
|
| 535 |
+
self,
|
| 536 |
+
dataset,
|
| 537 |
+
device=None,
|
| 538 |
+
rng_types=None,
|
| 539 |
+
synchronized_generator=None,
|
| 540 |
+
skip_batches=0,
|
| 541 |
+
use_stateful_dataloader=False,
|
| 542 |
+
_drop_last: bool = False,
|
| 543 |
+
_non_blocking: bool = False,
|
| 544 |
+
torch_device_mesh=None,
|
| 545 |
+
**kwargs,
|
| 546 |
+
):
|
| 547 |
+
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
|
| 548 |
+
self.device = device
|
| 549 |
+
self.rng_types = rng_types
|
| 550 |
+
self.synchronized_generator = synchronized_generator
|
| 551 |
+
self.skip_batches = skip_batches
|
| 552 |
+
self.gradient_state = GradientState()
|
| 553 |
+
self._drop_last = _drop_last
|
| 554 |
+
self._non_blocking = _non_blocking
|
| 555 |
+
self.iteration = 0
|
| 556 |
+
|
| 557 |
+
def __iter__(self):
|
| 558 |
+
if self.rng_types is not None:
|
| 559 |
+
synchronize_rng_states(self.rng_types, self.synchronized_generator)
|
| 560 |
+
self.begin()
|
| 561 |
+
|
| 562 |
+
self.set_epoch(self.iteration)
|
| 563 |
+
dataloader_iter = self.base_dataloader.__iter__()
|
| 564 |
+
# We iterate one batch ahead to check when we are at the end
|
| 565 |
+
try:
|
| 566 |
+
current_batch = next(dataloader_iter)
|
| 567 |
+
except StopIteration:
|
| 568 |
+
yield
|
| 569 |
+
|
| 570 |
+
batch_index = 0
|
| 571 |
+
while True:
|
| 572 |
+
try:
|
| 573 |
+
# But we still move it to the device so it is done before `StopIteration` is reached
|
| 574 |
+
if self.device is not None:
|
| 575 |
+
current_batch = send_to_device(current_batch, self.device, non_blocking=self._non_blocking)
|
| 576 |
+
self._update_state_dict()
|
| 577 |
+
next_batch = next(dataloader_iter)
|
| 578 |
+
if batch_index >= self.skip_batches:
|
| 579 |
+
yield current_batch
|
| 580 |
+
batch_index += 1
|
| 581 |
+
current_batch = next_batch
|
| 582 |
+
except StopIteration:
|
| 583 |
+
self.end_of_dataloader = True
|
| 584 |
+
self._update_state_dict()
|
| 585 |
+
if batch_index >= self.skip_batches:
|
| 586 |
+
yield current_batch
|
| 587 |
+
break
|
| 588 |
+
|
| 589 |
+
self.iteration += 1
|
| 590 |
+
self.end()
|
| 591 |
+
|
| 592 |
+
def __reduce__(self):
|
| 593 |
+
"""
|
| 594 |
+
Define the `__reduce__` method to ensure a `DataLoaderShard` can be pickled and unpickled. This needs to be
|
| 595 |
+
explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
|
| 596 |
+
`__class__` member.
|
| 597 |
+
"""
|
| 598 |
+
args = super().__reduce__()
|
| 599 |
+
return (DataLoaderShard, *args[1:])
|
| 600 |
+
|
| 601 |
+
def set_epoch(self, epoch: int):
|
| 602 |
+
# In case it is manually passed in, the user can set it to what they like
|
| 603 |
+
if self.iteration != epoch:
|
| 604 |
+
self.iteration = epoch
|
| 605 |
+
if hasattr(self.batch_sampler, "set_epoch"):
|
| 606 |
+
self.batch_sampler.set_epoch(epoch)
|
| 607 |
+
if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"):
|
| 608 |
+
self.batch_sampler.sampler.set_epoch(epoch)
|
| 609 |
+
# We support if a custom `Dataset` implementation has `set_epoch`
|
| 610 |
+
# or in general HF datasets `Datasets`
|
| 611 |
+
elif hasattr(self.dataset, "set_epoch"):
|
| 612 |
+
self.dataset.set_epoch(epoch)
|
| 613 |
+
|
| 614 |
+
@property
|
| 615 |
+
def total_batch_size(self):
|
| 616 |
+
batch_sampler = self.sampler if isinstance(self.sampler, BatchSampler) else self.batch_sampler
|
| 617 |
+
return (
|
| 618 |
+
batch_sampler.batch_size
|
| 619 |
+
if getattr(batch_sampler, "split_batches", False)
|
| 620 |
+
else (batch_sampler.batch_size * getattr(batch_sampler, "num_processes", 1))
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
@property
|
| 624 |
+
def total_dataset_length(self):
|
| 625 |
+
if hasattr(self.dataset, "total_length"):
|
| 626 |
+
return self.dataset.total_length
|
| 627 |
+
else:
|
| 628 |
+
return len(self.dataset)
|
| 629 |
+
|
| 630 |
+
def get_sampler(self):
|
| 631 |
+
return get_sampler(self)
|
| 632 |
+
|
| 633 |
+
def set_sampler(self, sampler):
|
| 634 |
+
sampler_is_batch_sampler = isinstance(self.sampler, BatchSampler)
|
| 635 |
+
if sampler_is_batch_sampler:
|
| 636 |
+
self.sampler.sampler = sampler
|
| 637 |
+
else:
|
| 638 |
+
self.batch_sampler.sampler = sampler
|
| 639 |
+
if hasattr(self.batch_sampler, "batch_sampler"):
|
| 640 |
+
self.batch_sampler.batch_sampler.sampler = sampler
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
if is_torch_xla_available():
|
| 644 |
+
import torch_xla.distributed.parallel_loader as xpl
|
| 645 |
+
|
| 646 |
+
class MpDeviceLoaderWrapper(xpl.MpDeviceLoader):
|
| 647 |
+
"""
|
| 648 |
+
Wrapper for the xpl.MpDeviceLoader class that knows the total batch size.
|
| 649 |
+
|
| 650 |
+
XLA preloading threads will all call DataLoaderShard's __iter__(). Remove rng_types from DataLoaderShard to
|
| 651 |
+
prevent it from using the XLA device in the preloading threads, and synchronize the RNG once from the main
|
| 652 |
+
thread only.
|
| 653 |
+
|
| 654 |
+
**Available attributes:**
|
| 655 |
+
|
| 656 |
+
- **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes.
|
| 657 |
+
Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total
|
| 658 |
+
number of processes
|
| 659 |
+
|
| 660 |
+
- **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes.
|
| 661 |
+
"""
|
| 662 |
+
|
| 663 |
+
def __init__(self, dataloader: DataLoaderShard, device: torch.device):
|
| 664 |
+
super().__init__(dataloader, device)
|
| 665 |
+
self._rng_types = self._loader.rng_types
|
| 666 |
+
self._loader.rng_types = None
|
| 667 |
+
self.device = device
|
| 668 |
+
|
| 669 |
+
def __iter__(self):
|
| 670 |
+
if self._rng_types is not None:
|
| 671 |
+
synchronize_rng_states(self._rng_types, self._loader.synchronized_generator)
|
| 672 |
+
|
| 673 |
+
return super().__iter__()
|
| 674 |
+
|
| 675 |
+
def set_epoch(self, epoch: int):
|
| 676 |
+
if hasattr(self.dataloader, "set_epoch"):
|
| 677 |
+
self.dataloader.set_epoch(epoch)
|
| 678 |
+
|
| 679 |
+
@property
|
| 680 |
+
def total_batch_size(self):
|
| 681 |
+
return self._loader.total_batch_size
|
| 682 |
+
|
| 683 |
+
@property
|
| 684 |
+
def total_dataset_length(self):
|
| 685 |
+
return self._loader.total_dataset_length
|
| 686 |
+
|
| 687 |
+
@property
|
| 688 |
+
def batch_sampler(self):
|
| 689 |
+
return self._loader.batch_sampler
|
| 690 |
+
|
| 691 |
+
@property
|
| 692 |
+
def dataloader(self):
|
| 693 |
+
return self._loader
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
|
| 697 |
+
"""
|
| 698 |
+
Subclass of `DataLoaderAdapter` that will iterate and preprocess on process 0 only, then dispatch on each process
|
| 699 |
+
their part of the batch.
|
| 700 |
+
|
| 701 |
+
Args:
|
| 702 |
+
split_batches (`bool`, *optional*, defaults to `False`):
|
| 703 |
+
Whether the resulting `DataLoader` should split the batches of the original data loader across devices or
|
| 704 |
+
yield full batches (in which case it will yield batches starting at the `process_index`-th and advancing of
|
| 705 |
+
`num_processes` batches at each iteration). Another way to see this is that the observed batch size will be
|
| 706 |
+
the same as the initial `dataloader` if this option is set to `True`, the batch size of the initial
|
| 707 |
+
`dataloader` multiplied by `num_processes` otherwise. Setting this option to `True` requires that the batch
|
| 708 |
+
size of the `dataloader` is a round multiple of `batch_size`.
|
| 709 |
+
skip_batches (`int`, *optional*, defaults to 0):
|
| 710 |
+
The number of batches to skip at the beginning of an iteration.
|
| 711 |
+
use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
|
| 712 |
+
Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
|
| 713 |
+
|
| 714 |
+
**Available attributes:**
|
| 715 |
+
|
| 716 |
+
- **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes.
|
| 717 |
+
Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total
|
| 718 |
+
number of processes
|
| 719 |
+
|
| 720 |
+
- **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes.
|
| 721 |
+
"""
|
| 722 |
+
|
| 723 |
+
def __init__(
|
| 724 |
+
self,
|
| 725 |
+
dataset,
|
| 726 |
+
split_batches: bool = False,
|
| 727 |
+
skip_batches=0,
|
| 728 |
+
use_stateful_dataloader=False,
|
| 729 |
+
_drop_last: bool = False,
|
| 730 |
+
_non_blocking: bool = False,
|
| 731 |
+
slice_fn=None,
|
| 732 |
+
torch_device_mesh=None,
|
| 733 |
+
**kwargs,
|
| 734 |
+
):
|
| 735 |
+
shuffle = False
|
| 736 |
+
from torch.utils.data.datapipes.iter.combinatorics import ShufflerIterDataPipe
|
| 737 |
+
|
| 738 |
+
# We need to save the shuffling state of the DataPipe
|
| 739 |
+
if isinstance(dataset, ShufflerIterDataPipe):
|
| 740 |
+
shuffle = dataset._shuffle_enabled
|
| 741 |
+
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
|
| 742 |
+
self.split_batches = split_batches
|
| 743 |
+
if shuffle:
|
| 744 |
+
torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle)
|
| 745 |
+
|
| 746 |
+
self.gradient_state = GradientState()
|
| 747 |
+
self.state = PartialState()
|
| 748 |
+
self._drop_last = _drop_last
|
| 749 |
+
self._non_blocking = _non_blocking
|
| 750 |
+
self.skip_batches = skip_batches
|
| 751 |
+
self.torch_device_mesh = torch_device_mesh
|
| 752 |
+
|
| 753 |
+
self.slice_fn = slice_tensors if slice_fn is None else slice_fn
|
| 754 |
+
self.iteration = 0
|
| 755 |
+
|
| 756 |
+
# if a device mesh is provided extract each dimension (dp, fsdp, tp)
|
| 757 |
+
# device mesh may hold any number of dimensions, however,
|
| 758 |
+
# below code is for targetted support for dp, fsdp and tp
|
| 759 |
+
|
| 760 |
+
# device mesh will be used only if there is tp involved
|
| 761 |
+
# or any multi-dimensional parallelism involving tp
|
| 762 |
+
# (dp, tp) (fsdp, tp) (dp, fsdp, tp)
|
| 763 |
+
# otherwise the default behavour not using device mesh should be sufficient
|
| 764 |
+
# since multi dimensional parallelism devoid of tp would anyway need
|
| 765 |
+
# different batches for each process irrespective of dp or fsdp
|
| 766 |
+
self.submesh_tp = None
|
| 767 |
+
self.submesh_dp = None
|
| 768 |
+
self.submesh_fsdp = None
|
| 769 |
+
if self.torch_device_mesh and "tp" in self.torch_device_mesh.mesh_dim_names:
|
| 770 |
+
self.submesh_tp = self.torch_device_mesh["tp"]
|
| 771 |
+
if "dp" in self.torch_device_mesh.mesh_dim_names:
|
| 772 |
+
self.submesh_dp = self.torch_device_mesh["dp"]
|
| 773 |
+
if "fsdp" in self.torch_device_mesh.mesh_dim_names:
|
| 774 |
+
self.submesh_fsdp = self.torch_device_mesh["fsdp"]
|
| 775 |
+
if self.submesh_tp and (self.submesh_dp or self.submesh_fsdp):
|
| 776 |
+
raise ValueError("TP + (DP/FSDP) is not yet supported in dispatch mode")
|
| 777 |
+
|
| 778 |
+
def _fetch_batches(self, iterator):
|
| 779 |
+
batches, batch = None, None
|
| 780 |
+
# On process 0, we gather the batch to dispatch.
|
| 781 |
+
if self.state.process_index == 0:
|
| 782 |
+
# Procedure to support TP only is simpler
|
| 783 |
+
# since we want to dispatch the same batch of samples across all ranks
|
| 784 |
+
# this removes complexity of handling multiple tp rank groups when TP + DP
|
| 785 |
+
# combination is involved.
|
| 786 |
+
|
| 787 |
+
try:
|
| 788 |
+
# for TP case avoid using split_batches
|
| 789 |
+
# since it would mean that the dataloader should be spilling out
|
| 790 |
+
# duplicates of batches.
|
| 791 |
+
if self.split_batches:
|
| 792 |
+
# One batch of the main iterator is dispatched and split.
|
| 793 |
+
if self.submesh_tp:
|
| 794 |
+
logger.warning(
|
| 795 |
+
"Use of split_batches for TP would need the dataloader to produce duplicate batches,"
|
| 796 |
+
"otherwise, use dispatch_batches=True instead."
|
| 797 |
+
)
|
| 798 |
+
self._update_state_dict()
|
| 799 |
+
batch = next(iterator)
|
| 800 |
+
else:
|
| 801 |
+
# num_processes batches of the main iterator are concatenated then dispatched and split.
|
| 802 |
+
# We add the batches one by one so we have the remainder available when drop_last=False.
|
| 803 |
+
batches = []
|
| 804 |
+
if self.submesh_tp:
|
| 805 |
+
# when tp, extract single batch and then replicate
|
| 806 |
+
self._update_state_dict()
|
| 807 |
+
batch = next(iterator)
|
| 808 |
+
batches = [batch] * self.state.num_processes
|
| 809 |
+
else:
|
| 810 |
+
for _ in range(self.state.num_processes):
|
| 811 |
+
self._update_state_dict()
|
| 812 |
+
batches.append(next(iterator))
|
| 813 |
+
try:
|
| 814 |
+
batch = concatenate(batches, dim=0)
|
| 815 |
+
except RuntimeError as e:
|
| 816 |
+
raise RuntimeError(
|
| 817 |
+
"You can't use batches of different size with `dispatch_batches=True` or when using an `IterableDataset`."
|
| 818 |
+
"either pass `dispatch_batches=False` and have each process fetch its own batch "
|
| 819 |
+
" or pass `split_batches=True`. By doing so, the main process will fetch a full batch and "
|
| 820 |
+
"slice it into `num_processes` batches for each process."
|
| 821 |
+
) from e
|
| 822 |
+
# In both cases, we need to get the structure of the batch that we will broadcast on other
|
| 823 |
+
# processes to initialize the tensors with the right shape.
|
| 824 |
+
# data_structure, stop_iteration
|
| 825 |
+
batch_info = [get_data_structure(batch), False]
|
| 826 |
+
except StopIteration:
|
| 827 |
+
batch_info = [None, True]
|
| 828 |
+
else:
|
| 829 |
+
batch_info = [None, self._stop_iteration]
|
| 830 |
+
# This is inplace, so after this instruction, every process has the same `batch_info` as process 0.
|
| 831 |
+
broadcast_object_list(batch_info)
|
| 832 |
+
self._stop_iteration = batch_info[1]
|
| 833 |
+
if self._stop_iteration:
|
| 834 |
+
# If drop_last is False and split_batches is False, we may have a remainder to take care of.
|
| 835 |
+
if not self.split_batches and not self._drop_last:
|
| 836 |
+
if self.state.process_index == 0 and len(batches) > 0:
|
| 837 |
+
batch = concatenate(batches, dim=0)
|
| 838 |
+
batch_info = [get_data_structure(batch), False]
|
| 839 |
+
else:
|
| 840 |
+
batch_info = [None, True]
|
| 841 |
+
broadcast_object_list(batch_info)
|
| 842 |
+
return batch, batch_info
|
| 843 |
+
|
| 844 |
+
def __iter__(self):
|
| 845 |
+
self.begin()
|
| 846 |
+
self.set_epoch(self.iteration)
|
| 847 |
+
main_iterator = None
|
| 848 |
+
if is_torch_version(">=", "2.0.1"):
|
| 849 |
+
# NOTE PyTorch DataLoader adds forward compatibilities for DataPipes, which broadcasts
|
| 850 |
+
# shared seed to all dist processes. Thus, we need to create iterator for all dist processes.
|
| 851 |
+
# But, we only iterate through the DataLoader on process 0.
|
| 852 |
+
main_iterator = self.base_dataloader.__iter__()
|
| 853 |
+
elif self.state.process_index == 0:
|
| 854 |
+
main_iterator = self.base_dataloader.__iter__()
|
| 855 |
+
stop_iteration = False
|
| 856 |
+
self._stop_iteration = False
|
| 857 |
+
first_batch = None
|
| 858 |
+
next_batch, next_batch_info = self._fetch_batches(main_iterator)
|
| 859 |
+
batch_index = 0
|
| 860 |
+
while not stop_iteration:
|
| 861 |
+
batch, batch_info = next_batch, next_batch_info
|
| 862 |
+
|
| 863 |
+
if self.state.process_index != 0:
|
| 864 |
+
# Initialize tensors on other processes than process 0.
|
| 865 |
+
batch = initialize_tensors(batch_info[0])
|
| 866 |
+
batch = send_to_device(batch, self.state.device, non_blocking=self._non_blocking)
|
| 867 |
+
# Broadcast the batch before splitting it.
|
| 868 |
+
batch = broadcast(batch, from_process=0)
|
| 869 |
+
|
| 870 |
+
if not self._drop_last and first_batch is None:
|
| 871 |
+
# We keep at least num processes elements of the first batch to be able to complete the last batch
|
| 872 |
+
first_batch = self.slice_fn(
|
| 873 |
+
batch,
|
| 874 |
+
slice(0, self.state.num_processes),
|
| 875 |
+
process_index=self.state.process_index,
|
| 876 |
+
num_processes=self.state.num_processes,
|
| 877 |
+
)
|
| 878 |
+
|
| 879 |
+
if batch is None:
|
| 880 |
+
raise ValueError(
|
| 881 |
+
f"Batch does not contain any data (`{batch}`). At the end of all iterable data available before expected stop iteration."
|
| 882 |
+
)
|
| 883 |
+
|
| 884 |
+
observed_batch_size = find_batch_size(batch)
|
| 885 |
+
batch_size = observed_batch_size // self.state.num_processes
|
| 886 |
+
|
| 887 |
+
stop_iteration = self._stop_iteration
|
| 888 |
+
if not stop_iteration:
|
| 889 |
+
# We may still be at the end of the dataloader without knowing it yet: if there is nothing left in
|
| 890 |
+
# the dataloader since the number of batches is a round multiple of the number of processes.
|
| 891 |
+
next_batch, next_batch_info = self._fetch_batches(main_iterator)
|
| 892 |
+
# next_batch_info[0] is None when there are no more batches, otherwise we still need to process them.
|
| 893 |
+
if self._stop_iteration and next_batch_info[0] is None:
|
| 894 |
+
stop_iteration = True
|
| 895 |
+
|
| 896 |
+
if not self._drop_last and stop_iteration and observed_batch_size % self.state.num_processes != 0:
|
| 897 |
+
# If the last batch is not complete, let's add the first batch to it.
|
| 898 |
+
batch = concatenate([batch, first_batch], dim=0)
|
| 899 |
+
# Batch size computation above is wrong, it's off by 1 so we fix it.
|
| 900 |
+
batch_size += 1
|
| 901 |
+
|
| 902 |
+
data_slice = slice(self.state.process_index * batch_size, (self.state.process_index + 1) * batch_size)
|
| 903 |
+
batch = self.slice_fn(
|
| 904 |
+
batch,
|
| 905 |
+
data_slice,
|
| 906 |
+
process_index=self.state.process_index,
|
| 907 |
+
num_processes=self.state.num_processes,
|
| 908 |
+
)
|
| 909 |
+
|
| 910 |
+
if stop_iteration:
|
| 911 |
+
self.end_of_dataloader = True
|
| 912 |
+
self._update_state_dict()
|
| 913 |
+
self.remainder = observed_batch_size
|
| 914 |
+
if batch_index >= self.skip_batches:
|
| 915 |
+
yield batch
|
| 916 |
+
batch_index += 1
|
| 917 |
+
self.iteration += 1
|
| 918 |
+
self.end()
|
| 919 |
+
|
| 920 |
+
def set_epoch(self, epoch: int):
|
| 921 |
+
# In case it is manually passed in, the user can set it to what they like
|
| 922 |
+
if self.iteration != epoch:
|
| 923 |
+
self.iteration = epoch
|
| 924 |
+
if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"):
|
| 925 |
+
self.batch_sampler.sampler.set_epoch(epoch)
|
| 926 |
+
elif hasattr(self.dataset, "set_epoch"):
|
| 927 |
+
self.dataset.set_epoch(epoch)
|
| 928 |
+
|
| 929 |
+
def __len__(self):
|
| 930 |
+
whole_length = len(self.base_dataloader)
|
| 931 |
+
if self.split_batches:
|
| 932 |
+
return whole_length
|
| 933 |
+
elif self._drop_last:
|
| 934 |
+
return whole_length // self.state.num_processes
|
| 935 |
+
else:
|
| 936 |
+
return math.ceil(whole_length / self.state.num_processes)
|
| 937 |
+
|
| 938 |
+
def __reduce__(self):
|
| 939 |
+
"""
|
| 940 |
+
Define the `__reduce__` method to ensure a `DataLoaderDispatcher` can be pickled and unpickled. This needs to
|
| 941 |
+
be explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
|
| 942 |
+
`__class__` member.
|
| 943 |
+
"""
|
| 944 |
+
args = super().__reduce__()
|
| 945 |
+
return (DataLoaderDispatcher, *args[1:])
|
| 946 |
+
|
| 947 |
+
@property
|
| 948 |
+
def total_batch_size(self):
|
| 949 |
+
return (
|
| 950 |
+
self.dataset.batch_size if self.split_batches else (self.dataset.batch_size * self.dataset.num_processes)
|
| 951 |
+
)
|
| 952 |
+
|
| 953 |
+
@property
|
| 954 |
+
def total_dataset_length(self):
|
| 955 |
+
return len(self.dataset)
|
| 956 |
+
|
| 957 |
+
def get_sampler(self):
|
| 958 |
+
return get_sampler(self)
|
| 959 |
+
|
| 960 |
+
def set_sampler(self, sampler):
|
| 961 |
+
sampler_is_batch_sampler = isinstance(self.sampler, BatchSampler)
|
| 962 |
+
if sampler_is_batch_sampler:
|
| 963 |
+
self.sampler.sampler = sampler
|
| 964 |
+
else:
|
| 965 |
+
self.batch_sampler.sampler = sampler
|
| 966 |
+
if hasattr(self.batch_sampler, "batch_sampler"):
|
| 967 |
+
self.batch_sampler.batch_sampler.sampler = sampler
|
| 968 |
+
|
| 969 |
+
|
| 970 |
+
def get_sampler(dataloader):
|
| 971 |
+
"""
|
| 972 |
+
Get the sampler associated to the dataloader
|
| 973 |
+
|
| 974 |
+
Args:
|
| 975 |
+
dataloader (`torch.utils.data.dataloader.DataLoader`):
|
| 976 |
+
The data loader to split across several devices.
|
| 977 |
+
Returns:
|
| 978 |
+
`torch.utils.data.Sampler`: The sampler associated to the dataloader
|
| 979 |
+
"""
|
| 980 |
+
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
|
| 981 |
+
if sampler_is_batch_sampler:
|
| 982 |
+
sampler = getattr(dataloader.sampler, "sampler", None)
|
| 983 |
+
else:
|
| 984 |
+
sampler = getattr(dataloader.batch_sampler, "sampler", None)
|
| 985 |
+
return sampler
|
| 986 |
+
|
| 987 |
+
|
| 988 |
+
def prepare_data_loader(
|
| 989 |
+
dataloader: DataLoader,
|
| 990 |
+
device: Optional[torch.device] = None,
|
| 991 |
+
num_processes: Optional[int] = None,
|
| 992 |
+
process_index: Optional[int] = None,
|
| 993 |
+
split_batches: bool = False,
|
| 994 |
+
put_on_device: bool = False,
|
| 995 |
+
rng_types: Optional[list[Union[str, RNGType]]] = None,
|
| 996 |
+
dispatch_batches: Optional[bool] = None,
|
| 997 |
+
even_batches: bool = True,
|
| 998 |
+
slice_fn_for_dispatch: Optional[Callable] = None,
|
| 999 |
+
use_seedable_sampler: bool = False,
|
| 1000 |
+
data_seed: Optional[int] = None,
|
| 1001 |
+
non_blocking: bool = False,
|
| 1002 |
+
use_stateful_dataloader: bool = False,
|
| 1003 |
+
torch_device_mesh=None,
|
| 1004 |
+
) -> DataLoader:
|
| 1005 |
+
"""
|
| 1006 |
+
Wraps a PyTorch `DataLoader` to generate batches for one of the processes only.
|
| 1007 |
+
|
| 1008 |
+
Depending on the value of the `drop_last` attribute of the `dataloader` passed, it will either stop the iteration
|
| 1009 |
+
at the first batch that would be too small / not present on all processes or loop with indices from the beginning.
|
| 1010 |
+
|
| 1011 |
+
Args:
|
| 1012 |
+
dataloader (`torch.utils.data.dataloader.DataLoader`):
|
| 1013 |
+
The data loader to split across several devices.
|
| 1014 |
+
device (`torch.device`):
|
| 1015 |
+
The target device for the returned `DataLoader`.
|
| 1016 |
+
num_processes (`int`, *optional*):
|
| 1017 |
+
The number of processes running concurrently. Will default to the value given by [`~state.PartialState`].
|
| 1018 |
+
process_index (`int`, *optional*):
|
| 1019 |
+
The index of the current process. Will default to the value given by [`~state.PartialState`].
|
| 1020 |
+
split_batches (`bool`, *optional*, defaults to `False`):
|
| 1021 |
+
Whether the resulting `DataLoader` should split the batches of the original data loader across devices or
|
| 1022 |
+
yield full batches (in which case it will yield batches starting at the `process_index`-th and advancing of
|
| 1023 |
+
`num_processes` batches at each iteration).
|
| 1024 |
+
|
| 1025 |
+
Another way to see this is that the observed batch size will be the same as the initial `dataloader` if
|
| 1026 |
+
this option is set to `True`, the batch size of the initial `dataloader` multiplied by `num_processes`
|
| 1027 |
+
otherwise.
|
| 1028 |
+
|
| 1029 |
+
Setting this option to `True` requires that the batch size of the `dataloader` is a round multiple of
|
| 1030 |
+
`batch_size`.
|
| 1031 |
+
put_on_device (`bool`, *optional*, defaults to `False`):
|
| 1032 |
+
Whether or not to put the batches on `device` (only works if the batches are nested list, tuples or
|
| 1033 |
+
dictionaries of tensors).
|
| 1034 |
+
rng_types (list of `str` or [`~utils.RNGType`]):
|
| 1035 |
+
The list of random number generators to synchronize at the beginning of each iteration. Should be one or
|
| 1036 |
+
several of:
|
| 1037 |
+
|
| 1038 |
+
- `"torch"`: the base torch random number generator
|
| 1039 |
+
- `"cuda"`: the CUDA random number generator (GPU only)
|
| 1040 |
+
- `"xla"`: the XLA random number generator (TPU only)
|
| 1041 |
+
- `"generator"`: the `torch.Generator` of the sampler (or batch sampler if there is no sampler in your
|
| 1042 |
+
dataloader) or of the iterable dataset (if it exists) if the underlying dataset is of that type.
|
| 1043 |
+
|
| 1044 |
+
dispatch_batches (`bool`, *optional*):
|
| 1045 |
+
If set to `True`, the dataloader prepared is only iterated through on the main process and then the batches
|
| 1046 |
+
are split and broadcast to each process. Will default to `True` when the underlying dataset is an
|
| 1047 |
+
`IterableDataset`, `False` otherwise.
|
| 1048 |
+
even_batches (`bool`, *optional*, defaults to `True`):
|
| 1049 |
+
If set to `True`, in cases where the total batch size across all processes does not exactly divide the
|
| 1050 |
+
dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among
|
| 1051 |
+
all workers.
|
| 1052 |
+
slice_fn_for_dispatch (`Callable`, *optional*`):
|
| 1053 |
+
If passed, this function will be used to slice tensors across `num_processes`. Will default to
|
| 1054 |
+
[`~utils.slice_tensors`]. This argument is used only when `dispatch_batches` is set to `True` and will be
|
| 1055 |
+
ignored otherwise.
|
| 1056 |
+
use_seedable_sampler (`bool`, *optional*, defaults to `False`):
|
| 1057 |
+
Whether to use the [`~data_loader.SeedableRandomSampler`] instead of a `RandomSampler` for better
|
| 1058 |
+
reproducability. Comes at a cost of potentially different performances due to different shuffling
|
| 1059 |
+
algorithms but ensures results will be the *exact* same. Should be paired with `set_seed()` at every
|
| 1060 |
+
`self.set_epoch`
|
| 1061 |
+
data_seed (`int`, *optional*, defaults to `None`):
|
| 1062 |
+
The seed to use for the underlying generator when using `use_seedable_sampler`. If `None`, the generator
|
| 1063 |
+
will use the current default seed from torch.
|
| 1064 |
+
non_blocking (`bool`, *optional*, defaults to `False`):
|
| 1065 |
+
If set to `True`, dataloader will utilize non-blocking host-to-device transfers. If the dataloader has
|
| 1066 |
+
`pin_memory` set to `True`, this will help to increase overlap between data transfer and computations.
|
| 1067 |
+
use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
|
| 1068 |
+
"If set to true, the dataloader prepared by the Accelerator will be backed by "
|
| 1069 |
+
"[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader).
|
| 1070 |
+
This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed."
|
| 1071 |
+
torch_device_mesh (`torch.distributed.DeviceMesh`, *optional*, defaults to `None`):
|
| 1072 |
+
PyTorch device mesh.
|
| 1073 |
+
|
| 1074 |
+
|
| 1075 |
+
Returns:
|
| 1076 |
+
`torch.utils.data.dataloader.DataLoader`: A new data loader that will yield the portion of the batches
|
| 1077 |
+
|
| 1078 |
+
<Tip warning={true}>
|
| 1079 |
+
|
| 1080 |
+
`BatchSampler`s with varying batch sizes are not enabled by default. To enable this behaviour, set `even_batches`
|
| 1081 |
+
equal to `False`
|
| 1082 |
+
|
| 1083 |
+
</Tip>
|
| 1084 |
+
"""
|
| 1085 |
+
if dispatch_batches is None:
|
| 1086 |
+
if not put_on_device:
|
| 1087 |
+
dispatch_batches = False
|
| 1088 |
+
else:
|
| 1089 |
+
dispatch_batches = isinstance(dataloader.dataset, IterableDataset)
|
| 1090 |
+
|
| 1091 |
+
if dispatch_batches and not put_on_device:
|
| 1092 |
+
raise ValueError("Using `dispatch_batches=True` requires `put_on_device=True`.")
|
| 1093 |
+
# Grab defaults from PartialState
|
| 1094 |
+
state = PartialState()
|
| 1095 |
+
if num_processes is None:
|
| 1096 |
+
num_processes = state.num_processes
|
| 1097 |
+
|
| 1098 |
+
if process_index is None:
|
| 1099 |
+
process_index = state.process_index
|
| 1100 |
+
|
| 1101 |
+
if torch_device_mesh:
|
| 1102 |
+
if state.distributed_type == DistributedType.DEEPSPEED:
|
| 1103 |
+
# In DeepSpeed, the optimizer sharing level in DP is determined by the config file.
|
| 1104 |
+
# Only considers "dp" and "tp".
|
| 1105 |
+
# Given a device mesh (dp, tp) = (2, 3):
|
| 1106 |
+
# - From the data parallel perspective, ranks should be structured as: 0 0 0 1 1 1
|
| 1107 |
+
# - Processes with the same DP rank will receive the same batch.
|
| 1108 |
+
if "tp" in torch_device_mesh.mesh_dim_names:
|
| 1109 |
+
submesh_tp_size = torch_device_mesh["tp"].size()
|
| 1110 |
+
process_index = process_index // submesh_tp_size
|
| 1111 |
+
num_processes = num_processes // submesh_tp_size
|
| 1112 |
+
else:
|
| 1113 |
+
# when device mesh is used, specifically with TP
|
| 1114 |
+
# then there is need to update process_index and num_processes
|
| 1115 |
+
# to bring in the effect of generating same batch across TP ranks
|
| 1116 |
+
# and different batch across FSDP and DP ranks.
|
| 1117 |
+
# Example:
|
| 1118 |
+
# if device mesh is (dp,fsdp,tp) = (2, 2, 3)
|
| 1119 |
+
# ranks would range from 0...11
|
| 1120 |
+
# from data angle ranks should look like 0 0 0 1 1 1 2 2 2 3 3 3
|
| 1121 |
+
# processes with same ranks/ids would receive the same batch
|
| 1122 |
+
submesh_fsdp_size = 1
|
| 1123 |
+
submesh_dp_size = 1
|
| 1124 |
+
submesh_tp_size = 1
|
| 1125 |
+
if "tp" in torch_device_mesh.mesh_dim_names:
|
| 1126 |
+
submesh_tp_size = torch_device_mesh["tp"].size()
|
| 1127 |
+
if "dp" in torch_device_mesh.mesh_dim_names:
|
| 1128 |
+
submesh_dp_size = torch_device_mesh["dp"].size()
|
| 1129 |
+
if "fsdp" in torch_device_mesh.mesh_dim_names:
|
| 1130 |
+
submesh_fsdp_size = torch_device_mesh["fsdp"].size()
|
| 1131 |
+
process_index = process_index // submesh_tp_size
|
| 1132 |
+
num_processes = submesh_fsdp_size * submesh_dp_size
|
| 1133 |
+
|
| 1134 |
+
# Sanity check
|
| 1135 |
+
if split_batches:
|
| 1136 |
+
if dataloader.batch_size is not None:
|
| 1137 |
+
batch_size_for_check = dataloader.batch_size
|
| 1138 |
+
else:
|
| 1139 |
+
# For custom batch_sampler
|
| 1140 |
+
if hasattr(dataloader.batch_sampler, "batch_size"):
|
| 1141 |
+
batch_size_for_check = dataloader.batch_sampler.batch_size
|
| 1142 |
+
else:
|
| 1143 |
+
raise ValueError(
|
| 1144 |
+
"In order to use `split_batches==True` you must have a `batch_size` attribute either in the passed "
|
| 1145 |
+
"`dataloader` or `dataloader.batch_sampler` objects, and it has to return a natural number. "
|
| 1146 |
+
"Your `dataloader.batch_size` is None and `dataloader.batch_sampler` "
|
| 1147 |
+
f"(`{type(dataloader.batch_sampler)}`) does not have the `batch_size` attribute set."
|
| 1148 |
+
)
|
| 1149 |
+
|
| 1150 |
+
if batch_size_for_check > 1 and batch_size_for_check % num_processes != 0:
|
| 1151 |
+
raise ValueError(
|
| 1152 |
+
f"To use a `DataLoader` in `split_batches` mode, the batch size ({dataloader.batch_size}) "
|
| 1153 |
+
f"needs to be a round multiple of the number of processes ({num_processes})."
|
| 1154 |
+
)
|
| 1155 |
+
|
| 1156 |
+
new_dataset = dataloader.dataset
|
| 1157 |
+
# Iterable dataset doesn't like batch_sampler, but data_loader creates a default one for it
|
| 1158 |
+
new_batch_sampler = dataloader.batch_sampler if not isinstance(new_dataset, IterableDataset) else None
|
| 1159 |
+
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
|
| 1160 |
+
synchronized_generator = None
|
| 1161 |
+
|
| 1162 |
+
sampler = get_sampler(dataloader)
|
| 1163 |
+
if isinstance(sampler, RandomSampler) and use_seedable_sampler:
|
| 1164 |
+
# When iterating through the dataloader during distributed processes
|
| 1165 |
+
# we want to ensure that on each process we are iterating through the same
|
| 1166 |
+
# samples in the same order if a seed is set. This requires a tweak
|
| 1167 |
+
# to the `torch.utils.data.RandomSampler` class (if used).
|
| 1168 |
+
sampler = SeedableRandomSampler(
|
| 1169 |
+
data_source=sampler.data_source,
|
| 1170 |
+
replacement=sampler.replacement,
|
| 1171 |
+
num_samples=sampler._num_samples,
|
| 1172 |
+
generator=getattr(
|
| 1173 |
+
sampler,
|
| 1174 |
+
"generator",
|
| 1175 |
+
torch.Generator(device=torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu"),
|
| 1176 |
+
),
|
| 1177 |
+
data_seed=data_seed,
|
| 1178 |
+
)
|
| 1179 |
+
|
| 1180 |
+
if isinstance(dataloader.sampler, RandomSampler) and state.distributed_type == DistributedType.XLA:
|
| 1181 |
+
# isinstance(dataloader.sampler, RandomSampler) indicates the original dataloader has `shuffle` enabled.
|
| 1182 |
+
generator = torch.Generator(
|
| 1183 |
+
device=torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu"
|
| 1184 |
+
)
|
| 1185 |
+
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
| 1186 |
+
generator.manual_seed(seed)
|
| 1187 |
+
dataloader.generator = generator
|
| 1188 |
+
dataloader.sampler.generator = generator
|
| 1189 |
+
# No change if no multiprocess
|
| 1190 |
+
if (num_processes != 1 or state.distributed_type == DistributedType.MEGATRON_LM) and not dispatch_batches:
|
| 1191 |
+
if isinstance(new_dataset, IterableDataset):
|
| 1192 |
+
if getattr(dataloader.dataset, "generator", None) is not None:
|
| 1193 |
+
synchronized_generator = dataloader.dataset.generator
|
| 1194 |
+
new_dataset = IterableDatasetShard(
|
| 1195 |
+
new_dataset,
|
| 1196 |
+
batch_size=dataloader.batch_size,
|
| 1197 |
+
drop_last=dataloader.drop_last,
|
| 1198 |
+
num_processes=num_processes,
|
| 1199 |
+
process_index=process_index,
|
| 1200 |
+
split_batches=split_batches,
|
| 1201 |
+
)
|
| 1202 |
+
else:
|
| 1203 |
+
if not use_seedable_sampler and hasattr(sampler, "generator"):
|
| 1204 |
+
if sampler.generator is None:
|
| 1205 |
+
sampler.generator = torch.Generator(
|
| 1206 |
+
device=torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu"
|
| 1207 |
+
)
|
| 1208 |
+
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
| 1209 |
+
sampler.generator.manual_seed(seed)
|
| 1210 |
+
synchronized_generator = sampler.generator
|
| 1211 |
+
batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
|
| 1212 |
+
new_batch_sampler = BatchSamplerShard(
|
| 1213 |
+
batch_sampler,
|
| 1214 |
+
num_processes=num_processes,
|
| 1215 |
+
process_index=process_index,
|
| 1216 |
+
split_batches=split_batches,
|
| 1217 |
+
even_batches=even_batches,
|
| 1218 |
+
)
|
| 1219 |
+
|
| 1220 |
+
# We ignore all of those since they are all dealt with by our new_batch_sampler
|
| 1221 |
+
ignore_kwargs = [
|
| 1222 |
+
"batch_size",
|
| 1223 |
+
"shuffle",
|
| 1224 |
+
"sampler",
|
| 1225 |
+
"batch_sampler",
|
| 1226 |
+
"drop_last",
|
| 1227 |
+
]
|
| 1228 |
+
|
| 1229 |
+
if rng_types is not None and synchronized_generator is None and "generator" in rng_types:
|
| 1230 |
+
rng_types.remove("generator")
|
| 1231 |
+
|
| 1232 |
+
kwargs = {
|
| 1233 |
+
k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k])
|
| 1234 |
+
for k in _PYTORCH_DATALOADER_KWARGS
|
| 1235 |
+
if k not in ignore_kwargs
|
| 1236 |
+
}
|
| 1237 |
+
|
| 1238 |
+
# Need to provide batch_size as batch_sampler is None for Iterable dataset
|
| 1239 |
+
if new_batch_sampler is None:
|
| 1240 |
+
kwargs["drop_last"] = dataloader.drop_last
|
| 1241 |
+
kwargs["batch_size"] = (
|
| 1242 |
+
dataloader.batch_size // num_processes if split_batches and not dispatch_batches else dataloader.batch_size
|
| 1243 |
+
)
|
| 1244 |
+
if dispatch_batches:
|
| 1245 |
+
kwargs.pop("generator")
|
| 1246 |
+
dataloader = DataLoaderDispatcher(
|
| 1247 |
+
new_dataset,
|
| 1248 |
+
split_batches=split_batches,
|
| 1249 |
+
batch_sampler=new_batch_sampler,
|
| 1250 |
+
_drop_last=dataloader.drop_last,
|
| 1251 |
+
_non_blocking=non_blocking,
|
| 1252 |
+
slice_fn=slice_fn_for_dispatch,
|
| 1253 |
+
use_stateful_dataloader=use_stateful_dataloader,
|
| 1254 |
+
torch_device_mesh=torch_device_mesh,
|
| 1255 |
+
**kwargs,
|
| 1256 |
+
)
|
| 1257 |
+
elif sampler_is_batch_sampler:
|
| 1258 |
+
dataloader = DataLoaderShard(
|
| 1259 |
+
new_dataset,
|
| 1260 |
+
device=device if put_on_device and state.distributed_type != DistributedType.XLA else None,
|
| 1261 |
+
sampler=new_batch_sampler,
|
| 1262 |
+
batch_size=dataloader.batch_size,
|
| 1263 |
+
rng_types=rng_types,
|
| 1264 |
+
_drop_last=dataloader.drop_last,
|
| 1265 |
+
_non_blocking=non_blocking,
|
| 1266 |
+
synchronized_generator=synchronized_generator,
|
| 1267 |
+
use_stateful_dataloader=use_stateful_dataloader,
|
| 1268 |
+
**kwargs,
|
| 1269 |
+
)
|
| 1270 |
+
else:
|
| 1271 |
+
dataloader = DataLoaderShard(
|
| 1272 |
+
new_dataset,
|
| 1273 |
+
device=device if put_on_device and state.distributed_type != DistributedType.XLA else None,
|
| 1274 |
+
batch_sampler=new_batch_sampler,
|
| 1275 |
+
rng_types=rng_types,
|
| 1276 |
+
synchronized_generator=synchronized_generator,
|
| 1277 |
+
_drop_last=dataloader.drop_last,
|
| 1278 |
+
_non_blocking=non_blocking,
|
| 1279 |
+
use_stateful_dataloader=use_stateful_dataloader,
|
| 1280 |
+
**kwargs,
|
| 1281 |
+
)
|
| 1282 |
+
|
| 1283 |
+
if isinstance(sampler, SeedableRandomSampler) and use_seedable_sampler:
|
| 1284 |
+
dataloader.set_sampler(sampler)
|
| 1285 |
+
if state.distributed_type == DistributedType.XLA:
|
| 1286 |
+
return MpDeviceLoaderWrapper(dataloader, device)
|
| 1287 |
+
return dataloader
|
| 1288 |
+
|
| 1289 |
+
|
| 1290 |
+
class SkipBatchSampler(BatchSampler):
|
| 1291 |
+
"""
|
| 1292 |
+
A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`.
|
| 1293 |
+
Should not be used if the original dataloader is a `StatefulDataLoader`.
|
| 1294 |
+
"""
|
| 1295 |
+
|
| 1296 |
+
def __init__(self, batch_sampler, skip_batches=0):
|
| 1297 |
+
self.batch_sampler = batch_sampler
|
| 1298 |
+
self.skip_batches = skip_batches
|
| 1299 |
+
|
| 1300 |
+
def __iter__(self):
|
| 1301 |
+
for index, samples in enumerate(self.batch_sampler):
|
| 1302 |
+
if index >= self.skip_batches:
|
| 1303 |
+
yield samples
|
| 1304 |
+
|
| 1305 |
+
@property
|
| 1306 |
+
def total_length(self):
|
| 1307 |
+
return len(self.batch_sampler)
|
| 1308 |
+
|
| 1309 |
+
def __len__(self):
|
| 1310 |
+
return len(self.batch_sampler) - self.skip_batches
|
| 1311 |
+
|
| 1312 |
+
|
| 1313 |
+
class SkipDataLoader(DataLoaderAdapter, DataLoaderStateMixin):
|
| 1314 |
+
"""
|
| 1315 |
+
Subclass of a PyTorch `DataLoader` that will skip the first batches. Generally it's preferable to use
|
| 1316 |
+
`skip_first_batches`/`torchdata.StatefulDataLoader` instead of this class.
|
| 1317 |
+
|
| 1318 |
+
Args:
|
| 1319 |
+
dataset (`torch.utils.data.dataset.Dataset`):
|
| 1320 |
+
The dataset to use to build this dataloader.
|
| 1321 |
+
skip_batches (`int`, *optional*, defaults to 0):
|
| 1322 |
+
The number of batches to skip at the beginning.
|
| 1323 |
+
kwargs:
|
| 1324 |
+
All other keyword arguments to pass to the regular `DataLoader` initialization.
|
| 1325 |
+
"""
|
| 1326 |
+
|
| 1327 |
+
def __init__(self, dataset, skip_batches=0, use_stateful_dataloader=False, **kwargs):
|
| 1328 |
+
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
|
| 1329 |
+
self.skip_batches = skip_batches
|
| 1330 |
+
self.gradient_state = GradientState()
|
| 1331 |
+
|
| 1332 |
+
def __iter__(self):
|
| 1333 |
+
self.begin()
|
| 1334 |
+
for index, batch in enumerate(self.base_dataloader.__iter__()):
|
| 1335 |
+
if index >= self.skip_batches:
|
| 1336 |
+
self._update_state_dict()
|
| 1337 |
+
yield batch
|
| 1338 |
+
self.end()
|
| 1339 |
+
|
| 1340 |
+
def __len__(self):
|
| 1341 |
+
return len(self.base_dataloader) - self.skip_batches
|
| 1342 |
+
|
| 1343 |
+
def __reduce__(self):
|
| 1344 |
+
"""
|
| 1345 |
+
Define the `__reduce__` method to ensure a `SkipDataLoader` can be pickled and unpickled. This needs to be
|
| 1346 |
+
explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
|
| 1347 |
+
`__class__` member.
|
| 1348 |
+
"""
|
| 1349 |
+
args = super().__reduce__()
|
| 1350 |
+
return (SkipDataLoader, *args[1:])
|
| 1351 |
+
|
| 1352 |
+
|
| 1353 |
+
def skip_first_batches(dataloader, num_batches=0):
|
| 1354 |
+
"""
|
| 1355 |
+
Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. Should not be used if
|
| 1356 |
+
the original dataloader is a `StatefulDataLoader`.
|
| 1357 |
+
"""
|
| 1358 |
+
state = PartialState()
|
| 1359 |
+
if state.distributed_type == DistributedType.XLA:
|
| 1360 |
+
device = dataloader.device
|
| 1361 |
+
dataloader = dataloader.dataloader
|
| 1362 |
+
|
| 1363 |
+
dataset = dataloader.dataset
|
| 1364 |
+
sampler_is_batch_sampler = False
|
| 1365 |
+
if isinstance(dataset, IterableDataset):
|
| 1366 |
+
new_batch_sampler = None
|
| 1367 |
+
else:
|
| 1368 |
+
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
|
| 1369 |
+
batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
|
| 1370 |
+
new_batch_sampler = SkipBatchSampler(batch_sampler, skip_batches=num_batches)
|
| 1371 |
+
|
| 1372 |
+
# We ignore all of those since they are all dealt with by our new_batch_sampler
|
| 1373 |
+
ignore_kwargs = [
|
| 1374 |
+
"batch_size",
|
| 1375 |
+
"shuffle",
|
| 1376 |
+
"sampler",
|
| 1377 |
+
"batch_sampler",
|
| 1378 |
+
"drop_last",
|
| 1379 |
+
]
|
| 1380 |
+
|
| 1381 |
+
kwargs = {
|
| 1382 |
+
k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k])
|
| 1383 |
+
for k in _PYTORCH_DATALOADER_KWARGS
|
| 1384 |
+
if k not in ignore_kwargs
|
| 1385 |
+
}
|
| 1386 |
+
|
| 1387 |
+
# Need to provide batch_size as batch_sampler is None for Iterable dataset
|
| 1388 |
+
if new_batch_sampler is None:
|
| 1389 |
+
kwargs["drop_last"] = dataloader.drop_last
|
| 1390 |
+
kwargs["batch_size"] = dataloader.batch_size
|
| 1391 |
+
|
| 1392 |
+
if isinstance(dataloader, DataLoaderDispatcher):
|
| 1393 |
+
if new_batch_sampler is None:
|
| 1394 |
+
# Need to manually skip batches in the dataloader
|
| 1395 |
+
kwargs["skip_batches"] = num_batches
|
| 1396 |
+
dataloader = DataLoaderDispatcher(
|
| 1397 |
+
dataset,
|
| 1398 |
+
split_batches=dataloader.split_batches,
|
| 1399 |
+
batch_sampler=new_batch_sampler,
|
| 1400 |
+
_drop_last=dataloader._drop_last,
|
| 1401 |
+
**kwargs,
|
| 1402 |
+
)
|
| 1403 |
+
elif isinstance(dataloader, DataLoaderShard):
|
| 1404 |
+
if new_batch_sampler is None:
|
| 1405 |
+
# Need to manually skip batches in the dataloader
|
| 1406 |
+
kwargs["skip_batches"] = num_batches
|
| 1407 |
+
elif sampler_is_batch_sampler:
|
| 1408 |
+
kwargs["sampler"] = new_batch_sampler
|
| 1409 |
+
kwargs["batch_size"] = dataloader.batch_size
|
| 1410 |
+
else:
|
| 1411 |
+
kwargs["batch_sampler"] = new_batch_sampler
|
| 1412 |
+
dataloader = DataLoaderShard(
|
| 1413 |
+
dataset,
|
| 1414 |
+
device=dataloader.device,
|
| 1415 |
+
rng_types=dataloader.rng_types,
|
| 1416 |
+
synchronized_generator=dataloader.synchronized_generator,
|
| 1417 |
+
**kwargs,
|
| 1418 |
+
)
|
| 1419 |
+
else:
|
| 1420 |
+
if new_batch_sampler is None:
|
| 1421 |
+
# Need to manually skip batches in the dataloader
|
| 1422 |
+
dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs)
|
| 1423 |
+
else:
|
| 1424 |
+
dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs)
|
| 1425 |
+
|
| 1426 |
+
if state.distributed_type == DistributedType.XLA:
|
| 1427 |
+
dataloader = MpDeviceLoaderWrapper(dataloader, device)
|
| 1428 |
+
|
| 1429 |
+
return dataloader
|
venv/Lib/site-packages/accelerate/hooks.py
ADDED
|
@@ -0,0 +1,739 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import functools
|
| 16 |
+
from collections.abc import Mapping
|
| 17 |
+
from typing import Optional, Union
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
|
| 22 |
+
from .state import PartialState
|
| 23 |
+
from .utils import (
|
| 24 |
+
PrefixedDataset,
|
| 25 |
+
find_device,
|
| 26 |
+
named_module_tensors,
|
| 27 |
+
send_to_device,
|
| 28 |
+
set_module_tensor_to_device,
|
| 29 |
+
)
|
| 30 |
+
from .utils.imports import (
|
| 31 |
+
is_mlu_available,
|
| 32 |
+
is_musa_available,
|
| 33 |
+
is_npu_available,
|
| 34 |
+
)
|
| 35 |
+
from .utils.memory import clear_device_cache
|
| 36 |
+
from .utils.modeling import get_non_persistent_buffers
|
| 37 |
+
from .utils.other import recursive_getattr
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
_accelerate_added_attributes = ["to", "cuda", "npu", "xpu", "mlu", "sdaa", "musa"]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class ModelHook:
|
| 44 |
+
"""
|
| 45 |
+
A hook that contains callbacks to be executed just before and after the forward method of a model. The difference
|
| 46 |
+
with PyTorch existing hooks is that they get passed along the kwargs.
|
| 47 |
+
|
| 48 |
+
Class attribute:
|
| 49 |
+
- **no_grad** (`bool`, *optional*, defaults to `False`) -- Whether or not to execute the actual forward pass under
|
| 50 |
+
the `torch.no_grad()` context manager.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
no_grad = False
|
| 54 |
+
|
| 55 |
+
def init_hook(self, module):
|
| 56 |
+
"""
|
| 57 |
+
To be executed when the hook is attached to the module.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
module (`torch.nn.Module`): The module attached to this hook.
|
| 61 |
+
"""
|
| 62 |
+
return module
|
| 63 |
+
|
| 64 |
+
def pre_forward(self, module, *args, **kwargs):
|
| 65 |
+
"""
|
| 66 |
+
To be executed just before the forward method of the model.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
module (`torch.nn.Module`): The module whose forward pass will be executed just after this event.
|
| 70 |
+
args (`Tuple[Any]`): The positional arguments passed to the module.
|
| 71 |
+
kwargs (`Dict[Str, Any]`): The keyword arguments passed to the module.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
`Tuple[Tuple[Any], Dict[Str, Any]]`: A tuple with the treated `args` and `kwargs`.
|
| 75 |
+
"""
|
| 76 |
+
return args, kwargs
|
| 77 |
+
|
| 78 |
+
def post_forward(self, module, output):
|
| 79 |
+
"""
|
| 80 |
+
To be executed just after the forward method of the model.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
module (`torch.nn.Module`): The module whose forward pass been executed just before this event.
|
| 84 |
+
output (`Any`): The output of the module.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
`Any`: The processed `output`.
|
| 88 |
+
"""
|
| 89 |
+
return output
|
| 90 |
+
|
| 91 |
+
def detach_hook(self, module):
|
| 92 |
+
"""
|
| 93 |
+
To be executed when the hook is detached from a module.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
module (`torch.nn.Module`): The module detached from this hook.
|
| 97 |
+
"""
|
| 98 |
+
return module
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class SequentialHook(ModelHook):
|
| 102 |
+
"""
|
| 103 |
+
A hook that can contain several hooks and iterates through them at each event.
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
def __init__(self, *hooks):
|
| 107 |
+
self.hooks = hooks
|
| 108 |
+
|
| 109 |
+
def init_hook(self, module):
|
| 110 |
+
for hook in self.hooks:
|
| 111 |
+
module = hook.init_hook(module)
|
| 112 |
+
return module
|
| 113 |
+
|
| 114 |
+
def pre_forward(self, module, *args, **kwargs):
|
| 115 |
+
for hook in self.hooks:
|
| 116 |
+
args, kwargs = hook.pre_forward(module, *args, **kwargs)
|
| 117 |
+
return args, kwargs
|
| 118 |
+
|
| 119 |
+
def post_forward(self, module, output):
|
| 120 |
+
for hook in self.hooks:
|
| 121 |
+
output = hook.post_forward(module, output)
|
| 122 |
+
return output
|
| 123 |
+
|
| 124 |
+
def detach_hook(self, module):
|
| 125 |
+
for hook in self.hooks:
|
| 126 |
+
module = hook.detach_hook(module)
|
| 127 |
+
return module
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def add_hook_to_module(module: nn.Module, hook: ModelHook, append: bool = False):
|
| 131 |
+
"""
|
| 132 |
+
Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove
|
| 133 |
+
this behavior and restore the original `forward` method, use `remove_hook_from_module`.
|
| 134 |
+
|
| 135 |
+
<Tip warning={true}>
|
| 136 |
+
|
| 137 |
+
If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks
|
| 138 |
+
together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class.
|
| 139 |
+
|
| 140 |
+
</Tip>
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
module (`torch.nn.Module`):
|
| 144 |
+
The module to attach a hook to.
|
| 145 |
+
hook (`ModelHook`):
|
| 146 |
+
The hook to attach.
|
| 147 |
+
append (`bool`, *optional*, defaults to `False`):
|
| 148 |
+
Whether the hook should be chained with an existing one (if module already contains a hook) or not.
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
`torch.nn.Module`: The same module, with the hook attached (the module is modified in place, so the result can
|
| 152 |
+
be discarded).
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
if append and (getattr(module, "_hf_hook", None) is not None):
|
| 156 |
+
old_hook = module._hf_hook
|
| 157 |
+
remove_hook_from_module(module)
|
| 158 |
+
hook = SequentialHook(old_hook, hook)
|
| 159 |
+
|
| 160 |
+
if hasattr(module, "_hf_hook") and hasattr(module, "_old_forward"):
|
| 161 |
+
# If we already put some hook on this module, we replace it with the new one.
|
| 162 |
+
old_forward = module._old_forward
|
| 163 |
+
else:
|
| 164 |
+
old_forward = module.forward
|
| 165 |
+
module._old_forward = old_forward
|
| 166 |
+
|
| 167 |
+
module = hook.init_hook(module)
|
| 168 |
+
module._hf_hook = hook
|
| 169 |
+
|
| 170 |
+
def new_forward(module, *args, **kwargs):
|
| 171 |
+
args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
|
| 172 |
+
if module._hf_hook.no_grad:
|
| 173 |
+
with torch.no_grad():
|
| 174 |
+
output = module._old_forward(*args, **kwargs)
|
| 175 |
+
else:
|
| 176 |
+
output = module._old_forward(*args, **kwargs)
|
| 177 |
+
return module._hf_hook.post_forward(module, output)
|
| 178 |
+
|
| 179 |
+
# Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
|
| 180 |
+
# Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
|
| 181 |
+
if "GraphModuleImpl" in str(type(module)):
|
| 182 |
+
module.__class__.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
|
| 183 |
+
else:
|
| 184 |
+
module.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
|
| 185 |
+
|
| 186 |
+
return module
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def remove_hook_from_module(module: nn.Module, recurse=False):
|
| 190 |
+
"""
|
| 191 |
+
Removes any hook attached to a module via `add_hook_to_module`.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
module (`torch.nn.Module`): The module to attach a hook to.
|
| 195 |
+
recurse (`bool`, **optional**): Whether to remove the hooks recursively
|
| 196 |
+
|
| 197 |
+
Returns:
|
| 198 |
+
`torch.nn.Module`: The same module, with the hook detached (the module is modified in place, so the result can
|
| 199 |
+
be discarded).
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
if hasattr(module, "_hf_hook"):
|
| 203 |
+
module._hf_hook.detach_hook(module)
|
| 204 |
+
delattr(module, "_hf_hook")
|
| 205 |
+
|
| 206 |
+
if hasattr(module, "_old_forward"):
|
| 207 |
+
# Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
|
| 208 |
+
# Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
|
| 209 |
+
if "GraphModuleImpl" in str(type(module)):
|
| 210 |
+
module.__class__.forward = module._old_forward
|
| 211 |
+
else:
|
| 212 |
+
module.forward = module._old_forward
|
| 213 |
+
delattr(module, "_old_forward")
|
| 214 |
+
|
| 215 |
+
# Remove accelerate added warning hooks from dispatch_model
|
| 216 |
+
for attr in _accelerate_added_attributes:
|
| 217 |
+
module.__dict__.pop(attr, None)
|
| 218 |
+
|
| 219 |
+
if recurse:
|
| 220 |
+
for child in module.children():
|
| 221 |
+
remove_hook_from_module(child, recurse)
|
| 222 |
+
|
| 223 |
+
return module
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class AlignDevicesHook(ModelHook):
|
| 227 |
+
"""
|
| 228 |
+
A generic `ModelHook` that ensures inputs and model weights are on the same device for the forward pass of the
|
| 229 |
+
associated module, potentially offloading the weights after the forward pass.
|
| 230 |
+
|
| 231 |
+
Args:
|
| 232 |
+
execution_device (`torch.device`, *optional*):
|
| 233 |
+
The device on which inputs and model weights should be placed before the forward pass.
|
| 234 |
+
offload (`bool`, *optional*, defaults to `False`):
|
| 235 |
+
Whether or not the weights should be offloaded after the forward pass.
|
| 236 |
+
io_same_device (`bool`, *optional*, defaults to `False`):
|
| 237 |
+
Whether or not the output should be placed on the same device as the input was.
|
| 238 |
+
weights_map (`Mapping[str, torch.Tensor]`, *optional*):
|
| 239 |
+
When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.
|
| 240 |
+
offload_buffers (`bool`, *optional*, defaults to `False`):
|
| 241 |
+
Whether or not to include the associated module's buffers when offloading.
|
| 242 |
+
place_submodules (`bool`, *optional*, defaults to `False`):
|
| 243 |
+
Whether to place the submodules on `execution_device` during the `init_hook` event.
|
| 244 |
+
"""
|
| 245 |
+
|
| 246 |
+
def __init__(
|
| 247 |
+
self,
|
| 248 |
+
execution_device: Optional[Union[int, str, torch.device]] = None,
|
| 249 |
+
offload: bool = False,
|
| 250 |
+
io_same_device: bool = False,
|
| 251 |
+
weights_map: Optional[Mapping] = None,
|
| 252 |
+
offload_buffers: bool = False,
|
| 253 |
+
place_submodules: bool = False,
|
| 254 |
+
skip_keys: Optional[Union[str, list[str]]] = None,
|
| 255 |
+
tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
|
| 256 |
+
):
|
| 257 |
+
self.execution_device = execution_device
|
| 258 |
+
self.offload = offload
|
| 259 |
+
self.io_same_device = io_same_device
|
| 260 |
+
self.weights_map = weights_map
|
| 261 |
+
self.offload_buffers = offload_buffers
|
| 262 |
+
self.place_submodules = place_submodules
|
| 263 |
+
self.skip_keys = skip_keys
|
| 264 |
+
|
| 265 |
+
# Will contain the input device when `io_same_device=True`.
|
| 266 |
+
self.input_device = None
|
| 267 |
+
self.param_original_devices = {}
|
| 268 |
+
self.buffer_original_devices = {}
|
| 269 |
+
self.tied_params_names = set()
|
| 270 |
+
|
| 271 |
+
# The hook pre_forward/post_forward need to have knowledge of this dictionary, as with offloading we want to avoid duplicating memory
|
| 272 |
+
# for tied weights already loaded on the target execution device.
|
| 273 |
+
self.tied_params_map = tied_params_map
|
| 274 |
+
|
| 275 |
+
def __repr__(self):
|
| 276 |
+
return (
|
| 277 |
+
f"AlignDevicesHook(execution_device={self.execution_device}, offload={self.offload}, "
|
| 278 |
+
f"io_same_device={self.io_same_device}, offload_buffers={self.offload_buffers}, "
|
| 279 |
+
f"place_submodules={self.place_submodules}, skip_keys={repr(self.skip_keys)})"
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
def init_hook(self, module):
|
| 283 |
+
# In case the AlignDevicesHook is on meta device, ignore tied weights as data_ptr() is then always zero.
|
| 284 |
+
if self.execution_device == "meta" or self.execution_device == torch.device("meta"):
|
| 285 |
+
self.tied_params_map = None
|
| 286 |
+
|
| 287 |
+
if not self.offload and self.execution_device is not None:
|
| 288 |
+
for name, _ in named_module_tensors(module, recurse=self.place_submodules):
|
| 289 |
+
set_module_tensor_to_device(module, name, self.execution_device, tied_params_map=self.tied_params_map)
|
| 290 |
+
elif self.offload:
|
| 291 |
+
self.original_devices = {
|
| 292 |
+
name: param.device for name, param in named_module_tensors(module, recurse=self.place_submodules)
|
| 293 |
+
}
|
| 294 |
+
if self.weights_map is None:
|
| 295 |
+
self.weights_map = {
|
| 296 |
+
name: param.to("cpu")
|
| 297 |
+
for name, param in named_module_tensors(
|
| 298 |
+
module, include_buffers=self.offload_buffers, recurse=self.place_submodules
|
| 299 |
+
)
|
| 300 |
+
}
|
| 301 |
+
for name, _ in named_module_tensors(
|
| 302 |
+
module, include_buffers=self.offload_buffers, recurse=self.place_submodules, remove_non_persistent=True
|
| 303 |
+
):
|
| 304 |
+
# When using disk offloading, we can not rely on `weights_map[name].data_ptr()` as the reference pointer,
|
| 305 |
+
# as we have no guarantee that safetensors' `file.get_tensor()` will always give the same pointer.
|
| 306 |
+
# As we have no reliable way to track the shared data pointer of tied weights in this case, we use tied_params_names: List[str]
|
| 307 |
+
# to add on the fly pointers to `tied_params_map` in the pre_forward call.
|
| 308 |
+
if (
|
| 309 |
+
self.tied_params_map is not None
|
| 310 |
+
and recursive_getattr(module, name).data_ptr() in self.tied_params_map
|
| 311 |
+
):
|
| 312 |
+
self.tied_params_names.add(name)
|
| 313 |
+
|
| 314 |
+
set_module_tensor_to_device(module, name, "meta")
|
| 315 |
+
|
| 316 |
+
if not self.offload_buffers and self.execution_device is not None:
|
| 317 |
+
for name, _ in module.named_buffers(recurse=self.place_submodules):
|
| 318 |
+
set_module_tensor_to_device(
|
| 319 |
+
module, name, self.execution_device, tied_params_map=self.tied_params_map
|
| 320 |
+
)
|
| 321 |
+
elif self.offload_buffers and self.execution_device is not None:
|
| 322 |
+
for name in get_non_persistent_buffers(module, recurse=self.place_submodules):
|
| 323 |
+
set_module_tensor_to_device(
|
| 324 |
+
module, name, self.execution_device, tied_params_map=self.tied_params_map
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
return module
|
| 328 |
+
|
| 329 |
+
def pre_forward(self, module, *args, **kwargs):
|
| 330 |
+
if self.io_same_device:
|
| 331 |
+
self.input_device = find_device([args, kwargs])
|
| 332 |
+
if self.offload:
|
| 333 |
+
self.tied_pointers_to_remove = set()
|
| 334 |
+
|
| 335 |
+
for name, _ in named_module_tensors(
|
| 336 |
+
module,
|
| 337 |
+
include_buffers=self.offload_buffers,
|
| 338 |
+
recurse=self.place_submodules,
|
| 339 |
+
remove_non_persistent=True,
|
| 340 |
+
):
|
| 341 |
+
fp16_statistics = None
|
| 342 |
+
value = self.weights_map[name]
|
| 343 |
+
if "weight" in name and name.replace("weight", "SCB") in self.weights_map.keys():
|
| 344 |
+
if value.dtype == torch.int8:
|
| 345 |
+
fp16_statistics = self.weights_map[name.replace("weight", "SCB")]
|
| 346 |
+
|
| 347 |
+
# In case we are using offloading with tied weights, we need to keep track of the offloaded weights
|
| 348 |
+
# that are loaded on device at this point, as we will need to remove them as well from the dictionary
|
| 349 |
+
# self.tied_params_map in order to allow to free memory.
|
| 350 |
+
if name in self.tied_params_names and value.data_ptr() not in self.tied_params_map:
|
| 351 |
+
self.tied_params_map[value.data_ptr()] = {}
|
| 352 |
+
|
| 353 |
+
if (
|
| 354 |
+
value is not None
|
| 355 |
+
and self.tied_params_map is not None
|
| 356 |
+
and value.data_ptr() in self.tied_params_map
|
| 357 |
+
and self.execution_device not in self.tied_params_map[value.data_ptr()]
|
| 358 |
+
):
|
| 359 |
+
self.tied_pointers_to_remove.add((value.data_ptr(), self.execution_device))
|
| 360 |
+
|
| 361 |
+
set_module_tensor_to_device(
|
| 362 |
+
module,
|
| 363 |
+
name,
|
| 364 |
+
self.execution_device,
|
| 365 |
+
value=value,
|
| 366 |
+
fp16_statistics=fp16_statistics,
|
| 367 |
+
tied_params_map=self.tied_params_map,
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
return send_to_device(args, self.execution_device), send_to_device(
|
| 371 |
+
kwargs, self.execution_device, skip_keys=self.skip_keys
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
def post_forward(self, module, output):
|
| 375 |
+
if self.offload:
|
| 376 |
+
for name, _ in named_module_tensors(
|
| 377 |
+
module,
|
| 378 |
+
include_buffers=self.offload_buffers,
|
| 379 |
+
recurse=self.place_submodules,
|
| 380 |
+
remove_non_persistent=True,
|
| 381 |
+
):
|
| 382 |
+
set_module_tensor_to_device(module, name, "meta")
|
| 383 |
+
if type(module).__name__ == "Linear8bitLt":
|
| 384 |
+
module.state.SCB = None
|
| 385 |
+
module.state.CxB = None
|
| 386 |
+
|
| 387 |
+
# We may have loaded tied weights into self.tied_params_map (avoiding to load them several times in e.g. submodules): remove them from
|
| 388 |
+
# this dictionary to allow the garbage collector to do its job.
|
| 389 |
+
for value_pointer, device in self.tied_pointers_to_remove:
|
| 390 |
+
if isinstance(device, int):
|
| 391 |
+
if is_npu_available():
|
| 392 |
+
device = f"npu:{device}"
|
| 393 |
+
elif is_mlu_available():
|
| 394 |
+
device = f"mlu:{device}"
|
| 395 |
+
elif is_musa_available():
|
| 396 |
+
device = f"musa:{device}"
|
| 397 |
+
if device in self.tied_params_map[value_pointer]:
|
| 398 |
+
del self.tied_params_map[value_pointer][device]
|
| 399 |
+
self.tied_pointers_to_remove = set()
|
| 400 |
+
if self.io_same_device and self.input_device is not None:
|
| 401 |
+
output = send_to_device(output, self.input_device, skip_keys=self.skip_keys)
|
| 402 |
+
|
| 403 |
+
return output
|
| 404 |
+
|
| 405 |
+
def detach_hook(self, module):
|
| 406 |
+
if self.offload:
|
| 407 |
+
for name, device in self.original_devices.items():
|
| 408 |
+
if device != torch.device("meta"):
|
| 409 |
+
set_module_tensor_to_device(module, name, device, value=self.weights_map.get(name, None))
|
| 410 |
+
return module
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def attach_execution_device_hook(
|
| 414 |
+
module: torch.nn.Module,
|
| 415 |
+
execution_device: Union[int, str, torch.device],
|
| 416 |
+
skip_keys: Optional[Union[str, list[str]]] = None,
|
| 417 |
+
preload_module_classes: Optional[list[str]] = None,
|
| 418 |
+
tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
|
| 419 |
+
):
|
| 420 |
+
"""
|
| 421 |
+
Recursively attaches `AlignDevicesHook` to all submodules of a given model to make sure they have the right
|
| 422 |
+
execution device
|
| 423 |
+
|
| 424 |
+
Args:
|
| 425 |
+
module (`torch.nn.Module`):
|
| 426 |
+
The module where we want to attach the hooks.
|
| 427 |
+
execution_device (`int`, `str` or `torch.device`):
|
| 428 |
+
The device on which inputs and model weights should be placed before the forward pass.
|
| 429 |
+
skip_keys (`str` or `List[str]`, *optional*):
|
| 430 |
+
A list of keys to ignore when moving inputs or outputs between devices.
|
| 431 |
+
preload_module_classes (`List[str]`, *optional*):
|
| 432 |
+
A list of classes whose instances should load all their weights (even in the submodules) at the beginning
|
| 433 |
+
of the forward. This should only be used for classes that have submodules which are registered but not
|
| 434 |
+
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
|
| 435 |
+
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
|
| 436 |
+
tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
|
| 437 |
+
A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
|
| 438 |
+
device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
|
| 439 |
+
instead of duplicating memory.
|
| 440 |
+
"""
|
| 441 |
+
if not hasattr(module, "_hf_hook") and len(module.state_dict()) > 0:
|
| 442 |
+
add_hook_to_module(
|
| 443 |
+
module,
|
| 444 |
+
AlignDevicesHook(execution_device, skip_keys=skip_keys, tied_params_map=tied_params_map),
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
# Break the recursion if we get to a preload module.
|
| 448 |
+
if preload_module_classes is not None and module.__class__.__name__ in preload_module_classes:
|
| 449 |
+
return
|
| 450 |
+
|
| 451 |
+
for child in module.children():
|
| 452 |
+
attach_execution_device_hook(
|
| 453 |
+
child,
|
| 454 |
+
execution_device,
|
| 455 |
+
skip_keys=skip_keys,
|
| 456 |
+
preload_module_classes=preload_module_classes,
|
| 457 |
+
tied_params_map=tied_params_map,
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def attach_align_device_hook(
|
| 462 |
+
module: torch.nn.Module,
|
| 463 |
+
execution_device: Optional[torch.device] = None,
|
| 464 |
+
offload: bool = False,
|
| 465 |
+
weights_map: Optional[Mapping] = None,
|
| 466 |
+
offload_buffers: bool = False,
|
| 467 |
+
module_name: str = "",
|
| 468 |
+
skip_keys: Optional[Union[str, list[str]]] = None,
|
| 469 |
+
preload_module_classes: Optional[list[str]] = None,
|
| 470 |
+
tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
|
| 471 |
+
):
|
| 472 |
+
"""
|
| 473 |
+
Recursively attaches `AlignDevicesHook` to all submodules of a given model that have direct parameters and/or
|
| 474 |
+
buffers.
|
| 475 |
+
|
| 476 |
+
Args:
|
| 477 |
+
module (`torch.nn.Module`):
|
| 478 |
+
The module where we want to attach the hooks.
|
| 479 |
+
execution_device (`torch.device`, *optional*):
|
| 480 |
+
The device on which inputs and model weights should be placed before the forward pass.
|
| 481 |
+
offload (`bool`, *optional*, defaults to `False`):
|
| 482 |
+
Whether or not the weights should be offloaded after the forward pass.
|
| 483 |
+
weights_map (`Mapping[str, torch.Tensor]`, *optional*):
|
| 484 |
+
When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.
|
| 485 |
+
offload_buffers (`bool`, *optional*, defaults to `False`):
|
| 486 |
+
Whether or not to include the associated module's buffers when offloading.
|
| 487 |
+
module_name (`str`, *optional*, defaults to `""`):
|
| 488 |
+
The name of the module.
|
| 489 |
+
skip_keys (`str` or `List[str]`, *optional*):
|
| 490 |
+
A list of keys to ignore when moving inputs or outputs between devices.
|
| 491 |
+
preload_module_classes (`List[str]`, *optional*):
|
| 492 |
+
A list of classes whose instances should load all their weights (even in the submodules) at the beginning
|
| 493 |
+
of the forward. This should only be used for classes that have submodules which are registered but not
|
| 494 |
+
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
|
| 495 |
+
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
|
| 496 |
+
tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
|
| 497 |
+
A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
|
| 498 |
+
device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
|
| 499 |
+
instead of duplicating memory.
|
| 500 |
+
"""
|
| 501 |
+
# Attach the hook on this module if it has any direct tensor.
|
| 502 |
+
directs = named_module_tensors(module)
|
| 503 |
+
full_offload = (
|
| 504 |
+
offload and preload_module_classes is not None and module.__class__.__name__ in preload_module_classes
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
if len(list(directs)) > 0 or full_offload:
|
| 508 |
+
if weights_map is not None:
|
| 509 |
+
prefix = f"{module_name}." if len(module_name) > 0 else ""
|
| 510 |
+
prefixed_weights_map = PrefixedDataset(weights_map, prefix)
|
| 511 |
+
else:
|
| 512 |
+
prefixed_weights_map = None
|
| 513 |
+
hook = AlignDevicesHook(
|
| 514 |
+
execution_device=execution_device,
|
| 515 |
+
offload=offload,
|
| 516 |
+
weights_map=prefixed_weights_map,
|
| 517 |
+
offload_buffers=offload_buffers,
|
| 518 |
+
place_submodules=full_offload,
|
| 519 |
+
skip_keys=skip_keys,
|
| 520 |
+
tied_params_map=tied_params_map,
|
| 521 |
+
)
|
| 522 |
+
add_hook_to_module(module, hook, append=True)
|
| 523 |
+
|
| 524 |
+
# We stop the recursion in case we hit the full offload.
|
| 525 |
+
if full_offload:
|
| 526 |
+
return
|
| 527 |
+
|
| 528 |
+
# Recurse on all children of the module.
|
| 529 |
+
for child_name, child in module.named_children():
|
| 530 |
+
child_name = f"{module_name}.{child_name}" if len(module_name) > 0 else child_name
|
| 531 |
+
attach_align_device_hook(
|
| 532 |
+
child,
|
| 533 |
+
execution_device=execution_device,
|
| 534 |
+
offload=offload,
|
| 535 |
+
weights_map=weights_map,
|
| 536 |
+
offload_buffers=offload_buffers,
|
| 537 |
+
module_name=child_name,
|
| 538 |
+
preload_module_classes=preload_module_classes,
|
| 539 |
+
skip_keys=skip_keys,
|
| 540 |
+
tied_params_map=tied_params_map,
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
def remove_hook_from_submodules(module: nn.Module):
|
| 545 |
+
"""
|
| 546 |
+
Recursively removes all hooks attached on the submodules of a given model.
|
| 547 |
+
|
| 548 |
+
Args:
|
| 549 |
+
module (`torch.nn.Module`): The module on which to remove all hooks.
|
| 550 |
+
"""
|
| 551 |
+
remove_hook_from_module(module)
|
| 552 |
+
for child in module.children():
|
| 553 |
+
remove_hook_from_submodules(child)
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
def attach_align_device_hook_on_blocks(
|
| 557 |
+
module: nn.Module,
|
| 558 |
+
execution_device: Optional[Union[torch.device, dict[str, torch.device]]] = None,
|
| 559 |
+
offload: Union[bool, dict[str, bool]] = False,
|
| 560 |
+
weights_map: Mapping = None,
|
| 561 |
+
offload_buffers: bool = False,
|
| 562 |
+
module_name: str = "",
|
| 563 |
+
skip_keys: Optional[Union[str, list[str]]] = None,
|
| 564 |
+
preload_module_classes: Optional[list[str]] = None,
|
| 565 |
+
tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
|
| 566 |
+
):
|
| 567 |
+
"""
|
| 568 |
+
Attaches `AlignDevicesHook` to all blocks of a given model as needed.
|
| 569 |
+
|
| 570 |
+
Args:
|
| 571 |
+
module (`torch.nn.Module`):
|
| 572 |
+
The module where we want to attach the hooks.
|
| 573 |
+
execution_device (`torch.device` or `Dict[str, torch.device]`, *optional*):
|
| 574 |
+
The device on which inputs and model weights should be placed before the forward pass. It can be one device
|
| 575 |
+
for the whole module, or a dictionary mapping module name to device.
|
| 576 |
+
offload (`bool`, *optional*, defaults to `False`):
|
| 577 |
+
Whether or not the weights should be offloaded after the forward pass. It can be one boolean for the whole
|
| 578 |
+
module, or a dictionary mapping module name to boolean.
|
| 579 |
+
weights_map (`Mapping[str, torch.Tensor]`, *optional*):
|
| 580 |
+
When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.
|
| 581 |
+
offload_buffers (`bool`, *optional*, defaults to `False`):
|
| 582 |
+
Whether or not to include the associated module's buffers when offloading.
|
| 583 |
+
module_name (`str`, *optional*, defaults to `""`):
|
| 584 |
+
The name of the module.
|
| 585 |
+
skip_keys (`str` or `List[str]`, *optional*):
|
| 586 |
+
A list of keys to ignore when moving inputs or outputs between devices.
|
| 587 |
+
preload_module_classes (`List[str]`, *optional*):
|
| 588 |
+
A list of classes whose instances should load all their weights (even in the submodules) at the beginning
|
| 589 |
+
of the forward. This should only be used for classes that have submodules which are registered but not
|
| 590 |
+
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
|
| 591 |
+
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
|
| 592 |
+
tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
|
| 593 |
+
A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
|
| 594 |
+
device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
|
| 595 |
+
instead of duplicating memory.
|
| 596 |
+
"""
|
| 597 |
+
# If one device and one offload, we've got one hook.
|
| 598 |
+
if not isinstance(execution_device, Mapping) and not isinstance(offload, dict):
|
| 599 |
+
if not offload:
|
| 600 |
+
hook = AlignDevicesHook(
|
| 601 |
+
execution_device=execution_device,
|
| 602 |
+
io_same_device=True,
|
| 603 |
+
skip_keys=skip_keys,
|
| 604 |
+
place_submodules=True,
|
| 605 |
+
tied_params_map=tied_params_map,
|
| 606 |
+
)
|
| 607 |
+
add_hook_to_module(module, hook)
|
| 608 |
+
else:
|
| 609 |
+
attach_align_device_hook(
|
| 610 |
+
module,
|
| 611 |
+
execution_device=execution_device,
|
| 612 |
+
offload=True,
|
| 613 |
+
weights_map=weights_map,
|
| 614 |
+
offload_buffers=offload_buffers,
|
| 615 |
+
module_name=module_name,
|
| 616 |
+
skip_keys=skip_keys,
|
| 617 |
+
tied_params_map=tied_params_map,
|
| 618 |
+
)
|
| 619 |
+
return
|
| 620 |
+
|
| 621 |
+
if not isinstance(execution_device, Mapping):
|
| 622 |
+
execution_device = {key: execution_device for key in offload.keys()}
|
| 623 |
+
if not isinstance(offload, Mapping):
|
| 624 |
+
offload = {key: offload for key in execution_device.keys()}
|
| 625 |
+
|
| 626 |
+
if module_name in execution_device and module_name in offload and not offload[module_name]:
|
| 627 |
+
hook = AlignDevicesHook(
|
| 628 |
+
execution_device=execution_device[module_name],
|
| 629 |
+
offload_buffers=offload_buffers,
|
| 630 |
+
io_same_device=(module_name == ""),
|
| 631 |
+
place_submodules=True,
|
| 632 |
+
skip_keys=skip_keys,
|
| 633 |
+
tied_params_map=tied_params_map,
|
| 634 |
+
)
|
| 635 |
+
add_hook_to_module(module, hook)
|
| 636 |
+
attach_execution_device_hook(
|
| 637 |
+
module, execution_device[module_name], skip_keys=skip_keys, tied_params_map=tied_params_map
|
| 638 |
+
)
|
| 639 |
+
elif module_name in execution_device and module_name in offload:
|
| 640 |
+
attach_align_device_hook(
|
| 641 |
+
module,
|
| 642 |
+
execution_device=execution_device[module_name],
|
| 643 |
+
offload=True,
|
| 644 |
+
weights_map=weights_map,
|
| 645 |
+
offload_buffers=offload_buffers,
|
| 646 |
+
module_name=module_name,
|
| 647 |
+
skip_keys=skip_keys,
|
| 648 |
+
preload_module_classes=preload_module_classes,
|
| 649 |
+
tied_params_map=tied_params_map,
|
| 650 |
+
)
|
| 651 |
+
if not hasattr(module, "_hf_hook"):
|
| 652 |
+
hook = AlignDevicesHook(
|
| 653 |
+
execution_device=execution_device[module_name],
|
| 654 |
+
io_same_device=(module_name == ""),
|
| 655 |
+
skip_keys=skip_keys,
|
| 656 |
+
tied_params_map=tied_params_map,
|
| 657 |
+
)
|
| 658 |
+
add_hook_to_module(module, hook)
|
| 659 |
+
attach_execution_device_hook(
|
| 660 |
+
module,
|
| 661 |
+
execution_device[module_name],
|
| 662 |
+
preload_module_classes=preload_module_classes,
|
| 663 |
+
skip_keys=skip_keys,
|
| 664 |
+
tied_params_map=tied_params_map,
|
| 665 |
+
)
|
| 666 |
+
elif module_name == "":
|
| 667 |
+
hook = AlignDevicesHook(
|
| 668 |
+
execution_device=execution_device.get(""),
|
| 669 |
+
io_same_device=True,
|
| 670 |
+
skip_keys=skip_keys,
|
| 671 |
+
tied_params_map=tied_params_map,
|
| 672 |
+
)
|
| 673 |
+
add_hook_to_module(module, hook)
|
| 674 |
+
|
| 675 |
+
for child_name, child in module.named_children():
|
| 676 |
+
child_name = f"{module_name}.{child_name}" if len(module_name) > 0 else child_name
|
| 677 |
+
attach_align_device_hook_on_blocks(
|
| 678 |
+
child,
|
| 679 |
+
execution_device=execution_device,
|
| 680 |
+
offload=offload,
|
| 681 |
+
weights_map=weights_map,
|
| 682 |
+
offload_buffers=offload_buffers,
|
| 683 |
+
module_name=child_name,
|
| 684 |
+
preload_module_classes=preload_module_classes,
|
| 685 |
+
skip_keys=skip_keys,
|
| 686 |
+
tied_params_map=tied_params_map,
|
| 687 |
+
)
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
class CpuOffload(ModelHook):
|
| 691 |
+
"""
|
| 692 |
+
Offloads a model on the CPU until its forward pass is called. The model will not be offloaded back to the CPU after
|
| 693 |
+
the forward, the user needs to call the `init_hook` method again for this.
|
| 694 |
+
|
| 695 |
+
Args:
|
| 696 |
+
execution_device(`str`, `int` or `torch.device`, *optional*):
|
| 697 |
+
The device on which the model should be executed. Will default to the MPS device if it's available, then
|
| 698 |
+
GPU 0 if there is a GPU, and finally to the CPU.
|
| 699 |
+
prev_module_hook (`UserCpuOffloadHook`, *optional*):
|
| 700 |
+
The hook sent back by [`cpu_offload_with_hook`] for a previous model in the pipeline you are running. If
|
| 701 |
+
passed, its offload method will be called just before the forward of the model to which this hook is
|
| 702 |
+
attached.
|
| 703 |
+
"""
|
| 704 |
+
|
| 705 |
+
def __init__(
|
| 706 |
+
self,
|
| 707 |
+
execution_device: Optional[Union[str, int, torch.device]] = None,
|
| 708 |
+
prev_module_hook: Optional["UserCpuOffloadHook"] = None,
|
| 709 |
+
):
|
| 710 |
+
self.prev_module_hook = prev_module_hook
|
| 711 |
+
|
| 712 |
+
self.execution_device = execution_device if execution_device is not None else PartialState().default_device
|
| 713 |
+
|
| 714 |
+
def init_hook(self, module):
|
| 715 |
+
return module.to("cpu")
|
| 716 |
+
|
| 717 |
+
def pre_forward(self, module, *args, **kwargs):
|
| 718 |
+
if self.prev_module_hook is not None:
|
| 719 |
+
self.prev_module_hook.offload()
|
| 720 |
+
clear_device_cache()
|
| 721 |
+
module.to(self.execution_device)
|
| 722 |
+
return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device)
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
class UserCpuOffloadHook:
|
| 726 |
+
"""
|
| 727 |
+
A simple hook grouping a model and a `ModelHook`, which provides easy APIs for to call the init method of the hook
|
| 728 |
+
or remove it entirely.
|
| 729 |
+
"""
|
| 730 |
+
|
| 731 |
+
def __init__(self, model, hook):
|
| 732 |
+
self.model = model
|
| 733 |
+
self.hook = hook
|
| 734 |
+
|
| 735 |
+
def offload(self):
|
| 736 |
+
self.hook.init_hook(self.model)
|
| 737 |
+
|
| 738 |
+
def remove(self):
|
| 739 |
+
remove_hook_from_module(self.model)
|
venv/Lib/site-packages/accelerate/inference.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import math
|
| 15 |
+
from types import MethodType
|
| 16 |
+
from typing import Any, Optional, Union
|
| 17 |
+
|
| 18 |
+
from .state import PartialState
|
| 19 |
+
from .utils import (
|
| 20 |
+
calculate_maximum_sizes,
|
| 21 |
+
convert_bytes,
|
| 22 |
+
copy_tensor_to_devices,
|
| 23 |
+
ignorant_find_batch_size,
|
| 24 |
+
infer_auto_device_map,
|
| 25 |
+
is_pippy_available,
|
| 26 |
+
pad_input_tensors,
|
| 27 |
+
send_to_device,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def generate_device_map(model, num_processes: int = 1, no_split_module_classes=None, max_memory: dict = None):
|
| 32 |
+
"""
|
| 33 |
+
Calculates the device map for `model` with an offset for PiPPy
|
| 34 |
+
"""
|
| 35 |
+
if num_processes == 1:
|
| 36 |
+
return infer_auto_device_map(model, no_split_module_classes=no_split_module_classes, clean_result=False)
|
| 37 |
+
if max_memory is None:
|
| 38 |
+
model_size, shared = calculate_maximum_sizes(model)
|
| 39 |
+
|
| 40 |
+
# Split into `n` chunks for each GPU
|
| 41 |
+
memory = (model_size + shared[0]) / num_processes
|
| 42 |
+
memory = convert_bytes(memory)
|
| 43 |
+
value, ending = memory.split(" ")
|
| 44 |
+
|
| 45 |
+
# Add a chunk to deal with potential extra shared memory instances
|
| 46 |
+
memory = math.ceil(float(value)) * 1.1
|
| 47 |
+
memory = f"{memory} {ending}"
|
| 48 |
+
max_memory = {i: memory for i in range(num_processes)}
|
| 49 |
+
device_map = infer_auto_device_map(
|
| 50 |
+
model,
|
| 51 |
+
max_memory=max_memory,
|
| 52 |
+
no_split_module_classes=no_split_module_classes,
|
| 53 |
+
clean_result=False,
|
| 54 |
+
)
|
| 55 |
+
return device_map
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def find_pippy_batch_size(args, kwargs):
|
| 59 |
+
found_batch_size = None
|
| 60 |
+
if args is not None:
|
| 61 |
+
for arg in args:
|
| 62 |
+
found_batch_size = ignorant_find_batch_size(arg)
|
| 63 |
+
if found_batch_size is not None:
|
| 64 |
+
break
|
| 65 |
+
if kwargs is not None and found_batch_size is None:
|
| 66 |
+
for kwarg in kwargs.values():
|
| 67 |
+
found_batch_size = ignorant_find_batch_size(kwarg)
|
| 68 |
+
if found_batch_size is not None:
|
| 69 |
+
break
|
| 70 |
+
return found_batch_size
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def build_pipeline(model, split_points, args, kwargs, num_chunks):
|
| 74 |
+
"""
|
| 75 |
+
Attaches the split points to the model based on `self.device_map` and generates a `PipelineStage`. Requires passing
|
| 76 |
+
in needed `args` and `kwargs` as the model needs on the CPU.
|
| 77 |
+
|
| 78 |
+
Users can pass in custom `num_chunks` as an optional hyper-parameter. By default will use
|
| 79 |
+
`AcceleratorState.num_processes`
|
| 80 |
+
"""
|
| 81 |
+
# Note: We import here to reduce import time from general modules, and isolate outside dependencies
|
| 82 |
+
from torch.distributed.pipelining import ScheduleGPipe, SplitPoint, pipeline
|
| 83 |
+
|
| 84 |
+
# We need to annotate the split points in the model for PiPPy
|
| 85 |
+
state = PartialState()
|
| 86 |
+
split_spec = {split_point: SplitPoint.BEGINNING for split_point in split_points}
|
| 87 |
+
pipe = pipeline(
|
| 88 |
+
model,
|
| 89 |
+
mb_args=args,
|
| 90 |
+
mb_kwargs=kwargs,
|
| 91 |
+
split_spec=split_spec,
|
| 92 |
+
)
|
| 93 |
+
stage = pipe.build_stage(state.local_process_index, device=state.device)
|
| 94 |
+
schedule = ScheduleGPipe(stage, num_chunks)
|
| 95 |
+
|
| 96 |
+
return schedule
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def pippy_forward(forward, num_chunks, gather_output, *args, **kwargs):
|
| 100 |
+
state = PartialState()
|
| 101 |
+
output = None
|
| 102 |
+
|
| 103 |
+
if state.num_processes == 1:
|
| 104 |
+
output = forward(*args, **kwargs)
|
| 105 |
+
elif state.is_local_main_process:
|
| 106 |
+
found_batch_size = find_pippy_batch_size(args, kwargs)
|
| 107 |
+
if found_batch_size is None:
|
| 108 |
+
raise ValueError("Could not find batch size from args or kwargs")
|
| 109 |
+
else:
|
| 110 |
+
if found_batch_size != num_chunks:
|
| 111 |
+
args = pad_input_tensors(args, found_batch_size, num_chunks)
|
| 112 |
+
kwargs = pad_input_tensors(kwargs, found_batch_size, num_chunks)
|
| 113 |
+
forward(*args, **kwargs)
|
| 114 |
+
elif state.is_last_process:
|
| 115 |
+
output = forward()
|
| 116 |
+
else:
|
| 117 |
+
forward()
|
| 118 |
+
if gather_output:
|
| 119 |
+
# Each node will get a copy of the full output which is only on the last GPU
|
| 120 |
+
output = copy_tensor_to_devices(output)
|
| 121 |
+
return output
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def prepare_pippy(
|
| 125 |
+
model,
|
| 126 |
+
split_points: Optional[Union[str, list[str]]] = "auto",
|
| 127 |
+
no_split_module_classes: Optional[list[str]] = None,
|
| 128 |
+
example_args: Optional[tuple[Any]] = (),
|
| 129 |
+
example_kwargs: Optional[dict[str, Any]] = None,
|
| 130 |
+
num_chunks: Optional[int] = None,
|
| 131 |
+
gather_output: Optional[bool] = False,
|
| 132 |
+
):
|
| 133 |
+
"""
|
| 134 |
+
Wraps `model` for pipeline parallel inference.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
model (`torch.nn.Module`):
|
| 138 |
+
A model we want to split for pipeline-parallel inference
|
| 139 |
+
split_points (`str` or `List[str]`, defaults to 'auto'):
|
| 140 |
+
How to generate the split points and chunk the model across each GPU. 'auto' will find the best balanced
|
| 141 |
+
split given any model. Should be a list of layer names in the model to split by otherwise.
|
| 142 |
+
no_split_module_classes (`List[str]`):
|
| 143 |
+
A list of class names for layers we don't want to be split.
|
| 144 |
+
example_args (tuple of model inputs):
|
| 145 |
+
The expected inputs for the model that uses order-based inputs for a *single process*. Recommended to use
|
| 146 |
+
this method if possible.
|
| 147 |
+
example_kwargs (dict of model inputs)
|
| 148 |
+
The expected inputs for the model that uses dictionary-based inputs for a *single process*. This is a
|
| 149 |
+
*highly* limiting structure that requires the same keys be present at *all* inference calls. Not
|
| 150 |
+
recommended unless the prior condition is true for all cases.
|
| 151 |
+
num_chunks (`int`, defaults to the number of available GPUs):
|
| 152 |
+
The number of different stages the Pipeline will have. By default it will assign one chunk per GPU, but
|
| 153 |
+
this can be tuned and played with. In general one should have num_chunks >= num_gpus.
|
| 154 |
+
gather_output (`bool`, defaults to `False`):
|
| 155 |
+
If `True`, the output from the last GPU (which holds the true outputs) is sent across to all GPUs.
|
| 156 |
+
"""
|
| 157 |
+
if not is_pippy_available():
|
| 158 |
+
raise ImportError("Using `torch.distributed.pipelining` requires PyTorch 2.4.0 or later.")
|
| 159 |
+
state = PartialState()
|
| 160 |
+
example_args = send_to_device(example_args, "cpu")
|
| 161 |
+
example_kwargs = send_to_device(example_kwargs, "cpu")
|
| 162 |
+
if num_chunks is None:
|
| 163 |
+
num_chunks = state.num_processes
|
| 164 |
+
if split_points == "auto":
|
| 165 |
+
device_map = generate_device_map(model, num_chunks, no_split_module_classes=no_split_module_classes)
|
| 166 |
+
split_points = []
|
| 167 |
+
for i in range(1, num_chunks):
|
| 168 |
+
split_points.append(next(k for k, v in device_map.items() if v == i))
|
| 169 |
+
model.hf_split_points = split_points
|
| 170 |
+
stage = build_pipeline(model, split_points, example_args, example_kwargs, num_chunks)
|
| 171 |
+
model._original_forward = model.forward
|
| 172 |
+
model._original_call = model.__call__
|
| 173 |
+
model.pippy_stage = stage
|
| 174 |
+
model.hf_split_points = split_points
|
| 175 |
+
|
| 176 |
+
def forward(*args, **kwargs):
|
| 177 |
+
return pippy_forward(stage.step, num_chunks, gather_output, *args, **kwargs)
|
| 178 |
+
|
| 179 |
+
# To act like a decorator so that it can be popped when doing `extract_model_from_parallel`
|
| 180 |
+
# Note: creates an infinite recursion loop with `generate`
|
| 181 |
+
model_forward = MethodType(forward, model)
|
| 182 |
+
forward.__wrapped__ = model_forward
|
| 183 |
+
model.forward = forward
|
| 184 |
+
return model
|
venv/Lib/site-packages/accelerate/launchers.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import sys
|
| 17 |
+
import tempfile
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
from .state import AcceleratorState, PartialState
|
| 22 |
+
from .utils import (
|
| 23 |
+
PrecisionType,
|
| 24 |
+
PrepareForLaunch,
|
| 25 |
+
are_libraries_initialized,
|
| 26 |
+
check_cuda_p2p_ib_support,
|
| 27 |
+
get_gpu_info,
|
| 28 |
+
is_mps_available,
|
| 29 |
+
is_torch_version,
|
| 30 |
+
patch_environment,
|
| 31 |
+
)
|
| 32 |
+
from .utils.constants import ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def test_launch():
|
| 36 |
+
"Verify a `PartialState` can be initialized."
|
| 37 |
+
_ = PartialState()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def notebook_launcher(
|
| 41 |
+
function,
|
| 42 |
+
args=(),
|
| 43 |
+
num_processes=None,
|
| 44 |
+
mixed_precision="no",
|
| 45 |
+
use_port="29500",
|
| 46 |
+
master_addr="127.0.0.1",
|
| 47 |
+
node_rank=0,
|
| 48 |
+
num_nodes=1,
|
| 49 |
+
rdzv_backend="static",
|
| 50 |
+
rdzv_endpoint="",
|
| 51 |
+
rdzv_conf=None,
|
| 52 |
+
rdzv_id="none",
|
| 53 |
+
max_restarts=0,
|
| 54 |
+
monitor_interval=0.1,
|
| 55 |
+
log_line_prefix_template=None,
|
| 56 |
+
):
|
| 57 |
+
"""
|
| 58 |
+
Launches a training function, using several processes or multiple nodes if it's possible in the current environment
|
| 59 |
+
(TPU with multiple cores for instance).
|
| 60 |
+
|
| 61 |
+
<Tip warning={true}>
|
| 62 |
+
|
| 63 |
+
To use this function absolutely zero calls to a CUDA device must be made in the notebook session before calling. If
|
| 64 |
+
any have been made, you will need to restart the notebook and make sure no cells use any CUDA capability.
|
| 65 |
+
|
| 66 |
+
Setting `ACCELERATE_DEBUG_MODE="1"` in your environment will run a test before truly launching to ensure that none
|
| 67 |
+
of those calls have been made.
|
| 68 |
+
|
| 69 |
+
</Tip>
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
function (`Callable`):
|
| 73 |
+
The training function to execute. If it accepts arguments, the first argument should be the index of the
|
| 74 |
+
process run.
|
| 75 |
+
args (`Tuple`):
|
| 76 |
+
Tuple of arguments to pass to the function (it will receive `*args`).
|
| 77 |
+
num_processes (`int`, *optional*):
|
| 78 |
+
The number of processes to use for training. Will default to 8 in Colab/Kaggle if a TPU is available, to
|
| 79 |
+
the number of GPUs available otherwise.
|
| 80 |
+
mixed_precision (`str`, *optional*, defaults to `"no"`):
|
| 81 |
+
If `fp16` or `bf16`, will use mixed precision training on multi-GPU.
|
| 82 |
+
use_port (`str`, *optional*, defaults to `"29500"`):
|
| 83 |
+
The port to use to communicate between processes when launching a multi-GPU training.
|
| 84 |
+
master_addr (`str`, *optional*, defaults to `"127.0.0.1"`):
|
| 85 |
+
The address to use for communication between processes.
|
| 86 |
+
node_rank (`int`, *optional*, defaults to 0):
|
| 87 |
+
The rank of the current node.
|
| 88 |
+
num_nodes (`int`, *optional*, defaults to 1):
|
| 89 |
+
The number of nodes to use for training.
|
| 90 |
+
rdzv_backend (`str`, *optional*, defaults to `"static"`):
|
| 91 |
+
The rendezvous method to use, such as 'static' (the default) or 'c10d'
|
| 92 |
+
rdzv_endpoint (`str`, *optional*, defaults to `""`):
|
| 93 |
+
The endpoint of the rdzv sync. storage.
|
| 94 |
+
rdzv_conf (`Dict`, *optional*, defaults to `None`):
|
| 95 |
+
Additional rendezvous configuration.
|
| 96 |
+
rdzv_id (`str`, *optional*, defaults to `"none"`):
|
| 97 |
+
The unique run id of the job.
|
| 98 |
+
max_restarts (`int`, *optional*, defaults to 0):
|
| 99 |
+
The maximum amount of restarts that elastic agent will conduct on workers before failure.
|
| 100 |
+
monitor_interval (`float`, *optional*, defaults to 0.1):
|
| 101 |
+
The interval in seconds that is used by the elastic_agent as a period of monitoring workers.
|
| 102 |
+
log_line_prefix_template (`str`, *optional*, defaults to `None`):
|
| 103 |
+
The prefix template for elastic launch logging. Available from PyTorch 2.2.0.
|
| 104 |
+
|
| 105 |
+
Example:
|
| 106 |
+
|
| 107 |
+
```python
|
| 108 |
+
# Assume this is defined in a Jupyter Notebook on an instance with two GPUs
|
| 109 |
+
from accelerate import notebook_launcher
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def train(*args):
|
| 113 |
+
# Your training function here
|
| 114 |
+
...
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
notebook_launcher(train, args=(arg1, arg2), num_processes=2, mixed_precision="fp16")
|
| 118 |
+
```
|
| 119 |
+
"""
|
| 120 |
+
# Are we in a google colab or a Kaggle Kernel?
|
| 121 |
+
in_colab = False
|
| 122 |
+
in_kaggle = False
|
| 123 |
+
if any(key.startswith("KAGGLE") for key in os.environ.keys()):
|
| 124 |
+
in_kaggle = True
|
| 125 |
+
elif "IPython" in sys.modules:
|
| 126 |
+
in_colab = "google.colab" in str(sys.modules["IPython"].get_ipython())
|
| 127 |
+
|
| 128 |
+
try:
|
| 129 |
+
mixed_precision = PrecisionType(mixed_precision.lower())
|
| 130 |
+
except ValueError:
|
| 131 |
+
raise ValueError(
|
| 132 |
+
f"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}."
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
if (in_colab or in_kaggle) and (os.environ.get("TPU_NAME", None) is not None):
|
| 136 |
+
# TPU launch
|
| 137 |
+
import torch_xla.distributed.xla_multiprocessing as xmp
|
| 138 |
+
from torch_xla import device_count
|
| 139 |
+
|
| 140 |
+
if len(AcceleratorState._shared_state) > 0:
|
| 141 |
+
raise ValueError(
|
| 142 |
+
"To train on TPU in Colab or Kaggle Kernel, the `Accelerator` should only be initialized inside "
|
| 143 |
+
"your training function. Restart your notebook and make sure no cells initializes an "
|
| 144 |
+
"`Accelerator`."
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
launcher = PrepareForLaunch(function, distributed_type="XLA")
|
| 148 |
+
print(f"Launching a training on {device_count()} TPU cores.")
|
| 149 |
+
xmp.spawn(launcher, args=args, start_method="fork")
|
| 150 |
+
elif in_colab and get_gpu_info()[1] < 2:
|
| 151 |
+
# No need for a distributed launch otherwise as it's either CPU or one GPU.
|
| 152 |
+
if torch.cuda.is_available():
|
| 153 |
+
print("Launching training on one GPU.")
|
| 154 |
+
else:
|
| 155 |
+
print("Launching training on one CPU.")
|
| 156 |
+
function(*args)
|
| 157 |
+
else:
|
| 158 |
+
if num_processes is None:
|
| 159 |
+
raise ValueError(
|
| 160 |
+
"You have to specify the number of GPUs you would like to use, add `num_processes=...` to your call."
|
| 161 |
+
)
|
| 162 |
+
if node_rank >= num_nodes:
|
| 163 |
+
raise ValueError("The node_rank must be less than the number of nodes.")
|
| 164 |
+
if num_processes > 1:
|
| 165 |
+
# Multi-GPU launch
|
| 166 |
+
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
| 167 |
+
from torch.multiprocessing import start_processes
|
| 168 |
+
from torch.multiprocessing.spawn import ProcessRaisedException
|
| 169 |
+
|
| 170 |
+
if len(AcceleratorState._shared_state) > 0:
|
| 171 |
+
raise ValueError(
|
| 172 |
+
"To launch a multi-GPU training from your notebook, the `Accelerator` should only be initialized "
|
| 173 |
+
"inside your training function. Restart your notebook and make sure no cells initializes an "
|
| 174 |
+
"`Accelerator`."
|
| 175 |
+
)
|
| 176 |
+
# Check for specific libraries known to initialize CUDA that users constantly use
|
| 177 |
+
problematic_imports = are_libraries_initialized("bitsandbytes")
|
| 178 |
+
if len(problematic_imports) > 0:
|
| 179 |
+
err = (
|
| 180 |
+
"Could not start distributed process. Libraries known to initialize CUDA upon import have been "
|
| 181 |
+
"imported already. Please keep these imports inside your training function to try and help with this:"
|
| 182 |
+
)
|
| 183 |
+
for lib_name in problematic_imports:
|
| 184 |
+
err += f"\n\t* `{lib_name}`"
|
| 185 |
+
raise RuntimeError(err)
|
| 186 |
+
|
| 187 |
+
patched_env = dict(
|
| 188 |
+
nproc=num_processes,
|
| 189 |
+
node_rank=node_rank,
|
| 190 |
+
world_size=num_nodes * num_processes,
|
| 191 |
+
master_addr=master_addr,
|
| 192 |
+
master_port=use_port,
|
| 193 |
+
mixed_precision=mixed_precision,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# Check for CUDA P2P and IB issues
|
| 197 |
+
if not check_cuda_p2p_ib_support():
|
| 198 |
+
patched_env["nccl_p2p_disable"] = "1"
|
| 199 |
+
patched_env["nccl_ib_disable"] = "1"
|
| 200 |
+
|
| 201 |
+
# torch.distributed will expect a few environment variable to be here. We set the ones common to each
|
| 202 |
+
# process here (the other ones will be set be the launcher).
|
| 203 |
+
with patch_environment(**patched_env):
|
| 204 |
+
# First dummy launch
|
| 205 |
+
if os.environ.get("ACCELERATE_DEBUG_MODE", "false").lower() == "true":
|
| 206 |
+
launcher = PrepareForLaunch(test_launch, distributed_type="MULTI_GPU")
|
| 207 |
+
try:
|
| 208 |
+
start_processes(launcher, args=(), nprocs=num_processes, start_method="fork")
|
| 209 |
+
except ProcessRaisedException as e:
|
| 210 |
+
err = "An issue was found when verifying a stable environment for the notebook launcher."
|
| 211 |
+
if "Cannot re-initialize CUDA in forked subprocess" in e.args[0]:
|
| 212 |
+
raise RuntimeError(
|
| 213 |
+
f"{err}"
|
| 214 |
+
"This likely stems from an outside import causing issues once the `notebook_launcher()` is called. "
|
| 215 |
+
"Please review your imports and test them when running the `notebook_launcher()` to identify "
|
| 216 |
+
"which one is problematic and causing CUDA to be initialized."
|
| 217 |
+
) from e
|
| 218 |
+
else:
|
| 219 |
+
raise RuntimeError(f"{err} The following error was raised: {e}") from e
|
| 220 |
+
# Now the actual launch
|
| 221 |
+
launcher = PrepareForLaunch(function, distributed_type="MULTI_GPU")
|
| 222 |
+
print(f"Launching training on {num_processes} GPUs.")
|
| 223 |
+
try:
|
| 224 |
+
if rdzv_conf is None:
|
| 225 |
+
rdzv_conf = {}
|
| 226 |
+
if rdzv_backend == "static":
|
| 227 |
+
rdzv_conf["rank"] = node_rank
|
| 228 |
+
if not rdzv_endpoint:
|
| 229 |
+
rdzv_endpoint = f"{master_addr}:{use_port}"
|
| 230 |
+
launch_config_kwargs = dict(
|
| 231 |
+
min_nodes=num_nodes,
|
| 232 |
+
max_nodes=num_nodes,
|
| 233 |
+
nproc_per_node=num_processes,
|
| 234 |
+
run_id=rdzv_id,
|
| 235 |
+
rdzv_endpoint=rdzv_endpoint,
|
| 236 |
+
rdzv_backend=rdzv_backend,
|
| 237 |
+
rdzv_configs=rdzv_conf,
|
| 238 |
+
max_restarts=max_restarts,
|
| 239 |
+
monitor_interval=monitor_interval,
|
| 240 |
+
start_method="fork",
|
| 241 |
+
)
|
| 242 |
+
if is_torch_version(">=", ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION):
|
| 243 |
+
launch_config_kwargs["log_line_prefix_template"] = log_line_prefix_template
|
| 244 |
+
elastic_launch(config=LaunchConfig(**launch_config_kwargs), entrypoint=function)(*args)
|
| 245 |
+
except ProcessRaisedException as e:
|
| 246 |
+
if "Cannot re-initialize CUDA in forked subprocess" in e.args[0]:
|
| 247 |
+
raise RuntimeError(
|
| 248 |
+
"CUDA has been initialized before the `notebook_launcher` could create a forked subprocess. "
|
| 249 |
+
"This likely stems from an outside import causing issues once the `notebook_launcher()` is called. "
|
| 250 |
+
"Please review your imports and test them when running the `notebook_launcher()` to identify "
|
| 251 |
+
"which one is problematic and causing CUDA to be initialized."
|
| 252 |
+
) from e
|
| 253 |
+
else:
|
| 254 |
+
raise RuntimeError(f"An issue was found when launching the training: {e}") from e
|
| 255 |
+
|
| 256 |
+
else:
|
| 257 |
+
# No need for a distributed launch otherwise as it's either CPU, GPU or MPS.
|
| 258 |
+
if is_mps_available():
|
| 259 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
| 260 |
+
print("Launching training on MPS.")
|
| 261 |
+
elif torch.cuda.is_available():
|
| 262 |
+
print("Launching training on one GPU.")
|
| 263 |
+
else:
|
| 264 |
+
print("Launching training on CPU.")
|
| 265 |
+
function(*args)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def debug_launcher(function, args=(), num_processes=2):
|
| 269 |
+
"""
|
| 270 |
+
Launches a training function using several processes on CPU for debugging purposes.
|
| 271 |
+
|
| 272 |
+
<Tip warning={true}>
|
| 273 |
+
|
| 274 |
+
This function is provided for internal testing and debugging, but it's not intended for real trainings. It will
|
| 275 |
+
only use the CPU.
|
| 276 |
+
|
| 277 |
+
</Tip>
|
| 278 |
+
|
| 279 |
+
Args:
|
| 280 |
+
function (`Callable`):
|
| 281 |
+
The training function to execute.
|
| 282 |
+
args (`Tuple`):
|
| 283 |
+
Tuple of arguments to pass to the function (it will receive `*args`).
|
| 284 |
+
num_processes (`int`, *optional*, defaults to 2):
|
| 285 |
+
The number of processes to use for training.
|
| 286 |
+
"""
|
| 287 |
+
from torch.multiprocessing import start_processes
|
| 288 |
+
|
| 289 |
+
with tempfile.NamedTemporaryFile() as tmp_file:
|
| 290 |
+
# torch.distributed will expect a few environment variable to be here. We set the ones common to each
|
| 291 |
+
# process here (the other ones will be set be the launcher).
|
| 292 |
+
with patch_environment(
|
| 293 |
+
world_size=num_processes,
|
| 294 |
+
master_addr="127.0.0.1",
|
| 295 |
+
master_port="29500",
|
| 296 |
+
accelerate_mixed_precision="no",
|
| 297 |
+
accelerate_debug_rdv_file=tmp_file.name,
|
| 298 |
+
accelerate_use_cpu="yes",
|
| 299 |
+
):
|
| 300 |
+
launcher = PrepareForLaunch(function, debug=True)
|
| 301 |
+
start_processes(launcher, args=args, nprocs=num_processes, start_method="fork")
|
venv/Lib/site-packages/accelerate/local_sgd.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
from accelerate import Accelerator, DistributedType
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class LocalSGD:
|
| 20 |
+
"""
|
| 21 |
+
A helper class to support local SGD on top of Accelerator. It simply runs a given number of updates independently
|
| 22 |
+
on each device, and averages model weights every K synchronization step.
|
| 23 |
+
|
| 24 |
+
It should be used only in the multi-GPU (or multi-CPU) setup without extensions such as DeepSpeed. In particular,
|
| 25 |
+
this is a simple implementation that cannot support scenarios such as model parallelism.
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
Although we are not aware of the true origins of this simple approach, the idea of local SGD is quite old and goes
|
| 29 |
+
back to at least:
|
| 30 |
+
|
| 31 |
+
Zhang, J., De Sa, C., Mitliagkas, I., & Ré, C. (2016). [Parallel SGD: When does averaging help?. arXiv preprint
|
| 32 |
+
arXiv:1606.07365.](https://arxiv.org/abs/1606.07365)
|
| 33 |
+
|
| 34 |
+
We credit the term Local SGD to the following paper (but there might be earlier references we are not aware of).
|
| 35 |
+
|
| 36 |
+
Stich, Sebastian Urban. ["Local SGD Converges Fast and Communicates Little." ICLR 2019-International Conference on
|
| 37 |
+
Learning Representations. No. CONF. 2019.](https://arxiv.org/abs/1805.09767)
|
| 38 |
+
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __enter__(self):
|
| 42 |
+
if self.enabled:
|
| 43 |
+
self.model_sync_obj = self.model.no_sync()
|
| 44 |
+
self.model_sync_obj.__enter__()
|
| 45 |
+
|
| 46 |
+
return self
|
| 47 |
+
|
| 48 |
+
def __exit__(self, type, value, tb):
|
| 49 |
+
if self.enabled:
|
| 50 |
+
# Average all models on exit
|
| 51 |
+
self._sync_and_avg_model_params()
|
| 52 |
+
self.model_sync_obj.__exit__(type, value, tb)
|
| 53 |
+
|
| 54 |
+
def __init__(self, accelerator: Accelerator, model: torch.nn.Module, local_sgd_steps: int, enabled: bool = True):
|
| 55 |
+
"""
|
| 56 |
+
Constructor.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
model (`torch.nn.Module):
|
| 60 |
+
The model whose parameters we need to average.
|
| 61 |
+
accelerator (`Accelerator`):
|
| 62 |
+
Accelerator object.
|
| 63 |
+
local_sgd_steps (`int`):
|
| 64 |
+
A number of local SGD steps (before model parameters are synchronized).
|
| 65 |
+
enabled (`bool):
|
| 66 |
+
Local SGD is disabled if this parameter set to `False`.
|
| 67 |
+
"""
|
| 68 |
+
if accelerator.distributed_type not in [
|
| 69 |
+
DistributedType.NO,
|
| 70 |
+
DistributedType.MULTI_CPU,
|
| 71 |
+
DistributedType.MULTI_GPU,
|
| 72 |
+
DistributedType.MULTI_XPU,
|
| 73 |
+
DistributedType.MULTI_MLU,
|
| 74 |
+
DistributedType.MULTI_HPU,
|
| 75 |
+
DistributedType.MULTI_SDAA,
|
| 76 |
+
DistributedType.MULTI_MUSA,
|
| 77 |
+
DistributedType.MULTI_NPU,
|
| 78 |
+
]:
|
| 79 |
+
raise NotImplementedError("LocalSGD is supported only for CPUs and GPUs (no DeepSpeed or MegatronLM)")
|
| 80 |
+
self.enabled = enabled and accelerator.distributed_type != DistributedType.NO
|
| 81 |
+
self.num_steps = 0
|
| 82 |
+
if self.enabled:
|
| 83 |
+
self.accelerator = accelerator
|
| 84 |
+
self.model = model
|
| 85 |
+
self.local_sgd_steps = local_sgd_steps
|
| 86 |
+
|
| 87 |
+
def step(self):
|
| 88 |
+
"""
|
| 89 |
+
This function makes a "step" and synchronizes model parameters if necessary.
|
| 90 |
+
"""
|
| 91 |
+
self.num_steps += 1
|
| 92 |
+
if not self.enabled:
|
| 93 |
+
return
|
| 94 |
+
|
| 95 |
+
if self.num_steps % self.local_sgd_steps == 0:
|
| 96 |
+
self._sync_and_avg_model_params()
|
| 97 |
+
|
| 98 |
+
def _sync_and_avg_model_params(self):
|
| 99 |
+
"""
|
| 100 |
+
Synchronize + Average model parameters across all GPUs
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
self.accelerator.wait_for_everyone()
|
| 104 |
+
with self.accelerator.autocast():
|
| 105 |
+
for param in self.model.parameters():
|
| 106 |
+
param.data = self.accelerator.reduce(param.data, reduction="mean")
|
venv/Lib/site-packages/accelerate/logging.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import functools
|
| 16 |
+
import logging
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
from .state import PartialState
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class MultiProcessAdapter(logging.LoggerAdapter):
|
| 23 |
+
"""
|
| 24 |
+
An adapter to assist with logging in multiprocess.
|
| 25 |
+
|
| 26 |
+
`log` takes in an additional `main_process_only` kwarg, which dictates whether it should be called on all processes
|
| 27 |
+
or only the main executed one. Default is `main_process_only=True`.
|
| 28 |
+
|
| 29 |
+
Does not require an `Accelerator` object to be created first.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
@staticmethod
|
| 33 |
+
def _should_log(main_process_only):
|
| 34 |
+
"Check if log should be performed"
|
| 35 |
+
state = PartialState()
|
| 36 |
+
return not main_process_only or (main_process_only and state.is_main_process)
|
| 37 |
+
|
| 38 |
+
def log(self, level, msg, *args, **kwargs):
|
| 39 |
+
"""
|
| 40 |
+
Delegates logger call after checking if we should log.
|
| 41 |
+
|
| 42 |
+
Accepts a new kwarg of `main_process_only`, which will dictate whether it will be logged across all processes
|
| 43 |
+
or only the main executed one. Default is `True` if not passed
|
| 44 |
+
|
| 45 |
+
Also accepts "in_order", which if `True` makes the processes log one by one, in order. This is much easier to
|
| 46 |
+
read, but comes at the cost of sometimes needing to wait for the other processes. Default is `False` to not
|
| 47 |
+
break with the previous behavior.
|
| 48 |
+
|
| 49 |
+
`in_order` is ignored if `main_process_only` is passed.
|
| 50 |
+
"""
|
| 51 |
+
if PartialState._shared_state == {}:
|
| 52 |
+
raise RuntimeError(
|
| 53 |
+
"You must initialize the accelerate state by calling either `PartialState()` or `Accelerator()` before using the logging utility."
|
| 54 |
+
)
|
| 55 |
+
main_process_only = kwargs.pop("main_process_only", True)
|
| 56 |
+
in_order = kwargs.pop("in_order", False)
|
| 57 |
+
# set `stacklevel` to exclude ourself in `Logger.findCaller()` while respecting user's choice
|
| 58 |
+
kwargs.setdefault("stacklevel", 2)
|
| 59 |
+
|
| 60 |
+
if self.isEnabledFor(level):
|
| 61 |
+
if self._should_log(main_process_only):
|
| 62 |
+
msg, kwargs = self.process(msg, kwargs)
|
| 63 |
+
self.logger.log(level, msg, *args, **kwargs)
|
| 64 |
+
|
| 65 |
+
elif in_order:
|
| 66 |
+
state = PartialState()
|
| 67 |
+
for i in range(state.num_processes):
|
| 68 |
+
if i == state.process_index:
|
| 69 |
+
msg, kwargs = self.process(msg, kwargs)
|
| 70 |
+
self.logger.log(level, msg, *args, **kwargs)
|
| 71 |
+
state.wait_for_everyone()
|
| 72 |
+
|
| 73 |
+
@functools.lru_cache(None)
|
| 74 |
+
def warning_once(self, *args, **kwargs):
|
| 75 |
+
"""
|
| 76 |
+
This method is identical to `logger.warning()`, but will emit the warning with the same message only once
|
| 77 |
+
|
| 78 |
+
Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the
|
| 79 |
+
cache. The assumption here is that all warning messages are unique across the code. If they aren't then need to
|
| 80 |
+
switch to another type of cache that includes the caller frame information in the hashing function.
|
| 81 |
+
"""
|
| 82 |
+
self.warning(*args, **kwargs)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def get_logger(name: str, log_level: str = None):
|
| 86 |
+
"""
|
| 87 |
+
Returns a `logging.Logger` for `name` that can handle multiprocessing.
|
| 88 |
+
|
| 89 |
+
If a log should be called on all processes, pass `main_process_only=False` If a log should be called on all
|
| 90 |
+
processes and in order, also pass `in_order=True`
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
name (`str`):
|
| 94 |
+
The name for the logger, such as `__file__`
|
| 95 |
+
log_level (`str`, *optional*):
|
| 96 |
+
The log level to use. If not passed, will default to the `LOG_LEVEL` environment variable, or `INFO` if not
|
| 97 |
+
|
| 98 |
+
Example:
|
| 99 |
+
|
| 100 |
+
```python
|
| 101 |
+
>>> from accelerate.logging import get_logger
|
| 102 |
+
>>> from accelerate import Accelerator
|
| 103 |
+
|
| 104 |
+
>>> logger = get_logger(__name__)
|
| 105 |
+
|
| 106 |
+
>>> accelerator = Accelerator()
|
| 107 |
+
>>> logger.info("My log", main_process_only=False)
|
| 108 |
+
>>> logger.debug("My log", main_process_only=True)
|
| 109 |
+
|
| 110 |
+
>>> logger = get_logger(__name__, log_level="DEBUG")
|
| 111 |
+
>>> logger.info("My log")
|
| 112 |
+
>>> logger.debug("My second log")
|
| 113 |
+
|
| 114 |
+
>>> array = ["a", "b", "c", "d"]
|
| 115 |
+
>>> letter_at_rank = array[accelerator.process_index]
|
| 116 |
+
>>> logger.info(letter_at_rank, in_order=True)
|
| 117 |
+
```
|
| 118 |
+
"""
|
| 119 |
+
if log_level is None:
|
| 120 |
+
log_level = os.environ.get("ACCELERATE_LOG_LEVEL", None)
|
| 121 |
+
logger = logging.getLogger(name)
|
| 122 |
+
if log_level is not None:
|
| 123 |
+
logger.setLevel(log_level.upper())
|
| 124 |
+
logger.root.setLevel(log_level.upper())
|
| 125 |
+
return MultiProcessAdapter(logger, {})
|
venv/Lib/site-packages/accelerate/memory_utils.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import warnings
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
warnings.warn(
|
| 19 |
+
"memory_utils has been reorganized to utils.memory. Import `find_executable_batchsize` from the main `__init__`: "
|
| 20 |
+
"`from accelerate import find_executable_batch_size` to avoid this warning.",
|
| 21 |
+
FutureWarning,
|
| 22 |
+
)
|
venv/Lib/site-packages/accelerate/optimizer.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
from .state import AcceleratorState, GradientState
|
| 20 |
+
from .utils import DistributedType, honor_type, is_lomo_available, is_torch_xla_available
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
if is_torch_xla_available():
|
| 24 |
+
import torch_xla.core.xla_model as xm
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def move_to_device(state, device):
|
| 28 |
+
if isinstance(state, (list, tuple)):
|
| 29 |
+
return honor_type(state, (move_to_device(t, device) for t in state))
|
| 30 |
+
elif isinstance(state, dict):
|
| 31 |
+
return type(state)({k: move_to_device(v, device) for k, v in state.items()})
|
| 32 |
+
elif isinstance(state, torch.Tensor):
|
| 33 |
+
return state.to(device)
|
| 34 |
+
return state
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class AcceleratedOptimizer(torch.optim.Optimizer):
|
| 38 |
+
"""
|
| 39 |
+
Internal wrapper around a torch optimizer.
|
| 40 |
+
|
| 41 |
+
Conditionally will perform `step` and `zero_grad` if gradients should be synchronized when performing gradient
|
| 42 |
+
accumulation.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
optimizer (`torch.optim.optimizer.Optimizer`):
|
| 46 |
+
The optimizer to wrap.
|
| 47 |
+
device_placement (`bool`, *optional*, defaults to `True`):
|
| 48 |
+
Whether or not the optimizer should handle device placement. If so, it will place the state dictionary of
|
| 49 |
+
`optimizer` on the right device.
|
| 50 |
+
scaler (`torch.cuda.amp.grad_scaler.GradScaler`, *optional*):
|
| 51 |
+
The scaler to use in the step function if training with mixed precision.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(self, optimizer, device_placement=True, scaler=None):
|
| 55 |
+
self.optimizer = optimizer
|
| 56 |
+
self.scaler = scaler
|
| 57 |
+
self.accelerator_state = AcceleratorState()
|
| 58 |
+
self.gradient_state = GradientState()
|
| 59 |
+
self.device_placement = device_placement
|
| 60 |
+
self._is_overflow = False
|
| 61 |
+
|
| 62 |
+
if self.scaler is not None:
|
| 63 |
+
self._accelerate_step_called = False
|
| 64 |
+
self._optimizer_original_step_method = self.optimizer.step
|
| 65 |
+
self._optimizer_patched_step_method = patch_optimizer_step(self, self.optimizer.step)
|
| 66 |
+
|
| 67 |
+
# Handle device placement
|
| 68 |
+
if device_placement:
|
| 69 |
+
state_dict = self.optimizer.state_dict()
|
| 70 |
+
if self.accelerator_state.distributed_type == DistributedType.XLA:
|
| 71 |
+
xm.send_cpu_data_to_device(state_dict, self.accelerator_state.device)
|
| 72 |
+
else:
|
| 73 |
+
state_dict = move_to_device(state_dict, self.accelerator_state.device)
|
| 74 |
+
self.optimizer.load_state_dict(state_dict)
|
| 75 |
+
|
| 76 |
+
@property
|
| 77 |
+
def state(self):
|
| 78 |
+
return self.optimizer.state
|
| 79 |
+
|
| 80 |
+
@state.setter
|
| 81 |
+
def state(self, state):
|
| 82 |
+
self.optimizer.state = state
|
| 83 |
+
|
| 84 |
+
@property
|
| 85 |
+
def param_groups(self):
|
| 86 |
+
return self.optimizer.param_groups
|
| 87 |
+
|
| 88 |
+
@param_groups.setter
|
| 89 |
+
def param_groups(self, param_groups):
|
| 90 |
+
self.optimizer.param_groups = param_groups
|
| 91 |
+
|
| 92 |
+
@property
|
| 93 |
+
def defaults(self):
|
| 94 |
+
return self.optimizer.defaults
|
| 95 |
+
|
| 96 |
+
@defaults.setter
|
| 97 |
+
def defaults(self, defaults):
|
| 98 |
+
self.optimizer.defaults = defaults
|
| 99 |
+
|
| 100 |
+
def add_param_group(self, param_group):
|
| 101 |
+
self.optimizer.add_param_group(param_group)
|
| 102 |
+
|
| 103 |
+
def load_state_dict(self, state_dict):
|
| 104 |
+
if self.accelerator_state.distributed_type == DistributedType.XLA and self.device_placement:
|
| 105 |
+
xm.send_cpu_data_to_device(state_dict, self.accelerator_state.device)
|
| 106 |
+
self.optimizer.load_state_dict(state_dict)
|
| 107 |
+
|
| 108 |
+
def state_dict(self):
|
| 109 |
+
return self.optimizer.state_dict()
|
| 110 |
+
|
| 111 |
+
def zero_grad(self, set_to_none=None):
|
| 112 |
+
if self.gradient_state.sync_gradients:
|
| 113 |
+
accept_arg = "set_to_none" in inspect.signature(self.optimizer.zero_grad).parameters
|
| 114 |
+
if accept_arg:
|
| 115 |
+
if set_to_none is None:
|
| 116 |
+
set_to_none = True
|
| 117 |
+
self.optimizer.zero_grad(set_to_none=set_to_none)
|
| 118 |
+
else:
|
| 119 |
+
if set_to_none is not None:
|
| 120 |
+
raise ValueError("`set_to_none` for Optimizer.zero_grad` is not supported by this optimizer.")
|
| 121 |
+
self.optimizer.zero_grad()
|
| 122 |
+
|
| 123 |
+
def train(self):
|
| 124 |
+
"""
|
| 125 |
+
Sets the optimizer to "train" mode. Useful for optimizers like `schedule_free`
|
| 126 |
+
"""
|
| 127 |
+
if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
|
| 128 |
+
self.optimizer.train()
|
| 129 |
+
elif (
|
| 130 |
+
hasattr(self.optimizer, "optimizer")
|
| 131 |
+
and hasattr(self.optimizer.optimizer, "train")
|
| 132 |
+
and callable(self.optimizer.optimizer.train)
|
| 133 |
+
):
|
| 134 |
+
# the deepspeed optimizer further wraps the optimizer
|
| 135 |
+
self.optimizer.optimizer.train()
|
| 136 |
+
|
| 137 |
+
def eval(self):
|
| 138 |
+
"""
|
| 139 |
+
Sets the optimizer to "eval" mode. Useful for optimizers like `schedule_free`
|
| 140 |
+
"""
|
| 141 |
+
if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval):
|
| 142 |
+
self.optimizer.eval()
|
| 143 |
+
|
| 144 |
+
def step(self, closure=None):
|
| 145 |
+
if is_lomo_available():
|
| 146 |
+
from lomo_optim import AdaLomo, Lomo
|
| 147 |
+
|
| 148 |
+
if (
|
| 149 |
+
not self.gradient_state.is_xla_gradients_synced
|
| 150 |
+
and self.accelerator_state.distributed_type == DistributedType.XLA
|
| 151 |
+
):
|
| 152 |
+
gradients = xm._fetch_gradients(self.optimizer)
|
| 153 |
+
xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
|
| 154 |
+
self.gradient_state.is_xla_gradients_synced = True
|
| 155 |
+
|
| 156 |
+
if is_lomo_available():
|
| 157 |
+
# `step` should be a no-op for LOMO optimizers.
|
| 158 |
+
if isinstance(self.optimizer, (Lomo, AdaLomo)):
|
| 159 |
+
return
|
| 160 |
+
|
| 161 |
+
if self.gradient_state.sync_gradients:
|
| 162 |
+
if self.scaler is not None:
|
| 163 |
+
self.optimizer.step = self._optimizer_patched_step_method
|
| 164 |
+
|
| 165 |
+
self.scaler.step(self.optimizer, closure)
|
| 166 |
+
self.scaler.update()
|
| 167 |
+
|
| 168 |
+
if not self._accelerate_step_called:
|
| 169 |
+
# If the optimizer step was skipped, gradient overflow was detected.
|
| 170 |
+
self._is_overflow = True
|
| 171 |
+
else:
|
| 172 |
+
self._is_overflow = False
|
| 173 |
+
# Reset the step method to the original one
|
| 174 |
+
self.optimizer.step = self._optimizer_original_step_method
|
| 175 |
+
# Reset the indicator
|
| 176 |
+
self._accelerate_step_called = False
|
| 177 |
+
else:
|
| 178 |
+
self.optimizer.step(closure)
|
| 179 |
+
if self.accelerator_state.distributed_type == DistributedType.XLA:
|
| 180 |
+
self.gradient_state.is_xla_gradients_synced = False
|
| 181 |
+
|
| 182 |
+
def _switch_parameters(self, parameters_map):
|
| 183 |
+
for param_group in self.optimizer.param_groups:
|
| 184 |
+
param_group["params"] = [parameters_map.get(p, p) for p in param_group["params"]]
|
| 185 |
+
|
| 186 |
+
@property
|
| 187 |
+
def step_was_skipped(self):
|
| 188 |
+
"""Whether or not the optimizer step was skipped."""
|
| 189 |
+
return self._is_overflow
|
| 190 |
+
|
| 191 |
+
def __getstate__(self):
|
| 192 |
+
_ignored_keys = [
|
| 193 |
+
"_accelerate_step_called",
|
| 194 |
+
"_optimizer_original_step_method",
|
| 195 |
+
"_optimizer_patched_step_method",
|
| 196 |
+
]
|
| 197 |
+
return {k: v for k, v in self.__dict__.items() if k not in _ignored_keys}
|
| 198 |
+
|
| 199 |
+
def __setstate__(self, state):
|
| 200 |
+
self.__dict__.update(state)
|
| 201 |
+
if self.scaler is not None:
|
| 202 |
+
self._accelerate_step_called = False
|
| 203 |
+
self._optimizer_original_step_method = self.optimizer.step
|
| 204 |
+
self._optimizer_patched_step_method = patch_optimizer_step(self, self.optimizer.step)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def patch_optimizer_step(accelerated_optimizer: AcceleratedOptimizer, method):
|
| 208 |
+
def patched_step(*args, **kwargs):
|
| 209 |
+
accelerated_optimizer._accelerate_step_called = True
|
| 210 |
+
return method(*args, **kwargs)
|
| 211 |
+
|
| 212 |
+
return patched_step
|
venv/Lib/site-packages/accelerate/scheduler.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# We ignore warnings about stepping the scheduler since we step it ourselves during gradient accumulation
|
| 16 |
+
|
| 17 |
+
import warnings
|
| 18 |
+
|
| 19 |
+
from .state import AcceleratorState, GradientState
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="torch.optim.lr_scheduler")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class AcceleratedScheduler:
|
| 26 |
+
"""
|
| 27 |
+
A wrapper around a learning rate scheduler that will only step when the optimizer(s) have a training step. Useful
|
| 28 |
+
to avoid making a scheduler step too fast when gradients went overflow and there was no training step (in mixed
|
| 29 |
+
precision training)
|
| 30 |
+
|
| 31 |
+
When performing gradient accumulation scheduler lengths should not be changed accordingly, Accelerate will always
|
| 32 |
+
step the scheduler to account for it.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
scheduler (`torch.optim.lr_scheduler._LRScheduler`):
|
| 36 |
+
The scheduler to wrap.
|
| 37 |
+
optimizers (one or a list of `torch.optim.Optimizer`):
|
| 38 |
+
The optimizers used.
|
| 39 |
+
step_with_optimizer (`bool`, *optional*, defaults to `True`):
|
| 40 |
+
Whether or not the scheduler should be stepped at each optimizer step.
|
| 41 |
+
split_batches (`bool`, *optional*, defaults to `False`):
|
| 42 |
+
Whether or not the dataloaders split one batch across the different processes (so batch size is the same
|
| 43 |
+
regardless of the number of processes) or create batches on each process (so batch size is the original
|
| 44 |
+
batch size multiplied by the number of processes).
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(self, scheduler, optimizers, step_with_optimizer: bool = True, split_batches: bool = False):
|
| 48 |
+
self.scheduler = scheduler
|
| 49 |
+
self.optimizers = optimizers if isinstance(optimizers, (list, tuple)) else [optimizers]
|
| 50 |
+
self.split_batches = split_batches
|
| 51 |
+
self.step_with_optimizer = step_with_optimizer
|
| 52 |
+
self.gradient_state = GradientState()
|
| 53 |
+
|
| 54 |
+
def step(self, *args, **kwargs):
|
| 55 |
+
if not self.step_with_optimizer:
|
| 56 |
+
# No link between scheduler and optimizer -> just step
|
| 57 |
+
self.scheduler.step(*args, **kwargs)
|
| 58 |
+
return
|
| 59 |
+
|
| 60 |
+
# Otherwise, first make sure the optimizer was stepped.
|
| 61 |
+
if not self.gradient_state.sync_gradients:
|
| 62 |
+
if self.gradient_state.adjust_scheduler:
|
| 63 |
+
self.scheduler._step_count += 1
|
| 64 |
+
return
|
| 65 |
+
|
| 66 |
+
for opt in self.optimizers:
|
| 67 |
+
if opt.step_was_skipped:
|
| 68 |
+
return
|
| 69 |
+
if self.split_batches:
|
| 70 |
+
# Split batches -> the training dataloader batch size is not changed so one step per training step
|
| 71 |
+
self.scheduler.step(*args, **kwargs)
|
| 72 |
+
else:
|
| 73 |
+
# Otherwise the training dataloader batch size was multiplied by `num_processes`, so we need to do
|
| 74 |
+
# num_processes steps per training step
|
| 75 |
+
num_processes = AcceleratorState().num_processes
|
| 76 |
+
for _ in range(num_processes):
|
| 77 |
+
# Special case when using OneCycle and `drop_last` was not used
|
| 78 |
+
if hasattr(self.scheduler, "total_steps"):
|
| 79 |
+
if self.scheduler._step_count <= self.scheduler.total_steps:
|
| 80 |
+
self.scheduler.step(*args, **kwargs)
|
| 81 |
+
else:
|
| 82 |
+
self.scheduler.step(*args, **kwargs)
|
| 83 |
+
|
| 84 |
+
# Passthroughs
|
| 85 |
+
def get_last_lr(self):
|
| 86 |
+
return self.scheduler.get_last_lr()
|
| 87 |
+
|
| 88 |
+
def state_dict(self):
|
| 89 |
+
return self.scheduler.state_dict()
|
| 90 |
+
|
| 91 |
+
def load_state_dict(self, state_dict):
|
| 92 |
+
self.scheduler.load_state_dict(state_dict)
|
| 93 |
+
|
| 94 |
+
def get_lr(self):
|
| 95 |
+
return self.scheduler.get_lr()
|
| 96 |
+
|
| 97 |
+
def print_lr(self, *args, **kwargs):
|
| 98 |
+
return self.scheduler.print_lr(*args, **kwargs)
|
venv/Lib/site-packages/accelerate/state.py
ADDED
|
@@ -0,0 +1,1330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import logging
|
| 18 |
+
import os
|
| 19 |
+
import threading
|
| 20 |
+
import warnings
|
| 21 |
+
import weakref
|
| 22 |
+
from contextlib import contextmanager
|
| 23 |
+
from functools import partial
|
| 24 |
+
from typing import Any, Callable
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
|
| 28 |
+
from .utils import (
|
| 29 |
+
DistributedType,
|
| 30 |
+
DynamoBackend,
|
| 31 |
+
GradientAccumulationPlugin,
|
| 32 |
+
check_cuda_fp8_capability,
|
| 33 |
+
check_cuda_p2p_ib_support,
|
| 34 |
+
deepspeed_required,
|
| 35 |
+
get_ccl_version,
|
| 36 |
+
get_cpu_distributed_information,
|
| 37 |
+
get_int_from_env,
|
| 38 |
+
is_ccl_available,
|
| 39 |
+
is_datasets_available,
|
| 40 |
+
is_deepspeed_available,
|
| 41 |
+
is_fp8_available,
|
| 42 |
+
is_habana_gaudi1,
|
| 43 |
+
is_hpu_available,
|
| 44 |
+
is_ipex_available,
|
| 45 |
+
is_mlu_available,
|
| 46 |
+
is_mps_available,
|
| 47 |
+
is_musa_available,
|
| 48 |
+
is_npu_available,
|
| 49 |
+
is_sdaa_available,
|
| 50 |
+
is_torch_xla_available,
|
| 51 |
+
is_xccl_available,
|
| 52 |
+
is_xpu_available,
|
| 53 |
+
parse_choice_from_env,
|
| 54 |
+
parse_flag_from_env,
|
| 55 |
+
set_numa_affinity,
|
| 56 |
+
)
|
| 57 |
+
from .utils.dataclasses import SageMakerDistributedType
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
if is_torch_xla_available():
|
| 61 |
+
import torch_xla.core.xla_model as xm
|
| 62 |
+
|
| 63 |
+
if is_mlu_available(check_device=False):
|
| 64 |
+
import torch_mlu # noqa: F401
|
| 65 |
+
|
| 66 |
+
if is_sdaa_available(check_device=False):
|
| 67 |
+
import torch_sdaa # noqa: F401
|
| 68 |
+
|
| 69 |
+
if is_musa_available(check_device=False):
|
| 70 |
+
import torch_musa # noqa: F401
|
| 71 |
+
|
| 72 |
+
if is_npu_available(check_device=False):
|
| 73 |
+
import torch_npu # noqa: F401
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
logger = logging.getLogger(__name__)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def is_initialized() -> bool:
|
| 80 |
+
"""
|
| 81 |
+
Checks if the `AcceleratorState` has been initialized from `Accelerator`. Same as `AcceleratorState.initialized`,
|
| 82 |
+
but works as a module method.
|
| 83 |
+
"""
|
| 84 |
+
return AcceleratorState._shared_state != {}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# Lambda function that does nothing
|
| 88 |
+
def do_nothing(*args, **kwargs):
|
| 89 |
+
return None
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class ThreadLocalSharedDict(threading.local):
|
| 93 |
+
"""
|
| 94 |
+
Descriptor that holds a dict shared between instances of a class in the same thread.
|
| 95 |
+
|
| 96 |
+
Note: Descriptors have slightly different semantics than just a dict field on its own.
|
| 97 |
+
`PartialState(...)._shared_state` and `PartialState._shared_state` (instance vs class) give the same value: the
|
| 98 |
+
underlying _storage dict. Likewise, `PartialState(...)._shared_state = {...}` overrides the _storage dict inside
|
| 99 |
+
the descriptor as you would expect. However, `PartialState._shared_state = {}` actually replaces the descriptor
|
| 100 |
+
object with a dict instead Thus, you should modify the _storage dict in-place (e.g. `_shared_state.clear()`).
|
| 101 |
+
|
| 102 |
+
See Python documentation for an explanation of descriptors: https://docs.python.org/3/howto/descriptor.html
|
| 103 |
+
|
| 104 |
+
This is required for using PyTorch/XLA with PJRT in multithreaded mode (required for TPU v2 and v3).
|
| 105 |
+
|
| 106 |
+
See https://github.com/pytorch/xla/blob/r2.0/docs/pjrt.md#multithreading-on-tpu-v2v3
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
def __init__(self, thread_local: bool = False):
|
| 110 |
+
self._storage = {}
|
| 111 |
+
|
| 112 |
+
def __get__(self, obj, objtype=None):
|
| 113 |
+
return self._storage
|
| 114 |
+
|
| 115 |
+
def __set__(self, obj, value):
|
| 116 |
+
self._storage = value
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# Prefer global shared dictionary, except when using TPU.
|
| 120 |
+
SharedDict = dict if not is_torch_xla_available() else ThreadLocalSharedDict
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
# Inspired by Alex Martelli's 'Borg'.
|
| 124 |
+
class PartialState:
|
| 125 |
+
"""
|
| 126 |
+
Singleton class that has information about the current training environment and functions to help with process
|
| 127 |
+
control. Designed to be used when only process control and device execution states are needed. Does *not* need to
|
| 128 |
+
be initialized from `Accelerator`.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
cpu (`bool`, *optional*):
|
| 132 |
+
Whether or not to force the script to execute on CPU. Will ignore any accelerators available if set to
|
| 133 |
+
`True` and force the execution on the CPU.
|
| 134 |
+
kwargs (additional keyword arguments, *optional*):
|
| 135 |
+
Additional keyword arguments to pass to the relevent `init_process_group` function. Valid `kwargs` can be
|
| 136 |
+
found in [`utils.InitProcessGroupKwargs`]. See the example section for detailed usage.
|
| 137 |
+
|
| 138 |
+
**Available attributes:**
|
| 139 |
+
|
| 140 |
+
- **device** (`torch.device`) -- The device to use.
|
| 141 |
+
- **distributed_type** ([`~accelerate.state.DistributedType`]) -- The type of distributed environment currently
|
| 142 |
+
in use.
|
| 143 |
+
- **local_process_index** (`int`) -- The index of the current process on the current server.
|
| 144 |
+
- **mixed_precision** (`str`) -- Whether or not the current script will use mixed precision, and if so the type
|
| 145 |
+
of mixed precision being performed. (Choose from 'no','fp16','bf16 or 'fp8').
|
| 146 |
+
- **num_processes** (`int`) -- The number of processes currently launched in parallel.
|
| 147 |
+
- **process_index** (`int`) -- The index of the current process.
|
| 148 |
+
- **is_last_process** (`bool`) -- Whether or not the current process is the last one.
|
| 149 |
+
- **is_main_process** (`bool`) -- Whether or not the current process is the main one.
|
| 150 |
+
- **is_local_main_process** (`bool`) -- Whether or not the current process is the main one on the local node.
|
| 151 |
+
- **debug** (`bool`) -- Whether or not the current script is being run in debug mode.
|
| 152 |
+
|
| 153 |
+
Example:
|
| 154 |
+
```python
|
| 155 |
+
from accelerate.utils import InitProcessGroupKwargs
|
| 156 |
+
|
| 157 |
+
# To include `InitProcessGroupKwargs`, init then call `.to_kwargs()`
|
| 158 |
+
kwargs = InitProcessGroupKwargs(...).to_kwargs()
|
| 159 |
+
state = PartialState(**kwargs)
|
| 160 |
+
```
|
| 161 |
+
"""
|
| 162 |
+
|
| 163 |
+
_shared_state = SharedDict()
|
| 164 |
+
_known_attrs = [
|
| 165 |
+
"_cpu",
|
| 166 |
+
"_mixed_precision",
|
| 167 |
+
"_shared_state",
|
| 168 |
+
"backend",
|
| 169 |
+
"debug",
|
| 170 |
+
"device",
|
| 171 |
+
"distributed_type",
|
| 172 |
+
"fork_launched",
|
| 173 |
+
"local_process_index",
|
| 174 |
+
"num_processes",
|
| 175 |
+
"process_index",
|
| 176 |
+
]
|
| 177 |
+
|
| 178 |
+
def __init__(self, cpu: bool = False, **kwargs):
|
| 179 |
+
self.__dict__ = self._shared_state
|
| 180 |
+
if not self.initialized:
|
| 181 |
+
self._cpu = cpu
|
| 182 |
+
self.backend = None
|
| 183 |
+
env_device = os.environ.get("ACCELERATE_TORCH_DEVICE", None)
|
| 184 |
+
self.device = torch.device(env_device) if env_device is not None else None
|
| 185 |
+
self.debug = parse_flag_from_env("ACCELERATE_DEBUG_MODE")
|
| 186 |
+
use_sagemaker_dp = kwargs.pop("_use_sagemaker_dp", None)
|
| 187 |
+
dist_information = None
|
| 188 |
+
if use_sagemaker_dp is None:
|
| 189 |
+
use_sagemaker_dp = (
|
| 190 |
+
os.environ.get("ACCELERATE_USE_SAGEMAKER", "false") == "true"
|
| 191 |
+
and os.environ.get("ACCELERATE_SAGEMAKER_DISTRIBUTED_TYPE") != SageMakerDistributedType.NO
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# Sets up self.backend + imports
|
| 195 |
+
original_backend = kwargs.pop("backend", None)
|
| 196 |
+
backend, distributed_type = self._prepare_backend(cpu, use_sagemaker_dp, original_backend)
|
| 197 |
+
if original_backend is not None and backend != original_backend:
|
| 198 |
+
raise ValueError(f"Your assigned backend {original_backend} is not avaliable, please use {backend}")
|
| 199 |
+
self.backend = backend
|
| 200 |
+
self.distributed_type = distributed_type
|
| 201 |
+
use_deepspeed = False
|
| 202 |
+
if not cpu and self.backend != "xla":
|
| 203 |
+
if int(os.environ.get("LOCAL_RANK", -1)) != -1:
|
| 204 |
+
# Deal with spawning deepspeed
|
| 205 |
+
if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
|
| 206 |
+
if not is_deepspeed_available():
|
| 207 |
+
raise ImportError(
|
| 208 |
+
"DeepSpeed is not available => install it using `pip3 install deepspeed` or build it from source"
|
| 209 |
+
)
|
| 210 |
+
from deepspeed import comm as dist
|
| 211 |
+
|
| 212 |
+
if not dist.is_initialized():
|
| 213 |
+
if self.backend == "tccl":
|
| 214 |
+
local_rank = os.environ.get("LOCAL_RANK", -1)
|
| 215 |
+
torch.sdaa.set_device(f"sdaa:{local_rank}")
|
| 216 |
+
dist.init_distributed(dist_backend=self.backend, auto_mpi_discovery=False, **kwargs)
|
| 217 |
+
# We need to flag to `use_deepspeed` to be True to override `distributed_type` later
|
| 218 |
+
use_deepspeed = True
|
| 219 |
+
# Deal with all other backends but XPU and CPU, that gets handled special later
|
| 220 |
+
elif (
|
| 221 |
+
self.distributed_type not in (DistributedType.MULTI_XPU, DistributedType.MULTI_CPU)
|
| 222 |
+
and not torch.distributed.is_initialized()
|
| 223 |
+
):
|
| 224 |
+
if self.backend == "tccl":
|
| 225 |
+
local_rank = os.environ.get("LOCAL_RANK", -1)
|
| 226 |
+
torch.sdaa.set_device(f"sdaa:{local_rank}")
|
| 227 |
+
torch.distributed.init_process_group(backend=self.backend, **kwargs)
|
| 228 |
+
|
| 229 |
+
# XPU and CPU require special env configs to be set
|
| 230 |
+
if self.distributed_type in (DistributedType.MULTI_XPU, DistributedType.MULTI_CPU):
|
| 231 |
+
dist_information = get_cpu_distributed_information()
|
| 232 |
+
os.environ["RANK"] = str(dist_information.rank)
|
| 233 |
+
os.environ["WORLD_SIZE"] = str(dist_information.world_size)
|
| 234 |
+
os.environ["LOCAL_RANK"] = str(dist_information.local_rank)
|
| 235 |
+
os.environ["LOCAL_WORLD_SIZE"] = str(dist_information.local_world_size)
|
| 236 |
+
if not os.environ.get("MASTER_PORT", None):
|
| 237 |
+
os.environ["MASTER_PORT"] = "29500"
|
| 238 |
+
if (
|
| 239 |
+
not os.environ.get("MASTER_ADDR", None)
|
| 240 |
+
and dist_information.local_world_size != dist_information.world_size
|
| 241 |
+
and self.backend != "mpi"
|
| 242 |
+
):
|
| 243 |
+
raise ValueError(
|
| 244 |
+
"Tried to launch on distributed with multinode, but `MASTER_ADDR` env was not set, "
|
| 245 |
+
"please try exporting rank 0's hostname as `MASTER_ADDR`"
|
| 246 |
+
)
|
| 247 |
+
kwargs["rank"] = dist_information.rank
|
| 248 |
+
kwargs["world_size"] = dist_information.world_size
|
| 249 |
+
|
| 250 |
+
if (
|
| 251 |
+
self.distributed_type == DistributedType.MULTI_CPU
|
| 252 |
+
and get_int_from_env(["OMP_NUM_THREADS"], 0) == 0
|
| 253 |
+
):
|
| 254 |
+
import psutil
|
| 255 |
+
|
| 256 |
+
num_cpu_threads_per_process = int(
|
| 257 |
+
psutil.cpu_count(logical=False) / dist_information.local_world_size
|
| 258 |
+
)
|
| 259 |
+
if num_cpu_threads_per_process == 0:
|
| 260 |
+
num_cpu_threads_per_process = 1
|
| 261 |
+
torch.set_num_threads(num_cpu_threads_per_process)
|
| 262 |
+
warnings.warn(
|
| 263 |
+
f"OMP_NUM_THREADS/MKL_NUM_THREADS unset, we set it at {num_cpu_threads_per_process} to improve oob"
|
| 264 |
+
" performance."
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
if not torch.distributed.is_initialized():
|
| 268 |
+
torch.distributed.init_process_group(backend=self.backend, **kwargs)
|
| 269 |
+
|
| 270 |
+
# No backend == no distributed training
|
| 271 |
+
if self.backend is None:
|
| 272 |
+
self.distributed_type = DistributedType.NO
|
| 273 |
+
self.num_processes = 1
|
| 274 |
+
self.process_index = 0
|
| 275 |
+
self.local_process_index = 0
|
| 276 |
+
elif self.backend == "xla":
|
| 277 |
+
# XLA needs device setting first for `set_replication`
|
| 278 |
+
self.set_device()
|
| 279 |
+
xm.set_replication(self.device, xm.get_xla_supported_devices())
|
| 280 |
+
self.num_processes = xm.xrt_world_size()
|
| 281 |
+
self.process_index = xm.get_ordinal()
|
| 282 |
+
if is_torch_xla_available(check_is_tpu=True):
|
| 283 |
+
self.local_process_index = xm.get_local_ordinal()
|
| 284 |
+
else:
|
| 285 |
+
self.local_process_index = int(os.environ.get("LOCAL_RANK", -1))
|
| 286 |
+
else:
|
| 287 |
+
self.num_processes = torch.distributed.get_world_size()
|
| 288 |
+
self.process_index = torch.distributed.get_rank()
|
| 289 |
+
self.local_process_index = (
|
| 290 |
+
int(os.environ.get("LOCAL_RANK", -1)) if dist_information is None else dist_information.local_rank
|
| 291 |
+
)
|
| 292 |
+
self.set_device()
|
| 293 |
+
# Now we can change to deepseed
|
| 294 |
+
if use_deepspeed:
|
| 295 |
+
self.distributed_type = DistributedType.DEEPSPEED
|
| 296 |
+
|
| 297 |
+
# Set CPU affinity if enabled
|
| 298 |
+
if parse_flag_from_env("ACCELERATE_CPU_AFFINITY", False):
|
| 299 |
+
set_numa_affinity(self.local_process_index)
|
| 300 |
+
|
| 301 |
+
# Check for old RTX 4000's that can't use P2P or IB and are on old drivers
|
| 302 |
+
if self.device.type == "cuda" and not check_cuda_p2p_ib_support():
|
| 303 |
+
if "NCCL_P2P_DISABLE" not in os.environ or "NCCL_IB_DISABLE" not in os.environ:
|
| 304 |
+
raise NotImplementedError(
|
| 305 |
+
"Using RTX 4000 series doesn't support faster communication broadband via P2P or IB. "
|
| 306 |
+
'Please set `NCCL_P2P_DISABLE="1"` and `NCCL_IB_DISABLE="1" or use `accelerate launch` which '
|
| 307 |
+
"will do this automatically."
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
# Important: This should be the *only* code outside of `self.initialized!`
|
| 311 |
+
self.fork_launched = parse_flag_from_env("FORK_LAUNCHED", 0)
|
| 312 |
+
|
| 313 |
+
def __repr__(self) -> str:
|
| 314 |
+
return (
|
| 315 |
+
f"Distributed environment: {self.distributed_type}{(' Backend: ' + self.backend) if self.backend else ''}\n"
|
| 316 |
+
f"Num processes: {self.num_processes}\n"
|
| 317 |
+
f"Process index: {self.process_index}\n"
|
| 318 |
+
f"Local process index: {self.local_process_index}\n"
|
| 319 |
+
f"Device: {self.device}\n"
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
@staticmethod
|
| 323 |
+
def _reset_state():
|
| 324 |
+
"Resets `_shared_state`, is used internally and should not be called"
|
| 325 |
+
PartialState._shared_state.clear()
|
| 326 |
+
|
| 327 |
+
@property
|
| 328 |
+
def initialized(self) -> bool:
|
| 329 |
+
"Returns whether the `PartialState` has been initialized"
|
| 330 |
+
return self._shared_state != {}
|
| 331 |
+
|
| 332 |
+
@property
|
| 333 |
+
def use_distributed(self):
|
| 334 |
+
"""
|
| 335 |
+
Whether the Accelerator is configured for distributed training
|
| 336 |
+
"""
|
| 337 |
+
return self.distributed_type != DistributedType.NO and self.num_processes > 1
|
| 338 |
+
|
| 339 |
+
@property
|
| 340 |
+
def is_last_process(self) -> bool:
|
| 341 |
+
"Returns whether the current process is the last one"
|
| 342 |
+
return self.process_index == self.num_processes - 1
|
| 343 |
+
|
| 344 |
+
@property
|
| 345 |
+
def is_main_process(self) -> bool:
|
| 346 |
+
"Returns whether the current process is the main process"
|
| 347 |
+
return (
|
| 348 |
+
self.process_index == 0 if self.distributed_type != DistributedType.MEGATRON_LM else self.is_last_process
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
@property
|
| 352 |
+
def is_local_main_process(self) -> bool:
|
| 353 |
+
"Returns whether the current process is the main process on the local node"
|
| 354 |
+
return (
|
| 355 |
+
self.local_process_index == 0
|
| 356 |
+
if self.distributed_type != DistributedType.MEGATRON_LM
|
| 357 |
+
else self.is_last_process
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
def wait_for_everyone(self):
|
| 361 |
+
"""
|
| 362 |
+
Will stop the execution of the current process until every other process has reached that point (so this does
|
| 363 |
+
nothing when the script is only run in one process). Useful to do before saving a model.
|
| 364 |
+
|
| 365 |
+
Example:
|
| 366 |
+
|
| 367 |
+
```python
|
| 368 |
+
>>> # Assuming two GPU processes
|
| 369 |
+
>>> import time
|
| 370 |
+
>>> from accelerate.state import PartialState
|
| 371 |
+
|
| 372 |
+
>>> state = PartialState()
|
| 373 |
+
>>> if state.is_main_process:
|
| 374 |
+
... time.sleep(2)
|
| 375 |
+
>>> else:
|
| 376 |
+
... print("I'm waiting for the main process to finish its sleep...")
|
| 377 |
+
>>> state.wait_for_everyone()
|
| 378 |
+
>>> # Should print on every process at the same time
|
| 379 |
+
>>> print("Everyone is here")
|
| 380 |
+
```
|
| 381 |
+
"""
|
| 382 |
+
if self.distributed_type in (
|
| 383 |
+
DistributedType.MULTI_GPU,
|
| 384 |
+
DistributedType.MULTI_MLU,
|
| 385 |
+
DistributedType.MULTI_SDAA,
|
| 386 |
+
DistributedType.MULTI_MUSA,
|
| 387 |
+
DistributedType.MULTI_NPU,
|
| 388 |
+
DistributedType.MULTI_XPU,
|
| 389 |
+
DistributedType.MULTI_CPU,
|
| 390 |
+
DistributedType.MULTI_HPU,
|
| 391 |
+
DistributedType.DEEPSPEED,
|
| 392 |
+
DistributedType.FSDP,
|
| 393 |
+
):
|
| 394 |
+
torch.distributed.barrier()
|
| 395 |
+
elif self.distributed_type == DistributedType.XLA:
|
| 396 |
+
xm.rendezvous("accelerate.utils.wait_for_everyone")
|
| 397 |
+
|
| 398 |
+
def _goes_first(self, is_main: bool):
|
| 399 |
+
if not is_main:
|
| 400 |
+
self.wait_for_everyone()
|
| 401 |
+
|
| 402 |
+
yield
|
| 403 |
+
|
| 404 |
+
if is_main:
|
| 405 |
+
self.wait_for_everyone()
|
| 406 |
+
|
| 407 |
+
@contextmanager
|
| 408 |
+
def split_between_processes(self, inputs: list | tuple | dict | torch.Tensor, apply_padding: bool = False):
|
| 409 |
+
"""
|
| 410 |
+
Splits `input` between `self.num_processes` quickly and can be then used on that process. Useful when doing
|
| 411 |
+
distributed inference, such as with different prompts.
|
| 412 |
+
|
| 413 |
+
Note that when using a `dict`, all keys need to have the same number of elements.
|
| 414 |
+
|
| 415 |
+
Args:
|
| 416 |
+
inputs (`list`, `tuple`, `torch.Tensor`, `dict` of `list`/`tuple`/`torch.Tensor`, or `datasets.Dataset`):
|
| 417 |
+
The input to split between processes.
|
| 418 |
+
apply_padding (`bool`, `optional`, defaults to `False`):
|
| 419 |
+
Whether to apply padding by repeating the last element of the input so that all processes have the same
|
| 420 |
+
number of elements. Useful when trying to perform actions such as `gather()` on the outputs or passing
|
| 421 |
+
in less inputs than there are processes. If so, just remember to drop the padded elements afterwards.
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
Example:
|
| 425 |
+
|
| 426 |
+
```python
|
| 427 |
+
# Assume there are two processes
|
| 428 |
+
from accelerate import PartialState
|
| 429 |
+
|
| 430 |
+
state = PartialState()
|
| 431 |
+
with state.split_between_processes(["A", "B", "C"]) as inputs:
|
| 432 |
+
print(inputs)
|
| 433 |
+
# Process 0
|
| 434 |
+
["A", "B"]
|
| 435 |
+
# Process 1
|
| 436 |
+
["C"]
|
| 437 |
+
|
| 438 |
+
with state.split_between_processes(["A", "B", "C"], apply_padding=True) as inputs:
|
| 439 |
+
print(inputs)
|
| 440 |
+
# Process 0
|
| 441 |
+
["A", "B"]
|
| 442 |
+
# Process 1
|
| 443 |
+
["C", "C"]
|
| 444 |
+
```
|
| 445 |
+
"""
|
| 446 |
+
if self.num_processes == 1:
|
| 447 |
+
yield inputs
|
| 448 |
+
return
|
| 449 |
+
length = len(inputs)
|
| 450 |
+
# Nested dictionary of any types
|
| 451 |
+
if isinstance(inputs, dict):
|
| 452 |
+
length = len(inputs[list(inputs.keys())[0]])
|
| 453 |
+
if not all(len(v) == length for v in inputs.values()):
|
| 454 |
+
raise ValueError("All values in the dictionary must have the same length")
|
| 455 |
+
num_samples_per_process, num_extras = divmod(length, self.num_processes)
|
| 456 |
+
start_index = self.process_index * num_samples_per_process + min(self.process_index, num_extras)
|
| 457 |
+
end_index = start_index + num_samples_per_process + (1 if self.process_index < num_extras else 0)
|
| 458 |
+
|
| 459 |
+
def _split_values(inputs, start_index, end_index):
|
| 460 |
+
if isinstance(inputs, (list, tuple, torch.Tensor)):
|
| 461 |
+
if start_index >= len(inputs):
|
| 462 |
+
result = inputs[-1:]
|
| 463 |
+
else:
|
| 464 |
+
result = inputs[start_index:end_index]
|
| 465 |
+
if apply_padding:
|
| 466 |
+
if isinstance(result, torch.Tensor):
|
| 467 |
+
from accelerate.utils import pad_across_processes, send_to_device
|
| 468 |
+
|
| 469 |
+
# The tensor needs to be on the device before we can pad it
|
| 470 |
+
tensorized_result = send_to_device(result, self.device)
|
| 471 |
+
result = pad_across_processes(tensorized_result, pad_index=inputs[-1])
|
| 472 |
+
else:
|
| 473 |
+
result += [result[-1]] * (num_samples_per_process + 1 - len(result))
|
| 474 |
+
return result
|
| 475 |
+
elif isinstance(inputs, dict):
|
| 476 |
+
for key in inputs.keys():
|
| 477 |
+
inputs[key] = _split_values(inputs[key], start_index, end_index)
|
| 478 |
+
return inputs
|
| 479 |
+
else:
|
| 480 |
+
if is_datasets_available():
|
| 481 |
+
from datasets import Dataset
|
| 482 |
+
|
| 483 |
+
if isinstance(inputs, Dataset):
|
| 484 |
+
if start_index >= len(inputs):
|
| 485 |
+
start_index = len(inputs) - 1
|
| 486 |
+
if end_index > len(inputs):
|
| 487 |
+
end_index = len(inputs)
|
| 488 |
+
result_idcs = list(range(start_index, end_index))
|
| 489 |
+
if apply_padding:
|
| 490 |
+
result_idcs += [end_index - 1] * (num_samples_per_process + 1 - len(result_idcs))
|
| 491 |
+
return inputs.select(result_idcs)
|
| 492 |
+
return inputs
|
| 493 |
+
|
| 494 |
+
yield _split_values(inputs, start_index, end_index)
|
| 495 |
+
|
| 496 |
+
@contextmanager
|
| 497 |
+
def main_process_first(self):
|
| 498 |
+
"""
|
| 499 |
+
Lets the main process go first inside a with block.
|
| 500 |
+
|
| 501 |
+
The other processes will enter the with block after the main process exits.
|
| 502 |
+
|
| 503 |
+
Example:
|
| 504 |
+
|
| 505 |
+
```python
|
| 506 |
+
>>> from accelerate import Accelerator
|
| 507 |
+
|
| 508 |
+
>>> accelerator = Accelerator()
|
| 509 |
+
>>> with accelerator.main_process_first():
|
| 510 |
+
... # This will be printed first by process 0 then in a seemingly
|
| 511 |
+
... # random order by the other processes.
|
| 512 |
+
... print(f"This will be printed by process {accelerator.process_index}")
|
| 513 |
+
```
|
| 514 |
+
"""
|
| 515 |
+
yield from self._goes_first(self.is_main_process)
|
| 516 |
+
|
| 517 |
+
@contextmanager
|
| 518 |
+
def local_main_process_first(self):
|
| 519 |
+
"""
|
| 520 |
+
Lets the local main process go inside a with block.
|
| 521 |
+
|
| 522 |
+
The other processes will enter the with block after the main process exits.
|
| 523 |
+
|
| 524 |
+
Example:
|
| 525 |
+
|
| 526 |
+
```python
|
| 527 |
+
>>> from accelerate.state import PartialState
|
| 528 |
+
|
| 529 |
+
>>> state = PartialState()
|
| 530 |
+
>>> with state.local_main_process_first():
|
| 531 |
+
... # This will be printed first by local process 0 then in a seemingly
|
| 532 |
+
... # random order by the other processes.
|
| 533 |
+
... print(f"This will be printed by process {state.local_process_index}")
|
| 534 |
+
```
|
| 535 |
+
"""
|
| 536 |
+
yield from self._goes_first(self.is_local_main_process)
|
| 537 |
+
|
| 538 |
+
def on_main_process(self, function: Callable[..., Any] = None):
|
| 539 |
+
"""
|
| 540 |
+
Decorator that only runs the decorated function on the main process.
|
| 541 |
+
|
| 542 |
+
Args:
|
| 543 |
+
function (`Callable`): The function to decorate.
|
| 544 |
+
|
| 545 |
+
Example:
|
| 546 |
+
|
| 547 |
+
```python
|
| 548 |
+
>>> from accelerate.state import PartialState
|
| 549 |
+
|
| 550 |
+
>>> state = PartialState()
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
>>> @state.on_main_process
|
| 554 |
+
... def print_something():
|
| 555 |
+
... print("This will be printed by process 0 only.")
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
>>> print_something()
|
| 559 |
+
"This will be printed by process 0 only"
|
| 560 |
+
```
|
| 561 |
+
"""
|
| 562 |
+
if not self.initialized:
|
| 563 |
+
raise ValueError("The `PartialState` or `Accelerator` must be initialized before calling this function.")
|
| 564 |
+
if self.is_main_process or not self.use_distributed:
|
| 565 |
+
return function
|
| 566 |
+
return do_nothing
|
| 567 |
+
|
| 568 |
+
def on_local_main_process(self, function: Callable[..., Any] = None):
|
| 569 |
+
"""
|
| 570 |
+
Decorator that only runs the decorated function on the local main process.
|
| 571 |
+
|
| 572 |
+
Args:
|
| 573 |
+
function (`Callable`): The function to decorate.
|
| 574 |
+
|
| 575 |
+
Example:
|
| 576 |
+
```python
|
| 577 |
+
# Assume we have 2 servers with 4 processes each.
|
| 578 |
+
from accelerate.state import PartialState
|
| 579 |
+
|
| 580 |
+
state = PartialState()
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
@state.on_local_main_process
|
| 584 |
+
def print_something():
|
| 585 |
+
print("This will be printed by process 0 only on each server.")
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
print_something()
|
| 589 |
+
# On server 1:
|
| 590 |
+
"This will be printed by process 0 only"
|
| 591 |
+
# On server 2:
|
| 592 |
+
"This will be printed by process 0 only"
|
| 593 |
+
```
|
| 594 |
+
"""
|
| 595 |
+
if self.is_local_main_process or not self.use_distributed:
|
| 596 |
+
return function
|
| 597 |
+
return do_nothing
|
| 598 |
+
|
| 599 |
+
def on_last_process(self, function: Callable[..., Any]):
|
| 600 |
+
"""
|
| 601 |
+
Decorator that only runs the decorated function on the last process.
|
| 602 |
+
|
| 603 |
+
Args:
|
| 604 |
+
function (`Callable`): The function to decorate.
|
| 605 |
+
|
| 606 |
+
Example:
|
| 607 |
+
```python
|
| 608 |
+
# Assume we have 4 processes.
|
| 609 |
+
from accelerate.state import PartialState
|
| 610 |
+
|
| 611 |
+
state = PartialState()
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
@state.on_last_process
|
| 615 |
+
def print_something():
|
| 616 |
+
print(f"Printed on process {state.process_index}")
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
print_something()
|
| 620 |
+
"Printed on process 3"
|
| 621 |
+
```
|
| 622 |
+
"""
|
| 623 |
+
if self.is_last_process or not self.use_distributed:
|
| 624 |
+
return function
|
| 625 |
+
return do_nothing
|
| 626 |
+
|
| 627 |
+
def on_process(self, function: Callable[..., Any] = None, process_index: int = None):
|
| 628 |
+
"""
|
| 629 |
+
Decorator that only runs the decorated function on the process with the given index.
|
| 630 |
+
|
| 631 |
+
Args:
|
| 632 |
+
function (`Callable`, `optional`):
|
| 633 |
+
The function to decorate.
|
| 634 |
+
process_index (`int`, `optional`):
|
| 635 |
+
The index of the process on which to run the function.
|
| 636 |
+
|
| 637 |
+
Example:
|
| 638 |
+
```python
|
| 639 |
+
# Assume we have 4 processes.
|
| 640 |
+
from accelerate.state import PartialState
|
| 641 |
+
|
| 642 |
+
state = PartialState()
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
@state.on_process(process_index=2)
|
| 646 |
+
def print_something():
|
| 647 |
+
print(f"Printed on process {state.process_index}")
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
print_something()
|
| 651 |
+
"Printed on process 2"
|
| 652 |
+
```
|
| 653 |
+
"""
|
| 654 |
+
if function is None:
|
| 655 |
+
return partial(self.on_process, process_index=process_index)
|
| 656 |
+
if (self.process_index == process_index) or (not self.use_distributed):
|
| 657 |
+
return function
|
| 658 |
+
return do_nothing
|
| 659 |
+
|
| 660 |
+
def on_local_process(self, function: Callable[..., Any] = None, local_process_index: int = None):
|
| 661 |
+
"""
|
| 662 |
+
Decorator that only runs the decorated function on the process with the given index on the current node.
|
| 663 |
+
|
| 664 |
+
Args:
|
| 665 |
+
function (`Callable`, *optional*):
|
| 666 |
+
The function to decorate.
|
| 667 |
+
local_process_index (`int`, *optional*):
|
| 668 |
+
The index of the local process on which to run the function.
|
| 669 |
+
|
| 670 |
+
Example:
|
| 671 |
+
```python
|
| 672 |
+
# Assume we have 2 servers with 4 processes each.
|
| 673 |
+
from accelerate import Accelerator
|
| 674 |
+
|
| 675 |
+
accelerator = Accelerator()
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
@accelerator.on_local_process(local_process_index=2)
|
| 679 |
+
def print_something():
|
| 680 |
+
print(f"Printed on process {accelerator.local_process_index}")
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
print_something()
|
| 684 |
+
# On server 1:
|
| 685 |
+
"Printed on process 2"
|
| 686 |
+
# On server 2:
|
| 687 |
+
"Printed on process 2"
|
| 688 |
+
```
|
| 689 |
+
"""
|
| 690 |
+
if function is None:
|
| 691 |
+
return partial(self.on_local_process, local_process_index=local_process_index)
|
| 692 |
+
if (self.local_process_index == local_process_index) or (not self.use_distributed):
|
| 693 |
+
return function
|
| 694 |
+
return do_nothing
|
| 695 |
+
|
| 696 |
+
def print(self, *args, **kwargs):
|
| 697 |
+
if self.is_local_main_process:
|
| 698 |
+
print(*args, **kwargs)
|
| 699 |
+
|
| 700 |
+
@property
|
| 701 |
+
def default_device(self) -> torch.device:
|
| 702 |
+
"""
|
| 703 |
+
Returns the default device which is:
|
| 704 |
+
- MPS if `torch.backends.mps.is_available()` and `torch.backends.mps.is_built()` both return True.
|
| 705 |
+
- CUDA if `torch.cuda.is_available()`
|
| 706 |
+
- MLU if `is_mlu_available()`
|
| 707 |
+
- SDAA if `is_sdaa_available()`
|
| 708 |
+
- MUSA if `is_musa_available()`
|
| 709 |
+
- NPU if `is_npu_available()`
|
| 710 |
+
- HPU if `is_hpu_available()`
|
| 711 |
+
- CPU otherwise
|
| 712 |
+
"""
|
| 713 |
+
if is_mps_available():
|
| 714 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
| 715 |
+
return torch.device("mps")
|
| 716 |
+
elif is_mlu_available():
|
| 717 |
+
return torch.device("mlu")
|
| 718 |
+
elif is_sdaa_available():
|
| 719 |
+
return torch.device("sdaa")
|
| 720 |
+
elif is_musa_available():
|
| 721 |
+
return torch.device("musa")
|
| 722 |
+
# NPU should be checked before CUDA when using `transfer_to_npu`
|
| 723 |
+
# See issue #3020: https://github.com/huggingface/accelerate/issues/3020
|
| 724 |
+
elif is_npu_available():
|
| 725 |
+
return torch.device("npu")
|
| 726 |
+
elif is_hpu_available():
|
| 727 |
+
return torch.device("hpu")
|
| 728 |
+
elif torch.cuda.is_available():
|
| 729 |
+
return torch.device("cuda")
|
| 730 |
+
elif is_xpu_available():
|
| 731 |
+
return torch.device("xpu")
|
| 732 |
+
else:
|
| 733 |
+
return torch.device("cpu")
|
| 734 |
+
|
| 735 |
+
def _prepare_backend(
|
| 736 |
+
self, cpu: bool = False, sagemaker_dp=False, backend: str = None
|
| 737 |
+
) -> tuple[str, DistributedType]:
|
| 738 |
+
"Prepares any imports needed before initializing the distributed backend and sets `self.backend` properly"
|
| 739 |
+
distributed_type = None
|
| 740 |
+
if sagemaker_dp:
|
| 741 |
+
import smdistributed.dataparallel.torch.torch_smddp # noqa
|
| 742 |
+
|
| 743 |
+
backend = "smddp"
|
| 744 |
+
distributed_type = DistributedType.MULTI_GPU
|
| 745 |
+
elif is_torch_xla_available():
|
| 746 |
+
backend = "xla"
|
| 747 |
+
distributed_type = DistributedType.XLA
|
| 748 |
+
|
| 749 |
+
elif int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu:
|
| 750 |
+
if is_mlu_available():
|
| 751 |
+
backend = "cncl"
|
| 752 |
+
distributed_type = DistributedType.MULTI_MLU
|
| 753 |
+
if is_sdaa_available():
|
| 754 |
+
backend = "tccl"
|
| 755 |
+
distributed_type = DistributedType.MULTI_SDAA
|
| 756 |
+
elif is_musa_available():
|
| 757 |
+
backend = "mccl"
|
| 758 |
+
distributed_type = DistributedType.MULTI_MUSA
|
| 759 |
+
# NPU should be checked before CUDA when using `transfer_to_npu`
|
| 760 |
+
# See issue #3020: https://github.com/huggingface/accelerate/issues/3020
|
| 761 |
+
elif is_npu_available():
|
| 762 |
+
backend = "hccl"
|
| 763 |
+
distributed_type = DistributedType.MULTI_NPU
|
| 764 |
+
elif is_hpu_available(init_hccl=True):
|
| 765 |
+
if backend is None:
|
| 766 |
+
backend = "hccl"
|
| 767 |
+
distributed_type = DistributedType.MULTI_HPU
|
| 768 |
+
elif torch.cuda.is_available():
|
| 769 |
+
if backend is None:
|
| 770 |
+
backend = "nccl"
|
| 771 |
+
distributed_type = DistributedType.MULTI_GPU
|
| 772 |
+
elif is_xpu_available() and is_xccl_available():
|
| 773 |
+
if backend is None:
|
| 774 |
+
backend = "xccl"
|
| 775 |
+
distributed_type = DistributedType.MULTI_XPU
|
| 776 |
+
|
| 777 |
+
if distributed_type is None and (
|
| 778 |
+
int(os.environ.get("LOCAL_RANK", -1)) != -1
|
| 779 |
+
or get_int_from_env(["PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE", "WORLD_SIZE"], 1) > 1
|
| 780 |
+
):
|
| 781 |
+
if not cpu and is_xpu_available():
|
| 782 |
+
distributed_type = DistributedType.MULTI_XPU
|
| 783 |
+
else:
|
| 784 |
+
distributed_type = DistributedType.MULTI_CPU
|
| 785 |
+
|
| 786 |
+
if (
|
| 787 |
+
backend in (None, "ccl")
|
| 788 |
+
and is_ccl_available()
|
| 789 |
+
and (get_int_from_env(["CCL_WORKER_COUNT"], 0) > 0 or distributed_type == DistributedType.MULTI_XPU)
|
| 790 |
+
):
|
| 791 |
+
if get_ccl_version() >= "1.12":
|
| 792 |
+
import oneccl_bindings_for_pytorch # noqa: F401
|
| 793 |
+
else:
|
| 794 |
+
import torch_ccl # noqa: F401
|
| 795 |
+
|
| 796 |
+
backend = "ccl"
|
| 797 |
+
elif backend in (None, "mpi") and torch.distributed.is_mpi_available():
|
| 798 |
+
backend = "mpi"
|
| 799 |
+
else:
|
| 800 |
+
backend = "gloo"
|
| 801 |
+
if distributed_type is None:
|
| 802 |
+
distributed_type = DistributedType.NO
|
| 803 |
+
|
| 804 |
+
return backend, distributed_type
|
| 805 |
+
|
| 806 |
+
def set_device(self):
|
| 807 |
+
"""
|
| 808 |
+
Sets the device in `self.device` to the current distributed environment.
|
| 809 |
+
"""
|
| 810 |
+
if self.device is not None:
|
| 811 |
+
return
|
| 812 |
+
if self.distributed_type == DistributedType.NO:
|
| 813 |
+
self.device = torch.device("cpu") if self._cpu else self.default_device
|
| 814 |
+
return
|
| 815 |
+
device = str(self.distributed_type).split(".")[-1].replace("MULTI_", "").lower()
|
| 816 |
+
if device not in ("cpu", "gpu", "mlu", "musa", "npu", "xpu", "xla", "hpu", "sdaa"):
|
| 817 |
+
raise ValueError(
|
| 818 |
+
f"Can't set device for {self.distributed_type} ({device}), verify we should be calling `_set_device()` for it!"
|
| 819 |
+
)
|
| 820 |
+
if device == "xla":
|
| 821 |
+
self.device = xm.xla_device()
|
| 822 |
+
elif device == "hpu":
|
| 823 |
+
self.device = torch.device("hpu", torch.hpu.current_device())
|
| 824 |
+
else:
|
| 825 |
+
if device == "gpu":
|
| 826 |
+
device = "cuda"
|
| 827 |
+
device_module = getattr(torch, device)
|
| 828 |
+
device_index = self.local_process_index % device_module.device_count()
|
| 829 |
+
self.device = torch.device(device, device_index)
|
| 830 |
+
device_module.set_device(self.device)
|
| 831 |
+
|
| 832 |
+
def destroy_process_group(self, group=None):
|
| 833 |
+
"""
|
| 834 |
+
Destroys the process group. If one is not specified, the default process group is destroyed.
|
| 835 |
+
"""
|
| 836 |
+
if self.fork_launched and group is None:
|
| 837 |
+
return
|
| 838 |
+
# needed when using torch.distributed.init_process_group
|
| 839 |
+
if torch.distributed.is_initialized():
|
| 840 |
+
torch.distributed.destroy_process_group(group)
|
| 841 |
+
|
| 842 |
+
def __getattr__(self, name: str):
|
| 843 |
+
# By this point we know that no attributes of `self` contain `name`,
|
| 844 |
+
# so we just modify the error message
|
| 845 |
+
if name in self._known_attrs:
|
| 846 |
+
raise AttributeError(
|
| 847 |
+
f"`PartialState` object has no attribute `{name}`. "
|
| 848 |
+
"This happens if `PartialState._reset_state()` was called and "
|
| 849 |
+
"an `Accelerator` or `PartialState` was not reinitialized."
|
| 850 |
+
)
|
| 851 |
+
# Raise a typical AttributeError
|
| 852 |
+
raise AttributeError(f"'PartialState' object has no attribute '{name}'")
|
| 853 |
+
|
| 854 |
+
|
| 855 |
+
class AcceleratorState:
|
| 856 |
+
"""
|
| 857 |
+
Singleton class that has information about the current training environment.
|
| 858 |
+
|
| 859 |
+
**Available attributes:**
|
| 860 |
+
|
| 861 |
+
- **device** (`torch.device`) -- The device to use.
|
| 862 |
+
- **distributed_type** ([`~accelerate.state.DistributedType`]) -- The type of distributed environment currently
|
| 863 |
+
in use.
|
| 864 |
+
- **initialized** (`bool`) -- Whether or not the `AcceleratorState` has been initialized from `Accelerator`.
|
| 865 |
+
- **local_process_index** (`int`) -- The index of the current process on the current server.
|
| 866 |
+
- **mixed_precision** (`str`) -- Whether or not the current script will use mixed precision, and if so the type
|
| 867 |
+
of mixed precision being performed. (Choose from 'no','fp16','bf16 or 'fp8').
|
| 868 |
+
- **num_processes** (`int`) -- The number of processes currently launched in parallel.
|
| 869 |
+
- **process_index** (`int`) -- The index of the current process.
|
| 870 |
+
- **is_last_process** (`bool`) -- Whether or not the current process is the last one.
|
| 871 |
+
- **is_main_process** (`bool`) -- Whether or not the current process is the main one.
|
| 872 |
+
- **is_local_main_process** (`bool`) -- Whether or not the current process is the main one on the local node.
|
| 873 |
+
- **debug** (`bool`) -- Whether or not the current script is being run in debug mode.
|
| 874 |
+
"""
|
| 875 |
+
|
| 876 |
+
_shared_state = SharedDict()
|
| 877 |
+
_known_attrs = PartialState._known_attrs + [
|
| 878 |
+
"deepspeed_plugin",
|
| 879 |
+
"use_ipex",
|
| 880 |
+
"fsdp_plugin",
|
| 881 |
+
"megatron_lm_plugin",
|
| 882 |
+
"dynamo_plugin",
|
| 883 |
+
]
|
| 884 |
+
|
| 885 |
+
def __init__(
|
| 886 |
+
self,
|
| 887 |
+
mixed_precision: str = None,
|
| 888 |
+
cpu: bool = False,
|
| 889 |
+
dynamo_plugin=None,
|
| 890 |
+
deepspeed_plugin=None,
|
| 891 |
+
fsdp_plugin=None,
|
| 892 |
+
torch_tp_plugin=None,
|
| 893 |
+
megatron_lm_plugin=None,
|
| 894 |
+
_from_accelerator: bool = False,
|
| 895 |
+
**kwargs,
|
| 896 |
+
):
|
| 897 |
+
self.__dict__ = self._shared_state
|
| 898 |
+
if parse_flag_from_env("ACCELERATE_USE_CPU"):
|
| 899 |
+
cpu = True
|
| 900 |
+
if PartialState._shared_state == {}:
|
| 901 |
+
PartialState(cpu, **kwargs)
|
| 902 |
+
self.__dict__.update(PartialState._shared_state)
|
| 903 |
+
self._check_initialized(mixed_precision, cpu)
|
| 904 |
+
if not self.initialized:
|
| 905 |
+
self.deepspeed_plugins = None
|
| 906 |
+
self.use_ipex = None
|
| 907 |
+
self.torch_tp_plugin = torch_tp_plugin
|
| 908 |
+
mixed_precision = (
|
| 909 |
+
parse_choice_from_env("ACCELERATE_MIXED_PRECISION", "no")
|
| 910 |
+
if mixed_precision is None
|
| 911 |
+
else mixed_precision.lower()
|
| 912 |
+
)
|
| 913 |
+
if mixed_precision == "fp8":
|
| 914 |
+
# this is confusing, why is is_fp8_available only checks for library availability ?
|
| 915 |
+
if not is_fp8_available():
|
| 916 |
+
raise ValueError(
|
| 917 |
+
"Using `fp8` precision requires `transformer_engine` or `MS-AMP` to be installed."
|
| 918 |
+
)
|
| 919 |
+
elif torch.cuda.is_available() and not check_cuda_fp8_capability():
|
| 920 |
+
logger.warning(
|
| 921 |
+
f"The current device has compute capability of {torch.cuda.get_device_capability()} which is "
|
| 922 |
+
"insufficient for FP8 mixed precision training (requires a GPU Hopper/Ada Lovelace "
|
| 923 |
+
"or higher, compute capability of 8.9 or higher). Will use FP16 instead."
|
| 924 |
+
)
|
| 925 |
+
mixed_precision = "fp16"
|
| 926 |
+
elif is_habana_gaudi1():
|
| 927 |
+
logger.warning(
|
| 928 |
+
"The current HPU device is Gaudi1 which does not support FP8 mixed precision training (requires "
|
| 929 |
+
"Gaudi2 or higher). Will use BF16 instead."
|
| 930 |
+
)
|
| 931 |
+
mixed_precision = "bf16"
|
| 932 |
+
|
| 933 |
+
self.dynamo_plugin = dynamo_plugin
|
| 934 |
+
if not _from_accelerator:
|
| 935 |
+
raise ValueError(
|
| 936 |
+
"Please make sure to properly initialize your accelerator via `accelerator = Accelerator()` "
|
| 937 |
+
"before using any functionality from the `accelerate` library."
|
| 938 |
+
)
|
| 939 |
+
# deepspeed handles mixed_precision using deepspeed_config
|
| 940 |
+
self._mixed_precision = "no" if self.distributed_type == DistributedType.DEEPSPEED else mixed_precision
|
| 941 |
+
if self.distributed_type == DistributedType.XLA and is_torch_xla_available(check_is_tpu=True):
|
| 942 |
+
if mixed_precision == "bf16":
|
| 943 |
+
if os.environ.get("ACCELERATE_DOWNCAST_BF16"):
|
| 944 |
+
os.environ["XLA_USE_BF16"] = str(0)
|
| 945 |
+
os.environ["XLA_DOWNCAST_BF16"] = str(1)
|
| 946 |
+
self.downcast_bfloat = True
|
| 947 |
+
else:
|
| 948 |
+
os.environ["XLA_USE_BF16"] = str(1)
|
| 949 |
+
os.environ["XLA_DOWNCAST_BF16"] = str(0)
|
| 950 |
+
self.downcast_bfloat = False
|
| 951 |
+
elif os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true" and not cpu:
|
| 952 |
+
self.deepspeed_plugins = deepspeed_plugin
|
| 953 |
+
self.distributed_type = DistributedType.DEEPSPEED
|
| 954 |
+
elif self.distributed_type in [
|
| 955 |
+
DistributedType.MULTI_GPU,
|
| 956 |
+
DistributedType.MULTI_MLU,
|
| 957 |
+
DistributedType.MULTI_SDAA,
|
| 958 |
+
DistributedType.MULTI_MUSA,
|
| 959 |
+
DistributedType.MULTI_NPU,
|
| 960 |
+
DistributedType.MULTI_XPU,
|
| 961 |
+
DistributedType.MULTI_HPU,
|
| 962 |
+
]:
|
| 963 |
+
if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" or fsdp_plugin is not None:
|
| 964 |
+
self.distributed_type = DistributedType.FSDP
|
| 965 |
+
if self._mixed_precision != "no":
|
| 966 |
+
fsdp_plugin.set_mixed_precision(self._mixed_precision)
|
| 967 |
+
self.fsdp_plugin = fsdp_plugin
|
| 968 |
+
if os.environ.get("ACCELERATE_USE_MEGATRON_LM", "false") == "true" and self.distributed_type not in [
|
| 969 |
+
DistributedType.MULTI_XPU,
|
| 970 |
+
]:
|
| 971 |
+
self.distributed_type = DistributedType.MEGATRON_LM
|
| 972 |
+
megatron_lm_plugin.set_mixed_precision(self._mixed_precision)
|
| 973 |
+
self.megatron_lm_plugin = megatron_lm_plugin
|
| 974 |
+
if os.environ.get("ACCELERATE_USE_TP", "false") == "true" or self.torch_tp_plugin is not None:
|
| 975 |
+
self.distributed_type = DistributedType.TP
|
| 976 |
+
elif self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]:
|
| 977 |
+
if is_ipex_available():
|
| 978 |
+
# check if user disables it explicitly
|
| 979 |
+
self.use_ipex = parse_flag_from_env("ACCELERATE_USE_IPEX", default=True)
|
| 980 |
+
else:
|
| 981 |
+
self.use_ipex = False
|
| 982 |
+
if (
|
| 983 |
+
self.dynamo_plugin.backend != DynamoBackend.NO
|
| 984 |
+
and self._mixed_precision == "no"
|
| 985 |
+
and self.device.type == "cuda"
|
| 986 |
+
):
|
| 987 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 988 |
+
if (
|
| 989 |
+
self.dynamo_plugin.backend != DynamoBackend.NO
|
| 990 |
+
and self._mixed_precision == "no"
|
| 991 |
+
and self.device.type == "musa"
|
| 992 |
+
):
|
| 993 |
+
torch.backends.musa.matmul.allow_tf32 = True
|
| 994 |
+
PartialState._shared_state["distributed_type"] = self.distributed_type
|
| 995 |
+
|
| 996 |
+
@property
|
| 997 |
+
def initialized(self) -> bool:
|
| 998 |
+
return self._shared_state != PartialState._shared_state
|
| 999 |
+
|
| 1000 |
+
def __repr__(self):
|
| 1001 |
+
repr = PartialState().__repr__() + f"\nMixed precision type: {self.mixed_precision}\n"
|
| 1002 |
+
if self.distributed_type == DistributedType.DEEPSPEED:
|
| 1003 |
+
repr += f"ds_config: {self.deepspeed_plugin.deepspeed_config}\n"
|
| 1004 |
+
return repr
|
| 1005 |
+
|
| 1006 |
+
def _check_initialized(self, mixed_precision=None, cpu=None):
|
| 1007 |
+
"Checks if a modification is trying to be made and the `AcceleratorState` has already been initialized"
|
| 1008 |
+
if self.initialized:
|
| 1009 |
+
err = "AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and pass `{flag}` to `Accelerator()`."
|
| 1010 |
+
if cpu and self.device.type != "cpu":
|
| 1011 |
+
raise ValueError(err.format(flag="cpu=True"))
|
| 1012 |
+
if (
|
| 1013 |
+
mixed_precision is not None
|
| 1014 |
+
and mixed_precision != self._mixed_precision
|
| 1015 |
+
and self.distributed_type != DistributedType.DEEPSPEED
|
| 1016 |
+
):
|
| 1017 |
+
raise ValueError(err.format(flag=f"mixed_precision='{mixed_precision}'"))
|
| 1018 |
+
|
| 1019 |
+
@property
|
| 1020 |
+
def mixed_precision(self):
|
| 1021 |
+
if self.distributed_type == DistributedType.DEEPSPEED:
|
| 1022 |
+
config = self.deepspeed_plugin.deepspeed_config
|
| 1023 |
+
if config.get("fp16", {}).get("enabled", False):
|
| 1024 |
+
mixed_precision = "fp16"
|
| 1025 |
+
elif config.get("bf16", {}).get("enabled", False):
|
| 1026 |
+
mixed_precision = "bf16"
|
| 1027 |
+
else:
|
| 1028 |
+
mixed_precision = "no"
|
| 1029 |
+
else:
|
| 1030 |
+
mixed_precision = self._mixed_precision
|
| 1031 |
+
return mixed_precision
|
| 1032 |
+
|
| 1033 |
+
@staticmethod
|
| 1034 |
+
def _reset_state(reset_partial_state: bool = False):
|
| 1035 |
+
"Resets `_shared_state`, is used internally and should not be called"
|
| 1036 |
+
AcceleratorState._shared_state.clear()
|
| 1037 |
+
if reset_partial_state:
|
| 1038 |
+
PartialState._reset_state()
|
| 1039 |
+
|
| 1040 |
+
def destroy_process_group(self, group=None):
|
| 1041 |
+
"""
|
| 1042 |
+
Destroys the process group. If one is not specified, the default process group is destroyed.
|
| 1043 |
+
|
| 1044 |
+
If `self.fork_lauched` is `True` and `group` is `None`, nothing happens.
|
| 1045 |
+
"""
|
| 1046 |
+
PartialState().destroy_process_group(group)
|
| 1047 |
+
|
| 1048 |
+
@property
|
| 1049 |
+
def fork_launched(self):
|
| 1050 |
+
return PartialState().fork_launched
|
| 1051 |
+
|
| 1052 |
+
@property
|
| 1053 |
+
def use_distributed(self):
|
| 1054 |
+
"""
|
| 1055 |
+
Whether the Accelerator is configured for distributed training
|
| 1056 |
+
"""
|
| 1057 |
+
return PartialState().use_distributed
|
| 1058 |
+
|
| 1059 |
+
@property
|
| 1060 |
+
def is_fsdp2(self) -> bool:
|
| 1061 |
+
return self.distributed_type == DistributedType.FSDP and self.fsdp_plugin.fsdp_version == 2
|
| 1062 |
+
|
| 1063 |
+
@property
|
| 1064 |
+
def is_last_process(self) -> bool:
|
| 1065 |
+
"Returns whether the current process is the last one"
|
| 1066 |
+
return PartialState().is_last_process
|
| 1067 |
+
|
| 1068 |
+
@property
|
| 1069 |
+
def is_main_process(self) -> bool:
|
| 1070 |
+
"Returns whether the current process is the main process"
|
| 1071 |
+
return PartialState().is_main_process
|
| 1072 |
+
|
| 1073 |
+
@property
|
| 1074 |
+
def is_local_main_process(self) -> bool:
|
| 1075 |
+
"Returns whether the current process is the main process on the local node"
|
| 1076 |
+
return PartialState().is_local_main_process
|
| 1077 |
+
|
| 1078 |
+
def wait_for_everyone(self):
|
| 1079 |
+
PartialState().wait_for_everyone()
|
| 1080 |
+
|
| 1081 |
+
@contextmanager
|
| 1082 |
+
def split_between_processes(self, inputs: list | tuple | dict | torch.Tensor, apply_padding: bool = False):
|
| 1083 |
+
"""
|
| 1084 |
+
Splits `input` between `self.num_processes` quickly and can be then used on that process. Useful when doing
|
| 1085 |
+
distributed inference, such as with different prompts.
|
| 1086 |
+
|
| 1087 |
+
Note that when using a `dict`, all keys need to have the same number of elements.
|
| 1088 |
+
|
| 1089 |
+
Args:
|
| 1090 |
+
inputs (`list`, `tuple`, `torch.Tensor`, or `dict` of `list`/`tuple`/`torch.Tensor`):
|
| 1091 |
+
The input to split between processes.
|
| 1092 |
+
apply_padding (`bool`, `optional`, defaults to `False`):
|
| 1093 |
+
Whether to apply padding by repeating the last element of the input so that all processes have the same
|
| 1094 |
+
number of elements. Useful when trying to perform actions such as `gather()` on the outputs or passing
|
| 1095 |
+
in less inputs than there are processes. If so, just remember to drop the padded elements afterwards.
|
| 1096 |
+
|
| 1097 |
+
|
| 1098 |
+
Example:
|
| 1099 |
+
|
| 1100 |
+
```python
|
| 1101 |
+
# Assume there are two processes
|
| 1102 |
+
from accelerate.state import AcceleratorState
|
| 1103 |
+
|
| 1104 |
+
state = AcceleratorState()
|
| 1105 |
+
with state.split_between_processes(["A", "B", "C"]) as inputs:
|
| 1106 |
+
print(inputs)
|
| 1107 |
+
# Process 0
|
| 1108 |
+
["A", "B"]
|
| 1109 |
+
# Process 1
|
| 1110 |
+
["C"]
|
| 1111 |
+
|
| 1112 |
+
with state.split_between_processes(["A", "B", "C"], apply_padding=True) as inputs:
|
| 1113 |
+
print(inputs)
|
| 1114 |
+
# Process 0
|
| 1115 |
+
["A", "B"]
|
| 1116 |
+
# Process 1
|
| 1117 |
+
["C", "C"]
|
| 1118 |
+
```
|
| 1119 |
+
"""
|
| 1120 |
+
with PartialState().split_between_processes(inputs, apply_padding=apply_padding) as inputs:
|
| 1121 |
+
yield inputs
|
| 1122 |
+
|
| 1123 |
+
@contextmanager
|
| 1124 |
+
def main_process_first(self):
|
| 1125 |
+
"""
|
| 1126 |
+
Lets the main process go first inside a with block.
|
| 1127 |
+
|
| 1128 |
+
The other processes will enter the with block after the main process exits.
|
| 1129 |
+
"""
|
| 1130 |
+
with PartialState().main_process_first():
|
| 1131 |
+
yield
|
| 1132 |
+
|
| 1133 |
+
@contextmanager
|
| 1134 |
+
def local_main_process_first(self):
|
| 1135 |
+
"""
|
| 1136 |
+
Lets the local main process go inside a with block.
|
| 1137 |
+
|
| 1138 |
+
The other processes will enter the with block after the main process exits.
|
| 1139 |
+
"""
|
| 1140 |
+
with PartialState().local_main_process_first():
|
| 1141 |
+
yield
|
| 1142 |
+
|
| 1143 |
+
@property
|
| 1144 |
+
def deepspeed_plugin(self):
|
| 1145 |
+
"""
|
| 1146 |
+
Returns the currently active DeepSpeedPlugin.
|
| 1147 |
+
|
| 1148 |
+
If not using deepspeed, returns `None`.
|
| 1149 |
+
"""
|
| 1150 |
+
# To maintain original behavior, return None if not using deepspeed.
|
| 1151 |
+
if self.distributed_type != DistributedType.DEEPSPEED:
|
| 1152 |
+
return None
|
| 1153 |
+
from accelerate.utils.deepspeed import get_active_deepspeed_plugin
|
| 1154 |
+
|
| 1155 |
+
return get_active_deepspeed_plugin(self)
|
| 1156 |
+
|
| 1157 |
+
@deepspeed_required
|
| 1158 |
+
def get_deepspeed_plugin(self, name: str):
|
| 1159 |
+
"""
|
| 1160 |
+
Returns the DeepSpeedPlugin with the given plugin_key.
|
| 1161 |
+
"""
|
| 1162 |
+
return self.deepspeed_plugins[name]
|
| 1163 |
+
|
| 1164 |
+
@deepspeed_required
|
| 1165 |
+
def select_deepspeed_plugin(self, name: str = None):
|
| 1166 |
+
"""
|
| 1167 |
+
Activates the DeepSpeedPlugin with the given `name`, and will disable all other plugins.
|
| 1168 |
+
"""
|
| 1169 |
+
for key, plugin in self.deepspeed_plugins.items():
|
| 1170 |
+
if key != name:
|
| 1171 |
+
plugin._unselect()
|
| 1172 |
+
self.deepspeed_plugins[name].select(_from_accelerator_state=True)
|
| 1173 |
+
|
| 1174 |
+
def print(self, *args, **kwargs):
|
| 1175 |
+
PartialState().print(*args, **kwargs)
|
| 1176 |
+
|
| 1177 |
+
def __getattr__(self, name: str):
|
| 1178 |
+
# By this point we know that no attributes of `self` contain `name`,
|
| 1179 |
+
# so we just modify the error message
|
| 1180 |
+
if name in self._known_attrs:
|
| 1181 |
+
raise AttributeError(
|
| 1182 |
+
f"`AcceleratorState` object has no attribute `{name}`. "
|
| 1183 |
+
"This happens if `AcceleratorState._reset_state()` was called and "
|
| 1184 |
+
"an `Accelerator` or `PartialState` was not reinitialized."
|
| 1185 |
+
)
|
| 1186 |
+
# Raise a typical AttributeError
|
| 1187 |
+
raise AttributeError(f"'AcceleratorState' object has no attribute '{name}'")
|
| 1188 |
+
|
| 1189 |
+
|
| 1190 |
+
class GradientState:
|
| 1191 |
+
"""
|
| 1192 |
+
Singleton class that has information related to gradient synchronization for gradient accumulation
|
| 1193 |
+
|
| 1194 |
+
**Available attributes:**
|
| 1195 |
+
|
| 1196 |
+
- **end_of_dataloader** (`bool`) -- Whether we have reached the end the current dataloader
|
| 1197 |
+
- **remainder** (`int`) -- The number of extra samples that were added from padding the dataloader
|
| 1198 |
+
- **sync_gradients** (`bool`) -- Whether the gradients should be synced across all devices
|
| 1199 |
+
- **active_dataloader** (`Optional[DataLoader]`) -- The dataloader that is currently being iterated over
|
| 1200 |
+
- **dataloader_references** (`List[Optional[DataLoader]]`) -- A list of references to the dataloaders that are
|
| 1201 |
+
being iterated over
|
| 1202 |
+
- **num_steps** (`int`) -- The number of steps to accumulate over
|
| 1203 |
+
- **adjust_scheduler** (`bool`) -- Whether the scheduler should be adjusted to account for the gradient
|
| 1204 |
+
accumulation
|
| 1205 |
+
- **sync_with_dataloader** (`bool`) -- Whether the gradients should be synced at the end of the dataloader
|
| 1206 |
+
iteration and the number of total steps reset
|
| 1207 |
+
- **is_xla_gradients_synced** (`bool`) -- Whether the XLA gradients have been synchronized. It is initialized
|
| 1208 |
+
as false. Once gradients have been reduced before the optimizer step, this flag is set to true. Subsequently,
|
| 1209 |
+
after each step, the flag is reset to false. FSDP will always synchronize the gradients, hence
|
| 1210 |
+
is_xla_gradients_synced is always true.
|
| 1211 |
+
"""
|
| 1212 |
+
|
| 1213 |
+
_shared_state = SharedDict()
|
| 1214 |
+
|
| 1215 |
+
def __init__(self, gradient_accumulation_plugin: GradientAccumulationPlugin | None = None):
|
| 1216 |
+
self.__dict__ = self._shared_state
|
| 1217 |
+
if not self.initialized:
|
| 1218 |
+
self.sync_gradients = True
|
| 1219 |
+
self._dataloader_references_ref = [None]
|
| 1220 |
+
self.plugin_kwargs = (
|
| 1221 |
+
gradient_accumulation_plugin.to_kwargs() if gradient_accumulation_plugin is not None else {}
|
| 1222 |
+
)
|
| 1223 |
+
self._is_xla_gradients_synced = False
|
| 1224 |
+
|
| 1225 |
+
# Plugin args are different and can be updated
|
| 1226 |
+
if gradient_accumulation_plugin is not None and self.plugin_kwargs != gradient_accumulation_plugin.to_kwargs():
|
| 1227 |
+
self.plugin_kwargs = gradient_accumulation_plugin.to_kwargs()
|
| 1228 |
+
|
| 1229 |
+
@property
|
| 1230 |
+
def num_steps(self) -> int:
|
| 1231 |
+
"Returns the number of steps to accumulate over"
|
| 1232 |
+
return self.plugin_kwargs.get("num_steps", 1)
|
| 1233 |
+
|
| 1234 |
+
@property
|
| 1235 |
+
def adjust_scheduler(self) -> bool:
|
| 1236 |
+
"Returns whether the scheduler should be adjusted"
|
| 1237 |
+
return self.plugin_kwargs.get("adjust_scheduler", False)
|
| 1238 |
+
|
| 1239 |
+
@property
|
| 1240 |
+
def sync_with_dataloader(self) -> bool:
|
| 1241 |
+
"Returns whether the gradients should be synced at the end of the dataloader iteration and the number of total steps reset"
|
| 1242 |
+
return self.plugin_kwargs.get("sync_with_dataloader", True)
|
| 1243 |
+
|
| 1244 |
+
@property
|
| 1245 |
+
def initialized(self) -> bool:
|
| 1246 |
+
"Returns whether the `GradientState` has been initialized"
|
| 1247 |
+
return GradientState._shared_state != {}
|
| 1248 |
+
|
| 1249 |
+
@property
|
| 1250 |
+
def end_of_dataloader(self) -> bool:
|
| 1251 |
+
"Returns whether we have reached the end of the current dataloader"
|
| 1252 |
+
if not self.in_dataloader:
|
| 1253 |
+
return False
|
| 1254 |
+
return self.active_dataloader.end_of_dataloader
|
| 1255 |
+
|
| 1256 |
+
@property
|
| 1257 |
+
def remainder(self) -> int:
|
| 1258 |
+
"Returns the number of extra samples that were added from padding the dataloader"
|
| 1259 |
+
if not self.in_dataloader:
|
| 1260 |
+
return -1
|
| 1261 |
+
return self.active_dataloader.remainder
|
| 1262 |
+
|
| 1263 |
+
def __repr__(self):
|
| 1264 |
+
return (
|
| 1265 |
+
f"Sync Gradients: {self.sync_gradients}\n"
|
| 1266 |
+
f"At end of current dataloader: {self.end_of_dataloader}\n"
|
| 1267 |
+
f"Extra samples added: {self.remainder}\n"
|
| 1268 |
+
f"Gradient accumulation plugin: {self.plugin_kwargs}\n"
|
| 1269 |
+
)
|
| 1270 |
+
|
| 1271 |
+
@property
|
| 1272 |
+
def is_xla_gradients_synced(self):
|
| 1273 |
+
"Returns the value of is_xla_gradients_synced. FSDP will always synchronize the gradients, hence is_xla_gradients_synced is always true."
|
| 1274 |
+
if parse_flag_from_env("ACCELERATE_USE_FSDP", default=False):
|
| 1275 |
+
return True
|
| 1276 |
+
return self._is_xla_gradients_synced
|
| 1277 |
+
|
| 1278 |
+
@is_xla_gradients_synced.setter
|
| 1279 |
+
def is_xla_gradients_synced(self, is_synced):
|
| 1280 |
+
"Set the _is_xla_gradients_synced attribute."
|
| 1281 |
+
self._is_xla_gradients_synced = is_synced
|
| 1282 |
+
|
| 1283 |
+
def _set_sync_gradients(self, sync_gradients):
|
| 1284 |
+
"Private function that sets whether gradients should be synchronized. Users should not have to call this."
|
| 1285 |
+
self.sync_gradients = sync_gradients
|
| 1286 |
+
# Allow grad-sync to automatically work on TPUs
|
| 1287 |
+
if (
|
| 1288 |
+
self.sync_gradients
|
| 1289 |
+
and is_torch_xla_available(check_is_tpu=True)
|
| 1290 |
+
and PartialState().distributed_type == DistributedType.XLA
|
| 1291 |
+
):
|
| 1292 |
+
xm.mark_step()
|
| 1293 |
+
|
| 1294 |
+
def _add_dataloader(self, dataloader):
|
| 1295 |
+
"Private function that adds a dataloader to `self.dataloader_references` and sets `in_dataloader` to `True`. Users should not have to call this."
|
| 1296 |
+
# We explicitly use assignment to ensure that the property setter is triggered, which is required for garbage collection.
|
| 1297 |
+
# Avoid using self.dataloader_references.append as it will not trigger the setter.
|
| 1298 |
+
self.dataloader_references += [dataloader]
|
| 1299 |
+
|
| 1300 |
+
def _remove_dataloader(self, dataloader):
|
| 1301 |
+
"Private function that removes a dataloader from `self.dataloader_references` and sets `in_dataloader` to `False` if there are no more dataloaders. Users should not have to call this."
|
| 1302 |
+
# We explicitly use assignment to ensure that the property setter is triggered.
|
| 1303 |
+
self.dataloader_references = [
|
| 1304 |
+
dataloader_ref for dataloader_ref in self.dataloader_references if dataloader_ref != dataloader
|
| 1305 |
+
]
|
| 1306 |
+
|
| 1307 |
+
@property
|
| 1308 |
+
def active_dataloader(self):
|
| 1309 |
+
return self.dataloader_references[-1]
|
| 1310 |
+
|
| 1311 |
+
@property
|
| 1312 |
+
def dataloader_references(self):
|
| 1313 |
+
# We use a property getter and setter with weakrefs to avoid circular references that prevent garbage collection
|
| 1314 |
+
return [reference() if reference is not None else reference for reference in self._dataloader_references_ref]
|
| 1315 |
+
|
| 1316 |
+
@dataloader_references.setter
|
| 1317 |
+
def dataloader_references(self, references):
|
| 1318 |
+
self._dataloader_references_ref = [
|
| 1319 |
+
weakref.ref(dataloader) if dataloader is not None else dataloader for dataloader in references
|
| 1320 |
+
]
|
| 1321 |
+
|
| 1322 |
+
@property
|
| 1323 |
+
def in_dataloader(self) -> bool:
|
| 1324 |
+
"Returns whether the current process is in a dataloader"
|
| 1325 |
+
return self.active_dataloader is not None
|
| 1326 |
+
|
| 1327 |
+
@staticmethod
|
| 1328 |
+
def _reset_state():
|
| 1329 |
+
"Resets `_shared_state`, is used internally and should not be called"
|
| 1330 |
+
GradientState._shared_state.clear()
|
venv/Lib/site-packages/accelerate/tracking.py
ADDED
|
@@ -0,0 +1,1089 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# Expectation:
|
| 16 |
+
# Provide a project dir name, then each type of logger gets stored in project/{`logging_dir`}
|
| 17 |
+
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
import time
|
| 21 |
+
from functools import wraps
|
| 22 |
+
from typing import Any, Optional, Union
|
| 23 |
+
|
| 24 |
+
import yaml
|
| 25 |
+
from packaging import version
|
| 26 |
+
|
| 27 |
+
from .logging import get_logger
|
| 28 |
+
from .state import PartialState
|
| 29 |
+
from .utils import (
|
| 30 |
+
LoggerType,
|
| 31 |
+
compare_versions,
|
| 32 |
+
is_aim_available,
|
| 33 |
+
is_clearml_available,
|
| 34 |
+
is_comet_ml_available,
|
| 35 |
+
is_dvclive_available,
|
| 36 |
+
is_mlflow_available,
|
| 37 |
+
is_tensorboard_available,
|
| 38 |
+
is_wandb_available,
|
| 39 |
+
listify,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
_available_trackers = []
|
| 44 |
+
|
| 45 |
+
if is_tensorboard_available():
|
| 46 |
+
_available_trackers.append(LoggerType.TENSORBOARD)
|
| 47 |
+
|
| 48 |
+
if is_wandb_available():
|
| 49 |
+
_available_trackers.append(LoggerType.WANDB)
|
| 50 |
+
|
| 51 |
+
if is_comet_ml_available():
|
| 52 |
+
_available_trackers.append(LoggerType.COMETML)
|
| 53 |
+
|
| 54 |
+
if is_aim_available():
|
| 55 |
+
_available_trackers.append(LoggerType.AIM)
|
| 56 |
+
|
| 57 |
+
if is_mlflow_available():
|
| 58 |
+
_available_trackers.append(LoggerType.MLFLOW)
|
| 59 |
+
|
| 60 |
+
if is_clearml_available():
|
| 61 |
+
_available_trackers.append(LoggerType.CLEARML)
|
| 62 |
+
|
| 63 |
+
if is_dvclive_available():
|
| 64 |
+
_available_trackers.append(LoggerType.DVCLIVE)
|
| 65 |
+
|
| 66 |
+
logger = get_logger(__name__)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def on_main_process(function):
|
| 70 |
+
"""
|
| 71 |
+
Decorator to selectively run the decorated function on the main process only based on the `main_process_only`
|
| 72 |
+
attribute in a class.
|
| 73 |
+
|
| 74 |
+
Checks at function execution rather than initialization time, not triggering the initialization of the
|
| 75 |
+
`PartialState`.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
@wraps(function)
|
| 79 |
+
def execute_on_main_process(self, *args, **kwargs):
|
| 80 |
+
if getattr(self, "main_process_only", False):
|
| 81 |
+
return PartialState().on_main_process(function)(self, *args, **kwargs)
|
| 82 |
+
else:
|
| 83 |
+
return function(self, *args, **kwargs)
|
| 84 |
+
|
| 85 |
+
return execute_on_main_process
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def get_available_trackers():
|
| 89 |
+
"Returns a list of all supported available trackers in the system"
|
| 90 |
+
return _available_trackers
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class GeneralTracker:
|
| 94 |
+
"""
|
| 95 |
+
A base Tracker class to be used for all logging integration implementations.
|
| 96 |
+
|
| 97 |
+
Each function should take in `**kwargs` that will automatically be passed in from a base dictionary provided to
|
| 98 |
+
[`Accelerator`].
|
| 99 |
+
|
| 100 |
+
Should implement `name`, `requires_logging_directory`, and `tracker` properties such that:
|
| 101 |
+
|
| 102 |
+
`name` (`str`): String representation of the tracker class name, such as "TensorBoard" `requires_logging_directory`
|
| 103 |
+
(`bool`): Whether the logger requires a directory to store their logs. `tracker` (`object`): Should return internal
|
| 104 |
+
tracking mechanism used by a tracker class (such as the `run` for wandb)
|
| 105 |
+
|
| 106 |
+
Implementations can also include a `main_process_only` (`bool`) attribute to toggle if relevent logging, init, and
|
| 107 |
+
other functions should occur on the main process or across all processes (by default will use `True`)
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
main_process_only = True
|
| 111 |
+
|
| 112 |
+
def __init__(self, _blank=False):
|
| 113 |
+
if not _blank:
|
| 114 |
+
err = ""
|
| 115 |
+
if not hasattr(self, "name"):
|
| 116 |
+
err += "`name`"
|
| 117 |
+
if not hasattr(self, "requires_logging_directory"):
|
| 118 |
+
if len(err) > 0:
|
| 119 |
+
err += ", "
|
| 120 |
+
err += "`requires_logging_directory`"
|
| 121 |
+
|
| 122 |
+
# as tracker is a @property that relies on post-init
|
| 123 |
+
if "tracker" not in dir(self):
|
| 124 |
+
if len(err) > 0:
|
| 125 |
+
err += ", "
|
| 126 |
+
err += "`tracker`"
|
| 127 |
+
if len(err) > 0:
|
| 128 |
+
raise NotImplementedError(
|
| 129 |
+
f"The implementation for this tracker class is missing the following "
|
| 130 |
+
f"required attributes. Please define them in the class definition: "
|
| 131 |
+
f"{err}"
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
def store_init_configuration(self, values: dict):
|
| 135 |
+
"""
|
| 136 |
+
Logs `values` as hyperparameters for the run. Implementations should use the experiment configuration
|
| 137 |
+
functionality of a tracking API.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
values (Dictionary `str` to `bool`, `str`, `float` or `int`):
|
| 141 |
+
Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
|
| 142 |
+
`str`, `float`, `int`, or `None`.
|
| 143 |
+
"""
|
| 144 |
+
pass
|
| 145 |
+
|
| 146 |
+
def log(self, values: dict, step: Optional[int], **kwargs):
|
| 147 |
+
"""
|
| 148 |
+
Logs `values` to the current run. Base `log` implementations of a tracking API should go in here, along with
|
| 149 |
+
special behavior for the `step parameter.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
values (Dictionary `str` to `str`, `float`, or `int`):
|
| 153 |
+
Values to be logged as key-value pairs. The values need to have type `str`, `float`, or `int`.
|
| 154 |
+
step (`int`, *optional*):
|
| 155 |
+
The run step. If included, the log will be affiliated with this step.
|
| 156 |
+
"""
|
| 157 |
+
pass
|
| 158 |
+
|
| 159 |
+
def finish(self):
|
| 160 |
+
"""
|
| 161 |
+
Should run any finalizing functions within the tracking API. If the API should not have one, just don't
|
| 162 |
+
overwrite that method.
|
| 163 |
+
"""
|
| 164 |
+
pass
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class TensorBoardTracker(GeneralTracker):
|
| 168 |
+
"""
|
| 169 |
+
A `Tracker` class that supports `tensorboard`. Should be initialized at the start of your script.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
run_name (`str`):
|
| 173 |
+
The name of the experiment run
|
| 174 |
+
logging_dir (`str`, `os.PathLike`):
|
| 175 |
+
Location for TensorBoard logs to be stored.
|
| 176 |
+
**kwargs (additional keyword arguments, *optional*):
|
| 177 |
+
Additional key word arguments passed along to the `tensorboard.SummaryWriter.__init__` method.
|
| 178 |
+
"""
|
| 179 |
+
|
| 180 |
+
name = "tensorboard"
|
| 181 |
+
requires_logging_directory = True
|
| 182 |
+
|
| 183 |
+
@on_main_process
|
| 184 |
+
def __init__(self, run_name: str, logging_dir: Union[str, os.PathLike], **kwargs):
|
| 185 |
+
try:
|
| 186 |
+
from torch.utils import tensorboard
|
| 187 |
+
except ModuleNotFoundError:
|
| 188 |
+
import tensorboardX as tensorboard
|
| 189 |
+
super().__init__()
|
| 190 |
+
self.run_name = run_name
|
| 191 |
+
self.logging_dir = os.path.join(logging_dir, run_name)
|
| 192 |
+
self.writer = tensorboard.SummaryWriter(self.logging_dir, **kwargs)
|
| 193 |
+
logger.debug(f"Initialized TensorBoard project {self.run_name} logging to {self.logging_dir}")
|
| 194 |
+
logger.debug(
|
| 195 |
+
"Make sure to log any initial configurations with `self.store_init_configuration` before training!"
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
@property
|
| 199 |
+
def tracker(self):
|
| 200 |
+
return self.writer
|
| 201 |
+
|
| 202 |
+
@on_main_process
|
| 203 |
+
def store_init_configuration(self, values: dict):
|
| 204 |
+
"""
|
| 205 |
+
Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment. Stores the
|
| 206 |
+
hyperparameters in a yaml file for future use.
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
values (Dictionary `str` to `bool`, `str`, `float` or `int`):
|
| 210 |
+
Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
|
| 211 |
+
`str`, `float`, `int`, or `None`.
|
| 212 |
+
"""
|
| 213 |
+
self.writer.add_hparams(values, metric_dict={})
|
| 214 |
+
self.writer.flush()
|
| 215 |
+
project_run_name = time.time()
|
| 216 |
+
dir_name = os.path.join(self.logging_dir, str(project_run_name))
|
| 217 |
+
os.makedirs(dir_name, exist_ok=True)
|
| 218 |
+
with open(os.path.join(dir_name, "hparams.yml"), "w") as outfile:
|
| 219 |
+
try:
|
| 220 |
+
yaml.dump(values, outfile)
|
| 221 |
+
except yaml.representer.RepresenterError:
|
| 222 |
+
logger.error("Serialization to store hyperparameters failed")
|
| 223 |
+
raise
|
| 224 |
+
logger.debug("Stored initial configuration hyperparameters to TensorBoard and hparams yaml file")
|
| 225 |
+
|
| 226 |
+
@on_main_process
|
| 227 |
+
def log(self, values: dict, step: Optional[int] = None, **kwargs):
|
| 228 |
+
"""
|
| 229 |
+
Logs `values` to the current run.
|
| 230 |
+
|
| 231 |
+
Args:
|
| 232 |
+
values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`):
|
| 233 |
+
Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of
|
| 234 |
+
`str` to `float`/`int`.
|
| 235 |
+
step (`int`, *optional*):
|
| 236 |
+
The run step. If included, the log will be affiliated with this step.
|
| 237 |
+
kwargs:
|
| 238 |
+
Additional key word arguments passed along to either `SummaryWriter.add_scaler`,
|
| 239 |
+
`SummaryWriter.add_text`, or `SummaryWriter.add_scalers` method based on the contents of `values`.
|
| 240 |
+
"""
|
| 241 |
+
values = listify(values)
|
| 242 |
+
for k, v in values.items():
|
| 243 |
+
if isinstance(v, (int, float)):
|
| 244 |
+
self.writer.add_scalar(k, v, global_step=step, **kwargs)
|
| 245 |
+
elif isinstance(v, str):
|
| 246 |
+
self.writer.add_text(k, v, global_step=step, **kwargs)
|
| 247 |
+
elif isinstance(v, dict):
|
| 248 |
+
self.writer.add_scalars(k, v, global_step=step, **kwargs)
|
| 249 |
+
self.writer.flush()
|
| 250 |
+
logger.debug("Successfully logged to TensorBoard")
|
| 251 |
+
|
| 252 |
+
@on_main_process
|
| 253 |
+
def log_images(self, values: dict, step: Optional[int], **kwargs):
|
| 254 |
+
"""
|
| 255 |
+
Logs `images` to the current run.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
values (Dictionary `str` to `List` of `np.ndarray` or `PIL.Image`):
|
| 259 |
+
Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or
|
| 260 |
+
step (`int`, *optional*):
|
| 261 |
+
The run step. If included, the log will be affiliated with this step.
|
| 262 |
+
kwargs:
|
| 263 |
+
Additional key word arguments passed along to the `SummaryWriter.add_image` method.
|
| 264 |
+
"""
|
| 265 |
+
for k, v in values.items():
|
| 266 |
+
self.writer.add_images(k, v, global_step=step, **kwargs)
|
| 267 |
+
logger.debug("Successfully logged images to TensorBoard")
|
| 268 |
+
|
| 269 |
+
@on_main_process
|
| 270 |
+
def finish(self):
|
| 271 |
+
"""
|
| 272 |
+
Closes `TensorBoard` writer
|
| 273 |
+
"""
|
| 274 |
+
self.writer.close()
|
| 275 |
+
logger.debug("TensorBoard writer closed")
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
class WandBTracker(GeneralTracker):
|
| 279 |
+
"""
|
| 280 |
+
A `Tracker` class that supports `wandb`. Should be initialized at the start of your script.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
run_name (`str`):
|
| 284 |
+
The name of the experiment run.
|
| 285 |
+
**kwargs (additional keyword arguments, *optional*):
|
| 286 |
+
Additional key word arguments passed along to the `wandb.init` method.
|
| 287 |
+
"""
|
| 288 |
+
|
| 289 |
+
name = "wandb"
|
| 290 |
+
requires_logging_directory = False
|
| 291 |
+
main_process_only = False
|
| 292 |
+
|
| 293 |
+
@on_main_process
|
| 294 |
+
def __init__(self, run_name: str, **kwargs):
|
| 295 |
+
super().__init__()
|
| 296 |
+
self.run_name = run_name
|
| 297 |
+
|
| 298 |
+
import wandb
|
| 299 |
+
|
| 300 |
+
self.run = wandb.init(project=self.run_name, **kwargs)
|
| 301 |
+
logger.debug(f"Initialized WandB project {self.run_name}")
|
| 302 |
+
logger.debug(
|
| 303 |
+
"Make sure to log any initial configurations with `self.store_init_configuration` before training!"
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
@property
|
| 307 |
+
def tracker(self):
|
| 308 |
+
return self.run
|
| 309 |
+
|
| 310 |
+
@on_main_process
|
| 311 |
+
def store_init_configuration(self, values: dict):
|
| 312 |
+
"""
|
| 313 |
+
Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.
|
| 314 |
+
|
| 315 |
+
Args:
|
| 316 |
+
values (Dictionary `str` to `bool`, `str`, `float` or `int`):
|
| 317 |
+
Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
|
| 318 |
+
`str`, `float`, `int`, or `None`.
|
| 319 |
+
"""
|
| 320 |
+
import wandb
|
| 321 |
+
|
| 322 |
+
wandb.config.update(values, allow_val_change=True)
|
| 323 |
+
logger.debug("Stored initial configuration hyperparameters to WandB")
|
| 324 |
+
|
| 325 |
+
@on_main_process
|
| 326 |
+
def log(self, values: dict, step: Optional[int] = None, **kwargs):
|
| 327 |
+
"""
|
| 328 |
+
Logs `values` to the current run.
|
| 329 |
+
|
| 330 |
+
Args:
|
| 331 |
+
values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`):
|
| 332 |
+
Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of
|
| 333 |
+
`str` to `float`/`int`.
|
| 334 |
+
step (`int`, *optional*):
|
| 335 |
+
The run step. If included, the log will be affiliated with this step.
|
| 336 |
+
kwargs:
|
| 337 |
+
Additional key word arguments passed along to the `wandb.log` method.
|
| 338 |
+
"""
|
| 339 |
+
self.run.log(values, step=step, **kwargs)
|
| 340 |
+
logger.debug("Successfully logged to WandB")
|
| 341 |
+
|
| 342 |
+
@on_main_process
|
| 343 |
+
def log_images(self, values: dict, step: Optional[int] = None, **kwargs):
|
| 344 |
+
"""
|
| 345 |
+
Logs `images` to the current run.
|
| 346 |
+
|
| 347 |
+
Args:
|
| 348 |
+
values (Dictionary `str` to `List` of `np.ndarray` or `PIL.Image`):
|
| 349 |
+
Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or
|
| 350 |
+
step (`int`, *optional*):
|
| 351 |
+
The run step. If included, the log will be affiliated with this step.
|
| 352 |
+
kwargs:
|
| 353 |
+
Additional key word arguments passed along to the `wandb.log` method.
|
| 354 |
+
"""
|
| 355 |
+
import wandb
|
| 356 |
+
|
| 357 |
+
for k, v in values.items():
|
| 358 |
+
self.log({k: [wandb.Image(image) for image in v]}, step=step, **kwargs)
|
| 359 |
+
logger.debug("Successfully logged images to WandB")
|
| 360 |
+
|
| 361 |
+
@on_main_process
|
| 362 |
+
def log_table(
|
| 363 |
+
self,
|
| 364 |
+
table_name: str,
|
| 365 |
+
columns: list[str] = None,
|
| 366 |
+
data: list[list[Any]] = None,
|
| 367 |
+
dataframe: Any = None,
|
| 368 |
+
step: Optional[int] = None,
|
| 369 |
+
**kwargs,
|
| 370 |
+
):
|
| 371 |
+
"""
|
| 372 |
+
Log a Table containing any object type (text, image, audio, video, molecule, html, etc). Can be defined either
|
| 373 |
+
with `columns` and `data` or with `dataframe`.
|
| 374 |
+
|
| 375 |
+
Args:
|
| 376 |
+
table_name (`str`):
|
| 377 |
+
The name to give to the logged table on the wandb workspace
|
| 378 |
+
columns (list of `str`, *optional*):
|
| 379 |
+
The name of the columns on the table
|
| 380 |
+
data (List of List of Any data type, *optional*):
|
| 381 |
+
The data to be logged in the table
|
| 382 |
+
dataframe (Any data type, *optional*):
|
| 383 |
+
The data to be logged in the table
|
| 384 |
+
step (`int`, *optional*):
|
| 385 |
+
The run step. If included, the log will be affiliated with this step.
|
| 386 |
+
"""
|
| 387 |
+
import wandb
|
| 388 |
+
|
| 389 |
+
values = {table_name: wandb.Table(columns=columns, data=data, dataframe=dataframe)}
|
| 390 |
+
self.log(values, step=step, **kwargs)
|
| 391 |
+
|
| 392 |
+
@on_main_process
|
| 393 |
+
def finish(self):
|
| 394 |
+
"""
|
| 395 |
+
Closes `wandb` writer
|
| 396 |
+
"""
|
| 397 |
+
self.run.finish()
|
| 398 |
+
logger.debug("WandB run closed")
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
class CometMLTracker(GeneralTracker):
|
| 402 |
+
"""
|
| 403 |
+
A `Tracker` class that supports `comet_ml`. Should be initialized at the start of your script.
|
| 404 |
+
|
| 405 |
+
API keys must be stored in a Comet config file.
|
| 406 |
+
|
| 407 |
+
Note:
|
| 408 |
+
For `comet_ml` versions < 3.41.0, additional keyword arguments are passed to `comet_ml.Experiment` instead:
|
| 409 |
+
https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment/#comet_ml.Experiment.__init__
|
| 410 |
+
|
| 411 |
+
Args:
|
| 412 |
+
run_name (`str`):
|
| 413 |
+
The name of the experiment run.
|
| 414 |
+
**kwargs (additional keyword arguments, *optional*):
|
| 415 |
+
Additional key word arguments passed along to the `comet_ml.start` method:
|
| 416 |
+
https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/start/
|
| 417 |
+
"""
|
| 418 |
+
|
| 419 |
+
name = "comet_ml"
|
| 420 |
+
requires_logging_directory = False
|
| 421 |
+
|
| 422 |
+
@on_main_process
|
| 423 |
+
def __init__(self, run_name: str, **kwargs):
|
| 424 |
+
super().__init__()
|
| 425 |
+
self.run_name = run_name
|
| 426 |
+
|
| 427 |
+
import comet_ml
|
| 428 |
+
|
| 429 |
+
comet_version = version.parse(comet_ml.__version__)
|
| 430 |
+
if compare_versions(comet_version, ">=", "3.41.0"):
|
| 431 |
+
self.writer = comet_ml.start(project_name=run_name, **kwargs)
|
| 432 |
+
else:
|
| 433 |
+
logger.info("Update `comet_ml` (>=3.41.0) for experiment reuse and offline support.")
|
| 434 |
+
self.writer = comet_ml.Experiment(project_name=run_name, **kwargs)
|
| 435 |
+
|
| 436 |
+
logger.debug(f"Initialized CometML project {self.run_name}")
|
| 437 |
+
logger.debug(
|
| 438 |
+
"Make sure to log any initial configurations with `self.store_init_configuration` before training!"
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
@property
|
| 442 |
+
def tracker(self):
|
| 443 |
+
return self.writer
|
| 444 |
+
|
| 445 |
+
@on_main_process
|
| 446 |
+
def store_init_configuration(self, values: dict):
|
| 447 |
+
"""
|
| 448 |
+
Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.
|
| 449 |
+
|
| 450 |
+
Args:
|
| 451 |
+
values (Dictionary `str` to `bool`, `str`, `float` or `int`):
|
| 452 |
+
Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
|
| 453 |
+
`str`, `float`, `int`, or `None`.
|
| 454 |
+
"""
|
| 455 |
+
self.writer.log_parameters(values)
|
| 456 |
+
logger.debug("Stored initial configuration hyperparameters to Comet")
|
| 457 |
+
|
| 458 |
+
@on_main_process
|
| 459 |
+
def log(self, values: dict, step: Optional[int] = None, **kwargs):
|
| 460 |
+
"""
|
| 461 |
+
Logs `values` to the current run.
|
| 462 |
+
|
| 463 |
+
Args:
|
| 464 |
+
values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`):
|
| 465 |
+
Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of
|
| 466 |
+
`str` to `float`/`int`.
|
| 467 |
+
step (`int`, *optional*):
|
| 468 |
+
The run step. If included, the log will be affiliated with this step.
|
| 469 |
+
kwargs:
|
| 470 |
+
Additional key word arguments passed along to either `Experiment.log_metric`, `Experiment.log_other`,
|
| 471 |
+
or `Experiment.log_metrics` method based on the contents of `values`.
|
| 472 |
+
"""
|
| 473 |
+
if step is not None:
|
| 474 |
+
self.writer.set_step(step)
|
| 475 |
+
for k, v in values.items():
|
| 476 |
+
if isinstance(v, (int, float)):
|
| 477 |
+
self.writer.log_metric(k, v, step=step, **kwargs)
|
| 478 |
+
elif isinstance(v, str):
|
| 479 |
+
self.writer.log_other(k, v, **kwargs)
|
| 480 |
+
elif isinstance(v, dict):
|
| 481 |
+
self.writer.log_metrics(v, step=step, **kwargs)
|
| 482 |
+
logger.debug("Successfully logged to Comet")
|
| 483 |
+
|
| 484 |
+
@on_main_process
|
| 485 |
+
def finish(self):
|
| 486 |
+
"""
|
| 487 |
+
Flush `comet-ml` writer
|
| 488 |
+
"""
|
| 489 |
+
self.writer.end()
|
| 490 |
+
logger.debug("Comet run flushed")
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
class AimTracker(GeneralTracker):
|
| 494 |
+
"""
|
| 495 |
+
A `Tracker` class that supports `aim`. Should be initialized at the start of your script.
|
| 496 |
+
|
| 497 |
+
Args:
|
| 498 |
+
run_name (`str`):
|
| 499 |
+
The name of the experiment run.
|
| 500 |
+
**kwargs (additional keyword arguments, *optional*):
|
| 501 |
+
Additional key word arguments passed along to the `Run.__init__` method.
|
| 502 |
+
"""
|
| 503 |
+
|
| 504 |
+
name = "aim"
|
| 505 |
+
requires_logging_directory = True
|
| 506 |
+
|
| 507 |
+
@on_main_process
|
| 508 |
+
def __init__(self, run_name: str, logging_dir: Optional[Union[str, os.PathLike]] = ".", **kwargs):
|
| 509 |
+
self.run_name = run_name
|
| 510 |
+
|
| 511 |
+
from aim import Run
|
| 512 |
+
|
| 513 |
+
self.writer = Run(repo=logging_dir, **kwargs)
|
| 514 |
+
self.writer.name = self.run_name
|
| 515 |
+
logger.debug(f"Initialized Aim project {self.run_name}")
|
| 516 |
+
logger.debug(
|
| 517 |
+
"Make sure to log any initial configurations with `self.store_init_configuration` before training!"
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
@property
|
| 521 |
+
def tracker(self):
|
| 522 |
+
return self.writer
|
| 523 |
+
|
| 524 |
+
@on_main_process
|
| 525 |
+
def store_init_configuration(self, values: dict):
|
| 526 |
+
"""
|
| 527 |
+
Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.
|
| 528 |
+
|
| 529 |
+
Args:
|
| 530 |
+
values (`dict`):
|
| 531 |
+
Values to be stored as initial hyperparameters as key-value pairs.
|
| 532 |
+
"""
|
| 533 |
+
self.writer["hparams"] = values
|
| 534 |
+
|
| 535 |
+
@on_main_process
|
| 536 |
+
def log(self, values: dict, step: Optional[int], **kwargs):
|
| 537 |
+
"""
|
| 538 |
+
Logs `values` to the current run.
|
| 539 |
+
|
| 540 |
+
Args:
|
| 541 |
+
values (`dict`):
|
| 542 |
+
Values to be logged as key-value pairs.
|
| 543 |
+
step (`int`, *optional*):
|
| 544 |
+
The run step. If included, the log will be affiliated with this step.
|
| 545 |
+
kwargs:
|
| 546 |
+
Additional key word arguments passed along to the `Run.track` method.
|
| 547 |
+
"""
|
| 548 |
+
# Note: replace this with the dictionary support when merged
|
| 549 |
+
for key, value in values.items():
|
| 550 |
+
self.writer.track(value, name=key, step=step, **kwargs)
|
| 551 |
+
|
| 552 |
+
@on_main_process
|
| 553 |
+
def log_images(self, values: dict, step: Optional[int] = None, kwargs: Optional[dict[str, dict]] = None):
|
| 554 |
+
"""
|
| 555 |
+
Logs `images` to the current run.
|
| 556 |
+
|
| 557 |
+
Args:
|
| 558 |
+
values (`Dict[str, Union[np.ndarray, PIL.Image, Tuple[np.ndarray, str], Tuple[PIL.Image, str]]]`):
|
| 559 |
+
Values to be logged as key-value pairs. The values need to have type `np.ndarray` or PIL.Image. If a
|
| 560 |
+
tuple is provided, the first element should be the image and the second element should be the caption.
|
| 561 |
+
step (`int`, *optional*):
|
| 562 |
+
The run step. If included, the log will be affiliated with this step.
|
| 563 |
+
kwargs (`Dict[str, dict]`):
|
| 564 |
+
Additional key word arguments passed along to the `Run.Image` and `Run.track` method specified by the
|
| 565 |
+
keys `aim_image` and `track`, respectively.
|
| 566 |
+
"""
|
| 567 |
+
import aim
|
| 568 |
+
|
| 569 |
+
aim_image_kw = {}
|
| 570 |
+
track_kw = {}
|
| 571 |
+
|
| 572 |
+
if kwargs is not None:
|
| 573 |
+
aim_image_kw = kwargs.get("aim_image", {})
|
| 574 |
+
track_kw = kwargs.get("track", {})
|
| 575 |
+
|
| 576 |
+
for key, value in values.items():
|
| 577 |
+
if isinstance(value, tuple):
|
| 578 |
+
img, caption = value
|
| 579 |
+
else:
|
| 580 |
+
img, caption = value, ""
|
| 581 |
+
aim_image = aim.Image(img, caption=caption, **aim_image_kw)
|
| 582 |
+
self.writer.track(aim_image, name=key, step=step, **track_kw)
|
| 583 |
+
|
| 584 |
+
@on_main_process
|
| 585 |
+
def finish(self):
|
| 586 |
+
"""
|
| 587 |
+
Closes `aim` writer
|
| 588 |
+
"""
|
| 589 |
+
self.writer.close()
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
class MLflowTracker(GeneralTracker):
|
| 593 |
+
"""
|
| 594 |
+
A `Tracker` class that supports `mlflow`. Should be initialized at the start of your script.
|
| 595 |
+
|
| 596 |
+
Args:
|
| 597 |
+
experiment_name (`str`, *optional*):
|
| 598 |
+
Name of the experiment. Environment variable MLFLOW_EXPERIMENT_NAME has priority over this argument.
|
| 599 |
+
logging_dir (`str` or `os.PathLike`, defaults to `"."`):
|
| 600 |
+
Location for mlflow logs to be stored.
|
| 601 |
+
run_id (`str`, *optional*):
|
| 602 |
+
If specified, get the run with the specified UUID and log parameters and metrics under that run. The run’s
|
| 603 |
+
end time is unset and its status is set to running, but the run’s other attributes (source_version,
|
| 604 |
+
source_type, etc.) are not changed. Environment variable MLFLOW_RUN_ID has priority over this argument.
|
| 605 |
+
tags (`Dict[str, str]`, *optional*):
|
| 606 |
+
An optional `dict` of `str` keys and values, or a `str` dump from a `dict`, to set as tags on the run. If a
|
| 607 |
+
run is being resumed, these tags are set on the resumed run. If a new run is being created, these tags are
|
| 608 |
+
set on the new run. Environment variable MLFLOW_TAGS has priority over this argument.
|
| 609 |
+
nested_run (`bool`, *optional*, defaults to `False`):
|
| 610 |
+
Controls whether run is nested in parent run. True creates a nested run. Environment variable
|
| 611 |
+
MLFLOW_NESTED_RUN has priority over this argument.
|
| 612 |
+
run_name (`str`, *optional*):
|
| 613 |
+
Name of new run (stored as a mlflow.runName tag). Used only when `run_id` is unspecified.
|
| 614 |
+
description (`str`, *optional*):
|
| 615 |
+
An optional string that populates the description box of the run. If a run is being resumed, the
|
| 616 |
+
description is set on the resumed run. If a new run is being created, the description is set on the new
|
| 617 |
+
run.
|
| 618 |
+
"""
|
| 619 |
+
|
| 620 |
+
name = "mlflow"
|
| 621 |
+
requires_logging_directory = False
|
| 622 |
+
|
| 623 |
+
@on_main_process
|
| 624 |
+
def __init__(
|
| 625 |
+
self,
|
| 626 |
+
experiment_name: str = None,
|
| 627 |
+
logging_dir: Optional[Union[str, os.PathLike]] = None,
|
| 628 |
+
run_id: Optional[str] = None,
|
| 629 |
+
tags: Optional[Union[dict[str, Any], str]] = None,
|
| 630 |
+
nested_run: Optional[bool] = False,
|
| 631 |
+
run_name: Optional[str] = None,
|
| 632 |
+
description: Optional[str] = None,
|
| 633 |
+
):
|
| 634 |
+
experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME", experiment_name)
|
| 635 |
+
run_id = os.environ.get("MLFLOW_RUN_ID", run_id)
|
| 636 |
+
tags = os.environ.get("MLFLOW_TAGS", tags)
|
| 637 |
+
if isinstance(tags, str):
|
| 638 |
+
tags = json.loads(tags)
|
| 639 |
+
|
| 640 |
+
nested_run = os.environ.get("MLFLOW_NESTED_RUN", nested_run)
|
| 641 |
+
|
| 642 |
+
import mlflow
|
| 643 |
+
|
| 644 |
+
exps = mlflow.search_experiments(filter_string=f"name = '{experiment_name}'")
|
| 645 |
+
if len(exps) > 0:
|
| 646 |
+
if len(exps) > 1:
|
| 647 |
+
logger.warning("Multiple experiments with the same name found. Using first one.")
|
| 648 |
+
experiment_id = exps[0].experiment_id
|
| 649 |
+
else:
|
| 650 |
+
experiment_id = mlflow.create_experiment(
|
| 651 |
+
name=experiment_name,
|
| 652 |
+
artifact_location=logging_dir,
|
| 653 |
+
tags=tags,
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
self.active_run = mlflow.start_run(
|
| 657 |
+
run_id=run_id,
|
| 658 |
+
experiment_id=experiment_id,
|
| 659 |
+
run_name=run_name,
|
| 660 |
+
nested=nested_run,
|
| 661 |
+
tags=tags,
|
| 662 |
+
description=description,
|
| 663 |
+
)
|
| 664 |
+
|
| 665 |
+
logger.debug(f"Initialized mlflow experiment {experiment_name}")
|
| 666 |
+
logger.debug(
|
| 667 |
+
"Make sure to log any initial configurations with `self.store_init_configuration` before training!"
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
@property
|
| 671 |
+
def tracker(self):
|
| 672 |
+
return self.active_run
|
| 673 |
+
|
| 674 |
+
@on_main_process
|
| 675 |
+
def store_init_configuration(self, values: dict):
|
| 676 |
+
"""
|
| 677 |
+
Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.
|
| 678 |
+
|
| 679 |
+
Args:
|
| 680 |
+
values (`dict`):
|
| 681 |
+
Values to be stored as initial hyperparameters as key-value pairs.
|
| 682 |
+
"""
|
| 683 |
+
import mlflow
|
| 684 |
+
|
| 685 |
+
for name, value in list(values.items()):
|
| 686 |
+
# internally, all values are converted to str in MLflow
|
| 687 |
+
if len(str(value)) > mlflow.utils.validation.MAX_PARAM_VAL_LENGTH:
|
| 688 |
+
logger.warning_once(
|
| 689 |
+
f'Accelerate is attempting to log a value of "{value}" for key "{name}" as a parameter. MLflow\'s'
|
| 690 |
+
f" log_param() only accepts values no longer than {mlflow.utils.validation.MAX_PARAM_VAL_LENGTH} characters so we dropped this attribute."
|
| 691 |
+
)
|
| 692 |
+
del values[name]
|
| 693 |
+
|
| 694 |
+
values_list = list(values.items())
|
| 695 |
+
|
| 696 |
+
# MLflow cannot log more than 100 values in one go, so we have to split it
|
| 697 |
+
for i in range(0, len(values_list), mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH):
|
| 698 |
+
mlflow.log_params(dict(values_list[i : i + mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH]))
|
| 699 |
+
|
| 700 |
+
logger.debug("Stored initial configuration hyperparameters to MLflow")
|
| 701 |
+
|
| 702 |
+
@on_main_process
|
| 703 |
+
def log(self, values: dict, step: Optional[int]):
|
| 704 |
+
"""
|
| 705 |
+
Logs `values` to the current run.
|
| 706 |
+
|
| 707 |
+
Args:
|
| 708 |
+
values (`dict`):
|
| 709 |
+
Values to be logged as key-value pairs.
|
| 710 |
+
step (`int`, *optional*):
|
| 711 |
+
The run step. If included, the log will be affiliated with this step.
|
| 712 |
+
"""
|
| 713 |
+
metrics = {}
|
| 714 |
+
for k, v in values.items():
|
| 715 |
+
if isinstance(v, (int, float)):
|
| 716 |
+
metrics[k] = v
|
| 717 |
+
else:
|
| 718 |
+
logger.warning_once(
|
| 719 |
+
f'MLflowTracker is attempting to log a value of "{v}" of type {type(v)} for key "{k}" as a metric. '
|
| 720 |
+
"MLflow's log_metric() only accepts float and int types so we dropped this attribute."
|
| 721 |
+
)
|
| 722 |
+
import mlflow
|
| 723 |
+
|
| 724 |
+
mlflow.log_metrics(metrics, step=step)
|
| 725 |
+
logger.debug("Successfully logged to mlflow")
|
| 726 |
+
|
| 727 |
+
@on_main_process
|
| 728 |
+
def log_figure(self, figure: Any, artifact_file: str, **save_kwargs):
|
| 729 |
+
"""
|
| 730 |
+
Logs an figure to the current run.
|
| 731 |
+
|
| 732 |
+
Args:
|
| 733 |
+
figure (Any):
|
| 734 |
+
The figure to be logged.
|
| 735 |
+
artifact_file (`str`, *optional*):
|
| 736 |
+
The run-relative artifact file path in posixpath format to which the image is saved.
|
| 737 |
+
If not provided, the image is saved to a default location.
|
| 738 |
+
**kwargs:
|
| 739 |
+
Additional keyword arguments passed to the underlying mlflow.log_image function.
|
| 740 |
+
"""
|
| 741 |
+
import mlflow
|
| 742 |
+
|
| 743 |
+
mlflow.log_figure(figure=figure, artifact_file=artifact_file, **save_kwargs)
|
| 744 |
+
logger.debug("Successfully logged image to mlflow")
|
| 745 |
+
|
| 746 |
+
@on_main_process
|
| 747 |
+
def log_artifacts(self, local_dir: str, artifact_path: Optional[str] = None):
|
| 748 |
+
"""
|
| 749 |
+
Logs an artifacts (all content of a dir) to the current run.
|
| 750 |
+
|
| 751 |
+
local_dir (`str`):
|
| 752 |
+
Path to the directory to be logged as an artifact.
|
| 753 |
+
artifact_path (`str`, *optional*):
|
| 754 |
+
Directory within the run's artifact directory where the artifact will be logged. If omitted, the
|
| 755 |
+
artifact will be logged to the root of the run's artifact directory. The run step. If included, the
|
| 756 |
+
artifact will be affiliated with this step.
|
| 757 |
+
"""
|
| 758 |
+
import mlflow
|
| 759 |
+
|
| 760 |
+
mlflow.log_artifacts(local_dir=local_dir, artifact_path=artifact_path)
|
| 761 |
+
logger.debug("Successfully logged artofact to mlflow")
|
| 762 |
+
|
| 763 |
+
@on_main_process
|
| 764 |
+
def log_artifact(self, local_path: str, artifact_path: Optional[str] = None):
|
| 765 |
+
"""
|
| 766 |
+
Logs an artifact (file) to the current run.
|
| 767 |
+
|
| 768 |
+
local_path (`str`):
|
| 769 |
+
Path to the file to be logged as an artifact.
|
| 770 |
+
artifact_path (`str`, *optional*):
|
| 771 |
+
Directory within the run's artifact directory where the artifact will be logged. If omitted, the
|
| 772 |
+
artifact will be logged to the root of the run's artifact directory. The run step. If included, the
|
| 773 |
+
artifact will be affiliated with this step.
|
| 774 |
+
"""
|
| 775 |
+
import mlflow
|
| 776 |
+
|
| 777 |
+
mlflow.log_artifact(local_path=local_path, artifact_path=artifact_path)
|
| 778 |
+
logger.debug("Successfully logged artofact to mlflow")
|
| 779 |
+
|
| 780 |
+
@on_main_process
|
| 781 |
+
def finish(self):
|
| 782 |
+
"""
|
| 783 |
+
End the active MLflow run.
|
| 784 |
+
"""
|
| 785 |
+
import mlflow
|
| 786 |
+
|
| 787 |
+
mlflow.end_run()
|
| 788 |
+
|
| 789 |
+
|
| 790 |
+
class ClearMLTracker(GeneralTracker):
|
| 791 |
+
"""
|
| 792 |
+
A `Tracker` class that supports `clearml`. Should be initialized at the start of your script.
|
| 793 |
+
|
| 794 |
+
Args:
|
| 795 |
+
run_name (`str`, *optional*):
|
| 796 |
+
Name of the experiment. Environment variables `CLEARML_PROJECT` and `CLEARML_TASK` have priority over this
|
| 797 |
+
argument.
|
| 798 |
+
**kwargs (additional keyword arguments, *optional*):
|
| 799 |
+
Kwargs passed along to the `Task.__init__` method.
|
| 800 |
+
"""
|
| 801 |
+
|
| 802 |
+
name = "clearml"
|
| 803 |
+
requires_logging_directory = False
|
| 804 |
+
|
| 805 |
+
@on_main_process
|
| 806 |
+
def __init__(self, run_name: str = None, **kwargs):
|
| 807 |
+
from clearml import Task
|
| 808 |
+
|
| 809 |
+
current_task = Task.current_task()
|
| 810 |
+
self._initialized_externally = False
|
| 811 |
+
if current_task:
|
| 812 |
+
self._initialized_externally = True
|
| 813 |
+
self.task = current_task
|
| 814 |
+
return
|
| 815 |
+
|
| 816 |
+
kwargs.setdefault("project_name", os.environ.get("CLEARML_PROJECT", run_name))
|
| 817 |
+
kwargs.setdefault("task_name", os.environ.get("CLEARML_TASK", run_name))
|
| 818 |
+
self.task = Task.init(**kwargs)
|
| 819 |
+
|
| 820 |
+
@property
|
| 821 |
+
def tracker(self):
|
| 822 |
+
return self.task
|
| 823 |
+
|
| 824 |
+
@on_main_process
|
| 825 |
+
def store_init_configuration(self, values: dict):
|
| 826 |
+
"""
|
| 827 |
+
Connect configuration dictionary to the Task object. Should be run at the beginning of your experiment.
|
| 828 |
+
|
| 829 |
+
Args:
|
| 830 |
+
values (`dict`):
|
| 831 |
+
Values to be stored as initial hyperparameters as key-value pairs.
|
| 832 |
+
"""
|
| 833 |
+
return self.task.connect_configuration(values)
|
| 834 |
+
|
| 835 |
+
@on_main_process
|
| 836 |
+
def log(self, values: dict[str, Union[int, float]], step: Optional[int] = None, **kwargs):
|
| 837 |
+
"""
|
| 838 |
+
Logs `values` dictionary to the current run. The dictionary keys must be strings. The dictionary values must be
|
| 839 |
+
ints or floats
|
| 840 |
+
|
| 841 |
+
Args:
|
| 842 |
+
values (`Dict[str, Union[int, float]]`):
|
| 843 |
+
Values to be logged as key-value pairs. If the key starts with 'eval_'/'test_'/'train_', the value will
|
| 844 |
+
be reported under the 'eval'/'test'/'train' series and the respective prefix will be removed.
|
| 845 |
+
Otherwise, the value will be reported under the 'train' series, and no prefix will be removed.
|
| 846 |
+
step (`int`, *optional*):
|
| 847 |
+
If specified, the values will be reported as scalars, with the iteration number equal to `step`.
|
| 848 |
+
Otherwise they will be reported as single values.
|
| 849 |
+
kwargs:
|
| 850 |
+
Additional key word arguments passed along to the `clearml.Logger.report_single_value` or
|
| 851 |
+
`clearml.Logger.report_scalar` methods.
|
| 852 |
+
"""
|
| 853 |
+
clearml_logger = self.task.get_logger()
|
| 854 |
+
for k, v in values.items():
|
| 855 |
+
if not isinstance(v, (int, float)):
|
| 856 |
+
logger.warning_once(
|
| 857 |
+
"Accelerator is attempting to log a value of "
|
| 858 |
+
f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
|
| 859 |
+
"This invocation of ClearML logger's report_scalar() "
|
| 860 |
+
"is incorrect so we dropped this attribute."
|
| 861 |
+
)
|
| 862 |
+
continue
|
| 863 |
+
if step is None:
|
| 864 |
+
clearml_logger.report_single_value(name=k, value=v, **kwargs)
|
| 865 |
+
continue
|
| 866 |
+
title, series = ClearMLTracker._get_title_series(k)
|
| 867 |
+
clearml_logger.report_scalar(title=title, series=series, value=v, iteration=step, **kwargs)
|
| 868 |
+
|
| 869 |
+
@on_main_process
|
| 870 |
+
def log_images(self, values: dict, step: Optional[int] = None, **kwargs):
|
| 871 |
+
"""
|
| 872 |
+
Logs `images` to the current run.
|
| 873 |
+
|
| 874 |
+
Args:
|
| 875 |
+
values (`Dict[str, List[Union[np.ndarray, PIL.Image]]`):
|
| 876 |
+
Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or
|
| 877 |
+
step (`int`, *optional*):
|
| 878 |
+
The run step. If included, the log will be affiliated with this step.
|
| 879 |
+
kwargs:
|
| 880 |
+
Additional key word arguments passed along to the `clearml.Logger.report_image` method.
|
| 881 |
+
"""
|
| 882 |
+
clearml_logger = self.task.get_logger()
|
| 883 |
+
for k, v in values.items():
|
| 884 |
+
title, series = ClearMLTracker._get_title_series(k)
|
| 885 |
+
clearml_logger.report_image(title=title, series=series, iteration=step, image=v, **kwargs)
|
| 886 |
+
|
| 887 |
+
@on_main_process
|
| 888 |
+
def log_table(
|
| 889 |
+
self,
|
| 890 |
+
table_name: str,
|
| 891 |
+
columns: list[str] = None,
|
| 892 |
+
data: list[list[Any]] = None,
|
| 893 |
+
dataframe: Any = None,
|
| 894 |
+
step: Optional[int] = None,
|
| 895 |
+
**kwargs,
|
| 896 |
+
):
|
| 897 |
+
"""
|
| 898 |
+
Log a Table to the task. Can be defined eitherwith `columns` and `data` or with `dataframe`.
|
| 899 |
+
|
| 900 |
+
Args:
|
| 901 |
+
table_name (`str`):
|
| 902 |
+
The name of the table
|
| 903 |
+
columns (list of `str`, *optional*):
|
| 904 |
+
The name of the columns on the table
|
| 905 |
+
data (List of List of Any data type, *optional*):
|
| 906 |
+
The data to be logged in the table. If `columns` is not specified, then the first entry in data will be
|
| 907 |
+
the name of the columns of the table
|
| 908 |
+
dataframe (Any data type, *optional*):
|
| 909 |
+
The data to be logged in the table
|
| 910 |
+
step (`int`, *optional*):
|
| 911 |
+
The run step. If included, the log will be affiliated with this step.
|
| 912 |
+
kwargs:
|
| 913 |
+
Additional key word arguments passed along to the `clearml.Logger.report_table` method.
|
| 914 |
+
"""
|
| 915 |
+
to_report = dataframe
|
| 916 |
+
if dataframe is None:
|
| 917 |
+
if data is None:
|
| 918 |
+
raise ValueError(
|
| 919 |
+
"`ClearMLTracker.log_table` requires that `data` to be supplied if `dataframe` is `None`"
|
| 920 |
+
)
|
| 921 |
+
to_report = [columns] + data if columns else data
|
| 922 |
+
title, series = ClearMLTracker._get_title_series(table_name)
|
| 923 |
+
self.task.get_logger().report_table(title=title, series=series, table_plot=to_report, iteration=step, **kwargs)
|
| 924 |
+
|
| 925 |
+
@on_main_process
|
| 926 |
+
def finish(self):
|
| 927 |
+
"""
|
| 928 |
+
Close the ClearML task. If the task was initialized externally (e.g. by manually calling `Task.init`), this
|
| 929 |
+
function is a noop
|
| 930 |
+
"""
|
| 931 |
+
if self.task and not self._initialized_externally:
|
| 932 |
+
self.task.close()
|
| 933 |
+
|
| 934 |
+
@staticmethod
|
| 935 |
+
def _get_title_series(name):
|
| 936 |
+
for prefix in ["eval", "test", "train"]:
|
| 937 |
+
if name.startswith(prefix + "_"):
|
| 938 |
+
return name[len(prefix) + 1 :], prefix
|
| 939 |
+
return name, "train"
|
| 940 |
+
|
| 941 |
+
|
| 942 |
+
class DVCLiveTracker(GeneralTracker):
|
| 943 |
+
"""
|
| 944 |
+
A `Tracker` class that supports `dvclive`. Should be initialized at the start of your script.
|
| 945 |
+
|
| 946 |
+
Args:
|
| 947 |
+
run_name (`str`, *optional*):
|
| 948 |
+
Ignored for dvclive. See `kwargs` instead.
|
| 949 |
+
kwargs:
|
| 950 |
+
Additional key word arguments passed along to [`dvclive.Live()`](https://dvc.org/doc/dvclive/live).
|
| 951 |
+
|
| 952 |
+
Example:
|
| 953 |
+
|
| 954 |
+
```py
|
| 955 |
+
from accelerate import Accelerator
|
| 956 |
+
|
| 957 |
+
accelerator = Accelerator(log_with="dvclive")
|
| 958 |
+
accelerator.init_trackers(project_name="my_project", init_kwargs={"dvclive": {"dir": "my_directory"}})
|
| 959 |
+
```
|
| 960 |
+
"""
|
| 961 |
+
|
| 962 |
+
name = "dvclive"
|
| 963 |
+
requires_logging_directory = False
|
| 964 |
+
|
| 965 |
+
@on_main_process
|
| 966 |
+
def __init__(self, run_name: Optional[str] = None, live: Optional[Any] = None, **kwargs):
|
| 967 |
+
from dvclive import Live
|
| 968 |
+
|
| 969 |
+
super().__init__()
|
| 970 |
+
self.live = live if live is not None else Live(**kwargs)
|
| 971 |
+
|
| 972 |
+
@property
|
| 973 |
+
def tracker(self):
|
| 974 |
+
return self.live
|
| 975 |
+
|
| 976 |
+
@on_main_process
|
| 977 |
+
def store_init_configuration(self, values: dict):
|
| 978 |
+
"""
|
| 979 |
+
Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment. Stores the
|
| 980 |
+
hyperparameters in a yaml file for future use.
|
| 981 |
+
|
| 982 |
+
Args:
|
| 983 |
+
values (Dictionary `str` to `bool`, `str`, `float`, `int`, or a List or Dict of those types):
|
| 984 |
+
Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
|
| 985 |
+
`str`, `float`, or `int`.
|
| 986 |
+
"""
|
| 987 |
+
self.live.log_params(values)
|
| 988 |
+
|
| 989 |
+
@on_main_process
|
| 990 |
+
def log(self, values: dict, step: Optional[int] = None, **kwargs):
|
| 991 |
+
"""
|
| 992 |
+
Logs `values` to the current run.
|
| 993 |
+
|
| 994 |
+
Args:
|
| 995 |
+
values (Dictionary `str` to `str`, `float`, or `int`):
|
| 996 |
+
Values to be logged as key-value pairs. The values need to have type `str`, `float`, or `int`.
|
| 997 |
+
step (`int`, *optional*):
|
| 998 |
+
The run step. If included, the log will be affiliated with this step.
|
| 999 |
+
kwargs:
|
| 1000 |
+
Additional key word arguments passed along to `dvclive.Live.log_metric()`.
|
| 1001 |
+
"""
|
| 1002 |
+
from dvclive.plots import Metric
|
| 1003 |
+
|
| 1004 |
+
if step is not None:
|
| 1005 |
+
self.live.step = step
|
| 1006 |
+
for k, v in values.items():
|
| 1007 |
+
if Metric.could_log(v):
|
| 1008 |
+
self.live.log_metric(k, v, **kwargs)
|
| 1009 |
+
else:
|
| 1010 |
+
logger.warning_once(
|
| 1011 |
+
"Accelerator attempted to log a value of "
|
| 1012 |
+
f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
|
| 1013 |
+
"This invocation of DVCLive's Live.log_metric() "
|
| 1014 |
+
"is incorrect so we dropped this attribute."
|
| 1015 |
+
)
|
| 1016 |
+
self.live.next_step()
|
| 1017 |
+
|
| 1018 |
+
@on_main_process
|
| 1019 |
+
def finish(self):
|
| 1020 |
+
"""
|
| 1021 |
+
Closes `dvclive.Live()`.
|
| 1022 |
+
"""
|
| 1023 |
+
self.live.end()
|
| 1024 |
+
|
| 1025 |
+
|
| 1026 |
+
LOGGER_TYPE_TO_CLASS = {
|
| 1027 |
+
"aim": AimTracker,
|
| 1028 |
+
"comet_ml": CometMLTracker,
|
| 1029 |
+
"mlflow": MLflowTracker,
|
| 1030 |
+
"tensorboard": TensorBoardTracker,
|
| 1031 |
+
"wandb": WandBTracker,
|
| 1032 |
+
"clearml": ClearMLTracker,
|
| 1033 |
+
"dvclive": DVCLiveTracker,
|
| 1034 |
+
}
|
| 1035 |
+
|
| 1036 |
+
|
| 1037 |
+
def filter_trackers(
|
| 1038 |
+
log_with: list[Union[str, LoggerType, GeneralTracker]],
|
| 1039 |
+
logging_dir: Union[str, os.PathLike] = None,
|
| 1040 |
+
):
|
| 1041 |
+
"""
|
| 1042 |
+
Takes in a list of potential tracker types and checks that:
|
| 1043 |
+
- The tracker wanted is available in that environment
|
| 1044 |
+
- Filters out repeats of tracker types
|
| 1045 |
+
- If `all` is in `log_with`, will return all trackers in the environment
|
| 1046 |
+
- If a tracker requires a `logging_dir`, ensures that `logging_dir` is not `None`
|
| 1047 |
+
|
| 1048 |
+
Args:
|
| 1049 |
+
log_with (list of `str`, [`~utils.LoggerType`] or [`~tracking.GeneralTracker`], *optional*):
|
| 1050 |
+
A list of loggers to be setup for experiment tracking. Should be one or several of:
|
| 1051 |
+
|
| 1052 |
+
- `"all"`
|
| 1053 |
+
- `"tensorboard"`
|
| 1054 |
+
- `"wandb"`
|
| 1055 |
+
- `"comet_ml"`
|
| 1056 |
+
- `"mlflow"`
|
| 1057 |
+
- `"dvclive"`
|
| 1058 |
+
If `"all"` is selected, will pick up all available trackers in the environment and initialize them. Can
|
| 1059 |
+
also accept implementations of `GeneralTracker` for custom trackers, and can be combined with `"all"`.
|
| 1060 |
+
logging_dir (`str`, `os.PathLike`, *optional*):
|
| 1061 |
+
A path to a directory for storing logs of locally-compatible loggers.
|
| 1062 |
+
"""
|
| 1063 |
+
loggers = []
|
| 1064 |
+
if log_with is not None:
|
| 1065 |
+
if not isinstance(log_with, (list, tuple)):
|
| 1066 |
+
log_with = [log_with]
|
| 1067 |
+
if "all" in log_with or LoggerType.ALL in log_with:
|
| 1068 |
+
loggers = [o for o in log_with if issubclass(type(o), GeneralTracker)] + get_available_trackers()
|
| 1069 |
+
else:
|
| 1070 |
+
for log_type in log_with:
|
| 1071 |
+
if log_type not in LoggerType and not issubclass(type(log_type), GeneralTracker):
|
| 1072 |
+
raise ValueError(f"Unsupported logging capability: {log_type}. Choose between {LoggerType.list()}")
|
| 1073 |
+
if issubclass(type(log_type), GeneralTracker):
|
| 1074 |
+
loggers.append(log_type)
|
| 1075 |
+
else:
|
| 1076 |
+
log_type = LoggerType(log_type)
|
| 1077 |
+
if log_type not in loggers:
|
| 1078 |
+
if log_type in get_available_trackers():
|
| 1079 |
+
tracker_init = LOGGER_TYPE_TO_CLASS[str(log_type)]
|
| 1080 |
+
if tracker_init.requires_logging_directory:
|
| 1081 |
+
if logging_dir is None:
|
| 1082 |
+
raise ValueError(
|
| 1083 |
+
f"Logging with `{log_type}` requires a `logging_dir` to be passed in."
|
| 1084 |
+
)
|
| 1085 |
+
loggers.append(log_type)
|
| 1086 |
+
else:
|
| 1087 |
+
logger.debug(f"Tried adding logger {log_type}, but package is unavailable in the system.")
|
| 1088 |
+
|
| 1089 |
+
return loggers
|
venv/Lib/site-packages/adodbapi/__init__.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# nopycln: file # undecidable cases due to explicit re-exports https://github.com/hadialqattan/pycln/issues/205
|
| 2 |
+
"""adodbapi - A python DB API 2.0 (PEP 249) interface to Microsoft ADO
|
| 3 |
+
|
| 4 |
+
Copyright (C) 2002 Henrik Ekelund, version 2.1 by Vernon Cole
|
| 5 |
+
* https://sourceforge.net/projects/adodbapi
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import time
|
| 9 |
+
|
| 10 |
+
# Re-exports to keep backward compatibility with existing code
|
| 11 |
+
from .adodbapi import (
|
| 12 |
+
Connection as Connection,
|
| 13 |
+
Cursor as Cursor,
|
| 14 |
+
__version__,
|
| 15 |
+
connect as connect,
|
| 16 |
+
dateconverter,
|
| 17 |
+
)
|
| 18 |
+
from .apibase import (
|
| 19 |
+
BINARY as BINARY,
|
| 20 |
+
DATETIME as DATETIME,
|
| 21 |
+
NUMBER as NUMBER,
|
| 22 |
+
ROWID as ROWID,
|
| 23 |
+
STRING as STRING,
|
| 24 |
+
DatabaseError as DatabaseError,
|
| 25 |
+
DataError as DataError,
|
| 26 |
+
Error as Error,
|
| 27 |
+
FetchFailedError as FetchFailedError,
|
| 28 |
+
IntegrityError as IntegrityError,
|
| 29 |
+
InterfaceError as InterfaceError,
|
| 30 |
+
InternalError as InternalError,
|
| 31 |
+
NotSupportedError as NotSupportedError,
|
| 32 |
+
OperationalError as OperationalError,
|
| 33 |
+
ProgrammingError as ProgrammingError,
|
| 34 |
+
Warning as Warning,
|
| 35 |
+
apilevel as apilevel,
|
| 36 |
+
paramstyle as paramstyle,
|
| 37 |
+
threadsafety as threadsafety,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def Binary(aString):
|
| 42 |
+
"""This function constructs an object capable of holding a binary (long) string value."""
|
| 43 |
+
return bytes(aString)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def Date(year, month, day):
|
| 47 |
+
"This function constructs an object holding a date value."
|
| 48 |
+
return dateconverter.Date(year, month, day)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def Time(hour, minute, second):
|
| 52 |
+
"This function constructs an object holding a time value."
|
| 53 |
+
return dateconverter.Time(hour, minute, second)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def Timestamp(year, month, day, hour, minute, second):
|
| 57 |
+
"This function constructs an object holding a time stamp value."
|
| 58 |
+
return dateconverter.Timestamp(year, month, day, hour, minute, second)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def DateFromTicks(ticks):
|
| 62 |
+
"""This function constructs an object holding a date value from the given ticks value
|
| 63 |
+
(number of seconds since the epoch; see the documentation of the standard Python time module for details).
|
| 64 |
+
"""
|
| 65 |
+
return Date(*time.gmtime(ticks)[:3])
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def TimeFromTicks(ticks):
|
| 69 |
+
"""This function constructs an object holding a time value from the given ticks value
|
| 70 |
+
(number of seconds since the epoch; see the documentation of the standard Python time module for details).
|
| 71 |
+
"""
|
| 72 |
+
return Time(*time.gmtime(ticks)[3:6])
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def TimestampFromTicks(ticks):
|
| 76 |
+
"""This function constructs an object holding a time stamp value from the given
|
| 77 |
+
ticks value (number of seconds since the epoch;
|
| 78 |
+
see the documentation of the standard Python time module for details)."""
|
| 79 |
+
return Timestamp(*time.gmtime(ticks)[:6])
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
version = "adodbapi v" + __version__
|
venv/Lib/site-packages/adodbapi/ado_consts.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ADO enumerated constants documented on MSDN:
|
| 2 |
+
# https://learn.microsoft.com/en-us/sql/ado/reference/ado-api/ado-enumerated-constants
|
| 3 |
+
# TODO: Update to https://learn.microsoft.com/en-us/sql/ado/reference/ado-api/ado-enumerated-constants
|
| 4 |
+
|
| 5 |
+
# IsolationLevelEnum
|
| 6 |
+
adXactUnspecified = -1
|
| 7 |
+
adXactBrowse = 0x100
|
| 8 |
+
adXactChaos = 0x10
|
| 9 |
+
adXactCursorStability = 0x1000
|
| 10 |
+
adXactIsolated = 0x100000
|
| 11 |
+
adXactReadCommitted = 0x1000
|
| 12 |
+
adXactReadUncommitted = 0x100
|
| 13 |
+
adXactRepeatableRead = 0x10000
|
| 14 |
+
adXactSerializable = 0x100000
|
| 15 |
+
|
| 16 |
+
# CursorLocationEnum
|
| 17 |
+
adUseClient = 3
|
| 18 |
+
adUseServer = 2
|
| 19 |
+
|
| 20 |
+
# CursorTypeEnum
|
| 21 |
+
adOpenDynamic = 2
|
| 22 |
+
adOpenForwardOnly = 0
|
| 23 |
+
adOpenKeyset = 1
|
| 24 |
+
adOpenStatic = 3
|
| 25 |
+
adOpenUnspecified = -1
|
| 26 |
+
|
| 27 |
+
# CommandTypeEnum
|
| 28 |
+
adCmdText = 1
|
| 29 |
+
adCmdStoredProc = 4
|
| 30 |
+
adSchemaTables = 20
|
| 31 |
+
|
| 32 |
+
# ParameterDirectionEnum
|
| 33 |
+
adParamInput = 1
|
| 34 |
+
adParamInputOutput = 3
|
| 35 |
+
adParamOutput = 2
|
| 36 |
+
adParamReturnValue = 4
|
| 37 |
+
adParamUnknown = 0
|
| 38 |
+
directions = {
|
| 39 |
+
0: "Unknown",
|
| 40 |
+
1: "Input",
|
| 41 |
+
2: "Output",
|
| 42 |
+
3: "InputOutput",
|
| 43 |
+
4: "Return",
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def ado_direction_name(ado_dir):
|
| 48 |
+
try:
|
| 49 |
+
return "adParam" + directions[ado_dir]
|
| 50 |
+
except:
|
| 51 |
+
return f"unknown direction ({ado_dir})"
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# ObjectStateEnum
|
| 55 |
+
adStateClosed = 0
|
| 56 |
+
adStateOpen = 1
|
| 57 |
+
adStateConnecting = 2
|
| 58 |
+
adStateExecuting = 4
|
| 59 |
+
adStateFetching = 8
|
| 60 |
+
|
| 61 |
+
# FieldAttributeEnum
|
| 62 |
+
adFldMayBeNull = 0x40
|
| 63 |
+
|
| 64 |
+
# ConnectModeEnum
|
| 65 |
+
adModeUnknown = 0
|
| 66 |
+
adModeRead = 1
|
| 67 |
+
adModeWrite = 2
|
| 68 |
+
adModeReadWrite = 3
|
| 69 |
+
adModeShareDenyRead = 4
|
| 70 |
+
adModeShareDenyWrite = 8
|
| 71 |
+
adModeShareExclusive = 12
|
| 72 |
+
adModeShareDenyNone = 16
|
| 73 |
+
adModeRecursive = 0x400000
|
| 74 |
+
|
| 75 |
+
# XactAttributeEnum
|
| 76 |
+
adXactCommitRetaining = 131072
|
| 77 |
+
adXactAbortRetaining = 262144
|
| 78 |
+
|
| 79 |
+
ado_error_TIMEOUT = -2147217871
|
| 80 |
+
|
| 81 |
+
# DataTypeEnum - ADO Data types documented at:
|
| 82 |
+
# http://msdn2.microsoft.com/en-us/library/ms675318.aspx
|
| 83 |
+
# TODO: Update to https://learn.microsoft.com/en-us/sql/ado/reference/ado-api/datatypeenum
|
| 84 |
+
adArray = 0x2000
|
| 85 |
+
adEmpty = 0x0
|
| 86 |
+
adBSTR = 0x8
|
| 87 |
+
adBigInt = 0x14
|
| 88 |
+
adBinary = 0x80
|
| 89 |
+
adBoolean = 0xB
|
| 90 |
+
adChapter = 0x88
|
| 91 |
+
adChar = 0x81
|
| 92 |
+
adCurrency = 0x6
|
| 93 |
+
adDBDate = 0x85
|
| 94 |
+
adDBTime = 0x86
|
| 95 |
+
adDBTimeStamp = 0x87
|
| 96 |
+
adDate = 0x7
|
| 97 |
+
adDecimal = 0xE
|
| 98 |
+
adDouble = 0x5
|
| 99 |
+
adError = 0xA
|
| 100 |
+
adFileTime = 0x40
|
| 101 |
+
adGUID = 0x48
|
| 102 |
+
adIDispatch = 0x9
|
| 103 |
+
adIUnknown = 0xD
|
| 104 |
+
adInteger = 0x3
|
| 105 |
+
adLongVarBinary = 0xCD
|
| 106 |
+
adLongVarChar = 0xC9
|
| 107 |
+
adLongVarWChar = 0xCB
|
| 108 |
+
adNumeric = 0x83
|
| 109 |
+
adPropVariant = 0x8A
|
| 110 |
+
adSingle = 0x4
|
| 111 |
+
adSmallInt = 0x2
|
| 112 |
+
adTinyInt = 0x10
|
| 113 |
+
adUnsignedBigInt = 0x15
|
| 114 |
+
adUnsignedInt = 0x13
|
| 115 |
+
adUnsignedSmallInt = 0x12
|
| 116 |
+
adUnsignedTinyInt = 0x11
|
| 117 |
+
adUserDefined = 0x84
|
| 118 |
+
adVarBinary = 0xCC
|
| 119 |
+
adVarChar = 0xC8
|
| 120 |
+
adVarNumeric = 0x8B
|
| 121 |
+
adVarWChar = 0xCA
|
| 122 |
+
adVariant = 0xC
|
| 123 |
+
adWChar = 0x82
|
| 124 |
+
# Additional constants used by introspection but not ADO itself
|
| 125 |
+
AUTO_FIELD_MARKER = -1000
|
| 126 |
+
|
| 127 |
+
adTypeNames = {
|
| 128 |
+
adBSTR: "adBSTR",
|
| 129 |
+
adBigInt: "adBigInt",
|
| 130 |
+
adBinary: "adBinary",
|
| 131 |
+
adBoolean: "adBoolean",
|
| 132 |
+
adChapter: "adChapter",
|
| 133 |
+
adChar: "adChar",
|
| 134 |
+
adCurrency: "adCurrency",
|
| 135 |
+
adDBDate: "adDBDate",
|
| 136 |
+
adDBTime: "adDBTime",
|
| 137 |
+
adDBTimeStamp: "adDBTimeStamp",
|
| 138 |
+
adDate: "adDate",
|
| 139 |
+
adDecimal: "adDecimal",
|
| 140 |
+
adDouble: "adDouble",
|
| 141 |
+
adEmpty: "adEmpty",
|
| 142 |
+
adError: "adError",
|
| 143 |
+
adFileTime: "adFileTime",
|
| 144 |
+
adGUID: "adGUID",
|
| 145 |
+
adIDispatch: "adIDispatch",
|
| 146 |
+
adIUnknown: "adIUnknown",
|
| 147 |
+
adInteger: "adInteger",
|
| 148 |
+
adLongVarBinary: "adLongVarBinary",
|
| 149 |
+
adLongVarChar: "adLongVarChar",
|
| 150 |
+
adLongVarWChar: "adLongVarWChar",
|
| 151 |
+
adNumeric: "adNumeric",
|
| 152 |
+
adPropVariant: "adPropVariant",
|
| 153 |
+
adSingle: "adSingle",
|
| 154 |
+
adSmallInt: "adSmallInt",
|
| 155 |
+
adTinyInt: "adTinyInt",
|
| 156 |
+
adUnsignedBigInt: "adUnsignedBigInt",
|
| 157 |
+
adUnsignedInt: "adUnsignedInt",
|
| 158 |
+
adUnsignedSmallInt: "adUnsignedSmallInt",
|
| 159 |
+
adUnsignedTinyInt: "adUnsignedTinyInt",
|
| 160 |
+
adUserDefined: "adUserDefined",
|
| 161 |
+
adVarBinary: "adVarBinary",
|
| 162 |
+
adVarChar: "adVarChar",
|
| 163 |
+
adVarNumeric: "adVarNumeric",
|
| 164 |
+
adVarWChar: "adVarWChar",
|
| 165 |
+
adVariant: "adVariant",
|
| 166 |
+
adWChar: "adWChar",
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def ado_type_name(ado_type):
|
| 171 |
+
return adTypeNames.get(ado_type, f"unknown type ({ado_type})")
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# here in decimal, sorted by value
|
| 175 |
+
# adEmpty 0 Specifies no value (DBTYPE_EMPTY).
|
| 176 |
+
# adSmallInt 2 Indicates a two-byte signed integer (DBTYPE_I2).
|
| 177 |
+
# adInteger 3 Indicates a four-byte signed integer (DBTYPE_I4).
|
| 178 |
+
# adSingle 4 Indicates a single-precision floating-point value (DBTYPE_R4).
|
| 179 |
+
# adDouble 5 Indicates a double-precision floating-point value (DBTYPE_R8).
|
| 180 |
+
# adCurrency 6 Indicates a currency value (DBTYPE_CY). Currency is a fixed-point number
|
| 181 |
+
# with four digits to the right of the decimal point. It is stored in an eight-byte signed integer scaled by 10,000.
|
| 182 |
+
# adDate 7 Indicates a date value (DBTYPE_DATE). A date is stored as a double, the whole part of which is
|
| 183 |
+
# the number of days since December 30, 1899, and the fractional part of which is the fraction of a day.
|
| 184 |
+
# adBSTR 8 Indicates a null-terminated character string (Unicode) (DBTYPE_BSTR).
|
| 185 |
+
# adIDispatch 9 Indicates a pointer to an IDispatch interface on a COM object (DBTYPE_IDISPATCH).
|
| 186 |
+
# adError 10 Indicates a 32-bit error code (DBTYPE_ERROR).
|
| 187 |
+
# adBoolean 11 Indicates a boolean value (DBTYPE_BOOL).
|
| 188 |
+
# adVariant 12 Indicates an Automation Variant (DBTYPE_VARIANT).
|
| 189 |
+
# adIUnknown 13 Indicates a pointer to an IUnknown interface on a COM object (DBTYPE_IUNKNOWN).
|
| 190 |
+
# adDecimal 14 Indicates an exact numeric value with a fixed precision and scale (DBTYPE_DECIMAL).
|
| 191 |
+
# adTinyInt 16 Indicates a one-byte signed integer (DBTYPE_I1).
|
| 192 |
+
# adUnsignedTinyInt 17 Indicates a one-byte unsigned integer (DBTYPE_UI1).
|
| 193 |
+
# adUnsignedSmallInt 18 Indicates a two-byte unsigned integer (DBTYPE_UI2).
|
| 194 |
+
# adUnsignedInt 19 Indicates a four-byte unsigned integer (DBTYPE_UI4).
|
| 195 |
+
# adBigInt 20 Indicates an eight-byte signed integer (DBTYPE_I8).
|
| 196 |
+
# adUnsignedBigInt 21 Indicates an eight-byte unsigned integer (DBTYPE_UI8).
|
| 197 |
+
# adFileTime 64 Indicates a 64-bit value representing the number of 100-nanosecond intervals since
|
| 198 |
+
# January 1, 1601 (DBTYPE_FILETIME).
|
| 199 |
+
# adGUID 72 Indicates a globally unique identifier (GUID) (DBTYPE_GUID).
|
| 200 |
+
# adBinary 128 Indicates a binary value (DBTYPE_BYTES).
|
| 201 |
+
# adChar 129 Indicates a string value (DBTYPE_STR).
|
| 202 |
+
# adWChar 130 Indicates a null-terminated Unicode character string (DBTYPE_WSTR).
|
| 203 |
+
# adNumeric 131 Indicates an exact numeric value with a fixed precision and scale (DBTYPE_NUMERIC).
|
| 204 |
+
# adUserDefined 132 Indicates a user-defined variable (DBTYPE_UDT).
|
| 205 |
+
# adUserDefined 132 Indicates a user-defined variable (DBTYPE_UDT).
|
| 206 |
+
# adDBDate 133 Indicates a date value (yyyymmdd) (DBTYPE_DBDATE).
|
| 207 |
+
# adDBTime 134 Indicates a time value (hhmmss) (DBTYPE_DBTIME).
|
| 208 |
+
# adDBTimeStamp 135 Indicates a date/time stamp (yyyymmddhhmmss plus a fraction in billionths) (DBTYPE_DBTIMESTAMP).
|
| 209 |
+
# adChapter 136 Indicates a four-byte chapter value that identifies rows in a child rowset (DBTYPE_HCHAPTER).
|
| 210 |
+
# adPropVariant 138 Indicates an Automation PROPVARIANT (DBTYPE_PROP_VARIANT).
|
| 211 |
+
# adVarNumeric 139 Indicates a numeric value (Parameter object only).
|
| 212 |
+
# adVarChar 200 Indicates a string value (Parameter object only).
|
| 213 |
+
# adLongVarChar 201 Indicates a long string value (Parameter object only).
|
| 214 |
+
# adVarWChar 202 Indicates a null-terminated Unicode character string (Parameter object only).
|
| 215 |
+
# adLongVarWChar 203 Indicates a long null-terminated Unicode string value (Parameter object only).
|
| 216 |
+
# adVarBinary 204 Indicates a binary value (Parameter object only).
|
| 217 |
+
# adLongVarBinary 205 Indicates a long binary value (Parameter object only).
|
| 218 |
+
# adArray (Does not apply to ADOX.) 0x2000 A flag value, always combined with another data type constant,
|
| 219 |
+
# that indicates an array of that other data type.
|
| 220 |
+
|
| 221 |
+
# Error codes to names
|
| 222 |
+
adoErrors = {
|
| 223 |
+
0xE7B: "adErrBoundToCommand",
|
| 224 |
+
0xE94: "adErrCannotComplete",
|
| 225 |
+
0xEA4: "adErrCantChangeConnection",
|
| 226 |
+
0xC94: "adErrCantChangeProvider",
|
| 227 |
+
0xE8C: "adErrCantConvertvalue",
|
| 228 |
+
0xE8D: "adErrCantCreate",
|
| 229 |
+
0xEA3: "adErrCatalogNotSet",
|
| 230 |
+
0xE8E: "adErrColumnNotOnThisRow",
|
| 231 |
+
0xD5D: "adErrDataConversion",
|
| 232 |
+
0xE89: "adErrDataOverflow",
|
| 233 |
+
0xE9A: "adErrDelResOutOfScope",
|
| 234 |
+
0xEA6: "adErrDenyNotSupported",
|
| 235 |
+
0xEA7: "adErrDenyTypeNotSupported",
|
| 236 |
+
0xCB3: "adErrFeatureNotAvailable",
|
| 237 |
+
0xEA5: "adErrFieldsUpdateFailed",
|
| 238 |
+
0xC93: "adErrIllegalOperation",
|
| 239 |
+
0xCAE: "adErrInTransaction",
|
| 240 |
+
0xE87: "adErrIntegrityViolation",
|
| 241 |
+
0xBB9: "adErrInvalidArgument",
|
| 242 |
+
0xE7D: "adErrInvalidConnection",
|
| 243 |
+
0xE7C: "adErrInvalidParamInfo",
|
| 244 |
+
0xE82: "adErrInvalidTransaction",
|
| 245 |
+
0xE91: "adErrInvalidURL",
|
| 246 |
+
0xCC1: "adErrItemNotFound",
|
| 247 |
+
0xBCD: "adErrNoCurrentRecord",
|
| 248 |
+
0xE83: "adErrNotExecuting",
|
| 249 |
+
0xE7E: "adErrNotReentrant",
|
| 250 |
+
0xE78: "adErrObjectClosed",
|
| 251 |
+
0xD27: "adErrObjectInCollection",
|
| 252 |
+
0xD5C: "adErrObjectNotSet",
|
| 253 |
+
0xE79: "adErrObjectOpen",
|
| 254 |
+
0xBBA: "adErrOpeningFile",
|
| 255 |
+
0xE80: "adErrOperationCancelled",
|
| 256 |
+
0xE96: "adErrOutOfSpace",
|
| 257 |
+
0xE88: "adErrPermissionDenied",
|
| 258 |
+
0xE9E: "adErrPropConflicting",
|
| 259 |
+
0xE9B: "adErrPropInvalidColumn",
|
| 260 |
+
0xE9C: "adErrPropInvalidOption",
|
| 261 |
+
0xE9D: "adErrPropInvalidValue",
|
| 262 |
+
0xE9F: "adErrPropNotAllSettable",
|
| 263 |
+
0xEA0: "adErrPropNotSet",
|
| 264 |
+
0xEA1: "adErrPropNotSettable",
|
| 265 |
+
0xEA2: "adErrPropNotSupported",
|
| 266 |
+
0xBB8: "adErrProviderFailed",
|
| 267 |
+
0xE7A: "adErrProviderNotFound",
|
| 268 |
+
0xBBB: "adErrReadFile",
|
| 269 |
+
0xE93: "adErrResourceExists",
|
| 270 |
+
0xE92: "adErrResourceLocked",
|
| 271 |
+
0xE97: "adErrResourceOutOfScope",
|
| 272 |
+
0xE8A: "adErrSchemaViolation",
|
| 273 |
+
0xE8B: "adErrSignMismatch",
|
| 274 |
+
0xE81: "adErrStillConnecting",
|
| 275 |
+
0xE7F: "adErrStillExecuting",
|
| 276 |
+
0xE90: "adErrTreePermissionDenied",
|
| 277 |
+
0xE8F: "adErrURLDoesNotExist",
|
| 278 |
+
0xE99: "adErrURLNamedRowDoesNotExist",
|
| 279 |
+
0xE98: "adErrUnavailable",
|
| 280 |
+
0xE84: "adErrUnsafeOperation",
|
| 281 |
+
0xE95: "adErrVolumeNotFound",
|
| 282 |
+
0xBBC: "adErrWriteFile",
|
| 283 |
+
}
|
venv/Lib/site-packages/adodbapi/adodbapi.py
ADDED
|
@@ -0,0 +1,1153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""adodbapi - A python DB API 2.0 (PEP 249) interface to Microsoft ADO
|
| 2 |
+
|
| 3 |
+
Copyright (C) 2002 Henrik Ekelund, versions 2.1 and later by Vernon Cole
|
| 4 |
+
* https://sourceforge.net/projects/pywin32
|
| 5 |
+
* https://github.com/mhammond/pywin32
|
| 6 |
+
* https://sourceforge.net/projects/adodbapi
|
| 7 |
+
|
| 8 |
+
This library is free software; you can redistribute it and/or
|
| 9 |
+
modify it under the terms of the GNU Lesser General Public
|
| 10 |
+
License as published by the Free Software Foundation; either
|
| 11 |
+
version 2.1 of the License, or (at your option) any later version.
|
| 12 |
+
|
| 13 |
+
This library is distributed in the hope that it will be useful,
|
| 14 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 15 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
| 16 |
+
Lesser General Public License for more details.
|
| 17 |
+
|
| 18 |
+
You should have received a copy of the GNU Lesser General Public
|
| 19 |
+
License along with this library; if not, write to the Free Software
|
| 20 |
+
Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
|
| 21 |
+
|
| 22 |
+
django adaptations and refactoring by Adam Vandenberg
|
| 23 |
+
|
| 24 |
+
DB-API 2.0 specification: https://peps.python.org/pep-0249/
|
| 25 |
+
|
| 26 |
+
This module source should run correctly in CPython versions 2.7 and later,
|
| 27 |
+
or CPython 3.4 or later.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
__version__ = "2.6.2.0"
|
| 31 |
+
version = "adodbapi v" + __version__
|
| 32 |
+
|
| 33 |
+
import copy
|
| 34 |
+
import decimal
|
| 35 |
+
import os
|
| 36 |
+
import sys
|
| 37 |
+
import weakref
|
| 38 |
+
|
| 39 |
+
from . import ado_consts as adc, apibase as api, process_connect_string
|
| 40 |
+
|
| 41 |
+
try:
|
| 42 |
+
verbose = int(os.environ["ADODBAPI_VERBOSE"])
|
| 43 |
+
except:
|
| 44 |
+
verbose = False
|
| 45 |
+
if verbose:
|
| 46 |
+
print(version)
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
import pythoncom
|
| 50 |
+
import pywintypes
|
| 51 |
+
from win32com.client import Dispatch
|
| 52 |
+
except ImportError:
|
| 53 |
+
import warnings
|
| 54 |
+
|
| 55 |
+
warnings.warn("pywin32 package required for adodbapi.", ImportWarning)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def getIndexedValue(obj, index):
|
| 59 |
+
return obj(index)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
from collections.abc import Mapping
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# ----------------- The .connect method -----------------
|
| 66 |
+
def make_COM_connecter():
|
| 67 |
+
try:
|
| 68 |
+
pythoncom.CoInitialize() # v2.1 Paj
|
| 69 |
+
c = Dispatch("ADODB.Connection") # connect _after_ CoInitialize v2.1.1 adamvan
|
| 70 |
+
except:
|
| 71 |
+
raise api.InterfaceError(
|
| 72 |
+
"Windows COM Error: Dispatch('ADODB.Connection') failed."
|
| 73 |
+
)
|
| 74 |
+
return c
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def connect(*args, **kwargs): # --> a db-api connection object
|
| 78 |
+
"""Connect to a database.
|
| 79 |
+
|
| 80 |
+
call using:
|
| 81 |
+
:connection_string -- An ADODB formatted connection string, see:
|
| 82 |
+
* https://www.connectionstrings.com
|
| 83 |
+
* https://www.codeguru.com/dotnet/whats-in-an-ado-connection-string/
|
| 84 |
+
* https://learn.microsoft.com/en-us/dotnet/framework/data/adonet/connection-strings
|
| 85 |
+
:timeout -- A command timeout value, in seconds (default 30 seconds)
|
| 86 |
+
"""
|
| 87 |
+
co = Connection() # make an empty connection object
|
| 88 |
+
|
| 89 |
+
kwargs = process_connect_string.process(args, kwargs, True)
|
| 90 |
+
|
| 91 |
+
try: # connect to the database, using the connection information in kwargs
|
| 92 |
+
co.connect(kwargs)
|
| 93 |
+
return co
|
| 94 |
+
except Exception as e:
|
| 95 |
+
message = 'Error opening connection to "%s"' % co.connection_string
|
| 96 |
+
raise api.OperationalError(e, message)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# so you could use something like:
|
| 100 |
+
# myConnection.paramstyle = 'named'
|
| 101 |
+
# The programmer may also change the default.
|
| 102 |
+
# For example, if I were using django, I would say:
|
| 103 |
+
# import adodbapi as Database
|
| 104 |
+
# Database.adodbapi.paramstyle = 'format'
|
| 105 |
+
|
| 106 |
+
# ------- other module level defaults --------
|
| 107 |
+
defaultIsolationLevel = adc.adXactReadCommitted
|
| 108 |
+
# Set defaultIsolationLevel on module level before creating the connection.
|
| 109 |
+
# For example:
|
| 110 |
+
# import adodbapi, ado_consts
|
| 111 |
+
# adodbapi.adodbapi.defaultIsolationLevel=ado_consts.adXactBrowse"
|
| 112 |
+
#
|
| 113 |
+
# Set defaultCursorLocation on module level before creating the connection.
|
| 114 |
+
# It may be one of the "adUse..." consts.
|
| 115 |
+
defaultCursorLocation = adc.adUseClient # changed from adUseServer as of v 2.3.0
|
| 116 |
+
|
| 117 |
+
dateconverter = api.pythonDateTimeConverter() # default
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def format_parameters(ADOparameters, show_value=False):
|
| 121 |
+
"""Format a collection of ADO Command Parameters.
|
| 122 |
+
|
| 123 |
+
Used by error reporting in _execute_command.
|
| 124 |
+
"""
|
| 125 |
+
try:
|
| 126 |
+
if show_value:
|
| 127 |
+
desc = [
|
| 128 |
+
'Name: %s, Dir.: %s, Type: %s, Size: %s, Value: "%s", Precision: %s, NumericScale: %s'
|
| 129 |
+
% (
|
| 130 |
+
p.Name,
|
| 131 |
+
adc.directions[p.Direction],
|
| 132 |
+
adc.adTypeNames.get(p.Type, str(p.Type) + " (unknown type)"),
|
| 133 |
+
p.Size,
|
| 134 |
+
p.Value,
|
| 135 |
+
p.Precision,
|
| 136 |
+
p.NumericScale,
|
| 137 |
+
)
|
| 138 |
+
for p in ADOparameters
|
| 139 |
+
]
|
| 140 |
+
else:
|
| 141 |
+
desc = [
|
| 142 |
+
"Name: %s, Dir.: %s, Type: %s, Size: %s, Precision: %s, NumericScale: %s"
|
| 143 |
+
% (
|
| 144 |
+
p.Name,
|
| 145 |
+
adc.directions[p.Direction],
|
| 146 |
+
adc.adTypeNames.get(p.Type, str(p.Type) + " (unknown type)"),
|
| 147 |
+
p.Size,
|
| 148 |
+
p.Precision,
|
| 149 |
+
p.NumericScale,
|
| 150 |
+
)
|
| 151 |
+
for p in ADOparameters
|
| 152 |
+
]
|
| 153 |
+
return "[" + "\n".join(desc) + "]"
|
| 154 |
+
except:
|
| 155 |
+
return "[]"
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def _configure_parameter(p, value, adotype, settings_known):
|
| 159 |
+
"""Configure the given ADO Parameter 'p' with the Python 'value'."""
|
| 160 |
+
|
| 161 |
+
if adotype in api.adoBinaryTypes:
|
| 162 |
+
p.Size = len(value)
|
| 163 |
+
p.AppendChunk(value)
|
| 164 |
+
|
| 165 |
+
elif isinstance(value, str): # v2.1 Jevon
|
| 166 |
+
length = len(value)
|
| 167 |
+
if adotype in api.adoStringTypes: # v2.2.1 Cole
|
| 168 |
+
if settings_known:
|
| 169 |
+
length = min(length, p.Size) # v2.1 Cole limit data to defined size
|
| 170 |
+
p.Value = value[:length] # v2.1 Jevon & v2.1 Cole
|
| 171 |
+
else:
|
| 172 |
+
p.Value = value # don't limit if db column is numeric
|
| 173 |
+
if length > 0: # v2.1 Cole something does not like p.Size as Zero
|
| 174 |
+
p.Size = length # v2.1 Jevon
|
| 175 |
+
|
| 176 |
+
elif isinstance(value, decimal.Decimal):
|
| 177 |
+
p.Value = value
|
| 178 |
+
exponent = value.as_tuple()[2]
|
| 179 |
+
digit_count = len(value.as_tuple()[1])
|
| 180 |
+
p.Precision = digit_count
|
| 181 |
+
if exponent == 0:
|
| 182 |
+
p.NumericScale = 0
|
| 183 |
+
elif exponent < 0:
|
| 184 |
+
p.NumericScale = -exponent
|
| 185 |
+
if p.Precision < p.NumericScale:
|
| 186 |
+
p.Precision = p.NumericScale
|
| 187 |
+
else: # exponent > 0:
|
| 188 |
+
p.NumericScale = 0
|
| 189 |
+
p.Precision = digit_count + exponent
|
| 190 |
+
|
| 191 |
+
elif type(value) in dateconverter.types:
|
| 192 |
+
if settings_known and adotype in api.adoDateTimeTypes:
|
| 193 |
+
p.Value = dateconverter.COMDate(value)
|
| 194 |
+
else: # probably a string
|
| 195 |
+
# provide the date as a string in the format 'YYYY-MM-dd'
|
| 196 |
+
s = dateconverter.DateObjectToIsoFormatString(value)
|
| 197 |
+
p.Value = s
|
| 198 |
+
p.Size = len(s)
|
| 199 |
+
|
| 200 |
+
elif adotype == adc.adEmpty: # ADO will not let you specify a null column
|
| 201 |
+
p.Type = (
|
| 202 |
+
adc.adInteger
|
| 203 |
+
) # so we will fake it to be an integer (just to have something)
|
| 204 |
+
p.Value = None # and pass in a Null *value*
|
| 205 |
+
|
| 206 |
+
# For any other type, set the value and let pythoncom do the right thing.
|
| 207 |
+
else:
|
| 208 |
+
p.Value = value
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# # # # # ----- the Class that defines a connection ----- # # # # #
|
| 212 |
+
class Connection:
|
| 213 |
+
# include connection attributes as class attributes required by api definition.
|
| 214 |
+
Warning = api.Warning
|
| 215 |
+
Error = api.Error
|
| 216 |
+
InterfaceError = api.InterfaceError
|
| 217 |
+
DataError = api.DataError
|
| 218 |
+
DatabaseError = api.DatabaseError
|
| 219 |
+
OperationalError = api.OperationalError
|
| 220 |
+
IntegrityError = api.IntegrityError
|
| 221 |
+
InternalError = api.InternalError
|
| 222 |
+
NotSupportedError = api.NotSupportedError
|
| 223 |
+
ProgrammingError = api.ProgrammingError
|
| 224 |
+
FetchFailedError = api.FetchFailedError # (special for django)
|
| 225 |
+
# ...class attributes... (can be overridden by instance attributes)
|
| 226 |
+
verbose = api.verbose
|
| 227 |
+
|
| 228 |
+
@property
|
| 229 |
+
def dbapi(self): # a proposed db-api version 3 extension.
|
| 230 |
+
"Return a reference to the DBAPI module for this Connection."
|
| 231 |
+
return api
|
| 232 |
+
|
| 233 |
+
def __init__(self): # now define the instance attributes
|
| 234 |
+
self.connector = None
|
| 235 |
+
self.paramstyle = api.paramstyle
|
| 236 |
+
self.supportsTransactions = False
|
| 237 |
+
self.connection_string = ""
|
| 238 |
+
self.cursors = weakref.WeakValueDictionary[int, Cursor]()
|
| 239 |
+
self.dbms_name = ""
|
| 240 |
+
self.dbms_version = ""
|
| 241 |
+
self.errorhandler = None # use the standard error handler for this instance
|
| 242 |
+
self.transaction_level = 0 # 0 == Not in a transaction, at the top level
|
| 243 |
+
self._autocommit = False
|
| 244 |
+
|
| 245 |
+
def connect(self, kwargs, connection_maker=make_COM_connecter):
|
| 246 |
+
if verbose > 9:
|
| 247 |
+
print(f"kwargs={kwargs!r}")
|
| 248 |
+
try:
|
| 249 |
+
self.connection_string = (
|
| 250 |
+
kwargs["connection_string"] % kwargs
|
| 251 |
+
) # insert keyword arguments
|
| 252 |
+
except Exception as e:
|
| 253 |
+
self._raiseConnectionError(
|
| 254 |
+
KeyError, "Python string format error in connection string->"
|
| 255 |
+
)
|
| 256 |
+
self.timeout = kwargs.get("timeout", 30)
|
| 257 |
+
self.mode = kwargs.get("mode", adc.adModeUnknown)
|
| 258 |
+
self.kwargs = kwargs
|
| 259 |
+
if verbose:
|
| 260 |
+
print('%s attempting: "%s"' % (version, self.connection_string))
|
| 261 |
+
self.connector = connection_maker()
|
| 262 |
+
self.connector.ConnectionTimeout = self.timeout
|
| 263 |
+
self.connector.ConnectionString = self.connection_string
|
| 264 |
+
self.connector.Mode = self.mode
|
| 265 |
+
|
| 266 |
+
try:
|
| 267 |
+
self.connector.Open() # Open the ADO connection
|
| 268 |
+
except api.Error:
|
| 269 |
+
self._raiseConnectionError(
|
| 270 |
+
api.DatabaseError,
|
| 271 |
+
"ADO error trying to Open=%s" % self.connection_string,
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
try: # Stefan Fuchs; support WINCCOLEDBProvider
|
| 275 |
+
if getIndexedValue(self.connector.Properties, "Transaction DDL").Value != 0:
|
| 276 |
+
self.supportsTransactions = True
|
| 277 |
+
except pywintypes.com_error:
|
| 278 |
+
pass # Stefan Fuchs
|
| 279 |
+
self.dbms_name = getIndexedValue(self.connector.Properties, "DBMS Name").Value
|
| 280 |
+
try: # Stefan Fuchs
|
| 281 |
+
self.dbms_version = getIndexedValue(
|
| 282 |
+
self.connector.Properties, "DBMS Version"
|
| 283 |
+
).Value
|
| 284 |
+
except pywintypes.com_error:
|
| 285 |
+
pass # Stefan Fuchs
|
| 286 |
+
self.connector.CursorLocation = defaultCursorLocation # v2.1 Rose
|
| 287 |
+
if self.supportsTransactions:
|
| 288 |
+
self.connector.IsolationLevel = defaultIsolationLevel
|
| 289 |
+
self._autocommit = bool(kwargs.get("autocommit", False))
|
| 290 |
+
if not self._autocommit:
|
| 291 |
+
self.transaction_level = (
|
| 292 |
+
self.connector.BeginTrans()
|
| 293 |
+
) # Disables autocommit & inits transaction_level
|
| 294 |
+
else:
|
| 295 |
+
self._autocommit = True
|
| 296 |
+
if "paramstyle" in kwargs:
|
| 297 |
+
self.paramstyle = kwargs["paramstyle"] # let setattr do the error checking
|
| 298 |
+
self.messages = []
|
| 299 |
+
if verbose:
|
| 300 |
+
print("adodbapi New connection at %X" % id(self))
|
| 301 |
+
|
| 302 |
+
def _raiseConnectionError(self, errorclass, errorvalue):
|
| 303 |
+
eh = self.errorhandler
|
| 304 |
+
if eh is None:
|
| 305 |
+
eh = api.standardErrorHandler
|
| 306 |
+
eh(self, None, errorclass, errorvalue)
|
| 307 |
+
|
| 308 |
+
def _closeAdoConnection(self): # all v2.1 Rose
|
| 309 |
+
"""close the underlying ADO Connection object,
|
| 310 |
+
rolling it back first if it supports transactions."""
|
| 311 |
+
if self.connector is None:
|
| 312 |
+
return
|
| 313 |
+
if not self._autocommit:
|
| 314 |
+
if self.transaction_level:
|
| 315 |
+
try:
|
| 316 |
+
self.connector.RollbackTrans()
|
| 317 |
+
except:
|
| 318 |
+
pass
|
| 319 |
+
self.connector.Close()
|
| 320 |
+
if verbose:
|
| 321 |
+
print("adodbapi Closed connection at %X" % id(self))
|
| 322 |
+
|
| 323 |
+
def close(self):
|
| 324 |
+
"""Close the connection now (rather than whenever __del__ is called).
|
| 325 |
+
|
| 326 |
+
The connection will be unusable from this point forward;
|
| 327 |
+
an Error (or subclass) exception will be raised if any operation is attempted with the connection.
|
| 328 |
+
The same applies to all cursor objects trying to use the connection.
|
| 329 |
+
"""
|
| 330 |
+
for crsr in list(self.cursors.values())[
|
| 331 |
+
:
|
| 332 |
+
]: # copy the list, then close each one
|
| 333 |
+
crsr.close(dont_tell_me=True) # close without back-link clearing
|
| 334 |
+
self.messages = []
|
| 335 |
+
try:
|
| 336 |
+
self._closeAdoConnection() # v2.1 Rose
|
| 337 |
+
except Exception as e:
|
| 338 |
+
self._raiseConnectionError(sys.exc_info()[0], sys.exc_info()[1])
|
| 339 |
+
|
| 340 |
+
self.connector = None # v2.4.2.2 fix subtle timeout bug
|
| 341 |
+
# per M.Hammond: "I expect the benefits of uninitializing are probably fairly small,
|
| 342 |
+
# so never uninitializing will probably not cause any problems."
|
| 343 |
+
|
| 344 |
+
def commit(self):
|
| 345 |
+
"""Commit any pending transaction to the database.
|
| 346 |
+
|
| 347 |
+
Note that if the database supports an auto-commit feature,
|
| 348 |
+
this must be initially off. An interface method may be provided to turn it back on.
|
| 349 |
+
Database modules that do not support transactions should implement this method with void functionality.
|
| 350 |
+
"""
|
| 351 |
+
self.messages = []
|
| 352 |
+
if not self.supportsTransactions:
|
| 353 |
+
return
|
| 354 |
+
|
| 355 |
+
try:
|
| 356 |
+
self.transaction_level = self.connector.CommitTrans()
|
| 357 |
+
if verbose > 1:
|
| 358 |
+
print("commit done on connection at %X" % id(self))
|
| 359 |
+
if not (
|
| 360 |
+
self._autocommit
|
| 361 |
+
or (self.connector.Attributes & adc.adXactAbortRetaining)
|
| 362 |
+
):
|
| 363 |
+
# If attributes has adXactCommitRetaining it performs retaining commits that is,
|
| 364 |
+
# calling CommitTrans automatically starts a new transaction. Not all providers support this.
|
| 365 |
+
# If not, we will have to start a new transaction by this command:
|
| 366 |
+
self.transaction_level = self.connector.BeginTrans()
|
| 367 |
+
except Exception as e:
|
| 368 |
+
self._raiseConnectionError(api.ProgrammingError, e)
|
| 369 |
+
|
| 370 |
+
def _rollback(self):
|
| 371 |
+
"""In case a database does provide transactions this method causes the the database to roll back to
|
| 372 |
+
the start of any pending transaction. Closing a connection without committing the changes first will
|
| 373 |
+
cause an implicit rollback to be performed.
|
| 374 |
+
|
| 375 |
+
If the database does not support the functionality required by the method, the interface should
|
| 376 |
+
throw an exception in case the method is used.
|
| 377 |
+
The preferred approach is to not implement the method and thus have Python generate
|
| 378 |
+
an AttributeError in case the method is requested. This allows the programmer to check for database
|
| 379 |
+
capabilities using the standard hasattr() function.
|
| 380 |
+
|
| 381 |
+
For some dynamically configured interfaces it may not be appropriate to require dynamically making
|
| 382 |
+
the method available. These interfaces should then raise a NotSupportedError to indicate the
|
| 383 |
+
non-ability to perform the roll back when the method is invoked.
|
| 384 |
+
"""
|
| 385 |
+
self.messages = []
|
| 386 |
+
if (
|
| 387 |
+
self.transaction_level
|
| 388 |
+
): # trying to roll back with no open transaction causes an error
|
| 389 |
+
try:
|
| 390 |
+
self.transaction_level = self.connector.RollbackTrans()
|
| 391 |
+
if verbose > 1:
|
| 392 |
+
print("rollback done on connection at %X" % id(self))
|
| 393 |
+
if not self._autocommit and not (
|
| 394 |
+
self.connector.Attributes & adc.adXactAbortRetaining
|
| 395 |
+
):
|
| 396 |
+
# If attributes has adXactAbortRetaining it performs retaining aborts that is,
|
| 397 |
+
# calling RollbackTrans automatically starts a new transaction. Not all providers support this.
|
| 398 |
+
# If not, we will have to start a new transaction by this command:
|
| 399 |
+
if not self.transaction_level:
|
| 400 |
+
self.transaction_level = self.connector.BeginTrans()
|
| 401 |
+
except Exception as e:
|
| 402 |
+
self._raiseConnectionError(api.ProgrammingError, e)
|
| 403 |
+
|
| 404 |
+
def __setattr__(self, name, value):
|
| 405 |
+
if name == "autocommit": # extension: allow user to turn autocommit on or off
|
| 406 |
+
if self.supportsTransactions:
|
| 407 |
+
object.__setattr__(self, "_autocommit", bool(value))
|
| 408 |
+
try:
|
| 409 |
+
self._rollback() # must clear any outstanding transactions
|
| 410 |
+
except:
|
| 411 |
+
pass
|
| 412 |
+
return
|
| 413 |
+
elif name == "paramstyle":
|
| 414 |
+
if value not in api.accepted_paramstyles:
|
| 415 |
+
self._raiseConnectionError(
|
| 416 |
+
api.NotSupportedError,
|
| 417 |
+
f"paramstyle={value!r} not in:{api.accepted_paramstyles!r}",
|
| 418 |
+
)
|
| 419 |
+
elif name == "variantConversions":
|
| 420 |
+
# make a new copy -- no changes in the default, please
|
| 421 |
+
value = copy.copy(value)
|
| 422 |
+
object.__setattr__(self, name, value)
|
| 423 |
+
|
| 424 |
+
def __getattr__(self, item):
|
| 425 |
+
if (
|
| 426 |
+
item == "rollback"
|
| 427 |
+
): # the rollback method only appears if the database supports transactions
|
| 428 |
+
if self.supportsTransactions:
|
| 429 |
+
return (
|
| 430 |
+
self._rollback
|
| 431 |
+
) # return the rollback method so the caller can execute it.
|
| 432 |
+
else:
|
| 433 |
+
raise AttributeError("this data provider does not support Rollback")
|
| 434 |
+
elif item == "autocommit":
|
| 435 |
+
return self._autocommit
|
| 436 |
+
else:
|
| 437 |
+
raise AttributeError(
|
| 438 |
+
'no such attribute in ADO connection object as="%s"' % item
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
def cursor(self):
|
| 442 |
+
"Return a new Cursor Object using the connection."
|
| 443 |
+
self.messages = []
|
| 444 |
+
c = Cursor(self)
|
| 445 |
+
return c
|
| 446 |
+
|
| 447 |
+
def _i_am_here(self, crsr):
|
| 448 |
+
"message from a new cursor proclaiming its existence"
|
| 449 |
+
oid = id(crsr)
|
| 450 |
+
self.cursors[oid] = crsr
|
| 451 |
+
|
| 452 |
+
def _i_am_closing(self, crsr):
|
| 453 |
+
"message from a cursor giving connection a chance to clean up"
|
| 454 |
+
try:
|
| 455 |
+
del self.cursors[id(crsr)]
|
| 456 |
+
except:
|
| 457 |
+
pass
|
| 458 |
+
|
| 459 |
+
def printADOerrors(self):
|
| 460 |
+
j = self.connector.Errors.Count
|
| 461 |
+
if j:
|
| 462 |
+
print("ADO Errors:(%i)" % j)
|
| 463 |
+
for e in self.connector.Errors:
|
| 464 |
+
print("Description: %s" % e.Description)
|
| 465 |
+
print("Error: %s %s " % (e.Number, adc.adoErrors.get(e.Number, "unknown")))
|
| 466 |
+
if e.Number == adc.ado_error_TIMEOUT:
|
| 467 |
+
print(
|
| 468 |
+
"Timeout Error: Try using adodbpi.connect(constr,timeout=Nseconds)"
|
| 469 |
+
)
|
| 470 |
+
print("Source: %s" % e.Source)
|
| 471 |
+
print("NativeError: %s" % e.NativeError)
|
| 472 |
+
print("SQL State: %s" % e.SQLState)
|
| 473 |
+
|
| 474 |
+
def _suggest_error_class(self):
|
| 475 |
+
"""Introspect the current ADO Errors and determine an appropriate error class.
|
| 476 |
+
|
| 477 |
+
Error.SQLState is a SQL-defined error condition, per the SQL specification:
|
| 478 |
+
https://www.contrib.andrew.cmu.edu/~shadow/sql/sql1992.txt
|
| 479 |
+
|
| 480 |
+
The 23000 class of errors are integrity errors.
|
| 481 |
+
Error 40002 is a transactional integrity error.
|
| 482 |
+
"""
|
| 483 |
+
if self.connector is not None:
|
| 484 |
+
for e in self.connector.Errors:
|
| 485 |
+
state = str(e.SQLState)
|
| 486 |
+
if state.startswith("23") or state == "40002":
|
| 487 |
+
return api.IntegrityError
|
| 488 |
+
return api.DatabaseError
|
| 489 |
+
|
| 490 |
+
def __del__(self):
|
| 491 |
+
try:
|
| 492 |
+
self._closeAdoConnection() # v2.1 Rose
|
| 493 |
+
except:
|
| 494 |
+
pass
|
| 495 |
+
self.connector = None
|
| 496 |
+
|
| 497 |
+
def __enter__(self): # Connections are context managers
|
| 498 |
+
return self
|
| 499 |
+
|
| 500 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 501 |
+
if exc_type:
|
| 502 |
+
self._rollback() # automatic rollback on errors
|
| 503 |
+
else:
|
| 504 |
+
self.commit()
|
| 505 |
+
|
| 506 |
+
def get_table_names(self):
|
| 507 |
+
schema = self.connector.OpenSchema(20) # constant = adSchemaTables
|
| 508 |
+
|
| 509 |
+
tables = []
|
| 510 |
+
while not schema.EOF:
|
| 511 |
+
name = getIndexedValue(schema.Fields, "TABLE_NAME").Value
|
| 512 |
+
tables.append(name)
|
| 513 |
+
schema.MoveNext()
|
| 514 |
+
del schema
|
| 515 |
+
return tables
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
# # # # # ----- the Class that defines a cursor ----- # # # # #
|
| 519 |
+
class Cursor:
|
| 520 |
+
## ** api required attributes:
|
| 521 |
+
## description...
|
| 522 |
+
## This read-only attribute is a sequence of 7-item sequences.
|
| 523 |
+
## Each of these sequences contains information describing one result column:
|
| 524 |
+
## (name, type_code, display_size, internal_size, precision, scale, null_ok).
|
| 525 |
+
## This attribute will be None for operations that do not return rows or if the
|
| 526 |
+
## cursor has not had an operation invoked via the executeXXX() method yet.
|
| 527 |
+
## The type_code can be interpreted by comparing it to the Type Objects specified in the section below.
|
| 528 |
+
## rowcount...
|
| 529 |
+
## This read-only attribute specifies the number of rows that the last executeXXX() produced
|
| 530 |
+
## (for DQL statements like select) or affected (for DML statements like update or insert).
|
| 531 |
+
## The attribute is -1 in case no executeXXX() has been performed on the cursor or
|
| 532 |
+
## the rowcount of the last operation is not determinable by the interface.[7]
|
| 533 |
+
## arraysize...
|
| 534 |
+
## This read/write attribute specifies the number of rows to fetch at a time with fetchmany().
|
| 535 |
+
## It defaults to 1 meaning to fetch a single row at a time.
|
| 536 |
+
## Implementations must observe this value with respect to the fetchmany() method,
|
| 537 |
+
## but are free to interact with the database a single row at a time.
|
| 538 |
+
## It may also be used in the implementation of executemany().
|
| 539 |
+
## ** extension attributes:
|
| 540 |
+
## paramstyle...
|
| 541 |
+
## allows the programmer to override the connection's default paramstyle
|
| 542 |
+
## errorhandler...
|
| 543 |
+
## allows the programmer to override the connection's default error handler
|
| 544 |
+
|
| 545 |
+
def __init__(self, connection):
|
| 546 |
+
self.command = None
|
| 547 |
+
self._ado_prepared = False
|
| 548 |
+
self.messages = []
|
| 549 |
+
self.connection = connection
|
| 550 |
+
self.paramstyle = connection.paramstyle # used for overriding the paramstyle
|
| 551 |
+
self._parameter_names = []
|
| 552 |
+
self.recordset_is_remote = False
|
| 553 |
+
self.rs = None # the ADO recordset for this cursor
|
| 554 |
+
self.converters = [] # conversion function for each column
|
| 555 |
+
self.columnNames = {} # names of columns {lowercase name : number,...}
|
| 556 |
+
self.numberOfColumns = 0
|
| 557 |
+
self._description = None
|
| 558 |
+
self.rowcount = -1
|
| 559 |
+
self.errorhandler = connection.errorhandler
|
| 560 |
+
self.arraysize = 1
|
| 561 |
+
connection._i_am_here(self)
|
| 562 |
+
if verbose:
|
| 563 |
+
print(
|
| 564 |
+
"%s New cursor at %X on conn %X"
|
| 565 |
+
% (version, id(self), id(self.connection))
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
def __iter__(self): # [2.1 Zamarev]
|
| 569 |
+
return iter(self.fetchone, None) # [2.1 Zamarev]
|
| 570 |
+
|
| 571 |
+
def prepare(self, operation):
|
| 572 |
+
self.command = operation
|
| 573 |
+
self._description = None
|
| 574 |
+
self._ado_prepared = "setup"
|
| 575 |
+
|
| 576 |
+
def __next__(self):
|
| 577 |
+
r = self.fetchone()
|
| 578 |
+
if r:
|
| 579 |
+
return r
|
| 580 |
+
raise StopIteration
|
| 581 |
+
|
| 582 |
+
def __enter__(self):
|
| 583 |
+
"Allow database cursors to be used with context managers."
|
| 584 |
+
return self
|
| 585 |
+
|
| 586 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 587 |
+
"Allow database cursors to be used with context managers."
|
| 588 |
+
self.close()
|
| 589 |
+
|
| 590 |
+
def _raiseCursorError(self, errorclass, errorvalue):
|
| 591 |
+
eh = self.errorhandler
|
| 592 |
+
if eh is None:
|
| 593 |
+
eh = api.standardErrorHandler
|
| 594 |
+
eh(self.connection, self, errorclass, errorvalue)
|
| 595 |
+
|
| 596 |
+
def build_column_info(self, recordset):
|
| 597 |
+
self.converters = [] # conversion function for each column
|
| 598 |
+
self.columnNames = {} # names of columns {lowercase name : number,...}
|
| 599 |
+
self._description = None
|
| 600 |
+
|
| 601 |
+
# if EOF and BOF are true at the same time, there are no records in the recordset
|
| 602 |
+
if (recordset is None) or (recordset.State == adc.adStateClosed):
|
| 603 |
+
self.rs = None
|
| 604 |
+
self.numberOfColumns = 0
|
| 605 |
+
return
|
| 606 |
+
self.rs = recordset # v2.1.1 bkline
|
| 607 |
+
self.recordset_format = api.RS_WIN_32
|
| 608 |
+
self.numberOfColumns = recordset.Fields.Count
|
| 609 |
+
try:
|
| 610 |
+
varCon = self.connection.variantConversions
|
| 611 |
+
except AttributeError:
|
| 612 |
+
varCon = api.variantConversions
|
| 613 |
+
for i in range(self.numberOfColumns):
|
| 614 |
+
f = getIndexedValue(self.rs.Fields, i)
|
| 615 |
+
try:
|
| 616 |
+
self.converters.append(
|
| 617 |
+
varCon[f.Type]
|
| 618 |
+
) # conversion function for this column
|
| 619 |
+
except KeyError:
|
| 620 |
+
self._raiseCursorError(
|
| 621 |
+
api.InternalError, "Data column of Unknown ADO type=%s" % f.Type
|
| 622 |
+
)
|
| 623 |
+
self.columnNames[f.Name.lower()] = i # columnNames lookup
|
| 624 |
+
|
| 625 |
+
def _makeDescriptionFromRS(self):
|
| 626 |
+
# Abort if closed or no recordset.
|
| 627 |
+
if self.rs is None:
|
| 628 |
+
self._description = None
|
| 629 |
+
return
|
| 630 |
+
desc = []
|
| 631 |
+
for i in range(self.numberOfColumns):
|
| 632 |
+
f = getIndexedValue(self.rs.Fields, i)
|
| 633 |
+
if self.rs.EOF or self.rs.BOF:
|
| 634 |
+
display_size = None
|
| 635 |
+
else:
|
| 636 |
+
# TODO: Is this the correct defintion according to the DB API 2 Spec ?
|
| 637 |
+
display_size = f.ActualSize
|
| 638 |
+
null_ok = bool(f.Attributes & adc.adFldMayBeNull) # v2.1 Cole
|
| 639 |
+
desc.append(
|
| 640 |
+
(
|
| 641 |
+
f.Name,
|
| 642 |
+
f.Type,
|
| 643 |
+
display_size,
|
| 644 |
+
f.DefinedSize,
|
| 645 |
+
f.Precision,
|
| 646 |
+
f.NumericScale,
|
| 647 |
+
null_ok,
|
| 648 |
+
)
|
| 649 |
+
)
|
| 650 |
+
self._description = desc
|
| 651 |
+
|
| 652 |
+
def get_description(self):
|
| 653 |
+
if not self._description:
|
| 654 |
+
self._makeDescriptionFromRS()
|
| 655 |
+
return self._description
|
| 656 |
+
|
| 657 |
+
def __getattr__(self, item):
|
| 658 |
+
if item == "description":
|
| 659 |
+
return self.get_description()
|
| 660 |
+
object.__getattribute__(
|
| 661 |
+
self, item
|
| 662 |
+
) # may get here on Remote attribute calls for existing attributes
|
| 663 |
+
|
| 664 |
+
def format_description(self, d):
|
| 665 |
+
"""Format db_api description tuple for printing."""
|
| 666 |
+
if self.description is None:
|
| 667 |
+
self._makeDescriptionFromRS()
|
| 668 |
+
if isinstance(d, int):
|
| 669 |
+
d = self.description[d]
|
| 670 |
+
desc = (
|
| 671 |
+
"Name= %s, Type= %s, DispSize= %s, IntSize= %s, Precision= %s, Scale= %s NullOK=%s"
|
| 672 |
+
% (
|
| 673 |
+
d[0],
|
| 674 |
+
adc.adTypeNames.get(d[1], str(d[1]) + " (unknown type)"),
|
| 675 |
+
d[2],
|
| 676 |
+
d[3],
|
| 677 |
+
d[4],
|
| 678 |
+
d[5],
|
| 679 |
+
d[6],
|
| 680 |
+
)
|
| 681 |
+
)
|
| 682 |
+
return desc
|
| 683 |
+
|
| 684 |
+
def close(self, dont_tell_me=False):
|
| 685 |
+
"""Close the cursor now (rather than whenever __del__ is called).
|
| 686 |
+
The cursor will be unusable from this point forward; an Error (or subclass)
|
| 687 |
+
exception will be raised if any operation is attempted with the cursor.
|
| 688 |
+
"""
|
| 689 |
+
if self.connection is None:
|
| 690 |
+
return
|
| 691 |
+
self.messages = []
|
| 692 |
+
if (
|
| 693 |
+
self.rs and self.rs.State != adc.adStateClosed
|
| 694 |
+
): # rs exists and is open #v2.1 Rose
|
| 695 |
+
self.rs.Close() # v2.1 Rose
|
| 696 |
+
self.rs = None # let go of the recordset so ADO will let it be disposed #v2.1 Rose
|
| 697 |
+
if not dont_tell_me:
|
| 698 |
+
self.connection._i_am_closing(
|
| 699 |
+
self
|
| 700 |
+
) # take me off the connection's cursors list
|
| 701 |
+
self.connection = (
|
| 702 |
+
None # this will make all future method calls on me throw an exception
|
| 703 |
+
)
|
| 704 |
+
if verbose:
|
| 705 |
+
print("adodbapi Closed cursor at %X" % id(self))
|
| 706 |
+
|
| 707 |
+
def __del__(self):
|
| 708 |
+
try:
|
| 709 |
+
self.close()
|
| 710 |
+
except:
|
| 711 |
+
pass
|
| 712 |
+
|
| 713 |
+
def _new_command(self, command_type=adc.adCmdText):
|
| 714 |
+
self.cmd = None
|
| 715 |
+
self.messages = []
|
| 716 |
+
|
| 717 |
+
if self.connection is None:
|
| 718 |
+
self._raiseCursorError(api.InterfaceError, None)
|
| 719 |
+
return
|
| 720 |
+
try:
|
| 721 |
+
self.cmd = Dispatch("ADODB.Command")
|
| 722 |
+
self.cmd.ActiveConnection = self.connection.connector
|
| 723 |
+
self.cmd.CommandTimeout = self.connection.timeout
|
| 724 |
+
self.cmd.CommandType = command_type
|
| 725 |
+
self.cmd.CommandText = self.commandText
|
| 726 |
+
self.cmd.Prepared = bool(self._ado_prepared)
|
| 727 |
+
except:
|
| 728 |
+
self._raiseCursorError(
|
| 729 |
+
api.DatabaseError,
|
| 730 |
+
f"Error creating new ADODB.Command object for {self.commandText!r}",
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
def _execute_command(self):
|
| 734 |
+
# Stored procedures may have an integer return value
|
| 735 |
+
self.return_value = None
|
| 736 |
+
recordset = None
|
| 737 |
+
count = -1 # default value
|
| 738 |
+
if verbose:
|
| 739 |
+
print('Executing command="%s"' % self.commandText)
|
| 740 |
+
try:
|
| 741 |
+
# ----- the actual SQL is executed here ---
|
| 742 |
+
recordset, count = self.cmd.Execute()
|
| 743 |
+
# ----- ------------------------------- ---
|
| 744 |
+
except Exception as e:
|
| 745 |
+
_message = ""
|
| 746 |
+
if hasattr(e, "args"):
|
| 747 |
+
_message += str(e.args) + "\n"
|
| 748 |
+
_message += "Command:\n%s\nParameters:\n%s" % (
|
| 749 |
+
self.commandText,
|
| 750 |
+
format_parameters(self.cmd.Parameters, True),
|
| 751 |
+
)
|
| 752 |
+
klass = self.connection._suggest_error_class()
|
| 753 |
+
self._raiseCursorError(klass, _message)
|
| 754 |
+
try:
|
| 755 |
+
self.rowcount = recordset.RecordCount
|
| 756 |
+
except:
|
| 757 |
+
self.rowcount = count
|
| 758 |
+
self.build_column_info(recordset)
|
| 759 |
+
|
| 760 |
+
# The ADO documentation hints that obtaining the recordcount may be timeconsuming
|
| 761 |
+
# "If the Recordset object does not support approximate positioning, this property
|
| 762 |
+
# may be a significant drain on resources # [ekelund]
|
| 763 |
+
# Therefore, COM will not return rowcount for server-side cursors. [Cole]
|
| 764 |
+
# Client-side cursors (the default since v2.8) will force a static
|
| 765 |
+
# cursor, and rowcount will then be set accurately [Cole]
|
| 766 |
+
|
| 767 |
+
def get_rowcount(self):
|
| 768 |
+
return self.rowcount
|
| 769 |
+
|
| 770 |
+
def get_returned_parameters(self):
|
| 771 |
+
"""with some providers, returned parameters and the .return_value are not available until
|
| 772 |
+
after the last recordset has been read. In that case, you must coll nextset() until it
|
| 773 |
+
returns None, then call this method to get your returned information."""
|
| 774 |
+
|
| 775 |
+
# store procedures may return altered parameters, including an added "return value" item
|
| 776 |
+
retLst = []
|
| 777 |
+
for p in tuple(self.cmd.Parameters):
|
| 778 |
+
if verbose > 2:
|
| 779 |
+
print(
|
| 780 |
+
'Returned=Name: %s, Dir.: %s, Type: %s, Size: %s, Value: "%s",'
|
| 781 |
+
" Precision: %s, NumericScale: %s"
|
| 782 |
+
% (
|
| 783 |
+
p.Name,
|
| 784 |
+
adc.directions[p.Direction],
|
| 785 |
+
adc.adTypeNames.get(p.Type, str(p.Type) + " (unknown type)"),
|
| 786 |
+
p.Size,
|
| 787 |
+
p.Value,
|
| 788 |
+
p.Precision,
|
| 789 |
+
p.NumericScale,
|
| 790 |
+
)
|
| 791 |
+
)
|
| 792 |
+
pyObject = api.convert_to_python(p.Value, api.variantConversions[p.Type])
|
| 793 |
+
if p.Direction == adc.adParamReturnValue:
|
| 794 |
+
self.returnValue = (
|
| 795 |
+
pyObject # also load the undocumented attribute (Vernon's Error!)
|
| 796 |
+
)
|
| 797 |
+
self.return_value = pyObject
|
| 798 |
+
else:
|
| 799 |
+
retLst.append(pyObject)
|
| 800 |
+
return retLst # return the parameter list to the caller
|
| 801 |
+
|
| 802 |
+
def callproc(self, procname, parameters=None):
|
| 803 |
+
"""Call a stored database procedure with the given name.
|
| 804 |
+
The sequence of parameters must contain one entry for each
|
| 805 |
+
argument that the sproc expects. The result of the
|
| 806 |
+
call is returned as modified copy of the input
|
| 807 |
+
sequence. Input parameters are left untouched, output and
|
| 808 |
+
input/output parameters replaced with possibly new values.
|
| 809 |
+
|
| 810 |
+
The sproc may also provide a result set as output,
|
| 811 |
+
which is available through the standard .fetch*() methods.
|
| 812 |
+
Extension: A "return_value" property may be set on the
|
| 813 |
+
cursor if the sproc defines an integer return value.
|
| 814 |
+
"""
|
| 815 |
+
self._parameter_names = []
|
| 816 |
+
self.commandText = procname
|
| 817 |
+
self._new_command(command_type=adc.adCmdStoredProc)
|
| 818 |
+
self._buildADOparameterList(parameters, sproc=True)
|
| 819 |
+
if verbose > 2:
|
| 820 |
+
print(
|
| 821 |
+
"Calling Stored Proc with Params=",
|
| 822 |
+
format_parameters(self.cmd.Parameters, True),
|
| 823 |
+
)
|
| 824 |
+
self._execute_command()
|
| 825 |
+
return self.get_returned_parameters()
|
| 826 |
+
|
| 827 |
+
def _reformat_operation(self, operation, parameters):
|
| 828 |
+
if self.paramstyle in ("format", "pyformat"): # convert %s to ?
|
| 829 |
+
operation, self._parameter_names = api.changeFormatToQmark(operation)
|
| 830 |
+
elif self.paramstyle == "named" or (
|
| 831 |
+
self.paramstyle == "dynamic" and isinstance(parameters, Mapping)
|
| 832 |
+
):
|
| 833 |
+
operation, self._parameter_names = api.changeNamedToQmark(
|
| 834 |
+
operation
|
| 835 |
+
) # convert :name to ?
|
| 836 |
+
return operation
|
| 837 |
+
|
| 838 |
+
def _buildADOparameterList(self, parameters, sproc=False):
|
| 839 |
+
self.parameters = parameters
|
| 840 |
+
if parameters is None:
|
| 841 |
+
parameters = []
|
| 842 |
+
|
| 843 |
+
# Note: ADO does not preserve the parameter list, even if "Prepared" is True, so we must build every time.
|
| 844 |
+
parameters_known = False
|
| 845 |
+
if sproc: # needed only if we are calling a stored procedure
|
| 846 |
+
try: # attempt to use ADO's parameter list
|
| 847 |
+
self.cmd.Parameters.Refresh()
|
| 848 |
+
if verbose > 2:
|
| 849 |
+
print(
|
| 850 |
+
"ADO detected Params=",
|
| 851 |
+
format_parameters(self.cmd.Parameters, True),
|
| 852 |
+
)
|
| 853 |
+
print(f"Program Parameters={parameters!r}")
|
| 854 |
+
parameters_known = True
|
| 855 |
+
except api.Error:
|
| 856 |
+
if verbose:
|
| 857 |
+
print("ADO Parameter Refresh failed")
|
| 858 |
+
pass
|
| 859 |
+
else:
|
| 860 |
+
if len(parameters) != self.cmd.Parameters.Count - 1:
|
| 861 |
+
raise api.ProgrammingError(
|
| 862 |
+
"You must supply %d parameters for this stored procedure"
|
| 863 |
+
% (self.cmd.Parameters.Count - 1)
|
| 864 |
+
)
|
| 865 |
+
if sproc or parameters != []:
|
| 866 |
+
i = 0
|
| 867 |
+
if parameters_known: # use ado parameter list
|
| 868 |
+
if self._parameter_names: # named parameters
|
| 869 |
+
for i, pm_name in enumerate(self._parameter_names):
|
| 870 |
+
p = getIndexedValue(self.cmd.Parameters, i)
|
| 871 |
+
try:
|
| 872 |
+
_configure_parameter(
|
| 873 |
+
p, parameters[pm_name], p.Type, parameters_known
|
| 874 |
+
)
|
| 875 |
+
except Exception as e:
|
| 876 |
+
_message = "Error Converting Parameter {}: {}, {} <- {!r}\n".format(
|
| 877 |
+
p.Name,
|
| 878 |
+
adc.ado_type_name(p.Type),
|
| 879 |
+
p.Value,
|
| 880 |
+
parameters[pm_name],
|
| 881 |
+
)
|
| 882 |
+
self._raiseCursorError(
|
| 883 |
+
api.DataError, f"{_message}->{e.args!r}"
|
| 884 |
+
)
|
| 885 |
+
else: # regular sequence of parameters
|
| 886 |
+
for value in parameters:
|
| 887 |
+
p = getIndexedValue(self.cmd.Parameters, i)
|
| 888 |
+
if (
|
| 889 |
+
p.Direction == adc.adParamReturnValue
|
| 890 |
+
): # this is an extra parameter added by ADO
|
| 891 |
+
i += 1 # skip the extra
|
| 892 |
+
p = getIndexedValue(self.cmd.Parameters, i)
|
| 893 |
+
try:
|
| 894 |
+
_configure_parameter(p, value, p.Type, parameters_known)
|
| 895 |
+
except Exception as e:
|
| 896 |
+
_message = "Error Converting Parameter {}: {}, {} <- {!r}\n".format(
|
| 897 |
+
p.Name,
|
| 898 |
+
adc.ado_type_name(p.Type),
|
| 899 |
+
p.Value,
|
| 900 |
+
value,
|
| 901 |
+
)
|
| 902 |
+
self._raiseCursorError(
|
| 903 |
+
api.DataError, f"{_message}->{e.args!r}"
|
| 904 |
+
)
|
| 905 |
+
i += 1
|
| 906 |
+
else: # -- build own parameter list
|
| 907 |
+
# we expect a dictionary of parameters, this is the list of expected names
|
| 908 |
+
if self._parameter_names:
|
| 909 |
+
for parm_name in self._parameter_names:
|
| 910 |
+
elem = parameters[parm_name]
|
| 911 |
+
adotype = api.pyTypeToADOType(elem)
|
| 912 |
+
p = self.cmd.CreateParameter(
|
| 913 |
+
parm_name, adotype, adc.adParamInput
|
| 914 |
+
)
|
| 915 |
+
_configure_parameter(p, elem, adotype, parameters_known)
|
| 916 |
+
try:
|
| 917 |
+
self.cmd.Parameters.Append(p)
|
| 918 |
+
except Exception as e:
|
| 919 |
+
_message = (
|
| 920 |
+
"Error Building Parameter {}: {}, {} <- {!r}\n".format(
|
| 921 |
+
p.Name,
|
| 922 |
+
adc.ado_type_name(p.Type),
|
| 923 |
+
p.Value,
|
| 924 |
+
elem,
|
| 925 |
+
)
|
| 926 |
+
)
|
| 927 |
+
self._raiseCursorError(
|
| 928 |
+
api.DataError, f"{_message}->{e.args!r}"
|
| 929 |
+
)
|
| 930 |
+
else: # expecting the usual sequence of parameters
|
| 931 |
+
if sproc:
|
| 932 |
+
p = self.cmd.CreateParameter(
|
| 933 |
+
"@RETURN_VALUE", adc.adInteger, adc.adParamReturnValue
|
| 934 |
+
)
|
| 935 |
+
self.cmd.Parameters.Append(p)
|
| 936 |
+
|
| 937 |
+
for elem in parameters:
|
| 938 |
+
name = "p%i" % i
|
| 939 |
+
adotype = api.pyTypeToADOType(elem)
|
| 940 |
+
p = self.cmd.CreateParameter(
|
| 941 |
+
name, adotype, adc.adParamInput
|
| 942 |
+
) # Name, Type, Direction, Size, Value
|
| 943 |
+
_configure_parameter(p, elem, adotype, parameters_known)
|
| 944 |
+
try:
|
| 945 |
+
self.cmd.Parameters.Append(p)
|
| 946 |
+
except Exception as e:
|
| 947 |
+
_message = (
|
| 948 |
+
"Error Building Parameter {}: {}, {} <- {!r}\n".format(
|
| 949 |
+
p.Name,
|
| 950 |
+
adc.ado_type_name(p.Type),
|
| 951 |
+
p.Value,
|
| 952 |
+
elem,
|
| 953 |
+
)
|
| 954 |
+
)
|
| 955 |
+
self._raiseCursorError(
|
| 956 |
+
api.DataError, f"{_message}->{e.args!r}"
|
| 957 |
+
)
|
| 958 |
+
i += 1
|
| 959 |
+
if self._ado_prepared == "setup":
|
| 960 |
+
self._ado_prepared = (
|
| 961 |
+
True # parameters will be "known" by ADO next loop
|
| 962 |
+
)
|
| 963 |
+
|
| 964 |
+
def execute(self, operation, parameters=None):
|
| 965 |
+
"""Prepare and execute a database operation (query or command).
|
| 966 |
+
|
| 967 |
+
Parameters may be provided as sequence or mapping and will be bound to variables in the operation.
|
| 968 |
+
Variables are specified in a database-specific notation
|
| 969 |
+
(see the module's paramstyle attribute for details). [5]
|
| 970 |
+
A reference to the operation will be retained by the cursor.
|
| 971 |
+
If the same operation object is passed in again, then the cursor
|
| 972 |
+
can optimize its behavior. This is most effective for algorithms
|
| 973 |
+
where the same operation is used, but different parameters are bound to it (many times).
|
| 974 |
+
|
| 975 |
+
For maximum efficiency when reusing an operation, it is best to use
|
| 976 |
+
the setinputsizes() method to specify the parameter types and sizes ahead of time.
|
| 977 |
+
It is legal for a parameter to not match the predefined information;
|
| 978 |
+
the implementation should compensate, possibly with a loss of efficiency.
|
| 979 |
+
|
| 980 |
+
The parameters may also be specified as list of tuples to e.g. insert multiple rows in
|
| 981 |
+
a single operation, but this kind of usage is depreciated: executemany() should be used instead.
|
| 982 |
+
|
| 983 |
+
Return value is not defined.
|
| 984 |
+
|
| 985 |
+
[5] The module will use the __getitem__ method of the parameters object to map either positions
|
| 986 |
+
(integers) or names (strings) to parameter values. This allows for both sequences and mappings
|
| 987 |
+
to be used as input.
|
| 988 |
+
The term "bound" refers to the process of binding an input value to a database execution buffer.
|
| 989 |
+
In practical terms, this means that the input value is directly used as a value in the operation.
|
| 990 |
+
The client should not be required to "escape" the value so that it can be used -- the value
|
| 991 |
+
should be equal to the actual database value."""
|
| 992 |
+
if (
|
| 993 |
+
self.command is not operation
|
| 994 |
+
or self._ado_prepared == "setup"
|
| 995 |
+
or not hasattr(self, "commandText")
|
| 996 |
+
):
|
| 997 |
+
if self.command is not operation:
|
| 998 |
+
self._ado_prepared = False
|
| 999 |
+
self.command = operation
|
| 1000 |
+
self._parameter_names = []
|
| 1001 |
+
self.commandText = (
|
| 1002 |
+
operation
|
| 1003 |
+
if (self.paramstyle == "qmark" or not parameters)
|
| 1004 |
+
else self._reformat_operation(operation, parameters)
|
| 1005 |
+
)
|
| 1006 |
+
self._new_command()
|
| 1007 |
+
self._buildADOparameterList(parameters)
|
| 1008 |
+
if verbose > 3:
|
| 1009 |
+
print("Params=", format_parameters(self.cmd.Parameters, True))
|
| 1010 |
+
self._execute_command()
|
| 1011 |
+
|
| 1012 |
+
def executemany(self, operation, seq_of_parameters):
|
| 1013 |
+
"""Prepare a database operation (query or command)
|
| 1014 |
+
and then execute it against all parameter sequences or mappings found in the sequence seq_of_parameters.
|
| 1015 |
+
|
| 1016 |
+
Return values are not defined.
|
| 1017 |
+
"""
|
| 1018 |
+
self.messages = list()
|
| 1019 |
+
total_recordcount = 0
|
| 1020 |
+
|
| 1021 |
+
self.prepare(operation)
|
| 1022 |
+
for params in seq_of_parameters:
|
| 1023 |
+
self.execute(self.command, params)
|
| 1024 |
+
if self.rowcount == -1:
|
| 1025 |
+
total_recordcount = -1
|
| 1026 |
+
if total_recordcount != -1:
|
| 1027 |
+
total_recordcount += self.rowcount
|
| 1028 |
+
self.rowcount = total_recordcount
|
| 1029 |
+
|
| 1030 |
+
def _fetch(self, limit=None):
|
| 1031 |
+
"""Fetch rows from the current recordset.
|
| 1032 |
+
|
| 1033 |
+
limit -- Number of rows to fetch, or None (default) to fetch all rows.
|
| 1034 |
+
"""
|
| 1035 |
+
if self.connection is None or self.rs is None:
|
| 1036 |
+
self._raiseCursorError(
|
| 1037 |
+
api.FetchFailedError, "fetch() on closed connection or empty query set"
|
| 1038 |
+
)
|
| 1039 |
+
return
|
| 1040 |
+
|
| 1041 |
+
if self.rs.State == adc.adStateClosed or self.rs.BOF or self.rs.EOF:
|
| 1042 |
+
return list()
|
| 1043 |
+
if limit: # limit number of rows retrieved
|
| 1044 |
+
ado_results = self.rs.GetRows(limit)
|
| 1045 |
+
else: # get all rows
|
| 1046 |
+
ado_results = self.rs.GetRows()
|
| 1047 |
+
if (
|
| 1048 |
+
self.recordset_format == api.RS_ARRAY
|
| 1049 |
+
): # result of GetRows is a two-dimension array
|
| 1050 |
+
length = (
|
| 1051 |
+
len(ado_results) // self.numberOfColumns
|
| 1052 |
+
) # length of first dimension
|
| 1053 |
+
else: # pywin32
|
| 1054 |
+
length = len(ado_results[0]) # result of GetRows is tuples in a tuple
|
| 1055 |
+
fetchObject = api.SQLrows(
|
| 1056 |
+
ado_results, length, self
|
| 1057 |
+
) # new object to hold the results of the fetch
|
| 1058 |
+
return fetchObject
|
| 1059 |
+
|
| 1060 |
+
def fetchone(self):
|
| 1061 |
+
"""Fetch the next row of a query result set, returning a single sequence,
|
| 1062 |
+
or None when no more data is available.
|
| 1063 |
+
|
| 1064 |
+
An Error (or subclass) exception is raised if the previous call to executeXXX()
|
| 1065 |
+
did not produce any result set or no call was issued yet.
|
| 1066 |
+
"""
|
| 1067 |
+
self.messages = []
|
| 1068 |
+
result = self._fetch(1)
|
| 1069 |
+
if result: # return record (not list of records)
|
| 1070 |
+
return result[0]
|
| 1071 |
+
return None
|
| 1072 |
+
|
| 1073 |
+
def fetchmany(self, size=None):
|
| 1074 |
+
"""Fetch the next set of rows of a query result, returning a list of tuples. An empty sequence is returned when no more rows are available.
|
| 1075 |
+
|
| 1076 |
+
The number of rows to fetch per call is specified by the parameter.
|
| 1077 |
+
If it is not given, the cursor's arraysize determines the number of rows to be fetched.
|
| 1078 |
+
The method should try to fetch as many rows as indicated by the size parameter.
|
| 1079 |
+
If this is not possible due to the specified number of rows not being available,
|
| 1080 |
+
fewer rows may be returned.
|
| 1081 |
+
|
| 1082 |
+
An Error (or subclass) exception is raised if the previous call to executeXXX()
|
| 1083 |
+
did not produce any result set or no call was issued yet.
|
| 1084 |
+
|
| 1085 |
+
Note there are performance considerations involved with the size parameter.
|
| 1086 |
+
For optimal performance, it is usually best to use the arraysize attribute.
|
| 1087 |
+
If the size parameter is used, then it is best for it to retain the same value from
|
| 1088 |
+
one fetchmany() call to the next.
|
| 1089 |
+
"""
|
| 1090 |
+
self.messages = []
|
| 1091 |
+
if size is None:
|
| 1092 |
+
size = self.arraysize
|
| 1093 |
+
return self._fetch(size)
|
| 1094 |
+
|
| 1095 |
+
def fetchall(self):
|
| 1096 |
+
"""Fetch all (remaining) rows of a query result, returning them as a sequence of sequences (e.g. a list of tuples).
|
| 1097 |
+
|
| 1098 |
+
Note that the cursor's arraysize attribute
|
| 1099 |
+
can affect the performance of this operation.
|
| 1100 |
+
An Error (or subclass) exception is raised if the previous call to executeXXX()
|
| 1101 |
+
did not produce any result set or no call was issued yet.
|
| 1102 |
+
"""
|
| 1103 |
+
self.messages = []
|
| 1104 |
+
return self._fetch()
|
| 1105 |
+
|
| 1106 |
+
def nextset(self):
|
| 1107 |
+
"""Skip to the next available recordset, discarding any remaining rows from the current recordset.
|
| 1108 |
+
|
| 1109 |
+
If there are no more sets, the method returns None. Otherwise, it returns a true
|
| 1110 |
+
value and subsequent calls to the fetch methods will return rows from the next result set.
|
| 1111 |
+
|
| 1112 |
+
An Error (or subclass) exception is raised if the previous call to executeXXX()
|
| 1113 |
+
did not produce any result set or no call was issued yet.
|
| 1114 |
+
"""
|
| 1115 |
+
self.messages = []
|
| 1116 |
+
if self.connection is None or self.rs is None:
|
| 1117 |
+
self._raiseCursorError(
|
| 1118 |
+
api.OperationalError,
|
| 1119 |
+
("nextset() on closed connection or empty query set"),
|
| 1120 |
+
)
|
| 1121 |
+
return None
|
| 1122 |
+
|
| 1123 |
+
try: # [begin 2.1 ekelund]
|
| 1124 |
+
rsTuple = self.rs.NextRecordset() #
|
| 1125 |
+
except pywintypes.com_error as exc: # return appropriate error
|
| 1126 |
+
self._raiseCursorError(api.NotSupportedError, exc.args) # [end 2.1 ekelund]
|
| 1127 |
+
recordset = rsTuple[0]
|
| 1128 |
+
if recordset is None:
|
| 1129 |
+
return None
|
| 1130 |
+
self.build_column_info(recordset)
|
| 1131 |
+
return True
|
| 1132 |
+
|
| 1133 |
+
def setinputsizes(self, sizes):
|
| 1134 |
+
pass
|
| 1135 |
+
|
| 1136 |
+
def setoutputsize(self, size, column=None):
|
| 1137 |
+
pass
|
| 1138 |
+
|
| 1139 |
+
def _last_query(self): # let the programmer see what query we actually used
|
| 1140 |
+
try:
|
| 1141 |
+
if self.parameters is None:
|
| 1142 |
+
ret = self.commandText
|
| 1143 |
+
else:
|
| 1144 |
+
ret = f"{self.commandText},parameters={self.parameters!r}"
|
| 1145 |
+
except:
|
| 1146 |
+
ret = None
|
| 1147 |
+
return ret
|
| 1148 |
+
|
| 1149 |
+
query = property(_last_query, None, None, "returns the last query executed")
|
| 1150 |
+
|
| 1151 |
+
|
| 1152 |
+
if __name__ == "__main__":
|
| 1153 |
+
raise api.ProgrammingError(version + " cannot be run as a main program.")
|
venv/Lib/site-packages/adodbapi/apibase.py
ADDED
|
@@ -0,0 +1,723 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""adodbapi.apibase - A python DB API 2.0 (PEP 249) interface to Microsoft ADO
|
| 2 |
+
|
| 3 |
+
Copyright (C) 2002 Henrik Ekelund, version 2.1 by Vernon Cole
|
| 4 |
+
* https://sourceforge.net/projects/pywin32
|
| 5 |
+
* https://sourceforge.net/projects/adodbapi
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import datetime
|
| 11 |
+
import decimal
|
| 12 |
+
import numbers
|
| 13 |
+
import sys
|
| 14 |
+
import time
|
| 15 |
+
from collections.abc import Callable, Iterable, Mapping
|
| 16 |
+
|
| 17 |
+
# noinspection PyUnresolvedReferences
|
| 18 |
+
from . import ado_consts as adc
|
| 19 |
+
|
| 20 |
+
verbose = False # debugging flag
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# ------- Error handlers ------
|
| 24 |
+
def standardErrorHandler(connection, cursor, errorclass, errorvalue):
|
| 25 |
+
err = (errorclass, errorvalue)
|
| 26 |
+
try:
|
| 27 |
+
connection.messages.append(err)
|
| 28 |
+
except:
|
| 29 |
+
pass
|
| 30 |
+
if cursor is not None:
|
| 31 |
+
try:
|
| 32 |
+
cursor.messages.append(err)
|
| 33 |
+
except:
|
| 34 |
+
pass
|
| 35 |
+
raise errorclass(errorvalue)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class Error(Exception):
|
| 39 |
+
pass # Exception that is the base class of all other error
|
| 40 |
+
# exceptions. You can use this to catch all errors with one
|
| 41 |
+
# single 'except' statement. Warnings are not considered
|
| 42 |
+
# errors and thus should not use this class as base. It must
|
| 43 |
+
# be a subclass of the Python StandardError (defined in the
|
| 44 |
+
# module exceptions).
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class Warning(Exception):
|
| 48 |
+
pass
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class InterfaceError(Error):
|
| 52 |
+
pass
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class DatabaseError(Error):
|
| 56 |
+
pass
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class InternalError(DatabaseError):
|
| 60 |
+
pass
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class OperationalError(DatabaseError):
|
| 64 |
+
pass
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class ProgrammingError(DatabaseError):
|
| 68 |
+
pass
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class IntegrityError(DatabaseError):
|
| 72 |
+
pass
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class DataError(DatabaseError):
|
| 76 |
+
pass
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class NotSupportedError(DatabaseError):
|
| 80 |
+
pass
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class FetchFailedError(OperationalError):
|
| 84 |
+
"""
|
| 85 |
+
Error is used by RawStoredProcedureQuerySet to determine when a fetch
|
| 86 |
+
failed due to a connection being closed or there is no record set
|
| 87 |
+
returned. (Non-standard, added especially for django)
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
pass
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# # # # # ----- Type Objects and Constructors ----- # # # # #
|
| 94 |
+
# Many databases need to have the input in a particular format for binding to an operation's input parameters.
|
| 95 |
+
# For example, if an input is destined for a DATE column, then it must be bound to the database in a particular
|
| 96 |
+
# string format. Similar problems exist for "Row ID" columns or large binary items (e.g. blobs or RAW columns).
|
| 97 |
+
# This presents problems for Python since the parameters to the executeXXX() method are untyped.
|
| 98 |
+
# When the database module sees a Python string object, it doesn't know if it should be bound as a simple CHAR
|
| 99 |
+
# column, as a raw BINARY item, or as a DATE.
|
| 100 |
+
#
|
| 101 |
+
# To overcome this problem, a module must provide the constructors defined below to create objects that can
|
| 102 |
+
# hold special values. When passed to the cursor methods, the module can then detect the proper type of
|
| 103 |
+
# the input parameter and bind it accordingly.
|
| 104 |
+
|
| 105 |
+
# A Cursor Object's description attribute returns information about each of the result columns of a query.
|
| 106 |
+
# The type_code must compare equal to one of Type Objects defined below. Type Objects may be equal to more than
|
| 107 |
+
# one type code (e.g. DATETIME could be equal to the type codes for date, time and timestamp columns;
|
| 108 |
+
# see the Implementation Hints below for details).
|
| 109 |
+
|
| 110 |
+
# SQL NULL values are represented by the Python None singleton on input and output.
|
| 111 |
+
|
| 112 |
+
# Note: Usage of Unix ticks for database interfacing can cause troubles because of the limited date range they cover.
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# def Date(year,month,day):
|
| 116 |
+
# "This function constructs an object holding a date value. "
|
| 117 |
+
# return dateconverter.date(year,month,day) #dateconverter.Date(year,month,day)
|
| 118 |
+
#
|
| 119 |
+
# def Time(hour,minute,second):
|
| 120 |
+
# "This function constructs an object holding a time value. "
|
| 121 |
+
# return dateconverter.time(hour, minute, second) # dateconverter.Time(hour,minute,second)
|
| 122 |
+
#
|
| 123 |
+
# def Timestamp(year,month,day,hour,minute,second):
|
| 124 |
+
# "This function constructs an object holding a time stamp value. "
|
| 125 |
+
# return dateconverter.datetime(year,month,day,hour,minute,second)
|
| 126 |
+
#
|
| 127 |
+
# def DateFromTicks(ticks):
|
| 128 |
+
# """This function constructs an object holding a date value from the given ticks value
|
| 129 |
+
# (number of seconds since the epoch; see the documentation of the standard Python time module for details). """
|
| 130 |
+
# return Date(*time.gmtime(ticks)[:3])
|
| 131 |
+
#
|
| 132 |
+
# def TimeFromTicks(ticks):
|
| 133 |
+
# """This function constructs an object holding a time value from the given ticks value
|
| 134 |
+
# (number of seconds since the epoch; see the documentation of the standard Python time module for details). """
|
| 135 |
+
# return Time(*time.gmtime(ticks)[3:6])
|
| 136 |
+
#
|
| 137 |
+
# def TimestampFromTicks(ticks):
|
| 138 |
+
# """This function constructs an object holding a time stamp value from the given
|
| 139 |
+
# ticks value (number of seconds since the epoch;
|
| 140 |
+
# see the documentation of the standard Python time module for details). """
|
| 141 |
+
# return Timestamp(*time.gmtime(ticks)[:6])
|
| 142 |
+
#
|
| 143 |
+
# def Binary(aString):
|
| 144 |
+
# """This function constructs an object capable of holding a binary (long) string value. """
|
| 145 |
+
# b = bytes(aString)
|
| 146 |
+
# return b
|
| 147 |
+
# ----- Time converters ----------------------------------------------
|
| 148 |
+
class TimeConverter: # this is a generic time converter skeleton
|
| 149 |
+
def __init__(self): # the details will be filled in by instances
|
| 150 |
+
self._ordinal_1899_12_31 = datetime.date(1899, 12, 31).toordinal() - 1
|
| 151 |
+
# Use cls.types to compare if an input parameter is a datetime
|
| 152 |
+
self.types = {
|
| 153 |
+
# Dynamically get the types as the methods may be overriden
|
| 154 |
+
type(self.Date(2000, 1, 1)),
|
| 155 |
+
type(self.Time(12, 1, 1)),
|
| 156 |
+
type(self.Timestamp(2000, 1, 1, 12, 1, 1)),
|
| 157 |
+
datetime.datetime,
|
| 158 |
+
datetime.time,
|
| 159 |
+
datetime.date,
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
def COMDate(self, obj):
|
| 163 |
+
"""Returns a ComDate from a date-time"""
|
| 164 |
+
try: # most likely a datetime
|
| 165 |
+
tt = obj.timetuple()
|
| 166 |
+
|
| 167 |
+
try:
|
| 168 |
+
ms = obj.microsecond
|
| 169 |
+
except:
|
| 170 |
+
ms = 0
|
| 171 |
+
return self.ComDateFromTuple(tt, ms)
|
| 172 |
+
except: # might be a tuple
|
| 173 |
+
try:
|
| 174 |
+
return self.ComDateFromTuple(obj)
|
| 175 |
+
except:
|
| 176 |
+
raise ValueError(f'Cannot convert "{obj!r}" to COMdate.')
|
| 177 |
+
|
| 178 |
+
def ComDateFromTuple(self, t, microseconds=0):
|
| 179 |
+
d = datetime.date(t[0], t[1], t[2])
|
| 180 |
+
integerPart = d.toordinal() - self._ordinal_1899_12_31
|
| 181 |
+
ms = (t[3] * 3600 + t[4] * 60 + t[5]) * 1000000 + microseconds
|
| 182 |
+
fractPart = float(ms) / 86400000000.0
|
| 183 |
+
return integerPart + fractPart
|
| 184 |
+
|
| 185 |
+
def DateObjectFromCOMDate(self, comDate):
|
| 186 |
+
"Returns an object of the wanted type from a ComDate"
|
| 187 |
+
raise NotImplementedError # "Abstract class"
|
| 188 |
+
|
| 189 |
+
def Date(self, year, month, day):
|
| 190 |
+
"This function constructs an object holding a date value."
|
| 191 |
+
raise NotImplementedError # "Abstract class"
|
| 192 |
+
|
| 193 |
+
def Time(self, hour, minute, second):
|
| 194 |
+
"This function constructs an object holding a time value."
|
| 195 |
+
raise NotImplementedError # "Abstract class"
|
| 196 |
+
|
| 197 |
+
def Timestamp(self, year, month, day, hour, minute, second):
|
| 198 |
+
"This function constructs an object holding a time stamp value."
|
| 199 |
+
raise NotImplementedError # "Abstract class"
|
| 200 |
+
# all purpose date to ISO format converter
|
| 201 |
+
|
| 202 |
+
def DateObjectToIsoFormatString(self, obj):
|
| 203 |
+
"This function should return a string in the format 'YYYY-MM-dd HH:MM:SS:ms' (ms optional)"
|
| 204 |
+
try: # most likely, a datetime.datetime
|
| 205 |
+
s = obj.isoformat(" ")
|
| 206 |
+
except (TypeError, AttributeError):
|
| 207 |
+
if isinstance(obj, datetime.date):
|
| 208 |
+
s = obj.isoformat() + " 00:00:00" # return exact midnight
|
| 209 |
+
else:
|
| 210 |
+
try: # but may be time.struct_time
|
| 211 |
+
s = time.strftime("%Y-%m-%d %H:%M:%S", obj)
|
| 212 |
+
except:
|
| 213 |
+
raise ValueError(f'Cannot convert "{obj!r}" to isoformat')
|
| 214 |
+
return s
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class pythonDateTimeConverter(TimeConverter): # standard since Python 2.3
|
| 218 |
+
def __init__(self):
|
| 219 |
+
TimeConverter.__init__(self)
|
| 220 |
+
|
| 221 |
+
def DateObjectFromCOMDate(self, comDate):
|
| 222 |
+
if isinstance(comDate, datetime.datetime):
|
| 223 |
+
odn = comDate.toordinal()
|
| 224 |
+
tim = comDate.time()
|
| 225 |
+
new = datetime.datetime.combine(datetime.datetime.fromordinal(odn), tim)
|
| 226 |
+
return new
|
| 227 |
+
# return comDate.replace(tzinfo=None) # make non aware
|
| 228 |
+
else:
|
| 229 |
+
fComDate = float(comDate) # ComDate is number of days since 1899-12-31
|
| 230 |
+
integerPart = int(fComDate)
|
| 231 |
+
floatpart = fComDate - integerPart
|
| 232 |
+
##if floatpart == 0.0:
|
| 233 |
+
## return datetime.date.fromordinal(integerPart + self._ordinal_1899_12_31)
|
| 234 |
+
dte = datetime.datetime.fromordinal(
|
| 235 |
+
integerPart + self._ordinal_1899_12_31
|
| 236 |
+
) + datetime.timedelta(milliseconds=floatpart * 86400000)
|
| 237 |
+
# millisecondsperday=86400000 # 24*60*60*1000
|
| 238 |
+
return dte
|
| 239 |
+
|
| 240 |
+
def Date(self, year, month, day):
|
| 241 |
+
return datetime.date(year, month, day)
|
| 242 |
+
|
| 243 |
+
def Time(self, hour, minute, second):
|
| 244 |
+
return datetime.time(hour, minute, second)
|
| 245 |
+
|
| 246 |
+
def Timestamp(self, year, month, day, hour, minute, second):
|
| 247 |
+
return datetime.datetime(year, month, day, hour, minute, second)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
class pythonTimeConverter(TimeConverter): # the old, ?nix type date and time
|
| 251 |
+
def __init__(self): # caution: this Class gets confised by timezones and DST
|
| 252 |
+
TimeConverter.__init__(self)
|
| 253 |
+
self.types.add(time.struct_time)
|
| 254 |
+
|
| 255 |
+
def DateObjectFromCOMDate(self, comDate):
|
| 256 |
+
"Returns ticks since 1970"
|
| 257 |
+
if isinstance(comDate, datetime.datetime):
|
| 258 |
+
return comDate.timetuple()
|
| 259 |
+
else:
|
| 260 |
+
fcomDate = float(comDate)
|
| 261 |
+
secondsperday = 86400 # 24*60*60
|
| 262 |
+
# ComDate is number of days since 1899-12-31, gmtime epoch is 1970-1-1 = 25569 days
|
| 263 |
+
t = time.gmtime(secondsperday * (fcomDate - 25569.0))
|
| 264 |
+
return t # year,month,day,hour,minute,second,weekday,julianday,daylightsaving=t
|
| 265 |
+
|
| 266 |
+
def Date(self, year, month, day):
|
| 267 |
+
return self.Timestamp(year, month, day, 0, 0, 0)
|
| 268 |
+
|
| 269 |
+
def Time(self, hour, minute, second):
|
| 270 |
+
return time.gmtime((hour * 60 + minute) * 60 + second)
|
| 271 |
+
|
| 272 |
+
def Timestamp(self, year, month, day, hour, minute, second):
|
| 273 |
+
return time.localtime(
|
| 274 |
+
time.mktime((year, month, day, hour, minute, second, 0, 0, -1))
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
base_dateconverter = pythonDateTimeConverter()
|
| 279 |
+
|
| 280 |
+
# ------ DB API required module attributes ---------------------
|
| 281 |
+
threadsafety = 1 # TODO -- find out whether this module is actually BETTER than 1.
|
| 282 |
+
|
| 283 |
+
apilevel = "2.0" # String constant stating the supported DB API level.
|
| 284 |
+
|
| 285 |
+
paramstyle = "qmark" # the default parameter style
|
| 286 |
+
|
| 287 |
+
# ------ control for an extension which may become part of DB API 3.0 ---
|
| 288 |
+
accepted_paramstyles = ("qmark", "named", "format", "pyformat", "dynamic")
|
| 289 |
+
|
| 290 |
+
# ------------------------------------------------------------------------------------------
|
| 291 |
+
# define similar types for generic conversion routines
|
| 292 |
+
adoIntegerTypes = (
|
| 293 |
+
adc.adInteger,
|
| 294 |
+
adc.adSmallInt,
|
| 295 |
+
adc.adTinyInt,
|
| 296 |
+
adc.adUnsignedInt,
|
| 297 |
+
adc.adUnsignedSmallInt,
|
| 298 |
+
adc.adUnsignedTinyInt,
|
| 299 |
+
adc.adBoolean,
|
| 300 |
+
adc.adError,
|
| 301 |
+
) # max 32 bits
|
| 302 |
+
adoRowIdTypes = (adc.adChapter,) # v2.1 Rose
|
| 303 |
+
adoLongTypes = (adc.adBigInt, adc.adFileTime, adc.adUnsignedBigInt)
|
| 304 |
+
adoExactNumericTypes = (
|
| 305 |
+
adc.adDecimal,
|
| 306 |
+
adc.adNumeric,
|
| 307 |
+
adc.adVarNumeric,
|
| 308 |
+
adc.adCurrency,
|
| 309 |
+
) # v2.3 Cole
|
| 310 |
+
adoApproximateNumericTypes = (adc.adDouble, adc.adSingle) # v2.1 Cole
|
| 311 |
+
adoStringTypes = (
|
| 312 |
+
adc.adBSTR,
|
| 313 |
+
adc.adChar,
|
| 314 |
+
adc.adLongVarChar,
|
| 315 |
+
adc.adLongVarWChar,
|
| 316 |
+
adc.adVarChar,
|
| 317 |
+
adc.adVarWChar,
|
| 318 |
+
adc.adWChar,
|
| 319 |
+
)
|
| 320 |
+
adoBinaryTypes = (adc.adBinary, adc.adLongVarBinary, adc.adVarBinary)
|
| 321 |
+
adoDateTimeTypes = (adc.adDBTime, adc.adDBTimeStamp, adc.adDate, adc.adDBDate)
|
| 322 |
+
adoRemainingTypes = (
|
| 323 |
+
adc.adEmpty,
|
| 324 |
+
adc.adIDispatch,
|
| 325 |
+
adc.adIUnknown,
|
| 326 |
+
adc.adPropVariant,
|
| 327 |
+
adc.adArray,
|
| 328 |
+
adc.adUserDefined,
|
| 329 |
+
adc.adVariant,
|
| 330 |
+
adc.adGUID,
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
# this class is a trick to determine whether a type is a member of a related group of types. see PEP notes
|
| 335 |
+
class DBAPITypeObject:
|
| 336 |
+
def __init__(self, valuesTuple):
|
| 337 |
+
self.values = frozenset(valuesTuple)
|
| 338 |
+
|
| 339 |
+
def __eq__(self, other):
|
| 340 |
+
return other in self.values
|
| 341 |
+
|
| 342 |
+
def __ne__(self, other):
|
| 343 |
+
return other not in self.values
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
"""This type object is used to describe columns in a database that are string-based (e.g. CHAR). """
|
| 347 |
+
STRING = DBAPITypeObject(adoStringTypes)
|
| 348 |
+
|
| 349 |
+
"""This type object is used to describe (long) binary columns in a database (e.g. LONG, RAW, BLOBs). """
|
| 350 |
+
BINARY = DBAPITypeObject(adoBinaryTypes)
|
| 351 |
+
|
| 352 |
+
"""This type object is used to describe numeric columns in a database. """
|
| 353 |
+
NUMBER = DBAPITypeObject(
|
| 354 |
+
adoIntegerTypes + adoLongTypes + adoExactNumericTypes + adoApproximateNumericTypes
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
"""This type object is used to describe date/time columns in a database. """
|
| 358 |
+
|
| 359 |
+
DATETIME = DBAPITypeObject(adoDateTimeTypes)
|
| 360 |
+
"""This type object is used to describe the "Row ID" column in a database. """
|
| 361 |
+
ROWID = DBAPITypeObject(adoRowIdTypes)
|
| 362 |
+
|
| 363 |
+
OTHER = DBAPITypeObject(adoRemainingTypes)
|
| 364 |
+
|
| 365 |
+
# ------- utilities for translating python data types to ADO data types ---------------------------------
|
| 366 |
+
typeMap = {
|
| 367 |
+
memoryview: adc.adVarBinary,
|
| 368 |
+
float: adc.adDouble,
|
| 369 |
+
type(None): adc.adEmpty,
|
| 370 |
+
str: adc.adBSTR,
|
| 371 |
+
bool: adc.adBoolean, # v2.1 Cole
|
| 372 |
+
decimal.Decimal: adc.adDecimal,
|
| 373 |
+
int: adc.adBigInt,
|
| 374 |
+
bytes: adc.adVarBinary,
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def pyTypeToADOType(d):
|
| 379 |
+
tp = type(d)
|
| 380 |
+
try:
|
| 381 |
+
return typeMap[tp]
|
| 382 |
+
except KeyError: # The type was not defined in the pre-computed Type table
|
| 383 |
+
from . import dateconverter
|
| 384 |
+
|
| 385 |
+
# maybe it is one of our supported Date/Time types
|
| 386 |
+
if tp in dateconverter.types:
|
| 387 |
+
return adc.adDate
|
| 388 |
+
# otherwise, attempt to discern the type by probing the data object itself -- to handle duck typing
|
| 389 |
+
if isinstance(d, str):
|
| 390 |
+
return adc.adBSTR
|
| 391 |
+
if isinstance(d, numbers.Integral):
|
| 392 |
+
return adc.adBigInt
|
| 393 |
+
if isinstance(d, numbers.Real):
|
| 394 |
+
return adc.adDouble
|
| 395 |
+
raise DataError(f'cannot convert "{d!r}" (type={tp}) to ADO')
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
# # # # # # # # # # # # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
| 399 |
+
# functions to convert database values to Python objects
|
| 400 |
+
# ------------------------------------------------------------------------
|
| 401 |
+
# variant type : function converting variant to Python value
|
| 402 |
+
def variantConvertDate(v):
|
| 403 |
+
from . import dateconverter # this function only called when adodbapi is running
|
| 404 |
+
|
| 405 |
+
return dateconverter.DateObjectFromCOMDate(v)
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def cvtString(variant): # use to get old action of adodbapi v1 if desired
|
| 409 |
+
return str(variant)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def cvtDecimal(variant): # better name
|
| 413 |
+
return _convertNumberWithCulture(variant, decimal.Decimal)
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
def cvtNumeric(variant): # older name - don't break old code
|
| 417 |
+
return cvtDecimal(variant)
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def cvtFloat(variant):
|
| 421 |
+
return _convertNumberWithCulture(variant, float)
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def _convertNumberWithCulture(variant, f):
|
| 425 |
+
try:
|
| 426 |
+
return f(variant)
|
| 427 |
+
except (ValueError, TypeError, decimal.InvalidOperation):
|
| 428 |
+
try:
|
| 429 |
+
europeVsUS = str(variant).replace(",", ".")
|
| 430 |
+
return f(europeVsUS)
|
| 431 |
+
except (ValueError, TypeError, decimal.InvalidOperation):
|
| 432 |
+
pass
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def cvtInt(variant):
|
| 436 |
+
return int(variant)
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def cvtLong(variant): # only important in old versions where long and int differ
|
| 440 |
+
return int(variant)
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def cvtBuffer(variant):
|
| 444 |
+
return bytes(variant)
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def cvtUnicode(variant):
|
| 448 |
+
return str(variant)
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def identity(x):
|
| 452 |
+
return x
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
def cvtUnusual(variant):
|
| 456 |
+
if verbose > 1:
|
| 457 |
+
sys.stderr.write(f"Conversion called for Unusual data={variant!r}\n")
|
| 458 |
+
return variant # cannot find conversion function -- just give the data to the user
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def convert_to_python(variant, func): # convert DB value into Python value
|
| 462 |
+
if variant is None:
|
| 463 |
+
return None
|
| 464 |
+
return func(variant) # call the appropriate conversion function
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
class MultiMap(dict[int, Callable[[object], object]]):
|
| 468 |
+
# builds a dictionary from {(iterable,of,keys) : function}
|
| 469 |
+
"""A dictionary of ado.type : function
|
| 470 |
+
-- but you can set multiple items by passing an iterable of keys"""
|
| 471 |
+
|
| 472 |
+
# useful for defining conversion functions for groups of similar data types.
|
| 473 |
+
def __init__(self, aDict: Mapping[Iterable[int] | int, Callable[[object], object]]):
|
| 474 |
+
for k, v in aDict.items():
|
| 475 |
+
self[k] = v # we must call __setitem__
|
| 476 |
+
|
| 477 |
+
def __setitem__(
|
| 478 |
+
self, adoType: Iterable[int] | int, cvtFn: Callable[[object], object]
|
| 479 |
+
):
|
| 480 |
+
"set a single item, or a whole iterable of items"
|
| 481 |
+
if isinstance(adoType, Iterable):
|
| 482 |
+
# user passed us an iterable, set them individually
|
| 483 |
+
for type in adoType:
|
| 484 |
+
dict.__setitem__(self, type, cvtFn)
|
| 485 |
+
else:
|
| 486 |
+
dict.__setitem__(self, adoType, cvtFn)
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
# initialize variantConversions dictionary used to convert SQL to Python
|
| 490 |
+
# this is the dictionary of default conversion functions, built by the class above.
|
| 491 |
+
# this becomes a class attribute for the Connection, and that attribute is used
|
| 492 |
+
# to build the list of column conversion functions for the Cursor
|
| 493 |
+
variantConversions = MultiMap(
|
| 494 |
+
{
|
| 495 |
+
adoDateTimeTypes: variantConvertDate,
|
| 496 |
+
adoApproximateNumericTypes: cvtFloat,
|
| 497 |
+
adoExactNumericTypes: cvtDecimal, # use to force decimal rather than unicode
|
| 498 |
+
adoLongTypes: cvtLong,
|
| 499 |
+
adoIntegerTypes: cvtInt,
|
| 500 |
+
adoRowIdTypes: cvtInt,
|
| 501 |
+
adoStringTypes: identity,
|
| 502 |
+
adoBinaryTypes: cvtBuffer,
|
| 503 |
+
adoRemainingTypes: cvtUnusual,
|
| 504 |
+
}
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
# # # # # classes to emulate the result of cursor.fetchxxx() as a sequence of sequences # # # # #
|
| 508 |
+
# "an ENUM of how my low level records are laid out"
|
| 509 |
+
RS_WIN_32, RS_ARRAY, RS_REMOTE = list(range(1, 4))
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
class SQLrow: # a single database row
|
| 513 |
+
# class to emulate a sequence, so that a column may be retrieved by either number or name
|
| 514 |
+
def __init__(self, rows, index): # "rows" is an _SQLrows object, index is which row
|
| 515 |
+
self.rows = rows # parent 'fetch' container object
|
| 516 |
+
self.index = index # my row number within parent
|
| 517 |
+
|
| 518 |
+
def __getattr__(self, name): # used for row.columnName type of value access
|
| 519 |
+
try:
|
| 520 |
+
return self._getValue(self.rows.columnNames[name.lower()])
|
| 521 |
+
except KeyError:
|
| 522 |
+
raise AttributeError('Unknown column name "{}"'.format(name))
|
| 523 |
+
|
| 524 |
+
def _getValue(self, key): # key must be an integer
|
| 525 |
+
if (
|
| 526 |
+
self.rows.recordset_format == RS_ARRAY
|
| 527 |
+
): # retrieve from two-dimensional array
|
| 528 |
+
v = self.rows.ado_results[key, self.index]
|
| 529 |
+
elif self.rows.recordset_format == RS_REMOTE:
|
| 530 |
+
v = self.rows.ado_results[self.index][key]
|
| 531 |
+
else: # pywin32 - retrieve from tuple of tuples
|
| 532 |
+
v = self.rows.ado_results[key][self.index]
|
| 533 |
+
if self.rows.converters is NotImplemented:
|
| 534 |
+
return v
|
| 535 |
+
return convert_to_python(v, self.rows.converters[key])
|
| 536 |
+
|
| 537 |
+
def __len__(self):
|
| 538 |
+
return self.rows.numberOfColumns
|
| 539 |
+
|
| 540 |
+
def __getitem__(self, key): # used for row[key] type of value access
|
| 541 |
+
if isinstance(key, int): # normal row[1] designation
|
| 542 |
+
try:
|
| 543 |
+
return self._getValue(key)
|
| 544 |
+
except IndexError:
|
| 545 |
+
raise
|
| 546 |
+
if isinstance(key, slice):
|
| 547 |
+
indices = key.indices(self.rows.numberOfColumns)
|
| 548 |
+
vl = [self._getValue(i) for i in range(*indices)]
|
| 549 |
+
return tuple(vl)
|
| 550 |
+
try:
|
| 551 |
+
return self._getValue(
|
| 552 |
+
self.rows.columnNames[key.lower()]
|
| 553 |
+
) # extension row[columnName] designation
|
| 554 |
+
except (KeyError, TypeError):
|
| 555 |
+
er, st, tr = sys.exc_info()
|
| 556 |
+
raise er(f'No such key as "{key!r}" in {self!r}').with_traceback(tr)
|
| 557 |
+
|
| 558 |
+
def __iter__(self):
|
| 559 |
+
return iter(self.__next__())
|
| 560 |
+
|
| 561 |
+
def __next__(self):
|
| 562 |
+
for n in range(self.rows.numberOfColumns):
|
| 563 |
+
yield self._getValue(n)
|
| 564 |
+
|
| 565 |
+
def __repr__(self): # create a human readable representation
|
| 566 |
+
taglist = sorted(list(self.rows.columnNames.items()), key=lambda x: x[1])
|
| 567 |
+
s = "<SQLrow={"
|
| 568 |
+
for name, i in taglist:
|
| 569 |
+
s += f"{name}:{self._getValue(i)!r}, "
|
| 570 |
+
return s[:-2] + "}>"
|
| 571 |
+
|
| 572 |
+
def __str__(self): # create a pretty human readable representation
|
| 573 |
+
return str(
|
| 574 |
+
tuple(str(self._getValue(i)) for i in range(self.rows.numberOfColumns))
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
# TO-DO implement pickling an SQLrow directly
|
| 578 |
+
# def __getstate__(self): return self.__dict__
|
| 579 |
+
# def __setstate__(self, d): self.__dict__.update(d)
|
| 580 |
+
# which basically tell pickle to treat your class just like a normal one,
|
| 581 |
+
# taking self.__dict__ as representing the whole of the instance state,
|
| 582 |
+
# despite the existence of the __getattr__.
|
| 583 |
+
# # # #
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
class SQLrows:
|
| 587 |
+
# class to emulate a sequence for multiple rows using a container object
|
| 588 |
+
def __init__(self, ado_results, numberOfRows, cursor):
|
| 589 |
+
self.ado_results = ado_results # raw result of SQL get
|
| 590 |
+
try:
|
| 591 |
+
self.recordset_format = cursor.recordset_format
|
| 592 |
+
self.numberOfColumns = cursor.numberOfColumns
|
| 593 |
+
self.converters = cursor.converters
|
| 594 |
+
self.columnNames = cursor.columnNames
|
| 595 |
+
except AttributeError:
|
| 596 |
+
self.recordset_format = RS_ARRAY
|
| 597 |
+
self.numberOfColumns = 0
|
| 598 |
+
self.converters = []
|
| 599 |
+
self.columnNames = {}
|
| 600 |
+
self.numberOfRows = numberOfRows
|
| 601 |
+
|
| 602 |
+
def __len__(self):
|
| 603 |
+
return self.numberOfRows
|
| 604 |
+
|
| 605 |
+
def __getitem__(self, item): # used for row or row,column access
|
| 606 |
+
if not self.ado_results:
|
| 607 |
+
return []
|
| 608 |
+
if isinstance(item, slice): # will return a list of row objects
|
| 609 |
+
indices = item.indices(self.numberOfRows)
|
| 610 |
+
return [SQLrow(self, k) for k in range(*indices)]
|
| 611 |
+
elif isinstance(item, tuple) and len(item) == 2:
|
| 612 |
+
# d = some_rowsObject[i,j] will return a datum from a two-dimension address
|
| 613 |
+
i, j = item
|
| 614 |
+
if not isinstance(j, int):
|
| 615 |
+
try:
|
| 616 |
+
j = self.columnNames[j.lower()] # convert named column to numeric
|
| 617 |
+
except KeyError:
|
| 618 |
+
raise KeyError(f"adodbapi: no such column name as {j!r}")
|
| 619 |
+
if self.recordset_format == RS_ARRAY: # retrieve from two-dimensional array
|
| 620 |
+
v = self.ado_results[j, i]
|
| 621 |
+
elif self.recordset_format == RS_REMOTE:
|
| 622 |
+
v = self.ado_results[i][j]
|
| 623 |
+
else: # pywin32 - retrieve from tuple of tuples
|
| 624 |
+
v = self.ado_results[j][i]
|
| 625 |
+
if self.converters is NotImplemented:
|
| 626 |
+
return v
|
| 627 |
+
return convert_to_python(v, self.converters[j])
|
| 628 |
+
else:
|
| 629 |
+
row = SQLrow(self, item) # new row descriptor
|
| 630 |
+
return row
|
| 631 |
+
|
| 632 |
+
def __iter__(self):
|
| 633 |
+
return iter(self.__next__())
|
| 634 |
+
|
| 635 |
+
def __next__(self):
|
| 636 |
+
for n in range(self.numberOfRows):
|
| 637 |
+
row = SQLrow(self, n)
|
| 638 |
+
yield row
|
| 639 |
+
# # # # #
|
| 640 |
+
|
| 641 |
+
# # # # # functions to re-format SQL requests to other paramstyle requirements # # # # # # # # # #
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
def changeNamedToQmark(
|
| 645 |
+
op,
|
| 646 |
+
): # convert from 'named' paramstyle to ADO required '?'mark parameters
|
| 647 |
+
outOp = ""
|
| 648 |
+
outparms = []
|
| 649 |
+
chunks = op.split(
|
| 650 |
+
"'"
|
| 651 |
+
) # quote all literals -- odd numbered list results are literals.
|
| 652 |
+
inQuotes = False
|
| 653 |
+
for chunk in chunks:
|
| 654 |
+
if inQuotes: # this is inside a quote
|
| 655 |
+
if chunk == "": # double apostrophe to quote one apostrophe
|
| 656 |
+
outOp = outOp[:-1] # so take one away
|
| 657 |
+
else:
|
| 658 |
+
outOp += "'" + chunk + "'" # else pass the quoted string as is.
|
| 659 |
+
else: # is SQL code -- look for a :namedParameter
|
| 660 |
+
while chunk: # some SQL string remains
|
| 661 |
+
sp = chunk.split(":", 1)
|
| 662 |
+
outOp += sp[0] # concat the part up to the :
|
| 663 |
+
s = ""
|
| 664 |
+
try:
|
| 665 |
+
chunk = sp[1]
|
| 666 |
+
except IndexError:
|
| 667 |
+
chunk = None
|
| 668 |
+
if chunk: # there was a parameter - parse it out
|
| 669 |
+
i = 0
|
| 670 |
+
c = chunk[0]
|
| 671 |
+
while c.isalnum() or c == "_":
|
| 672 |
+
i += 1
|
| 673 |
+
try:
|
| 674 |
+
c = chunk[i]
|
| 675 |
+
except IndexError:
|
| 676 |
+
break
|
| 677 |
+
s = chunk[:i]
|
| 678 |
+
chunk = chunk[i:]
|
| 679 |
+
if s:
|
| 680 |
+
outparms.append(s) # list the parameters in order
|
| 681 |
+
outOp += "?" # put in the Qmark
|
| 682 |
+
inQuotes = not inQuotes
|
| 683 |
+
return outOp, outparms
|
| 684 |
+
|
| 685 |
+
|
| 686 |
+
def changeFormatToQmark(
|
| 687 |
+
op,
|
| 688 |
+
): # convert from 'format' paramstyle to ADO required '?'mark parameters
|
| 689 |
+
outOp = ""
|
| 690 |
+
outparams = []
|
| 691 |
+
chunks = op.split(
|
| 692 |
+
"'"
|
| 693 |
+
) # quote all literals -- odd numbered list results are literals.
|
| 694 |
+
inQuotes = False
|
| 695 |
+
for chunk in chunks:
|
| 696 |
+
if inQuotes:
|
| 697 |
+
if (
|
| 698 |
+
outOp != "" and chunk == ""
|
| 699 |
+
): # he used a double apostrophe to quote one apostrophe
|
| 700 |
+
outOp = outOp[:-1] # so take one away
|
| 701 |
+
else:
|
| 702 |
+
outOp += "'" + chunk + "'" # else pass the quoted string as is.
|
| 703 |
+
else: # is SQL code -- look for a %s parameter
|
| 704 |
+
if "%(" in chunk: # ugh! pyformat!
|
| 705 |
+
while chunk: # some SQL string remains
|
| 706 |
+
sp = chunk.split("%(", 1)
|
| 707 |
+
outOp += sp[0] # concat the part up to the %
|
| 708 |
+
if len(sp) > 1:
|
| 709 |
+
try:
|
| 710 |
+
s, chunk = sp[1].split(")s", 1) # find the ')s'
|
| 711 |
+
except ValueError:
|
| 712 |
+
raise ProgrammingError(
|
| 713 |
+
'Pyformat SQL has incorrect format near "%s"' % chunk
|
| 714 |
+
)
|
| 715 |
+
outparams.append(s)
|
| 716 |
+
outOp += "?" # put in the Qmark
|
| 717 |
+
else:
|
| 718 |
+
chunk = None
|
| 719 |
+
else: # proper '%s' format
|
| 720 |
+
sp = chunk.split("%s") # make each %s
|
| 721 |
+
outOp += "?".join(sp) # into ?
|
| 722 |
+
inQuotes = not inQuotes # every other chunk is a quoted string
|
| 723 |
+
return outOp, outparams
|
venv/Lib/site-packages/adodbapi/is64bit.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""is64bit.Python() --> boolean value of detected Python word size. is64bit.os() --> os build version"""
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def Python():
|
| 7 |
+
return sys.maxsize > 2147483647
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def os():
|
| 11 |
+
import platform
|
| 12 |
+
|
| 13 |
+
pm = platform.machine()
|
| 14 |
+
if pm != ".." and pm.endswith("64"): # recent 64 bit Python
|
| 15 |
+
return True
|
| 16 |
+
else:
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
if "PROCESSOR_ARCHITEW6432" in os.environ:
|
| 20 |
+
return True # 32 bit program running on 64 bit Windows
|
| 21 |
+
try:
|
| 22 |
+
return os.environ["PROCESSOR_ARCHITECTURE"].endswith(
|
| 23 |
+
"64"
|
| 24 |
+
) # 64 bit Windows 64 bit program
|
| 25 |
+
except (IndexError, KeyError):
|
| 26 |
+
pass # not Windows
|
| 27 |
+
try:
|
| 28 |
+
return "64" in platform.architecture()[0] # this often works in Linux
|
| 29 |
+
except:
|
| 30 |
+
return False # is an older version of Python, assume also an older os (best we can guess)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
if __name__ == "__main__":
|
| 34 |
+
print("is64bit.Python() =", Python(), "is64bit.os() =", os())
|
venv/Lib/site-packages/adodbapi/license.txt
ADDED
|
@@ -0,0 +1,505 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
GNU LESSER GENERAL PUBLIC LICENSE
|
| 2 |
+
Version 2.1, February 1999
|
| 3 |
+
|
| 4 |
+
Copyright (C) 1991, 1999 Free Software Foundation, Inc.
|
| 5 |
+
59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
|
| 6 |
+
Everyone is permitted to copy and distribute verbatim copies
|
| 7 |
+
of this license document, but changing it is not allowed.
|
| 8 |
+
|
| 9 |
+
[This is the first released version of the Lesser GPL. It also counts
|
| 10 |
+
as the successor of the GNU Library Public License, version 2, hence
|
| 11 |
+
the version number 2.1.]
|
| 12 |
+
|
| 13 |
+
Preamble
|
| 14 |
+
|
| 15 |
+
The licenses for most software are designed to take away your
|
| 16 |
+
freedom to share and change it. By contrast, the GNU General Public
|
| 17 |
+
Licenses are intended to guarantee your freedom to share and change
|
| 18 |
+
free software--to make sure the software is free for all its users.
|
| 19 |
+
|
| 20 |
+
This license, the Lesser General Public License, applies to some
|
| 21 |
+
specially designated software packages--typically libraries--of the
|
| 22 |
+
Free Software Foundation and other authors who decide to use it. You
|
| 23 |
+
can use it too, but we suggest you first think carefully about whether
|
| 24 |
+
this license or the ordinary General Public License is the better
|
| 25 |
+
strategy to use in any particular case, based on the explanations below.
|
| 26 |
+
|
| 27 |
+
When we speak of free software, we are referring to freedom of use,
|
| 28 |
+
not price. Our General Public Licenses are designed to make sure that
|
| 29 |
+
you have the freedom to distribute copies of free software (and charge
|
| 30 |
+
for this service if you wish); that you receive source code or can get
|
| 31 |
+
it if you want it; that you can change the software and use pieces of
|
| 32 |
+
it in new free programs; and that you are informed that you can do
|
| 33 |
+
these things.
|
| 34 |
+
|
| 35 |
+
To protect your rights, we need to make restrictions that forbid
|
| 36 |
+
distributors to deny you these rights or to ask you to surrender these
|
| 37 |
+
rights. These restrictions translate to certain responsibilities for
|
| 38 |
+
you if you distribute copies of the library or if you modify it.
|
| 39 |
+
|
| 40 |
+
For example, if you distribute copies of the library, whether gratis
|
| 41 |
+
or for a fee, you must give the recipients all the rights that we gave
|
| 42 |
+
you. You must make sure that they, too, receive or can get the source
|
| 43 |
+
code. If you link other code with the library, you must provide
|
| 44 |
+
complete object files to the recipients, so that they can relink them
|
| 45 |
+
with the library after making changes to the library and recompiling
|
| 46 |
+
it. And you must show them these terms so they know their rights.
|
| 47 |
+
|
| 48 |
+
We protect your rights with a two-step method: (1) we copyright the
|
| 49 |
+
library, and (2) we offer you this license, which gives you legal
|
| 50 |
+
permission to copy, distribute and/or modify the library.
|
| 51 |
+
|
| 52 |
+
To protect each distributor, we want to make it very clear that
|
| 53 |
+
there is no warranty for the free library. Also, if the library is
|
| 54 |
+
modified by someone else and passed on, the recipients should know
|
| 55 |
+
that what they have is not the original version, so that the original
|
| 56 |
+
author's reputation will not be affected by problems that might be
|
| 57 |
+
introduced by others.
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
Finally, software patents pose a constant threat to the existence of
|
| 62 |
+
any free program. We wish to make sure that a company cannot
|
| 63 |
+
effectively restrict the users of a free program by obtaining a
|
| 64 |
+
restrictive license from a patent holder. Therefore, we insist that
|
| 65 |
+
any patent license obtained for a version of the library must be
|
| 66 |
+
consistent with the full freedom of use specified in this license.
|
| 67 |
+
|
| 68 |
+
Most GNU software, including some libraries, is covered by the
|
| 69 |
+
ordinary GNU General Public License. This license, the GNU Lesser
|
| 70 |
+
General Public License, applies to certain designated libraries, and
|
| 71 |
+
is quite different from the ordinary General Public License. We use
|
| 72 |
+
this license for certain libraries in order to permit linking those
|
| 73 |
+
libraries into non-free programs.
|
| 74 |
+
|
| 75 |
+
When a program is linked with a library, whether statically or using
|
| 76 |
+
a shared library, the combination of the two is legally speaking a
|
| 77 |
+
combined work, a derivative of the original library. The ordinary
|
| 78 |
+
General Public License therefore permits such linking only if the
|
| 79 |
+
entire combination fits its criteria of freedom. The Lesser General
|
| 80 |
+
Public License permits more lax criteria for linking other code with
|
| 81 |
+
the library.
|
| 82 |
+
|
| 83 |
+
We call this license the "Lesser" General Public License because it
|
| 84 |
+
does Less to protect the user's freedom than the ordinary General
|
| 85 |
+
Public License. It also provides other free software developers Less
|
| 86 |
+
of an advantage over competing non-free programs. These disadvantages
|
| 87 |
+
are the reason we use the ordinary General Public License for many
|
| 88 |
+
libraries. However, the Lesser license provides advantages in certain
|
| 89 |
+
special circumstances.
|
| 90 |
+
|
| 91 |
+
For example, on rare occasions, there may be a special need to
|
| 92 |
+
encourage the widest possible use of a certain library, so that it becomes
|
| 93 |
+
a de-facto standard. To achieve this, non-free programs must be
|
| 94 |
+
allowed to use the library. A more frequent case is that a free
|
| 95 |
+
library does the same job as widely used non-free libraries. In this
|
| 96 |
+
case, there is little to gain by limiting the free library to free
|
| 97 |
+
software only, so we use the Lesser General Public License.
|
| 98 |
+
|
| 99 |
+
In other cases, permission to use a particular library in non-free
|
| 100 |
+
programs enables a greater number of people to use a large body of
|
| 101 |
+
free software. For example, permission to use the GNU C Library in
|
| 102 |
+
non-free programs enables many more people to use the whole GNU
|
| 103 |
+
operating system, as well as its variant, the GNU/Linux operating
|
| 104 |
+
system.
|
| 105 |
+
|
| 106 |
+
Although the Lesser General Public License is Less protective of the
|
| 107 |
+
users' freedom, it does ensure that the user of a program that is
|
| 108 |
+
linked with the Library has the freedom and the wherewithal to run
|
| 109 |
+
that program using a modified version of the Library.
|
| 110 |
+
|
| 111 |
+
The precise terms and conditions for copying, distribution and
|
| 112 |
+
modification follow. Pay close attention to the difference between a
|
| 113 |
+
"work based on the library" and a "work that uses the library". The
|
| 114 |
+
former contains code derived from the library, whereas the latter must
|
| 115 |
+
be combined with the library in order to run.
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
GNU LESSER GENERAL PUBLIC LICENSE
|
| 120 |
+
TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
|
| 121 |
+
|
| 122 |
+
0. This License Agreement applies to any software library or other
|
| 123 |
+
program which contains a notice placed by the copyright holder or
|
| 124 |
+
other authorized party saying it may be distributed under the terms of
|
| 125 |
+
this Lesser General Public License (also called "this License").
|
| 126 |
+
Each licensee is addressed as "you".
|
| 127 |
+
|
| 128 |
+
A "library" means a collection of software functions and/or data
|
| 129 |
+
prepared so as to be conveniently linked with application programs
|
| 130 |
+
(which use some of those functions and data) to form executables.
|
| 131 |
+
|
| 132 |
+
The "Library", below, refers to any such software library or work
|
| 133 |
+
which has been distributed under these terms. A "work based on the
|
| 134 |
+
Library" means either the Library or any derivative work under
|
| 135 |
+
copyright law: that is to say, a work containing the Library or a
|
| 136 |
+
portion of it, either verbatim or with modifications and/or translated
|
| 137 |
+
straightforwardly into another language. (Hereinafter, translation is
|
| 138 |
+
included without limitation in the term "modification".)
|
| 139 |
+
|
| 140 |
+
"Source code" for a work means the preferred form of the work for
|
| 141 |
+
making modifications to it. For a library, complete source code means
|
| 142 |
+
all the source code for all modules it contains, plus any associated
|
| 143 |
+
interface definition files, plus the scripts used to control compilation
|
| 144 |
+
and installation of the library.
|
| 145 |
+
|
| 146 |
+
Activities other than copying, distribution and modification are not
|
| 147 |
+
covered by this License; they are outside its scope. The act of
|
| 148 |
+
running a program using the Library is not restricted, and output from
|
| 149 |
+
such a program is covered only if its contents constitute a work based
|
| 150 |
+
on the Library (independent of the use of the Library in a tool for
|
| 151 |
+
writing it). Whether that is true depends on what the Library does
|
| 152 |
+
and what the program that uses the Library does.
|
| 153 |
+
|
| 154 |
+
1. You may copy and distribute verbatim copies of the Library's
|
| 155 |
+
complete source code as you receive it, in any medium, provided that
|
| 156 |
+
you conspicuously and appropriately publish on each copy an
|
| 157 |
+
appropriate copyright notice and disclaimer of warranty; keep intact
|
| 158 |
+
all the notices that refer to this License and to the absence of any
|
| 159 |
+
warranty; and distribute a copy of this License along with the
|
| 160 |
+
Library.
|
| 161 |
+
You may charge a fee for the physical act of transferring a copy,
|
| 162 |
+
and you may at your option offer warranty protection in exchange for a
|
| 163 |
+
fee.
|
| 164 |
+
|
| 165 |
+
2. You may modify your copy or copies of the Library or any portion
|
| 166 |
+
of it, thus forming a work based on the Library, and copy and
|
| 167 |
+
distribute such modifications or work under the terms of Section 1
|
| 168 |
+
above, provided that you also meet all of these conditions:
|
| 169 |
+
|
| 170 |
+
a) The modified work must itself be a software library.
|
| 171 |
+
|
| 172 |
+
b) You must cause the files modified to carry prominent notices
|
| 173 |
+
stating that you changed the files and the date of any change.
|
| 174 |
+
|
| 175 |
+
c) You must cause the whole of the work to be licensed at no
|
| 176 |
+
charge to all third parties under the terms of this License.
|
| 177 |
+
|
| 178 |
+
d) If a facility in the modified Library refers to a function or a
|
| 179 |
+
table of data to be supplied by an application program that uses
|
| 180 |
+
the facility, other than as an argument passed when the facility
|
| 181 |
+
is invoked, then you must make a good faith effort to ensure that,
|
| 182 |
+
in the event an application does not supply such function or
|
| 183 |
+
table, the facility still operates, and performs whatever part of
|
| 184 |
+
its purpose remains meaningful.
|
| 185 |
+
|
| 186 |
+
(For example, a function in a library to compute square roots has
|
| 187 |
+
a purpose that is entirely well-defined independent of the
|
| 188 |
+
application. Therefore, Subsection 2d requires that any
|
| 189 |
+
application-supplied function or table used by this function must
|
| 190 |
+
be optional: if the application does not supply it, the square
|
| 191 |
+
root function must still compute square roots.)
|
| 192 |
+
|
| 193 |
+
These requirements apply to the modified work as a whole. If
|
| 194 |
+
identifiable sections of that work are not derived from the Library,
|
| 195 |
+
and can be reasonably considered independent and separate works in
|
| 196 |
+
themselves, then this License, and its terms, do not apply to those
|
| 197 |
+
sections when you distribute them as separate works. But when you
|
| 198 |
+
distribute the same sections as part of a whole which is a work based
|
| 199 |
+
on the Library, the distribution of the whole must be on the terms of
|
| 200 |
+
this License, whose permissions for other licensees extend to the
|
| 201 |
+
entire whole, and thus to each and every part regardless of who wrote
|
| 202 |
+
it.
|
| 203 |
+
|
| 204 |
+
Thus, it is not the intent of this section to claim rights or contest
|
| 205 |
+
your rights to work written entirely by you; rather, the intent is to
|
| 206 |
+
exercise the right to control the distribution of derivative or
|
| 207 |
+
collective works based on the Library.
|
| 208 |
+
|
| 209 |
+
In addition, mere aggregation of another work not based on the Library
|
| 210 |
+
with the Library (or with a work based on the Library) on a volume of
|
| 211 |
+
a storage or distribution medium does not bring the other work under
|
| 212 |
+
the scope of this License.
|
| 213 |
+
|
| 214 |
+
3. You may opt to apply the terms of the ordinary GNU General Public
|
| 215 |
+
License instead of this License to a given copy of the Library. To do
|
| 216 |
+
this, you must alter all the notices that refer to this License, so
|
| 217 |
+
that they refer to the ordinary GNU General Public License, version 2,
|
| 218 |
+
instead of to this License. (If a newer version than version 2 of the
|
| 219 |
+
ordinary GNU General Public License has appeared, then you can specify
|
| 220 |
+
that version instead if you wish.) Do not make any other change in
|
| 221 |
+
these notices.
|
| 222 |
+
|
| 223 |
+
Once this change is made in a given copy, it is irreversible for
|
| 224 |
+
that copy, so the ordinary GNU General Public License applies to all
|
| 225 |
+
subsequent copies and derivative works made from that copy.
|
| 226 |
+
|
| 227 |
+
This option is useful when you wish to copy part of the code of
|
| 228 |
+
the Library into a program that is not a library.
|
| 229 |
+
|
| 230 |
+
4. You may copy and distribute the Library (or a portion or
|
| 231 |
+
derivative of it, under Section 2) in object code or executable form
|
| 232 |
+
under the terms of Sections 1 and 2 above provided that you accompany
|
| 233 |
+
it with the complete corresponding machine-readable source code, which
|
| 234 |
+
must be distributed under the terms of Sections 1 and 2 above on a
|
| 235 |
+
medium customarily used for software interchange.
|
| 236 |
+
|
| 237 |
+
If distribution of object code is made by offering access to copy
|
| 238 |
+
from a designated place, then offering equivalent access to copy the
|
| 239 |
+
source code from the same place satisfies the requirement to
|
| 240 |
+
distribute the source code, even though third parties are not
|
| 241 |
+
compelled to copy the source along with the object code.
|
| 242 |
+
|
| 243 |
+
5. A program that contains no derivative of any portion of the
|
| 244 |
+
Library, but is designed to work with the Library by being compiled or
|
| 245 |
+
linked with it, is called a "work that uses the Library". Such a
|
| 246 |
+
work, in isolation, is not a derivative work of the Library, and
|
| 247 |
+
therefore falls outside the scope of this License.
|
| 248 |
+
|
| 249 |
+
However, linking a "work that uses the Library" with the Library
|
| 250 |
+
creates an executable that is a derivative of the Library (because it
|
| 251 |
+
contains portions of the Library), rather than a "work that uses the
|
| 252 |
+
library". The executable is therefore covered by this License.
|
| 253 |
+
Section 6 states terms for distribution of such executables.
|
| 254 |
+
|
| 255 |
+
When a "work that uses the Library" uses material from a header file
|
| 256 |
+
that is part of the Library, the object code for the work may be a
|
| 257 |
+
derivative work of the Library even though the source code is not.
|
| 258 |
+
Whether this is true is especially significant if the work can be
|
| 259 |
+
linked without the Library, or if the work is itself a library. The
|
| 260 |
+
threshold for this to be true is not precisely defined by law.
|
| 261 |
+
|
| 262 |
+
If such an object file uses only numerical parameters, data
|
| 263 |
+
structure layouts and accessors, and small macros and small inline
|
| 264 |
+
functions (ten lines or less in length), then the use of the object
|
| 265 |
+
file is unrestricted, regardless of whether it is legally a derivative
|
| 266 |
+
work. (Executables containing this object code plus portions of the
|
| 267 |
+
Library will still fall under Section 6.)
|
| 268 |
+
|
| 269 |
+
Otherwise, if the work is a derivative of the Library, you may
|
| 270 |
+
distribute the object code for the work under the terms of Section 6.
|
| 271 |
+
Any executables containing that work also fall under Section 6,
|
| 272 |
+
whether or not they are linked directly with the Library itself.
|
| 273 |
+
|
| 274 |
+
6. As an exception to the Sections above, you may also combine or
|
| 275 |
+
link a "work that uses the Library" with the Library to produce a
|
| 276 |
+
work containing portions of the Library, and distribute that work
|
| 277 |
+
under terms of your choice, provided that the terms permit
|
| 278 |
+
modification of the work for the customer's own use and reverse
|
| 279 |
+
engineering for debugging such modifications.
|
| 280 |
+
|
| 281 |
+
You must give prominent notice with each copy of the work that the
|
| 282 |
+
Library is used in it and that the Library and its use are covered by
|
| 283 |
+
this License. You must supply a copy of this License. If the work
|
| 284 |
+
during execution displays copyright notices, you must include the
|
| 285 |
+
copyright notice for the Library among them, as well as a reference
|
| 286 |
+
directing the user to the copy of this License. Also, you must do one
|
| 287 |
+
of these things:
|
| 288 |
+
|
| 289 |
+
a) Accompany the work with the complete corresponding
|
| 290 |
+
machine-readable source code for the Library including whatever
|
| 291 |
+
changes were used in the work (which must be distributed under
|
| 292 |
+
Sections 1 and 2 above); and, if the work is an executable linked
|
| 293 |
+
with the Library, with the complete machine-readable "work that
|
| 294 |
+
uses the Library", as object code and/or source code, so that the
|
| 295 |
+
user can modify the Library and then relink to produce a modified
|
| 296 |
+
executable containing the modified Library. (It is understood
|
| 297 |
+
that the user who changes the contents of definitions files in the
|
| 298 |
+
Library will not necessarily be able to recompile the application
|
| 299 |
+
to use the modified definitions.)
|
| 300 |
+
|
| 301 |
+
b) Use a suitable shared library mechanism for linking with the
|
| 302 |
+
Library. A suitable mechanism is one that (1) uses at run time a
|
| 303 |
+
copy of the library already present on the user's computer system,
|
| 304 |
+
rather than copying library functions into the executable, and (2)
|
| 305 |
+
will operate properly with a modified version of the library, if
|
| 306 |
+
the user installs one, as long as the modified version is
|
| 307 |
+
interface-compatible with the version that the work was made with.
|
| 308 |
+
|
| 309 |
+
c) Accompany the work with a written offer, valid for at
|
| 310 |
+
least three years, to give the same user the materials
|
| 311 |
+
specified in Subsection 6a, above, for a charge no more
|
| 312 |
+
than the cost of performing this distribution.
|
| 313 |
+
|
| 314 |
+
d) If distribution of the work is made by offering access to copy
|
| 315 |
+
from a designated place, offer equivalent access to copy the above
|
| 316 |
+
specified materials from the same place.
|
| 317 |
+
|
| 318 |
+
e) Verify that the user has already received a copy of these
|
| 319 |
+
materials or that you have already sent this user a copy.
|
| 320 |
+
|
| 321 |
+
For an executable, the required form of the "work that uses the
|
| 322 |
+
Library" must include any data and utility programs needed for
|
| 323 |
+
reproducing the executable from it. However, as a special exception,
|
| 324 |
+
the materials to be distributed need not include anything that is
|
| 325 |
+
normally distributed (in either source or binary form) with the major
|
| 326 |
+
components (compiler, kernel, and so on) of the operating system on
|
| 327 |
+
which the executable runs, unless that component itself accompanies
|
| 328 |
+
the executable.
|
| 329 |
+
|
| 330 |
+
It may happen that this requirement contradicts the license
|
| 331 |
+
restrictions of other proprietary libraries that do not normally
|
| 332 |
+
accompany the operating system. Such a contradiction means you cannot
|
| 333 |
+
use both them and the Library together in an executable that you
|
| 334 |
+
distribute.
|
| 335 |
+
|
| 336 |
+
7. You may place library facilities that are a work based on the
|
| 337 |
+
Library side-by-side in a single library together with other library
|
| 338 |
+
facilities not covered by this License, and distribute such a combined
|
| 339 |
+
library, provided that the separate distribution of the work based on
|
| 340 |
+
the Library and of the other library facilities is otherwise
|
| 341 |
+
permitted, and provided that you do these two things:
|
| 342 |
+
|
| 343 |
+
a) Accompany the combined library with a copy of the same work
|
| 344 |
+
based on the Library, uncombined with any other library
|
| 345 |
+
facilities. This must be distributed under the terms of the
|
| 346 |
+
Sections above.
|
| 347 |
+
|
| 348 |
+
b) Give prominent notice with the combined library of the fact
|
| 349 |
+
that part of it is a work based on the Library, and explaining
|
| 350 |
+
where to find the accompanying uncombined form of the same work.
|
| 351 |
+
|
| 352 |
+
8. You may not copy, modify, sublicense, link with, or distribute
|
| 353 |
+
the Library except as expressly provided under this License. Any
|
| 354 |
+
attempt otherwise to copy, modify, sublicense, link with, or
|
| 355 |
+
distribute the Library is void, and will automatically terminate your
|
| 356 |
+
rights under this License. However, parties who have received copies,
|
| 357 |
+
or rights, from you under this License will not have their licenses
|
| 358 |
+
terminated so long as such parties remain in full compliance.
|
| 359 |
+
|
| 360 |
+
9. You are not required to accept this License, since you have not
|
| 361 |
+
signed it. However, nothing else grants you permission to modify or
|
| 362 |
+
distribute the Library or its derivative works. These actions are
|
| 363 |
+
prohibited by law if you do not accept this License. Therefore, by
|
| 364 |
+
modifying or distributing the Library (or any work based on the
|
| 365 |
+
Library), you indicate your acceptance of this License to do so, and
|
| 366 |
+
all its terms and conditions for copying, distributing or modifying
|
| 367 |
+
the Library or works based on it.
|
| 368 |
+
|
| 369 |
+
10. Each time you redistribute the Library (or any work based on the
|
| 370 |
+
Library), the recipient automatically receives a license from the
|
| 371 |
+
original licensor to copy, distribute, link with or modify the Library
|
| 372 |
+
subject to these terms and conditions. You may not impose any further
|
| 373 |
+
restrictions on the recipients' exercise of the rights granted herein.
|
| 374 |
+
You are not responsible for enforcing compliance by third parties with
|
| 375 |
+
this License.
|
| 376 |
+
|
| 377 |
+
11. If, as a consequence of a court judgment or allegation of patent
|
| 378 |
+
infringement or for any other reason (not limited to patent issues),
|
| 379 |
+
conditions are imposed on you (whether by court order, agreement or
|
| 380 |
+
otherwise) that contradict the conditions of this License, they do not
|
| 381 |
+
excuse you from the conditions of this License. If you cannot
|
| 382 |
+
distribute so as to satisfy simultaneously your obligations under this
|
| 383 |
+
License and any other pertinent obligations, then as a consequence you
|
| 384 |
+
may not distribute the Library at all. For example, if a patent
|
| 385 |
+
license would not permit royalty-free redistribution of the Library by
|
| 386 |
+
all those who receive copies directly or indirectly through you, then
|
| 387 |
+
the only way you could satisfy both it and this License would be to
|
| 388 |
+
refrain entirely from distribution of the Library.
|
| 389 |
+
|
| 390 |
+
If any portion of this section is held invalid or unenforceable under any
|
| 391 |
+
particular circumstance, the balance of the section is intended to apply,
|
| 392 |
+
and the section as a whole is intended to apply in other circumstances.
|
| 393 |
+
|
| 394 |
+
It is not the purpose of this section to induce you to infringe any
|
| 395 |
+
patents or other property right claims or to contest validity of any
|
| 396 |
+
such claims; this section has the sole purpose of protecting the
|
| 397 |
+
integrity of the free software distribution system which is
|
| 398 |
+
implemented by public license practices. Many people have made
|
| 399 |
+
generous contributions to the wide range of software distributed
|
| 400 |
+
through that system in reliance on consistent application of that
|
| 401 |
+
system; it is up to the author/donor to decide if he or she is willing
|
| 402 |
+
to distribute software through any other system and a licensee cannot
|
| 403 |
+
impose that choice.
|
| 404 |
+
|
| 405 |
+
This section is intended to make thoroughly clear what is believed to
|
| 406 |
+
be a consequence of the rest of this License.
|
| 407 |
+
|
| 408 |
+
12. If the distribution and/or use of the Library is restricted in
|
| 409 |
+
certain countries either by patents or by copyrighted interfaces, the
|
| 410 |
+
original copyright holder who places the Library under this License may add
|
| 411 |
+
an explicit geographical distribution limitation excluding those countries,
|
| 412 |
+
so that distribution is permitted only in or among countries not thus
|
| 413 |
+
excluded. In such case, this License incorporates the limitation as if
|
| 414 |
+
written in the body of this License.
|
| 415 |
+
|
| 416 |
+
13. The Free Software Foundation may publish revised and/or new
|
| 417 |
+
versions of the Lesser General Public License from time to time.
|
| 418 |
+
Such new versions will be similar in spirit to the present version,
|
| 419 |
+
but may differ in detail to address new problems or concerns.
|
| 420 |
+
|
| 421 |
+
Each version is given a distinguishing version number. If the Library
|
| 422 |
+
specifies a version number of this License which applies to it and
|
| 423 |
+
"any later version", you have the option of following the terms and
|
| 424 |
+
conditions either of that version or of any later version published by
|
| 425 |
+
the Free Software Foundation. If the Library does not specify a
|
| 426 |
+
license version number, you may choose any version ever published by
|
| 427 |
+
the Free Software Foundation.
|
| 428 |
+
|
| 429 |
+
14. If you wish to incorporate parts of the Library into other free
|
| 430 |
+
programs whose distribution conditions are incompatible with these,
|
| 431 |
+
write to the author to ask for permission. For software which is
|
| 432 |
+
copyrighted by the Free Software Foundation, write to the Free
|
| 433 |
+
Software Foundation; we sometimes make exceptions for this. Our
|
| 434 |
+
decision will be guided by the two goals of preserving the free status
|
| 435 |
+
of all derivatives of our free software and of promoting the sharing
|
| 436 |
+
and reuse of software generally.
|
| 437 |
+
|
| 438 |
+
NO WARRANTY
|
| 439 |
+
|
| 440 |
+
15. BECAUSE THE LIBRARY IS LICENSED FREE OF CHARGE, THERE IS NO
|
| 441 |
+
WARRANTY FOR THE LIBRARY, TO THE EXTENT PERMITTED BY APPLICABLE LAW.
|
| 442 |
+
EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR
|
| 443 |
+
OTHER PARTIES PROVIDE THE LIBRARY "AS IS" WITHOUT WARRANTY OF ANY
|
| 444 |
+
KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE
|
| 445 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
| 446 |
+
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE
|
| 447 |
+
LIBRARY IS WITH YOU. SHOULD THE LIBRARY PROVE DEFECTIVE, YOU ASSUME
|
| 448 |
+
THE COST OF ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
| 449 |
+
|
| 450 |
+
16. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN
|
| 451 |
+
WRITING WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY
|
| 452 |
+
AND/OR REDISTRIBUTE THE LIBRARY AS PERMITTED ABOVE, BE LIABLE TO YOU
|
| 453 |
+
FOR DAMAGES, INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR
|
| 454 |
+
CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OR INABILITY TO USE THE
|
| 455 |
+
LIBRARY (INCLUDING BUT NOT LIMITED TO LOSS OF DATA OR DATA BEING
|
| 456 |
+
RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A
|
| 457 |
+
FAILURE OF THE LIBRARY TO OPERATE WITH ANY OTHER SOFTWARE), EVEN IF
|
| 458 |
+
SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH
|
| 459 |
+
DAMAGES.
|
| 460 |
+
|
| 461 |
+
END OF TERMS AND CONDITIONS
|
| 462 |
+
|
| 463 |
+
How to Apply These Terms to Your New Libraries
|
| 464 |
+
|
| 465 |
+
If you develop a new library, and you want it to be of the greatest
|
| 466 |
+
possible use to the public, we recommend making it free software that
|
| 467 |
+
everyone can redistribute and change. You can do so by permitting
|
| 468 |
+
redistribution under these terms (or, alternatively, under the terms of the
|
| 469 |
+
ordinary General Public License).
|
| 470 |
+
|
| 471 |
+
To apply these terms, attach the following notices to the library. It is
|
| 472 |
+
safest to attach them to the start of each source file to most effectively
|
| 473 |
+
convey the exclusion of warranty; and each file should have at least the
|
| 474 |
+
"copyright" line and a pointer to where the full notice is found.
|
| 475 |
+
|
| 476 |
+
<one line to give the library's name and a brief idea of what it does.>
|
| 477 |
+
Copyright (C) <year> <name of author>
|
| 478 |
+
|
| 479 |
+
This library is free software; you can redistribute it and/or
|
| 480 |
+
modify it under the terms of the GNU Lesser General Public
|
| 481 |
+
License as published by the Free Software Foundation; either
|
| 482 |
+
version 2.1 of the License, or (at your option) any later version.
|
| 483 |
+
|
| 484 |
+
This library is distributed in the hope that it will be useful,
|
| 485 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 486 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
| 487 |
+
Lesser General Public License for more details.
|
| 488 |
+
|
| 489 |
+
You should have received a copy of the GNU Lesser General Public
|
| 490 |
+
License along with this library; if not, write to the Free Software
|
| 491 |
+
Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
|
| 492 |
+
|
| 493 |
+
Also add information on how to contact you by electronic and paper mail.
|
| 494 |
+
|
| 495 |
+
You should also get your employer (if you work as a programmer) or your
|
| 496 |
+
school, if any, to sign a "copyright disclaimer" for the library, if
|
| 497 |
+
necessary. Here is a sample; alter the names:
|
| 498 |
+
|
| 499 |
+
Yoyodyne, Inc., hereby disclaims all copyright interest in the
|
| 500 |
+
library `Frob' (a library for tweaking knobs) written by James Random Hacker.
|
| 501 |
+
|
| 502 |
+
<signature of Ty Coon>, 1 April 1990
|
| 503 |
+
Ty Coon, President of Vice
|
| 504 |
+
|
| 505 |
+
That's all there is to it!
|
venv/Lib/site-packages/adodbapi/process_connect_string.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""a clumsy attempt at a macro language to let the programmer execute code on the server (ex: determine 64bit)"""
|
| 2 |
+
|
| 3 |
+
from . import is64bit
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def macro_call(macro_name, args, kwargs):
|
| 7 |
+
"""allow the programmer to perform limited processing on the server by passing macro names and args
|
| 8 |
+
|
| 9 |
+
:new_key - the key name the macro will create
|
| 10 |
+
:args[0] - macro name
|
| 11 |
+
:args[1:] - any arguments
|
| 12 |
+
:code - the value of the keyword item
|
| 13 |
+
:kwargs - the connection keyword dictionary. ??key has been removed
|
| 14 |
+
--> the value to put in for kwargs['name'] = value
|
| 15 |
+
"""
|
| 16 |
+
if isinstance(args, (str, str)):
|
| 17 |
+
args = [
|
| 18 |
+
args
|
| 19 |
+
] # the user forgot to pass a sequence, so make a string into args[0]
|
| 20 |
+
new_key = args[0]
|
| 21 |
+
try:
|
| 22 |
+
if macro_name == "is64bit":
|
| 23 |
+
if is64bit.Python(): # if on 64 bit Python
|
| 24 |
+
return new_key, args[1] # return first argument
|
| 25 |
+
else:
|
| 26 |
+
try:
|
| 27 |
+
return new_key, args[2] # else return second argument (if defined)
|
| 28 |
+
except IndexError:
|
| 29 |
+
return new_key, "" # else return blank
|
| 30 |
+
|
| 31 |
+
elif (
|
| 32 |
+
macro_name == "getuser"
|
| 33 |
+
): # get the name of the user the server is logged in under
|
| 34 |
+
if not new_key in kwargs:
|
| 35 |
+
import getpass
|
| 36 |
+
|
| 37 |
+
return new_key, getpass.getuser()
|
| 38 |
+
|
| 39 |
+
elif macro_name == "getnode": # get the name of the computer running the server
|
| 40 |
+
import platform
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
return new_key, args[1] % platform.node()
|
| 44 |
+
except IndexError:
|
| 45 |
+
return new_key, platform.node()
|
| 46 |
+
|
| 47 |
+
elif macro_name == "getenv": # expand the server's environment variable args[1]
|
| 48 |
+
import os
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
dflt = args[2] # if not found, default from args[2]
|
| 52 |
+
except IndexError: # or blank
|
| 53 |
+
dflt = ""
|
| 54 |
+
return new_key, os.environ.get(args[1], dflt)
|
| 55 |
+
|
| 56 |
+
elif macro_name == "auto_security":
|
| 57 |
+
if (
|
| 58 |
+
not "user" in kwargs or not kwargs["user"]
|
| 59 |
+
): # missing, blank, or Null username
|
| 60 |
+
return new_key, "Integrated Security=SSPI"
|
| 61 |
+
return new_key, "User ID=%(user)s; Password=%(password)s" % kwargs
|
| 62 |
+
|
| 63 |
+
elif (
|
| 64 |
+
macro_name == "find_temp_test_path"
|
| 65 |
+
): # helper function for testing ado operation -- undocumented
|
| 66 |
+
import os
|
| 67 |
+
import tempfile
|
| 68 |
+
|
| 69 |
+
return new_key, os.path.join(
|
| 70 |
+
tempfile.gettempdir(), "adodbapi_test", args[1]
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
raise ValueError(f"Unknown connect string macro={macro_name}")
|
| 74 |
+
except:
|
| 75 |
+
raise ValueError(f"Error in macro processing {macro_name} {args!r}")
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def process(
|
| 79 |
+
args, kwargs, expand_macros=False
|
| 80 |
+
): # --> connection string with keyword arguments processed.
|
| 81 |
+
"""attempts to inject arguments into a connection string using Python "%" operator for strings
|
| 82 |
+
|
| 83 |
+
co: adodbapi connection object
|
| 84 |
+
args: positional parameters from the .connect() call
|
| 85 |
+
kvargs: keyword arguments from the .connect() call
|
| 86 |
+
"""
|
| 87 |
+
try:
|
| 88 |
+
dsn = args[0]
|
| 89 |
+
except IndexError:
|
| 90 |
+
dsn = None
|
| 91 |
+
# as a convenience the first argument may be django settings
|
| 92 |
+
if isinstance(dsn, dict):
|
| 93 |
+
kwargs.update(dsn)
|
| 94 |
+
# the connection string is passed to the connection as part of the keyword dictionary
|
| 95 |
+
elif dsn:
|
| 96 |
+
kwargs["connection_string"] = dsn
|
| 97 |
+
try:
|
| 98 |
+
a1 = args[1]
|
| 99 |
+
except IndexError:
|
| 100 |
+
a1 = None
|
| 101 |
+
# historically, the second positional argument might be a timeout value
|
| 102 |
+
if isinstance(a1, int):
|
| 103 |
+
kwargs["timeout"] = a1
|
| 104 |
+
# if the second positional argument is a string, then it is user
|
| 105 |
+
elif isinstance(a1, str):
|
| 106 |
+
kwargs["user"] = a1
|
| 107 |
+
# if the second positional argument is a dictionary, use it as keyword arguments, too
|
| 108 |
+
elif isinstance(a1, dict):
|
| 109 |
+
kwargs.update(a1)
|
| 110 |
+
try:
|
| 111 |
+
kwargs["password"] = args[2] # the third positional argument is password
|
| 112 |
+
kwargs["host"] = args[3] # the fourth positional argument is host name
|
| 113 |
+
kwargs["database"] = args[4] # the fifth positional argument is database name
|
| 114 |
+
except IndexError:
|
| 115 |
+
pass
|
| 116 |
+
|
| 117 |
+
# make sure connection string is defined somehow
|
| 118 |
+
if not "connection_string" in kwargs:
|
| 119 |
+
try: # perhaps 'dsn' was defined
|
| 120 |
+
kwargs["connection_string"] = kwargs["dsn"]
|
| 121 |
+
except KeyError:
|
| 122 |
+
try: # as a last effort, use the "host" keyword
|
| 123 |
+
kwargs["connection_string"] = kwargs["host"]
|
| 124 |
+
except KeyError:
|
| 125 |
+
raise TypeError("Must define 'connection_string' for ado connections")
|
| 126 |
+
if expand_macros:
|
| 127 |
+
for kwarg in list(kwargs.keys()):
|
| 128 |
+
if kwarg.startswith("macro_"): # If a key defines a macro
|
| 129 |
+
macro_name = kwarg[6:] # name without the "macro_"
|
| 130 |
+
macro_code = kwargs.pop(
|
| 131 |
+
kwarg
|
| 132 |
+
) # we remove the macro_key and get the code to execute
|
| 133 |
+
new_key, rslt = macro_call(
|
| 134 |
+
macro_name, macro_code, kwargs
|
| 135 |
+
) # run the code in the local context
|
| 136 |
+
kwargs[new_key] = rslt # put the result back in the keywords dict
|
| 137 |
+
return kwargs
|
venv/Lib/site-packages/adodbapi/readme.txt
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Project
|
| 2 |
+
-------
|
| 3 |
+
adodbapi
|
| 4 |
+
|
| 5 |
+
A Python DB-API 2.0 (PEP-249) module that makes it easy to use Microsoft ADO
|
| 6 |
+
for connecting with databases and other data sources using CPython.
|
| 7 |
+
|
| 8 |
+
Home page: <https://sourceforge.net/projects/adodbapi>
|
| 9 |
+
|
| 10 |
+
Features:
|
| 11 |
+
* 100% DB-API 2.0 (PEP-249) compliant (including most extensions and recommendations).
|
| 12 |
+
* Includes pyunit testcases that describe how to use the module.
|
| 13 |
+
* Fully implemented in Python. -- runs in current versions of Python 3
|
| 14 |
+
* Licensed under the LGPL license, which means that it can be used freely even in commercial programs subject to certain restrictions.
|
| 15 |
+
* The user can choose between paramstyles: 'qmark' 'named' 'format' 'pyformat' 'dynamic'
|
| 16 |
+
* Supports data retrieval by column name e.g.:
|
| 17 |
+
for row in myCurser.execute("select name,age from students"):
|
| 18 |
+
print("Student", row.name, "is", row.age, "years old.")
|
| 19 |
+
* Supports user-definable system-to-Python data conversion functions (selected by ADO data type, or by column)
|
| 20 |
+
|
| 21 |
+
Prerequisites:
|
| 22 |
+
* C Python 3.6 or higher
|
| 23 |
+
and pywin32 (Mark Hammond's python for windows extensions.)
|
| 24 |
+
|
| 25 |
+
Installation:
|
| 26 |
+
* (C-Python on Windows): Install pywin32 (`python -m pip install pywin32`) which includes adodbapi.
|
| 27 |
+
* (IronPython on Windows): Download adodbapi from https://sourceforge.net/projects/adodbapi/ . Unpack the zip.
|
| 28 |
+
|
| 29 |
+
NOTE: ...........
|
| 30 |
+
If you do not like the new default operation of returning Numeric columns as decimal.Decimal,
|
| 31 |
+
you can select other options by the user defined conversion feature.
|
| 32 |
+
Try:
|
| 33 |
+
adodbapi.apibase.variantConversions[adodbapi.ado_consts.adNumeric] = adodbapi.apibase.cvtString
|
| 34 |
+
or:
|
| 35 |
+
adodbapi.apibase.variantConversions[adodbapi.ado_consts.adNumeric] = adodbapi.apibase.cvtFloat
|
| 36 |
+
or:
|
| 37 |
+
adodbapi.apibase.variantConversions[adodbapi.ado_consts.adNumeric] = write_your_own_conversion_function
|
| 38 |
+
............
|
| 39 |
+
notes for 2.6.2:
|
| 40 |
+
The definitive source has been moved to https://github.com/mhammond/pywin32/tree/main/adodbapi.
|
| 41 |
+
Remote has proven too hard to configure and test with Pyro4. I am moving it to unsupported status
|
| 42 |
+
until I can change to a different connection method.
|
| 43 |
+
what's new in version 2.6
|
| 44 |
+
A cursor.prepare() method and support for prepared SQL statements.
|
| 45 |
+
Lots of refactoring, especially of the Remote and Server modules (still to be treated as Beta code).
|
| 46 |
+
The quick start document 'quick_reference.odt' will export as a nice-looking pdf.
|
| 47 |
+
Added paramstyles 'pyformat' and 'dynamic'. If your 'paramstyle' is 'named' you _must_ pass a dictionary of
|
| 48 |
+
parameters to your .execute() method. If your 'paramstyle' is 'format' 'pyformat' or 'dynamic', you _may_
|
| 49 |
+
pass a dictionary of parameters -- provided your SQL operation string is formatted correctly.
|
| 50 |
+
|
| 51 |
+
what's new in version 2.5
|
| 52 |
+
Remote module: (works on Linux!) allows a Windows computer to serve ADO databases via PyRO
|
| 53 |
+
Server module: PyRO server for ADO. Run using a command like= C:>python -m adodbapi.server
|
| 54 |
+
(server has simple connection string macros: is64bit, getuser, sql_provider, auto_security)
|
| 55 |
+
Brief documentation included. See adodbapi/examples folder adodbapi.rtf
|
| 56 |
+
New connection method conn.get_table_names() --> list of names of tables in database
|
| 57 |
+
|
| 58 |
+
Vastly refactored. Data conversion things have been moved to the new adodbapi.apibase module.
|
| 59 |
+
Many former module-level attributes are now class attributes. (Should be more thread-safe)
|
| 60 |
+
Connection objects are now context managers for transactions and will commit or rollback.
|
| 61 |
+
Cursor objects are context managers and will automatically close themselves.
|
| 62 |
+
Autocommit can be switched on and off.
|
| 63 |
+
Keyword and positional arguments on the connect() method work as documented in PEP 249.
|
| 64 |
+
Keyword arguments from the connect call can be formatted into the connection string.
|
| 65 |
+
New keyword arguments defined, such as: autocommit, paramstyle, remote_proxy, remote_port.
|
| 66 |
+
*** Breaking change: variantConversion lookups are simplified: the following will raise KeyError:
|
| 67 |
+
oldconverter=adodbapi.variantConversions[adodbapi.adoStringTypes]
|
| 68 |
+
Refactor as: oldconverter=adodbapi.variantConversions[adodbapi.adoStringTypes[0]]
|
| 69 |
+
|
| 70 |
+
License
|
| 71 |
+
-------
|
| 72 |
+
LGPL, see https://opensource.org/license/lgpl-2-1
|
| 73 |
+
|
| 74 |
+
Documentation
|
| 75 |
+
-------------
|
| 76 |
+
|
| 77 |
+
Look at:
|
| 78 |
+
- `adodbapi/quick_reference.md`
|
| 79 |
+
- https://wiki.python.org/moin/DatabaseProgramming#The_DB-API
|
| 80 |
+
- read the examples in adodbapi/examples
|
| 81 |
+
- and the test cases in `adodbapi/test directory`
|
| 82 |
+
|
| 83 |
+
Mailing lists
|
| 84 |
+
-------------
|
| 85 |
+
The adodbapi mailing lists have been deactivated. Submit comments to the
|
| 86 |
+
pywin32 mailing lists.
|
| 87 |
+
-- the bug tracker on sourceforge.net/projects/adodbapi may be checked, (infrequently).
|
| 88 |
+
-- please use: https://github.com/mhammond/pywin32/issues
|
venv/Lib/site-packages/adodbapi/schema_table.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""call using an open ADO connection --> list of table names"""
|
| 2 |
+
|
| 3 |
+
from . import adodbapi
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def names(connection_object):
|
| 7 |
+
ado = connection_object.adoConn
|
| 8 |
+
schema = ado.OpenSchema(20) # constant = adSchemaTables
|
| 9 |
+
|
| 10 |
+
tables = []
|
| 11 |
+
while not schema.EOF:
|
| 12 |
+
name = adodbapi.getIndexedValue(schema.Fields, "TABLE_NAME").Value
|
| 13 |
+
tables.append(name)
|
| 14 |
+
schema.MoveNext()
|
| 15 |
+
del schema
|
| 16 |
+
return tables
|
venv/Lib/site-packages/adodbapi/setup.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""adodbapi -- a pure Python PEP 249 DB-API package using Microsoft ADO
|
| 2 |
+
|
| 3 |
+
Adodbapi can be run on CPython 3.5 and later.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
NAME = "adodbapi"
|
| 7 |
+
MAINTAINER = "Vernon Cole"
|
| 8 |
+
MAINTAINER_EMAIL = "vernondcole@gmail.com"
|
| 9 |
+
DESCRIPTION = (
|
| 10 |
+
"""A pure Python package implementing PEP 249 DB-API using Microsoft ADO."""
|
| 11 |
+
)
|
| 12 |
+
URL = "https://sourceforge.net/projects/adodbapi"
|
| 13 |
+
LICENSE = "LGPL"
|
| 14 |
+
CLASSIFIERS = [
|
| 15 |
+
"Development Status :: 5 - Production/Stable",
|
| 16 |
+
"Intended Audience :: Developers",
|
| 17 |
+
"License :: OSI Approved :: GNU Library or Lesser General Public License (LGPL)",
|
| 18 |
+
"Operating System :: Microsoft :: Windows",
|
| 19 |
+
"Operating System :: POSIX :: Linux",
|
| 20 |
+
"Programming Language :: Python",
|
| 21 |
+
"Programming Language :: Python :: 3",
|
| 22 |
+
"Programming Language :: SQL",
|
| 23 |
+
"Topic :: Software Development",
|
| 24 |
+
"Topic :: Software Development :: Libraries :: Python Modules",
|
| 25 |
+
"Topic :: Database",
|
| 26 |
+
]
|
| 27 |
+
AUTHOR = "Henrik Ekelund, Vernon Cole, et.al."
|
| 28 |
+
AUTHOR_EMAIL = "vernondcole@gmail.com"
|
| 29 |
+
PLATFORMS = ["Windows", "Linux"]
|
| 30 |
+
|
| 31 |
+
VERSION = None # in case searching for version fails
|
| 32 |
+
a = open("adodbapi.py") # find the version string in the source code
|
| 33 |
+
for line in a:
|
| 34 |
+
if "__version__" in line:
|
| 35 |
+
VERSION = line.split("'")[1] # pyright: ignore[reportConstantRedefinition]
|
| 36 |
+
print('adodbapi version="%s"' % VERSION)
|
| 37 |
+
break
|
| 38 |
+
a.close()
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def setup_package():
|
| 42 |
+
from setuptools import setup
|
| 43 |
+
from setuptools.command.build_py import build_py
|
| 44 |
+
|
| 45 |
+
setup(
|
| 46 |
+
cmdclass={"build_py": build_py},
|
| 47 |
+
name=NAME,
|
| 48 |
+
maintainer=MAINTAINER,
|
| 49 |
+
maintainer_email=MAINTAINER_EMAIL,
|
| 50 |
+
description=DESCRIPTION,
|
| 51 |
+
url=URL,
|
| 52 |
+
keywords="database ado odbc dbapi db-api Microsoft SQL",
|
| 53 |
+
## download_url=DOWNLOAD_URL,
|
| 54 |
+
long_description=open("README.txt").read(),
|
| 55 |
+
license=LICENSE,
|
| 56 |
+
classifiers=CLASSIFIERS,
|
| 57 |
+
author=AUTHOR,
|
| 58 |
+
author_email=AUTHOR_EMAIL,
|
| 59 |
+
platforms=PLATFORMS,
|
| 60 |
+
version=VERSION,
|
| 61 |
+
package_dir={"adodbapi": ""},
|
| 62 |
+
packages=["adodbapi"],
|
| 63 |
+
)
|
| 64 |
+
return
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
if __name__ == "__main__":
|
| 68 |
+
setup_package()
|
venv/Lib/site-packages/aiohappyeyeballs-2.6.1.dist-info/INSTALLER
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
pip
|
venv/Lib/site-packages/aiohappyeyeballs-2.6.1.dist-info/LICENSE
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
A. HISTORY OF THE SOFTWARE
|
| 2 |
+
==========================
|
| 3 |
+
|
| 4 |
+
Python was created in the early 1990s by Guido van Rossum at Stichting
|
| 5 |
+
Mathematisch Centrum (CWI, see https://www.cwi.nl) in the Netherlands
|
| 6 |
+
as a successor of a language called ABC. Guido remains Python's
|
| 7 |
+
principal author, although it includes many contributions from others.
|
| 8 |
+
|
| 9 |
+
In 1995, Guido continued his work on Python at the Corporation for
|
| 10 |
+
National Research Initiatives (CNRI, see https://www.cnri.reston.va.us)
|
| 11 |
+
in Reston, Virginia where he released several versions of the
|
| 12 |
+
software.
|
| 13 |
+
|
| 14 |
+
In May 2000, Guido and the Python core development team moved to
|
| 15 |
+
BeOpen.com to form the BeOpen PythonLabs team. In October of the same
|
| 16 |
+
year, the PythonLabs team moved to Digital Creations, which became
|
| 17 |
+
Zope Corporation. In 2001, the Python Software Foundation (PSF, see
|
| 18 |
+
https://www.python.org/psf/) was formed, a non-profit organization
|
| 19 |
+
created specifically to own Python-related Intellectual Property.
|
| 20 |
+
Zope Corporation was a sponsoring member of the PSF.
|
| 21 |
+
|
| 22 |
+
All Python releases are Open Source (see https://opensource.org for
|
| 23 |
+
the Open Source Definition). Historically, most, but not all, Python
|
| 24 |
+
releases have also been GPL-compatible; the table below summarizes
|
| 25 |
+
the various releases.
|
| 26 |
+
|
| 27 |
+
Release Derived Year Owner GPL-
|
| 28 |
+
from compatible? (1)
|
| 29 |
+
|
| 30 |
+
0.9.0 thru 1.2 1991-1995 CWI yes
|
| 31 |
+
1.3 thru 1.5.2 1.2 1995-1999 CNRI yes
|
| 32 |
+
1.6 1.5.2 2000 CNRI no
|
| 33 |
+
2.0 1.6 2000 BeOpen.com no
|
| 34 |
+
1.6.1 1.6 2001 CNRI yes (2)
|
| 35 |
+
2.1 2.0+1.6.1 2001 PSF no
|
| 36 |
+
2.0.1 2.0+1.6.1 2001 PSF yes
|
| 37 |
+
2.1.1 2.1+2.0.1 2001 PSF yes
|
| 38 |
+
2.1.2 2.1.1 2002 PSF yes
|
| 39 |
+
2.1.3 2.1.2 2002 PSF yes
|
| 40 |
+
2.2 and above 2.1.1 2001-now PSF yes
|
| 41 |
+
|
| 42 |
+
Footnotes:
|
| 43 |
+
|
| 44 |
+
(1) GPL-compatible doesn't mean that we're distributing Python under
|
| 45 |
+
the GPL. All Python licenses, unlike the GPL, let you distribute
|
| 46 |
+
a modified version without making your changes open source. The
|
| 47 |
+
GPL-compatible licenses make it possible to combine Python with
|
| 48 |
+
other software that is released under the GPL; the others don't.
|
| 49 |
+
|
| 50 |
+
(2) According to Richard Stallman, 1.6.1 is not GPL-compatible,
|
| 51 |
+
because its license has a choice of law clause. According to
|
| 52 |
+
CNRI, however, Stallman's lawyer has told CNRI's lawyer that 1.6.1
|
| 53 |
+
is "not incompatible" with the GPL.
|
| 54 |
+
|
| 55 |
+
Thanks to the many outside volunteers who have worked under Guido's
|
| 56 |
+
direction to make these releases possible.
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
B. TERMS AND CONDITIONS FOR ACCESSING OR OTHERWISE USING PYTHON
|
| 60 |
+
===============================================================
|
| 61 |
+
|
| 62 |
+
Python software and documentation are licensed under the
|
| 63 |
+
Python Software Foundation License Version 2.
|
| 64 |
+
|
| 65 |
+
Starting with Python 3.8.6, examples, recipes, and other code in
|
| 66 |
+
the documentation are dual licensed under the PSF License Version 2
|
| 67 |
+
and the Zero-Clause BSD license.
|
| 68 |
+
|
| 69 |
+
Some software incorporated into Python is under different licenses.
|
| 70 |
+
The licenses are listed with code falling under that license.
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2
|
| 74 |
+
--------------------------------------------
|
| 75 |
+
|
| 76 |
+
1. This LICENSE AGREEMENT is between the Python Software Foundation
|
| 77 |
+
("PSF"), and the Individual or Organization ("Licensee") accessing and
|
| 78 |
+
otherwise using this software ("Python") in source or binary form and
|
| 79 |
+
its associated documentation.
|
| 80 |
+
|
| 81 |
+
2. Subject to the terms and conditions of this License Agreement, PSF hereby
|
| 82 |
+
grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce,
|
| 83 |
+
analyze, test, perform and/or display publicly, prepare derivative works,
|
| 84 |
+
distribute, and otherwise use Python alone or in any derivative version,
|
| 85 |
+
provided, however, that PSF's License Agreement and PSF's notice of copyright,
|
| 86 |
+
i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010,
|
| 87 |
+
2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023 Python Software Foundation;
|
| 88 |
+
All Rights Reserved" are retained in Python alone or in any derivative version
|
| 89 |
+
prepared by Licensee.
|
| 90 |
+
|
| 91 |
+
3. In the event Licensee prepares a derivative work that is based on
|
| 92 |
+
or incorporates Python or any part thereof, and wants to make
|
| 93 |
+
the derivative work available to others as provided herein, then
|
| 94 |
+
Licensee hereby agrees to include in any such work a brief summary of
|
| 95 |
+
the changes made to Python.
|
| 96 |
+
|
| 97 |
+
4. PSF is making Python available to Licensee on an "AS IS"
|
| 98 |
+
basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR
|
| 99 |
+
IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND
|
| 100 |
+
DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS
|
| 101 |
+
FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT
|
| 102 |
+
INFRINGE ANY THIRD PARTY RIGHTS.
|
| 103 |
+
|
| 104 |
+
5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON
|
| 105 |
+
FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS
|
| 106 |
+
A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON,
|
| 107 |
+
OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF.
|
| 108 |
+
|
| 109 |
+
6. This License Agreement will automatically terminate upon a material
|
| 110 |
+
breach of its terms and conditions.
|
| 111 |
+
|
| 112 |
+
7. Nothing in this License Agreement shall be deemed to create any
|
| 113 |
+
relationship of agency, partnership, or joint venture between PSF and
|
| 114 |
+
Licensee. This License Agreement does not grant permission to use PSF
|
| 115 |
+
trademarks or trade name in a trademark sense to endorse or promote
|
| 116 |
+
products or services of Licensee, or any third party.
|
| 117 |
+
|
| 118 |
+
8. By copying, installing or otherwise using Python, Licensee
|
| 119 |
+
agrees to be bound by the terms and conditions of this License
|
| 120 |
+
Agreement.
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
BEOPEN.COM LICENSE AGREEMENT FOR PYTHON 2.0
|
| 124 |
+
-------------------------------------------
|
| 125 |
+
|
| 126 |
+
BEOPEN PYTHON OPEN SOURCE LICENSE AGREEMENT VERSION 1
|
| 127 |
+
|
| 128 |
+
1. This LICENSE AGREEMENT is between BeOpen.com ("BeOpen"), having an
|
| 129 |
+
office at 160 Saratoga Avenue, Santa Clara, CA 95051, and the
|
| 130 |
+
Individual or Organization ("Licensee") accessing and otherwise using
|
| 131 |
+
this software in source or binary form and its associated
|
| 132 |
+
documentation ("the Software").
|
| 133 |
+
|
| 134 |
+
2. Subject to the terms and conditions of this BeOpen Python License
|
| 135 |
+
Agreement, BeOpen hereby grants Licensee a non-exclusive,
|
| 136 |
+
royalty-free, world-wide license to reproduce, analyze, test, perform
|
| 137 |
+
and/or display publicly, prepare derivative works, distribute, and
|
| 138 |
+
otherwise use the Software alone or in any derivative version,
|
| 139 |
+
provided, however, that the BeOpen Python License is retained in the
|
| 140 |
+
Software, alone or in any derivative version prepared by Licensee.
|
| 141 |
+
|
| 142 |
+
3. BeOpen is making the Software available to Licensee on an "AS IS"
|
| 143 |
+
basis. BEOPEN MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR
|
| 144 |
+
IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, BEOPEN MAKES NO AND
|
| 145 |
+
DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS
|
| 146 |
+
FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF THE SOFTWARE WILL NOT
|
| 147 |
+
INFRINGE ANY THIRD PARTY RIGHTS.
|
| 148 |
+
|
| 149 |
+
4. BEOPEN SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF THE
|
| 150 |
+
SOFTWARE FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS
|
| 151 |
+
AS A RESULT OF USING, MODIFYING OR DISTRIBUTING THE SOFTWARE, OR ANY
|
| 152 |
+
DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF.
|
| 153 |
+
|
| 154 |
+
5. This License Agreement will automatically terminate upon a material
|
| 155 |
+
breach of its terms and conditions.
|
| 156 |
+
|
| 157 |
+
6. This License Agreement shall be governed by and interpreted in all
|
| 158 |
+
respects by the law of the State of California, excluding conflict of
|
| 159 |
+
law provisions. Nothing in this License Agreement shall be deemed to
|
| 160 |
+
create any relationship of agency, partnership, or joint venture
|
| 161 |
+
between BeOpen and Licensee. This License Agreement does not grant
|
| 162 |
+
permission to use BeOpen trademarks or trade names in a trademark
|
| 163 |
+
sense to endorse or promote products or services of Licensee, or any
|
| 164 |
+
third party. As an exception, the "BeOpen Python" logos available at
|
| 165 |
+
http://www.pythonlabs.com/logos.html may be used according to the
|
| 166 |
+
permissions granted on that web page.
|
| 167 |
+
|
| 168 |
+
7. By copying, installing or otherwise using the software, Licensee
|
| 169 |
+
agrees to be bound by the terms and conditions of this License
|
| 170 |
+
Agreement.
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
CNRI LICENSE AGREEMENT FOR PYTHON 1.6.1
|
| 174 |
+
---------------------------------------
|
| 175 |
+
|
| 176 |
+
1. This LICENSE AGREEMENT is between the Corporation for National
|
| 177 |
+
Research Initiatives, having an office at 1895 Preston White Drive,
|
| 178 |
+
Reston, VA 20191 ("CNRI"), and the Individual or Organization
|
| 179 |
+
("Licensee") accessing and otherwise using Python 1.6.1 software in
|
| 180 |
+
source or binary form and its associated documentation.
|
| 181 |
+
|
| 182 |
+
2. Subject to the terms and conditions of this License Agreement, CNRI
|
| 183 |
+
hereby grants Licensee a nonexclusive, royalty-free, world-wide
|
| 184 |
+
license to reproduce, analyze, test, perform and/or display publicly,
|
| 185 |
+
prepare derivative works, distribute, and otherwise use Python 1.6.1
|
| 186 |
+
alone or in any derivative version, provided, however, that CNRI's
|
| 187 |
+
License Agreement and CNRI's notice of copyright, i.e., "Copyright (c)
|
| 188 |
+
1995-2001 Corporation for National Research Initiatives; All Rights
|
| 189 |
+
Reserved" are retained in Python 1.6.1 alone or in any derivative
|
| 190 |
+
version prepared by Licensee. Alternately, in lieu of CNRI's License
|
| 191 |
+
Agreement, Licensee may substitute the following text (omitting the
|
| 192 |
+
quotes): "Python 1.6.1 is made available subject to the terms and
|
| 193 |
+
conditions in CNRI's License Agreement. This Agreement together with
|
| 194 |
+
Python 1.6.1 may be located on the internet using the following
|
| 195 |
+
unique, persistent identifier (known as a handle): 1895.22/1013. This
|
| 196 |
+
Agreement may also be obtained from a proxy server on the internet
|
| 197 |
+
using the following URL: http://hdl.handle.net/1895.22/1013".
|
| 198 |
+
|
| 199 |
+
3. In the event Licensee prepares a derivative work that is based on
|
| 200 |
+
or incorporates Python 1.6.1 or any part thereof, and wants to make
|
| 201 |
+
the derivative work available to others as provided herein, then
|
| 202 |
+
Licensee hereby agrees to include in any such work a brief summary of
|
| 203 |
+
the changes made to Python 1.6.1.
|
| 204 |
+
|
| 205 |
+
4. CNRI is making Python 1.6.1 available to Licensee on an "AS IS"
|
| 206 |
+
basis. CNRI MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR
|
| 207 |
+
IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, CNRI MAKES NO AND
|
| 208 |
+
DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS
|
| 209 |
+
FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON 1.6.1 WILL NOT
|
| 210 |
+
INFRINGE ANY THIRD PARTY RIGHTS.
|
| 211 |
+
|
| 212 |
+
5. CNRI SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON
|
| 213 |
+
1.6.1 FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS
|
| 214 |
+
A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON 1.6.1,
|
| 215 |
+
OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF.
|
| 216 |
+
|
| 217 |
+
6. This License Agreement will automatically terminate upon a material
|
| 218 |
+
breach of its terms and conditions.
|
| 219 |
+
|
| 220 |
+
7. This License Agreement shall be governed by the federal
|
| 221 |
+
intellectual property law of the United States, including without
|
| 222 |
+
limitation the federal copyright law, and, to the extent such
|
| 223 |
+
U.S. federal law does not apply, by the law of the Commonwealth of
|
| 224 |
+
Virginia, excluding Virginia's conflict of law provisions.
|
| 225 |
+
Notwithstanding the foregoing, with regard to derivative works based
|
| 226 |
+
on Python 1.6.1 that incorporate non-separable material that was
|
| 227 |
+
previously distributed under the GNU General Public License (GPL), the
|
| 228 |
+
law of the Commonwealth of Virginia shall govern this License
|
| 229 |
+
Agreement only as to issues arising under or with respect to
|
| 230 |
+
Paragraphs 4, 5, and 7 of this License Agreement. Nothing in this
|
| 231 |
+
License Agreement shall be deemed to create any relationship of
|
| 232 |
+
agency, partnership, or joint venture between CNRI and Licensee. This
|
| 233 |
+
License Agreement does not grant permission to use CNRI trademarks or
|
| 234 |
+
trade name in a trademark sense to endorse or promote products or
|
| 235 |
+
services of Licensee, or any third party.
|
| 236 |
+
|
| 237 |
+
8. By clicking on the "ACCEPT" button where indicated, or by copying,
|
| 238 |
+
installing or otherwise using Python 1.6.1, Licensee agrees to be
|
| 239 |
+
bound by the terms and conditions of this License Agreement.
|
| 240 |
+
|
| 241 |
+
ACCEPT
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
CWI LICENSE AGREEMENT FOR PYTHON 0.9.0 THROUGH 1.2
|
| 245 |
+
--------------------------------------------------
|
| 246 |
+
|
| 247 |
+
Copyright (c) 1991 - 1995, Stichting Mathematisch Centrum Amsterdam,
|
| 248 |
+
The Netherlands. All rights reserved.
|
| 249 |
+
|
| 250 |
+
Permission to use, copy, modify, and distribute this software and its
|
| 251 |
+
documentation for any purpose and without fee is hereby granted,
|
| 252 |
+
provided that the above copyright notice appear in all copies and that
|
| 253 |
+
both that copyright notice and this permission notice appear in
|
| 254 |
+
supporting documentation, and that the name of Stichting Mathematisch
|
| 255 |
+
Centrum or CWI not be used in advertising or publicity pertaining to
|
| 256 |
+
distribution of the software without specific, written prior
|
| 257 |
+
permission.
|
| 258 |
+
|
| 259 |
+
STICHTING MATHEMATISCH CENTRUM DISCLAIMS ALL WARRANTIES WITH REGARD TO
|
| 260 |
+
THIS SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
| 261 |
+
FITNESS, IN NO EVENT SHALL STICHTING MATHEMATISCH CENTRUM BE LIABLE
|
| 262 |
+
FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
| 263 |
+
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
| 264 |
+
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
| 265 |
+
OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
| 266 |
+
|
| 267 |
+
ZERO-CLAUSE BSD LICENSE FOR CODE IN THE PYTHON DOCUMENTATION
|
| 268 |
+
----------------------------------------------------------------------
|
| 269 |
+
|
| 270 |
+
Permission to use, copy, modify, and/or distribute this software for any
|
| 271 |
+
purpose with or without fee is hereby granted.
|
| 272 |
+
|
| 273 |
+
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
|
| 274 |
+
REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
|
| 275 |
+
AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
|
| 276 |
+
INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
|
| 277 |
+
LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR
|
| 278 |
+
OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
|
| 279 |
+
PERFORMANCE OF THIS SOFTWARE.
|
venv/Lib/site-packages/aiohappyeyeballs-2.6.1.dist-info/METADATA
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.3
|
| 2 |
+
Name: aiohappyeyeballs
|
| 3 |
+
Version: 2.6.1
|
| 4 |
+
Summary: Happy Eyeballs for asyncio
|
| 5 |
+
License: PSF-2.0
|
| 6 |
+
Author: J. Nick Koston
|
| 7 |
+
Author-email: nick@koston.org
|
| 8 |
+
Requires-Python: >=3.9
|
| 9 |
+
Classifier: Development Status :: 5 - Production/Stable
|
| 10 |
+
Classifier: Intended Audience :: Developers
|
| 11 |
+
Classifier: Natural Language :: English
|
| 12 |
+
Classifier: Operating System :: OS Independent
|
| 13 |
+
Classifier: Topic :: Software Development :: Libraries
|
| 14 |
+
Classifier: Programming Language :: Python :: 3
|
| 15 |
+
Classifier: Programming Language :: Python :: 3.9
|
| 16 |
+
Classifier: Programming Language :: Python :: 3.10
|
| 17 |
+
Classifier: Programming Language :: Python :: 3.11
|
| 18 |
+
Classifier: Programming Language :: Python :: 3.12
|
| 19 |
+
Classifier: Programming Language :: Python :: 3.13
|
| 20 |
+
Classifier: License :: OSI Approved :: Python Software Foundation License
|
| 21 |
+
Project-URL: Bug Tracker, https://github.com/aio-libs/aiohappyeyeballs/issues
|
| 22 |
+
Project-URL: Changelog, https://github.com/aio-libs/aiohappyeyeballs/blob/main/CHANGELOG.md
|
| 23 |
+
Project-URL: Documentation, https://aiohappyeyeballs.readthedocs.io
|
| 24 |
+
Project-URL: Repository, https://github.com/aio-libs/aiohappyeyeballs
|
| 25 |
+
Description-Content-Type: text/markdown
|
| 26 |
+
|
| 27 |
+
# aiohappyeyeballs
|
| 28 |
+
|
| 29 |
+
<p align="center">
|
| 30 |
+
<a href="https://github.com/aio-libs/aiohappyeyeballs/actions/workflows/ci.yml?query=branch%3Amain">
|
| 31 |
+
<img src="https://img.shields.io/github/actions/workflow/status/aio-libs/aiohappyeyeballs/ci-cd.yml?branch=main&label=CI&logo=github&style=flat-square" alt="CI Status" >
|
| 32 |
+
</a>
|
| 33 |
+
<a href="https://aiohappyeyeballs.readthedocs.io">
|
| 34 |
+
<img src="https://img.shields.io/readthedocs/aiohappyeyeballs.svg?logo=read-the-docs&logoColor=fff&style=flat-square" alt="Documentation Status">
|
| 35 |
+
</a>
|
| 36 |
+
<a href="https://codecov.io/gh/aio-libs/aiohappyeyeballs">
|
| 37 |
+
<img src="https://img.shields.io/codecov/c/github/aio-libs/aiohappyeyeballs.svg?logo=codecov&logoColor=fff&style=flat-square" alt="Test coverage percentage">
|
| 38 |
+
</a>
|
| 39 |
+
</p>
|
| 40 |
+
<p align="center">
|
| 41 |
+
<a href="https://python-poetry.org/">
|
| 42 |
+
<img src="https://img.shields.io/badge/packaging-poetry-299bd7?style=flat-square&logo=" alt="Poetry">
|
| 43 |
+
</a>
|
| 44 |
+
<a href="https://github.com/astral-sh/ruff">
|
| 45 |
+
<img src="https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json" alt="Ruff">
|
| 46 |
+
</a>
|
| 47 |
+
<a href="https://github.com/pre-commit/pre-commit">
|
| 48 |
+
<img src="https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white&style=flat-square" alt="pre-commit">
|
| 49 |
+
</a>
|
| 50 |
+
</p>
|
| 51 |
+
<p align="center">
|
| 52 |
+
<a href="https://pypi.org/project/aiohappyeyeballs/">
|
| 53 |
+
<img src="https://img.shields.io/pypi/v/aiohappyeyeballs.svg?logo=python&logoColor=fff&style=flat-square" alt="PyPI Version">
|
| 54 |
+
</a>
|
| 55 |
+
<img src="https://img.shields.io/pypi/pyversions/aiohappyeyeballs.svg?style=flat-square&logo=python&logoColor=fff" alt="Supported Python versions">
|
| 56 |
+
<img src="https://img.shields.io/pypi/l/aiohappyeyeballs.svg?style=flat-square" alt="License">
|
| 57 |
+
</p>
|
| 58 |
+
|
| 59 |
+
---
|
| 60 |
+
|
| 61 |
+
**Documentation**: <a href="https://aiohappyeyeballs.readthedocs.io" target="_blank">https://aiohappyeyeballs.readthedocs.io </a>
|
| 62 |
+
|
| 63 |
+
**Source Code**: <a href="https://github.com/aio-libs/aiohappyeyeballs" target="_blank">https://github.com/aio-libs/aiohappyeyeballs </a>
|
| 64 |
+
|
| 65 |
+
---
|
| 66 |
+
|
| 67 |
+
[Happy Eyeballs](https://en.wikipedia.org/wiki/Happy_Eyeballs)
|
| 68 |
+
([RFC 8305](https://www.rfc-editor.org/rfc/rfc8305.html))
|
| 69 |
+
|
| 70 |
+
## Use case
|
| 71 |
+
|
| 72 |
+
This library exists to allow connecting with
|
| 73 |
+
[Happy Eyeballs](https://en.wikipedia.org/wiki/Happy_Eyeballs)
|
| 74 |
+
([RFC 8305](https://www.rfc-editor.org/rfc/rfc8305.html))
|
| 75 |
+
when you
|
| 76 |
+
already have a list of addrinfo and not a DNS name.
|
| 77 |
+
|
| 78 |
+
The stdlib version of `loop.create_connection()`
|
| 79 |
+
will only work when you pass in an unresolved name which
|
| 80 |
+
is not a good fit when using DNS caching or resolving
|
| 81 |
+
names via another method such as `zeroconf`.
|
| 82 |
+
|
| 83 |
+
## Installation
|
| 84 |
+
|
| 85 |
+
Install this via pip (or your favourite package manager):
|
| 86 |
+
|
| 87 |
+
`pip install aiohappyeyeballs`
|
| 88 |
+
|
| 89 |
+
## License
|
| 90 |
+
|
| 91 |
+
[aiohappyeyeballs is licensed under the same terms as cpython itself.](https://github.com/python/cpython/blob/main/LICENSE)
|
| 92 |
+
|
| 93 |
+
## Example usage
|
| 94 |
+
|
| 95 |
+
```python
|
| 96 |
+
|
| 97 |
+
addr_infos = await loop.getaddrinfo("example.org", 80)
|
| 98 |
+
|
| 99 |
+
socket = await start_connection(addr_infos)
|
| 100 |
+
socket = await start_connection(addr_infos, local_addr_infos=local_addr_infos, happy_eyeballs_delay=0.2)
|
| 101 |
+
|
| 102 |
+
transport, protocol = await loop.create_connection(
|
| 103 |
+
MyProtocol, sock=socket, ...)
|
| 104 |
+
|
| 105 |
+
# Remove the first address for each family from addr_info
|
| 106 |
+
pop_addr_infos_interleave(addr_info, 1)
|
| 107 |
+
|
| 108 |
+
# Remove all matching address from addr_info
|
| 109 |
+
remove_addr_infos(addr_info, "dead::beef::")
|
| 110 |
+
|
| 111 |
+
# Convert a local_addr to local_addr_infos
|
| 112 |
+
local_addr_infos = addr_to_addr_infos(("127.0.0.1",0))
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
## Credits
|
| 116 |
+
|
| 117 |
+
This package contains code from cpython and is licensed under the same terms as cpython itself.
|
| 118 |
+
|
| 119 |
+
This package was created with
|
| 120 |
+
[Copier](https://copier.readthedocs.io/) and the
|
| 121 |
+
[browniebroke/pypackage-template](https://github.com/browniebroke/pypackage-template)
|
| 122 |
+
project template.
|
| 123 |
+
|
venv/Lib/site-packages/aiohappyeyeballs-2.6.1.dist-info/RECORD
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
aiohappyeyeballs-2.6.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
| 2 |
+
aiohappyeyeballs-2.6.1.dist-info/LICENSE,sha256=Oy-B_iHRgcSZxZolbI4ZaEVdZonSaaqFNzv7avQdo78,13936
|
| 3 |
+
aiohappyeyeballs-2.6.1.dist-info/METADATA,sha256=NSXlhJwAfi380eEjAo7BQ4P_TVal9xi0qkyZWibMsVM,5915
|
| 4 |
+
aiohappyeyeballs-2.6.1.dist-info/RECORD,,
|
| 5 |
+
aiohappyeyeballs-2.6.1.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
|
| 6 |
+
aiohappyeyeballs/__init__.py,sha256=x7kktHEtaD9quBcWDJPuLeKyjuVAI-Jj14S9B_5hcTs,361
|
| 7 |
+
aiohappyeyeballs/__pycache__/__init__.cpython-312.pyc,,
|
| 8 |
+
aiohappyeyeballs/__pycache__/_staggered.cpython-312.pyc,,
|
| 9 |
+
aiohappyeyeballs/__pycache__/impl.cpython-312.pyc,,
|
| 10 |
+
aiohappyeyeballs/__pycache__/types.cpython-312.pyc,,
|
| 11 |
+
aiohappyeyeballs/__pycache__/utils.cpython-312.pyc,,
|
| 12 |
+
aiohappyeyeballs/_staggered.py,sha256=edfVowFx-P-ywJjIEF3MdPtEMVODujV6CeMYr65otac,6900
|
| 13 |
+
aiohappyeyeballs/impl.py,sha256=Dlcm2mTJ28ucrGnxkb_fo9CZzLAkOOBizOt7dreBbXE,9681
|
| 14 |
+
aiohappyeyeballs/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
| 15 |
+
aiohappyeyeballs/types.py,sha256=YZJIAnyoV4Dz0WFtlaf_OyE4EW7Xus1z7aIfNI6tDDQ,425
|
| 16 |
+
aiohappyeyeballs/utils.py,sha256=on9GxIR0LhEfZu8P6Twi9hepX9zDanuZM20MWsb3xlQ,3028
|
venv/Lib/site-packages/aiohappyeyeballs-2.6.1.dist-info/WHEEL
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Wheel-Version: 1.0
|
| 2 |
+
Generator: poetry-core 2.1.1
|
| 3 |
+
Root-Is-Purelib: true
|
| 4 |
+
Tag: py3-none-any
|
venv/Lib/site-packages/aiohappyeyeballs/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__version__ = "2.6.1"
|
| 2 |
+
|
| 3 |
+
from .impl import start_connection
|
| 4 |
+
from .types import AddrInfoType, SocketFactoryType
|
| 5 |
+
from .utils import addr_to_addr_infos, pop_addr_infos_interleave, remove_addr_infos
|
| 6 |
+
|
| 7 |
+
__all__ = (
|
| 8 |
+
"AddrInfoType",
|
| 9 |
+
"SocketFactoryType",
|
| 10 |
+
"addr_to_addr_infos",
|
| 11 |
+
"pop_addr_infos_interleave",
|
| 12 |
+
"remove_addr_infos",
|
| 13 |
+
"start_connection",
|
| 14 |
+
)
|
venv/Lib/site-packages/aiohappyeyeballs/_staggered.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import contextlib
|
| 3 |
+
|
| 4 |
+
# PY3.9: Import Callable from typing until we drop Python 3.9 support
|
| 5 |
+
# https://github.com/python/cpython/issues/87131
|
| 6 |
+
from typing import (
|
| 7 |
+
TYPE_CHECKING,
|
| 8 |
+
Any,
|
| 9 |
+
Awaitable,
|
| 10 |
+
Callable,
|
| 11 |
+
Iterable,
|
| 12 |
+
List,
|
| 13 |
+
Optional,
|
| 14 |
+
Set,
|
| 15 |
+
Tuple,
|
| 16 |
+
TypeVar,
|
| 17 |
+
Union,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
_T = TypeVar("_T")
|
| 21 |
+
|
| 22 |
+
RE_RAISE_EXCEPTIONS = (SystemExit, KeyboardInterrupt)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _set_result(wait_next: "asyncio.Future[None]") -> None:
|
| 26 |
+
"""Set the result of a future if it is not already done."""
|
| 27 |
+
if not wait_next.done():
|
| 28 |
+
wait_next.set_result(None)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
async def _wait_one(
|
| 32 |
+
futures: "Iterable[asyncio.Future[Any]]",
|
| 33 |
+
loop: asyncio.AbstractEventLoop,
|
| 34 |
+
) -> _T:
|
| 35 |
+
"""Wait for the first future to complete."""
|
| 36 |
+
wait_next = loop.create_future()
|
| 37 |
+
|
| 38 |
+
def _on_completion(fut: "asyncio.Future[Any]") -> None:
|
| 39 |
+
if not wait_next.done():
|
| 40 |
+
wait_next.set_result(fut)
|
| 41 |
+
|
| 42 |
+
for f in futures:
|
| 43 |
+
f.add_done_callback(_on_completion)
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
return await wait_next
|
| 47 |
+
finally:
|
| 48 |
+
for f in futures:
|
| 49 |
+
f.remove_done_callback(_on_completion)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
async def staggered_race(
|
| 53 |
+
coro_fns: Iterable[Callable[[], Awaitable[_T]]],
|
| 54 |
+
delay: Optional[float],
|
| 55 |
+
*,
|
| 56 |
+
loop: Optional[asyncio.AbstractEventLoop] = None,
|
| 57 |
+
) -> Tuple[Optional[_T], Optional[int], List[Optional[BaseException]]]:
|
| 58 |
+
"""
|
| 59 |
+
Run coroutines with staggered start times and take the first to finish.
|
| 60 |
+
|
| 61 |
+
This method takes an iterable of coroutine functions. The first one is
|
| 62 |
+
started immediately. From then on, whenever the immediately preceding one
|
| 63 |
+
fails (raises an exception), or when *delay* seconds has passed, the next
|
| 64 |
+
coroutine is started. This continues until one of the coroutines complete
|
| 65 |
+
successfully, in which case all others are cancelled, or until all
|
| 66 |
+
coroutines fail.
|
| 67 |
+
|
| 68 |
+
The coroutines provided should be well-behaved in the following way:
|
| 69 |
+
|
| 70 |
+
* They should only ``return`` if completed successfully.
|
| 71 |
+
|
| 72 |
+
* They should always raise an exception if they did not complete
|
| 73 |
+
successfully. In particular, if they handle cancellation, they should
|
| 74 |
+
probably reraise, like this::
|
| 75 |
+
|
| 76 |
+
try:
|
| 77 |
+
# do work
|
| 78 |
+
except asyncio.CancelledError:
|
| 79 |
+
# undo partially completed work
|
| 80 |
+
raise
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
----
|
| 84 |
+
coro_fns: an iterable of coroutine functions, i.e. callables that
|
| 85 |
+
return a coroutine object when called. Use ``functools.partial`` or
|
| 86 |
+
lambdas to pass arguments.
|
| 87 |
+
|
| 88 |
+
delay: amount of time, in seconds, between starting coroutines. If
|
| 89 |
+
``None``, the coroutines will run sequentially.
|
| 90 |
+
|
| 91 |
+
loop: the event loop to use. If ``None``, the running loop is used.
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
-------
|
| 95 |
+
tuple *(winner_result, winner_index, exceptions)* where
|
| 96 |
+
|
| 97 |
+
- *winner_result*: the result of the winning coroutine, or ``None``
|
| 98 |
+
if no coroutines won.
|
| 99 |
+
|
| 100 |
+
- *winner_index*: the index of the winning coroutine in
|
| 101 |
+
``coro_fns``, or ``None`` if no coroutines won. If the winning
|
| 102 |
+
coroutine may return None on success, *winner_index* can be used
|
| 103 |
+
to definitively determine whether any coroutine won.
|
| 104 |
+
|
| 105 |
+
- *exceptions*: list of exceptions returned by the coroutines.
|
| 106 |
+
``len(exceptions)`` is equal to the number of coroutines actually
|
| 107 |
+
started, and the order is the same as in ``coro_fns``. The winning
|
| 108 |
+
coroutine's entry is ``None``.
|
| 109 |
+
|
| 110 |
+
"""
|
| 111 |
+
loop = loop or asyncio.get_running_loop()
|
| 112 |
+
exceptions: List[Optional[BaseException]] = []
|
| 113 |
+
tasks: Set[asyncio.Task[Optional[Tuple[_T, int]]]] = set()
|
| 114 |
+
|
| 115 |
+
async def run_one_coro(
|
| 116 |
+
coro_fn: Callable[[], Awaitable[_T]],
|
| 117 |
+
this_index: int,
|
| 118 |
+
start_next: "asyncio.Future[None]",
|
| 119 |
+
) -> Optional[Tuple[_T, int]]:
|
| 120 |
+
"""
|
| 121 |
+
Run a single coroutine.
|
| 122 |
+
|
| 123 |
+
If the coroutine fails, set the exception in the exceptions list and
|
| 124 |
+
start the next coroutine by setting the result of the start_next.
|
| 125 |
+
|
| 126 |
+
If the coroutine succeeds, return the result and the index of the
|
| 127 |
+
coroutine in the coro_fns list.
|
| 128 |
+
|
| 129 |
+
If SystemExit or KeyboardInterrupt is raised, re-raise it.
|
| 130 |
+
"""
|
| 131 |
+
try:
|
| 132 |
+
result = await coro_fn()
|
| 133 |
+
except RE_RAISE_EXCEPTIONS:
|
| 134 |
+
raise
|
| 135 |
+
except BaseException as e:
|
| 136 |
+
exceptions[this_index] = e
|
| 137 |
+
_set_result(start_next) # Kickstart the next coroutine
|
| 138 |
+
return None
|
| 139 |
+
|
| 140 |
+
return result, this_index
|
| 141 |
+
|
| 142 |
+
start_next_timer: Optional[asyncio.TimerHandle] = None
|
| 143 |
+
start_next: Optional[asyncio.Future[None]]
|
| 144 |
+
task: asyncio.Task[Optional[Tuple[_T, int]]]
|
| 145 |
+
done: Union[asyncio.Future[None], asyncio.Task[Optional[Tuple[_T, int]]]]
|
| 146 |
+
coro_iter = iter(coro_fns)
|
| 147 |
+
this_index = -1
|
| 148 |
+
try:
|
| 149 |
+
while True:
|
| 150 |
+
if coro_fn := next(coro_iter, None):
|
| 151 |
+
this_index += 1
|
| 152 |
+
exceptions.append(None)
|
| 153 |
+
start_next = loop.create_future()
|
| 154 |
+
task = loop.create_task(run_one_coro(coro_fn, this_index, start_next))
|
| 155 |
+
tasks.add(task)
|
| 156 |
+
start_next_timer = (
|
| 157 |
+
loop.call_later(delay, _set_result, start_next) if delay else None
|
| 158 |
+
)
|
| 159 |
+
elif not tasks:
|
| 160 |
+
# We exhausted the coro_fns list and no tasks are running
|
| 161 |
+
# so we have no winner and all coroutines failed.
|
| 162 |
+
break
|
| 163 |
+
|
| 164 |
+
while tasks or start_next:
|
| 165 |
+
done = await _wait_one(
|
| 166 |
+
(*tasks, start_next) if start_next else tasks, loop
|
| 167 |
+
)
|
| 168 |
+
if done is start_next:
|
| 169 |
+
# The current task has failed or the timer has expired
|
| 170 |
+
# so we need to start the next task.
|
| 171 |
+
start_next = None
|
| 172 |
+
if start_next_timer:
|
| 173 |
+
start_next_timer.cancel()
|
| 174 |
+
start_next_timer = None
|
| 175 |
+
|
| 176 |
+
# Break out of the task waiting loop to start the next
|
| 177 |
+
# task.
|
| 178 |
+
break
|
| 179 |
+
|
| 180 |
+
if TYPE_CHECKING:
|
| 181 |
+
assert isinstance(done, asyncio.Task)
|
| 182 |
+
|
| 183 |
+
tasks.remove(done)
|
| 184 |
+
if winner := done.result():
|
| 185 |
+
return *winner, exceptions
|
| 186 |
+
finally:
|
| 187 |
+
# We either have:
|
| 188 |
+
# - a winner
|
| 189 |
+
# - all tasks failed
|
| 190 |
+
# - a KeyboardInterrupt or SystemExit.
|
| 191 |
+
|
| 192 |
+
#
|
| 193 |
+
# If the timer is still running, cancel it.
|
| 194 |
+
#
|
| 195 |
+
if start_next_timer:
|
| 196 |
+
start_next_timer.cancel()
|
| 197 |
+
|
| 198 |
+
#
|
| 199 |
+
# If there are any tasks left, cancel them and than
|
| 200 |
+
# wait them so they fill the exceptions list.
|
| 201 |
+
#
|
| 202 |
+
for task in tasks:
|
| 203 |
+
task.cancel()
|
| 204 |
+
with contextlib.suppress(asyncio.CancelledError):
|
| 205 |
+
await task
|
| 206 |
+
|
| 207 |
+
return None, None, exceptions
|
venv/Lib/site-packages/aiohappyeyeballs/impl.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Base implementation."""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import collections
|
| 5 |
+
import contextlib
|
| 6 |
+
import functools
|
| 7 |
+
import itertools
|
| 8 |
+
import socket
|
| 9 |
+
from typing import List, Optional, Sequence, Set, Union
|
| 10 |
+
|
| 11 |
+
from . import _staggered
|
| 12 |
+
from .types import AddrInfoType, SocketFactoryType
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
async def start_connection(
|
| 16 |
+
addr_infos: Sequence[AddrInfoType],
|
| 17 |
+
*,
|
| 18 |
+
local_addr_infos: Optional[Sequence[AddrInfoType]] = None,
|
| 19 |
+
happy_eyeballs_delay: Optional[float] = None,
|
| 20 |
+
interleave: Optional[int] = None,
|
| 21 |
+
loop: Optional[asyncio.AbstractEventLoop] = None,
|
| 22 |
+
socket_factory: Optional[SocketFactoryType] = None,
|
| 23 |
+
) -> socket.socket:
|
| 24 |
+
"""
|
| 25 |
+
Connect to a TCP server.
|
| 26 |
+
|
| 27 |
+
Create a socket connection to a specified destination. The
|
| 28 |
+
destination is specified as a list of AddrInfoType tuples as
|
| 29 |
+
returned from getaddrinfo().
|
| 30 |
+
|
| 31 |
+
The arguments are, in order:
|
| 32 |
+
|
| 33 |
+
* ``family``: the address family, e.g. ``socket.AF_INET`` or
|
| 34 |
+
``socket.AF_INET6``.
|
| 35 |
+
* ``type``: the socket type, e.g. ``socket.SOCK_STREAM`` or
|
| 36 |
+
``socket.SOCK_DGRAM``.
|
| 37 |
+
* ``proto``: the protocol, e.g. ``socket.IPPROTO_TCP`` or
|
| 38 |
+
``socket.IPPROTO_UDP``.
|
| 39 |
+
* ``canonname``: the canonical name of the address, e.g.
|
| 40 |
+
``"www.python.org"``.
|
| 41 |
+
* ``sockaddr``: the socket address
|
| 42 |
+
|
| 43 |
+
This method is a coroutine which will try to establish the connection
|
| 44 |
+
in the background. When successful, the coroutine returns a
|
| 45 |
+
socket.
|
| 46 |
+
|
| 47 |
+
The expected use case is to use this method in conjunction with
|
| 48 |
+
loop.create_connection() to establish a connection to a server::
|
| 49 |
+
|
| 50 |
+
socket = await start_connection(addr_infos)
|
| 51 |
+
transport, protocol = await loop.create_connection(
|
| 52 |
+
MyProtocol, sock=socket, ...)
|
| 53 |
+
"""
|
| 54 |
+
if not (current_loop := loop):
|
| 55 |
+
current_loop = asyncio.get_running_loop()
|
| 56 |
+
|
| 57 |
+
single_addr_info = len(addr_infos) == 1
|
| 58 |
+
|
| 59 |
+
if happy_eyeballs_delay is not None and interleave is None:
|
| 60 |
+
# If using happy eyeballs, default to interleave addresses by family
|
| 61 |
+
interleave = 1
|
| 62 |
+
|
| 63 |
+
if interleave and not single_addr_info:
|
| 64 |
+
addr_infos = _interleave_addrinfos(addr_infos, interleave)
|
| 65 |
+
|
| 66 |
+
sock: Optional[socket.socket] = None
|
| 67 |
+
# uvloop can raise RuntimeError instead of OSError
|
| 68 |
+
exceptions: List[List[Union[OSError, RuntimeError]]] = []
|
| 69 |
+
if happy_eyeballs_delay is None or single_addr_info:
|
| 70 |
+
# not using happy eyeballs
|
| 71 |
+
for addrinfo in addr_infos:
|
| 72 |
+
try:
|
| 73 |
+
sock = await _connect_sock(
|
| 74 |
+
current_loop,
|
| 75 |
+
exceptions,
|
| 76 |
+
addrinfo,
|
| 77 |
+
local_addr_infos,
|
| 78 |
+
None,
|
| 79 |
+
socket_factory,
|
| 80 |
+
)
|
| 81 |
+
break
|
| 82 |
+
except (RuntimeError, OSError):
|
| 83 |
+
continue
|
| 84 |
+
else: # using happy eyeballs
|
| 85 |
+
open_sockets: Set[socket.socket] = set()
|
| 86 |
+
try:
|
| 87 |
+
sock, _, _ = await _staggered.staggered_race(
|
| 88 |
+
(
|
| 89 |
+
functools.partial(
|
| 90 |
+
_connect_sock,
|
| 91 |
+
current_loop,
|
| 92 |
+
exceptions,
|
| 93 |
+
addrinfo,
|
| 94 |
+
local_addr_infos,
|
| 95 |
+
open_sockets,
|
| 96 |
+
socket_factory,
|
| 97 |
+
)
|
| 98 |
+
for addrinfo in addr_infos
|
| 99 |
+
),
|
| 100 |
+
happy_eyeballs_delay,
|
| 101 |
+
)
|
| 102 |
+
finally:
|
| 103 |
+
# If we have a winner, staggered_race will
|
| 104 |
+
# cancel the other tasks, however there is a
|
| 105 |
+
# small race window where any of the other tasks
|
| 106 |
+
# can be done before they are cancelled which
|
| 107 |
+
# will leave the socket open. To avoid this problem
|
| 108 |
+
# we pass a set to _connect_sock to keep track of
|
| 109 |
+
# the open sockets and close them here if there
|
| 110 |
+
# are any "runner up" sockets.
|
| 111 |
+
for s in open_sockets:
|
| 112 |
+
if s is not sock:
|
| 113 |
+
with contextlib.suppress(OSError):
|
| 114 |
+
s.close()
|
| 115 |
+
open_sockets = None # type: ignore[assignment]
|
| 116 |
+
|
| 117 |
+
if sock is None:
|
| 118 |
+
all_exceptions = [exc for sub in exceptions for exc in sub]
|
| 119 |
+
try:
|
| 120 |
+
first_exception = all_exceptions[0]
|
| 121 |
+
if len(all_exceptions) == 1:
|
| 122 |
+
raise first_exception
|
| 123 |
+
else:
|
| 124 |
+
# If they all have the same str(), raise one.
|
| 125 |
+
model = str(first_exception)
|
| 126 |
+
if all(str(exc) == model for exc in all_exceptions):
|
| 127 |
+
raise first_exception
|
| 128 |
+
# Raise a combined exception so the user can see all
|
| 129 |
+
# the various error messages.
|
| 130 |
+
msg = "Multiple exceptions: {}".format(
|
| 131 |
+
", ".join(str(exc) for exc in all_exceptions)
|
| 132 |
+
)
|
| 133 |
+
# If the errno is the same for all exceptions, raise
|
| 134 |
+
# an OSError with that errno.
|
| 135 |
+
if isinstance(first_exception, OSError):
|
| 136 |
+
first_errno = first_exception.errno
|
| 137 |
+
if all(
|
| 138 |
+
isinstance(exc, OSError) and exc.errno == first_errno
|
| 139 |
+
for exc in all_exceptions
|
| 140 |
+
):
|
| 141 |
+
raise OSError(first_errno, msg)
|
| 142 |
+
elif isinstance(first_exception, RuntimeError) and all(
|
| 143 |
+
isinstance(exc, RuntimeError) for exc in all_exceptions
|
| 144 |
+
):
|
| 145 |
+
raise RuntimeError(msg)
|
| 146 |
+
# We have a mix of OSError and RuntimeError
|
| 147 |
+
# so we have to pick which one to raise.
|
| 148 |
+
# and we raise OSError for compatibility
|
| 149 |
+
raise OSError(msg)
|
| 150 |
+
finally:
|
| 151 |
+
all_exceptions = None # type: ignore[assignment]
|
| 152 |
+
exceptions = None # type: ignore[assignment]
|
| 153 |
+
|
| 154 |
+
return sock
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
async def _connect_sock(
|
| 158 |
+
loop: asyncio.AbstractEventLoop,
|
| 159 |
+
exceptions: List[List[Union[OSError, RuntimeError]]],
|
| 160 |
+
addr_info: AddrInfoType,
|
| 161 |
+
local_addr_infos: Optional[Sequence[AddrInfoType]] = None,
|
| 162 |
+
open_sockets: Optional[Set[socket.socket]] = None,
|
| 163 |
+
socket_factory: Optional[SocketFactoryType] = None,
|
| 164 |
+
) -> socket.socket:
|
| 165 |
+
"""
|
| 166 |
+
Create, bind and connect one socket.
|
| 167 |
+
|
| 168 |
+
If open_sockets is passed, add the socket to the set of open sockets.
|
| 169 |
+
Any failure caught here will remove the socket from the set and close it.
|
| 170 |
+
|
| 171 |
+
Callers can use this set to close any sockets that are not the winner
|
| 172 |
+
of all staggered tasks in the result there are runner up sockets aka
|
| 173 |
+
multiple winners.
|
| 174 |
+
"""
|
| 175 |
+
my_exceptions: List[Union[OSError, RuntimeError]] = []
|
| 176 |
+
exceptions.append(my_exceptions)
|
| 177 |
+
family, type_, proto, _, address = addr_info
|
| 178 |
+
sock = None
|
| 179 |
+
try:
|
| 180 |
+
if socket_factory is not None:
|
| 181 |
+
sock = socket_factory(addr_info)
|
| 182 |
+
else:
|
| 183 |
+
sock = socket.socket(family=family, type=type_, proto=proto)
|
| 184 |
+
if open_sockets is not None:
|
| 185 |
+
open_sockets.add(sock)
|
| 186 |
+
sock.setblocking(False)
|
| 187 |
+
if local_addr_infos is not None:
|
| 188 |
+
for lfamily, _, _, _, laddr in local_addr_infos:
|
| 189 |
+
# skip local addresses of different family
|
| 190 |
+
if lfamily != family:
|
| 191 |
+
continue
|
| 192 |
+
try:
|
| 193 |
+
sock.bind(laddr)
|
| 194 |
+
break
|
| 195 |
+
except OSError as exc:
|
| 196 |
+
msg = (
|
| 197 |
+
f"error while attempting to bind on "
|
| 198 |
+
f"address {laddr!r}: "
|
| 199 |
+
f"{(exc.strerror or '').lower()}"
|
| 200 |
+
)
|
| 201 |
+
exc = OSError(exc.errno, msg)
|
| 202 |
+
my_exceptions.append(exc)
|
| 203 |
+
else: # all bind attempts failed
|
| 204 |
+
if my_exceptions:
|
| 205 |
+
raise my_exceptions.pop()
|
| 206 |
+
else:
|
| 207 |
+
raise OSError(f"no matching local address with {family=} found")
|
| 208 |
+
await loop.sock_connect(sock, address)
|
| 209 |
+
return sock
|
| 210 |
+
except (RuntimeError, OSError) as exc:
|
| 211 |
+
my_exceptions.append(exc)
|
| 212 |
+
if sock is not None:
|
| 213 |
+
if open_sockets is not None:
|
| 214 |
+
open_sockets.remove(sock)
|
| 215 |
+
try:
|
| 216 |
+
sock.close()
|
| 217 |
+
except OSError as e:
|
| 218 |
+
my_exceptions.append(e)
|
| 219 |
+
raise
|
| 220 |
+
raise
|
| 221 |
+
except:
|
| 222 |
+
if sock is not None:
|
| 223 |
+
if open_sockets is not None:
|
| 224 |
+
open_sockets.remove(sock)
|
| 225 |
+
try:
|
| 226 |
+
sock.close()
|
| 227 |
+
except OSError as e:
|
| 228 |
+
my_exceptions.append(e)
|
| 229 |
+
raise
|
| 230 |
+
raise
|
| 231 |
+
finally:
|
| 232 |
+
exceptions = my_exceptions = None # type: ignore[assignment]
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def _interleave_addrinfos(
|
| 236 |
+
addrinfos: Sequence[AddrInfoType], first_address_family_count: int = 1
|
| 237 |
+
) -> List[AddrInfoType]:
|
| 238 |
+
"""Interleave list of addrinfo tuples by family."""
|
| 239 |
+
# Group addresses by family
|
| 240 |
+
addrinfos_by_family: collections.OrderedDict[int, List[AddrInfoType]] = (
|
| 241 |
+
collections.OrderedDict()
|
| 242 |
+
)
|
| 243 |
+
for addr in addrinfos:
|
| 244 |
+
family = addr[0]
|
| 245 |
+
if family not in addrinfos_by_family:
|
| 246 |
+
addrinfos_by_family[family] = []
|
| 247 |
+
addrinfos_by_family[family].append(addr)
|
| 248 |
+
addrinfos_lists = list(addrinfos_by_family.values())
|
| 249 |
+
|
| 250 |
+
reordered: List[AddrInfoType] = []
|
| 251 |
+
if first_address_family_count > 1:
|
| 252 |
+
reordered.extend(addrinfos_lists[0][: first_address_family_count - 1])
|
| 253 |
+
del addrinfos_lists[0][: first_address_family_count - 1]
|
| 254 |
+
reordered.extend(
|
| 255 |
+
a
|
| 256 |
+
for a in itertools.chain.from_iterable(itertools.zip_longest(*addrinfos_lists))
|
| 257 |
+
if a is not None
|
| 258 |
+
)
|
| 259 |
+
return reordered
|
venv/Lib/site-packages/aiohappyeyeballs/py.typed
ADDED
|
File without changes
|
venv/Lib/site-packages/aiohappyeyeballs/types.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Types for aiohappyeyeballs."""
|
| 2 |
+
|
| 3 |
+
import socket
|
| 4 |
+
|
| 5 |
+
# PY3.9: Import Callable from typing until we drop Python 3.9 support
|
| 6 |
+
# https://github.com/python/cpython/issues/87131
|
| 7 |
+
from typing import Callable, Tuple, Union
|
| 8 |
+
|
| 9 |
+
AddrInfoType = Tuple[
|
| 10 |
+
Union[int, socket.AddressFamily],
|
| 11 |
+
Union[int, socket.SocketKind],
|
| 12 |
+
int,
|
| 13 |
+
str,
|
| 14 |
+
Tuple, # type: ignore[type-arg]
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
SocketFactoryType = Callable[[AddrInfoType], socket.socket]
|
venv/Lib/site-packages/aiohappyeyeballs/utils.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility functions for aiohappyeyeballs."""
|
| 2 |
+
|
| 3 |
+
import ipaddress
|
| 4 |
+
import socket
|
| 5 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 6 |
+
|
| 7 |
+
from .types import AddrInfoType
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def addr_to_addr_infos(
|
| 11 |
+
addr: Optional[
|
| 12 |
+
Union[Tuple[str, int, int, int], Tuple[str, int, int], Tuple[str, int]]
|
| 13 |
+
],
|
| 14 |
+
) -> Optional[List[AddrInfoType]]:
|
| 15 |
+
"""Convert an address tuple to a list of addr_info tuples."""
|
| 16 |
+
if addr is None:
|
| 17 |
+
return None
|
| 18 |
+
host = addr[0]
|
| 19 |
+
port = addr[1]
|
| 20 |
+
is_ipv6 = ":" in host
|
| 21 |
+
if is_ipv6:
|
| 22 |
+
flowinfo = 0
|
| 23 |
+
scopeid = 0
|
| 24 |
+
addr_len = len(addr)
|
| 25 |
+
if addr_len >= 4:
|
| 26 |
+
scopeid = addr[3] # type: ignore[misc]
|
| 27 |
+
if addr_len >= 3:
|
| 28 |
+
flowinfo = addr[2] # type: ignore[misc]
|
| 29 |
+
addr = (host, port, flowinfo, scopeid)
|
| 30 |
+
family = socket.AF_INET6
|
| 31 |
+
else:
|
| 32 |
+
addr = (host, port)
|
| 33 |
+
family = socket.AF_INET
|
| 34 |
+
return [(family, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", addr)]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def pop_addr_infos_interleave(
|
| 38 |
+
addr_infos: List[AddrInfoType], interleave: Optional[int] = None
|
| 39 |
+
) -> None:
|
| 40 |
+
"""
|
| 41 |
+
Pop addr_info from the list of addr_infos by family up to interleave times.
|
| 42 |
+
|
| 43 |
+
The interleave parameter is used to know how many addr_infos for
|
| 44 |
+
each family should be popped of the top of the list.
|
| 45 |
+
"""
|
| 46 |
+
seen: Dict[int, int] = {}
|
| 47 |
+
if interleave is None:
|
| 48 |
+
interleave = 1
|
| 49 |
+
to_remove: List[AddrInfoType] = []
|
| 50 |
+
for addr_info in addr_infos:
|
| 51 |
+
family = addr_info[0]
|
| 52 |
+
if family not in seen:
|
| 53 |
+
seen[family] = 0
|
| 54 |
+
if seen[family] < interleave:
|
| 55 |
+
to_remove.append(addr_info)
|
| 56 |
+
seen[family] += 1
|
| 57 |
+
for addr_info in to_remove:
|
| 58 |
+
addr_infos.remove(addr_info)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _addr_tuple_to_ip_address(
|
| 62 |
+
addr: Union[Tuple[str, int], Tuple[str, int, int, int]],
|
| 63 |
+
) -> Union[
|
| 64 |
+
Tuple[ipaddress.IPv4Address, int], Tuple[ipaddress.IPv6Address, int, int, int]
|
| 65 |
+
]:
|
| 66 |
+
"""Convert an address tuple to an IPv4Address."""
|
| 67 |
+
return (ipaddress.ip_address(addr[0]), *addr[1:])
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def remove_addr_infos(
|
| 71 |
+
addr_infos: List[AddrInfoType],
|
| 72 |
+
addr: Union[Tuple[str, int], Tuple[str, int, int, int]],
|
| 73 |
+
) -> None:
|
| 74 |
+
"""
|
| 75 |
+
Remove an address from the list of addr_infos.
|
| 76 |
+
|
| 77 |
+
The addr value is typically the return value of
|
| 78 |
+
sock.getpeername().
|
| 79 |
+
"""
|
| 80 |
+
bad_addrs_infos: List[AddrInfoType] = []
|
| 81 |
+
for addr_info in addr_infos:
|
| 82 |
+
if addr_info[-1] == addr:
|
| 83 |
+
bad_addrs_infos.append(addr_info)
|
| 84 |
+
if bad_addrs_infos:
|
| 85 |
+
for bad_addr_info in bad_addrs_infos:
|
| 86 |
+
addr_infos.remove(bad_addr_info)
|
| 87 |
+
return
|
| 88 |
+
# Slow path in case addr is formatted differently
|
| 89 |
+
match_addr = _addr_tuple_to_ip_address(addr)
|
| 90 |
+
for addr_info in addr_infos:
|
| 91 |
+
if match_addr == _addr_tuple_to_ip_address(addr_info[-1]):
|
| 92 |
+
bad_addrs_infos.append(addr_info)
|
| 93 |
+
if bad_addrs_infos:
|
| 94 |
+
for bad_addr_info in bad_addrs_infos:
|
| 95 |
+
addr_infos.remove(bad_addr_info)
|
| 96 |
+
return
|
| 97 |
+
raise ValueError(f"Address {addr} not found in addr_infos")
|
venv/Lib/site-packages/aiohttp/abc.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import logging
|
| 3 |
+
import socket
|
| 4 |
+
import zlib
|
| 5 |
+
from abc import ABC, abstractmethod
|
| 6 |
+
from collections.abc import Sized
|
| 7 |
+
from http.cookies import BaseCookie, Morsel
|
| 8 |
+
from typing import (
|
| 9 |
+
TYPE_CHECKING,
|
| 10 |
+
Any,
|
| 11 |
+
Awaitable,
|
| 12 |
+
Callable,
|
| 13 |
+
Dict,
|
| 14 |
+
Generator,
|
| 15 |
+
Iterable,
|
| 16 |
+
List,
|
| 17 |
+
Optional,
|
| 18 |
+
Tuple,
|
| 19 |
+
TypedDict,
|
| 20 |
+
Union,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
from multidict import CIMultiDict
|
| 24 |
+
from yarl import URL
|
| 25 |
+
|
| 26 |
+
from .typedefs import LooseCookies
|
| 27 |
+
|
| 28 |
+
if TYPE_CHECKING:
|
| 29 |
+
from .web_app import Application
|
| 30 |
+
from .web_exceptions import HTTPException
|
| 31 |
+
from .web_request import BaseRequest, Request
|
| 32 |
+
from .web_response import StreamResponse
|
| 33 |
+
else:
|
| 34 |
+
BaseRequest = Request = Application = StreamResponse = None
|
| 35 |
+
HTTPException = None
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class AbstractRouter(ABC):
|
| 39 |
+
def __init__(self) -> None:
|
| 40 |
+
self._frozen = False
|
| 41 |
+
|
| 42 |
+
def post_init(self, app: Application) -> None:
|
| 43 |
+
"""Post init stage.
|
| 44 |
+
|
| 45 |
+
Not an abstract method for sake of backward compatibility,
|
| 46 |
+
but if the router wants to be aware of the application
|
| 47 |
+
it can override this.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
@property
|
| 51 |
+
def frozen(self) -> bool:
|
| 52 |
+
return self._frozen
|
| 53 |
+
|
| 54 |
+
def freeze(self) -> None:
|
| 55 |
+
"""Freeze router."""
|
| 56 |
+
self._frozen = True
|
| 57 |
+
|
| 58 |
+
@abstractmethod
|
| 59 |
+
async def resolve(self, request: Request) -> "AbstractMatchInfo":
|
| 60 |
+
"""Return MATCH_INFO for given request"""
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class AbstractMatchInfo(ABC):
|
| 64 |
+
|
| 65 |
+
__slots__ = ()
|
| 66 |
+
|
| 67 |
+
@property # pragma: no branch
|
| 68 |
+
@abstractmethod
|
| 69 |
+
def handler(self) -> Callable[[Request], Awaitable[StreamResponse]]:
|
| 70 |
+
"""Execute matched request handler"""
|
| 71 |
+
|
| 72 |
+
@property
|
| 73 |
+
@abstractmethod
|
| 74 |
+
def expect_handler(
|
| 75 |
+
self,
|
| 76 |
+
) -> Callable[[Request], Awaitable[Optional[StreamResponse]]]:
|
| 77 |
+
"""Expect handler for 100-continue processing"""
|
| 78 |
+
|
| 79 |
+
@property # pragma: no branch
|
| 80 |
+
@abstractmethod
|
| 81 |
+
def http_exception(self) -> Optional[HTTPException]:
|
| 82 |
+
"""HTTPException instance raised on router's resolving, or None"""
|
| 83 |
+
|
| 84 |
+
@abstractmethod # pragma: no branch
|
| 85 |
+
def get_info(self) -> Dict[str, Any]:
|
| 86 |
+
"""Return a dict with additional info useful for introspection"""
|
| 87 |
+
|
| 88 |
+
@property # pragma: no branch
|
| 89 |
+
@abstractmethod
|
| 90 |
+
def apps(self) -> Tuple[Application, ...]:
|
| 91 |
+
"""Stack of nested applications.
|
| 92 |
+
|
| 93 |
+
Top level application is left-most element.
|
| 94 |
+
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
@abstractmethod
|
| 98 |
+
def add_app(self, app: Application) -> None:
|
| 99 |
+
"""Add application to the nested apps stack."""
|
| 100 |
+
|
| 101 |
+
@abstractmethod
|
| 102 |
+
def freeze(self) -> None:
|
| 103 |
+
"""Freeze the match info.
|
| 104 |
+
|
| 105 |
+
The method is called after route resolution.
|
| 106 |
+
|
| 107 |
+
After the call .add_app() is forbidden.
|
| 108 |
+
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class AbstractView(ABC):
|
| 113 |
+
"""Abstract class based view."""
|
| 114 |
+
|
| 115 |
+
def __init__(self, request: Request) -> None:
|
| 116 |
+
self._request = request
|
| 117 |
+
|
| 118 |
+
@property
|
| 119 |
+
def request(self) -> Request:
|
| 120 |
+
"""Request instance."""
|
| 121 |
+
return self._request
|
| 122 |
+
|
| 123 |
+
@abstractmethod
|
| 124 |
+
def __await__(self) -> Generator[Any, None, StreamResponse]:
|
| 125 |
+
"""Execute the view handler."""
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class ResolveResult(TypedDict):
|
| 129 |
+
"""Resolve result.
|
| 130 |
+
|
| 131 |
+
This is the result returned from an AbstractResolver's
|
| 132 |
+
resolve method.
|
| 133 |
+
|
| 134 |
+
:param hostname: The hostname that was provided.
|
| 135 |
+
:param host: The IP address that was resolved.
|
| 136 |
+
:param port: The port that was resolved.
|
| 137 |
+
:param family: The address family that was resolved.
|
| 138 |
+
:param proto: The protocol that was resolved.
|
| 139 |
+
:param flags: The flags that were resolved.
|
| 140 |
+
"""
|
| 141 |
+
|
| 142 |
+
hostname: str
|
| 143 |
+
host: str
|
| 144 |
+
port: int
|
| 145 |
+
family: int
|
| 146 |
+
proto: int
|
| 147 |
+
flags: int
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class AbstractResolver(ABC):
|
| 151 |
+
"""Abstract DNS resolver."""
|
| 152 |
+
|
| 153 |
+
@abstractmethod
|
| 154 |
+
async def resolve(
|
| 155 |
+
self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
|
| 156 |
+
) -> List[ResolveResult]:
|
| 157 |
+
"""Return IP address for given hostname"""
|
| 158 |
+
|
| 159 |
+
@abstractmethod
|
| 160 |
+
async def close(self) -> None:
|
| 161 |
+
"""Release resolver"""
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
if TYPE_CHECKING:
|
| 165 |
+
IterableBase = Iterable[Morsel[str]]
|
| 166 |
+
else:
|
| 167 |
+
IterableBase = Iterable
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
ClearCookiePredicate = Callable[["Morsel[str]"], bool]
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class AbstractCookieJar(Sized, IterableBase):
|
| 174 |
+
"""Abstract Cookie Jar."""
|
| 175 |
+
|
| 176 |
+
def __init__(self, *, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
| 177 |
+
self._loop = loop or asyncio.get_running_loop()
|
| 178 |
+
|
| 179 |
+
@property
|
| 180 |
+
@abstractmethod
|
| 181 |
+
def quote_cookie(self) -> bool:
|
| 182 |
+
"""Return True if cookies should be quoted."""
|
| 183 |
+
|
| 184 |
+
@abstractmethod
|
| 185 |
+
def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None:
|
| 186 |
+
"""Clear all cookies if no predicate is passed."""
|
| 187 |
+
|
| 188 |
+
@abstractmethod
|
| 189 |
+
def clear_domain(self, domain: str) -> None:
|
| 190 |
+
"""Clear all cookies for domain and all subdomains."""
|
| 191 |
+
|
| 192 |
+
@abstractmethod
|
| 193 |
+
def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
|
| 194 |
+
"""Update cookies."""
|
| 195 |
+
|
| 196 |
+
@abstractmethod
|
| 197 |
+
def filter_cookies(self, request_url: URL) -> "BaseCookie[str]":
|
| 198 |
+
"""Return the jar's cookies filtered by their attributes."""
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class AbstractStreamWriter(ABC):
|
| 202 |
+
"""Abstract stream writer."""
|
| 203 |
+
|
| 204 |
+
buffer_size: int = 0
|
| 205 |
+
output_size: int = 0
|
| 206 |
+
length: Optional[int] = 0
|
| 207 |
+
|
| 208 |
+
@abstractmethod
|
| 209 |
+
async def write(self, chunk: Union[bytes, bytearray, memoryview]) -> None:
|
| 210 |
+
"""Write chunk into stream."""
|
| 211 |
+
|
| 212 |
+
@abstractmethod
|
| 213 |
+
async def write_eof(self, chunk: bytes = b"") -> None:
|
| 214 |
+
"""Write last chunk."""
|
| 215 |
+
|
| 216 |
+
@abstractmethod
|
| 217 |
+
async def drain(self) -> None:
|
| 218 |
+
"""Flush the write buffer."""
|
| 219 |
+
|
| 220 |
+
@abstractmethod
|
| 221 |
+
def enable_compression(
|
| 222 |
+
self, encoding: str = "deflate", strategy: int = zlib.Z_DEFAULT_STRATEGY
|
| 223 |
+
) -> None:
|
| 224 |
+
"""Enable HTTP body compression"""
|
| 225 |
+
|
| 226 |
+
@abstractmethod
|
| 227 |
+
def enable_chunking(self) -> None:
|
| 228 |
+
"""Enable HTTP chunked mode"""
|
| 229 |
+
|
| 230 |
+
@abstractmethod
|
| 231 |
+
async def write_headers(
|
| 232 |
+
self, status_line: str, headers: "CIMultiDict[str]"
|
| 233 |
+
) -> None:
|
| 234 |
+
"""Write HTTP headers"""
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class AbstractAccessLogger(ABC):
|
| 238 |
+
"""Abstract writer to access log."""
|
| 239 |
+
|
| 240 |
+
__slots__ = ("logger", "log_format")
|
| 241 |
+
|
| 242 |
+
def __init__(self, logger: logging.Logger, log_format: str) -> None:
|
| 243 |
+
self.logger = logger
|
| 244 |
+
self.log_format = log_format
|
| 245 |
+
|
| 246 |
+
@abstractmethod
|
| 247 |
+
def log(self, request: BaseRequest, response: StreamResponse, time: float) -> None:
|
| 248 |
+
"""Emit log to logger."""
|
| 249 |
+
|
| 250 |
+
@property
|
| 251 |
+
def enabled(self) -> bool:
|
| 252 |
+
"""Check if logger is enabled."""
|
| 253 |
+
return True
|
venv/Lib/site-packages/aiohttp/base_protocol.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
from typing import Optional, cast
|
| 3 |
+
|
| 4 |
+
from .client_exceptions import ClientConnectionResetError
|
| 5 |
+
from .helpers import set_exception
|
| 6 |
+
from .tcp_helpers import tcp_nodelay
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class BaseProtocol(asyncio.Protocol):
|
| 10 |
+
__slots__ = (
|
| 11 |
+
"_loop",
|
| 12 |
+
"_paused",
|
| 13 |
+
"_drain_waiter",
|
| 14 |
+
"_connection_lost",
|
| 15 |
+
"_reading_paused",
|
| 16 |
+
"transport",
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
|
| 20 |
+
self._loop: asyncio.AbstractEventLoop = loop
|
| 21 |
+
self._paused = False
|
| 22 |
+
self._drain_waiter: Optional[asyncio.Future[None]] = None
|
| 23 |
+
self._reading_paused = False
|
| 24 |
+
|
| 25 |
+
self.transport: Optional[asyncio.Transport] = None
|
| 26 |
+
|
| 27 |
+
@property
|
| 28 |
+
def connected(self) -> bool:
|
| 29 |
+
"""Return True if the connection is open."""
|
| 30 |
+
return self.transport is not None
|
| 31 |
+
|
| 32 |
+
@property
|
| 33 |
+
def writing_paused(self) -> bool:
|
| 34 |
+
return self._paused
|
| 35 |
+
|
| 36 |
+
def pause_writing(self) -> None:
|
| 37 |
+
assert not self._paused
|
| 38 |
+
self._paused = True
|
| 39 |
+
|
| 40 |
+
def resume_writing(self) -> None:
|
| 41 |
+
assert self._paused
|
| 42 |
+
self._paused = False
|
| 43 |
+
|
| 44 |
+
waiter = self._drain_waiter
|
| 45 |
+
if waiter is not None:
|
| 46 |
+
self._drain_waiter = None
|
| 47 |
+
if not waiter.done():
|
| 48 |
+
waiter.set_result(None)
|
| 49 |
+
|
| 50 |
+
def pause_reading(self) -> None:
|
| 51 |
+
if not self._reading_paused and self.transport is not None:
|
| 52 |
+
try:
|
| 53 |
+
self.transport.pause_reading()
|
| 54 |
+
except (AttributeError, NotImplementedError, RuntimeError):
|
| 55 |
+
pass
|
| 56 |
+
self._reading_paused = True
|
| 57 |
+
|
| 58 |
+
def resume_reading(self) -> None:
|
| 59 |
+
if self._reading_paused and self.transport is not None:
|
| 60 |
+
try:
|
| 61 |
+
self.transport.resume_reading()
|
| 62 |
+
except (AttributeError, NotImplementedError, RuntimeError):
|
| 63 |
+
pass
|
| 64 |
+
self._reading_paused = False
|
| 65 |
+
|
| 66 |
+
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
| 67 |
+
tr = cast(asyncio.Transport, transport)
|
| 68 |
+
tcp_nodelay(tr, True)
|
| 69 |
+
self.transport = tr
|
| 70 |
+
|
| 71 |
+
def connection_lost(self, exc: Optional[BaseException]) -> None:
|
| 72 |
+
# Wake up the writer if currently paused.
|
| 73 |
+
self.transport = None
|
| 74 |
+
if not self._paused:
|
| 75 |
+
return
|
| 76 |
+
waiter = self._drain_waiter
|
| 77 |
+
if waiter is None:
|
| 78 |
+
return
|
| 79 |
+
self._drain_waiter = None
|
| 80 |
+
if waiter.done():
|
| 81 |
+
return
|
| 82 |
+
if exc is None:
|
| 83 |
+
waiter.set_result(None)
|
| 84 |
+
else:
|
| 85 |
+
set_exception(
|
| 86 |
+
waiter,
|
| 87 |
+
ConnectionError("Connection lost"),
|
| 88 |
+
exc,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
async def _drain_helper(self) -> None:
|
| 92 |
+
if self.transport is None:
|
| 93 |
+
raise ClientConnectionResetError("Connection lost")
|
| 94 |
+
if not self._paused:
|
| 95 |
+
return
|
| 96 |
+
waiter = self._drain_waiter
|
| 97 |
+
if waiter is None:
|
| 98 |
+
waiter = self._loop.create_future()
|
| 99 |
+
self._drain_waiter = waiter
|
| 100 |
+
await asyncio.shield(waiter)
|
venv/Lib/site-packages/scipy-1.15.3-cp312-cp312-win_amd64.whl
ADDED
|
File without changes
|
venv/Lib/site-packages/six.py
ADDED
|
@@ -0,0 +1,1003 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2010-2024 Benjamin Peterson
|
| 2 |
+
#
|
| 3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 4 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 5 |
+
# in the Software without restriction, including without limitation the rights
|
| 6 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 7 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 8 |
+
# furnished to do so, subject to the following conditions:
|
| 9 |
+
#
|
| 10 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 11 |
+
# copies or substantial portions of the Software.
|
| 12 |
+
#
|
| 13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 16 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 18 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 19 |
+
# SOFTWARE.
|
| 20 |
+
|
| 21 |
+
"""Utilities for writing code that runs on Python 2 and 3"""
|
| 22 |
+
|
| 23 |
+
from __future__ import absolute_import
|
| 24 |
+
|
| 25 |
+
import functools
|
| 26 |
+
import itertools
|
| 27 |
+
import operator
|
| 28 |
+
import sys
|
| 29 |
+
import types
|
| 30 |
+
|
| 31 |
+
__author__ = "Benjamin Peterson <benjamin@python.org>"
|
| 32 |
+
__version__ = "1.17.0"
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# Useful for very coarse version differentiation.
|
| 36 |
+
PY2 = sys.version_info[0] == 2
|
| 37 |
+
PY3 = sys.version_info[0] == 3
|
| 38 |
+
PY34 = sys.version_info[0:2] >= (3, 4)
|
| 39 |
+
|
| 40 |
+
if PY3:
|
| 41 |
+
string_types = str,
|
| 42 |
+
integer_types = int,
|
| 43 |
+
class_types = type,
|
| 44 |
+
text_type = str
|
| 45 |
+
binary_type = bytes
|
| 46 |
+
|
| 47 |
+
MAXSIZE = sys.maxsize
|
| 48 |
+
else:
|
| 49 |
+
string_types = basestring,
|
| 50 |
+
integer_types = (int, long)
|
| 51 |
+
class_types = (type, types.ClassType)
|
| 52 |
+
text_type = unicode
|
| 53 |
+
binary_type = str
|
| 54 |
+
|
| 55 |
+
if sys.platform.startswith("java"):
|
| 56 |
+
# Jython always uses 32 bits.
|
| 57 |
+
MAXSIZE = int((1 << 31) - 1)
|
| 58 |
+
else:
|
| 59 |
+
# It's possible to have sizeof(long) != sizeof(Py_ssize_t).
|
| 60 |
+
class X(object):
|
| 61 |
+
|
| 62 |
+
def __len__(self):
|
| 63 |
+
return 1 << 31
|
| 64 |
+
try:
|
| 65 |
+
len(X())
|
| 66 |
+
except OverflowError:
|
| 67 |
+
# 32-bit
|
| 68 |
+
MAXSIZE = int((1 << 31) - 1)
|
| 69 |
+
else:
|
| 70 |
+
# 64-bit
|
| 71 |
+
MAXSIZE = int((1 << 63) - 1)
|
| 72 |
+
del X
|
| 73 |
+
|
| 74 |
+
if PY34:
|
| 75 |
+
from importlib.util import spec_from_loader
|
| 76 |
+
else:
|
| 77 |
+
spec_from_loader = None
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _add_doc(func, doc):
|
| 81 |
+
"""Add documentation to a function."""
|
| 82 |
+
func.__doc__ = doc
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _import_module(name):
|
| 86 |
+
"""Import module, returning the module after the last dot."""
|
| 87 |
+
__import__(name)
|
| 88 |
+
return sys.modules[name]
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class _LazyDescr(object):
|
| 92 |
+
|
| 93 |
+
def __init__(self, name):
|
| 94 |
+
self.name = name
|
| 95 |
+
|
| 96 |
+
def __get__(self, obj, tp):
|
| 97 |
+
result = self._resolve()
|
| 98 |
+
setattr(obj, self.name, result) # Invokes __set__.
|
| 99 |
+
try:
|
| 100 |
+
# This is a bit ugly, but it avoids running this again by
|
| 101 |
+
# removing this descriptor.
|
| 102 |
+
delattr(obj.__class__, self.name)
|
| 103 |
+
except AttributeError:
|
| 104 |
+
pass
|
| 105 |
+
return result
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class MovedModule(_LazyDescr):
|
| 109 |
+
|
| 110 |
+
def __init__(self, name, old, new=None):
|
| 111 |
+
super(MovedModule, self).__init__(name)
|
| 112 |
+
if PY3:
|
| 113 |
+
if new is None:
|
| 114 |
+
new = name
|
| 115 |
+
self.mod = new
|
| 116 |
+
else:
|
| 117 |
+
self.mod = old
|
| 118 |
+
|
| 119 |
+
def _resolve(self):
|
| 120 |
+
return _import_module(self.mod)
|
| 121 |
+
|
| 122 |
+
def __getattr__(self, attr):
|
| 123 |
+
_module = self._resolve()
|
| 124 |
+
value = getattr(_module, attr)
|
| 125 |
+
setattr(self, attr, value)
|
| 126 |
+
return value
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class _LazyModule(types.ModuleType):
|
| 130 |
+
|
| 131 |
+
def __init__(self, name):
|
| 132 |
+
super(_LazyModule, self).__init__(name)
|
| 133 |
+
self.__doc__ = self.__class__.__doc__
|
| 134 |
+
|
| 135 |
+
def __dir__(self):
|
| 136 |
+
attrs = ["__doc__", "__name__"]
|
| 137 |
+
attrs += [attr.name for attr in self._moved_attributes]
|
| 138 |
+
return attrs
|
| 139 |
+
|
| 140 |
+
# Subclasses should override this
|
| 141 |
+
_moved_attributes = []
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class MovedAttribute(_LazyDescr):
|
| 145 |
+
|
| 146 |
+
def __init__(self, name, old_mod, new_mod, old_attr=None, new_attr=None):
|
| 147 |
+
super(MovedAttribute, self).__init__(name)
|
| 148 |
+
if PY3:
|
| 149 |
+
if new_mod is None:
|
| 150 |
+
new_mod = name
|
| 151 |
+
self.mod = new_mod
|
| 152 |
+
if new_attr is None:
|
| 153 |
+
if old_attr is None:
|
| 154 |
+
new_attr = name
|
| 155 |
+
else:
|
| 156 |
+
new_attr = old_attr
|
| 157 |
+
self.attr = new_attr
|
| 158 |
+
else:
|
| 159 |
+
self.mod = old_mod
|
| 160 |
+
if old_attr is None:
|
| 161 |
+
old_attr = name
|
| 162 |
+
self.attr = old_attr
|
| 163 |
+
|
| 164 |
+
def _resolve(self):
|
| 165 |
+
module = _import_module(self.mod)
|
| 166 |
+
return getattr(module, self.attr)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class _SixMetaPathImporter(object):
|
| 170 |
+
|
| 171 |
+
"""
|
| 172 |
+
A meta path importer to import six.moves and its submodules.
|
| 173 |
+
|
| 174 |
+
This class implements a PEP302 finder and loader. It should be compatible
|
| 175 |
+
with Python 2.5 and all existing versions of Python3
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
def __init__(self, six_module_name):
|
| 179 |
+
self.name = six_module_name
|
| 180 |
+
self.known_modules = {}
|
| 181 |
+
|
| 182 |
+
def _add_module(self, mod, *fullnames):
|
| 183 |
+
for fullname in fullnames:
|
| 184 |
+
self.known_modules[self.name + "." + fullname] = mod
|
| 185 |
+
|
| 186 |
+
def _get_module(self, fullname):
|
| 187 |
+
return self.known_modules[self.name + "." + fullname]
|
| 188 |
+
|
| 189 |
+
def find_module(self, fullname, path=None):
|
| 190 |
+
if fullname in self.known_modules:
|
| 191 |
+
return self
|
| 192 |
+
return None
|
| 193 |
+
|
| 194 |
+
def find_spec(self, fullname, path, target=None):
|
| 195 |
+
if fullname in self.known_modules:
|
| 196 |
+
return spec_from_loader(fullname, self)
|
| 197 |
+
return None
|
| 198 |
+
|
| 199 |
+
def __get_module(self, fullname):
|
| 200 |
+
try:
|
| 201 |
+
return self.known_modules[fullname]
|
| 202 |
+
except KeyError:
|
| 203 |
+
raise ImportError("This loader does not know module " + fullname)
|
| 204 |
+
|
| 205 |
+
def load_module(self, fullname):
|
| 206 |
+
try:
|
| 207 |
+
# in case of a reload
|
| 208 |
+
return sys.modules[fullname]
|
| 209 |
+
except KeyError:
|
| 210 |
+
pass
|
| 211 |
+
mod = self.__get_module(fullname)
|
| 212 |
+
if isinstance(mod, MovedModule):
|
| 213 |
+
mod = mod._resolve()
|
| 214 |
+
else:
|
| 215 |
+
mod.__loader__ = self
|
| 216 |
+
sys.modules[fullname] = mod
|
| 217 |
+
return mod
|
| 218 |
+
|
| 219 |
+
def is_package(self, fullname):
|
| 220 |
+
"""
|
| 221 |
+
Return true, if the named module is a package.
|
| 222 |
+
|
| 223 |
+
We need this method to get correct spec objects with
|
| 224 |
+
Python 3.4 (see PEP451)
|
| 225 |
+
"""
|
| 226 |
+
return hasattr(self.__get_module(fullname), "__path__")
|
| 227 |
+
|
| 228 |
+
def get_code(self, fullname):
|
| 229 |
+
"""Return None
|
| 230 |
+
|
| 231 |
+
Required, if is_package is implemented"""
|
| 232 |
+
self.__get_module(fullname) # eventually raises ImportError
|
| 233 |
+
return None
|
| 234 |
+
get_source = get_code # same as get_code
|
| 235 |
+
|
| 236 |
+
def create_module(self, spec):
|
| 237 |
+
return self.load_module(spec.name)
|
| 238 |
+
|
| 239 |
+
def exec_module(self, module):
|
| 240 |
+
pass
|
| 241 |
+
|
| 242 |
+
_importer = _SixMetaPathImporter(__name__)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class _MovedItems(_LazyModule):
|
| 246 |
+
|
| 247 |
+
"""Lazy loading of moved objects"""
|
| 248 |
+
__path__ = [] # mark as package
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
_moved_attributes = [
|
| 252 |
+
MovedAttribute("cStringIO", "cStringIO", "io", "StringIO"),
|
| 253 |
+
MovedAttribute("filter", "itertools", "builtins", "ifilter", "filter"),
|
| 254 |
+
MovedAttribute("filterfalse", "itertools", "itertools", "ifilterfalse", "filterfalse"),
|
| 255 |
+
MovedAttribute("input", "__builtin__", "builtins", "raw_input", "input"),
|
| 256 |
+
MovedAttribute("intern", "__builtin__", "sys"),
|
| 257 |
+
MovedAttribute("map", "itertools", "builtins", "imap", "map"),
|
| 258 |
+
MovedAttribute("getcwd", "os", "os", "getcwdu", "getcwd"),
|
| 259 |
+
MovedAttribute("getcwdb", "os", "os", "getcwd", "getcwdb"),
|
| 260 |
+
MovedAttribute("getoutput", "commands", "subprocess"),
|
| 261 |
+
MovedAttribute("range", "__builtin__", "builtins", "xrange", "range"),
|
| 262 |
+
MovedAttribute("reload_module", "__builtin__", "importlib" if PY34 else "imp", "reload"),
|
| 263 |
+
MovedAttribute("reduce", "__builtin__", "functools"),
|
| 264 |
+
MovedAttribute("shlex_quote", "pipes", "shlex", "quote"),
|
| 265 |
+
MovedAttribute("StringIO", "StringIO", "io"),
|
| 266 |
+
MovedAttribute("UserDict", "UserDict", "collections", "IterableUserDict", "UserDict"),
|
| 267 |
+
MovedAttribute("UserList", "UserList", "collections"),
|
| 268 |
+
MovedAttribute("UserString", "UserString", "collections"),
|
| 269 |
+
MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"),
|
| 270 |
+
MovedAttribute("zip", "itertools", "builtins", "izip", "zip"),
|
| 271 |
+
MovedAttribute("zip_longest", "itertools", "itertools", "izip_longest", "zip_longest"),
|
| 272 |
+
MovedModule("builtins", "__builtin__"),
|
| 273 |
+
MovedModule("configparser", "ConfigParser"),
|
| 274 |
+
MovedModule("collections_abc", "collections", "collections.abc" if sys.version_info >= (3, 3) else "collections"),
|
| 275 |
+
MovedModule("copyreg", "copy_reg"),
|
| 276 |
+
MovedModule("dbm_gnu", "gdbm", "dbm.gnu"),
|
| 277 |
+
MovedModule("dbm_ndbm", "dbm", "dbm.ndbm"),
|
| 278 |
+
MovedModule("_dummy_thread", "dummy_thread", "_dummy_thread" if sys.version_info < (3, 9) else "_thread"),
|
| 279 |
+
MovedModule("http_cookiejar", "cookielib", "http.cookiejar"),
|
| 280 |
+
MovedModule("http_cookies", "Cookie", "http.cookies"),
|
| 281 |
+
MovedModule("html_entities", "htmlentitydefs", "html.entities"),
|
| 282 |
+
MovedModule("html_parser", "HTMLParser", "html.parser"),
|
| 283 |
+
MovedModule("http_client", "httplib", "http.client"),
|
| 284 |
+
MovedModule("email_mime_base", "email.MIMEBase", "email.mime.base"),
|
| 285 |
+
MovedModule("email_mime_image", "email.MIMEImage", "email.mime.image"),
|
| 286 |
+
MovedModule("email_mime_multipart", "email.MIMEMultipart", "email.mime.multipart"),
|
| 287 |
+
MovedModule("email_mime_nonmultipart", "email.MIMENonMultipart", "email.mime.nonmultipart"),
|
| 288 |
+
MovedModule("email_mime_text", "email.MIMEText", "email.mime.text"),
|
| 289 |
+
MovedModule("BaseHTTPServer", "BaseHTTPServer", "http.server"),
|
| 290 |
+
MovedModule("CGIHTTPServer", "CGIHTTPServer", "http.server"),
|
| 291 |
+
MovedModule("SimpleHTTPServer", "SimpleHTTPServer", "http.server"),
|
| 292 |
+
MovedModule("cPickle", "cPickle", "pickle"),
|
| 293 |
+
MovedModule("queue", "Queue"),
|
| 294 |
+
MovedModule("reprlib", "repr"),
|
| 295 |
+
MovedModule("socketserver", "SocketServer"),
|
| 296 |
+
MovedModule("_thread", "thread", "_thread"),
|
| 297 |
+
MovedModule("tkinter", "Tkinter"),
|
| 298 |
+
MovedModule("tkinter_dialog", "Dialog", "tkinter.dialog"),
|
| 299 |
+
MovedModule("tkinter_filedialog", "FileDialog", "tkinter.filedialog"),
|
| 300 |
+
MovedModule("tkinter_scrolledtext", "ScrolledText", "tkinter.scrolledtext"),
|
| 301 |
+
MovedModule("tkinter_simpledialog", "SimpleDialog", "tkinter.simpledialog"),
|
| 302 |
+
MovedModule("tkinter_tix", "Tix", "tkinter.tix"),
|
| 303 |
+
MovedModule("tkinter_ttk", "ttk", "tkinter.ttk"),
|
| 304 |
+
MovedModule("tkinter_constants", "Tkconstants", "tkinter.constants"),
|
| 305 |
+
MovedModule("tkinter_dnd", "Tkdnd", "tkinter.dnd"),
|
| 306 |
+
MovedModule("tkinter_colorchooser", "tkColorChooser",
|
| 307 |
+
"tkinter.colorchooser"),
|
| 308 |
+
MovedModule("tkinter_commondialog", "tkCommonDialog",
|
| 309 |
+
"tkinter.commondialog"),
|
| 310 |
+
MovedModule("tkinter_tkfiledialog", "tkFileDialog", "tkinter.filedialog"),
|
| 311 |
+
MovedModule("tkinter_font", "tkFont", "tkinter.font"),
|
| 312 |
+
MovedModule("tkinter_messagebox", "tkMessageBox", "tkinter.messagebox"),
|
| 313 |
+
MovedModule("tkinter_tksimpledialog", "tkSimpleDialog",
|
| 314 |
+
"tkinter.simpledialog"),
|
| 315 |
+
MovedModule("urllib_parse", __name__ + ".moves.urllib_parse", "urllib.parse"),
|
| 316 |
+
MovedModule("urllib_error", __name__ + ".moves.urllib_error", "urllib.error"),
|
| 317 |
+
MovedModule("urllib", __name__ + ".moves.urllib", __name__ + ".moves.urllib"),
|
| 318 |
+
MovedModule("urllib_robotparser", "robotparser", "urllib.robotparser"),
|
| 319 |
+
MovedModule("xmlrpc_client", "xmlrpclib", "xmlrpc.client"),
|
| 320 |
+
MovedModule("xmlrpc_server", "SimpleXMLRPCServer", "xmlrpc.server"),
|
| 321 |
+
]
|
| 322 |
+
# Add windows specific modules.
|
| 323 |
+
if sys.platform == "win32":
|
| 324 |
+
_moved_attributes += [
|
| 325 |
+
MovedModule("winreg", "_winreg"),
|
| 326 |
+
]
|
| 327 |
+
|
| 328 |
+
for attr in _moved_attributes:
|
| 329 |
+
setattr(_MovedItems, attr.name, attr)
|
| 330 |
+
if isinstance(attr, MovedModule):
|
| 331 |
+
_importer._add_module(attr, "moves." + attr.name)
|
| 332 |
+
del attr
|
| 333 |
+
|
| 334 |
+
_MovedItems._moved_attributes = _moved_attributes
|
| 335 |
+
|
| 336 |
+
moves = _MovedItems(__name__ + ".moves")
|
| 337 |
+
_importer._add_module(moves, "moves")
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
class Module_six_moves_urllib_parse(_LazyModule):
|
| 341 |
+
|
| 342 |
+
"""Lazy loading of moved objects in six.moves.urllib_parse"""
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
_urllib_parse_moved_attributes = [
|
| 346 |
+
MovedAttribute("ParseResult", "urlparse", "urllib.parse"),
|
| 347 |
+
MovedAttribute("SplitResult", "urlparse", "urllib.parse"),
|
| 348 |
+
MovedAttribute("parse_qs", "urlparse", "urllib.parse"),
|
| 349 |
+
MovedAttribute("parse_qsl", "urlparse", "urllib.parse"),
|
| 350 |
+
MovedAttribute("urldefrag", "urlparse", "urllib.parse"),
|
| 351 |
+
MovedAttribute("urljoin", "urlparse", "urllib.parse"),
|
| 352 |
+
MovedAttribute("urlparse", "urlparse", "urllib.parse"),
|
| 353 |
+
MovedAttribute("urlsplit", "urlparse", "urllib.parse"),
|
| 354 |
+
MovedAttribute("urlunparse", "urlparse", "urllib.parse"),
|
| 355 |
+
MovedAttribute("urlunsplit", "urlparse", "urllib.parse"),
|
| 356 |
+
MovedAttribute("quote", "urllib", "urllib.parse"),
|
| 357 |
+
MovedAttribute("quote_plus", "urllib", "urllib.parse"),
|
| 358 |
+
MovedAttribute("unquote", "urllib", "urllib.parse"),
|
| 359 |
+
MovedAttribute("unquote_plus", "urllib", "urllib.parse"),
|
| 360 |
+
MovedAttribute("unquote_to_bytes", "urllib", "urllib.parse", "unquote", "unquote_to_bytes"),
|
| 361 |
+
MovedAttribute("urlencode", "urllib", "urllib.parse"),
|
| 362 |
+
MovedAttribute("splitquery", "urllib", "urllib.parse"),
|
| 363 |
+
MovedAttribute("splittag", "urllib", "urllib.parse"),
|
| 364 |
+
MovedAttribute("splituser", "urllib", "urllib.parse"),
|
| 365 |
+
MovedAttribute("splitvalue", "urllib", "urllib.parse"),
|
| 366 |
+
MovedAttribute("uses_fragment", "urlparse", "urllib.parse"),
|
| 367 |
+
MovedAttribute("uses_netloc", "urlparse", "urllib.parse"),
|
| 368 |
+
MovedAttribute("uses_params", "urlparse", "urllib.parse"),
|
| 369 |
+
MovedAttribute("uses_query", "urlparse", "urllib.parse"),
|
| 370 |
+
MovedAttribute("uses_relative", "urlparse", "urllib.parse"),
|
| 371 |
+
]
|
| 372 |
+
for attr in _urllib_parse_moved_attributes:
|
| 373 |
+
setattr(Module_six_moves_urllib_parse, attr.name, attr)
|
| 374 |
+
del attr
|
| 375 |
+
|
| 376 |
+
Module_six_moves_urllib_parse._moved_attributes = _urllib_parse_moved_attributes
|
| 377 |
+
|
| 378 |
+
_importer._add_module(Module_six_moves_urllib_parse(__name__ + ".moves.urllib_parse"),
|
| 379 |
+
"moves.urllib_parse", "moves.urllib.parse")
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
class Module_six_moves_urllib_error(_LazyModule):
|
| 383 |
+
|
| 384 |
+
"""Lazy loading of moved objects in six.moves.urllib_error"""
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
_urllib_error_moved_attributes = [
|
| 388 |
+
MovedAttribute("URLError", "urllib2", "urllib.error"),
|
| 389 |
+
MovedAttribute("HTTPError", "urllib2", "urllib.error"),
|
| 390 |
+
MovedAttribute("ContentTooShortError", "urllib", "urllib.error"),
|
| 391 |
+
]
|
| 392 |
+
for attr in _urllib_error_moved_attributes:
|
| 393 |
+
setattr(Module_six_moves_urllib_error, attr.name, attr)
|
| 394 |
+
del attr
|
| 395 |
+
|
| 396 |
+
Module_six_moves_urllib_error._moved_attributes = _urllib_error_moved_attributes
|
| 397 |
+
|
| 398 |
+
_importer._add_module(Module_six_moves_urllib_error(__name__ + ".moves.urllib.error"),
|
| 399 |
+
"moves.urllib_error", "moves.urllib.error")
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
class Module_six_moves_urllib_request(_LazyModule):
|
| 403 |
+
|
| 404 |
+
"""Lazy loading of moved objects in six.moves.urllib_request"""
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
_urllib_request_moved_attributes = [
|
| 408 |
+
MovedAttribute("urlopen", "urllib2", "urllib.request"),
|
| 409 |
+
MovedAttribute("install_opener", "urllib2", "urllib.request"),
|
| 410 |
+
MovedAttribute("build_opener", "urllib2", "urllib.request"),
|
| 411 |
+
MovedAttribute("pathname2url", "urllib", "urllib.request"),
|
| 412 |
+
MovedAttribute("url2pathname", "urllib", "urllib.request"),
|
| 413 |
+
MovedAttribute("getproxies", "urllib", "urllib.request"),
|
| 414 |
+
MovedAttribute("Request", "urllib2", "urllib.request"),
|
| 415 |
+
MovedAttribute("OpenerDirector", "urllib2", "urllib.request"),
|
| 416 |
+
MovedAttribute("HTTPDefaultErrorHandler", "urllib2", "urllib.request"),
|
| 417 |
+
MovedAttribute("HTTPRedirectHandler", "urllib2", "urllib.request"),
|
| 418 |
+
MovedAttribute("HTTPCookieProcessor", "urllib2", "urllib.request"),
|
| 419 |
+
MovedAttribute("ProxyHandler", "urllib2", "urllib.request"),
|
| 420 |
+
MovedAttribute("BaseHandler", "urllib2", "urllib.request"),
|
| 421 |
+
MovedAttribute("HTTPPasswordMgr", "urllib2", "urllib.request"),
|
| 422 |
+
MovedAttribute("HTTPPasswordMgrWithDefaultRealm", "urllib2", "urllib.request"),
|
| 423 |
+
MovedAttribute("AbstractBasicAuthHandler", "urllib2", "urllib.request"),
|
| 424 |
+
MovedAttribute("HTTPBasicAuthHandler", "urllib2", "urllib.request"),
|
| 425 |
+
MovedAttribute("ProxyBasicAuthHandler", "urllib2", "urllib.request"),
|
| 426 |
+
MovedAttribute("AbstractDigestAuthHandler", "urllib2", "urllib.request"),
|
| 427 |
+
MovedAttribute("HTTPDigestAuthHandler", "urllib2", "urllib.request"),
|
| 428 |
+
MovedAttribute("ProxyDigestAuthHandler", "urllib2", "urllib.request"),
|
| 429 |
+
MovedAttribute("HTTPHandler", "urllib2", "urllib.request"),
|
| 430 |
+
MovedAttribute("HTTPSHandler", "urllib2", "urllib.request"),
|
| 431 |
+
MovedAttribute("FileHandler", "urllib2", "urllib.request"),
|
| 432 |
+
MovedAttribute("FTPHandler", "urllib2", "urllib.request"),
|
| 433 |
+
MovedAttribute("CacheFTPHandler", "urllib2", "urllib.request"),
|
| 434 |
+
MovedAttribute("UnknownHandler", "urllib2", "urllib.request"),
|
| 435 |
+
MovedAttribute("HTTPErrorProcessor", "urllib2", "urllib.request"),
|
| 436 |
+
MovedAttribute("urlretrieve", "urllib", "urllib.request"),
|
| 437 |
+
MovedAttribute("urlcleanup", "urllib", "urllib.request"),
|
| 438 |
+
MovedAttribute("proxy_bypass", "urllib", "urllib.request"),
|
| 439 |
+
MovedAttribute("parse_http_list", "urllib2", "urllib.request"),
|
| 440 |
+
MovedAttribute("parse_keqv_list", "urllib2", "urllib.request"),
|
| 441 |
+
]
|
| 442 |
+
if sys.version_info[:2] < (3, 14):
|
| 443 |
+
_urllib_request_moved_attributes.extend(
|
| 444 |
+
[
|
| 445 |
+
MovedAttribute("URLopener", "urllib", "urllib.request"),
|
| 446 |
+
MovedAttribute("FancyURLopener", "urllib", "urllib.request"),
|
| 447 |
+
]
|
| 448 |
+
)
|
| 449 |
+
for attr in _urllib_request_moved_attributes:
|
| 450 |
+
setattr(Module_six_moves_urllib_request, attr.name, attr)
|
| 451 |
+
del attr
|
| 452 |
+
|
| 453 |
+
Module_six_moves_urllib_request._moved_attributes = _urllib_request_moved_attributes
|
| 454 |
+
|
| 455 |
+
_importer._add_module(Module_six_moves_urllib_request(__name__ + ".moves.urllib.request"),
|
| 456 |
+
"moves.urllib_request", "moves.urllib.request")
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
class Module_six_moves_urllib_response(_LazyModule):
|
| 460 |
+
|
| 461 |
+
"""Lazy loading of moved objects in six.moves.urllib_response"""
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
_urllib_response_moved_attributes = [
|
| 465 |
+
MovedAttribute("addbase", "urllib", "urllib.response"),
|
| 466 |
+
MovedAttribute("addclosehook", "urllib", "urllib.response"),
|
| 467 |
+
MovedAttribute("addinfo", "urllib", "urllib.response"),
|
| 468 |
+
MovedAttribute("addinfourl", "urllib", "urllib.response"),
|
| 469 |
+
]
|
| 470 |
+
for attr in _urllib_response_moved_attributes:
|
| 471 |
+
setattr(Module_six_moves_urllib_response, attr.name, attr)
|
| 472 |
+
del attr
|
| 473 |
+
|
| 474 |
+
Module_six_moves_urllib_response._moved_attributes = _urllib_response_moved_attributes
|
| 475 |
+
|
| 476 |
+
_importer._add_module(Module_six_moves_urllib_response(__name__ + ".moves.urllib.response"),
|
| 477 |
+
"moves.urllib_response", "moves.urllib.response")
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
class Module_six_moves_urllib_robotparser(_LazyModule):
|
| 481 |
+
|
| 482 |
+
"""Lazy loading of moved objects in six.moves.urllib_robotparser"""
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
_urllib_robotparser_moved_attributes = [
|
| 486 |
+
MovedAttribute("RobotFileParser", "robotparser", "urllib.robotparser"),
|
| 487 |
+
]
|
| 488 |
+
for attr in _urllib_robotparser_moved_attributes:
|
| 489 |
+
setattr(Module_six_moves_urllib_robotparser, attr.name, attr)
|
| 490 |
+
del attr
|
| 491 |
+
|
| 492 |
+
Module_six_moves_urllib_robotparser._moved_attributes = _urllib_robotparser_moved_attributes
|
| 493 |
+
|
| 494 |
+
_importer._add_module(Module_six_moves_urllib_robotparser(__name__ + ".moves.urllib.robotparser"),
|
| 495 |
+
"moves.urllib_robotparser", "moves.urllib.robotparser")
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
class Module_six_moves_urllib(types.ModuleType):
|
| 499 |
+
|
| 500 |
+
"""Create a six.moves.urllib namespace that resembles the Python 3 namespace"""
|
| 501 |
+
__path__ = [] # mark as package
|
| 502 |
+
parse = _importer._get_module("moves.urllib_parse")
|
| 503 |
+
error = _importer._get_module("moves.urllib_error")
|
| 504 |
+
request = _importer._get_module("moves.urllib_request")
|
| 505 |
+
response = _importer._get_module("moves.urllib_response")
|
| 506 |
+
robotparser = _importer._get_module("moves.urllib_robotparser")
|
| 507 |
+
|
| 508 |
+
def __dir__(self):
|
| 509 |
+
return ['parse', 'error', 'request', 'response', 'robotparser']
|
| 510 |
+
|
| 511 |
+
_importer._add_module(Module_six_moves_urllib(__name__ + ".moves.urllib"),
|
| 512 |
+
"moves.urllib")
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
def add_move(move):
|
| 516 |
+
"""Add an item to six.moves."""
|
| 517 |
+
setattr(_MovedItems, move.name, move)
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
def remove_move(name):
|
| 521 |
+
"""Remove item from six.moves."""
|
| 522 |
+
try:
|
| 523 |
+
delattr(_MovedItems, name)
|
| 524 |
+
except AttributeError:
|
| 525 |
+
try:
|
| 526 |
+
del moves.__dict__[name]
|
| 527 |
+
except KeyError:
|
| 528 |
+
raise AttributeError("no such move, %r" % (name,))
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
if PY3:
|
| 532 |
+
_meth_func = "__func__"
|
| 533 |
+
_meth_self = "__self__"
|
| 534 |
+
|
| 535 |
+
_func_closure = "__closure__"
|
| 536 |
+
_func_code = "__code__"
|
| 537 |
+
_func_defaults = "__defaults__"
|
| 538 |
+
_func_globals = "__globals__"
|
| 539 |
+
else:
|
| 540 |
+
_meth_func = "im_func"
|
| 541 |
+
_meth_self = "im_self"
|
| 542 |
+
|
| 543 |
+
_func_closure = "func_closure"
|
| 544 |
+
_func_code = "func_code"
|
| 545 |
+
_func_defaults = "func_defaults"
|
| 546 |
+
_func_globals = "func_globals"
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
try:
|
| 550 |
+
advance_iterator = next
|
| 551 |
+
except NameError:
|
| 552 |
+
def advance_iterator(it):
|
| 553 |
+
return it.next()
|
| 554 |
+
next = advance_iterator
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
try:
|
| 558 |
+
callable = callable
|
| 559 |
+
except NameError:
|
| 560 |
+
def callable(obj):
|
| 561 |
+
return any("__call__" in klass.__dict__ for klass in type(obj).__mro__)
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
if PY3:
|
| 565 |
+
def get_unbound_function(unbound):
|
| 566 |
+
return unbound
|
| 567 |
+
|
| 568 |
+
create_bound_method = types.MethodType
|
| 569 |
+
|
| 570 |
+
def create_unbound_method(func, cls):
|
| 571 |
+
return func
|
| 572 |
+
|
| 573 |
+
Iterator = object
|
| 574 |
+
else:
|
| 575 |
+
def get_unbound_function(unbound):
|
| 576 |
+
return unbound.im_func
|
| 577 |
+
|
| 578 |
+
def create_bound_method(func, obj):
|
| 579 |
+
return types.MethodType(func, obj, obj.__class__)
|
| 580 |
+
|
| 581 |
+
def create_unbound_method(func, cls):
|
| 582 |
+
return types.MethodType(func, None, cls)
|
| 583 |
+
|
| 584 |
+
class Iterator(object):
|
| 585 |
+
|
| 586 |
+
def next(self):
|
| 587 |
+
return type(self).__next__(self)
|
| 588 |
+
|
| 589 |
+
callable = callable
|
| 590 |
+
_add_doc(get_unbound_function,
|
| 591 |
+
"""Get the function out of a possibly unbound function""")
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
get_method_function = operator.attrgetter(_meth_func)
|
| 595 |
+
get_method_self = operator.attrgetter(_meth_self)
|
| 596 |
+
get_function_closure = operator.attrgetter(_func_closure)
|
| 597 |
+
get_function_code = operator.attrgetter(_func_code)
|
| 598 |
+
get_function_defaults = operator.attrgetter(_func_defaults)
|
| 599 |
+
get_function_globals = operator.attrgetter(_func_globals)
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
if PY3:
|
| 603 |
+
def iterkeys(d, **kw):
|
| 604 |
+
return iter(d.keys(**kw))
|
| 605 |
+
|
| 606 |
+
def itervalues(d, **kw):
|
| 607 |
+
return iter(d.values(**kw))
|
| 608 |
+
|
| 609 |
+
def iteritems(d, **kw):
|
| 610 |
+
return iter(d.items(**kw))
|
| 611 |
+
|
| 612 |
+
def iterlists(d, **kw):
|
| 613 |
+
return iter(d.lists(**kw))
|
| 614 |
+
|
| 615 |
+
viewkeys = operator.methodcaller("keys")
|
| 616 |
+
|
| 617 |
+
viewvalues = operator.methodcaller("values")
|
| 618 |
+
|
| 619 |
+
viewitems = operator.methodcaller("items")
|
| 620 |
+
else:
|
| 621 |
+
def iterkeys(d, **kw):
|
| 622 |
+
return d.iterkeys(**kw)
|
| 623 |
+
|
| 624 |
+
def itervalues(d, **kw):
|
| 625 |
+
return d.itervalues(**kw)
|
| 626 |
+
|
| 627 |
+
def iteritems(d, **kw):
|
| 628 |
+
return d.iteritems(**kw)
|
| 629 |
+
|
| 630 |
+
def iterlists(d, **kw):
|
| 631 |
+
return d.iterlists(**kw)
|
| 632 |
+
|
| 633 |
+
viewkeys = operator.methodcaller("viewkeys")
|
| 634 |
+
|
| 635 |
+
viewvalues = operator.methodcaller("viewvalues")
|
| 636 |
+
|
| 637 |
+
viewitems = operator.methodcaller("viewitems")
|
| 638 |
+
|
| 639 |
+
_add_doc(iterkeys, "Return an iterator over the keys of a dictionary.")
|
| 640 |
+
_add_doc(itervalues, "Return an iterator over the values of a dictionary.")
|
| 641 |
+
_add_doc(iteritems,
|
| 642 |
+
"Return an iterator over the (key, value) pairs of a dictionary.")
|
| 643 |
+
_add_doc(iterlists,
|
| 644 |
+
"Return an iterator over the (key, [values]) pairs of a dictionary.")
|
| 645 |
+
|
| 646 |
+
|
| 647 |
+
if PY3:
|
| 648 |
+
def b(s):
|
| 649 |
+
return s.encode("latin-1")
|
| 650 |
+
|
| 651 |
+
def u(s):
|
| 652 |
+
return s
|
| 653 |
+
unichr = chr
|
| 654 |
+
import struct
|
| 655 |
+
int2byte = struct.Struct(">B").pack
|
| 656 |
+
del struct
|
| 657 |
+
byte2int = operator.itemgetter(0)
|
| 658 |
+
indexbytes = operator.getitem
|
| 659 |
+
iterbytes = iter
|
| 660 |
+
import io
|
| 661 |
+
StringIO = io.StringIO
|
| 662 |
+
BytesIO = io.BytesIO
|
| 663 |
+
del io
|
| 664 |
+
_assertCountEqual = "assertCountEqual"
|
| 665 |
+
if sys.version_info[1] <= 1:
|
| 666 |
+
_assertRaisesRegex = "assertRaisesRegexp"
|
| 667 |
+
_assertRegex = "assertRegexpMatches"
|
| 668 |
+
_assertNotRegex = "assertNotRegexpMatches"
|
| 669 |
+
else:
|
| 670 |
+
_assertRaisesRegex = "assertRaisesRegex"
|
| 671 |
+
_assertRegex = "assertRegex"
|
| 672 |
+
_assertNotRegex = "assertNotRegex"
|
| 673 |
+
else:
|
| 674 |
+
def b(s):
|
| 675 |
+
return s
|
| 676 |
+
# Workaround for standalone backslash
|
| 677 |
+
|
| 678 |
+
def u(s):
|
| 679 |
+
return unicode(s.replace(r'\\', r'\\\\'), "unicode_escape")
|
| 680 |
+
unichr = unichr
|
| 681 |
+
int2byte = chr
|
| 682 |
+
|
| 683 |
+
def byte2int(bs):
|
| 684 |
+
return ord(bs[0])
|
| 685 |
+
|
| 686 |
+
def indexbytes(buf, i):
|
| 687 |
+
return ord(buf[i])
|
| 688 |
+
iterbytes = functools.partial(itertools.imap, ord)
|
| 689 |
+
import StringIO
|
| 690 |
+
StringIO = BytesIO = StringIO.StringIO
|
| 691 |
+
_assertCountEqual = "assertItemsEqual"
|
| 692 |
+
_assertRaisesRegex = "assertRaisesRegexp"
|
| 693 |
+
_assertRegex = "assertRegexpMatches"
|
| 694 |
+
_assertNotRegex = "assertNotRegexpMatches"
|
| 695 |
+
_add_doc(b, """Byte literal""")
|
| 696 |
+
_add_doc(u, """Text literal""")
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
def assertCountEqual(self, *args, **kwargs):
|
| 700 |
+
return getattr(self, _assertCountEqual)(*args, **kwargs)
|
| 701 |
+
|
| 702 |
+
|
| 703 |
+
def assertRaisesRegex(self, *args, **kwargs):
|
| 704 |
+
return getattr(self, _assertRaisesRegex)(*args, **kwargs)
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
def assertRegex(self, *args, **kwargs):
|
| 708 |
+
return getattr(self, _assertRegex)(*args, **kwargs)
|
| 709 |
+
|
| 710 |
+
|
| 711 |
+
def assertNotRegex(self, *args, **kwargs):
|
| 712 |
+
return getattr(self, _assertNotRegex)(*args, **kwargs)
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
if PY3:
|
| 716 |
+
exec_ = getattr(moves.builtins, "exec")
|
| 717 |
+
|
| 718 |
+
def reraise(tp, value, tb=None):
|
| 719 |
+
try:
|
| 720 |
+
if value is None:
|
| 721 |
+
value = tp()
|
| 722 |
+
if value.__traceback__ is not tb:
|
| 723 |
+
raise value.with_traceback(tb)
|
| 724 |
+
raise value
|
| 725 |
+
finally:
|
| 726 |
+
value = None
|
| 727 |
+
tb = None
|
| 728 |
+
|
| 729 |
+
else:
|
| 730 |
+
def exec_(_code_, _globs_=None, _locs_=None):
|
| 731 |
+
"""Execute code in a namespace."""
|
| 732 |
+
if _globs_ is None:
|
| 733 |
+
frame = sys._getframe(1)
|
| 734 |
+
_globs_ = frame.f_globals
|
| 735 |
+
if _locs_ is None:
|
| 736 |
+
_locs_ = frame.f_locals
|
| 737 |
+
del frame
|
| 738 |
+
elif _locs_ is None:
|
| 739 |
+
_locs_ = _globs_
|
| 740 |
+
exec("""exec _code_ in _globs_, _locs_""")
|
| 741 |
+
|
| 742 |
+
exec_("""def reraise(tp, value, tb=None):
|
| 743 |
+
try:
|
| 744 |
+
raise tp, value, tb
|
| 745 |
+
finally:
|
| 746 |
+
tb = None
|
| 747 |
+
""")
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
if sys.version_info[:2] > (3,):
|
| 751 |
+
exec_("""def raise_from(value, from_value):
|
| 752 |
+
try:
|
| 753 |
+
raise value from from_value
|
| 754 |
+
finally:
|
| 755 |
+
value = None
|
| 756 |
+
""")
|
| 757 |
+
else:
|
| 758 |
+
def raise_from(value, from_value):
|
| 759 |
+
raise value
|
| 760 |
+
|
| 761 |
+
|
| 762 |
+
print_ = getattr(moves.builtins, "print", None)
|
| 763 |
+
if print_ is None:
|
| 764 |
+
def print_(*args, **kwargs):
|
| 765 |
+
"""The new-style print function for Python 2.4 and 2.5."""
|
| 766 |
+
fp = kwargs.pop("file", sys.stdout)
|
| 767 |
+
if fp is None:
|
| 768 |
+
return
|
| 769 |
+
|
| 770 |
+
def write(data):
|
| 771 |
+
if not isinstance(data, basestring):
|
| 772 |
+
data = str(data)
|
| 773 |
+
# If the file has an encoding, encode unicode with it.
|
| 774 |
+
if (isinstance(fp, file) and
|
| 775 |
+
isinstance(data, unicode) and
|
| 776 |
+
fp.encoding is not None):
|
| 777 |
+
errors = getattr(fp, "errors", None)
|
| 778 |
+
if errors is None:
|
| 779 |
+
errors = "strict"
|
| 780 |
+
data = data.encode(fp.encoding, errors)
|
| 781 |
+
fp.write(data)
|
| 782 |
+
want_unicode = False
|
| 783 |
+
sep = kwargs.pop("sep", None)
|
| 784 |
+
if sep is not None:
|
| 785 |
+
if isinstance(sep, unicode):
|
| 786 |
+
want_unicode = True
|
| 787 |
+
elif not isinstance(sep, str):
|
| 788 |
+
raise TypeError("sep must be None or a string")
|
| 789 |
+
end = kwargs.pop("end", None)
|
| 790 |
+
if end is not None:
|
| 791 |
+
if isinstance(end, unicode):
|
| 792 |
+
want_unicode = True
|
| 793 |
+
elif not isinstance(end, str):
|
| 794 |
+
raise TypeError("end must be None or a string")
|
| 795 |
+
if kwargs:
|
| 796 |
+
raise TypeError("invalid keyword arguments to print()")
|
| 797 |
+
if not want_unicode:
|
| 798 |
+
for arg in args:
|
| 799 |
+
if isinstance(arg, unicode):
|
| 800 |
+
want_unicode = True
|
| 801 |
+
break
|
| 802 |
+
if want_unicode:
|
| 803 |
+
newline = unicode("\n")
|
| 804 |
+
space = unicode(" ")
|
| 805 |
+
else:
|
| 806 |
+
newline = "\n"
|
| 807 |
+
space = " "
|
| 808 |
+
if sep is None:
|
| 809 |
+
sep = space
|
| 810 |
+
if end is None:
|
| 811 |
+
end = newline
|
| 812 |
+
for i, arg in enumerate(args):
|
| 813 |
+
if i:
|
| 814 |
+
write(sep)
|
| 815 |
+
write(arg)
|
| 816 |
+
write(end)
|
| 817 |
+
if sys.version_info[:2] < (3, 3):
|
| 818 |
+
_print = print_
|
| 819 |
+
|
| 820 |
+
def print_(*args, **kwargs):
|
| 821 |
+
fp = kwargs.get("file", sys.stdout)
|
| 822 |
+
flush = kwargs.pop("flush", False)
|
| 823 |
+
_print(*args, **kwargs)
|
| 824 |
+
if flush and fp is not None:
|
| 825 |
+
fp.flush()
|
| 826 |
+
|
| 827 |
+
_add_doc(reraise, """Reraise an exception.""")
|
| 828 |
+
|
| 829 |
+
if sys.version_info[0:2] < (3, 4):
|
| 830 |
+
# This does exactly the same what the :func:`py3:functools.update_wrapper`
|
| 831 |
+
# function does on Python versions after 3.2. It sets the ``__wrapped__``
|
| 832 |
+
# attribute on ``wrapper`` object and it doesn't raise an error if any of
|
| 833 |
+
# the attributes mentioned in ``assigned`` and ``updated`` are missing on
|
| 834 |
+
# ``wrapped`` object.
|
| 835 |
+
def _update_wrapper(wrapper, wrapped,
|
| 836 |
+
assigned=functools.WRAPPER_ASSIGNMENTS,
|
| 837 |
+
updated=functools.WRAPPER_UPDATES):
|
| 838 |
+
for attr in assigned:
|
| 839 |
+
try:
|
| 840 |
+
value = getattr(wrapped, attr)
|
| 841 |
+
except AttributeError:
|
| 842 |
+
continue
|
| 843 |
+
else:
|
| 844 |
+
setattr(wrapper, attr, value)
|
| 845 |
+
for attr in updated:
|
| 846 |
+
getattr(wrapper, attr).update(getattr(wrapped, attr, {}))
|
| 847 |
+
wrapper.__wrapped__ = wrapped
|
| 848 |
+
return wrapper
|
| 849 |
+
_update_wrapper.__doc__ = functools.update_wrapper.__doc__
|
| 850 |
+
|
| 851 |
+
def wraps(wrapped, assigned=functools.WRAPPER_ASSIGNMENTS,
|
| 852 |
+
updated=functools.WRAPPER_UPDATES):
|
| 853 |
+
return functools.partial(_update_wrapper, wrapped=wrapped,
|
| 854 |
+
assigned=assigned, updated=updated)
|
| 855 |
+
wraps.__doc__ = functools.wraps.__doc__
|
| 856 |
+
|
| 857 |
+
else:
|
| 858 |
+
wraps = functools.wraps
|
| 859 |
+
|
| 860 |
+
|
| 861 |
+
def with_metaclass(meta, *bases):
|
| 862 |
+
"""Create a base class with a metaclass."""
|
| 863 |
+
# This requires a bit of explanation: the basic idea is to make a dummy
|
| 864 |
+
# metaclass for one level of class instantiation that replaces itself with
|
| 865 |
+
# the actual metaclass.
|
| 866 |
+
class metaclass(type):
|
| 867 |
+
|
| 868 |
+
def __new__(cls, name, this_bases, d):
|
| 869 |
+
if sys.version_info[:2] >= (3, 7):
|
| 870 |
+
# This version introduced PEP 560 that requires a bit
|
| 871 |
+
# of extra care (we mimic what is done by __build_class__).
|
| 872 |
+
resolved_bases = types.resolve_bases(bases)
|
| 873 |
+
if resolved_bases is not bases:
|
| 874 |
+
d['__orig_bases__'] = bases
|
| 875 |
+
else:
|
| 876 |
+
resolved_bases = bases
|
| 877 |
+
return meta(name, resolved_bases, d)
|
| 878 |
+
|
| 879 |
+
@classmethod
|
| 880 |
+
def __prepare__(cls, name, this_bases):
|
| 881 |
+
return meta.__prepare__(name, bases)
|
| 882 |
+
return type.__new__(metaclass, 'temporary_class', (), {})
|
| 883 |
+
|
| 884 |
+
|
| 885 |
+
def add_metaclass(metaclass):
|
| 886 |
+
"""Class decorator for creating a class with a metaclass."""
|
| 887 |
+
def wrapper(cls):
|
| 888 |
+
orig_vars = cls.__dict__.copy()
|
| 889 |
+
slots = orig_vars.get('__slots__')
|
| 890 |
+
if slots is not None:
|
| 891 |
+
if isinstance(slots, str):
|
| 892 |
+
slots = [slots]
|
| 893 |
+
for slots_var in slots:
|
| 894 |
+
orig_vars.pop(slots_var)
|
| 895 |
+
orig_vars.pop('__dict__', None)
|
| 896 |
+
orig_vars.pop('__weakref__', None)
|
| 897 |
+
if hasattr(cls, '__qualname__'):
|
| 898 |
+
orig_vars['__qualname__'] = cls.__qualname__
|
| 899 |
+
return metaclass(cls.__name__, cls.__bases__, orig_vars)
|
| 900 |
+
return wrapper
|
| 901 |
+
|
| 902 |
+
|
| 903 |
+
def ensure_binary(s, encoding='utf-8', errors='strict'):
|
| 904 |
+
"""Coerce **s** to six.binary_type.
|
| 905 |
+
|
| 906 |
+
For Python 2:
|
| 907 |
+
- `unicode` -> encoded to `str`
|
| 908 |
+
- `str` -> `str`
|
| 909 |
+
|
| 910 |
+
For Python 3:
|
| 911 |
+
- `str` -> encoded to `bytes`
|
| 912 |
+
- `bytes` -> `bytes`
|
| 913 |
+
"""
|
| 914 |
+
if isinstance(s, binary_type):
|
| 915 |
+
return s
|
| 916 |
+
if isinstance(s, text_type):
|
| 917 |
+
return s.encode(encoding, errors)
|
| 918 |
+
raise TypeError("not expecting type '%s'" % type(s))
|
| 919 |
+
|
| 920 |
+
|
| 921 |
+
def ensure_str(s, encoding='utf-8', errors='strict'):
|
| 922 |
+
"""Coerce *s* to `str`.
|
| 923 |
+
|
| 924 |
+
For Python 2:
|
| 925 |
+
- `unicode` -> encoded to `str`
|
| 926 |
+
- `str` -> `str`
|
| 927 |
+
|
| 928 |
+
For Python 3:
|
| 929 |
+
- `str` -> `str`
|
| 930 |
+
- `bytes` -> decoded to `str`
|
| 931 |
+
"""
|
| 932 |
+
# Optimization: Fast return for the common case.
|
| 933 |
+
if type(s) is str:
|
| 934 |
+
return s
|
| 935 |
+
if PY2 and isinstance(s, text_type):
|
| 936 |
+
return s.encode(encoding, errors)
|
| 937 |
+
elif PY3 and isinstance(s, binary_type):
|
| 938 |
+
return s.decode(encoding, errors)
|
| 939 |
+
elif not isinstance(s, (text_type, binary_type)):
|
| 940 |
+
raise TypeError("not expecting type '%s'" % type(s))
|
| 941 |
+
return s
|
| 942 |
+
|
| 943 |
+
|
| 944 |
+
def ensure_text(s, encoding='utf-8', errors='strict'):
|
| 945 |
+
"""Coerce *s* to six.text_type.
|
| 946 |
+
|
| 947 |
+
For Python 2:
|
| 948 |
+
- `unicode` -> `unicode`
|
| 949 |
+
- `str` -> `unicode`
|
| 950 |
+
|
| 951 |
+
For Python 3:
|
| 952 |
+
- `str` -> `str`
|
| 953 |
+
- `bytes` -> decoded to `str`
|
| 954 |
+
"""
|
| 955 |
+
if isinstance(s, binary_type):
|
| 956 |
+
return s.decode(encoding, errors)
|
| 957 |
+
elif isinstance(s, text_type):
|
| 958 |
+
return s
|
| 959 |
+
else:
|
| 960 |
+
raise TypeError("not expecting type '%s'" % type(s))
|
| 961 |
+
|
| 962 |
+
|
| 963 |
+
def python_2_unicode_compatible(klass):
|
| 964 |
+
"""
|
| 965 |
+
A class decorator that defines __unicode__ and __str__ methods under Python 2.
|
| 966 |
+
Under Python 3 it does nothing.
|
| 967 |
+
|
| 968 |
+
To support Python 2 and 3 with a single code base, define a __str__ method
|
| 969 |
+
returning text and apply this decorator to the class.
|
| 970 |
+
"""
|
| 971 |
+
if PY2:
|
| 972 |
+
if '__str__' not in klass.__dict__:
|
| 973 |
+
raise ValueError("@python_2_unicode_compatible cannot be applied "
|
| 974 |
+
"to %s because it doesn't define __str__()." %
|
| 975 |
+
klass.__name__)
|
| 976 |
+
klass.__unicode__ = klass.__str__
|
| 977 |
+
klass.__str__ = lambda self: self.__unicode__().encode('utf-8')
|
| 978 |
+
return klass
|
| 979 |
+
|
| 980 |
+
|
| 981 |
+
# Complete the moves implementation.
|
| 982 |
+
# This code is at the end of this module to speed up module loading.
|
| 983 |
+
# Turn this module into a package.
|
| 984 |
+
__path__ = [] # required for PEP 302 and PEP 451
|
| 985 |
+
__package__ = __name__ # see PEP 366 @ReservedAssignment
|
| 986 |
+
if globals().get("__spec__") is not None:
|
| 987 |
+
__spec__.submodule_search_locations = [] # PEP 451 @UndefinedVariable
|
| 988 |
+
# Remove other six meta path importers, since they cause problems. This can
|
| 989 |
+
# happen if six is removed from sys.modules and then reloaded. (Setuptools does
|
| 990 |
+
# this for some reason.)
|
| 991 |
+
if sys.meta_path:
|
| 992 |
+
for i, importer in enumerate(sys.meta_path):
|
| 993 |
+
# Here's some real nastiness: Another "instance" of the six module might
|
| 994 |
+
# be floating around. Therefore, we can't use isinstance() to check for
|
| 995 |
+
# the six meta path importer, since the other six instance will have
|
| 996 |
+
# inserted an importer with different class.
|
| 997 |
+
if (type(importer).__name__ == "_SixMetaPathImporter" and
|
| 998 |
+
importer.name == __name__):
|
| 999 |
+
del sys.meta_path[i]
|
| 1000 |
+
break
|
| 1001 |
+
del i, importer
|
| 1002 |
+
# Finally, add the importer to the meta path import hook.
|
| 1003 |
+
sys.meta_path.append(_importer)
|
venv/Lib/site-packages/threadpoolctl.py
ADDED
|
@@ -0,0 +1,1292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""threadpoolctl
|
| 2 |
+
|
| 3 |
+
This module provides utilities to introspect native libraries that relies on
|
| 4 |
+
thread pools (notably BLAS and OpenMP implementations) and dynamically set the
|
| 5 |
+
maximal number of threads they can use.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
# License: BSD 3-Clause
|
| 9 |
+
|
| 10 |
+
# The code to introspect dynamically loaded libraries on POSIX systems is
|
| 11 |
+
# adapted from code by Intel developer @anton-malakhov available at
|
| 12 |
+
# https://github.com/IntelPython/smp (Copyright (c) 2017, Intel Corporation)
|
| 13 |
+
# and also published under the BSD 3-Clause license
|
| 14 |
+
import os
|
| 15 |
+
import re
|
| 16 |
+
import sys
|
| 17 |
+
import ctypes
|
| 18 |
+
import itertools
|
| 19 |
+
import textwrap
|
| 20 |
+
from typing import final
|
| 21 |
+
import warnings
|
| 22 |
+
from ctypes.util import find_library
|
| 23 |
+
from abc import ABC, abstractmethod
|
| 24 |
+
from functools import lru_cache
|
| 25 |
+
from contextlib import ContextDecorator
|
| 26 |
+
|
| 27 |
+
__version__ = "3.6.0"
|
| 28 |
+
__all__ = [
|
| 29 |
+
"threadpool_limits",
|
| 30 |
+
"threadpool_info",
|
| 31 |
+
"ThreadpoolController",
|
| 32 |
+
"LibController",
|
| 33 |
+
"register",
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# One can get runtime errors or even segfaults due to multiple OpenMP libraries
|
| 38 |
+
# loaded simultaneously which can happen easily in Python when importing and
|
| 39 |
+
# using compiled extensions built with different compilers and therefore
|
| 40 |
+
# different OpenMP runtimes in the same program. In particular libiomp (used by
|
| 41 |
+
# Intel ICC) and libomp used by clang/llvm tend to crash. This can happen for
|
| 42 |
+
# instance when calling BLAS inside a prange. Setting the following environment
|
| 43 |
+
# variable allows multiple OpenMP libraries to be loaded. It should not degrade
|
| 44 |
+
# performances since we manually take care of potential over-subscription
|
| 45 |
+
# performance issues, in sections of the code where nested OpenMP loops can
|
| 46 |
+
# happen, by dynamically reconfiguring the inner OpenMP runtime to temporarily
|
| 47 |
+
# disable it while under the scope of the outer OpenMP parallel section.
|
| 48 |
+
os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "True")
|
| 49 |
+
|
| 50 |
+
# Structure to cast the info on dynamically loaded library. See
|
| 51 |
+
# https://linux.die.net/man/3/dl_iterate_phdr for more details.
|
| 52 |
+
_SYSTEM_UINT = ctypes.c_uint64 if sys.maxsize > 2**32 else ctypes.c_uint32
|
| 53 |
+
_SYSTEM_UINT_HALF = ctypes.c_uint32 if sys.maxsize > 2**32 else ctypes.c_uint16
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class _dl_phdr_info(ctypes.Structure):
|
| 57 |
+
_fields_ = [
|
| 58 |
+
("dlpi_addr", _SYSTEM_UINT), # Base address of object
|
| 59 |
+
("dlpi_name", ctypes.c_char_p), # path to the library
|
| 60 |
+
("dlpi_phdr", ctypes.c_void_p), # pointer on dlpi_headers
|
| 61 |
+
("dlpi_phnum", _SYSTEM_UINT_HALF), # number of elements in dlpi_phdr
|
| 62 |
+
]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# The RTLD_NOLOAD flag for loading shared libraries is not defined on Windows.
|
| 66 |
+
try:
|
| 67 |
+
_RTLD_NOLOAD = os.RTLD_NOLOAD
|
| 68 |
+
except AttributeError:
|
| 69 |
+
_RTLD_NOLOAD = ctypes.DEFAULT_MODE
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class LibController(ABC):
|
| 73 |
+
"""Abstract base class for the individual library controllers
|
| 74 |
+
|
| 75 |
+
A library controller must expose the following class attributes:
|
| 76 |
+
- user_api : str
|
| 77 |
+
Usually the name of the library or generic specification the library
|
| 78 |
+
implements, e.g. "blas" is a specification with different implementations.
|
| 79 |
+
- internal_api : str
|
| 80 |
+
Usually the name of the library or concrete implementation of some
|
| 81 |
+
specification, e.g. "openblas" is an implementation of the "blas"
|
| 82 |
+
specification.
|
| 83 |
+
- filename_prefixes : tuple
|
| 84 |
+
Possible prefixes of the shared library's filename that allow to
|
| 85 |
+
identify the library. e.g. "libopenblas" for libopenblas.so.
|
| 86 |
+
|
| 87 |
+
and implement the following methods: `get_num_threads`, `set_num_threads` and
|
| 88 |
+
`get_version`.
|
| 89 |
+
|
| 90 |
+
Threadpoolctl loops through all the loaded shared libraries and tries to match
|
| 91 |
+
the filename of each library with the `filename_prefixes`. If a match is found, a
|
| 92 |
+
controller is instantiated and a handler to the library is stored in the `dynlib`
|
| 93 |
+
attribute as a `ctypes.CDLL` object. It can be used to access the necessary symbols
|
| 94 |
+
of the shared library to implement the above methods.
|
| 95 |
+
|
| 96 |
+
The following information will be exposed in the info dictionary:
|
| 97 |
+
- user_api : standardized API, if any, or a copy of internal_api.
|
| 98 |
+
- internal_api : implementation-specific API.
|
| 99 |
+
- num_threads : the current thread limit.
|
| 100 |
+
- prefix : prefix of the shared library's filename.
|
| 101 |
+
- filepath : path to the loaded shared library.
|
| 102 |
+
- version : version of the library (if available).
|
| 103 |
+
|
| 104 |
+
In addition, each library controller may expose internal API specific entries. They
|
| 105 |
+
must be set as attributes in the `set_additional_attributes` method.
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
@final
|
| 109 |
+
def __init__(self, *, filepath=None, prefix=None, parent=None):
|
| 110 |
+
"""This is not meant to be overriden by subclasses."""
|
| 111 |
+
self.parent = parent
|
| 112 |
+
self.prefix = prefix
|
| 113 |
+
self.filepath = filepath
|
| 114 |
+
self.dynlib = ctypes.CDLL(filepath, mode=_RTLD_NOLOAD)
|
| 115 |
+
self._symbol_prefix, self._symbol_suffix = self._find_affixes()
|
| 116 |
+
self.version = self.get_version()
|
| 117 |
+
self.set_additional_attributes()
|
| 118 |
+
|
| 119 |
+
def info(self):
|
| 120 |
+
"""Return relevant info wrapped in a dict"""
|
| 121 |
+
hidden_attrs = ("dynlib", "parent", "_symbol_prefix", "_symbol_suffix")
|
| 122 |
+
return {
|
| 123 |
+
"user_api": self.user_api,
|
| 124 |
+
"internal_api": self.internal_api,
|
| 125 |
+
"num_threads": self.num_threads,
|
| 126 |
+
**{k: v for k, v in vars(self).items() if k not in hidden_attrs},
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
def set_additional_attributes(self):
|
| 130 |
+
"""Set additional attributes meant to be exposed in the info dict"""
|
| 131 |
+
|
| 132 |
+
@property
|
| 133 |
+
def num_threads(self):
|
| 134 |
+
"""Exposes the current thread limit as a dynamic property
|
| 135 |
+
|
| 136 |
+
This is not meant to be used or overriden by subclasses.
|
| 137 |
+
"""
|
| 138 |
+
return self.get_num_threads()
|
| 139 |
+
|
| 140 |
+
@abstractmethod
|
| 141 |
+
def get_num_threads(self):
|
| 142 |
+
"""Return the maximum number of threads available to use"""
|
| 143 |
+
|
| 144 |
+
@abstractmethod
|
| 145 |
+
def set_num_threads(self, num_threads):
|
| 146 |
+
"""Set the maximum number of threads to use"""
|
| 147 |
+
|
| 148 |
+
@abstractmethod
|
| 149 |
+
def get_version(self):
|
| 150 |
+
"""Return the version of the shared library"""
|
| 151 |
+
|
| 152 |
+
def _find_affixes(self):
|
| 153 |
+
"""Return the affixes for the symbols of the shared library"""
|
| 154 |
+
return "", ""
|
| 155 |
+
|
| 156 |
+
def _get_symbol(self, name):
|
| 157 |
+
"""Return the symbol of the shared library accounding for the affixes"""
|
| 158 |
+
return getattr(
|
| 159 |
+
self.dynlib, f"{self._symbol_prefix}{name}{self._symbol_suffix}", None
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class OpenBLASController(LibController):
|
| 164 |
+
"""Controller class for OpenBLAS"""
|
| 165 |
+
|
| 166 |
+
user_api = "blas"
|
| 167 |
+
internal_api = "openblas"
|
| 168 |
+
filename_prefixes = ("libopenblas", "libblas", "libscipy_openblas")
|
| 169 |
+
|
| 170 |
+
_symbol_prefixes = ("", "scipy_")
|
| 171 |
+
_symbol_suffixes = ("", "64_", "_64")
|
| 172 |
+
|
| 173 |
+
# All variations of "openblas_get_num_threads", accounting for the affixes
|
| 174 |
+
check_symbols = tuple(
|
| 175 |
+
f"{prefix}openblas_get_num_threads{suffix}"
|
| 176 |
+
for prefix, suffix in itertools.product(_symbol_prefixes, _symbol_suffixes)
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
def _find_affixes(self):
|
| 180 |
+
for prefix, suffix in itertools.product(
|
| 181 |
+
self._symbol_prefixes, self._symbol_suffixes
|
| 182 |
+
):
|
| 183 |
+
if hasattr(self.dynlib, f"{prefix}openblas_get_num_threads{suffix}"):
|
| 184 |
+
return prefix, suffix
|
| 185 |
+
|
| 186 |
+
def set_additional_attributes(self):
|
| 187 |
+
self.threading_layer = self._get_threading_layer()
|
| 188 |
+
self.architecture = self._get_architecture()
|
| 189 |
+
|
| 190 |
+
def get_num_threads(self):
|
| 191 |
+
get_num_threads_func = self._get_symbol("openblas_get_num_threads")
|
| 192 |
+
if get_num_threads_func is not None:
|
| 193 |
+
return get_num_threads_func()
|
| 194 |
+
return None
|
| 195 |
+
|
| 196 |
+
def set_num_threads(self, num_threads):
|
| 197 |
+
set_num_threads_func = self._get_symbol("openblas_set_num_threads")
|
| 198 |
+
if set_num_threads_func is not None:
|
| 199 |
+
return set_num_threads_func(num_threads)
|
| 200 |
+
return None
|
| 201 |
+
|
| 202 |
+
def get_version(self):
|
| 203 |
+
# None means OpenBLAS is not loaded or version < 0.3.4, since OpenBLAS
|
| 204 |
+
# did not expose its version before that.
|
| 205 |
+
get_version_func = self._get_symbol("openblas_get_config")
|
| 206 |
+
if get_version_func is not None:
|
| 207 |
+
get_version_func.restype = ctypes.c_char_p
|
| 208 |
+
config = get_version_func().split()
|
| 209 |
+
if config[0] == b"OpenBLAS":
|
| 210 |
+
return config[1].decode("utf-8")
|
| 211 |
+
return None
|
| 212 |
+
return None
|
| 213 |
+
|
| 214 |
+
def _get_threading_layer(self):
|
| 215 |
+
"""Return the threading layer of OpenBLAS"""
|
| 216 |
+
get_threading_layer_func = self._get_symbol("openblas_get_parallel")
|
| 217 |
+
if get_threading_layer_func is not None:
|
| 218 |
+
threading_layer = get_threading_layer_func()
|
| 219 |
+
if threading_layer == 2:
|
| 220 |
+
return "openmp"
|
| 221 |
+
elif threading_layer == 1:
|
| 222 |
+
return "pthreads"
|
| 223 |
+
return "disabled"
|
| 224 |
+
return "unknown"
|
| 225 |
+
|
| 226 |
+
def _get_architecture(self):
|
| 227 |
+
"""Return the architecture detected by OpenBLAS"""
|
| 228 |
+
get_architecture_func = self._get_symbol("openblas_get_corename")
|
| 229 |
+
if get_architecture_func is not None:
|
| 230 |
+
get_architecture_func.restype = ctypes.c_char_p
|
| 231 |
+
return get_architecture_func().decode("utf-8")
|
| 232 |
+
return None
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class BLISController(LibController):
|
| 236 |
+
"""Controller class for BLIS"""
|
| 237 |
+
|
| 238 |
+
user_api = "blas"
|
| 239 |
+
internal_api = "blis"
|
| 240 |
+
filename_prefixes = ("libblis", "libblas")
|
| 241 |
+
check_symbols = (
|
| 242 |
+
"bli_thread_get_num_threads",
|
| 243 |
+
"bli_thread_set_num_threads",
|
| 244 |
+
"bli_info_get_version_str",
|
| 245 |
+
"bli_info_get_enable_openmp",
|
| 246 |
+
"bli_info_get_enable_pthreads",
|
| 247 |
+
"bli_arch_query_id",
|
| 248 |
+
"bli_arch_string",
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
def set_additional_attributes(self):
|
| 252 |
+
self.threading_layer = self._get_threading_layer()
|
| 253 |
+
self.architecture = self._get_architecture()
|
| 254 |
+
|
| 255 |
+
def get_num_threads(self):
|
| 256 |
+
get_func = getattr(self.dynlib, "bli_thread_get_num_threads", lambda: None)
|
| 257 |
+
num_threads = get_func()
|
| 258 |
+
# by default BLIS is single-threaded and get_num_threads
|
| 259 |
+
# returns -1. We map it to 1 for consistency with other libraries.
|
| 260 |
+
return 1 if num_threads == -1 else num_threads
|
| 261 |
+
|
| 262 |
+
def set_num_threads(self, num_threads):
|
| 263 |
+
set_func = getattr(
|
| 264 |
+
self.dynlib, "bli_thread_set_num_threads", lambda num_threads: None
|
| 265 |
+
)
|
| 266 |
+
return set_func(num_threads)
|
| 267 |
+
|
| 268 |
+
def get_version(self):
|
| 269 |
+
get_version_ = getattr(self.dynlib, "bli_info_get_version_str", None)
|
| 270 |
+
if get_version_ is None:
|
| 271 |
+
return None
|
| 272 |
+
|
| 273 |
+
get_version_.restype = ctypes.c_char_p
|
| 274 |
+
return get_version_().decode("utf-8")
|
| 275 |
+
|
| 276 |
+
def _get_threading_layer(self):
|
| 277 |
+
"""Return the threading layer of BLIS"""
|
| 278 |
+
if getattr(self.dynlib, "bli_info_get_enable_openmp", lambda: False)():
|
| 279 |
+
return "openmp"
|
| 280 |
+
elif getattr(self.dynlib, "bli_info_get_enable_pthreads", lambda: False)():
|
| 281 |
+
return "pthreads"
|
| 282 |
+
return "disabled"
|
| 283 |
+
|
| 284 |
+
def _get_architecture(self):
|
| 285 |
+
"""Return the architecture detected by BLIS"""
|
| 286 |
+
bli_arch_query_id = getattr(self.dynlib, "bli_arch_query_id", None)
|
| 287 |
+
bli_arch_string = getattr(self.dynlib, "bli_arch_string", None)
|
| 288 |
+
if bli_arch_query_id is None or bli_arch_string is None:
|
| 289 |
+
return None
|
| 290 |
+
|
| 291 |
+
# the true restype should be BLIS' arch_t (enum) but int should work
|
| 292 |
+
# for us:
|
| 293 |
+
bli_arch_query_id.restype = ctypes.c_int
|
| 294 |
+
bli_arch_string.restype = ctypes.c_char_p
|
| 295 |
+
return bli_arch_string(bli_arch_query_id()).decode("utf-8")
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
class FlexiBLASController(LibController):
|
| 299 |
+
"""Controller class for FlexiBLAS"""
|
| 300 |
+
|
| 301 |
+
user_api = "blas"
|
| 302 |
+
internal_api = "flexiblas"
|
| 303 |
+
filename_prefixes = ("libflexiblas",)
|
| 304 |
+
check_symbols = (
|
| 305 |
+
"flexiblas_get_num_threads",
|
| 306 |
+
"flexiblas_set_num_threads",
|
| 307 |
+
"flexiblas_get_version",
|
| 308 |
+
"flexiblas_list",
|
| 309 |
+
"flexiblas_list_loaded",
|
| 310 |
+
"flexiblas_current_backend",
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
@property
|
| 314 |
+
def loaded_backends(self):
|
| 315 |
+
return self._get_backend_list(loaded=True)
|
| 316 |
+
|
| 317 |
+
@property
|
| 318 |
+
def current_backend(self):
|
| 319 |
+
return self._get_current_backend()
|
| 320 |
+
|
| 321 |
+
def info(self):
|
| 322 |
+
"""Return relevant info wrapped in a dict"""
|
| 323 |
+
# We override the info method because the loaded and current backends
|
| 324 |
+
# are dynamic properties
|
| 325 |
+
exposed_attrs = super().info()
|
| 326 |
+
exposed_attrs["loaded_backends"] = self.loaded_backends
|
| 327 |
+
exposed_attrs["current_backend"] = self.current_backend
|
| 328 |
+
|
| 329 |
+
return exposed_attrs
|
| 330 |
+
|
| 331 |
+
def set_additional_attributes(self):
|
| 332 |
+
self.available_backends = self._get_backend_list(loaded=False)
|
| 333 |
+
|
| 334 |
+
def get_num_threads(self):
|
| 335 |
+
get_func = getattr(self.dynlib, "flexiblas_get_num_threads", lambda: None)
|
| 336 |
+
num_threads = get_func()
|
| 337 |
+
# by default BLIS is single-threaded and get_num_threads
|
| 338 |
+
# returns -1. We map it to 1 for consistency with other libraries.
|
| 339 |
+
return 1 if num_threads == -1 else num_threads
|
| 340 |
+
|
| 341 |
+
def set_num_threads(self, num_threads):
|
| 342 |
+
set_func = getattr(
|
| 343 |
+
self.dynlib, "flexiblas_set_num_threads", lambda num_threads: None
|
| 344 |
+
)
|
| 345 |
+
return set_func(num_threads)
|
| 346 |
+
|
| 347 |
+
def get_version(self):
|
| 348 |
+
get_version_ = getattr(self.dynlib, "flexiblas_get_version", None)
|
| 349 |
+
if get_version_ is None:
|
| 350 |
+
return None
|
| 351 |
+
|
| 352 |
+
major = ctypes.c_int()
|
| 353 |
+
minor = ctypes.c_int()
|
| 354 |
+
patch = ctypes.c_int()
|
| 355 |
+
get_version_(ctypes.byref(major), ctypes.byref(minor), ctypes.byref(patch))
|
| 356 |
+
return f"{major.value}.{minor.value}.{patch.value}"
|
| 357 |
+
|
| 358 |
+
def _get_backend_list(self, loaded=False):
|
| 359 |
+
"""Return the list of available backends for FlexiBLAS.
|
| 360 |
+
|
| 361 |
+
If loaded is False, return the list of available backends from the FlexiBLAS
|
| 362 |
+
configuration. If loaded is True, return the list of actually loaded backends.
|
| 363 |
+
"""
|
| 364 |
+
func_name = f"flexiblas_list{'_loaded' if loaded else ''}"
|
| 365 |
+
get_backend_list_ = getattr(self.dynlib, func_name, None)
|
| 366 |
+
if get_backend_list_ is None:
|
| 367 |
+
return None
|
| 368 |
+
|
| 369 |
+
n_backends = get_backend_list_(None, 0, 0)
|
| 370 |
+
|
| 371 |
+
backends = []
|
| 372 |
+
for i in range(n_backends):
|
| 373 |
+
backend_name = ctypes.create_string_buffer(1024)
|
| 374 |
+
get_backend_list_(backend_name, 1024, i)
|
| 375 |
+
if backend_name.value.decode("utf-8") != "__FALLBACK__":
|
| 376 |
+
# We don't know when to expect __FALLBACK__ but it is not a real
|
| 377 |
+
# backend and does not show up when running flexiblas list.
|
| 378 |
+
backends.append(backend_name.value.decode("utf-8"))
|
| 379 |
+
return backends
|
| 380 |
+
|
| 381 |
+
def _get_current_backend(self):
|
| 382 |
+
"""Return the backend of FlexiBLAS"""
|
| 383 |
+
get_backend_ = getattr(self.dynlib, "flexiblas_current_backend", None)
|
| 384 |
+
if get_backend_ is None:
|
| 385 |
+
return None
|
| 386 |
+
|
| 387 |
+
backend = ctypes.create_string_buffer(1024)
|
| 388 |
+
get_backend_(backend, ctypes.sizeof(backend))
|
| 389 |
+
return backend.value.decode("utf-8")
|
| 390 |
+
|
| 391 |
+
def switch_backend(self, backend):
|
| 392 |
+
"""Switch the backend of FlexiBLAS
|
| 393 |
+
|
| 394 |
+
Parameters
|
| 395 |
+
----------
|
| 396 |
+
backend : str
|
| 397 |
+
The name or the path to the shared library of the backend to switch to. If
|
| 398 |
+
the backend is not already loaded, it will be loaded first.
|
| 399 |
+
"""
|
| 400 |
+
if backend not in self.loaded_backends:
|
| 401 |
+
if backend in self.available_backends:
|
| 402 |
+
load_func = getattr(self.dynlib, "flexiblas_load_backend", lambda _: -1)
|
| 403 |
+
else: # assume backend is a path to a shared library
|
| 404 |
+
load_func = getattr(
|
| 405 |
+
self.dynlib, "flexiblas_load_backend_library", lambda _: -1
|
| 406 |
+
)
|
| 407 |
+
res = load_func(str(backend).encode("utf-8"))
|
| 408 |
+
if res == -1:
|
| 409 |
+
raise RuntimeError(
|
| 410 |
+
f"Failed to load backend {backend!r}. It must either be the name of"
|
| 411 |
+
" a backend available in the FlexiBLAS configuration "
|
| 412 |
+
f"{self.available_backends} or the path to a valid shared library."
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
# Trigger a new search of loaded shared libraries since loading a new
|
| 416 |
+
# backend caused a dlopen.
|
| 417 |
+
self.parent._load_libraries()
|
| 418 |
+
|
| 419 |
+
switch_func = getattr(self.dynlib, "flexiblas_switch", lambda _: -1)
|
| 420 |
+
idx = self.loaded_backends.index(backend)
|
| 421 |
+
res = switch_func(idx)
|
| 422 |
+
if res == -1:
|
| 423 |
+
raise RuntimeError(f"Failed to switch to backend {backend!r}.")
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
class MKLController(LibController):
|
| 427 |
+
"""Controller class for MKL"""
|
| 428 |
+
|
| 429 |
+
user_api = "blas"
|
| 430 |
+
internal_api = "mkl"
|
| 431 |
+
filename_prefixes = ("libmkl_rt", "mkl_rt", "libblas")
|
| 432 |
+
check_symbols = (
|
| 433 |
+
"MKL_Get_Max_Threads",
|
| 434 |
+
"MKL_Set_Num_Threads",
|
| 435 |
+
"MKL_Get_Version_String",
|
| 436 |
+
"MKL_Set_Threading_Layer",
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
def set_additional_attributes(self):
|
| 440 |
+
self.threading_layer = self._get_threading_layer()
|
| 441 |
+
|
| 442 |
+
def get_num_threads(self):
|
| 443 |
+
get_func = getattr(self.dynlib, "MKL_Get_Max_Threads", lambda: None)
|
| 444 |
+
return get_func()
|
| 445 |
+
|
| 446 |
+
def set_num_threads(self, num_threads):
|
| 447 |
+
set_func = getattr(self.dynlib, "MKL_Set_Num_Threads", lambda num_threads: None)
|
| 448 |
+
return set_func(num_threads)
|
| 449 |
+
|
| 450 |
+
def get_version(self):
|
| 451 |
+
if not hasattr(self.dynlib, "MKL_Get_Version_String"):
|
| 452 |
+
return None
|
| 453 |
+
|
| 454 |
+
res = ctypes.create_string_buffer(200)
|
| 455 |
+
self.dynlib.MKL_Get_Version_String(res, 200)
|
| 456 |
+
|
| 457 |
+
version = res.value.decode("utf-8")
|
| 458 |
+
group = re.search(r"Version ([^ ]+) ", version)
|
| 459 |
+
if group is not None:
|
| 460 |
+
version = group.groups()[0]
|
| 461 |
+
return version.strip()
|
| 462 |
+
|
| 463 |
+
def _get_threading_layer(self):
|
| 464 |
+
"""Return the threading layer of MKL"""
|
| 465 |
+
# The function mkl_set_threading_layer returns the current threading
|
| 466 |
+
# layer. Calling it with an invalid threading layer allows us to safely
|
| 467 |
+
# get the threading layer
|
| 468 |
+
set_threading_layer = getattr(
|
| 469 |
+
self.dynlib, "MKL_Set_Threading_Layer", lambda layer: -1
|
| 470 |
+
)
|
| 471 |
+
layer_map = {
|
| 472 |
+
0: "intel",
|
| 473 |
+
1: "sequential",
|
| 474 |
+
2: "pgi",
|
| 475 |
+
3: "gnu",
|
| 476 |
+
4: "tbb",
|
| 477 |
+
-1: "not specified",
|
| 478 |
+
}
|
| 479 |
+
return layer_map[set_threading_layer(-1)]
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
class OpenMPController(LibController):
|
| 483 |
+
"""Controller class for OpenMP"""
|
| 484 |
+
|
| 485 |
+
user_api = "openmp"
|
| 486 |
+
internal_api = "openmp"
|
| 487 |
+
filename_prefixes = ("libiomp", "libgomp", "libomp", "vcomp")
|
| 488 |
+
check_symbols = (
|
| 489 |
+
"omp_get_max_threads",
|
| 490 |
+
"omp_get_num_threads",
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
def get_num_threads(self):
|
| 494 |
+
get_func = getattr(self.dynlib, "omp_get_max_threads", lambda: None)
|
| 495 |
+
return get_func()
|
| 496 |
+
|
| 497 |
+
def set_num_threads(self, num_threads):
|
| 498 |
+
set_func = getattr(self.dynlib, "omp_set_num_threads", lambda num_threads: None)
|
| 499 |
+
return set_func(num_threads)
|
| 500 |
+
|
| 501 |
+
def get_version(self):
|
| 502 |
+
# There is no way to get the version number programmatically in OpenMP.
|
| 503 |
+
return None
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
# Controllers for the libraries that we'll look for in the loaded libraries.
|
| 507 |
+
# Third party libraries can register their own controllers.
|
| 508 |
+
_ALL_CONTROLLERS = [
|
| 509 |
+
OpenBLASController,
|
| 510 |
+
BLISController,
|
| 511 |
+
MKLController,
|
| 512 |
+
OpenMPController,
|
| 513 |
+
FlexiBLASController,
|
| 514 |
+
]
|
| 515 |
+
|
| 516 |
+
# Helpers for the doc and test names
|
| 517 |
+
_ALL_USER_APIS = list(set(lib.user_api for lib in _ALL_CONTROLLERS))
|
| 518 |
+
_ALL_INTERNAL_APIS = [lib.internal_api for lib in _ALL_CONTROLLERS]
|
| 519 |
+
_ALL_PREFIXES = list(
|
| 520 |
+
set(prefix for lib in _ALL_CONTROLLERS for prefix in lib.filename_prefixes)
|
| 521 |
+
)
|
| 522 |
+
_ALL_BLAS_LIBRARIES = [
|
| 523 |
+
lib.internal_api for lib in _ALL_CONTROLLERS if lib.user_api == "blas"
|
| 524 |
+
]
|
| 525 |
+
_ALL_OPENMP_LIBRARIES = OpenMPController.filename_prefixes
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def register(controller):
|
| 529 |
+
"""Register a new controller"""
|
| 530 |
+
_ALL_CONTROLLERS.append(controller)
|
| 531 |
+
_ALL_USER_APIS.append(controller.user_api)
|
| 532 |
+
_ALL_INTERNAL_APIS.append(controller.internal_api)
|
| 533 |
+
_ALL_PREFIXES.extend(controller.filename_prefixes)
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
def _format_docstring(*args, **kwargs):
|
| 537 |
+
def decorator(o):
|
| 538 |
+
if o.__doc__ is not None:
|
| 539 |
+
o.__doc__ = o.__doc__.format(*args, **kwargs)
|
| 540 |
+
return o
|
| 541 |
+
|
| 542 |
+
return decorator
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
@lru_cache(maxsize=10000)
|
| 546 |
+
def _realpath(filepath):
|
| 547 |
+
"""Small caching wrapper around os.path.realpath to limit system calls"""
|
| 548 |
+
return os.path.realpath(filepath)
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
@_format_docstring(USER_APIS=list(_ALL_USER_APIS), INTERNAL_APIS=_ALL_INTERNAL_APIS)
|
| 552 |
+
def threadpool_info():
|
| 553 |
+
"""Return the maximal number of threads for each detected library.
|
| 554 |
+
|
| 555 |
+
Return a list with all the supported libraries that have been found. Each
|
| 556 |
+
library is represented by a dict with the following information:
|
| 557 |
+
|
| 558 |
+
- "user_api" : user API. Possible values are {USER_APIS}.
|
| 559 |
+
- "internal_api": internal API. Possible values are {INTERNAL_APIS}.
|
| 560 |
+
- "prefix" : filename prefix of the specific implementation.
|
| 561 |
+
- "filepath": path to the loaded library.
|
| 562 |
+
- "version": version of the library (if available).
|
| 563 |
+
- "num_threads": the current thread limit.
|
| 564 |
+
|
| 565 |
+
In addition, each library may contain internal_api specific entries.
|
| 566 |
+
"""
|
| 567 |
+
return ThreadpoolController().info()
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
class _ThreadpoolLimiter:
|
| 571 |
+
"""The guts of ThreadpoolController.limit
|
| 572 |
+
|
| 573 |
+
Refer to the docstring of ThreadpoolController.limit for more details.
|
| 574 |
+
|
| 575 |
+
It will only act on the library controllers held by the provided `controller`.
|
| 576 |
+
Using the default constructor sets the limits right away such that it can be used as
|
| 577 |
+
a callable. Setting the limits can be delayed by using the `wrap` class method such
|
| 578 |
+
that it can be used as a decorator.
|
| 579 |
+
"""
|
| 580 |
+
|
| 581 |
+
def __init__(self, controller, *, limits=None, user_api=None):
|
| 582 |
+
self._controller = controller
|
| 583 |
+
self._limits, self._user_api, self._prefixes = self._check_params(
|
| 584 |
+
limits, user_api
|
| 585 |
+
)
|
| 586 |
+
self._original_info = self._controller.info()
|
| 587 |
+
self._set_threadpool_limits()
|
| 588 |
+
|
| 589 |
+
def __enter__(self):
|
| 590 |
+
return self
|
| 591 |
+
|
| 592 |
+
def __exit__(self, type, value, traceback):
|
| 593 |
+
self.restore_original_limits()
|
| 594 |
+
|
| 595 |
+
@classmethod
|
| 596 |
+
def wrap(cls, controller, *, limits=None, user_api=None):
|
| 597 |
+
"""Return an instance of this class that can be used as a decorator"""
|
| 598 |
+
return _ThreadpoolLimiterDecorator(
|
| 599 |
+
controller=controller, limits=limits, user_api=user_api
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
def restore_original_limits(self):
|
| 603 |
+
"""Set the limits back to their original values"""
|
| 604 |
+
for lib_controller, original_info in zip(
|
| 605 |
+
self._controller.lib_controllers, self._original_info
|
| 606 |
+
):
|
| 607 |
+
lib_controller.set_num_threads(original_info["num_threads"])
|
| 608 |
+
|
| 609 |
+
# Alias of `restore_original_limits` for backward compatibility
|
| 610 |
+
unregister = restore_original_limits
|
| 611 |
+
|
| 612 |
+
def get_original_num_threads(self):
|
| 613 |
+
"""Original num_threads from before calling threadpool_limits
|
| 614 |
+
|
| 615 |
+
Return a dict `{user_api: num_threads}`.
|
| 616 |
+
"""
|
| 617 |
+
num_threads = {}
|
| 618 |
+
warning_apis = []
|
| 619 |
+
|
| 620 |
+
for user_api in self._user_api:
|
| 621 |
+
limits = [
|
| 622 |
+
lib_info["num_threads"]
|
| 623 |
+
for lib_info in self._original_info
|
| 624 |
+
if lib_info["user_api"] == user_api
|
| 625 |
+
]
|
| 626 |
+
limits = set(limits)
|
| 627 |
+
n_limits = len(limits)
|
| 628 |
+
|
| 629 |
+
if n_limits == 1:
|
| 630 |
+
limit = limits.pop()
|
| 631 |
+
elif n_limits == 0:
|
| 632 |
+
limit = None
|
| 633 |
+
else:
|
| 634 |
+
limit = min(limits)
|
| 635 |
+
warning_apis.append(user_api)
|
| 636 |
+
|
| 637 |
+
num_threads[user_api] = limit
|
| 638 |
+
|
| 639 |
+
if warning_apis:
|
| 640 |
+
warnings.warn(
|
| 641 |
+
"Multiple value possible for following user apis: "
|
| 642 |
+
+ ", ".join(warning_apis)
|
| 643 |
+
+ ". Returning the minimum."
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
return num_threads
|
| 647 |
+
|
| 648 |
+
def _check_params(self, limits, user_api):
|
| 649 |
+
"""Suitable values for the _limits, _user_api and _prefixes attributes"""
|
| 650 |
+
|
| 651 |
+
if isinstance(limits, str) and limits == "sequential_blas_under_openmp":
|
| 652 |
+
(
|
| 653 |
+
limits,
|
| 654 |
+
user_api,
|
| 655 |
+
) = self._controller._get_params_for_sequential_blas_under_openmp().values()
|
| 656 |
+
|
| 657 |
+
if limits is None or isinstance(limits, int):
|
| 658 |
+
if user_api is None:
|
| 659 |
+
user_api = _ALL_USER_APIS
|
| 660 |
+
elif user_api in _ALL_USER_APIS:
|
| 661 |
+
user_api = [user_api]
|
| 662 |
+
else:
|
| 663 |
+
raise ValueError(
|
| 664 |
+
f"user_api must be either in {_ALL_USER_APIS} or None. Got "
|
| 665 |
+
f"{user_api} instead."
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
if limits is not None:
|
| 669 |
+
limits = {api: limits for api in user_api}
|
| 670 |
+
prefixes = []
|
| 671 |
+
else:
|
| 672 |
+
if isinstance(limits, list):
|
| 673 |
+
# This should be a list of dicts of library info, for
|
| 674 |
+
# compatibility with the result from threadpool_info.
|
| 675 |
+
limits = {
|
| 676 |
+
lib_info["prefix"]: lib_info["num_threads"] for lib_info in limits
|
| 677 |
+
}
|
| 678 |
+
elif isinstance(limits, ThreadpoolController):
|
| 679 |
+
# To set the limits from the library controllers of a
|
| 680 |
+
# ThreadpoolController object.
|
| 681 |
+
limits = {
|
| 682 |
+
lib_controller.prefix: lib_controller.num_threads
|
| 683 |
+
for lib_controller in limits.lib_controllers
|
| 684 |
+
}
|
| 685 |
+
|
| 686 |
+
if not isinstance(limits, dict):
|
| 687 |
+
raise TypeError(
|
| 688 |
+
"limits must either be an int, a list, a dict, or "
|
| 689 |
+
f"'sequential_blas_under_openmp'. Got {type(limits)} instead"
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
# With a dictionary, can set both specific limit for given
|
| 693 |
+
# libraries and global limit for user_api. Fetch each separately.
|
| 694 |
+
prefixes = [prefix for prefix in limits if prefix in _ALL_PREFIXES]
|
| 695 |
+
user_api = [api for api in limits if api in _ALL_USER_APIS]
|
| 696 |
+
|
| 697 |
+
return limits, user_api, prefixes
|
| 698 |
+
|
| 699 |
+
def _set_threadpool_limits(self):
|
| 700 |
+
"""Change the maximal number of threads in selected thread pools.
|
| 701 |
+
|
| 702 |
+
Return a list with all the supported libraries that have been found
|
| 703 |
+
matching `self._prefixes` and `self._user_api`.
|
| 704 |
+
"""
|
| 705 |
+
if self._limits is None:
|
| 706 |
+
return
|
| 707 |
+
|
| 708 |
+
for lib_controller in self._controller.lib_controllers:
|
| 709 |
+
# self._limits is a dict {key: num_threads} where key is either
|
| 710 |
+
# a prefix or a user_api. If a library matches both, the limit
|
| 711 |
+
# corresponding to the prefix is chosen.
|
| 712 |
+
if lib_controller.prefix in self._limits:
|
| 713 |
+
num_threads = self._limits[lib_controller.prefix]
|
| 714 |
+
elif lib_controller.user_api in self._limits:
|
| 715 |
+
num_threads = self._limits[lib_controller.user_api]
|
| 716 |
+
else:
|
| 717 |
+
continue
|
| 718 |
+
|
| 719 |
+
if num_threads is not None:
|
| 720 |
+
lib_controller.set_num_threads(num_threads)
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
class _ThreadpoolLimiterDecorator(_ThreadpoolLimiter, ContextDecorator):
|
| 724 |
+
"""Same as _ThreadpoolLimiter but to be used as a decorator"""
|
| 725 |
+
|
| 726 |
+
def __init__(self, controller, *, limits=None, user_api=None):
|
| 727 |
+
self._limits, self._user_api, self._prefixes = self._check_params(
|
| 728 |
+
limits, user_api
|
| 729 |
+
)
|
| 730 |
+
self._controller = controller
|
| 731 |
+
|
| 732 |
+
def __enter__(self):
|
| 733 |
+
# we need to set the limits here and not in the __init__ because we want the
|
| 734 |
+
# limits to be set when calling the decorated function, not when creating the
|
| 735 |
+
# decorator.
|
| 736 |
+
self._original_info = self._controller.info()
|
| 737 |
+
self._set_threadpool_limits()
|
| 738 |
+
return self
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
@_format_docstring(
|
| 742 |
+
USER_APIS=", ".join(f'"{api}"' for api in _ALL_USER_APIS),
|
| 743 |
+
BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES),
|
| 744 |
+
OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES),
|
| 745 |
+
)
|
| 746 |
+
class threadpool_limits(_ThreadpoolLimiter):
|
| 747 |
+
"""Change the maximal number of threads that can be used in thread pools.
|
| 748 |
+
|
| 749 |
+
This object can be used either as a callable (the construction of this object
|
| 750 |
+
limits the number of threads), as a context manager in a `with` block to
|
| 751 |
+
automatically restore the original state of the controlled libraries when exiting
|
| 752 |
+
the block, or as a decorator through its `wrap` method.
|
| 753 |
+
|
| 754 |
+
Set the maximal number of threads that can be used in thread pools used in
|
| 755 |
+
the supported libraries to `limit`. This function works for libraries that
|
| 756 |
+
are already loaded in the interpreter and can be changed dynamically.
|
| 757 |
+
|
| 758 |
+
This effect is global and impacts the whole Python process. There is no thread level
|
| 759 |
+
isolation as these libraries do not offer thread-local APIs to configure the number
|
| 760 |
+
of threads to use in nested parallel calls.
|
| 761 |
+
|
| 762 |
+
Parameters
|
| 763 |
+
----------
|
| 764 |
+
limits : int, dict, 'sequential_blas_under_openmp' or None (default=None)
|
| 765 |
+
The maximal number of threads that can be used in thread pools
|
| 766 |
+
|
| 767 |
+
- If int, sets the maximum number of threads to `limits` for each
|
| 768 |
+
library selected by `user_api`.
|
| 769 |
+
|
| 770 |
+
- If it is a dictionary `{{key: max_threads}}`, this function sets a
|
| 771 |
+
custom maximum number of threads for each `key` which can be either a
|
| 772 |
+
`user_api` or a `prefix` for a specific library.
|
| 773 |
+
|
| 774 |
+
- If 'sequential_blas_under_openmp', it will chose the appropriate `limits`
|
| 775 |
+
and `user_api` parameters for the specific use case of sequential BLAS
|
| 776 |
+
calls within an OpenMP parallel region. The `user_api` parameter is
|
| 777 |
+
ignored.
|
| 778 |
+
|
| 779 |
+
- If None, this function does not do anything.
|
| 780 |
+
|
| 781 |
+
user_api : {USER_APIS} or None (default=None)
|
| 782 |
+
APIs of libraries to limit. Used only if `limits` is an int.
|
| 783 |
+
|
| 784 |
+
- If "blas", it will only limit BLAS supported libraries ({BLAS_LIBS}).
|
| 785 |
+
|
| 786 |
+
- If "openmp", it will only limit OpenMP supported libraries
|
| 787 |
+
({OPENMP_LIBS}). Note that it can affect the number of threads used
|
| 788 |
+
by the BLAS libraries if they rely on OpenMP.
|
| 789 |
+
|
| 790 |
+
- If None, this function will apply to all supported libraries.
|
| 791 |
+
"""
|
| 792 |
+
|
| 793 |
+
def __init__(self, limits=None, user_api=None):
|
| 794 |
+
super().__init__(ThreadpoolController(), limits=limits, user_api=user_api)
|
| 795 |
+
|
| 796 |
+
@classmethod
|
| 797 |
+
def wrap(cls, limits=None, user_api=None):
|
| 798 |
+
return super().wrap(ThreadpoolController(), limits=limits, user_api=user_api)
|
| 799 |
+
|
| 800 |
+
|
| 801 |
+
class ThreadpoolController:
|
| 802 |
+
"""Collection of LibController objects for all loaded supported libraries
|
| 803 |
+
|
| 804 |
+
Attributes
|
| 805 |
+
----------
|
| 806 |
+
lib_controllers : list of `LibController` objects
|
| 807 |
+
The list of library controllers of all loaded supported libraries.
|
| 808 |
+
"""
|
| 809 |
+
|
| 810 |
+
# Cache for libc under POSIX and a few system libraries under Windows.
|
| 811 |
+
# We use a class level cache instead of an instance level cache because
|
| 812 |
+
# it's very unlikely that a shared library will be unloaded and reloaded
|
| 813 |
+
# during the lifetime of a program.
|
| 814 |
+
_system_libraries = dict()
|
| 815 |
+
|
| 816 |
+
def __init__(self):
|
| 817 |
+
self.lib_controllers = []
|
| 818 |
+
self._load_libraries()
|
| 819 |
+
self._warn_if_incompatible_openmp()
|
| 820 |
+
|
| 821 |
+
@classmethod
|
| 822 |
+
def _from_controllers(cls, lib_controllers):
|
| 823 |
+
new_controller = cls.__new__(cls)
|
| 824 |
+
new_controller.lib_controllers = lib_controllers
|
| 825 |
+
return new_controller
|
| 826 |
+
|
| 827 |
+
def info(self):
|
| 828 |
+
"""Return lib_controllers info as a list of dicts"""
|
| 829 |
+
return [lib_controller.info() for lib_controller in self.lib_controllers]
|
| 830 |
+
|
| 831 |
+
def select(self, **kwargs):
|
| 832 |
+
"""Return a ThreadpoolController containing a subset of its current
|
| 833 |
+
library controllers
|
| 834 |
+
|
| 835 |
+
It will select all libraries matching at least one pair (key, value) from kwargs
|
| 836 |
+
where key is an entry of the library info dict (like "user_api", "internal_api",
|
| 837 |
+
"prefix", ...) and value is the value or a list of acceptable values for that
|
| 838 |
+
entry.
|
| 839 |
+
|
| 840 |
+
For instance, `ThreadpoolController().select(internal_api=["blis", "openblas"])`
|
| 841 |
+
will select all library controllers whose internal_api is either "blis" or
|
| 842 |
+
"openblas".
|
| 843 |
+
"""
|
| 844 |
+
for key, vals in kwargs.items():
|
| 845 |
+
kwargs[key] = [vals] if not isinstance(vals, list) else vals
|
| 846 |
+
|
| 847 |
+
lib_controllers = [
|
| 848 |
+
lib_controller
|
| 849 |
+
for lib_controller in self.lib_controllers
|
| 850 |
+
if any(
|
| 851 |
+
getattr(lib_controller, key, None) in vals
|
| 852 |
+
for key, vals in kwargs.items()
|
| 853 |
+
)
|
| 854 |
+
]
|
| 855 |
+
|
| 856 |
+
return ThreadpoolController._from_controllers(lib_controllers)
|
| 857 |
+
|
| 858 |
+
def _get_params_for_sequential_blas_under_openmp(self):
|
| 859 |
+
"""Return appropriate params to use for a sequential BLAS call in an OpenMP loop
|
| 860 |
+
|
| 861 |
+
This function takes into account the unexpected behavior of OpenBLAS with the
|
| 862 |
+
OpenMP threading layer.
|
| 863 |
+
"""
|
| 864 |
+
if self.select(
|
| 865 |
+
internal_api="openblas", threading_layer="openmp"
|
| 866 |
+
).lib_controllers:
|
| 867 |
+
return {"limits": None, "user_api": None}
|
| 868 |
+
return {"limits": 1, "user_api": "blas"}
|
| 869 |
+
|
| 870 |
+
@_format_docstring(
|
| 871 |
+
USER_APIS=", ".join('"{}"'.format(api) for api in _ALL_USER_APIS),
|
| 872 |
+
BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES),
|
| 873 |
+
OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES),
|
| 874 |
+
)
|
| 875 |
+
def limit(self, *, limits=None, user_api=None):
|
| 876 |
+
"""Change the maximal number of threads that can be used in thread pools.
|
| 877 |
+
|
| 878 |
+
This function returns an object that can be used either as a callable (the
|
| 879 |
+
construction of this object limits the number of threads) or as a context
|
| 880 |
+
manager, in a `with` block to automatically restore the original state of the
|
| 881 |
+
controlled libraries when exiting the block.
|
| 882 |
+
|
| 883 |
+
Set the maximal number of threads that can be used in thread pools used in
|
| 884 |
+
the supported libraries to `limits`. This function works for libraries that
|
| 885 |
+
are already loaded in the interpreter and can be changed dynamically.
|
| 886 |
+
|
| 887 |
+
This effect is global and impacts the whole Python process. There is no thread
|
| 888 |
+
level isolation as these libraries do not offer thread-local APIs to configure
|
| 889 |
+
the number of threads to use in nested parallel calls.
|
| 890 |
+
|
| 891 |
+
Parameters
|
| 892 |
+
----------
|
| 893 |
+
limits : int, dict, 'sequential_blas_under_openmp' or None (default=None)
|
| 894 |
+
The maximal number of threads that can be used in thread pools
|
| 895 |
+
|
| 896 |
+
- If int, sets the maximum number of threads to `limits` for each
|
| 897 |
+
library selected by `user_api`.
|
| 898 |
+
|
| 899 |
+
- If it is a dictionary `{{key: max_threads}}`, this function sets a
|
| 900 |
+
custom maximum number of threads for each `key` which can be either a
|
| 901 |
+
`user_api` or a `prefix` for a specific library.
|
| 902 |
+
|
| 903 |
+
- If 'sequential_blas_under_openmp', it will chose the appropriate `limits`
|
| 904 |
+
and `user_api` parameters for the specific use case of sequential BLAS
|
| 905 |
+
calls within an OpenMP parallel region. The `user_api` parameter is
|
| 906 |
+
ignored.
|
| 907 |
+
|
| 908 |
+
- If None, this function does not do anything.
|
| 909 |
+
|
| 910 |
+
user_api : {USER_APIS} or None (default=None)
|
| 911 |
+
APIs of libraries to limit. Used only if `limits` is an int.
|
| 912 |
+
|
| 913 |
+
- If "blas", it will only limit BLAS supported libraries ({BLAS_LIBS}).
|
| 914 |
+
|
| 915 |
+
- If "openmp", it will only limit OpenMP supported libraries
|
| 916 |
+
({OPENMP_LIBS}). Note that it can affect the number of threads used
|
| 917 |
+
by the BLAS libraries if they rely on OpenMP.
|
| 918 |
+
|
| 919 |
+
- If None, this function will apply to all supported libraries.
|
| 920 |
+
"""
|
| 921 |
+
return _ThreadpoolLimiter(self, limits=limits, user_api=user_api)
|
| 922 |
+
|
| 923 |
+
@_format_docstring(
|
| 924 |
+
USER_APIS=", ".join('"{}"'.format(api) for api in _ALL_USER_APIS),
|
| 925 |
+
BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES),
|
| 926 |
+
OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES),
|
| 927 |
+
)
|
| 928 |
+
def wrap(self, *, limits=None, user_api=None):
|
| 929 |
+
"""Change the maximal number of threads that can be used in thread pools.
|
| 930 |
+
|
| 931 |
+
This function returns an object that can be used as a decorator.
|
| 932 |
+
|
| 933 |
+
Set the maximal number of threads that can be used in thread pools used in
|
| 934 |
+
the supported libraries to `limits`. This function works for libraries that
|
| 935 |
+
are already loaded in the interpreter and can be changed dynamically.
|
| 936 |
+
|
| 937 |
+
Parameters
|
| 938 |
+
----------
|
| 939 |
+
limits : int, dict or None (default=None)
|
| 940 |
+
The maximal number of threads that can be used in thread pools
|
| 941 |
+
|
| 942 |
+
- If int, sets the maximum number of threads to `limits` for each
|
| 943 |
+
library selected by `user_api`.
|
| 944 |
+
|
| 945 |
+
- If it is a dictionary `{{key: max_threads}}`, this function sets a
|
| 946 |
+
custom maximum number of threads for each `key` which can be either a
|
| 947 |
+
`user_api` or a `prefix` for a specific library.
|
| 948 |
+
|
| 949 |
+
- If None, this function does not do anything.
|
| 950 |
+
|
| 951 |
+
user_api : {USER_APIS} or None (default=None)
|
| 952 |
+
APIs of libraries to limit. Used only if `limits` is an int.
|
| 953 |
+
|
| 954 |
+
- If "blas", it will only limit BLAS supported libraries ({BLAS_LIBS}).
|
| 955 |
+
|
| 956 |
+
- If "openmp", it will only limit OpenMP supported libraries
|
| 957 |
+
({OPENMP_LIBS}). Note that it can affect the number of threads used
|
| 958 |
+
by the BLAS libraries if they rely on OpenMP.
|
| 959 |
+
|
| 960 |
+
- If None, this function will apply to all supported libraries.
|
| 961 |
+
"""
|
| 962 |
+
return _ThreadpoolLimiter.wrap(self, limits=limits, user_api=user_api)
|
| 963 |
+
|
| 964 |
+
def __len__(self):
|
| 965 |
+
return len(self.lib_controllers)
|
| 966 |
+
|
| 967 |
+
def _load_libraries(self):
|
| 968 |
+
"""Loop through loaded shared libraries and store the supported ones"""
|
| 969 |
+
if sys.platform == "darwin":
|
| 970 |
+
self._find_libraries_with_dyld()
|
| 971 |
+
elif sys.platform == "win32":
|
| 972 |
+
self._find_libraries_with_enum_process_module_ex()
|
| 973 |
+
elif "pyodide" in sys.modules:
|
| 974 |
+
self._find_libraries_pyodide()
|
| 975 |
+
else:
|
| 976 |
+
self._find_libraries_with_dl_iterate_phdr()
|
| 977 |
+
|
| 978 |
+
def _find_libraries_with_dl_iterate_phdr(self):
|
| 979 |
+
"""Loop through loaded libraries and return binders on supported ones
|
| 980 |
+
|
| 981 |
+
This function is expected to work on POSIX system only.
|
| 982 |
+
This code is adapted from code by Intel developer @anton-malakhov
|
| 983 |
+
available at https://github.com/IntelPython/smp
|
| 984 |
+
|
| 985 |
+
Copyright (c) 2017, Intel Corporation published under the BSD 3-Clause
|
| 986 |
+
license
|
| 987 |
+
"""
|
| 988 |
+
libc = self._get_libc()
|
| 989 |
+
if not hasattr(libc, "dl_iterate_phdr"): # pragma: no cover
|
| 990 |
+
warnings.warn(
|
| 991 |
+
"Could not find dl_iterate_phdr in the C standard library.",
|
| 992 |
+
RuntimeWarning,
|
| 993 |
+
)
|
| 994 |
+
return []
|
| 995 |
+
|
| 996 |
+
# Callback function for `dl_iterate_phdr` which is called for every
|
| 997 |
+
# library loaded in the current process until it returns 1.
|
| 998 |
+
def match_library_callback(info, size, data):
|
| 999 |
+
# Get the path of the current library
|
| 1000 |
+
filepath = info.contents.dlpi_name
|
| 1001 |
+
if filepath:
|
| 1002 |
+
filepath = filepath.decode("utf-8")
|
| 1003 |
+
|
| 1004 |
+
# Store the library controller if it is supported and selected
|
| 1005 |
+
self._make_controller_from_path(filepath)
|
| 1006 |
+
return 0
|
| 1007 |
+
|
| 1008 |
+
c_func_signature = ctypes.CFUNCTYPE(
|
| 1009 |
+
ctypes.c_int, # Return type
|
| 1010 |
+
ctypes.POINTER(_dl_phdr_info),
|
| 1011 |
+
ctypes.c_size_t,
|
| 1012 |
+
ctypes.c_char_p,
|
| 1013 |
+
)
|
| 1014 |
+
c_match_library_callback = c_func_signature(match_library_callback)
|
| 1015 |
+
|
| 1016 |
+
data = ctypes.c_char_p(b"")
|
| 1017 |
+
libc.dl_iterate_phdr(c_match_library_callback, data)
|
| 1018 |
+
|
| 1019 |
+
def _find_libraries_with_dyld(self):
|
| 1020 |
+
"""Loop through loaded libraries and return binders on supported ones
|
| 1021 |
+
|
| 1022 |
+
This function is expected to work on OSX system only
|
| 1023 |
+
"""
|
| 1024 |
+
libc = self._get_libc()
|
| 1025 |
+
if not hasattr(libc, "_dyld_image_count"): # pragma: no cover
|
| 1026 |
+
warnings.warn(
|
| 1027 |
+
"Could not find _dyld_image_count in the C standard library.",
|
| 1028 |
+
RuntimeWarning,
|
| 1029 |
+
)
|
| 1030 |
+
return []
|
| 1031 |
+
|
| 1032 |
+
n_dyld = libc._dyld_image_count()
|
| 1033 |
+
libc._dyld_get_image_name.restype = ctypes.c_char_p
|
| 1034 |
+
|
| 1035 |
+
for i in range(n_dyld):
|
| 1036 |
+
filepath = ctypes.string_at(libc._dyld_get_image_name(i))
|
| 1037 |
+
filepath = filepath.decode("utf-8")
|
| 1038 |
+
|
| 1039 |
+
# Store the library controller if it is supported and selected
|
| 1040 |
+
self._make_controller_from_path(filepath)
|
| 1041 |
+
|
| 1042 |
+
def _find_libraries_with_enum_process_module_ex(self):
|
| 1043 |
+
"""Loop through loaded libraries and return binders on supported ones
|
| 1044 |
+
|
| 1045 |
+
This function is expected to work on windows system only.
|
| 1046 |
+
This code is adapted from code by Philipp Hagemeister @phihag available
|
| 1047 |
+
at https://stackoverflow.com/questions/17474574
|
| 1048 |
+
"""
|
| 1049 |
+
from ctypes.wintypes import DWORD, HMODULE, MAX_PATH
|
| 1050 |
+
|
| 1051 |
+
PROCESS_QUERY_INFORMATION = 0x0400
|
| 1052 |
+
PROCESS_VM_READ = 0x0010
|
| 1053 |
+
|
| 1054 |
+
LIST_LIBRARIES_ALL = 0x03
|
| 1055 |
+
|
| 1056 |
+
ps_api = self._get_windll("Psapi")
|
| 1057 |
+
kernel_32 = self._get_windll("kernel32")
|
| 1058 |
+
|
| 1059 |
+
h_process = kernel_32.OpenProcess(
|
| 1060 |
+
PROCESS_QUERY_INFORMATION | PROCESS_VM_READ, False, os.getpid()
|
| 1061 |
+
)
|
| 1062 |
+
if not h_process: # pragma: no cover
|
| 1063 |
+
raise OSError(f"Could not open PID {os.getpid()}")
|
| 1064 |
+
|
| 1065 |
+
try:
|
| 1066 |
+
buf_count = 256
|
| 1067 |
+
needed = DWORD()
|
| 1068 |
+
# Grow the buffer until it becomes large enough to hold all the
|
| 1069 |
+
# module headers
|
| 1070 |
+
while True:
|
| 1071 |
+
buf = (HMODULE * buf_count)()
|
| 1072 |
+
buf_size = ctypes.sizeof(buf)
|
| 1073 |
+
if not ps_api.EnumProcessModulesEx(
|
| 1074 |
+
h_process,
|
| 1075 |
+
ctypes.byref(buf),
|
| 1076 |
+
buf_size,
|
| 1077 |
+
ctypes.byref(needed),
|
| 1078 |
+
LIST_LIBRARIES_ALL,
|
| 1079 |
+
):
|
| 1080 |
+
raise OSError("EnumProcessModulesEx failed")
|
| 1081 |
+
if buf_size >= needed.value:
|
| 1082 |
+
break
|
| 1083 |
+
buf_count = needed.value // (buf_size // buf_count)
|
| 1084 |
+
|
| 1085 |
+
count = needed.value // (buf_size // buf_count)
|
| 1086 |
+
h_modules = map(HMODULE, buf[:count])
|
| 1087 |
+
|
| 1088 |
+
# Loop through all the module headers and get the library path
|
| 1089 |
+
# Allocate a buffer for the path 10 times the size of MAX_PATH to take
|
| 1090 |
+
# into account long path names.
|
| 1091 |
+
max_path = 10 * MAX_PATH
|
| 1092 |
+
buf = ctypes.create_unicode_buffer(max_path)
|
| 1093 |
+
n_size = DWORD()
|
| 1094 |
+
for h_module in h_modules:
|
| 1095 |
+
# Get the path of the current module
|
| 1096 |
+
if not ps_api.GetModuleFileNameExW(
|
| 1097 |
+
h_process, h_module, ctypes.byref(buf), ctypes.byref(n_size)
|
| 1098 |
+
):
|
| 1099 |
+
raise OSError("GetModuleFileNameEx failed")
|
| 1100 |
+
filepath = buf.value
|
| 1101 |
+
|
| 1102 |
+
if len(filepath) == max_path: # pragma: no cover
|
| 1103 |
+
warnings.warn(
|
| 1104 |
+
"Could not get the full path of a dynamic library (path too "
|
| 1105 |
+
"long). This library will be ignored and threadpoolctl might "
|
| 1106 |
+
"not be able to control or display information about all "
|
| 1107 |
+
f"loaded libraries. Here's the truncated path: {filepath!r}",
|
| 1108 |
+
RuntimeWarning,
|
| 1109 |
+
)
|
| 1110 |
+
else:
|
| 1111 |
+
# Store the library controller if it is supported and selected
|
| 1112 |
+
self._make_controller_from_path(filepath)
|
| 1113 |
+
finally:
|
| 1114 |
+
kernel_32.CloseHandle(h_process)
|
| 1115 |
+
|
| 1116 |
+
def _find_libraries_pyodide(self):
|
| 1117 |
+
"""Pyodide specific implementation for finding loaded libraries.
|
| 1118 |
+
|
| 1119 |
+
Adapted from suggestion in https://github.com/joblib/threadpoolctl/pull/169#issuecomment-1946696449.
|
| 1120 |
+
|
| 1121 |
+
One day, we may have a simpler solution. libc dl_iterate_phdr needs to
|
| 1122 |
+
be implemented in Emscripten and exposed in Pyodide, see
|
| 1123 |
+
https://github.com/emscripten-core/emscripten/issues/21354 for more
|
| 1124 |
+
details.
|
| 1125 |
+
"""
|
| 1126 |
+
try:
|
| 1127 |
+
from pyodide_js._module import LDSO
|
| 1128 |
+
except ImportError:
|
| 1129 |
+
warnings.warn(
|
| 1130 |
+
"Unable to import LDSO from pyodide_js._module. This should never "
|
| 1131 |
+
"happen."
|
| 1132 |
+
)
|
| 1133 |
+
return
|
| 1134 |
+
|
| 1135 |
+
for filepath in LDSO.loadedLibsByName.as_object_map():
|
| 1136 |
+
# Some libraries are duplicated by Pyodide and do not exist in the
|
| 1137 |
+
# filesystem, so we first check for the existence of the file. For
|
| 1138 |
+
# more details, see
|
| 1139 |
+
# https://github.com/joblib/threadpoolctl/pull/169#issuecomment-1947946728
|
| 1140 |
+
if os.path.exists(filepath):
|
| 1141 |
+
self._make_controller_from_path(filepath)
|
| 1142 |
+
|
| 1143 |
+
def _make_controller_from_path(self, filepath):
|
| 1144 |
+
"""Store a library controller if it is supported and selected"""
|
| 1145 |
+
# Required to resolve symlinks
|
| 1146 |
+
filepath = _realpath(filepath)
|
| 1147 |
+
# `lower` required to take account of OpenMP dll case on Windows
|
| 1148 |
+
# (vcomp, VCOMP, Vcomp, ...)
|
| 1149 |
+
filename = os.path.basename(filepath).lower()
|
| 1150 |
+
|
| 1151 |
+
# Loop through supported libraries to find if this filename corresponds
|
| 1152 |
+
# to a supported one.
|
| 1153 |
+
for controller_class in _ALL_CONTROLLERS:
|
| 1154 |
+
# check if filename matches a supported prefix
|
| 1155 |
+
prefix = self._check_prefix(filename, controller_class.filename_prefixes)
|
| 1156 |
+
|
| 1157 |
+
# filename does not match any of the prefixes of the candidate
|
| 1158 |
+
# library. move to next library.
|
| 1159 |
+
if prefix is None:
|
| 1160 |
+
continue
|
| 1161 |
+
|
| 1162 |
+
# workaround for BLAS libraries packaged by conda-forge on windows, which
|
| 1163 |
+
# are all renamed "libblas.dll". We thus have to check to which BLAS
|
| 1164 |
+
# implementation it actually corresponds looking for implementation
|
| 1165 |
+
# specific symbols.
|
| 1166 |
+
if prefix == "libblas":
|
| 1167 |
+
if filename.endswith(".dll"):
|
| 1168 |
+
libblas = ctypes.CDLL(filepath, _RTLD_NOLOAD)
|
| 1169 |
+
if not any(
|
| 1170 |
+
hasattr(libblas, func)
|
| 1171 |
+
for func in controller_class.check_symbols
|
| 1172 |
+
):
|
| 1173 |
+
continue
|
| 1174 |
+
else:
|
| 1175 |
+
# We ignore libblas on other platforms than windows because there
|
| 1176 |
+
# might be a libblas dso comming with openblas for instance that
|
| 1177 |
+
# can't be used to instantiate a pertinent LibController (many
|
| 1178 |
+
# symbols are missing) and would create confusion by making a
|
| 1179 |
+
# duplicate entry in threadpool_info.
|
| 1180 |
+
continue
|
| 1181 |
+
|
| 1182 |
+
# filename matches a prefix. Now we check if the library has the symbols we
|
| 1183 |
+
# are looking for. If none of the symbols exists, it's very likely not the
|
| 1184 |
+
# expected library (e.g. a library having a common prefix with one of the
|
| 1185 |
+
# our supported libraries). Otherwise, create and store the library
|
| 1186 |
+
# controller.
|
| 1187 |
+
lib_controller = controller_class(
|
| 1188 |
+
filepath=filepath, prefix=prefix, parent=self
|
| 1189 |
+
)
|
| 1190 |
+
|
| 1191 |
+
if filepath in (lib.filepath for lib in self.lib_controllers):
|
| 1192 |
+
# We already have a controller for this library.
|
| 1193 |
+
continue
|
| 1194 |
+
|
| 1195 |
+
if not hasattr(controller_class, "check_symbols") or any(
|
| 1196 |
+
hasattr(lib_controller.dynlib, func)
|
| 1197 |
+
for func in controller_class.check_symbols
|
| 1198 |
+
):
|
| 1199 |
+
self.lib_controllers.append(lib_controller)
|
| 1200 |
+
|
| 1201 |
+
def _check_prefix(self, library_basename, filename_prefixes):
|
| 1202 |
+
"""Return the prefix library_basename starts with
|
| 1203 |
+
|
| 1204 |
+
Return None if none matches.
|
| 1205 |
+
"""
|
| 1206 |
+
for prefix in filename_prefixes:
|
| 1207 |
+
if library_basename.startswith(prefix):
|
| 1208 |
+
return prefix
|
| 1209 |
+
return None
|
| 1210 |
+
|
| 1211 |
+
def _warn_if_incompatible_openmp(self):
|
| 1212 |
+
"""Raise a warning if llvm-OpenMP and intel-OpenMP are both loaded"""
|
| 1213 |
+
prefixes = [lib_controller.prefix for lib_controller in self.lib_controllers]
|
| 1214 |
+
msg = textwrap.dedent(
|
| 1215 |
+
"""
|
| 1216 |
+
Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at
|
| 1217 |
+
the same time. Both libraries are known to be incompatible and this
|
| 1218 |
+
can cause random crashes or deadlocks on Linux when loaded in the
|
| 1219 |
+
same Python program.
|
| 1220 |
+
Using threadpoolctl may cause crashes or deadlocks. For more
|
| 1221 |
+
information and possible workarounds, please see
|
| 1222 |
+
https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md
|
| 1223 |
+
"""
|
| 1224 |
+
)
|
| 1225 |
+
if "libomp" in prefixes and "libiomp" in prefixes:
|
| 1226 |
+
warnings.warn(msg, RuntimeWarning)
|
| 1227 |
+
|
| 1228 |
+
@classmethod
|
| 1229 |
+
def _get_libc(cls):
|
| 1230 |
+
"""Load the lib-C for unix systems."""
|
| 1231 |
+
libc = cls._system_libraries.get("libc")
|
| 1232 |
+
if libc is None:
|
| 1233 |
+
# Remark: If libc is statically linked or if Python is linked against an
|
| 1234 |
+
# alternative implementation of libc like musl, find_library will return
|
| 1235 |
+
# None and CDLL will load the main program itself which should contain the
|
| 1236 |
+
# libc symbols. We still name it libc for convenience.
|
| 1237 |
+
# If the main program does not contain the libc symbols, it's ok because
|
| 1238 |
+
# we check their presence later anyway.
|
| 1239 |
+
libc = ctypes.CDLL(find_library("c"), mode=_RTLD_NOLOAD)
|
| 1240 |
+
cls._system_libraries["libc"] = libc
|
| 1241 |
+
return libc
|
| 1242 |
+
|
| 1243 |
+
@classmethod
|
| 1244 |
+
def _get_windll(cls, dll_name):
|
| 1245 |
+
"""Load a windows DLL"""
|
| 1246 |
+
dll = cls._system_libraries.get(dll_name)
|
| 1247 |
+
if dll is None:
|
| 1248 |
+
dll = ctypes.WinDLL(f"{dll_name}.dll")
|
| 1249 |
+
cls._system_libraries[dll_name] = dll
|
| 1250 |
+
return dll
|
| 1251 |
+
|
| 1252 |
+
|
| 1253 |
+
def _main():
|
| 1254 |
+
"""Commandline interface to display thread-pool information and exit."""
|
| 1255 |
+
import argparse
|
| 1256 |
+
import importlib
|
| 1257 |
+
import json
|
| 1258 |
+
import sys
|
| 1259 |
+
|
| 1260 |
+
parser = argparse.ArgumentParser(
|
| 1261 |
+
usage="python -m threadpoolctl -i numpy scipy.linalg xgboost",
|
| 1262 |
+
description="Display thread-pool information and exit.",
|
| 1263 |
+
)
|
| 1264 |
+
parser.add_argument(
|
| 1265 |
+
"-i",
|
| 1266 |
+
"--import",
|
| 1267 |
+
dest="modules",
|
| 1268 |
+
nargs="*",
|
| 1269 |
+
default=(),
|
| 1270 |
+
help="Python modules to import before introspecting thread-pools.",
|
| 1271 |
+
)
|
| 1272 |
+
parser.add_argument(
|
| 1273 |
+
"-c",
|
| 1274 |
+
"--command",
|
| 1275 |
+
help="a Python statement to execute before introspecting thread-pools.",
|
| 1276 |
+
)
|
| 1277 |
+
|
| 1278 |
+
options = parser.parse_args(sys.argv[1:])
|
| 1279 |
+
for module in options.modules:
|
| 1280 |
+
try:
|
| 1281 |
+
importlib.import_module(module, package=None)
|
| 1282 |
+
except ImportError:
|
| 1283 |
+
print("WARNING: could not import", module, file=sys.stderr)
|
| 1284 |
+
|
| 1285 |
+
if options.command:
|
| 1286 |
+
exec(options.command)
|
| 1287 |
+
|
| 1288 |
+
print(json.dumps(threadpool_info(), indent=2))
|
| 1289 |
+
|
| 1290 |
+
|
| 1291 |
+
if __name__ == "__main__":
|
| 1292 |
+
_main()
|
venv/Lib/site-packages/typing_extensions.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|